<a href="https://colab.research.google.com/github/pvrancx/pytorch_utils/blob/master/pytorch_utils.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Install requirements

In [0]:
!pip3 install torch torchvision

Mount google drive to store results

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

In [0]:
! mkdir -p "gdrive/My Drive/cifar10"

In [0]:
% cd gdrive/My Drive/cifar10

Import utilities repository from GitHub

In [0]:
! git clone https://github.com/pvrancx/pytorch_utils.git

In [0]:
% cd pytorch_utils

In [0]:
! git pull

Setup experiment

In [0]:
import torch
import torch.nn as nn
from torch.utils.data.dataloader import DataLoader
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.models.resnet import Bottleneck, ResNet
from torch.optim.lr_scheduler import ReduceLROnPlateau, LambdaLR

In [0]:
from torchutils.experiment import Experiment, Config
from torchutils.dataloaders import cifar10_loader
from torchutils.metrics import accuracy, ValidationMetric
from torchutils.train import fit
from torchutils.callbacks import ModelSaverCallback, LoggerCallback

In [0]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device)

In [0]:
def get_experiment():
  resnet50 = ResNet(Bottleneck, [3, 4, 6, 3], num_classes=10)
  optimizer = torch.optim.Adam(resnet50.parameters(), lr=1e-3)
  return Experiment(
      model=resnet50,
      optimizer=optimizer,
      loss_fn=nn.CrossEntropyLoss(),
      lr_scheduler=None,
      config=Config(device=device, max_epochs=200)
  )


In [0]:
! mkdir -p "/content/gdrive/My Drive/cifar10/checkpoints"

In [0]:
def get_callbacks(exp, data):
    return [
        LoggerCallback(),
        ModelSaverCallback("/content/gdrive/My Drive/cifar10/checkpoints", frequency=10),
        ValidationMetric(accuracy, data.test, name='validation accuracy')
        ]

Fit model

In [0]:
def lr_schedule(epoch):
  if epoch > 180:
    return 0.5e-3
  elif epoch > 160:
    return 1e-3
  elif epoch > 120:
    return 1e-2
  elif epoch > 80:
    return 1e-1
  return 1.

In [0]:
def main():
  experiment = get_experiment()
  data=cifar10_loader('../data', batch_size=128)
  fit(
      exp=experiment, 
      data=data, 
      callbacks=get_callbacks(experiment, data),
      lr_schedulers=[ReduceLROnPlateau(experiment.optimizer), 
                     LambdaLR(experiment.optimizer, lr_schedule)]
    )


In [0]:
main()