In [50]:
from models.v1block import V1Block
from torchvision import models, transforms, datasets
import torch
from torch import nn, optim
import lightning as L
from torch.utils.data import DataLoader
import torchmetrics

In [51]:
class LightningTrainer(L.LightningModule):

    def __init__(self, model, loss_fn):
        super().__init__()
        self.model = model
        self.loss_fn = loss_fn
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10)


    def training_step(self, batch, batch_idx):
        x, y = batch
        yhat = self.model.forward(x)
        loss = self.loss_fn(yhat, y)
        self.log("train_loss", loss)
        self.log("train_acc", self.train_acc(yhat, y), on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        yhat = self.model.forward(x)
        self.log("val_loss", self.loss_fn(yhat, y))
        self.log("val_acc", self.val_acc(yhat, y), on_epoch=True)
    
    def configure_optimizers(self):
        return optim.Adam(self.model.parameters(), lr=1e-5)

In [52]:
class MiniV1Net(nn.Module):

    def __init__(self):
        super(MiniV1Net, self).__init__()
        self.layers = nn.Sequential(
            V1Block(3, 64, image_size=32),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
        )
        self.output_layer = nn.Linear(256 * 2 * 2, 10)

    def forward(self, x):
        x = self.layers(x)
        return self.output_layer(x.view(x.shape[0], -1))

In [59]:
train_dataset = datasets.CIFAR10(
    train=True, root="/Users/gursi/Desktop/data",
    download=True, transform=transforms.ToTensor()
)
test_dataset = datasets.CIFAR10(
    train=False, root="/Users/gursi/Desktop/data",
    download=True, transform=transforms.ToTensor()
)
train_loader = DataLoader(
    train_dataset, batch_size=32, shuffle=True,
    num_workers=9, persistent_workers=True
)
test_loader = DataLoader(
    test_dataset, batch_size=32,
    num_workers=9, persistent_workers=True
)

model = MiniV1Net()

Files already downloaded and verified
Files already downloaded and verified


In [60]:
training_module = LightningTrainer(
    model=model,
    loss_fn=nn.CrossEntropyLoss()
)

trainer = L.Trainer(
    max_epochs=100,
    accelerator="mps"
)

trainer.fit(
    model=training_module,
    train_dataloaders=train_loader,
    val_dataloaders=test_loader,
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params
-------------------------------------------------
0 | model     | MiniV1Net          | 1.4 M 
1 | loss_fn   | CrossEntropyLoss   | 0     
2 | train_acc | MulticlassAccuracy | 0     
3 | val_acc   | MulticlassAccuracy | 0     
-------------------------------------------------
396 K     Trainable params
960 K     Non-trainable params
1.4 M     Total params
5.426     Total estimated model params size (MB)


Epoch 30:  26%|██▌       | 401/1563 [00:14<00:42, 27.47it/s, v_num=5]      

/Users/gursi/miniforge3/envs/ml/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


In [65]:
layer = nn.MultiheadAttention(256, 8, batch_first=True)
x = torch.randn(32, 1000, 256)
layer.forward(x, x, x)[0].shape

torch.Size([32, 1000, 256])