# install

### Install Lightning 2.0 

In [1]:
%pip install -qqq lightning
%pip install -qqq timm torchmetrics wandb

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m718.6/718.6 kB[0m [31m18.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m553.5/553.5 kB[0m [31m19.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.5/129.5 kB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.7/69.7 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.4/66.4 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.9/66.9 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.8/57.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━

# import

In [2]:
import os
import gc

import time
from tqdm import tqdm

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

# torchvision
import torchvision
from torchvision import transforms, datasets

# wandb
import wandb
from lightning.pytorch.loggers import WandbLogger

# torchmetrics
import torchmetrics

# timm
import timm
from timm import create_model

# import Pytorch Lightning 2.0 
import lightning as L

# Fabric
from lightning.fabric import Fabric

# Config

In [3]:
config= {'model_name':'resnet18', 
         'seed': 2023,
         'bs': 32,
         'n_epochs': 10,
         'lr': 1e-3,
         'is_compiled': True, 
         'mode': 'default',
         'strategy': "auto",
         }

# Ramdom SEED

In [4]:
L.seed_everything(config['seed'])

INFO: Global seed set to 2023
INFO:lightning.fabric.utilities.seed:Global seed set to 2023


2023

# Fabric launch

In [5]:
# fabric = Fabric(accelerator="cuda", devices=8, strategy="ddp")
fabric = Fabric(accelerator= "auto", devices= "auto", strategy="auto")
fabric.launch()

# LitModel

In [6]:
class LitModel(L.LightningModule):
    def __init__(self, model_name, is_compiled, mode =None, batch_size = 32, lr = 1e-3):
        super().__init__()
        self.model = self.get_model(model_name, is_compiled, mode)
        self.batch_size = batch_size
        self.lr = lr

        # Lightning 2.0
        self.losses = []

        # Loss Function
        self.criterion = nn.CrossEntropyLoss()

        # torchmetrics modules
        self.f1 = torchmetrics.F1Score(task="multiclass", num_classes = 10)

    def forward(self, x):
        return self.model(x)

    def get_model(self, model_name, is_compiled, mode):
        
        model = create_model(model_name, pretrained=True, num_classes=10)

        if is_compiled:
            print(f"model_name: {model_name} | Compiled?: {is_compiled} | Compiled MODE: {mode}")
            compiled_model = torch.compile(model, mode = mode) #, backend="aot_ts_nvfuser") # backend="inductor")
            return compiled_model

        else:

            print(f"model_name: {model_name}")
            return model

    def shared_step(self, batch, mode):
        x, y = batch
        logits = self(x)
        loss = self.criterion(logits, y)

        self.losses.append(loss)

        # F1 Score
        f1 = self.f1(logits, y)

        return {'loss': loss, "f1": f1}
        
    def shared_epoch_end(self, mode):

        loss = torch.stack(self.losses).mean().item()

        metrics = {f'{mode}_loss_epoch': loss, f'{mode}_f1_epoch': self.f1.compute()}
        
        self.log_dict(metrics, prog_bar=True)

        return {'loss': loss, "f1": self.f1.compute()}

        # Reset
        self.f1.reset()
        self.losses.clear()

    def validation_step(self, batch, batch_idx):
        return self.shared_step(batch, "valid")
    
    # https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#validation-epoch-level-metrics
    def on_validation_epoch_end(self):
        return self.shared_epoch_end(mode = "valid")

    def training_step(self, batch, batch_idx):
        return self.shared_step(batch,"train")

    # reference: https://github.com/Lightning-AI/lightning/pull/16520
    def on_train_epoch_end(self):
        return self.shared_epoch_end(mode = "train")

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def train_dataloader(self):
        # Return your dataloader for training
        train_set = datasets.CIFAR10(root="~/data", train=True, download=True, transform=transforms.ToTensor())
        train_loader = torch.utils.data.DataLoader(train_set, batch_size= self.batch_size, shuffle=True, num_workers= os.cpu_count())
        return train_loader
    
    def val_dataloader(self):
        val_set = datasets.CIFAR10(root="~/val_data", train=False, download=True, transform=transforms.ToTensor())
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=self.batch_size, shuffle=False, num_workers= os.cpu_count())
        return val_loader

In [7]:
model = LitModel(model_name = config['model_name'],
                 is_compiled = config["is_compiled"],
                 mode = config["mode"],                  
                 batch_size = config['bs'],
                 lr = config['lr'])

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


model_name: resnet18 | Compiled?: True | Compiled MODE: default


# Wandb Logger

In [8]:
# Reference: https://lightning.ai/docs/pytorch/stable/visualize/logging_intermediate.html?highlight=wandb_logger

wandb_logger = WandbLogger( project= 'fabric_test', 
                            config=config,
                            job_type='Train',
                            name=  "[compiled] lightning_2",
                            anonymous='must')

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


# Trainer

In [9]:
trainer = L.Trainer(accelerator="auto", 
                    devices = -1, 
                    max_epochs= config['n_epochs'],
                    logger = wandb_logger,
                    strategy = config["strategy"]
                    )

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs


# Trainer fit!

In [10]:
trainer.fit(model)

INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name      | Type              | Params
------------------------------------------------
0 | model     | OptimizedModule   | 11.2 M
1 | criterion | CrossEntropyLoss  | 0     
2 | f1        | MulticlassF1Score | 0     
------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name      | Type              | Params
------------------------------------------------
0 | model     | OptimizedModule   | 11.2 M
1 | criterion | CrossEntropyLoss  | 0     
2 | f1        | MulticlassF1Score | 0     
------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.727    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /root/val_data/cifar-10-python.tar.gz



  0%|          | 0/170498071 [00:00<?, ?it/s][A
  0%|          | 65536/170498071 [00:00<07:35, 374449.70it/s][A
  0%|          | 229376/170498071 [00:00<04:12, 674791.00it/s][A
  0%|          | 557056/170498071 [00:00<01:54, 1488556.08it/s][A
  1%|          | 1277952/170498071 [00:00<00:52, 3243783.58it/s][A
  2%|▏         | 2621440/170498071 [00:00<00:26, 6345692.42it/s][A
  3%|▎         | 5341184/170498071 [00:00<00:13, 12659943.42it/s][A
  5%|▌         | 8650752/170498071 [00:00<00:08, 18686427.31it/s][A
  7%|▋         | 11960320/170498071 [00:00<00:06, 23003446.52it/s][A
  9%|▉         | 15237120/170498071 [00:01<00:05, 25923185.15it/s][A
 11%|█         | 18350080/170498071 [00:01<00:05, 27480379.64it/s][A
 13%|█▎        | 21528576/170498071 [00:01<00:05, 28300165.63it/s][A
 15%|█▍        | 24969216/170498071 [00:01<00:04, 30083622.79it/s][A
 16%|█▋        | 28016640/170498071 [00:01<00:04, 30088476.44it/s][A
 18%|█▊        | 31326208/170498071 [00:01<00:04, 30978624

Extracting /root/val_data/cifar-10-python.tar.gz to /root/val_data
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /root/data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:06<00:00, 28322556.85it/s]


Extracting /root/data/cifar-10-python.tar.gz to /root/data


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


# Finished

In [11]:
wandb.finish()

0,1
epoch,▁▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇██
train_f1_epoch,▁▃▄▅▆▆▇▇██
train_loss_epoch,█▆▅▄▃▃▂▂▁▁
trainer/global_step,▁▁▂▂▃▃▃▃▄▄▅▅▆▆▆▆▇▇██
valid_f1_epoch,▁▃▄▅▆▆▇▇██
valid_loss_epoch,█▆▅▄▃▃▂▂▁▁

0,1
epoch,9.0
train_f1_epoch,0.8316
train_loss_epoch,0.5083
trainer/global_step,15629.0
valid_f1_epoch,0.8316
valid_loss_epoch,0.5083
