### Must install Lightning

In [1]:
!pip3 install lightning-bolts --quiet
!pip3 install torchmetrics
!pip3 install pytorch-lightning --quiet

!pip install kaggle

[K     |████████████████████████████████| 316 kB 5.3 MB/s 
[K     |████████████████████████████████| 584 kB 41.5 MB/s 
[K     |████████████████████████████████| 418 kB 59.2 MB/s 
[K     |████████████████████████████████| 140 kB 66.6 MB/s 
[K     |████████████████████████████████| 596 kB 23.3 MB/s 
[K     |████████████████████████████████| 1.1 MB 57.4 MB/s 
[K     |████████████████████████████████| 144 kB 16.6 MB/s 
[K     |████████████████████████████████| 94 kB 2.5 MB/s 
[K     |████████████████████████████████| 271 kB 55.5 MB/s 
[?25hLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from torchvision import transforms
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torchvision.models.resnet import resnet18
import pytorch_lightning as pl
from pytorch_lightning import Trainer, LightningModule
import torch.nn as nn
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import CIFAR10

In [3]:
from torchmetrics.functional import accuracy
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import ModelCheckpoint

In [4]:
EPOCHS = 200
LR = 0.1
MOMENTUM = 0.9
WEIGHT_DECAY = 5e-4
PRINT_FREQ = 50
TRAIN_BATCH=128
VAL_BATCH=128

In [5]:
GPU = 0

In [6]:
from google.colab import files

uploaded = files.upload()

for fn in uploaded.keys():
  print('User uploaded file "{name}" with length {length} bytes'.format(
      name=fn, length=len(uploaded[fn])))

Saving kaggle.json to kaggle.json
User uploaded file "kaggle.json" with length 70 bytes


In [7]:
# Then move kaggle.json into the folder where the API expects to find it.
!mkdir -p ~/.kaggle/ && mv kaggle.json ~/.kaggle/ && chmod 600 ~/.kaggle/kaggle.json

!kaggle datasets download -d mengcius/cinic10

Downloading cinic10.zip to /content
 98% 741M/754M [00:05<00:00, 229MB/s]
100% 754M/754M [00:05<00:00, 152MB/s]


In [8]:
!unzip cinic10.zip -d ./data/

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: ./data/valid/truck/n03632852_1173.png  
  inflating: ./data/valid/truck/n03632852_11766.png  
  inflating: ./data/valid/truck/n03632852_11887.png  
  inflating: ./data/valid/truck/n03632852_1191.png  
  inflating: ./data/valid/truck/n03632852_1192.png  
  inflating: ./data/valid/truck/n03632852_1195.png  
  inflating: ./data/valid/truck/n03632852_1207.png  
  inflating: ./data/valid/truck/n03632852_12128.png  
  inflating: ./data/valid/truck/n03632852_1214.png  
  inflating: ./data/valid/truck/n03632852_1237.png  
  inflating: ./data/valid/truck/n03632852_1248.png  
  inflating: ./data/valid/truck/n03632852_1261.png  
  inflating: ./data/valid/truck/n03632852_1264.png  
  inflating: ./data/valid/truck/n03632852_1292.png  
  inflating: ./data/valid/truck/n03632852_1313.png  
  inflating: ./data/valid/truck/n03632852_1319.png  
  inflating: ./data/valid/truck/n03632852_13311.png  
  inflating: ./data/valid/truc

### fill in the transform statements below

In [9]:
imagenet_mean_RGB = [0.47889522, 0.47227842, 0.43047404]
imagenet_std_RGB = [0.229, 0.224, 0.225]
cinic_mean_RGB = [0.47889522, 0.47227842, 0.43047404]
cinic_std_RGB = [0.24205776, 0.23828046, 0.25874835]
cifar_mean_RGB = [0.4914, 0.4822, 0.4465]
cifar_std_RGB = [0.2023, 0.1994, 0.2010]

In [10]:
class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, train_batch_size, val_batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        
        self.transform_train = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cifar_mean_RGB, cifar_std_RGB),
        ])
        self.transform_val = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(cifar_mean_RGB, cifar_std_RGB),
        ])
        
        self.dims = (3, 32, 32)
        self.num_classes = 10

    def prepare_data(self):
        # download 
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        # Assign train/val datasets for use in dataloaders
        if stage == 'fit' or stage is None:
#            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
#            self.cifar_train, self.cifar_val = random_split(cifar_full, [45000, 5000])
            self.cifar_train = CIFAR10(self.data_dir, train=True, transform=self.transform_train)
            self.cifar_val = CIFAR10(self.data_dir, train=False, transform=self.transform_val)

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.transform_val)

    def train_dataloader(self):
        return DataLoader(self.cifar_train, batch_size=self.train_batch_size, num_workers = 2, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.cifar_val, batch_size=self.val_batch_size, num_workers = 2)

    def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size, num_workers = 2)

In [11]:
dm = CIFAR10DataModule(TRAIN_BATCH, VAL_BATCH)
dm.prepare_data()
dm.setup()

  rank_zero_deprecation("DataModule property `dims` was deprecated in v1.5 and will be removed in v1.7.")


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


In [12]:
MODEL_CKPT_PATH = 'model/'
MODEL_CKPT = 'model/model-{epoch:02d}-{val_loss:.2f}'

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filename=MODEL_CKPT ,
    save_top_k=3,
    mode='min')

In [13]:
# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

(torch.Size([128, 3, 32, 32]), torch.Size([128]))

In [14]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

### Complete the training, validation, and optimizer methods below

In [15]:
class LitResnet18(LightningModule):
    def __init__(self, learning_rate, momentum, weight_decay):
        super().__init__()
        self.nn = resnet18(pretrained = False, progress  = True)
        self.nn.fc = nn.Linear(self.nn.fc.in_features, 10)
        self.lr = learning_rate
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.criterion = nn.CrossEntropyLoss().cuda(GPU)
    
    def forward(self, x):
        return self.nn.forward(x)
    
    def training_step(self, batch, batch_idx):
        x,y = batch
        logits = self.nn(x)
        loss = self.criterion(logits, y)
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=False)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=False)
        if batch_idx % PRINT_FREQ == 0:
          print("train step! " + str(batch_idx) + " train loss: " + str(loss.item()) + " train acc " + str(acc.item()))        
        return loss     
        
        
    def validation_step(self, batch, batch_idx):
        x,y = batch
        logits = self.nn(x)
        loss = self.criterion(logits, y) 
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        if batch_idx % PRINT_FREQ == 0:
          print("val step! " + str(batch_idx) + " val loss: " + str(loss.item()) + " val acc " + str(acc.item()))
        return loss  
        
        
        
    def configure_optimizers(self):
        optimizer = torch.optim.SGD(model.parameters(), self.lr, momentum=self.momentum, weight_decay=self.weight_decay)
        return optimizer

In [16]:
# model = resnet18(pretrained = False, progress  = True)
model = LitResnet18(LR, MOMENTUM, WEIGHT_DECAY)


In [17]:
# Initialize a trainer
trainer = pl.Trainer(max_epochs=EPOCHS,
                     progress_bar_refresh_rate=20, 
                     gpus=1, 
                     logger=None,
                     callbacks=[early_stop_callback],
                     checkpoint_callback=checkpoint_callback)

  f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
  f"Setting `Trainer(progress_bar_refresh_rate={progress_bar_refresh_rate})` is deprecated in v1.5 and"
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [18]:
trainer.fit(model, dm)

Files already downloaded and verified
Files already downloaded and verified


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | nn        | ResNet           | 11.2 M
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

val step! 0 val loss: 2.3216991424560547 val acc 0.1484375


                not been set for this class (_ResultMetric). The property determines if `update` by
                default needs access to the full metric state. If this is not the case, significant speedups can be
                achieved and we recommend setting this to `False`.
                We provide an checking function
                `from torchmetrics.utilities import check_forward_no_full_state`
                that can be used to check if the `full_state_update=True` (old and potential slower behaviour,
                default for now) or if `full_state_update=False` can be used safely.
                


Training: 0it [00:00, ?it/s]

train step! 0 train loss: 2.5748090744018555 train acc 0.109375
train step! 50 train loss: 2.3691768646240234 train acc 0.2890625
train step! 100 train loss: 2.274785280227661 train acc 0.3515625
train step! 150 train loss: 1.8171049356460571 train acc 0.3125
train step! 200 train loss: 1.7158052921295166 train acc 0.3515625
train step! 250 train loss: 1.7050533294677734 train acc 0.359375
train step! 300 train loss: 1.5424573421478271 train acc 0.4765625
train step! 350 train loss: 1.583848476409912 train acc 0.4140625


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 1.3348746299743652 val acc 0.4765625
val step! 50 val loss: 1.3264224529266357 val acc 0.5546875
train step! 0 train loss: 1.49345862865448 train acc 0.4296875
train step! 50 train loss: 1.3401082754135132 train acc 0.515625
train step! 100 train loss: 1.3030517101287842 train acc 0.5546875
train step! 150 train loss: 1.2380108833312988 train acc 0.5546875
train step! 200 train loss: 1.349328637123108 train acc 0.5078125
train step! 250 train loss: 1.1462666988372803 train acc 0.5703125
train step! 300 train loss: 1.3463410139083862 train acc 0.546875
train step! 350 train loss: 1.2203919887542725 train acc 0.609375


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 1.1648815870285034 val acc 0.59375
val step! 50 val loss: 1.2800347805023193 val acc 0.5703125
train step! 0 train loss: 1.0608619451522827 train acc 0.6328125
train step! 50 train loss: 0.9892431497573853 train acc 0.6640625
train step! 100 train loss: 1.0749876499176025 train acc 0.6171875
train step! 150 train loss: 1.2160907983779907 train acc 0.578125
train step! 200 train loss: 1.150248408317566 train acc 0.546875
train step! 250 train loss: 1.0083627700805664 train acc 0.6484375
train step! 300 train loss: 1.036720633506775 train acc 0.6328125
train step! 350 train loss: 1.0287127494812012 train acc 0.671875


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 0.9946984648704529 val acc 0.6484375
val step! 50 val loss: 1.1975936889648438 val acc 0.640625
train step! 0 train loss: 0.8068895936012268 train acc 0.6953125
train step! 50 train loss: 0.8298383355140686 train acc 0.6796875
train step! 100 train loss: 1.0130627155303955 train acc 0.6015625
train step! 150 train loss: 0.9890507459640503 train acc 0.6875
train step! 200 train loss: 0.8836997151374817 train acc 0.6796875
train step! 250 train loss: 0.8462570309638977 train acc 0.6640625
train step! 300 train loss: 1.104025959968567 train acc 0.625
train step! 350 train loss: 1.0965566635131836 train acc 0.625


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 1.0929768085479736 val acc 0.6171875
val step! 50 val loss: 1.0281367301940918 val acc 0.6953125
train step! 0 train loss: 0.7736379504203796 train acc 0.7265625
train step! 50 train loss: 0.7172626256942749 train acc 0.765625
train step! 100 train loss: 0.8387313485145569 train acc 0.703125
train step! 150 train loss: 1.0124753713607788 train acc 0.6484375
train step! 200 train loss: 0.7818409204483032 train acc 0.734375
train step! 250 train loss: 0.981384813785553 train acc 0.6328125
train step! 300 train loss: 0.8283544182777405 train acc 0.734375
train step! 350 train loss: 0.75108402967453 train acc 0.734375


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 0.7615908980369568 val acc 0.765625
val step! 50 val loss: 0.8695994019508362 val acc 0.6953125
train step! 0 train loss: 0.8113320469856262 train acc 0.703125
train step! 50 train loss: 0.5621874332427979 train acc 0.796875
train step! 100 train loss: 0.9382079839706421 train acc 0.6796875
train step! 150 train loss: 0.832574188709259 train acc 0.7109375
train step! 200 train loss: 0.8816530704498291 train acc 0.671875
train step! 250 train loss: 0.8630558252334595 train acc 0.6875
train step! 300 train loss: 0.6957355737686157 train acc 0.7578125
train step! 350 train loss: 0.6731017827987671 train acc 0.765625


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 0.7729185223579407 val acc 0.71875
val step! 50 val loss: 0.745276153087616 val acc 0.703125
train step! 0 train loss: 0.6905215978622437 train acc 0.7890625
train step! 50 train loss: 0.6645482778549194 train acc 0.7890625
train step! 100 train loss: 0.5888908505439758 train acc 0.828125
train step! 150 train loss: 0.583372950553894 train acc 0.796875
train step! 200 train loss: 0.796809732913971 train acc 0.7265625
train step! 250 train loss: 0.659849226474762 train acc 0.8125
train step! 300 train loss: 0.8720200657844543 train acc 0.7109375
train step! 350 train loss: 0.646828830242157 train acc 0.765625


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 0.8305903673171997 val acc 0.7421875
val step! 50 val loss: 0.9003006815910339 val acc 0.7109375
train step! 0 train loss: 0.6182538270950317 train acc 0.7734375
train step! 50 train loss: 0.7443426847457886 train acc 0.6953125
train step! 100 train loss: 0.7058830857276917 train acc 0.765625
train step! 150 train loss: 0.5815715193748474 train acc 0.765625
train step! 200 train loss: 0.6560128927230835 train acc 0.75
train step! 250 train loss: 0.6869444847106934 train acc 0.78125
train step! 300 train loss: 0.6986615657806396 train acc 0.765625
train step! 350 train loss: 0.6816555857658386 train acc 0.78125


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 0.7705528140068054 val acc 0.7578125
val step! 50 val loss: 0.6352909207344055 val acc 0.78125
train step! 0 train loss: 0.6029937267303467 train acc 0.7890625
train step! 50 train loss: 0.4072580933570862 train acc 0.8515625
train step! 100 train loss: 0.5008268356323242 train acc 0.84375
train step! 150 train loss: 0.454405814409256 train acc 0.84375
train step! 200 train loss: 0.6796107888221741 train acc 0.75
train step! 250 train loss: 0.8035030961036682 train acc 0.765625
train step! 300 train loss: 0.5510991215705872 train acc 0.8046875
train step! 350 train loss: 0.5595366358757019 train acc 0.8046875


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 0.7187463641166687 val acc 0.734375
val step! 50 val loss: 0.792082667350769 val acc 0.765625
train step! 0 train loss: 0.38063058257102966 train acc 0.8671875
train step! 50 train loss: 0.49629607796669006 train acc 0.8046875
train step! 100 train loss: 0.8459348678588867 train acc 0.7265625
train step! 150 train loss: 0.47931644320487976 train acc 0.8125
train step! 200 train loss: 0.721611499786377 train acc 0.7421875
train step! 250 train loss: 0.7879101037979126 train acc 0.734375
train step! 300 train loss: 0.38726818561553955 train acc 0.859375
train step! 350 train loss: 0.8008633852005005 train acc 0.703125


Validation: 0it [00:00, ?it/s]

val step! 0 val loss: 0.7680980563163757 val acc 0.7421875
val step! 50 val loss: 0.7471539378166199 val acc 0.7265625
train step! 0 train loss: 0.4366177022457123 train acc 0.8515625
train step! 50 train loss: 0.48352888226509094 train acc 0.8515625
train step! 100 train loss: 0.663245677947998 train acc 0.7890625
train step! 150 train loss: 0.5197433233261108 train acc 0.8125
train step! 200 train loss: 0.6245307326316833 train acc 0.8125
train step! 250 train loss: 0.6342105865478516 train acc 0.78125
train step! 300 train loss: 0.6228864789009094 train acc 0.78125
train step! 350 train loss: 0.6454299092292786 train acc 0.7578125


  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
