<a href="https://colab.research.google.com/github/quanganh1999/NF-Net-on-CIFAR/blob/main/Train_NF_ResNet_on_Cifar_10_using_PyTorch_Lightning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## ⚙️ Imports and Setups

In [29]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [30]:
%cd drive/MyDrive/NF-Resnet

[Errno 2] No such file or directory: 'drive/MyDrive/NF-Resnet'
/content/drive/MyDrive/NF-Resnet


In [4]:
%%capture
# Install pytorch lighting
!pip install pytorch-lightning --quiet
# Install weights and biases
!pip install wandb --quiet

In [5]:
!pip install lightning-bolts["extra"] --quiet

[K     |████████████████████████████████| 256kB 10.1MB/s 
[K     |████████████████████████████████| 37.6MB 73kB/s 
[K     |████████████████████████████████| 22.3MB 1.5MB/s 
[?25h

In [6]:
!git clone https://github.com/rwightman/pytorch-image-models

fatal: destination path 'pytorch-image-models' already exists and is not an empty directory.


In [48]:
# regular imports
import sys
sys.path.append("pytorch-image-models")
import os
import re
import numpy as np

# pytorch related imports 
import torch
from torch import nn
from torch.nn import functional as F
import torchvision.models as models
from torchvision import transforms
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.datasets.utils import download_url

# import for nfnet
import timm

# lightning related imports
import pytorch_lightning as pl
from pytorch_lightning.metrics.functional import accuracy
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

# sklearn related imports
from sklearn.metrics import precision_recall_curve
from sklearn.preprocessing import label_binarize

# import wandb and login
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [14]:
!wandb login --relogin

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [49]:
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.datasets import CIFAR10
from pytorch_lightning.utilities.seed import seed_everything
import pl_bolts

In [50]:
seed_everything(1999)

Global seed set to 1999


1999

## 🎨 Using DataModules - `Clatech101DataModule`

DataModules are a way of decoupling data-related hooks from the `LightningModule` so you can develop dataset agnostic models.

In [None]:
#@title
class Caltech101DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

        # Augmentation policy
        self.augmentation = transforms.Compose([
              transforms.RandomResizedCrop(size=256, scale=(0.8, 1.0)),
              transforms.RandomRotation(degrees=15),
              transforms.RandomHorizontalFlip(),
              transforms.CenterCrop(size=224),
              transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])
        self.transform = transforms.Compose([
              transforms.Resize(size=256),
              transforms.CenterCrop(size=224),
              transforms.ToTensor(),
              transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
        ])
        
        self.num_classes = 102

    def prepare_data(self):
        # source: https://figshare.com/articles/dataset/Caltech101_Image_Dataset/7007090
        url = 'https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/12855005/Caltech101ImageDataset.rar'
        # download
        download_url(url, self.data_dir)
        # extract 
        patoolib.extract_archive("Caltech101ImageDataset.rar", outdir=self.data_dir)

    def setup(self, stage=None):
        # build dataset
        caltect_dataset = ImageFolder('Caltech101')
        # split dataset
        self.train, self.val, self.test = random_split(caltect_dataset, [6500, 1000, 1645])
        self.train.dataset.transform = self.augmentation
        self.val.dataset.transform = self.transform
        self.test.dataset.transform = self.transform
        
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size)

In [60]:
class CIFAR10Data(pl.LightningDataModule):
      def __init__(self, batch_size, data_dir: str = './'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size        
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2471, 0.2435, 0.2616)

        #augmentation (use other strong augmentation)
        self.train_transform = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ]
        )
        #need to resize for more acc ??
        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(self.mean, self.std),
            ]
        )
        # self.dims = (3, 32, 32)
        self.num_classes = 10

      def prepare_data(self):
        # download 
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)    

      def setup(self, stage=None):        
        if stage == 'fit' or stage is None:
            # load the dataset    
            self.cifar_full_train = CIFAR10(self.data_dir, train=True, transform=self.train_transform)
            self.cifar_full_val = CIFAR10(self.data_dir, train=True, transform=self.test_transform)             
            num_train = len(self.cifar_full_train)
            indices = list(range(num_train))
            split = int(np.floor(0.1 * num_train))
            np.random.seed(1999)
            np.random.shuffle(indices)

            #train
            train_idx, valid_idx = indices[split:], indices[:split]
            self.train_sampler = SubsetRandomSampler(train_idx)
            self.valid_sampler = SubsetRandomSampler(valid_idx)            

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.cifar_test = CIFAR10(self.data_dir, train=False, transform=self.test_transform)        

      def train_dataloader(self):
        #return DataLoader(self.cifar_full_train, batch_size=self.batch_size, sampler=self.train_sampler)
        return DataLoader(self.cifar_full_train, batch_size=self.batch_size, shuffle=True)

      def val_dataloader(self):      
        #return DataLoader(self.cifar_full_val, batch_size=self.batch_size, sampler=self.valid_sampler)
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

      def test_dataloader(self):
        return DataLoader(self.cifar_test, batch_size=self.batch_size)

## 📲 Callbacks

#### 🚏 Earlystopping

In [61]:
early_stop_callback = EarlyStopping(
   monitor='val_loss',
   patience=3,
   verbose=False,
   mode='min'
)

#### 🛃 Custom Callback - `ImagePredictionLogger`

In [62]:
class ImagePredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples
        
    def on_validation_epoch_end(self, trainer, pl_module):
        val_imgs = self.val_imgs.to(device=pl_module.device)
        val_labels = self.val_labels.to(device=pl_module.device)
       
        logits = pl_module(val_imgs)
        preds = torch.argmax(logits, -1)
        
        trainer.logger.experiment.log({
            "examples":[wandb.Image(x, caption=f"Pred:{pred}, Label:{y}") 
                           for x, pred, y in zip(val_imgs[:self.num_samples], 
                                                 preds[:self.num_samples], 
                                                 val_labels[:self.num_samples])]
            })

#### 💾 Model Checkpoint Callback

In [63]:
MODEL_CKPT_PATH = './model'
# MODEL_CKPT = 'model-{epoch:02d}-{val_loss:.2f}'
MODEL_CKPT = 'model-{epoch:02d}-{val_acc:.2f}'

checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    dirpath = MODEL_CKPT_PATH,
    filename=MODEL_CKPT,
    save_top_k=3,
    mode='max')



## 🎺 Define The Model

In [64]:
class LitModel(pl.LightningModule):
    def __init__(self, input_shape, num_classes, learning_rate=2e-4):
        super().__init__()
        
        # log hyperparameters
        self.save_hyperparameters()
        self.learning_rate = learning_rate
        self.dim = input_shape
        self.num_classes = num_classes

        #train from scratch
        # self.classifier = timm.create_model('nf_resnet50', pretrained=False, num_classes = num_classes)

        #transfer learning
        #fine-tuning
        self.model = timm.create_model("eca_nfnet_l1", pretrained=True, num_classes=num_classes)        

        #feature exactor (freezee all layers except the last layer)
        # self.model = timm.create_model("eca_nfnet_l1", pretrained=True)        
        # for param in self.model.parameters():
        #     param.requires_grad = False
        # self.model.reset_classifier(num_classes)

    def forward(self, x):
      x = F.log_softmax(self.model.forward(x), dim=1)
      return x


    # logic for a single training step
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # training metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        
        return loss

    # logic for a single validation step
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        return loss

    # logic for a single testing step
    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        
        # validation metrics
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        # optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=5e-5,nesterov = True)
        scheduler = pl_bolts.optimizers.lr_scheduler.LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs = 5, max_epochs = 60)
        # return optimizer
        return [optimizer], [scheduler]


## ⚡ Train and Evaluate The Model

In [65]:
# Init our data pipeline
BATCH_SIZE = 64
dm = CIFAR10Data(batch_size=BATCH_SIZE)
# To access the x_dataloader we need to call prepare_data and setup.
dm.prepare_data()
dm.setup()

# Samples required by the custom ImagePredictionLogger callback to log image predictions.
val_samples = next(iter(dm.val_dataloader()))
val_imgs, val_labels = val_samples[0], val_samples[1]
val_imgs.shape, val_labels.shape

Files already downloaded and verified
Files already downloaded and verified


(torch.Size([64, 3, 32, 32]), torch.Size([64]))

In [None]:
# Init our model
model = LitModel((3, 32, 32), 10)

# Initialize wandb logger
wandb_logger = WandbLogger(project='nfnet', job_type='train-nf-resnet')

# Initialize a trainer
trainer = pl.Trainer(max_epochs=60,
                     progress_bar_refresh_rate=20, 
                     gpus=1, 
                     logger=wandb_logger,
                     #early_stop_callback,
                     callbacks=[ImagePredictionLogger(val_samples), checkpoint_callback]
                     )

# Train the model ⚡🚅⚡
trainer.fit(model, dm) 

# Evaluate the model on the held out test set ⚡⚡
trainer.test()

# Close wandb run
wandb.finish()

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: Currently logged in as: [33mquanganhk62[0m (use `wandb login --relogin` to force relogin)



  | Name  | Type        | Params
--------------------------------------
0 | model | NormFreeNet | 38.4 M
--------------------------------------
38.4 M    Trainable params
0         Non-trainable params
38.4 M    Total params
153.462   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

In [47]:
#Test with checkpoint
LitModel((3, 32, 32), 10)
model = LitModel.load_from_checkpoint(checkpoint_path='./model/model-epoch=09-val_loss=0.34.ckpt')

# init trainer with whatever options
trainer = trainer = pl.Trainer(gpus=1)

# test (pass in the model)
trainer.test(model, datamodule = dm)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': 0.8726999759674072, 'test_loss': 0.36484888195991516}
--------------------------------------------------------------------------------


[{'test_acc': 0.8726999759674072, 'test_loss': 0.36484888195991516}]

In [45]:
%ls model/

'model-epoch=02-val_loss=0.31.ckpt'  'model-epoch=08-val_loss=0.35.ckpt'
'model-epoch=04-val_loss=0.31.ckpt'  'model-epoch=09-val_loss=0.34.ckpt'
'model-epoch=04-val_loss=0.33.ckpt'  'model-epoch=13-val_loss=0.35.ckpt'
'model-epoch=05-val_loss=0.31.ckpt'  'model-epoch=27-val_loss=0.67.ckpt'
'model-epoch=05-val_loss=0.35.ckpt'  'model-epoch=28-val_loss=0.67.ckpt'
'model-epoch=06-val_loss=0.31.ckpt'  'model-epoch=29-val_loss=0.67.ckpt'
