# Imports

In [None]:
import wandb
from kaggle_secrets import UserSecretsClient
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from torchvision import datasets, transforms, utils
import os
import random
from collections import defaultdict
from PIL import Image
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True


user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("wandb-api-key")

root_dir = '/kaggle/input/intel-image-classification'

wandb.login(key=wandb_api_key)

# Hyperparameters

In [None]:
batch_size = 32
n_epochs = 500
learning_rate = 1e-3
dropout = 0.2
limit_per_class = 0  # 0 to disable
seed = 42
save_every = 1
n_patches = 32
n_embedding = 128

# WandB

In [None]:
model_name = "ViT-IntelImage"

wandb.init(
    project="deep-learning",
    config={
        "model": model_name,
        "batch_size": batch_size,
        "n_epochs": n_epochs,
        "n_patches": n_patches,
        "n_embedding": n_embedding,
        "learning_rate": learning_rate,
        "dropout": dropout,
        "seed": seed,
        "limit_per_class": limit_per_class,
    }
)

# IntelImageDataset

In [None]:
class IntelImageDataset(data.Dataset):
    def __init__(self, root_dir=root_dir, train=True, seed=seed, limit_per_class=limit_per_class):
        super().__init__()

        # params dataset
        self.root_dir = root_dir
        self.train = train
        self.limit_per_class = limit_per_class

        # data and labels
        self.image_paths = []
        self.labels = []

        self.number_of_classes = 0
        self.classes = []

        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(degrees=30), # degrees = range of rotation
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1), # parameters are ranges
            transforms.RandomGrayscale(p=0.1), # p = probability of applying the transform
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        data_dir = os.path.join(root_dir, 'seg_train/seg_train' if train else 'seg_test/seg_test')

        # Iterate through class folders within the alphabetical folder
        for class_name in os.listdir(data_dir):
            class_dir = os.path.join(data_dir, class_name)
            if not os.path.isdir(class_dir):
                continue  # Skip if not a directory

            filenames = sorted(os.listdir(class_dir))
            count = 0
            for filename in filenames:
                if limit_per_class == 0 or count < limit_per_class:
                    try:
                      img_path = os.path.join(class_dir, filename)

                      Image.open(img_path).verify()

                      self.image_paths.append(img_path)
                      self.labels.append(self.number_of_classes)

                      count += 1

                    except (IOError, SyntaxError):
                      print(
                          'Corrupted image or non-image file detected and skipped:', filename)
                else:
                    break
                    
            self.number_of_classes += 1
            self.classes.append(class_name)

        random.seed(seed)
        combined = list(zip(self.image_paths, self.labels))
        random.shuffle(combined)
        self.image_paths, self.labels = zip(*combined)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label = self.labels[index]

        try:
          image = Image.open(image_path)
          op = 1
          image = image.convert('RGB')
          rgb = 1
        except (IOError, SyntaxError):
          print(f"Error convert to load {op} RGB {rgb} : {image_path} {label}")

        image = self.transform(image)

        return image, label

# ViT

In [None]:
if torch.cuda.is_available():
    device = 'cuda'
else :
    device = 'cpu'

print(device)

In [None]:
def patchify(images, n_patches):
    n, c, h, w = images.shape

    patches = torch.zeros(n, n_patches ** 2, h * w * c // n_patches ** 2)
    patch_size = h // n_patches

    for idx, image in enumerate(images):
        for i in range(n_patches):
            for j in range(n_patches):
                patch = image[:, i * patch_size: (i + 1) * patch_size, j * patch_size: (j + 1) * patch_size]
                patches[idx, i * n_patches + j] = patch.flatten()
    return patches

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, head_size, n_patches, n_embedding):
        super(MultiHeadAttention, self).__init__()
        self.n_heads = n_heads
        self.head_size = head_size
        
        self.q_mappings = nn.ModuleList([nn.Linear(self.head_size, self.head_size) for _ in range(self.n_heads)])
        self.k_mappings = nn.ModuleList([nn.Linear(self.head_size, self.head_size) for _ in range(self.n_heads)])
        self.v_mappings = nn.ModuleList([nn.Linear(self.head_size, self.head_size) for _ in range(self.n_heads)])
        self.head_dropout = nn.ModuleList([nn.Dropout(dropout) for _ in range(self.n_heads)])

        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, sequences):
        # B : batch size
        # L : length of the sequence
        # D : dimension of each token
        # Sequences has shape (B, L, D)
        
        # N : number of heads
        # S : size of each head 
        
        result = []
        for sequence in sequences:
            seq_result = []
            for head in range(self.n_heads):
                q_mapping = self.q_mappings[head]
                k_mapping = self.k_mappings[head]
                v_mapping = self.v_mappings[head]
                
                #each head operates on a given segment of the sequence
                seq = sequence[:, head * self.head_size: (head + 1) * self.head_size] #slices of shape (L, D_h)
                
                q, k, v = q_mapping(seq), k_mapping(seq), v_mapping(seq) #slices of shape (L, D_h)

                attention = self.softmax(q @ k.T / (self.head_size ** 0.5)) # (L, D_h) x (D_h, L) --> (L, L)
                dropout_layer = self.head_dropout[head]
                attention = dropout_layer(attention)  #(L, L)
                
                seq_result.append(attention @ v) #(L, L) x (L, D_h) --> (L, D_h)
                
            #concatenate 
            result.append(torch.hstack(seq_result)) # concatenation of N slice of shape (L, D_h) --> (L, D)
            
        h = torch.cat([torch.unsqueeze(r, dim=0) for r in result])#stack all sequences along a new dimesions : back to (B, L, D)
        return h

In [None]:
class FeedForward(nn.Module):
    def __init__(self, n_embedding):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embedding, 4*n_embedding),
            nn.GELU(),
            nn.Linear(4*n_embedding, n_embedding),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)  

In [None]:
class Block(nn.Module):
    def __init__(self, n_embedding, n_heads, n_patches):
        super(Block, self).__init__()
        self.n_embedding = n_embedding
        self.n_heads = n_heads
        self.n_patches = n_patches

        head_size = self.n_embedding // n_heads

        self.sa = MultiHeadAttention(self.n_heads, head_size, self.n_patches, self.n_embedding)
        self.ff = FeedForward(self.n_embedding)
        self.ln1 = nn.LayerNorm(self.n_embedding)
        self.ln2 = nn.LayerNorm(self.n_embedding)


    def forward(self, x):
        h = x
        h = h + self.sa(self.ln1(h))
        h = h + self.ff(self.ln2(h))
        return h

In [None]:
class VisualTransformer(nn.Module):
    def __init__(self, dimensions, n_patches, n_embedding, n_heads, output_size):
        # Super constructor
        super(VisualTransformer, self).__init__()
        
        # Attributes
        self.dimensions = dimensions 
        self.n_patches = n_patches
        self.n_heads = n_heads
        self.n_embedding = n_embedding
        self.output_size = output_size
        
        self.patch_size = (dimensions[1] / n_patches, dimensions[2] / n_patches)

        # 1) Linear mapper
        self.input_d = int(dimensions[0] * self.patch_size[0] * self.patch_size[1])
        self.linear_mapper = nn.Linear(self.input_d, self.n_embedding)
        
        # 2) Learnable classification token
        self.class_token = nn.Parameter(torch.rand(1, self.n_embedding))
        
        # 3) Positional embedding
        self.positional_embeddings = nn.Embedding(self.n_patches**2 + 1, self.n_embedding)
        
        # 4) Transformer encoder blocks
        self.blocks = nn.Sequential(
            Block(self.n_embedding, self.n_heads, self.n_patches),
            Block(self.n_embedding, self.n_heads, self.n_patches),
            Block(self.n_embedding, self.n_heads, self.n_patches),
            nn.LayerNorm(self.n_embedding)
        )
        
        # 5) Classification MLP
        self.lm_head = nn.Sequential(
            nn.Linear(self.n_embedding, self.output_size),
            nn.Softmax(dim=-1)
        )

    def forward(self, images):
        # Patchify
        patches = patchify(images, self.n_patches).to(device)
        
        #Token Embedding
        tok_emb = self.linear_mapper(patches)
        
        # Adding classification token
        tok_emb = torch.stack([torch.vstack((self.class_token, tok_emb[i]))for i in range(len(tok_emb))])
        
        #Positional Embedding
        pos_emb = self.positional_embeddings(torch.arange(self.n_patches**2 + 1, device=device))
        pos_emb = pos_emb.unsqueeze(0)
        pos_emb = pos_emb.expand(tok_emb.shape[0], -1, -1)  # Expand along the batch dimension
        
        #Sum Token and Positional Embedding
        h = tok_emb + pos_emb
        
        # Transformer Blocks
        h = self.blocks(h)
            
        # Getting the classification token only
        h = h[:, 0]
        
        #Final head
        h = self.lm_head(h)
        
        return h

# Train loop

In [None]:
def train(model, trainloader, testloader, n_epochs=n_epochs, learning_rate=learning_rate):

    model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    train_avg_loss = []
    test_avg_loss = []
    test_accuracy = []

    for i in range(n_epochs):

        print(f"Epoch : {i}")

        train_losses = []
        test_losses = []
        
        # train
        for x, y in trainloader:
            # send to device
            x = x.to(device)
            y = y.to(device)

            # predict
            pred = model(x)
            loss = criterion(pred, y)
            train_losses.append(loss.detach())

            # step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # test
        with torch.no_grad():
            correct = 0

            for x,y in testloader:
                x = x.to(device)
                y = y.to(device)

                pred = model(x)
                loss = criterion(pred, y)
                test_losses.append(loss.detach())

                y_pred = pred.argmax(dim=-1)
                correct = correct + (y_pred==y).sum()

            accuracy = (correct / len(testloader.dataset))

        train_loss = torch.stack(train_losses).mean()
        test_loss = torch.stack(test_losses).mean()

        print(f"train_losses : {train_loss}")
        print(f"test_losses : {test_loss}")
        print(f"accuracy : {accuracy}")
        
        wandb.log({
            "epoch": i,
            "train loss": train_loss,
            "test loss": test_loss,
            "accuracy": accuracy,
        })
        
        if i % save_every == 0:
            torch.save(model.state_dict(), f"epoch_{i}_model.pt")
            wandb.save(f"epoch_{i}_model.pt")

        train_avg_loss.append(train_loss)
        test_avg_loss.append(test_loss)
        test_accuracy.append(accuracy)

    return train_avg_loss, test_avg_loss, test_accuracy

# Create dataset / dataloader

In [None]:
# Instantiate the train and test set

# train
train_dataset = IntelImageDataset(train=True)

# test
test_dataset = IntelImageDataset(train=False)

In [None]:
# Instantiate the corresponding data loaders

# train
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# test
test_loader = data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

# Define Model

In [None]:
input_features = [3, 256, 256] # Channels (assuming RGB images), Height, Width
output_features = train_dataset.number_of_classes

network = VisualTransformer(input_features, n_patches, n_embedding, 4, output_features).to(device)
print("Parameters:", sum(p.numel() for p in network.parameters())/1e3, 'K parameters')
print(network)

# Train Model

In [None]:
train_avg_loss, test_avg_loss, test_accuracy = train(model=network,
                                                     trainloader=train_loader,
                                                     testloader=test_loader,
                                                     n_epochs=n_epochs,
                                                     learning_rate=learning_rate
                                                     )

# Plot

In [None]:
train_avg_loss_np = torch.tensor(train_avg_loss).detach().cpu().numpy()
test_avg_loss_np = torch.tensor(test_avg_loss).detach().cpu().numpy()
test_accuracy_np = torch.tensor(test_accuracy).detach().cpu().numpy()

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_avg_loss_np, label='Training Loss')
plt.plot(test_avg_loss_np, label='Testing Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Testing Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_accuracy_np, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Test Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

# Finish wandb run

In [None]:
# necessary in notebooks
wandb.finish()