<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/pmicl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torchvision numpy pillow scikit-learn matplotlib torch-geometric
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.6.0+cpu.html

Looking in links: https://data.pyg.org/whl/torch-2.6.0+cpu.html


#Dataset_setting

In [2]:
%%writefile /content/oasis_dataset.py
import os
import zipfile
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

ZIP_PATH = '/content/oaisis.zip'
EXTRACT_DIR = '/content/oasis_data/'

def extract_zip():
    try:
        if not os.path.exists(ZIP_PATH):
            raise FileNotFoundError(f"{ZIP_PATH} not found. Please upload the file to Colab.")
        os.makedirs(EXTRACT_DIR, exist_ok=True)
        if not os.listdir(EXTRACT_DIR):
            with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
                zip_ref.extractall(EXTRACT_DIR)
            print(f"Extracted {ZIP_PATH} to {EXTRACT_DIR}")
        else:
            print(f"Directory {EXTRACT_DIR} already contains files, skipping extraction.")
    except Exception as e:
        print(f"Error extracting ZIP file: {e}")
        raise

extract_zip()

class PatchExtractor:
    def __init__(self, patch_size=64, K=80, spatial_threshold=5):
        self.patch_size = patch_size
        self.K = K
        self.spatial_threshold = spatial_threshold

    def extract_patches(self, image, prob_map):
        patches = []
        coords = []
        prob_map_copy = prob_map.copy()
        for _ in range(self.K):
            max_prob_idx = np.argmax(prob_map_copy)
            y, x = np.unravel_index(max_prob_idx, prob_map_copy.shape)
            patch = self.get_patch(image, (x, y))
            patches.append(patch)
            coords.append((x, y))
            prob_map_copy = self.mask_neighbors(prob_map_copy, (x, y))
        return patches, coords

    def get_patch(self, image, center):
        half_size = self.patch_size // 2
        x, y = center
        patch = image[
            max(0, y - half_size):y + half_size,
            max(0, x - half_size):x + half_size
        ]
        if patch.shape[0] < self.patch_size or patch.shape[1] < self.patch_size:
            patch = np.pad(patch, [(0, max(0, self.patch_size - patch.shape[0])),
                                   (0, max(0, self.patch_size - patch.shape[1]))],
                           mode='constant')
        return patch

    def mask_neighbors(self, prob_map, center):
        half_size = self.spatial_threshold
        x, y = center
        prob_map[
            max(0, y - half_size):y + half_size + 1,
            max(0, x - half_size):x + half_size + 1
        ] = 0
        return prob_map

    def dynamic_sample(self, image, N=40):
        candidate_patches = []
        candidate_coords = []
        stride = self.patch_size // 2
        h, w = image.shape
        for y in range(0, h - self.patch_size + 1, stride):
            for x in range(0, w - self.patch_size + 1, stride):
                patch = image[y:y+self.patch_size, x:x+self.patch_size]
                candidate_patches.append(patch)
                candidate_coords.append((x, y))
        indices = np.random.choice(len(candidate_patches),
                                   min(N, len(candidate_patches)),
                                   replace=False)
        return [candidate_patches[i] for i in indices], [candidate_coords[i] for i in indices]

class OasisDataset(Dataset):
    def __init__(self, data_dir, patch_size=64, n_sampled_patches=40):
        self.data_dir = os.path.join(data_dir, 'Data')
        self.patch_size = patch_size
        self.n_sampled_patches = n_sampled_patches
        self.patch_extractor = PatchExtractor(patch_size=patch_size)
        self.class_map = {
            'Non Demented': 0,
            'Very mild Dementia': 1,
            'Mild Dementia': 2,
            'Moderate Dementia': 3
        }
        self.image_paths = []
        self.labels = []

        print(f"Looking for images in {self.data_dir}")
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"Data directory {self.data_dir} not found")

        available_dirs = os.listdir(self.data_dir)
        print(f"Available directories: {available_dirs}")

        class_valid_paths = {}
        for class_name in self.class_map:
            matching_dir = next((d for d in available_dirs if d.lower() == class_name.lower()), None)
            if not matching_dir:
                print(f"Warning: No directory found for {class_name}")
                continue
            class_dir = os.path.join(self.data_dir, matching_dir)
            print(f"Checking directory: {class_dir}")
            img_files = [f for f in os.listdir(class_dir) if f.lower().endswith('.jpg')]
            img_paths = [os.path.join(class_dir, f) for f in img_files]
            print(f"Found {len(img_paths)} .jpg files in {class_dir}")

            valid_paths = []
            for img_path in img_paths:
                try:
                    with Image.open(img_path) as img:
                        img.verify()
                    valid_paths.append(img_path)
                except Exception as e:
                    print(f"Warning: Failed to load {img_path}: {e}")
            class_valid_paths[class_name] = valid_paths

        if not class_valid_paths:
            raise ValueError(f"No valid images found in {self.data_dir}.")

        min_images = min(len(paths) for paths in class_valid_paths.values())
        self.images_per_class = min_images

        for class_name in self.class_map:
            if class_name not in class_valid_paths:
                continue
            valid_paths = class_valid_paths[class_name]
            sampled_paths = np.random.choice(valid_paths,
                                             self.images_per_class,
                                             replace=False)
            self.image_paths.extend(sampled_paths)
            self.labels.extend([self.class_map[class_name]] * self.images_per_class)

        print(f"Loaded {len(self.image_paths)} images: "
              f"{len([l for l in self.labels if l == 0])} CN, "
              f"{len([l for l in self.labels if l == 1])} MCI, "
              f"{len([l for l in self.labels if l == 2])} Mild, "
              f"{len([l for l in self.labels if l == 3])} Moderate")

        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.RandomRotation(10),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        try:
            image = np.array(Image.open(img_path).convert('L')) / 255.0
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            raise
        prob_map = np.random.rand(image.shape[0], image.shape[1])
        patches, coords = self.patch_extractor.extract_patches(image, prob_map)
        sampled_patches, sampled_coords = self.patch_extractor.dynamic_sample(
            image, self.n_sampled_patches)
        patches_tensor = torch.stack([self.transform(patch) for patch in sampled_patches])
        label_tensor = torch.tensor(label, dtype=torch.long)
        return patches_tensor, label_tensor, sampled_coords


def custom_collate_fn(batch):
    patches = torch.stack([item[0] for item in batch])
    labels = torch.stack([item[1] for item in batch])
    coords = [item[2] for item in batch]
    return patches, labels, coords


def get_dataloader(data_dir, batch_size=2):
    dataset = OasisDataset(data_dir, patch_size=64, n_sampled_patches=40)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,      # finite, randomized per epoch
        num_workers=0,
        collate_fn=custom_collate_fn
    )


Overwriting /content/oasis_dataset.py


In [3]:
from oasis_dataset import OasisDataset
import numpy as np
try:
    dataset = OasisDataset('/content/oasis_data/', subset_size=1000)
    print(f"Class distribution: {np.bincount(dataset.labels)}")
except Exception as e:
    print(f"Dataset error: {e}")

Directory /content/oasis_data/ already contains files, skipping extraction.
Dataset error: OasisDataset.__init__() got an unexpected keyword argument 'subset_size'


In [4]:
from oasis_dataset import get_dataloader
try:
    dataloader = get_dataloader('/content/oasis_data/', batch_size=2)
    for patches, labels, coords in dataloader:
        print("Patches shape:", patches.shape)  # Expected: [2, 20, 1, 32, 32]
        print("Labels shape:", labels.shape)   # Expected: [2]
        print("Coords length:", len(coords))   # Expected: 2
        break
except Exception as e:
    print(f"DataLoader error: {e}")

Looking for images in /content/oasis_data/Data
Available directories: ['Mild Dementia', 'Non Demented', 'Moderate Dementia', 'Very mild Dementia']
Checking directory: /content/oasis_data/Data/Non Demented
Found 67222 .jpg files in /content/oasis_data/Data/Non Demented
Checking directory: /content/oasis_data/Data/Very mild Dementia
Found 13725 .jpg files in /content/oasis_data/Data/Very mild Dementia
Checking directory: /content/oasis_data/Data/Mild Dementia
Found 5002 .jpg files in /content/oasis_data/Data/Mild Dementia
Checking directory: /content/oasis_data/Data/Moderate Dementia
Found 488 .jpg files in /content/oasis_data/Data/Moderate Dementia
Loaded 1952 images: 488 CN, 488 MCI, 488 Mild, 488 Moderate
Patches shape: torch.Size([2, 40, 1, 64, 64])
Labels shape: torch.Size([2])
Coords length: 2


#pmicl_model

In [5]:
%%writefile /content/pmicl_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
import torch.nn.init as init

class GraphConstructor:
    def __init__(self, k=10):
        self.k = k

    def build_graph(self, features):
        from sklearn.neighbors import kneighbors_graph
        adj = kneighbors_graph(features, n_neighbors=self.k, mode='distance', include_self=False)
        edge_index = torch.tensor(adj.nonzero(), dtype=torch.long)
        edge_weight = torch.tensor(adj[adj.nonzero()], dtype=torch.float)
        edge_weight = 1.0 / (edge_weight + 1e-6)  # Inverse distance
        return edge_index, edge_weight

def graph_loss(embeddings, edge_index, edge_weight):
    print(f"graph_loss: embeddings shape: {embeddings.shape}, edge_index max: {edge_index.max()}")
    source = embeddings[edge_index[0]]
    target = embeddings[edge_index[1]]
    similarity = F.cosine_similarity(source, target, dim=-1)
    print(f"Cosine similarity (sample): {similarity[:5]}")
    loss = -torch.mean(edge_weight * similarity)
    print(f"Graph loss: {loss.item()}")
    return loss

class PMICL(nn.Module):
    def __init__(self, num_classes=4, embed_dim=128, num_prototypes=10):
        super(PMICL, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, embed_dim)
        self.gcn = GCNConv(embed_dim, embed_dim)
        self.fc2 = nn.Linear(embed_dim, num_classes)
        self.prototypes = nn.Parameter(torch.randn(num_prototypes, embed_dim))
        self.dropout = nn.Dropout(0.5)

        # Kaiming initialization for fc1 and fc2, normal for prototypes
        init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
        init.kaiming_normal_(self.fc2.weight, mode='fan_out', nonlinearity='relu')
        init.normal_(self.prototypes, mean=0.0, std=0.01)

    def forward(self, patches, edge_index, batch, labels=None):
        B, N, C, H, W = patches.shape
        x = patches.view(B * N, C, H, W)
        x = self.bn1(F.relu(self.conv1(x)))
        x = self.pool(x)
        x = self.bn2(F.relu(self.conv2(x)))
        x = self.pool(x)
        x = self.bn3(F.relu(self.conv3(x)))
        x = self.pool(x)
        x = x.view(B * N, -1)
        x = self.dropout(F.relu(self.fc1(x)))

        print(f"Before GCN - embeddings shape: {x.shape}, Min: {x.min()}, Max: {x.max()}")
        x_gcn = F.relu(self.gcn(x, edge_index))
        print(f"GCN embeddings - Min: {x_gcn.min()}, Max: {x_gcn.max()}")

        attention_weights = F.softmax(torch.matmul(x_gcn, x_gcn.t()) / (x_gcn.shape[-1] ** 0.5), dim=-1)
        x = torch.matmul(attention_weights, x_gcn)
        print(f"After Attention - Min: {x.min()}, Max: {x.max()}")

        graph_loss_val = graph_loss(x, edge_index, torch.ones(edge_index.shape[1], device=x.device))

        x = x.view(B, N, -1)
        x = x.mean(dim=1)
        logits = self.fc2(x)
        cls_loss = F.cross_entropy(logits, labels) if labels is not None else torch.tensor(0.0, device=logits.device)
        proto_dist = torch.cdist(x, self.prototypes)
        proto_loss = proto_dist.mean()

        return logits, cls_loss, proto_loss, graph_loss_val, attention_weights


Overwriting /content/pmicl_model.py


#Train


In [6]:
%%writefile /content/train.py
import torch
import torch.nn as nn
import torch.nn.functional as F  # Fix for NameError
import torch.optim as optim
from torch.utils.data import DataLoader
from oasis_dataset import get_dataloader
from pmicl_model import PMICL, GraphConstructor

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_EPOCHS = 5
BATCH_SIZE = 2
LEARNING_RATE = 1e-2  # Increased from 5e-3
NUM_CLASSES = 4
PATIENCE = 5

def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    cls_loss_total = 0
    proto_loss_total = 0
    graph_loss_total = 0
    correct = 0
    total = 0
    graph_constructor = GraphConstructor(k=5)

    for patches, labels, coords in dataloader:
        patches, labels = patches.to(device), labels.to(device)
        B, N, C, H, W = patches.shape
        with torch.no_grad():
            x = patches.view(B * N, C, H, W)
            x = model.bn1(F.relu(model.conv1(x)))
            x = model.pool(x)
            x = model.bn2(F.relu(model.conv2(x)))
            x = model.pool(x)
            x = model.bn3(F.relu(model.conv3(x)))
            x = model.pool(x)
            x = x.view(B * N, -1)
            x = F.relu(model.fc1(x))
        patch_features = x.cpu().numpy()
        edge_index, edge_weight = graph_constructor.build_graph(patch_features)
        edge_index, edge_weight = edge_index.to(device), edge_weight.to(device)
        batch = torch.repeat_interleave(torch.arange(B, device=device), N)

        optimizer.zero_grad()
        logits, cls_loss, proto_loss, graph_loss, _ = model(patches, edge_index, batch, labels)
        print(f"Graph loss (batch): {graph_loss.item():.4f}")

        loss = 0.55 * cls_loss + 0.4 * proto_loss + 0.05 * graph_loss  # Reduced graph loss weight
        loss.backward()

        # Monitor gradient norms
        total_grad_norm = 0
        for p in model.parameters():
            if p.grad is not None:
                total_grad_norm += p.grad.norm().item() ** 2
        total_grad_norm = total_grad_norm ** 0.5
        print(f"Gradients - GCN weights: {model.gcn.lin.weight.grad}")
        print(f"Total gradient norm: {total_grad_norm:.6f}")

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item()
        cls_loss_total += cls_loss.item()
        proto_loss_total += proto_loss.item()
        graph_loss_total += graph_loss.item()

        preds = torch.argmax(logits, dim=1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    train_acc = correct / total
    print(f"Graph Loss (Epoch Total): {graph_loss_total / len(dataloader):.4f}")
    return (total_loss / len(dataloader), cls_loss_total / len(dataloader),
            proto_loss_total / len(dataloader), graph_loss_total / len(dataloader), train_acc)

def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    graph_constructor = GraphConstructor(k=5)

    with torch.no_grad():
        for patches, labels, coords in dataloader:
            patches, labels = patches.to(device), labels.to(device)
            B, N, C, H, W = patches.shape
            x = patches.view(B * N, C, H, W)
            x = model.bn1(F.relu(model.conv1(x)))
            x = model.pool(x)
            x = model.bn2(F.relu(model.conv2(x)))
            x = model.pool(x)
            x = model.bn3(F.relu(model.conv3(x)))
            x = model.pool(x)
            x = x.view(B * N, -1)
            x = F.relu(model.fc1(x))
            patch_features = x.cpu().numpy()
            edge_index, edge_weight = graph_constructor.build_graph(patch_features)
            edge_index, edge_weight = edge_index.to(device), edge_weight.to(device)
            batch = torch.repeat_interleave(torch.arange(B, device=device), N)

            logits, _, _, _, _ = model(patches, edge_index, batch)
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    return correct / total

def main():
    dataloader = get_dataloader('/content/oasis_data/', batch_size=BATCH_SIZE)
    model = PMICL(num_classes=NUM_CLASSES).to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)

    best_val_acc = 0
    patience_counter = 0

    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        loss, cls_loss, proto_loss, graph_loss, train_acc = train_epoch(model, dataloader, optimizer, DEVICE)
        print(f"Loss: {loss:.4f}, Cls: {cls_loss:.4f}, Proto: {proto_loss:.4f}, Graph: {graph_loss:.4f}, Train Acc: {train_acc:.4f}")

        val_acc = evaluate(model, dataloader, DEVICE)
        print(f"Val Acc: {val_acc:.4f}")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter >= PATIENCE:
                print("Early stopping triggered")
                break

if __name__ == "__main__":
    main()

Overwriting /content/train.py


In [7]:
!rm -rf /content/__pycache__

In [None]:
%run /content/train.py

Looking for images in /content/oasis_data/Data
Available directories: ['Mild Dementia', 'Non Demented', 'Moderate Dementia', 'Very mild Dementia']
Checking directory: /content/oasis_data/Data/Non Demented
Found 67222 .jpg files in /content/oasis_data/Data/Non Demented
Checking directory: /content/oasis_data/Data/Very mild Dementia
Found 13725 .jpg files in /content/oasis_data/Data/Very mild Dementia
Checking directory: /content/oasis_data/Data/Mild Dementia
Found 5002 .jpg files in /content/oasis_data/Data/Mild Dementia
Checking directory: /content/oasis_data/Data/Moderate Dementia
Found 488 .jpg files in /content/oasis_data/Data/Moderate Dementia
Loaded 1952 images: 488 CN, 488 MCI, 488 Mild, 488 Moderate
Epoch 1/5


  edge_index = torch.tensor(adj.nonzero(), dtype=torch.long)


Before GCN - embeddings shape: torch.Size([80, 128]), Min: 0.0, Max: 300.1095886230469
GCN embeddings - Min: 0.0, Max: 133.39749145507812
After Attention - Min: 0.0, Max: 133.39749145507812
graph_loss: embeddings shape: torch.Size([80, 128]), edge_index max: 79
Cosine similarity (sample): tensor([0.8900, 0.8900, 0.8900, 0.3943, 0.8900], grad_fn=<SliceBackward0>)
Graph loss: -0.7880587577819824
Graph loss (batch): -0.7881
Gradients - GCN weights: tensor([[ 3.5685e-02,  3.2288e-01,  9.0518e-02,  ...,  4.6428e-01,
         -4.6841e-12,  0.0000e+00],
        [-4.4381e+00, -1.1026e-08, -1.5231e-01,  ..., -4.1010e-10,
         -4.1534e-01, -2.7913e-02],
        [ 8.6499e+00,  1.2290e+00,  5.1658e-01,  ...,  2.2316e+00,
          1.3149e-01,  5.9684e-02],
        ...,
        [ 0.0000e+00,  0.0000e+00, -1.6083e-20,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 2.9020e+00,  3.2730e-01,  4.0778e-01,  ...,  6.0154e-01,
          2.4037e-01,  1.6121e-02],
        [ 4.2460e-01, 

#Evaluation


In [None]:
%%writefile /content/evaluate_visualize.py
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from torch.utils.data import DataLoader
from oasis_dataset import get_dataloader
from pmicl_model import PMICL, GraphConstructor

# Device configuration
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
BATCH_SIZE = 2
NUM_CLASSES = 4

def evaluate(model, dataloader, device):
    model.eval()
    preds = []
    true_labels = []
    graph_constructor = GraphConstructor()

    with torch.no_grad():
        for patches, labels, coords in dataloader:
            patches, labels = patches.to(device), labels.to(device)
            B, N, C, H, W = patches.shape
            patch_features = patches.view(B * N, -1).cpu().numpy()
            edge_index, edge_weight = graph_constructor.build_graph(patch_features)
            edge_index, edge_weight = edge_index.to(device), edge_weight.to(device)
            batch = torch.repeat_interleave(torch.arange(B, device=device), N)

            logits, _, _, _, _ = model(patches, edge_index, batch)
            preds.append(torch.softmax(logits, dim=1).cpu().numpy())
            true_labels.append(labels.cpu().numpy())

    preds = np.concatenate(preds)
    true_labels = np.concatenate(true_labels)

    acc = accuracy_score(true_labels, np.argmax(preds, axis=1))
    f1 = f1_score(true_labels, np.argmax(preds, axis=1), average='macro')
    auc = roc_auc_score(true_labels, preds, multi_class='ovr')
    return acc, f1, auc, true_labels, np.argmax(preds, axis=1)

def generate_attention_map(model, dataloader, device):
    model.eval()
    graph_constructor = GraphConstructor()

    with torch.no_grad():
        for patches, _, coords in dataloader:
            patches = patches.to(device)
            B, N, C, H, W = patches.shape
            patch_features = patches.view(B * N, -1).cpu().numpy()
            edge_index, edge_weight = graph_constructor.build_graph(patch_features)
            edge_index, edge_weight = edge_index.to(device), edge_weight.to(device)
            batch = torch.repeat_interleave(torch.arange(B, device=device), N)

            _, _, _, _, attention_weights = model(patches, edge_index, batch)
            attention_weights = attention_weights.cpu().numpy()

            attention_per_image = attention_weights[:N, :N]
            plt.figure(figsize=(10, 10))
            plt.imshow(attention_per_image, cmap='hot')
            plt.colorbar()
            plt.title("Attention Map")
            plt.savefig('/content/attention_map.png')
            plt.close()
            break

def plot_confusion_matrix(true_labels, pred_labels, class_names):
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig('/content/confusion_matrix.png')
    plt.close()

def plot_metrics(train_accs, val_accs, val_f1s, val_aucs, losses):
    epochs = range(1, len(train_accs) + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_accs, label='Training Accuracy')
    plt.plot(epochs, val_accs, label='Validation Accuracy')
    plt.title('Accuracy Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig('/content/accuracy_plot.png')
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, val_accs, label='Validation Accuracy')
    plt.plot(epochs, val_f1s, label='Validation F1 Score')
    plt.plot(epochs, val_aucs, label='Validation AUC')
    plt.title('Validation Metrics Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)
    plt.savefig('/content/validation_metrics_plot.png')
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, losses, label='Training Loss')
    plt.title('Training Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('/content/loss_plot.png')
    plt.close()

def main():
    dataloader = get_dataloader('/content/oasis_data/', batch_size=BATCH_SIZE)
    model = PMICL(num_classes=NUM_CLASSES).to(DEVICE)
    class_names = ['Non Demented', 'Very mild Dementia', 'Mild Dementia', 'Moderate Dementia']

    # Example metrics (replace with actual training results)
    train_accs = [0.5, 0.6, 0.7]  # Placeholder
    val_accs = [0.4, 0.5, 0.6]    # Placeholder
    val_f1s = [0.3, 0.4, 0.5]     # Placeholder
    val_aucs = [0.6, 0.7, 0.8]    # Placeholder
    losses = [1.0, 0.8, 0.6]      # Placeholder

    acc, f1, auc, true_labels, pred_labels = evaluate(model, dataloader, DEVICE)
    print(f"Val: Acc={acc:.4f}, F1={f1:.4f}, AUC={auc:.4f}")

    plot_confusion_matrix(true_labels, pred_labels, class_names)
    plot_metrics(train_accs, val_accs, val_f1s, val_aucs, losses)
    generate_attention_map(model, dataloader, DEVICE)

if __name__ == "__main__":
    main()