## Using Gradient Checkpointing in models using MLFlow experiment tracking

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from torchinfo import summary
from torchmetrics import Accuracy

import mlflow
import numpy as np

In [None]:
# MLFlow setup

mlflow.set_tracking_uri("http://127.0.0.1:5000")
mlflow.set_experiment("check-localhost-connection")

In [None]:
class CustomLoss(torch.autograd.Function):
    @staticmethod
    def forward(ctx, y_pred, y_true):
        ctx.save_for_backward(y_pred, y_true)
        y_pred_clipped = torch.clamp(y_pred, 1e-7, 1-1e-7)
        loss = torch.sum(-y_true * torch.log(y_pred_clipped)) / y_true.shape[0]

        return loss

    def backward(ctx, grad_output):
        y_pred, y_true = ctx.saved_tensors
        N = y_true.shape[0]
        dy_pred = (-y_true/y_pred) / N
        return grad_output * dy_pred, None

In [None]:
A = torch.rand(3,3, requires_grad = True, dtype = torch.double)
B = torch.rand(3,3, requires_grad = False, dtype = torch.double)

crossentropyloss = CustomLoss.apply

torch.autograd.gradcheck(crossentropyloss, [A, B])

In [None]:
from torchvision.datasets import CIFAR10
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

tansforms = transforms.Compose([
    transforms.PILToTensor()
])

train_dataset = CIFAR10(root="\data", download=True, train=True, transform=transforms.ToTensor())
test_dataset = CIFAR10(root="\data", download = True, train=False, transform = transforms.ToTensor())

train_dataloader = DataLoader(train_dataset, batch_size = 64, shuffle = True, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size = 64, pin_memory=True)

In [None]:
class CIFAR10_Model(nn.Module):
    def __init__(self):
        super(CIFAR10_Model, self).__init__()
        self.cnn_block_1 = nn.Sequential(*[
            nn.Conv2d(3, 32, 3),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
            ])
        self.dropout_1 = nn.Dropout(0.25)
        self.cnn_block_2 = nn.Sequential(*[
            nn.Conv2d(64, 64, 3, padding = 1),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        ])
        self.dropout_2 = nn.Dropout(0.25)
        self.flatten = lambda input: torch.flatten(input, 1)
        self.linearize = nn.Sequential(*[
            nn.Linear(64 * 7 * 7, 512),
            nn.ReLU()
        ])
        self.dropout_3 = nn.Dropout(0.5)
        self.linear1 = nn.Linear(512, 84)
        self.out = nn.Linear(84, 10)

    def forward(self, img):
        x = self.cnn_block_1(img)
        x = self.dropout_1(x)
        x = checkpoint(self.cnn_block_2, x, use_reentrant = True)
        x = self.dropout_2(x)
        x = self.flatten(x)
        x = self.linearize(x)
        x = self.dropout_3(x)
        x = F.relu(self.linear1(x))
        x = self.out(x)
        return x

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = CIFAR10_Model().to(device)

In [None]:
lr = 1e-3
loss_fn = torch.nn.CrossEntropyLoss()
optim = torch.optim.AdamW(model.parameters(), lr = lr)
batch_size = 64
epochs = 10
metric_fn = Accuracy("multiclass", num_classes=10).to(device)

def train(dataloader, model, loss_fn, metrics_fn, optimizer, epoch):
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        pred = model(X)
        loss = loss_fn(pred, y)
        accuracy = metrics_fn(pred, y)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch % 100  == 0:
            loss, current = loss.item(), batch
            step = batch // 100 * (epoch + 1)
            mlflow.log_metric("Loss", f"{loss:2f}", step = step)
            mlflow.log_metric("Accuracy", f"{accuracy: 2f}", step = step)
            print(f"loss: {loss:2f} accuracy: {accuracy:2f} [{current} / {len(dataloader)}]")


def evaluate(dataloader, model, loss_fn, metrics_fn, epoch):
    model.eval()
    num_batches = len(dataloader)
    eval_loss, eval_accuracy = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            eval_loss += loss_fn(pred, y).item()
            eval_accuracy += metrics_fn(pred, y)
    eval_loss /= num_batches
    eval_accuracy /= num_batches
    mlflow.log_metric("eval_loss", f"{eval_loss:2f}", step = epoch)
    mlflow.log_metric("eval_accuracy", f"{eval_accuracy: 2f}", step = epoch)

    print(f"Eval metrics: \nAccuracy: {eval_accuracy:.2f}, Avg loss: {eval_loss:2f} \n")

In [None]:
with mlflow.start_run() as run:
    params = {
        "epochs" : epochs,
        "learning_rate" : 1e-3,
        "batch_size" : 64,
        "loss_function" : loss_fn.__class__.__name__,
        "metric_function" : metric_fn.__class__.__name__,
        "optimizer": "AdamW"
        
    }

    mlflow.log_params(params)

    with open("model_summary.txt", "w") as f:
        f.write(str(summary(model)))
    mlflow.log_artifact("model_summary.txt")


    for t in range(epochs):
        print(f"Epoch {t+1}\n----------------------------")
        train(train_dataloader, model, loss_fn, metric_fn, optim, epoch = t)
        evaluate(test_dataloader, model, loss_fn, metric_fn, epoch = 0)

    mlflow.pytorch.log_model(model, "model")