In [112]:
import math
import torch 
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torch import nn
from dataclasses import dataclass
import torch.nn.functional as F

In [113]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x

In [114]:
all_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
desired_classes = all_classes[0:8]
desired_indices = [all_classes.index(cls) for cls in desired_classes]
desired_indices

[0, 1, 2, 3, 4, 5, 6, 7]

In [115]:
DATA_DIR="../data"
def get_cifar10_data_loader():
    """
    Get the CIFAR10 data loader
    """
    # define the transform
    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    # get the training and testing datasets
    train_dataset = CIFAR10(root=DATA_DIR, train=True, transform=transform, download=True)
    test_dataset = CIFAR10(root=DATA_DIR, train=False, transform=transform, download=True)
    train_set, val_set = torch.utils.data.random_split(train_dataset, [45000, 5000])

    return train_set, val_set, test_dataset

In [99]:
train_set, val_set, test_set = get_cifar10_data_loader()

Files already downloaded and verified
Files already downloaded and verified


In [100]:
def filter_dataset(dataset):
    filtered_indices = [i  for i,  (_, label) in enumerate(dataset) if label in desired_indices]
    return torch.utils.data.Subset(dataset, filtered_indices)

In [101]:
train_set = filter_dataset(train_set)
val_set = filter_dataset(val_set)
test_set = filter_dataset(test_set)
print(f"len of train set {len(train_set)} val set {len(val_set)} test set {len(test_set)}")

len of train set 36029 val set 3971 test set 8000


In [116]:
@dataclass
class ModelArgs:
    dim:int =  256
    hidden_dim:int = 512
    n_heads:int = 8
    n_layers:int = 6
    patch_size:int = 4
    n_channels = 3
    n_patches = 64
    n_classes = 10
    dropout = 0.2

In [117]:
class MultiHeadAttention(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        
        self.wq = nn.Linear(self.dim, self.n_heads*self.head_dim, bias=False)
        self.wk = nn.Linear(self.dim, self.n_heads*self.head_dim, bias=False)
        self.wv = nn.Linear(self.dim, self.n_heads*self.head_dim, bias=False)
        self.wo = nn.Linear(self.n_heads*self.head_dim, self.dim, bias=False)
    
    def forward(self, x):
        b, seq_len, dim = x.shape
        
        assert dim == self.dim, "dim is not matching"
        q = self.wq(x)
        k = self.wk(x)
        v = self.wv(x)
        
        q = q.contiguous().view(b, seq_len, self.n_heads, self.head_dim)
        k = k.contiguous().view(b, seq_len, self.n_heads, self.head_dim)
        v = v.contiguous().view(b, seq_len, self.n_heads, self.head_dim)
        
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1,2)
        
        attn = torch.matmul(q, k. transpose(2, 3)) / math.sqrt(self.head_dim)
        attn_scores = F.softmax(attn, dim = -1)
        
        out = torch.matmul(attn_scores, v)
        out = out.contiguous().view(b, seq_len, -1)
        
        return self.wo(out)        

In [118]:
class AttentionBlock(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.layer_norm_1 = nn.LayerNorm(args.dim)
        self.attn = MultiHeadAttention(args)
        
        self.layer_norm_2 = nn.LayerNorm(args.dim)
        
        self.ffn = nn.Sequential(
            nn.Linear(args.dim, args.hidden_dim),
            nn.GELU(),
            nn.Dropout(args.dropout),
            nn.Linear(args.hidden_dim, args.dim),
            nn.Dropout(args.dropout)
        )
    
    def forward(self, x):
        x = x + self.attn(self.layer_norm_1(x))
        x = x + self.ffn(self.layer_norm_2(x))
        return x

In [119]:
class VisionTransformer(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        self.patch_size = args.patch_size
        
        self.input_layer = nn.Linear(args.n_channels * (args.patch_size ** 2), args.dim)
        attn_blocks = []
        for _ in range(args.n_layers):
            attn_blocks.append(AttentionBlock(args))
        
        self.transformer = nn.Sequential(*attn_blocks)
        
        self.mlp = nn.Sequential(
            nn.LayerNorm(args.dim),
            nn.Linear(args.dim, args.n_classes)
        )
        
        self.dropout = nn.Dropout(args.dropout)
        
        self.cls_token = nn.Parameter(torch.randn(1, 1, args.dim))
        self.pos_embedding = nn.Parameter(torch.randn(1, 1+args.n_patches, args.dim))
    
    def forward(self, x):
        x = img_to_patch(x, self.patch_size)
        b, seq_len, _ = x.shape
        x = self.input_layer(x)
        
        cls_token = self.cls_token.repeat(b, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        
        x = x + self.pos_embedding[:,:seq_len+1]
        
        x = self.dropout(x)
        x = self.transformer(x)
        # print("========== x shape =====", x.shape)
        x = x.transpose(0, 1)
        cls = x[0]
        out = self.mlp(cls)
        return out

In [120]:
args = ModelArgs()
args.dim

256

In [121]:
# Model, Loss and Optimizer
device = "cuda:0" if torch.cuda.is_available() else 0
args = ModelArgs()
model = VisionTransformer(args).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], gamma=0.1)

In [122]:
batch_size=64
num_workers = 16
# get the data loaders
train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True,
                                           num_workers=num_workers, drop_last=True)
val_loader = torch.utils.data.DataLoader(dataset=val_set, batch_size=batch_size, shuffle=True,
                                         num_workers=num_workers, drop_last=False)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False,
                                          num_workers=num_workers, drop_last=False)

In [123]:
num_epochs = 200  # example value, adjust as needed

for epoch in range(num_epochs):
    
    # Training Phase
    model.train()
    total_loss = 0.0
    for i, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(device), labels.to(device)
        
        # Zero the parameter gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        # print("==== outputs shape ===", outputs.shape)
        loss = criterion(outputs, labels)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}")

    # Validation Phase
    model.eval()
    total_val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:  # Assuming val_loader is defined elsewhere
            inputs, labels = inputs.to(device), labels.to(device)

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            total_val_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    avg_val_loss = total_val_loss / len(val_loader)
    val_accuracy = 100 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.2f}%")

    # Update the learning rate
    lr_scheduler.step()

print("Training complete!")


Epoch [1/200], Training Loss: 1.6599
Epoch [1/200], Validation Loss: 1.5112, Validation Accuracy: 43.21%
Epoch [2/200], Training Loss: 1.3780
Epoch [2/200], Validation Loss: 1.2998, Validation Accuracy: 51.67%
Epoch [3/200], Training Loss: 1.2626
Epoch [3/200], Validation Loss: 1.2616, Validation Accuracy: 55.20%
Epoch [4/200], Training Loss: 1.1883
Epoch [4/200], Validation Loss: 1.2068, Validation Accuracy: 56.06%
Epoch [5/200], Training Loss: 1.1228
Epoch [5/200], Validation Loss: 1.2128, Validation Accuracy: 58.47%
Epoch [6/200], Training Loss: 1.0702
Epoch [6/200], Validation Loss: 1.1508, Validation Accuracy: 58.93%
Epoch [7/200], Training Loss: 1.0143
Epoch [7/200], Validation Loss: 1.1241, Validation Accuracy: 59.28%
Epoch [8/200], Training Loss: 0.9619
Epoch [8/200], Validation Loss: 1.1036, Validation Accuracy: 59.81%
Epoch [9/200], Training Loss: 0.9168
Epoch [9/200], Validation Loss: 1.1206, Validation Accuracy: 60.09%
Epoch [10/200], Training Loss: 0.8715
Epoch [10/200], V

KeyboardInterrupt: 