In [1]:
import numpy as np
from torch.utils.data import DataLoader
from torch.optim import SGD

from pytorch_common.additional_configs import BaseDatasetConfig, BaseModelConfig
from pytorch_common.config import load_pytorch_common_config
from pytorch_common.datasets import create_dataset
from pytorch_common.metrics import get_loss_eval_criteria
from pytorch_common.models import create_model
from pytorch_common.train_utils import train_model, get_all_predictions, EarlyStopping
from pytorch_common.utils import get_model_performance_trackers

## Create/load your own config here

In [2]:
# Create your own config (or load from a yaml file)
config_dict = {
    "batch_size_per_gpu": 5,
    "device": "cpu",
    "epochs": 5,
    "lr": 1e-3,
    "eval_criteria": ["accuracy", "precision", "recall", "f1", "auc"],
    "disable_checkpointing": False,
    "use_early_stopping": True,
}

## Merge it with pytorch_common default config

In [3]:
# Load the deault pytorch_common config, and then override it with your own custom one
config = load_pytorch_common_config(config_dict)

## Define your training objects here

In [4]:
# Create your own objects here
dataset_config = BaseDatasetConfig({"size": 10, "dim": 1, "num_classes": 2})
model_config = BaseModelConfig({"in_dim": 1, "num_classes": 2})

dataset = create_dataset("multi_class_dataset", dataset_config)
train_loader = DataLoader(dataset, batch_size=config.train_batch_size)
np.random.shuffle(dataset.data.target)  # Shuffle just to randomize data
val_loader = DataLoader(dataset, batch_size=config.eval_batch_size)

model = create_model("single_layer_classifier", model_config)
optimizer = SGD(model.parameters(), lr=config.lr)

2020-07-23 22:37:11,837: INFO: models_dl.py: print_model: SingleLayerClassifier(
  (fc): Linear(in_features=1, out_features=2, bias=True)
)
2020-07-23 22:37:11,838: INFO: utils.py: get_trainable_params: Number of trainable/total parameters in SingleLayerClassifier: 4/4


## Use pytorch_common to define early stopping, batch decoupling functions, etc.

In [5]:
# Use `pytorch_common` to get loss/eval criteria, initialize loggers, and train the model
early_stopping = EarlyStopping(criterion=config.early_stopping_criterion, patience=3)
loss_criterion_train, loss_criterion_eval, eval_criteria = get_loss_eval_criteria(config, reduction="mean")
train_logger, val_logger = get_model_performance_trackers(config)

## Train the model!

In [6]:
return_dict = train_model(
    model=model,
    config=config,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    loss_criterion_train=loss_criterion_train,
    loss_criterion_eval=loss_criterion_eval,
    eval_criteria=eval_criteria,
    train_logger=train_logger,
    val_logger=val_logger,
    epochs=config.epochs,
    early_stopping=early_stopping,
)

2020-07-23 22:37:11,868: INFO: utils.py: log_epoch_metrics: 
[1mTRAIN Epoch: 1	Average loss: 1.0289, accuracy: 0.4000, precision: 0.4000, recall: 1.0000, f1: 0.5714, auc: 0.6667[0m

2020-07-23 22:37:11,875: INFO: utils.py: log_epoch_metrics: 
[1mVAL   Epoch: 1	Average loss: 1.0283, accuracy: 0.4000, precision: 0.4000, recall: 1.0000, f1: 0.5714, auc: 0.6667[0m

2020-07-23 22:37:11,876: INFO: train_utils.py: train_model: Computing best epoch and adding to validation logger...
2020-07-23 22:37:11,876: INFO: train_utils.py: train_model: Done.
2020-07-23 22:37:11,878: INFO: train_utils.py: train_model: Replacing current best model checkpoint...
2020-07-23 22:37:11,879: INFO: train_utils.py: save_model: Saving state checkpoint '/Users/mrana/pytorch_common/checkpoints/checkpoint-state-single_layer_classifier-epoch_1.pt'...
2020-07-23 22:37:11,881: INFO: train_utils.py: save_model: Done.
2020-07-23 22:37:11,882: INFO: train_utils.py: train_model: Done.
2020-07-23 22:37:11,897: INFO: utils

Function 'pytorch_common.train_utils.perform_one_epoch' took 8.04ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 7.29ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 5.75ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 7.73ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 6.37ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 5.84ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 6.88ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 6.43ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 6.18ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 6.88ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 5.62ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 6.86ms
Function 'pytorch_common.train_utils.train_model' took 124.78ms


## Inspect results

In [7]:
return_dict.keys()

dict_keys(['model', 'best_model', 'train_logger', 'val_logger', 'optimizer', 'scheduler', 'stop_epoch', 'best_epoch', 'best_checkpoint_file'])

In [8]:
best_epoch = return_dict["val_logger"].best_epoch
best_epoch

1

In [9]:
return_dict["val_logger"].get_eval_metrics(epoch=best_epoch)

OrderedDict([('accuracy', 0.4),
             ('precision', 0.4),
             ('recall', 1.0),
             ('f1', 0.5714285714285715),
             ('auc', 0.6666666666666667)])

In [10]:
return_dict["val_logger"].get_eval_metrics_df()

Unnamed: 0,epoch,accuracy,precision,recall,f1,auc
0,1,0.4,0.4,1.0,0.571429,0.666667
1,2,0.4,0.4,1.0,0.571429,0.666667
2,3,0.4,0.4,1.0,0.571429,0.666667
3,4,0.4,0.4,1.0,0.571429,0.666667


In [11]:
return_dict["train_logger"].loss_hist

OrderedDict([(1, [0.8854662179946899, 1.1723929643630981]),
             (2, [0.884765625, 1.1714625358581543]),
             (3, [0.8840659856796265, 1.170533537864685]),
             (4, [0.8833677172660828, 1.1696058511734009])])

## Test model

In [12]:
# Create dummy test data
np.random.shuffle(dataset.data.target)
test_loader = DataLoader(dataset, batch_size=config.test_batch_size)

In [13]:
outputs_hist, preds_hist, probs_hist = get_all_predictions(
    model=return_dict["best_model"],
    dataloader=test_loader,
    device=config.device,
    threshold_prob=0.8,
)

2020-07-23 22:37:12,040: INFO: train_utils.py: perform_one_epoch: 5/10 (50%) complete.
2020-07-23 22:37:12,042: INFO: train_utils.py: perform_one_epoch: 10/10 (100%) complete.


Function 'pytorch_common.train_utils.perform_one_epoch' took 6.28ms


In [14]:
probs_hist

tensor([0.7838, 0.8669, 0.8254, 0.5460, 0.8069, 0.8846, 0.8112, 0.7981, 0.8199,
        0.8383])

In [15]:
preds_hist

tensor([0, 1, 1, 0, 1, 1, 1, 0, 1, 1])