In [1]:
from torchvision.models import vit_h_14, ViT_H_14_Weights
from sklearn.utils import compute_class_weight
from torchvision import datasets, transforms
from torch.amp import autocast, GradScaler
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from torch.optim import lr_scheduler
from tqdm.notebook import tqdm
import torch.nn as nn
import pandas as pd
import numpy as np
import warnings
import torch
import time
import os
warnings.filterwarnings("ignore", message=".*flash attention.*")

In [2]:
train_data_directory = 'data/train'
test_data_directory = 'data/test'
weights_directory = 'weights'
predictions_directory = 'predictions'

os.makedirs(weights_directory, exist_ok=True)
os.makedirs(predictions_directory, exist_ok=True)

model_name = 'vit_h_14'
model_image_width = 518
num_folds = 5
batch_size = 1
num_epochs = 4
learning_rate = 0.000001

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
train_transforms = transforms.Compose([
    v2.CenterCrop(800),
    v2.RandomHorizontalFlip(),
    v2.RandomVerticalFlip(),
    v2.RandomRotation(10),
    v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    v2.GaussianNoise(),
    v2.Resize((model_image_width, model_image_width)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

In [5]:
val_transforms = transforms.Compose([
    v2.CenterCrop(800),
    v2.Resize((model_image_width, model_image_width)),
    v2.ToImage(),
    v2.ToDtype(torch.float32, scale=True),
    v2.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

In [None]:
full_dataset = datasets.ImageFolder(root=train_data_directory)

class_names = full_dataset.classes
num_classes = len(class_names)

timestamp = time.strftime('%Y%m%d-%H%M%S')
model_folder = os.path.join(weights_directory, f"{model_name}_{timestamp}")
os.makedirs(model_folder, exist_ok=True)

performance_reports = []
report_path = os.path.join(model_folder, 'performance_report.csv')

kfold = KFold(n_splits=num_folds, shuffle=True)

In [7]:
start_time = time.time()
for fold, (train_idx, val_idx) in enumerate(kfold.split(full_dataset)):
    model = vit_h_14(weights=ViT_H_14_Weights.DEFAULT)
    model.heads.head = nn.Linear(model.heads.head.in_features, num_classes)
    model = model.to(device)

    train_subset = torch.utils.data.Subset(full_dataset, train_idx)
    val_subset = torch.utils.data.Subset(full_dataset, val_idx)

    train_subset.dataset.transform = train_transforms
    val_subset.dataset.transform = val_transforms

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

    train_labels = [train_subset.dataset.targets[i] for i in train_subset.indices]
    class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
    class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 4)
    scaler = GradScaler()

    # Training loop
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        
        for index, (images, labels) in enumerate(tqdm(train_loader, desc=f'Training Epoch {epoch+1}/{num_epochs}')):
            images, labels = images.to(device), labels.to(device)
            
            # Forward pass
            with autocast("cuda"):
                outputs = model(images)
                loss = criterion(outputs, labels)
            
            # Backward and optimize
            optimizer.zero_grad()
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            scheduler.step(epoch +  index / len(train_loader))
            
            train_loss += loss.item()
        
        # Validation loop
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        
        with torch.no_grad():
            for images, labels in tqdm(val_loader, desc=f'Validation Epoch {epoch+1}/{num_epochs}'):
                images, labels = images.to(device), labels.to(device)

                with autocast("cuda"):
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                
                    _, predictions = torch.max(outputs, 1)
                    total += labels.size(0)
                    correct += (predictions == labels).sum().item()
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        accuracy = 100 * correct / total
        time_elapsed = time.time() - start_time
        
        print(f'Epoch [{epoch+1}/{num_epochs}], '
              f'Train Loss: {avg_train_loss:.4f}, '
              f'Val Loss: {avg_val_loss:.4f}, '
              f'Accuracy: {accuracy:.2f}% '
              f'Time elapsed: {time_elapsed:.2f} seconds')
        
        performance_reports.append({
            'Batch Size': batch_size,
            'Fold': fold + 1,
            'Epoch': epoch + 1,
            'LR': scheduler.get_last_lr()[0],
            'Train Loss': avg_train_loss,
            'Validation Loss': avg_val_loss,
            'Accuracy': accuracy,
            'Time Elapsed': time_elapsed
        })

        report_df = pd.DataFrame(performance_reports)
        report_df.to_csv(report_path, index=False)
    
    weights_path = os.path.join(model_folder, f'model_weights_{fold + 1}.pth')
    torch.save(model.state_dict(), weights_path)

Training Epoch 1/4:   0%|          | 0/18880 [00:00<?, ?it/s]

Validation Epoch 1/4:   0%|          | 0/4721 [00:00<?, ?it/s]

Epoch [1/4], Train Loss: 0.3984, Val Loss: 0.3312, Accuracy: 88.77% Time elapsed: 56422.71 seconds


Training Epoch 2/4:   0%|          | 0/18880 [00:00<?, ?it/s]

In [8]:
test_dataset = datasets.ImageFolder(root=test_data_directory, transform=val_transforms)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
idx_to_class = {v: k for k, v in test_loader.dataset.class_to_idx.items()}

model.eval()
test_predictions = []
image_paths = []

with torch.no_grad():
    for images, batch_indices in tqdm(test_loader, desc='Test'):
        images = images.to(device)
        outputs = model(images)
        _, predictions = torch.max(outputs, 1)
        labels = [idx_to_class[prediction] for prediction in predictions.cpu().numpy()]
        test_predictions.extend(labels)
        image_paths.extend([test_loader.dataset.samples[idx][0] for idx in batch_indices])

Test:   0%|          | 0/2619 [00:00<?, ?it/s]

In [9]:
test_predictions_df = pd.DataFrame({
    'Image': [path.split('\\')[-1] for path in image_paths],
    'ImagePath': image_paths,
    'PredictedClass': test_predictions,
    'ActualClass': [path.split('\\')[-2] for path in image_paths]
})
test_predictions_filename = f"{model_name}_{timestamp}_test_predictions.csv"
test_predictions_path = os.path.join(predictions_directory, test_predictions_filename)
test_predictions_df.to_csv(test_predictions_path, index=False)