<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/pmicl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%writefile /content/oasis_dataset.py
import os
import zipfile
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from sklearn.model_selection import train_test_split

ZIP_PATH = '/content/oaisis.zip'
EXTRACT_DIR = '/content/oasis_data/'

def extract_zip():
    try:
        if not os.path.exists(ZIP_PATH):
            raise FileNotFoundError(f"{ZIP_PATH} not found. Please upload the file to Colab.")
        os.makedirs(EXTRACT_DIR, exist_ok=True)
        if not os.listdir(EXTRACT_DIR):
            with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
                zip_ref.extractall(EXTRACT_DIR)
            print(f"Extracted {ZIP_PATH} to {EXTRACT_DIR}")
        else:
            print(f"Directory {EXTRACT_DIR} already contains files, skipping extraction.")
    except Exception as e:
        print(f"Error extracting ZIP file: {e}")
        raise

extract_zip()

class PatchExtractor:
    def __init__(self, patch_size=32, K=80, spatial_threshold=5):
        self.patch_size = patch_size
        self.K = K
        self.spatial_threshold = spatial_threshold

    def extract_patches(self, image, prob_map):
        patches = []
        coords = []
        prob_map_copy = prob_map.copy()
        for _ in range(self.K):
            max_prob_idx = np.argmax(prob_map_copy)
            y, x = np.unravel_index(max_prob_idx, prob_map_copy.shape)
            patch = self.get_patch(image, (x, y))
            patches.append(patch)
            coords.append((x, y))
            prob_map_copy = self.mask_neighbors(prob_map_copy, (x, y))
        return patches, coords

    def get_patch(self, image, center):
        half_size = self.patch_size // 2
        x, y = center
        patch = image[
            max(0, y - half_size):y + half_size,
            max(0, x - half_size):x + half_size
        ]
        if patch.shape[0] < self.patch_size or patch.shape[1] < self.patch_size:
            patch = np.pad(patch, [(0, max(0, self.patch_size - patch.shape[0])),
                                   (0, max(0, self.patch_size - patch.shape[1]))],
                           mode='constant')
        return patch

    def mask_neighbors(self, prob_map, center):
        half_size = self.spatial_threshold
        x, y = center
        prob_map[
            max(0, y - half_size):y + half_size + 1,
            max(0, x - half_size):x + half_size + 1
        ] = 0
        return prob_map

    def dynamic_sample(self, image, N=20):
        candidate_patches = []
        candidate_coords = []
        stride = self.patch_size // 2
        h, w = image.shape
        for y in range(0, h - self.patch_size + 1, stride):
            for x in range(0, w - self.patch_size + 1, stride):
                patch = image[y:y+self.patch_size, x:x+self.patch_size]
                candidate_patches.append(patch)
                candidate_coords.append((x, y))
        indices = np.random.choice(len(candidate_patches), min(N, len(candidate_patches)), replace=False)
        return [candidate_patches[i] for i in indices], [candidate_coords[i] for i in indices]

class OasisDataset(Dataset):
    def __init__(self, data_dir, patch_size=32, n_sampled_patches=20, subset_size=1000):
        self.data_dir = os.path.join(data_dir, 'Data')
        self.patch_size = patch_size
        self.n_sampled_patches = n_sampled_patches
        self.patch_extractor = PatchExtractor(patch_size=patch_size)
        self.class_map = {
            'Non Demented': 0,
            'Very mild Dementia': 1,
            'Mild Dementia': 2,
            'Moderate Dementia': 3
        }
        self.image_paths = []
        self.labels = []
        self.subset_size = subset_size

        print(f"Looking for images in {self.data_dir}")
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"Data directory {self.data_dir} not found")

        available_dirs = os.listdir(self.data_dir)
        print(f"Available directories: {available_dirs}")

        for class_name in self.class_map:
            matching_dir = next((d for d in available_dirs if d.lower() == class_name.lower()), None)
            if not matching_dir:
                print(f"Warning: No directory found for {class_name}")
                continue
            class_dir = os.path.join(self.data_dir, matching_dir)
            print(f"Checking directory: {class_dir}")
            img_files = [f for f in os.listdir(class_dir) if f.lower().endswith('.jpg')]
            img_paths = [os.path.join(class_dir, f) for f in img_files]
            print(f"Found {len(img_paths)} .jpg files in {class_dir}")
            for img_path in img_paths:
                try:
                    with Image.open(img_path) as img:
                        img.verify()
                    self.image_paths.append(img_path)
                    self.labels.append(self.class_map[class_name])
                except Exception as e:
                    print(f"Warning: Failed to load {img_path}: {e}")

        if not self.image_paths:
            raise ValueError(f"No valid images found in {self.data_dir}.")

        total_images = len(self.image_paths)
        if self.subset_size > total_images:
            self.subset_size = total_images
            print(f"Subset size adjusted to {self.subset_size} (total available images)")

        if self.subset_size < total_images:
            X_subset, _, y_subset, _ = train_test_split(
                self.image_paths, self.labels, train_size=self.subset_size, stratify=self.labels, random_state=42
            )
            self.image_paths = X_subset
            self.labels = y_subset

        print(f"Loaded {len(self.image_paths)} images: "
              f"{len([l for l in self.labels if l == 0])} CN, "
              f"{len([l for l in self.labels if l == 1])} MCI, "
              f"{len([l for l in self.labels if l == 2])} Mild, "
              f"{len([l for l in self.labels if l == 3])} Moderate")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        try:
            image = np.array(Image.open(img_path).convert('L')) / 255.0
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            raise
        prob_map = np.random.rand(image.shape[0], image.shape[1])
        patches, coords = self.patch_extractor.extract_patches(image, prob_map)
        sampled_patches, sampled_coords = self.patch_extractor.dynamic_sample(image, self.n_sampled_patches)
        patches_tensor = torch.tensor(sampled_patches, dtype=torch.float).unsqueeze(1)
        label_tensor = torch.tensor(label, dtype=torch.long)
        return patches_tensor, label_tensor, coords

def get_dataloader(data_dir, batch_size=2, subset_size=1000):
    dataset = OasisDataset(data_dir, patch_size=32, n_sampled_patches=20, subset_size=subset_size)
    class_counts = np.bincount(dataset.labels)
    weights = 1.0 / class_counts[dataset.labels]
    sampler = torch.utils.data.WeightedRandomSampler(weights, len(weights))
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler, num_workers=0)
    return dataloader

Writing /content/oasis_dataset.py


In [2]:
from oasis_dataset import OasisDataset
import numpy as np
try:
    dataset = OasisDataset('/content/oasis_data/', subset_size=1000)
    print(f"Class distribution: {np.bincount(dataset.labels)}")
except Exception as e:
    print(f"Dataset error: {e}")

Extracted /content/oaisis.zip to /content/oasis_data/
Looking for images in /content/oasis_data/Data
Available directories: ['Very mild Dementia', 'Non Demented', 'Mild Dementia', 'Moderate Dementia']
Checking directory: /content/oasis_data/Data/Non Demented
Found 67222 .jpg files in /content/oasis_data/Data/Non Demented
Checking directory: /content/oasis_data/Data/Very mild Dementia
Found 13725 .jpg files in /content/oasis_data/Data/Very mild Dementia
Checking directory: /content/oasis_data/Data/Mild Dementia
Found 5002 .jpg files in /content/oasis_data/Data/Mild Dementia
Checking directory: /content/oasis_data/Data/Moderate Dementia
Found 488 .jpg files in /content/oasis_data/Data/Moderate Dementia
Loaded 1000 images: 778 CN, 159 MCI, 58 Mild, 5 Moderate
Class distribution: [778 159  58   5]


In [3]:
from oasis_dataset import get_dataloader
try:
    dataloader = get_dataloader('/content/oasis_data/', batch_size=2)
    for patches, labels, coords in dataloader:
        print("Patches shape:", patches.shape)  # Expected: [2, 20, 1, 32, 32]
        print("Labels shape:", labels.shape)   # Expected: [2]
        print("Coords length:", len(coords))   # Expected: 2
        break
except Exception as e:
    print(f"DataLoader error: {e}")

Looking for images in /content/oasis_data/Data
Available directories: ['Very mild Dementia', 'Non Demented', 'Mild Dementia', 'Moderate Dementia']
Checking directory: /content/oasis_data/Data/Non Demented
Found 67222 .jpg files in /content/oasis_data/Data/Non Demented
Checking directory: /content/oasis_data/Data/Very mild Dementia
Found 13725 .jpg files in /content/oasis_data/Data/Very mild Dementia
Checking directory: /content/oasis_data/Data/Mild Dementia
Found 5002 .jpg files in /content/oasis_data/Data/Mild Dementia
Checking directory: /content/oasis_data/Data/Moderate Dementia
Found 488 .jpg files in /content/oasis_data/Data/Moderate Dementia
Loaded 1000 images: 778 CN, 159 MCI, 58 Mild, 5 Moderate
Patches shape: torch.Size([2, 20, 1, 32, 32])
Labels shape: torch.Size([2])
Coords length: 80


  patches_tensor = torch.tensor(sampled_patches, dtype=torch.float).unsqueeze(1)


#pmicl_model

In [4]:
%%writefile /content/pmicl_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import numpy as np

# Check PyTorch version for compatibility
TORCH_VERSION = torch.__version__.split('+')[0]
if not torch.__version__.startswith(('2.3', '2.4', '2.5', '2.6')):
    print(f"Warning: PyTorch version {TORCH_VERSION} detected. This code is tested with PyTorch 2.3.x to 2.6.x.")

# Import torch-geometric with error handling
try:
    from torch_geometric.nn import GCNConv, global_mean_pool
except ImportError:
    print("ERROR: torch-geometric is not installed. Please run the following in a Colab cell:")
    print("!pip install torch-geometric")
    print("!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.6.0+cpu.html")
    sys.exit(1)

from sklearn.metrics.pairwise import euclidean_distances

# Configuration
EMBED_DIM = 128
NUM_CLASSES = 4  # Updated for CN, MCI, Mild, Moderate
TEMPERATURE = 0.1
PROTOTYPE_LOSS_WEIGHT = 0.5
GRAPH_LOSS_WEIGHT = 0.3

# Graph Constructor for mi-Graph
class GraphConstructor:
    def __init__(self):
        pass

    def build_graph(self, patches):
        try:
            # Check if patches is already a NumPy array
            if isinstance(patches, np.ndarray):
                patch_features = patches
            else:
                # Assume patches is a PyTorch tensor
                patch_features = patches.cpu().numpy()
            dist_matrix = euclidean_distances(patch_features)
            edge_index = []
            edge_weight = []
            for i in range(len(patch_features)):
                for j in range(i + 1, len(patch_features)):
                    edge_index.append([i, j])
                    edge_index.append([j, i])
                    edge_weight.append(dist_matrix[i, j])
                    edge_weight.append(dist_matrix[i, j])
            edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
            edge_weight = torch.tensor(edge_weight, dtype=torch.float)
            return edge_index, edge_weight
        except Exception as e:
            print(f"Error building graph: {e}")
            raise

# 2D CNN for Patch Encoding
class PatchEncoder(nn.Module):
    def __init__(self, in_channels=1, embed_dim=EMBED_DIM):
        super(PatchEncoder, self).__init__()
        self.embed_dim = embed_dim
        self.conv_layers = nn.Sequential(
            nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.fc = nn.Linear(64 * 4 * 4, embed_dim)

    def forward(self, patches):
        B, N, C, H, W = patches.shape
        patches_flat = patches.view(B * N, C, H, W)
        features = self.conv_layers(patches_flat)
        features_flat = features.view(B * N, -1)
        embeddings = self.fc(features_flat)
        embeddings_normalized = F.normalize(embeddings, p=2, dim=1)
        return embeddings_normalized.view(B, N, self.embed_dim)

# GNN for Graph Processing (mi-Graph)
class GNN(nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(EMBED_DIM, 64)
        self.conv2 = GCNConv(64, 32)

    def forward(self, x, edge_index, batch):
        x = F.relu(self.conv1(x, edge_index))
        x = F.relu(self.conv2(x, edge_index))
        x = global_mean_pool(x, batch)
        return x

# Attention-Based MIL Aggregator (PMICL)
class AttentionMIL(nn.Module):
    def __init__(self, embed_dim=EMBED_DIM):
        super(AttentionMIL, self).__init__()
        self.attention = nn.Sequential(
            nn.Linear(embed_dim, embed_dim // 2),
            nn.Tanh(),
            nn.Linear(embed_dim // 2, 1)
        )

    def forward(self, patch_embeddings):
        u = self.attention(patch_embeddings)
        a = F.softmax(u, dim=1)
        bag_embedding = torch.sum(a * patch_embeddings, dim=1)
        return bag_embedding, a.squeeze(-1)

# PMICL Model with mi-Graph Integration
class PMICL(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(PMICL, self).__init__()
        self.encoder = PatchEncoder()
        self.gnn = GNN()
        self.attention_mil = AttentionMIL()
        self.fc = nn.Linear(EMBED_DIM + 32, num_classes)
        self.prototypes = nn.Parameter(torch.randn(num_classes, EMBED_DIM))
        nn.init.xavier_uniform_(self.prototypes)

    def forward(self, patches, edge_index, batch, bag_labels=None):
        patch_embeddings = self.encoder(patches)
        B, N, D = patch_embeddings.shape
        patch_embeddings_flat = patch_embeddings.view(B * N, D)
        graph_features = self.gnn(patch_embeddings_flat, edge_index, batch)
        bag_embedding, attention_weights = self.attention_mil(patch_embeddings)
        combined_features = torch.cat([bag_embedding, graph_features], dim=1)
        logits = self.fc(combined_features)
        proto_loss = torch.tensor(0.0, device=patches.device)
        if bag_labels is not None:
            sim_matrix = torch.matmul(patch_embeddings_flat, self.prototypes.t())
            logits_proto = sim_matrix / TEMPERATURE
            patch_labels = bag_labels.unsqueeze(1).repeat(1, N).view(B * N)
            proto_loss = F.cross_entropy(logits_proto, patch_labels)
        return logits, proto_loss, attention_weights, patch_embeddings

# Graph Loss for Feature Smoothness
def graph_loss(patch_embeddings, edge_index, edge_weight):
    diff = patch_embeddings[edge_index[0]] - patch_embeddings[edge_index[1]]
    loss = (edge_weight * (diff ** 2).sum(dim=1)).mean()
    return loss

Writing /content/pmicl_model.py


#Train_eval


In [5]:
%%writefile /content/train_eval.py
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
import numpy as np
from pmicl_model import PMICL, GraphConstructor, graph_loss
from oasis_dataset import get_dataloader

# Configuration
DEVICE = torch.device("cpu")  # Use CPU to avoid CUDA issues
BATCH_SIZE = 2
LEARNING_RATE = 1e-4
NUM_EPOCHS = 10
EMBED_DIM = 128
PROTOTYPE_LOSS_WEIGHT = 0.5
GRAPH_LOSS_WEIGHT = 0.3
NUM_CLASSES = 4  # Updated for CN, MCI, Mild, Moderate

# Training Loop
def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    total_cls_loss = 0
    total_proto_loss = 0
    total_graph_loss = 0
    graph_constructor = GraphConstructor()

    for patches, labels, coords in dataloader:
        patches, labels = patches.to(device), labels.to(device)
        B, N, C, H, W = patches.shape

        # Build graph
        patch_features = patches.view(B * N, -1).cpu().numpy()
        edge_index, edge_weight = graph_constructor.build_graph(patch_features)
        edge_index, edge_weight = edge_index.to(device), edge_weight.to(device)
        batch = torch.repeat_interleave(torch.arange(B, device=device), N)

        # Forward pass
        optimizer.zero_grad()
        logits, proto_loss, attention_weights, patch_embeddings = model(patches, edge_index, batch, labels)

        # Losses
        cls_loss = F.cross_entropy(logits, labels)
        graph = graph_loss(patch_embeddings.view(B * N, EMBED_DIM), edge_index, edge_weight)
        loss = cls_loss + PROTOTYPE_LOSS_WEIGHT * proto_loss + GRAPH_LOSS_WEIGHT * graph

        # Backward pass
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_cls_loss += cls_loss.item()
        total_proto_loss += proto_loss.item()
        total_graph_loss += graph.item()

    return (total_loss / len(dataloader), total_cls_loss / len(dataloader),
            total_proto_loss / len(dataloader), total_graph_loss / len(dataloader))

# Evaluation
def evaluate(model, dataloader, device):
    model.eval()
    preds = []
    true_labels = []
    graph_constructor = GraphConstructor()

    with torch.no_grad():
        for patches, labels, coords in dataloader:
            patches, labels = patches.to(device), labels.to(device)
            B, N, C, H, W = patches.shape
            patch_features = patches.view(B * N, -1).cpu().numpy()
            edge_index, edge_weight = graph_constructor.build_graph(patch_features)
            edge_index, edge_weight = edge_index.to(device), edge_weight.to(device)
            batch = torch.repeat_interleave(torch.arange(B, device=device), N)

            logits, _, _, _ = model(patches, edge_index, batch)
            preds.append(torch.argmax(logits, dim=1).cpu().numpy())
            true_labels.append(labels.cpu().numpy())

    preds = np.concatenate(preds)
    true_labels = np.concatenate(true_labels)
    acc = accuracy_score(true_labels, preds)
    f1 = f1_score(true_labels, preds, average='macro')  # Macro for multi-class
    auc = roc_auc_score(true_labels, preds, multi_class='ovr')  # OVR for multi-class
    return acc, f1, auc

# Visualize Attention Weights
def visualize_attention(image, coords, attention_weights, patch_size=32):
    attention_map = np.zeros(image.shape)
    for (x, y), weight in zip(coords, attention_weights):
        attention_map[y:y+patch_size, x:x+patch_size] += weight
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image, cmap='gray')
    plt.title("Original Slice")
    plt.subplot(1, 2, 2)
    plt.imshow(attention_map, cmap='hot')
    plt.title("Attention Map")
    plt.savefig('/content/attention_map.png')
    plt.close()

# Main
if __name__ == "__main__":
    # Initialize DataLoader
    dataloader = get_dataloader('/content/oasis_data/', BATCH_SIZE)

    # Initialize Model
    model = PMICL(num_classes=NUM_CLASSES).to(DEVICE)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Training
    for epoch in range(NUM_EPOCHS):
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}")
        losses = train_epoch(model, dataloader, optimizer, DEVICE)
        print(f"Loss: {losses[0]:.4f}, Cls: {losses[1]:.4f}, Proto: {losses[2]:.4f}, Graph: {losses[3]:.4f}")

        # Evaluate every epoch
        acc, f1, auc = evaluate(model, dataloader, DEVICE)
        print(f"Val: Acc={acc:.4f}, F1={f1:.4f}, AUC={auc:.4f}")

    # Visualize attention for a sample
    model.eval()
    with torch.no_grad():
        patches, labels, coords = next(iter(dataloader))
        from PIL import Image
        image = np.array(Image.open(dataloader.dataset.image_paths[0]).convert('L')) / 255.0
        patches = patches.to(DEVICE)
        B, N, C, H, W = patches.shape
        patch_features = patches.view(B * N, -1).cpu().numpy()
        edge_index, edge_weight = GraphConstructor().build_graph(patch_features)
        edge_index, edge_weight = edge_index.to(DEVICE), edge_weight.to(DEVICE)
        batch = torch.repeat_interleave(torch.arange(B, device=DEVICE), N)

        _, _, attention_weights, _ = model(patches, edge_index, batch)
        visualize_attention(image, coords[0], attention_weights[0].cpu().numpy())

Writing /content/train_eval.py


In [6]:
!rm -rf /content/__pycache__

In [7]:
!pip install torch torchvision numpy pillow scikit-learn matplotlib torch-geometric
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.6.0+cpu.html

Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuff

In [8]:
!grep -A 5 "class PMICL" /content/pmicl_model.py
!grep -A 5 "def build_graph" /content/pmicl_model.py

class PMICL(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(PMICL, self).__init__()
        self.encoder = PatchEncoder()
        self.gnn = GNN()
        self.attention_mil = AttentionMIL()
    def build_graph(self, patches):
        try:
            # Check if patches is already a NumPy array
            if isinstance(patches, np.ndarray):
                patch_features = patches
            else:


In [9]:
from pmicl_model import PMICL, GraphConstructor, graph_loss
print("Imported PMICL, GraphConstructor, and graph_loss successfully")

Imported PMICL, GraphConstructor, and graph_loss successfully


In [11]:
%run /content/train_eval.py

Looking for images in /content/oasis_data/Data
Available directories: ['Very mild Dementia', 'Non Demented', 'Mild Dementia', 'Moderate Dementia']
Checking directory: /content/oasis_data/Data/Non Demented
Found 67222 .jpg files in /content/oasis_data/Data/Non Demented
Checking directory: /content/oasis_data/Data/Very mild Dementia
Found 13725 .jpg files in /content/oasis_data/Data/Very mild Dementia
Checking directory: /content/oasis_data/Data/Mild Dementia
Found 5002 .jpg files in /content/oasis_data/Data/Mild Dementia
Checking directory: /content/oasis_data/Data/Moderate Dementia
Found 488 .jpg files in /content/oasis_data/Data/Moderate Dementia
Loaded 1000 images: 778 CN, 159 MCI, 58 Mild, 5 Moderate
Epoch 1/10
Loss: 2.1391, Cls: 1.3881, Proto: 1.4724, Graph: 0.0494
Val: Acc=0.2750, F1=0.1078, AUC=0.4919
Epoch 2/10
Loss: 2.1015, Cls: 1.3857, Proto: 1.4273, Graph: 0.0070
Val: Acc=0.2500, F1=0.1000, AUC=0.5031
Epoch 3/10
Loss: 2.1049, Cls: 1.3876, Proto: 1.4323, Graph: 0.0039
Val: Acc

<Figure size 640x480 with 0 Axes>