In [None]:
# Run this cell to mount your Google Drive.

from google.colab import drive
drive.mount("/content/drive")

# Imports

In [None]:
# path
pth = '/content/drive/MyDrive/Colab Notebooks/Thesis'

%cd /content/drive/My Drive/Colab Notebooks/Thesis/SupervisedLearning

from train import *
from visualize_data import *
from utils import *

%cd /content/drive/My Drive/Colab Notebooks/Thesis

# State-Action Pair Guideline

Specify the desired state-action pair:

i.e. Before Pickup, Before Discard: Draw (bpbd_draw)

## Model Name Guideline

Specify the following parameters:

- **Data Selection**:
 - **`state-action pair`**
   - possible state-action pairs
```r
state_action_pair = {'all': 'all', # all actions
                    'bpbd': 'draw', # actions 2/3 
                    'apbd': ['discard', 'knock'], # actions 6-57, 58-109
                    'apad': 'knock_bin'}
```
 - **`model_name`**
   - name of the model based on **pruned states** and **chosen action**
 - **`numGames`**
   - Number of games used to train model
$$numGames \in [2000, 6000, 8000]$$

 - **`pruneStatesList`**
   - which states to omit when training model 
 ```r
 pruneStatesList is a list, {'currHand','topCard','deadCard','oppCard','unknownCard'}
 ```
 - **`actionChoice`**
   - which specific action to train model 
```r
actionChoice is one of {'all','draw','discard','knock'}
```
 - **`balance`** (T/F, **`default = False`**)
   - balance data by smallest class

--- 

- **Model Parameters**:

| Parameter         | Type            |Default|
| ----------------- |:---------------:| -----:|
| **batch_size**    | int             | 1000  |
| **learning_rate** | float           | 0.001 |
| **epoch**         | int             | 100   |
| **pre_train**     | bool (T/F)      | False |
| **model_PT**      | str (model/path)| null  |
| **device**        | str (cpu/cuda)  | cpu   |

# State-Action Pair

## Model Name

### Parameters

In [2]:
# state_action pair
state = 'all'
action = 'all'

# model name
model_name = 'all_states_discard_only'

# Number of Games
numGames = 8000

# prunable states
# {'currHand','topCard','deadCard','oppCard','unknownCard'} or blank if None
pruneStatesList = []

# choosable actions
# {'all','draw_pickup','discard','knock','knock_bin'}
actionChoice = 'discard'

# Balance classes
balance = False

# Training parameters
batch_size = 1000
lr = 0.001
epochs = 100

# Pretrain model
pre_train = False
model_PT = ''

# device for model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

### Create Directories

In [3]:
data_pth, model_pth, plot_pth = create_dir(pth, state, action, model_name)

### Load, Prune, and Split Training Data

In [None]:
train_loader, val_loader, classes = load_train_data(data_pth, plot_pth,
                                                    numGames, batch_size, 
                                                    state, action, 
                                                    pruneStatesList, actionChoice, 
                                                    balance, visualize=True)

### Train Model

In [None]:
model, model_acc, model_loss = train(train_loader, val_loader, plot_pth, batch_size, lr, epochs, verbose=True, pre_train=pre_train, model_PT=model_PT, device=device)
torch.save(model, '{}/model.pt'.format(model_pth))
torch.save(model_acc, '{}/model_acc.pt'.format(model_pth))
torch.save(model_loss, '{}/model_loss.pt'.format(model_pth))

#### Confusion Matrix

##### Load Models

In [5]:
model = torch.load('{}/model.pt'.format(model_pth), map_location=device)
model_acc = torch.load('{}/model_acc.pt'.format(model_pth), map_location=device)
model_loss = torch.load('{}/model_loss.pt'.format(model_pth), map_location=device)

##### Train Set

In [None]:
currGames = 8000
plot_cm(plot_pth, classes, model, train_loader, device, numGames=currGames)

##### Validation Set

In [None]:
currGames = 8000
plot_cm(plot_pth, classes, model, val_loader, device, numGames=currGames, mode='val')

##### Test Set (6k)

In [None]:
currGames = 6000
test_loader_6k, classes = load_test_data(data_pth, currGames, state, action,
                                         pruneStatesList, actionChoice)
plot_cm(plot_pth, classes, model, test_loader_6k, device, numGames=currGames)

##### Test Set (2k)

Test on all three models generated:

In [None]:
currGames = 2000
test_loader_2k, classes = load_test_data(data_pth, currGames, state, action,
                                         pruneStatesList, actionChoice)

###### all epoch

In [None]:
plot_cm(plot_pth, classes, model, test_loader_2k, device, numGames=currGames)

###### max validation accuracy

In [None]:
plot_cm(plot_pth, classes, model_acc, test_loader_2k, device, numGames=currGames, mode='acc')

###### min validation loss

In [None]:
plot_cm(plot_pth, classes, model_loss, test_loader_2k, device, numGames=currGames, mode='loss')