# Library

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt

# Create Model

In [None]:
# Define Vision Transformer (ViT) from scratch
class VisionTransformer(nn.Module):
    def __init__(self, image_size=224, patch_size=16, num_classes=2, dim=768, depth=12, heads=12, mlp_dim=3072, dropout=0.1):
        super(VisionTransformer, self).__init__()
        assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
        num_patches = (image_size // patch_size) ** 2
        patch_dim = 3 * patch_size * patch_size  # 3 for RGB channels
        
        # Linear projection of flattened patches
        self.patch_embed = nn.Linear(patch_dim, dim)
        
        # Positional embeddings for patches
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        
        # Class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        
        # Transformer blocks
        self.transformer = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(dim, heads, mlp_dim, dropout),
            num_layers=depth
        )
        
        # MLP head for classification
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
        
        # Dropout layer
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        batch_size = x.shape[0]
        
        # Divide image into patches and flatten
        patches = x.unfold(2, 16, 16).unfold(3, 16, 16)
        patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
        patches = patches.view(batch_size, -1, 16 * 16 * 3)  # Flatten patches
        
        # Apply linear projection to patches
        patches = self.patch_embed(patches)
        
        # Concatenate class token and add positional embeddings
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, patches), dim=1)
        x += self.pos_embedding
        
        # Apply transformer
        x = self.transformer(x)
        
        # Take the class token output (the first token)
        cls_output = x[:, 0]
        
        # Classification head
        x = self.mlp_head(cls_output)
        
        return x

# Setup Device

In [None]:
# Setup device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Transform Dataset

In [None]:
# Define the transform for the dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Load Dataset

In [None]:
# Load train and test datasets
train_data = datasets.ImageFolder(root='Dataset/Train', transform=transform)
test_data = datasets.ImageFolder(root='Dataset/Test', transform=transform)

In [None]:
# Dataloaders
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False)

# Initialize Model

In [None]:
# Initialize the Vision Transformer model
model = VisionTransformer(num_classes=2).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Callbacks

In [None]:
# Reduce learning rate when a metric has stopped improving
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)

# Checkpoint function
def save_checkpoint(epoch, model, optimizer, loss, filepath):
    print(f"Saving model checkpoint at epoch {epoch+1} with loss {loss:.4f}...")
    state = {
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss
    }
    torch.save(state, filepath)

# Directory to save the checkpoints
checkpoint_dir = './checkpoints'
os.makedirs(checkpoint_dir, exist_ok=True)

# Train Model

In [None]:
# Lists to store the loss and accuracy for plotting
train_losses = []
test_losses = []
test_accuracies = []
best_loss = float('inf')

In [None]:
# Training loop
num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    avg_loss = running_loss / len(train_loader)
    train_losses.append(avg_loss)  # Store train loss
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_loss:.4f}")

    # Testing loop
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    avg_test_loss = test_loss / len(test_loader)
    test_losses.append(avg_test_loss)  # Store test loss
    accuracy = 100 * correct / total
    test_accuracies.append(accuracy)  # Store accuracy
    print(f'Epoch [{epoch+1}/{num_epochs}], Test Loss: {avg_test_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

    # Save checkpoint if the test loss improves
    if avg_test_loss < best_loss:
        best_loss = avg_test_loss
        checkpoint_path = os.path.join(checkpoint_dir, f"best_model_epoch_{epoch+1}.pt")
        save_checkpoint(epoch, model, optimizer, best_loss, checkpoint_path)

    # Step the scheduler
    scheduler.step(avg_test_loss)

# Plot Result

In [None]:
# Plotting the loss and accuracy
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.title('Loss per Epoch')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(test_accuracies, label='Test Accuracy')
plt.title('Accuracy per Epoch')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()