In [1]:
import torch
from torch.utils.data import random_split, DataLoader
import torchvision.transforms as transforms
import torchvision

transform = transforms.Compose([
    transforms.RandomRotation(degrees=15),  
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), 
    transforms.RandomResizedCrop(size=(32, 32), scale=(0.8, 1.0)),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),  
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)) 
])

train_ds = torchvision.datasets.CIFAR100("/home/eagle/Projects/dl_from_scratch/cifar10", train=True, download=True, transform=transform)
test_ds = torchvision.datasets.CIFAR100("/home/eagle/Projects/dl_from_scratch/cifar10", train=False, download=True, transform=transform)

train_size = int(0.8 * len(train_ds))  # 80% for training
val_size = len(train_ds) - train_size  # 20% for validation

# Split the train_dataset into train and val
train_ds, val_ds = random_split(train_ds, [train_size, val_size])

batch_size = 512
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=6)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, pin_memory=True, num_workers=6)

Files already downloaded and verified
Files already downloaded and verified


In [2]:
import torch
import torch.nn as nn
import math
from pytorch_model_summary import summary

class FeedForwardBlock(nn.Module):
    def __init__(self, d_model: int, d_ff: int, dropout: float):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.linear_2(self.dropout(self.gelu(self.linear_1(x))))

class MultiHeadAttentionBlock(nn.Module):
    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.h = h
        assert d_model % h == 0, "d_model is not appropriate size"

        self.d_k = d_model // h
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)

        self.w_o = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]

        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            attention_scores.masked_fill(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim = -1)
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        return (attention_scores @ value), attention_scores
        
    def forward(self, q, v, k, mask):
        query = self.w_q(q)
        key = self.w_k(k)
        value = self.w_v(v)

        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        x, self_attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)

        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
        return self.w_o(x)

class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, h, d_ff, dropout):
        super().__init__()
        self.attention = MultiHeadAttentionBlock(d_model, h, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.ff = FeedForwardBlock(d_model, d_ff, dropout)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        x = x + self.attention(self.norm1(x), self.norm1(x), self.norm1(x), mask)
        x = x + self.ff(self.norm2(x))
        return x

class DEIT(nn.Module):
    def __init__(self, img_size=224, patch_size=16, d_model=192, num_classes=1000, num_layers=12, h=3, d_ff=768, dropout=0.1):
        super().__init__()
        self.patch_size = patch_size
        num_patches = (img_size // patch_size) ** 2
        self.d_model = d_model

        self.patch_embed = nn.Conv2d(3, d_model, kernel_size=patch_size, stride=patch_size)

        self.class_token = nn.Parameter(torch.randn(1, 1, d_model))
        self.pos_embed = nn.Parameter(torch.randn(1, num_patches + 1, d_model))
        self.dropout = nn.Dropout(dropout)

        self.layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, h, d_ff, dropout)
            for _ in range(num_layers)
        ])

        self.norm = nn.LayerNorm(d_model)
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, x):
        batch_size = x.size(0)

        x = self.patch_embed(x).flatten(2).transpose(1, 2)

        class_token = self.class_token.expand(batch_size, -1, -1)
        x = torch.cat((class_token, x), dim=1)

        x = x + self.pos_embed
        x = self.dropout(x)

        for layer in self.layers:
            x = layer(x)

        x = self.norm(x[:, 0])  # Only use the class token
        return self.classifier(x)

device = "cuda" if torch.cuda.is_available() else "cpu"
model = DEIT(img_size=32, patch_size=4, d_model=192, num_classes=100, num_layers=4, h=3, d_ff=768, dropout=0.1).to(device)
print(summary(model, torch.zeros(1, 3, 32,32).to(device), show_input=False, show_hierarchical=False))

---------------------------------------------------------------------------------
                Layer (type)        Output Shape         Param #     Tr. Param #
                    Conv2d-1      [1, 192, 8, 8]           9,408           9,408
                   Dropout-2        [1, 65, 192]               0               0
   TransformerEncoderLayer-3        [1, 65, 192]         341,684         341,684
   TransformerEncoderLayer-4        [1, 65, 192]         341,684         341,684
   TransformerEncoderLayer-5        [1, 65, 192]         341,684         341,684
   TransformerEncoderLayer-6        [1, 65, 192]         341,684         341,684
                 LayerNorm-7            [1, 192]             384             384
                    Linear-8            [1, 100]          19,300          19,300
Total params: 1,395,828
Trainable params: 1,395,828
Non-trainable params: 0
---------------------------------------------------------------------------------


In [3]:
from tqdm import tqdm
import torch.optim as optim

loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.01)
scaler = torch.cuda.amp.GradScaler()

num_epochs = 100

for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    
    # Add tqdm to the training loop
    train_loader_tqdm = tqdm(train_loader, desc=f"Epoch [{epoch + 1}/{num_epochs}] - Training", leave=False)
    
    for inputs, labels in train_loader_tqdm:
        inputs, labels = inputs.to(device), labels.to(device)  
        optimizer.zero_grad()

        # Forward pass with mixed precision
        with torch.cuda.amp.autocast():  
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)

        # Backward pass
        scaler.scale(loss).backward()  # Scale the loss for stable gradients
        scaler.step(optimizer)  # Update the parameters
        scaler.update()  # Update the scaler

        running_loss += loss.item()   
        
        # Update tqdm description with running loss
        train_loader_tqdm.set_postfix(loss=loss.item())

    # Calculate average loss for the epoch
    avg_loss = running_loss / len(train_loader)

    # Validation phase
    model.eval()  # Set the model to evaluation mode
    running_val_loss = 0.0
    correct = 0
    total = 0

    # Add tqdm to the validation loop
    val_loader_tqdm = tqdm(val_loader, desc="Validation", leave=False)

    with torch.no_grad():  # Disable gradient calculation
        for inputs, labels in val_loader_tqdm:  # Assuming val_loader is defined
            inputs, labels = inputs.to(device), labels.to(device)

            with torch.cuda.amp.autocast():  # Enable autocasting for validation
                outputs = model(inputs)
                loss = loss_fn(outputs, labels)
                
            running_val_loss += loss.item()

            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Update tqdm description with running validation loss
            val_loader_tqdm.set_postfix(val_loss=loss.item())
    
    # Calculate average validation loss and accuracy
    avg_val_loss = running_val_loss / len(val_loader)
    accuracy = correct / total * 100  # Convert to percentage

    print(f"Epoch [{epoch + 1}/{num_epochs}], "
          f"Training Loss: {avg_loss:.4f}, "
          f"Validation Loss: {avg_val_loss:.4f}, "
          f"Validation Accuracy: {accuracy:.2f}%")

print("Training complete.")

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():  # Enable autocasting for validation
                                                                                

Epoch [1/100], Training Loss: 4.4306, Validation Loss: 4.2432, Validation Accuracy: 4.87%


                                                                                

Epoch [2/100], Training Loss: 4.1112, Validation Loss: 3.9193, Validation Accuracy: 10.36%


                                                                                

Epoch [3/100], Training Loss: 3.8542, Validation Loss: 3.7122, Validation Accuracy: 13.75%


                                                                                

Epoch [4/100], Training Loss: 3.6964, Validation Loss: 3.5741, Validation Accuracy: 15.87%


                                                                                

Epoch [5/100], Training Loss: 3.5691, Validation Loss: 3.4694, Validation Accuracy: 16.81%


                                                                                

Epoch [6/100], Training Loss: 3.4662, Validation Loss: 3.3724, Validation Accuracy: 18.48%


                                                                                

Epoch [7/100], Training Loss: 3.3794, Validation Loss: 3.2944, Validation Accuracy: 20.44%


                                                                                

Epoch [8/100], Training Loss: 3.3058, Validation Loss: 3.2364, Validation Accuracy: 20.93%


                                                                                

Epoch [9/100], Training Loss: 3.2428, Validation Loss: 3.1891, Validation Accuracy: 22.46%


                                                                                

Epoch [10/100], Training Loss: 3.1788, Validation Loss: 3.1202, Validation Accuracy: 23.77%


                                                                                

Epoch [11/100], Training Loss: 3.1254, Validation Loss: 3.1185, Validation Accuracy: 23.27%


                                                                                

Epoch [12/100], Training Loss: 3.0837, Validation Loss: 3.0176, Validation Accuracy: 25.51%


                                                                                

Epoch [13/100], Training Loss: 3.0370, Validation Loss: 3.0091, Validation Accuracy: 25.20%


                                                                                

Epoch [14/100], Training Loss: 2.9970, Validation Loss: 2.9757, Validation Accuracy: 26.20%


                                                                                

Epoch [15/100], Training Loss: 2.9596, Validation Loss: 2.9391, Validation Accuracy: 27.33%


                                                                                

Epoch [16/100], Training Loss: 2.9241, Validation Loss: 2.9250, Validation Accuracy: 27.57%


                                                                                

Epoch [17/100], Training Loss: 2.8833, Validation Loss: 2.8949, Validation Accuracy: 27.67%


                                                                                

Epoch [18/100], Training Loss: 2.8562, Validation Loss: 2.8545, Validation Accuracy: 28.85%


                                                                                

Epoch [19/100], Training Loss: 2.8208, Validation Loss: 2.8431, Validation Accuracy: 28.71%


                                                                                

Epoch [20/100], Training Loss: 2.7932, Validation Loss: 2.8201, Validation Accuracy: 29.15%


                                                                                

Epoch [21/100], Training Loss: 2.7704, Validation Loss: 2.8044, Validation Accuracy: 30.22%


                                                                                

Epoch [22/100], Training Loss: 2.7427, Validation Loss: 2.7887, Validation Accuracy: 30.10%


                                                                                

Epoch [23/100], Training Loss: 2.7276, Validation Loss: 2.7508, Validation Accuracy: 30.78%


                                                                                

Epoch [24/100], Training Loss: 2.6950, Validation Loss: 2.7499, Validation Accuracy: 30.36%


                                                                                

Epoch [25/100], Training Loss: 2.6673, Validation Loss: 2.7370, Validation Accuracy: 31.01%


                                                                                

Epoch [26/100], Training Loss: 2.6381, Validation Loss: 2.7259, Validation Accuracy: 31.25%


                                                                                

Epoch [27/100], Training Loss: 2.6163, Validation Loss: 2.7024, Validation Accuracy: 31.64%


                                                                                

Epoch [28/100], Training Loss: 2.6054, Validation Loss: 2.6784, Validation Accuracy: 32.34%


                                                                                

Epoch [29/100], Training Loss: 2.5718, Validation Loss: 2.6699, Validation Accuracy: 32.61%


                                                                                

Epoch [30/100], Training Loss: 2.5569, Validation Loss: 2.6556, Validation Accuracy: 32.66%


                                                                                

Epoch [31/100], Training Loss: 2.5274, Validation Loss: 2.6447, Validation Accuracy: 33.52%


                                                                                

Epoch [32/100], Training Loss: 2.5127, Validation Loss: 2.6305, Validation Accuracy: 33.33%


                                                                                

Epoch [33/100], Training Loss: 2.4874, Validation Loss: 2.6166, Validation Accuracy: 33.66%


                                                                                

Epoch [34/100], Training Loss: 2.4712, Validation Loss: 2.6186, Validation Accuracy: 33.56%


                                                                                

Epoch [35/100], Training Loss: 2.4462, Validation Loss: 2.6092, Validation Accuracy: 34.29%


                                                                                

KeyboardInterrupt: 

In [None]:
model.eval()  # Set the model to evaluation mode
running_test_loss = 0.0
correct = 0
total = 0

# Add tqdm to the test loop
test_loader_tqdm = tqdm(test_loader, desc="Testing", leave=False)

with torch.no_grad():  # Disable gradient calculation
    for inputs, labels in test_loader_tqdm:  # Assuming test_loader is defined
        inputs, labels = inputs.to(device), labels.to(device)

        with torch.cuda.amp.autocast():  # Enable autocasting for testing
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
        
        running_test_loss += loss.item()

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Update tqdm description with running test loss
        test_loader_tqdm.set_postfix(test_loss=loss.item())

# Calculate average test loss and accuracy
avg_test_loss = running_test_loss / len(test_loader)
test_accuracy = correct / total * 100  # Convert to percentage

print(f"Test Loss: {avg_test_loss:.4f}, "
      f"Test Accuracy: {test_accuracy:.2f}%")

print("Testing complete.")