In [1]:
# Install required packages
!pip install linformer
!pip install vit_pytorch

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from torch.optim.lr_scheduler import StepLR
from tqdm.notebook import tqdm
from vit_pytorch.efficient import ViT
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.metrics import confusion_matrix
import torch.utils.data as data
import torchvision
from torchvision import transforms

Collecting linformer
  Downloading linformer-0.2.3-py3-none-any.whl (6.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->linformer)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->linformer)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->linformer)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->linformer)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->linformer)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->linformer)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collect

In [2]:
# Check if CUDA is available
print(torch.cuda.is_available())

# Set device
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Set seed for reproducibility
torch.manual_seed(142)

False


<torch._C.Generator at 0x7ae463fc5090>

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
# Hyperparameters
batch_size = 64
epochs = 100
lr = 0.0001
gamma = 0.7
IMG_SIZE = 200
patch_size = 20
num_classes = 2

In [5]:
# Transforms for image resizing and normalization
'''transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor()
])'''
'''transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])'''

transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.ToTensor()
])


In [None]:
# Define paths
train_dir = '/content/drive/MyDrive/dataset/training'
val_dir = '/content/drive/MyDrive/dataset/test'
test_dir = '/content/drive/MyDrive/dataset/validation'

# Load datasets
train_ds = torchvision.datasets.ImageFolder(train_dir, transform=transform)
valid_ds = torchvision.datasets.ImageFolder(val_dir, transform=transform)
test_ds = torchvision.datasets.ImageFolder(test_dir, transform=transform)

# Data loaders
train_loader = data.DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
valid_loader = data.DataLoader(valid_ds, batch_size=batch_size, shuffle=True, num_workers=0)
test_loader = data.DataLoader(test_ds, batch_size=batch_size, shuffle=True, num_workers=0)

# Linear Transformer
efficient_transformer = Linformer(dim=256, seq_len=(IMG_SIZE // patch_size) ** 2 + 1, depth=24, heads=16, k=128)

# Vision Transformer Model
model = ViT(
    dim=256,
    image_size=IMG_SIZE,
    patch_size=patch_size,
    num_classes=num_classes,
    transformer=efficient_transformer,
    channels=3,
).to(device)

# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)

# Learning rate scheduler
scheduler = StepLR(optimizer, step_size=10, gamma=gamma)

# Training loop
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    epoch_accuracy = 0
    for data, label in tqdm(train_loader):
        data, label = data.to(device), label.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

        acc = (output.argmax(dim=1) == label).float().mean()
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)

    model.eval()
    with torch.no_grad():
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        for data, label in valid_loader:
            data, label = data.to(device), label.to(device)
            val_output = model(data)
            val_loss = criterion(val_output, label)

            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(valid_loader)
            epoch_val_loss += val_loss / len(valid_loader)

    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}, Val Loss: {epoch_val_loss:.4f}, Val Accuracy: {epoch_val_accuracy:.4f}")

# Save the model
PATH = f"epochs_{epochs}img{IMG_SIZE}patch{patch_size}lr{lr}.pt"
torch.save(model.state_dict(), PATH)

# Load saved model
efficient_transformer = Linformer(dim=128, seq_len=(IMG_SIZE // patch_size) ** 2 + 1, depth=12, heads=8, k=64)
model = ViT(image_size=IMG_SIZE, patch_size=patch_size, num_classes=num_classes, dim=128, transformer=efficient_transformer, channels=3).to(device)
model.load_state_dict(torch.load(PATH))

# Function to calculate overall accuracy
def overall_accuracy(model, test_loader, criterion):
    model.eval()
    y_proba = []
    y_truth = []
    test_loss = 0
    total = 0
    correct = 0
    with torch.no_grad():
        for data, label in tqdm(test_loader):
            data, label = data.to(device), label.to(device)
            output = model(data)
            test_loss += criterion(output, label.long()).item()
            for index, i in enumerate(output):
                y_proba.append(i[1].item())
                y_truth.append(label[index].item())
                if torch.argmax(i) == label[index]:
                    correct += 1
                total += 1
    accuracy = correct / total
    y_proba_out = np.array(y_proba)
    y_truth_out = np.array(y_truth)
    return test_loss, accuracy, y_proba_out, y_truth_out

# Evaluate model on test data
loss, acc, y_proba, y_truth = overall_accuracy(model, test_loader, criterion)

print(f"Test Accuracy: {acc:.4f}")

# Plot confusion matrix
cm = confusion_matrix(y_truth, np.argmax(y_proba.reshape(-1, 1), axis=1))
print(cm)

  0%|          | 0/19 [00:00<?, ?it/s]