Classification task


In [3]:
# !pip install einops
# !pip install --upgrade pytorch-lightning
# !pip install wandb


Imports

In [4]:
import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import CIFAR10

import pytorch_lightning as pl
import wandb


Data Module


In [5]:
def accuracy(preds, target):
    return (preds == target).float().mean()

class CIFAR10DataModule(pl.LightningDataModule):
    def __init__(self, batch_size, data_dir: str = '/data'):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        self.dims = (3, 32, 32)
        self.num_classes = 10

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def prepare_data(self):
        CIFAR10(self.data_dir, train=True, download=True)
        CIFAR10(self.data_dir, train=False, download=True)

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            cifar_full = CIFAR10(self.data_dir, train=True, transform=self.transform)
            self.train_dataset, self.val_dataset = random_split(cifar_full, [45000, 5000])
        if stage == 'test' or stage is None:
            self.test_dataset = CIFAR10(self.data_dir, train=False, transform=self.transform)


Lora Layer

In [6]:
class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        #this choice ensures that the initial values in A are not too large.
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.A = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev,requires_grad=True)
        #LoRALayer does not impact the original weights because AB=0 if B=0. in the begining of the training
        self.B = torch.nn.Parameter(torch.zeros(rank, out_dim),requires_grad=True)

        self.alpha = alpha

    def forward(self, x):
        x = self.alpha * (x @ self.A @ self.B)
        return x

In [7]:
# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class FeedForward(pl.LightningModule):
    def __init__(self, dim, hidden_dim, rank, alpha=32, dropout = 0.05):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            LinearWithLoRA(nn.Linear(dim, hidden_dim), rank, alpha),
            nn.GELU(),
            nn.Dropout(dropout),
            LinearWithLoRA(nn.Linear(hidden_dim, dim), rank, alpha),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)

class Attention(pl.LightningModule):
    def __init__(self, dim, heads = 8, dim_head = 64, rank=32, alpha=32, dropout = 0.05):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            LinearWithLoRA(nn.Linear(inner_dim, dim), rank, alpha),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(pl.LightningModule):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.05):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout),
                FeedForward(dim, mlp_dim, dropout = dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x

        return self.norm(x)

class Transformer(pl.LightningModule):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, rank, alpha=32, dropout = 0.05):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads=heads, dim_head=dim_head, rank=rank, alpha=alpha, dropout=dropout),
                FeedForward(dim, mlp_dim, rank=rank, alpha=alpha, dropout=dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)

class ViT(pl.LightningModule):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, rank, pool='cls', channels=3, dim_head=64, alpha=32, dropout = 0.05, emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            LinearWithLoRA(nn.Linear(patch_dim, dim), rank, alpha),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, rank, alpha, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True, logger=True)
        self.log('train_acc', acc, on_step=True, on_epoch=True, logger=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer



#MAIN

In [15]:
v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.05,
    emb_dropout = 0.1,
    rank = 32,  # Rank for LoRALayer
    alpha = 32 # Alpha for LoRALayer
)
# Freeze all parameters except for the parameters of the LoRALayer
for name, param in v.named_parameters():

    if 'lora' in name and not name.startswith("mlp_head"):  # Modify this condition to target the LoRALayer parameters
        param.requires_grad = True
    else:
        param.requires_grad = False
# Print out parameter names and their requires_grad attributes
# for name, param in v.named_parameters():
#     print(name, param.requires_grad)
img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)


In [16]:
# Main
dm = CIFAR10DataModule(batch_size=32)
dm.prepare_data()
dm.setup(stage='fit')
dm.setup(stage='test')

# input_shape = (3, 32, 32)
# model = LitModel(input_shape=input_shape, num_classes=dm.num_classes)
model=v
#init wantb
wandb.init(project="Clean ViT")
wandb.watch(v)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    monitor='val_loss',
    dirpath='/content/checkpoints',
    filename='cifar10-{epoch:02d}-{val_loss:.2f}',
    save_top_k=1,
    mode='min',
    save_weights_only=False  # Save entire model
)

#start training
trainer = pl.Trainer(max_epochs=50, logger=pl.loggers.WandbLogger(), callbacks=[checkpoint_callback])
trainer.fit(model, dm)
trainer.test()

Files already downloaded and verified
Files already downloaded and verified


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

INFO:pytorch_lightning.utilities.rank_zero:GPU available: False, used: False
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


Files already downloaded and verified
Files already downloaded and verified


INFO:pytorch_lightning.callbacks.model_summary:
  | Name               | Type        | Params
---------------------------------------------------
0 | to_patch_embedding | Sequential  | 3.3 M 
1 | dropout            | Dropout     | 0     
2 | transformer        | Transformer | 52.0 M
3 | to_latent          | Identity    | 0     
4 | mlp_head           | Linear      | 1.0 M 
  | other params       | n/a         | 67.6 K
---------------------------------------------------
1.7 M     Trainable params
54.6 M    Non-trainable params
56.3 M    Total params
225.337   Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.10/dist-packages/pytorch_lightning/callbacks/model_checkpoint.py:383: `ModelCheckpoint(monitor='val_loss')` could not find the monitored key in the returned metrics: ['train_loss', 'train_loss_step', 'train_acc', 'train_acc_step', 'train_loss_epoch', 'train_acc_epoch', 'epoch', 'step']. HINT: Did you call `log('val_loss', value)` in the `LightningModule`?
/usr/local/lib/python3.10/dist-packages/pytorch_lightning/trainer/call.py:54: Detected KeyboardInterrupt, attempting graceful shutdown...


ValueError: `.test(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.