In [None]:
# First cell - Imports
import torch
import torch.nn as nn
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from sklearn.model_selection import StratifiedKFold
from tqdm.auto import tqdm
import os
from typing import Dict, List, Tuple
import numpy as np

In [None]:
#For Change the trainablr layer
# Second cell - Custom ViT Model
class CustomViT(nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.vit = base_model
        
        # Freeze all parameters initially
        for parameter in self.vit.parameters():
            parameter.requires_grad = False
            
        # Make only LayerNorm layers trainable
        for name, module in self.vit.named_modules():
            if isinstance(module, nn.LayerNorm):
                for param in module.parameters():
                    param.requires_grad = True
        
        # Modify the heads structure
        self.vit.heads = nn.Sequential(
            nn.Linear(in_features=768, out_features=2),
            nn.Linear(in_features=2, out_features=1),
            nn.Sigmoid()
        )
        
        # Make the heads trainable
        for param in self.vit.heads.parameters():
            param.requires_grad = True

    def forward(self, x):
        return self.vit(x)

In [None]:
# Third cell - Setup functions
def setup_model(device):
    pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
    base_model = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)
    model = CustomViT(base_model).to(device)
    
    # Verify trainable layers
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(f"Trainable layer: {name}")
    
    return model

In [None]:
# Fourth cell - Training and testing functions
def train_step_binary(model: torch.nn.Module,
                     dataloader: torch.utils.data.DataLoader,
                     loss_fn: torch.nn.Module,
                     optimizer: torch.optim.Optimizer,
                     device: torch.device) -> Tuple[float, float]:
    model.train()
    train_loss, train_acc = 0, 0
    
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device).float()
        
        # Forward pass
        y_pred = model(X)
        y_pred = y_pred.squeeze()
        
        # Calculate loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        
        # Optimizer zero grad
        optimizer.zero_grad()
        
        # Loss backward
        loss.backward()
        
        # Optimizer step
        optimizer.step()
        
        # Calculate accuracy
        y_pred_class = (y_pred > 0.5).float()
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)
    
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

def test_step_binary(model: torch.nn.Module,
                    dataloader: torch.utils.data.DataLoader,
                    loss_fn: torch.nn.Module,
                    device: torch.device) -> Tuple[float, float]:
    model.eval()
    test_loss, test_acc = 0, 0
    
    with torch.inference_mode():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device).float()
            
            # Forward pass
            test_pred = model(X)
            test_pred = test_pred.squeeze()
            
            # Calculate loss
            loss = loss_fn(test_pred, y)
            test_loss += loss.item()
            
            # Calculate accuracy
            test_pred_class = (test_pred > 0.5).float()
            test_acc += (test_pred_class == y).sum().item()/len(test_pred)
            
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc

In [None]:
#For half Datasets
NUM_WORKERS = os.cpu_count()
def create_dataloaders_with_cross_validation(
    dataset_dir: str,
    transform: transforms.Compose,
    batch_size: int,
    sampling_ratio: float = 0.5,  # Added sampling ratio parameter
    num_splits: int = 5,
    num_workers: int = NUM_WORKERS
):
    # Use ImageFolder to create the dataset
    full_dataset = datasets.ImageFolder(dataset_dir, transform=transform)
    
    # Get indices for each class
    class_indices = {i: [] for i in range(len(full_dataset.classes))}
    for idx, (_, label) in enumerate(full_dataset):
        class_indices[label].append(idx)
    
    # Randomly sample indices from each class
    sampled_indices = []
    for class_idx, indices in class_indices.items():
        n_samples = int(len(indices) * sampling_ratio)
        sampled_indices.extend(np.random.choice(indices, size=n_samples, replace=False))
    
    # Shuffle the sampled indices
    np.random.shuffle(sampled_indices)
    
    # Create a subset of the dataset with only sampled indices
    sampled_dataset = torch.utils.data.Subset(full_dataset, sampled_indices)
    sampled_targets = [full_dataset.targets[i] for i in sampled_indices]
    
    print(f"Original dataset size: {len(full_dataset)}")
    print(f"Sampled dataset size: {len(sampled_dataset)} ({sampling_ratio*100}%)")
    
    # Initialize StratifiedKFold for cross-validation
    skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)
    # Initialize lists to store train and test data loaders for each split
    train_dataloaders = []
    test_dataloaders = []
    # Get class names
    class_names = full_dataset.classes
    
    for train_indices, test_indices in skf.split(range(len(sampled_dataset)), sampled_targets):
        # Create train and test datasets for the current split
        train_dataset = torch.utils.data.Subset(sampled_dataset, train_indices)
        test_dataset = torch.utils.data.Subset(sampled_dataset, test_indices)
        # Turn images into data loaders
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
        )
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
        )
        train_dataloaders.append(train_dataloader)
        test_dataloaders.append(test_dataloader)
    return train_dataloaders, test_dataloaders, class_names


In [None]:
# For Full Dataset
# Fifth cell - Dataloader and training functions
# def create_dataloaders_with_cross_validation(
#     dataset_dir: str,
#     transform: transforms.Compose,
#     batch_size: int,
#     num_splits: int = 5,
#     num_workers: int = os.cpu_count()
# ):
#     full_dataset = datasets.ImageFolder(dataset_dir, transform=transform)
#     skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)
    
#     train_dataloaders = []
#     test_dataloaders = []
#     class_names = full_dataset.classes
    
#     for train_indices, test_indices in skf.split(full_dataset.imgs, full_dataset.targets):
#         train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
#         test_dataset = torch.utils.data.Subset(full_dataset, test_indices)
        
#         train_dataloader = DataLoader(
#             train_dataset,
#             batch_size=batch_size,
#             shuffle=True,
#             num_workers=num_workers,
#             pin_memory=True,
#         )
#         test_dataloader = DataLoader(
#             test_dataset,
#             batch_size=batch_size,
#             shuffle=False,
#             num_workers=num_workers,
#             pin_memory=True,
#         )
        
#         train_dataloaders.append(train_dataloader)
#         test_dataloaders.append(test_dataloader)
    
#     return train_dataloaders, test_dataloaders, class_names

def train_with_cross_validation(model, train_dataloaders, test_dataloaders, 
                              optimizer, loss_fn, epochs, device,
                              train_step_fn, test_step_fn):
    all_results = []
    
    for split in range(len(train_dataloaders)):
        results = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}
        
        for epoch in tqdm(range(epochs)):
            train_loss, train_acc = train_step_fn(
                model=model,
                dataloader=train_dataloaders[split],
                loss_fn=loss_fn,
                optimizer=optimizer,
                device=device
            )
            
            test_loss, test_acc = test_step_fn(
                model=model,
                dataloader=test_dataloaders[split],
                loss_fn=loss_fn,
                device=device
            )
            
            print(
                f"Split: {split+1} | Epoch: {epoch+1} | "
                f"train_loss: {train_loss:.4f} | "
                f"train_acc: {train_acc:.4f} | "
                f"test_loss: {test_loss:.4f} | "
                f"test_acc: {test_acc:.4f}"
            )
            
            results["train_loss"].append(train_loss)
            results["train_acc"].append(train_acc)
            results["test_loss"].append(test_loss)
            results["test_acc"].append(test_acc)
            
        all_results.append(results)
        
    return all_results

In [None]:
from google.colab import drive
drive.mount("/content/gdrive")

In [None]:
# Sixth cell - Setup and run training
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set dataset directory
dataset_dir = "/content/gdrive/MyDrive/Colab Notebooks/CMID/"  # Your dataset path

# Create model
model = setup_model(device)

# Get transforms
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT
transforms = pretrained_vit_weights.transforms()

# Create dataloaders
train_dataloaders, test_dataloaders, class_names = create_dataloaders_with_cross_validation(
    dataset_dir=dataset_dir,
    transform=transforms,
    batch_size=32,
    num_splits=5,
    sampling_ratio=0.5
)

# Setup loss and optimizer
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(
    [p for p in model.parameters() if p.requires_grad],
    lr=1e-3
)

# Train with cross validation
results = train_with_cross_validation(
    model=model,
    train_dataloaders=train_dataloaders,
    test_dataloaders=test_dataloaders,
    optimizer=optimizer,
    loss_fn=loss_fn,
    epochs=100,
    device=device,
    train_step_fn=train_step_binary,
    test_step_fn=test_step_binary
)


In [None]:
# from torchinfo import summary
# # Print model summary
# summary(model=model,
#         input_size=(32, 3, 224, 224),
#         col_names=["input_size", "output_size", "num_params", "trainable"],
#         col_width=20,
#         row_settings=["var_names"])