In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
from sklearn.model_selection import StratifiedKFold
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np
from collections import Counter
from sklearn.ensemble import RandomForestClassifier

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
pretrained_vit_weights = models.ViT_B_16_Weights.DEFAULT
pretrained_vit = models.vit_b_16(weights=pretrained_vit_weights).to(device)

In [None]:
pretrained_vit_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])

In [None]:
# For Changing Trainable Layers
class VitFeatures(nn.Module):
    def __init__(self, vit_model, num_trainable_layers=2):
        super().__init__()
        self.vit = vit_model

        # Freeze all parameters first
        for param in self.vit.parameters():
            param.requires_grad = True

        # Unfreeze only the last n transformer layers
        # num_layers = len(self.vit.encoder.layers)
        # for i in range(num_layers - num_trainable_layers, num_layers):
        #     for param in self.vit.encoder.layers[i].parameters():
        #         param.requires_grad = True
        for name, module in self.vit.named_modules():
            if isinstance(module, nn.LayerNorm):
                for param in module.parameters():
                    param.requires_grad = True
        
        # Modify the heads structure
        self.vit.heads = nn.Sequential(
            nn.Linear(in_features=768, out_features=2),
            nn.Linear(in_features=2, out_features=1),
            nn.Sigmoid()
        )

        # Print trainable layers info
        total_params = sum(p.numel() for p in self.vit.parameters())
        trainable_params = sum(p.numel() for p in self.vit.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        print(f"Percentage of trainable parameters: {100 * trainable_params / total_params:.2f}%")

    def forward(self, x):
        x = self.vit.conv_proj(x)
        x = x.flatten(2).transpose(1, 2)
        x = torch.cat([self.vit.class_token.expand(x.shape[0], -1, -1), x], dim=1)
        x = self.vit.encoder(x)
        return x[:, 0]  # Return the [CLS] token features


In [None]:
# Full Dataset
def create_dataloaders_with_cross_validation(
    dataset_dir: str,
    transform: transforms.Compose,
    batch_size: int,
    num_splits: int = 5,
    num_workers: int = os.cpu_count()
):
    # Use ImageFolder to create the dataset
    full_dataset = datasets.ImageFolder(dataset_dir, transform=transform)

    # Initialize StratifiedKFold for cross-validation
    skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)

    # Initialize lists to store train and test data loaders for each split
    train_dataloaders = []
    test_dataloaders = []

    # Get class names
    class_names = full_dataset.classes

    for train_indices, test_indices in skf.split(range(len(full_dataset)), full_dataset.targets):
        # Create train and test datasets for the current split
        train_dataset = torch.utils.data.Subset(full_dataset, train_indices)
        test_dataset = torch.utils.data.Subset(full_dataset, test_indices)

        # Turn images into data loaders
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True,
        )
        test_dataloader = DataLoader(
            test_dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True,
        )

        train_dataloaders.append(train_dataloader)
        test_dataloaders.append(test_dataloader)

    return train_dataloaders, test_dataloaders, class_names


In [None]:
#Half Dataset
# NUM_WORKERS = os.cpu_count()
# def create_dataloaders_with_cross_validation(
#     dataset_dir: str,
#     transform: transforms.Compose,
#     batch_size: int,
#     sampling_ratio: float = 0.5,  # Added sampling ratio parameter
#     num_splits: int = 5,
#     num_workers: int = NUM_WORKERS
# ):
#     # Use ImageFolder to create the dataset
#     full_dataset = datasets.ImageFolder(dataset_dir, transform=transform)
    
#     # Get indices for each class
#     class_indices = {i: [] for i in range(len(full_dataset.classes))}
#     for idx, (_, label) in enumerate(full_dataset):
#         class_indices[label].append(idx)
    
#     # Randomly sample indices from each class
#     sampled_indices = []
#     for class_idx, indices in class_indices.items():
#         n_samples = int(len(indices) * sampling_ratio)
#         sampled_indices.extend(np.random.choice(indices, size=n_samples, replace=False))
    
#     # Shuffle the sampled indices
#     np.random.shuffle(sampled_indices)
    
#     # Create a subset of the dataset with only sampled indices
#     sampled_dataset = torch.utils.data.Subset(full_dataset, sampled_indices)
#     sampled_targets = [full_dataset.targets[i] for i in sampled_indices]
    
#     print(f"Original dataset size: {len(full_dataset)}")
#     print(f"Sampled dataset size: {len(sampled_dataset)} ({sampling_ratio*100}%)")
    
#     # Initialize StratifiedKFold for cross-validation
#     skf = StratifiedKFold(n_splits=num_splits, shuffle=True, random_state=42)
#     # Initialize lists to store train and test data loaders for each split
#     train_dataloaders = []
#     test_dataloaders = []
#     # Get class names
#     class_names = full_dataset.classes
    
#     for train_indices, test_indices in skf.split(range(len(sampled_dataset)), sampled_targets):
#         # Create train and test datasets for the current split
#         train_dataset = torch.utils.data.Subset(sampled_dataset, train_indices)
#         test_dataset = torch.utils.data.Subset(sampled_dataset, test_indices)
#         # Turn images into data loaders
#         train_dataloader = DataLoader(
#             train_dataset,
#             batch_size=batch_size,
#             shuffle=True,
#             num_workers=num_workers,
#             pin_memory=True,
#         )
#         test_dataloader = DataLoader(
#             test_dataset,
#             batch_size=batch_size,
#             shuffle=False,
#             num_workers=num_workers,
#             pin_memory=True,
#         )
#         train_dataloaders.append(train_dataloader)
#         test_dataloaders.append(test_dataloader)
#     return train_dataloaders, test_dataloaders, class_names

In [None]:
def extract_features(dataloader, model, device):
    """
    Extract features from the ViT model using the full forward pass.
    """
    model.eval()
    features, labels = [], []

    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            features_batch = model(X)
            features.append(features_batch.cpu().numpy())
            labels.append(y.numpy())

    return np.vstack(features), np.hstack(labels)

In [None]:
def evaluate_metrics(y_true, y_pred):
    """
    Calculate all metrics for a given set of predictions.
    """
    return {
        'accuracy': (y_pred == y_true).mean(),
        'precision': precision_score(y_true, y_pred, average="weighted"),
        'recall': recall_score(y_true, y_pred, average="weighted"),
        'f1': f1_score(y_true, y_pred, average="weighted")
    }

In [None]:
# SVM Model
def train_svm_with_cross_validation(train_dataloaders, test_dataloaders, vit_model, device):
    all_results = []

    # Create feature extractor with only last 2 layers trainable
    feature_extractor = VitFeatures(vit_model, num_trainable_layers=2).to(device)

    for split in range(len(train_dataloaders)):
        print(f"\nTraining on Split {split+1}")

        # Extract features
        train_features, train_labels = extract_features(train_dataloaders[split], feature_extractor, device)
        test_features, test_labels = extract_features(test_dataloaders[split], feature_extractor, device)

        # Train SVM
        svm_classifier = make_pipeline(StandardScaler(), SVC(kernel='rbf', C=1, gamma=0.0001, probability=True))
        svm_classifier.fit(train_features, train_labels)

        # Get predictions
        train_pred = svm_classifier.predict(train_features)
        test_pred = svm_classifier.predict(test_features)

        # Calculate metrics
        train_metrics = evaluate_metrics(train_labels, train_pred)
        test_metrics = evaluate_metrics(test_labels, test_pred)

        # Print results
        print(f"Split {split+1} Results:")
        print(f"Train Metrics - Accuracy: {train_metrics['accuracy']*100:.2f}%, "
              f"Precision: {train_metrics['precision']:.4f}, "
              f"Recall: {train_metrics['recall']:.4f}, "
              f"F1: {train_metrics['f1']:.4f}")
        print(f"Test Metrics  - Accuracy: {test_metrics['accuracy']*100:.2f}%, "
              f"Precision: {test_metrics['precision']:.4f}, "
              f"Recall: {test_metrics['recall']:.4f}, "
              f"F1: {test_metrics['f1']:.4f}")

        # Store results
        all_results.append({
            'split': split + 1,
            'train': train_metrics,
            'test': test_metrics
        })

    # Print average results across all splits
    avg_train_acc = np.mean([r['train']['accuracy'] for r in all_results])
    avg_test_acc = np.mean([r['test']['accuracy'] for r in all_results])
    print(f"\nAverage Results Across All Splits:")
    print(f"Average Train Accuracy: {avg_train_acc*100:.2f}%")
    print(f"Average Test Accuracy: {avg_test_acc*100:.2f}%")

    return all_results


In [None]:
#RF Model
# def train_svm_with_cross_validation(train_dataloaders, test_dataloaders, vit_model, device, n_estimators=32):
#     all_results = []

#     # Create feature extractor with only last 2 layers trainable
#     feature_extractor = VitFeatures(vit_model).to(device)

#     for split in range(len(train_dataloaders)):
#         print(f"\nTraining on Split {split+1}")

#         # Extract features
#         train_features, train_labels = extract_features(train_dataloaders[split], feature_extractor, device)
#         test_features, test_labels = extract_features(test_dataloaders[split], feature_extractor, device)

#         # Train Random Forest 
#         rf_classifier = RandomForestClassifier(n_estimators=32, max_leaf_nodes=16, n_jobs=-1, random_state=42)  
#         rf_classifier.fit(train_features, train_labels)

#         # Get predictions
#         train_pred = rf_classifier.predict(train_features)
#         test_pred = rf_classifier.predict(test_features)

#         # Calculate metrics
#         train_metrics = evaluate_metrics(train_labels, train_pred)
#         test_metrics = evaluate_metrics(test_labels, test_pred)

#         # Print results
#         print(f"Split {split+1} Results:")
#         print(f"Train Metrics - Accuracy: {train_metrics['accuracy']*100:.2f}%, "
#               f"Precision: {train_metrics['precision']:.4f}, "
#               f"Recall: {train_metrics['recall']:.4f}, "
#               f"F1: {train_metrics['f1']:.4f}")
#         print(f"Test Metrics  - Accuracy: {test_metrics['accuracy']*100:.2f}%, "
#               f"Precision: {test_metrics['precision']:.4f}, "
#               f"Recall: {test_metrics['recall']:.4f}, "
#               f"F1: {test_metrics['f1']:.4f}")

#         # Store results
#         all_results.append({
#             'split': split + 1,
#             'train': train_metrics,
#             'test': test_metrics
#         })

#     # Print average results across all splits
#     avg_train_acc = np.mean([r['train']['accuracy'] for r in all_results])
#     avg_test_acc = np.mean([r['test']['accuracy'] for r in all_results])
#     print(f"\nAverage Results Across All Splits:")
#     print(f"Average Train Accuracy: {avg_train_acc*100:.2f}%")
#     print(f"Average Test Accuracy: {avg_test_acc*100:.2f}%")

#     return all_results


In [None]:
dataset_dir = "/kaggle/input/cmid-dataset/CMID"

In [None]:
# Full Dataset
train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders_with_cross_validation(
    dataset_dir=dataset_dir,
    transform=pretrained_vit_transforms,
    batch_size=32,
    num_splits=5
)

In [None]:
# Half Dataset
# train_dataloader_pretrained, test_dataloader_pretrained, class_names = create_dataloaders_with_cross_validation(
#     dataset_dir=dataset_dir,
#     transform=pretrained_vit_transforms,
#     batch_size=32,
#     sampling_ratio=0.5,  # Use 50% of the data
#     num_splits=5
# )

In [None]:
svm_results = train_svm_with_cross_validation(
    train_dataloader_pretrained,
    test_dataloader_pretrained,
    pretrained_vit,
    device
)

In [None]:
# rf_results = train_svm_with_cross_validation(
#     train_dataloader_pretrained,
#     test_dataloader_pretrained,
#     pretrained_vit,
#     device,
#     n_estimators=32
# )