In [1]:
from torch.utils.data import DataLoader
from torch.optim import SGD

from pytorch_common.additional_configs import BaseDatasetConfig, BaseModelConfig
from pytorch_common.config import load_pytorch_common_config
from pytorch_common.datasets import create_dataset
from pytorch_common.metrics import get_loss_eval_criteria
from pytorch_common.models import create_model
from pytorch_common.train_utils import train_model
from pytorch_common.utils import get_model_performance_trackers

In [2]:
# Create your own config (or load from a yaml file)
config_dict = {"batch_size_per_gpu": 5, "device": "cpu", "epochs": 2, "lr": 1e-3, "disable_checkpointing": False}

# Load the deault pytorch_common config, and then override it with your own custom one
config = load_pytorch_common_config(config_dict)

In [3]:
# Create your own objects here
dataset_config = BaseDatasetConfig({"size": 5, "dim": 1, "num_classes": 2})
model_config = BaseModelConfig({"in_dim": 1, "num_classes": 2})
dataset = create_dataset("multi_class_dataset", dataset_config)
train_loader = DataLoader(dataset, batch_size=config.train_batch_size)
val_loader = DataLoader(dataset, batch_size=config.eval_batch_size)
model = create_model("single_layer_classifier", model_config)
optimizer = SGD(model.parameters(), lr=config.lr)

2020-07-23 22:10:22,970: INFO: models_dl.py: print_model: SingleLayerClassifier(
  (fc): Linear(in_features=1, out_features=2, bias=True)
)
2020-07-23 22:10:22,971: INFO: utils.py: get_trainable_params: Number of trainable/total parameters in SingleLayerClassifier: 4/4


In [4]:
# Use `pytorch_common` to get loss/eval criteria, initialize loggers, and train the model
loss_criterion_train, loss_criterion_eval, eval_criteria = get_loss_eval_criteria(config, reduction="mean")
train_logger, val_logger = get_model_performance_trackers(config)
return_dict = train_model(
    model, config, train_loader, val_loader, optimizer, loss_criterion_train, loss_criterion_eval, eval_criteria, train_logger, val_logger
)

2020-07-23 22:10:22,985: INFO: utils.py: log_epoch_metrics: 
[1mTRAIN Epoch: 1	Average loss: 1.1719, accuracy: 0.2000[0m

2020-07-23 22:10:22,988: INFO: utils.py: log_epoch_metrics: 
[1mVAL   Epoch: 1	Average loss: 1.1711, accuracy: 0.2000[0m

2020-07-23 22:10:22,989: INFO: train_utils.py: train_model: Computing best epoch and adding to validation logger...
2020-07-23 22:10:22,990: INFO: train_utils.py: train_model: Done.
2020-07-23 22:10:22,991: INFO: train_utils.py: train_model: Replacing current best model checkpoint...
2020-07-23 22:10:22,991: INFO: train_utils.py: save_model: Saving state checkpoint '/Users/mrana/pytorch_common/checkpoints/checkpoint-state-single_layer_classifier-epoch_1.pt'...
2020-07-23 22:10:22,994: INFO: train_utils.py: save_model: Done.
2020-07-23 22:10:22,995: INFO: train_utils.py: train_model: Done.
2020-07-23 22:10:23,002: INFO: utils.py: log_epoch_metrics: 
[1mTRAIN Epoch: 2	Average loss: 1.1711, accuracy: 0.2000[0m

2020-07-23 22:10:23,006: INFO: u

Function 'pytorch_common.train_utils.perform_one_epoch' took 4.21ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 1.91ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 2.62ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 4.47ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 1.81ms
Function 'pytorch_common.train_utils.perform_one_epoch' took 2.46ms
Function 'pytorch_common.train_utils.train_model' took 54.96ms
