In [None]:
## Building image classification using vision transformer
## Article: Developing Building Exposure Models Using Computer Vision and Deep Learning
#
# Authors: Sukh Sagar Shukla, Amit Bhatiya, Dhanya J, Saman Ghaffarian, Roberto Gentile
#
# Description:
# This script provides an example implementation of training a ViT model for an image classification task
# Please refer to the article for further details.

# Folder structure of the dataset:
# ├── Class1/
# │   ├── image1.jpg
# │   ├── image2.jpg
# │   ├── image3.jpg
# │   └── ...
# ├── Class2/
# │   ├── image1.jpg
# │   ├── image2.jpg
# │   ├── image3.jpg
# │   └── ...
# └── ...

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from pathlib import Path
import time
import json

class ViTTrainer:
    def __init__(
        self,
        data_dir,
        num_classes,
        batch_size=32,  # Reduced batch size for better generalization
        learning_rate=5e-6,  # Reduced learning rate
        train_split=0.7,  # Updated to 70%
        val_split=0.15,   # Added validation split 15%
        test_split=0.15,  # Added test split 15%
        image_size=224,
        num_workers=10,
        checkpoint_dir='Path to save the checkpoint' # Provide the directory path to save the checkpoints
    ):
        self.data_dir = data_dir
        self.num_classes = num_classes
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.train_split = train_split
        self.val_split = val_split
        self.test_split = test_split
        self.image_size = image_size
        self.num_workers = num_workers
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(exist_ok=True)
        
        self.log_dir = Path('Path to save training logs ') # Provide the directory path to save the training logs
        self.log_dir.mkdir(exist_ok=True)
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {self.device}")
        
        self._setup_data()
        self._setup_model()
        self._setup_training()

        # Initialize lists to store training and validation metrics
        self.train_losses = []
        self.train_accuracies = []
        self.val_losses = []
        self.val_accuracies = []

    def _setup_data(self):
        # Simplified transforms since data is pre-augmented
        data_transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

        # Load dataset
        full_dataset = datasets.ImageFolder(self.data_dir, transform=data_transform)
        
        # Calculate splits - Fixed to use proper 70:15:15 split
        total_size = len(full_dataset)
        train_size = int(self.train_split * total_size)
        val_size = int(self.val_split * total_size)
        test_size = total_size - train_size - val_size  # Ensure all samples are used

        print(f"Dataset split - Total: {total_size}, Train: {train_size}, Val: {val_size}, Test: {test_size}")

        generator = torch.Generator().manual_seed(42)
        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            full_dataset, [train_size, val_size, test_size], generator=generator
        )
        
        # Calculate class weights for balanced sampling
        labels = [full_dataset.targets[i] for i in self.train_dataset.indices]
        class_counts = torch.bincount(torch.tensor(labels))
        class_weights = 1.0 / class_counts.float()
        sample_weights = [class_weights[label] for label in labels]
        self.sampler = torch.utils.data.WeightedRandomSampler(
            sample_weights, len(sample_weights), replacement=True
        )
        
        # Create data loaders with balanced sampling
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            sampler=self.sampler,
            num_workers=self.num_workers,
            pin_memory=True,
            drop_last=True  # Drop the last incomplete batch to avoid BatchNorm issues
        )
        
        self.val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
        
        self.test_loader = DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )

    def _setup_model(self):
        # Load with latest weights and increased dropout
        self.model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        num_features = self.model.heads.head.in_features
        
        # Modified head with more regularization and complexity
        self.model.heads = nn.Sequential(
            nn.LayerNorm(num_features),
            nn.Linear(num_features, 2048),
            nn.GELU(),
            nn.Dropout(0.6),  # Increased dropout
            nn.Linear(2048, 1024),
            nn.GELU(),
            nn.Dropout(0.4),
            nn.Linear(1024, self.num_classes)
        )

        for param in self.model.parameters():
            param.requires_grad = False
            
        # Unfreeze more layers for fine-tuning
        trainable_layers = [
            'heads',
            'blocks.11',
            'blocks.10',
            'blocks.9',
            'blocks.8',
            'blocks.7',
            'blocks.6',
            'blocks.5'
        ]
        
        for name, param in self.model.named_parameters():
            if any(layer in name for layer in trainable_layers):
                param.requires_grad = True
                
        self.model = self.model.to(self.device)

    def _setup_training(self):
        # Use label smoothing and class weights
        class_counts = torch.bincount(torch.tensor(self.train_dataset.dataset.targets))
        class_weights = 1.0 / class_counts.float()
        class_weights = class_weights.to(self.device)
        self.criterion = nn.CrossEntropyLoss(weight=class_weights, label_smoothing=0.1)
        
        # Separate parameter groups with different learning rates
        encoder_params = []
        head_params = []
        
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                if 'heads' in name:
                    head_params.append(param)
                else:
                    encoder_params.append(param)
        
        # Modified optimizer settings
        self.optimizer = optim.AdamW([
            {'params': encoder_params, 'lr': self.learning_rate * 0.1},
            {'params': head_params, 'lr': self.learning_rate}
        ], weight_decay=0.02)  # Increased weight decay
        
        # Cosine annealing with warm restarts
        self.scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
            self.optimizer,
            T_0=5,  # Restart every 5 epochs
            T_mult=2,  # Double the restart interval after each restart
            eta_min=1e-6
        )
        
        self.scaler = GradScaler()

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        correct = 0
        total = 0
        
        pbar = tqdm(self.train_loader, desc='Training')
        for batch_idx, (images, targets) in enumerate(pbar):
            images, targets = images.to(self.device), targets.to(self.device)
            
            self.optimizer.zero_grad(set_to_none=True)  # More efficient than zero_grad()
            
            with autocast():
                outputs = self.model(images)
                loss = self.criterion(outputs, targets)
            
            self.scaler.scale(loss).backward()
            
            # Gradient clipping
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            pbar.set_postfix({
                'loss': total_loss / (batch_idx + 1),
                'acc': 100. * correct / total
            })
        
        self.scheduler.step()
        return total_loss / len(self.train_loader), 100. * correct / total

    @torch.no_grad()
    def validate(self):
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        class_correct = torch.zeros(self.num_classes, device=self.device)
        class_total = torch.zeros(self.num_classes, device=self.device)
        
        for images, targets in tqdm(self.val_loader, desc='Validating'):
            images, targets = images.to(self.device), targets.to(self.device)
            
            with autocast():
                outputs = self.model(images)
                loss = self.criterion(outputs, targets)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Per-class accuracy
            for i in range(self.num_classes):
                mask = targets == i
                class_correct[i] += predicted[mask].eq(targets[mask]).sum().item()
                class_total[i] += mask.sum().item()
        
        # Print per-class accuracy
        for i in range(self.num_classes):
            if class_total[i] > 0:
                print(f'Accuracy of class {i}: {100 * class_correct[i] / class_total[i]:.2f}%')
        
        return total_loss / len(self.val_loader), 100. * correct / total

    @torch.no_grad()
    def test(self):
        """Test the model on the test set"""
        self.model.eval()
        total_loss = 0
        correct = 0
        total = 0
        class_correct = torch.zeros(self.num_classes, device=self.device)
        class_total = torch.zeros(self.num_classes, device=self.device)
        
        for images, targets in tqdm(self.test_loader, desc='Testing'):
            images, targets = images.to(self.device), targets.to(self.device)
            
            with autocast():
                outputs = self.model(images)
                loss = self.criterion(outputs, targets)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Per-class accuracy
            for i in range(self.num_classes):
                mask = targets == i
                class_correct[i] += predicted[mask].eq(targets[mask]).sum().item()
                class_total[i] += mask.sum().item()
        
        print(f"\nTest Results:")
        print(f"Test Loss: {total_loss / len(self.test_loader):.4f}")
        print(f"Test Accuracy: {100. * correct / total:.2f}%")
        
        # Print per-class test accuracy
        for i in range(self.num_classes):
            if class_total[i] > 0:
                print(f'Test Accuracy of class {i}: {100 * class_correct[i] / class_total[i]:.2f}%')
        
        return total_loss / len(self.test_loader), 100. * correct / total

    def train(self, num_epochs=50):  # Increased default epochs
        print("Starting training...")
        start_time = time.time()
        best_val_acc = 0
        patience = 10  # Early stopping patience
        no_improve = 0
        
        for epoch in range(num_epochs):
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            
            train_loss, train_acc = self.train_epoch()
            val_loss, val_acc = self.validate()
            
            # Append metrics to lists
            self.train_losses.append(train_loss)
            self.train_accuracies.append(train_acc)
            self.val_losses.append(val_loss)
            self.val_accuracies.append(val_acc)
            
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                self.save_checkpoint(epoch, val_loss, val_acc, is_best=True)
                no_improve = 0
                print("New best model saved!")
            else:
                no_improve += 1
                if no_improve >= patience:
                    print("Early stopping triggered!")
                    break
            
            # Regular checkpoint
            if (epoch + 1) % 5 == 0:
                self.save_checkpoint(epoch, val_loss, val_acc)
        
        # Save training and validation metrics to a JSON file
        metrics = {
            'train_losses': self.train_losses,
            'train_accuracies': self.train_accuracies,
            'val_losses': self.val_losses,
            'val_accuracies': self.val_accuracies
        }
        
        with open(self.log_dir / 'training_metrics.json', 'w') as f:
            json.dump(metrics, f)
        
        print(f"\nTraining completed. Best validation accuracy: {best_val_acc:.2f}%")
        
        # Run test evaluation after training
        print("\nRunning final test evaluation...")
        test_loss, test_acc = self.test()
        
        return best_val_acc, test_acc

    def save_checkpoint(self, epoch, val_loss, val_acc, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc
        }
        
        if is_best:
            torch.save(checkpoint, self.checkpoint_dir / 'best_model.pth')
        else:
            torch.save(checkpoint, self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pth')



In [None]:
# Usage
if __name__ == "__main__":
    trainer = ViTTrainer(
        data_dir="Path to your data directory", #enter your dataset folder
        num_classes=24,
        batch_size=32,
        learning_rate=1e-4,
        train_split=0.7,   # 70% for training
        val_split=0.15,    # 15% for validation
        test_split=0.15    # 15% for testing
    )
    best_val_acc, test_acc = trainer.train(num_epochs=30)