<a href="https://colab.research.google.com/github/tousifo/ml_notebooks/blob/main/pmicl.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install torch torchvision numpy pillow scikit-learn matplotlib torch-geometric
!pip install torch-scatter torch-sparse torch-cluster torch-spline-conv -f https://data.pyg.org/whl/torch-2.6.0+cpu.html

Looking in links: https://data.pyg.org/whl/torch-2.6.0+cpu.html


#Dataset_setting

In [2]:
%%writefile /content/oasis_dataset.py
import os
import zipfile
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms

ZIP_PATH = '/content/oaisis.zip'
EXTRACT_DIR = '/content/oasis_data/'

def extract_zip():
    try:
        if not os.path.exists(ZIP_PATH):
            raise FileNotFoundError(f"{ZIP_PATH} not found. Please upload the file to Colab.")
        os.makedirs(EXTRACT_DIR, exist_ok=True)
        if not os.listdir(EXTRACT_DIR):
            with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
                zip_ref.extractall(EXTRACT_DIR)
            print(f"Extracted {ZIP_PATH} to {EXTRACT_DIR}")
        else:
            print(f"Directory {EXTRACT_DIR} already contains files, skipping extraction.")
    except Exception as e:
        print(f"Error extracting ZIP file: {e}")
        raise

extract_zip()

class PatchExtractor:
    def __init__(self, patch_size=32, K=80, spatial_threshold=5):
        self.patch_size = patch_size
        self.K = K
        self.spatial_threshold = spatial_threshold

    def extract_patches(self, image, prob_map):
        patches = []
        coords = []
        prob_map_copy = prob_map.copy()
        for _ in range(self.K):
            max_prob_idx = np.argmax(prob_map_copy)
            y, x = np.unravel_index(max_prob_idx, prob_map_copy.shape)
            patch = self.get_patch(image, (x, y))
            patches.append(patch)
            coords.append((x, y))
            prob_map_copy = self.mask_neighbors(prob_map_copy, (x, y))
        return patches, coords

    def get_patch(self, image, center):
        half_size = self.patch_size // 2
        x, y = center
        patch = image[
            max(0, y - half_size):y + half_size,
            max(0, x - half_size):x + half_size
        ]
        if patch.shape[0] < self.patch_size or patch.shape[1] < self.patch_size:
            patch = np.pad(patch, [(0, max(0, self.patch_size - patch.shape[0])),
                                   (0, max(0, self.patch_size - patch.shape[1]))],
                           mode='constant')
        return patch

    def mask_neighbors(self, prob_map, center):
        half_size = self.spatial_threshold
        x, y = center
        prob_map[
            max(0, y - half_size):y + half_size + 1,
            max(0, x - half_size):x + half_size + 1
        ] = 0
        return prob_map

    def dynamic_sample(self, image, N=20):
        candidate_patches = []
        candidate_coords = []
        stride = self.patch_size // 2
        h, w = image.shape
        for y in range(0, h - self.patch_size + 1, stride):
            for x in range(0, w - self.patch_size + 1, stride):
                patch = image[y:y+self.patch_size, x:x+self.patch_size]
                candidate_patches.append(patch)
                candidate_coords.append((x, y))
        indices = np.random.choice(len(candidate_patches),
                                   min(N, len(candidate_patches)),
                                   replace=False)
        return [candidate_patches[i] for i in indices], [candidate_coords[i] for i in indices]

class OasisDataset(Dataset):
    def __init__(self, data_dir, patch_size=32, n_sampled_patches=20, subset_size=None, augment=False):
        self.data_dir = os.path.join(data_dir, 'Data')
        self.patch_size = patch_size
        self.n_sampled_patches = n_sampled_patches
        self.subset_size = subset_size
        self.augment = augment
        self.patch_extractor = PatchExtractor(patch_size=patch_size)
        self.class_map = {
            'Non Demented': 0,
            'Very mild Dementia': 1,
            'Mild Dementia': 2,
            'Moderate Dementia': 3
        }
        self.image_paths = []
        self.labels = []

        print(f"Looking for images in {self.data_dir}")
        if not os.path.exists(self.data_dir):
            raise FileNotFoundError(f"Data directory {self.data_dir} not found")

        available_dirs = os.listdir(self.data_dir)
        print(f"Available directories: {available_dirs}")

        class_valid_paths = {}
        for class_name in self.class_map:
            matching_dir = next((d for d in available_dirs if d.lower() == class_name.lower()), None)
            if not matching_dir:
                print(f"Warning: No directory found for {class_name}")
                continue
            class_dir = os.path.join(self.data_dir, matching_dir)
            print(f"Checking directory: {class_dir}")
            img_files = [f for f in os.listdir(class_dir) if f.lower().endswith('.jpg')]
            img_paths = [os.path.join(class_dir, f) for f in img_files]
            print(f"Found {len(img_paths)} .jpg files in {class_dir}")

            valid_paths = []
            for img_path in img_paths:
                try:
                    with Image.open(img_path) as img:
                        img.verify()
                    valid_paths.append(img_path)
                except Exception as e:
                    print(f"Warning: Failed to load {img_path}: {e}")
            class_valid_paths[class_name] = valid_paths

        if not class_valid_paths:
            raise ValueError(f"No valid images found in {self.data_dir}.")

        min_images = min(len(paths) for paths in class_valid_paths.values())
        images_per_class = self.subset_size if self.subset_size is not None else min_images
        images_per_class = min(images_per_class, min_images)

        for class_name in self.class_map:
            if class_name not in class_valid_paths:
                continue
            valid_paths = class_valid_paths[class_name]
            sampled_paths = np.random.choice(valid_paths,
                                            images_per_class,
                                            replace=False)
            self.image_paths.extend(sampled_paths)
            self.labels.extend([self.class_map[class_name]] * images_per_class)

        print(f"Loaded {len(self.image_paths)} images: "
              f"{len([l for l in self.labels if l == 0])} CN, "
              f"{len([l for l in self.labels if l == 1])} MCI, "
              f"{len([l for l in self.labels if l == 2])} Mild, "
              f"{len([l for l in self.labels if l == 3])} Moderate")

        transform_list = [
            transforms.ToPILImage(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])  # Normalize for stability
        ]
        if self.augment:
            transform_list.insert(1, transforms.RandomRotation(15))  # Stronger augmentation
            transform_list.insert(2, transforms.RandomHorizontalFlip())
            transform_list.insert(3, transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)))

        self.transform = transforms.Compose(transform_list)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        try:
            image = np.array(Image.open(img_path).convert('L')) / 255.0
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            raise
        prob_map = np.random.rand(image.shape[0], image.shape[1])
        patches, coords = self.patch_extractor.extract_patches(image, prob_map)
        sampled_patches, sampled_coords = self.patch_extractor.dynamic_sample(
            image, self.n_sampled_patches)
        patches_tensor = torch.stack([self.transform(patch) for patch in sampled_patches])
        label_tensor = torch.tensor(label, dtype=torch.long)
        print(f"Patches tensor shape: {patches_tensor.shape}, Min: {patches_tensor.min():.4f}, Max: {patches_tensor.max():.4f}")
        return patches_tensor, label_tensor, sampled_coords

def custom_collate_fn(batch):
    patches = torch.stack([item[0] for item in batch])
    labels = torch.stack([item[1] for item in batch])
    coords = [item[2] for item in batch]
    return patches, labels, coords

def get_dataloader(data_dir, batch_size=2, subset_size=None, augment=False):
    dataset = OasisDataset(data_dir, patch_size=32, n_sampled_patches=20, subset_size=subset_size, augment=augment)
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        collate_fn=custom_collate_fn,
        pin_memory=True if torch.cuda.is_available() else False
    )

Overwriting /content/oasis_dataset.py


#Test_dataloader

In [3]:
from oasis_dataset import get_dataloader
try:
    dataloader = get_dataloader('/content/oasis_data/', batch_size=2)
    for patches, labels, coords in dataloader:
        print("Patches shape:", patches.shape)  # Expected: [2, 20, 1, 32, 32]
        print("Labels shape:", labels.shape)   # Expected: [2]
        print("Coords length:", len(coords))   # Expected: 2
        break
except Exception as e:
    print(f"DataLoader error: {e}")

Directory /content/oasis_data/ already contains files, skipping extraction.
Looking for images in /content/oasis_data/Data
Available directories: ['Mild Dementia', 'Moderate Dementia', 'Very mild Dementia', 'Non Demented']
Checking directory: /content/oasis_data/Data/Non Demented
Found 67222 .jpg files in /content/oasis_data/Data/Non Demented
Checking directory: /content/oasis_data/Data/Very mild Dementia
Found 13725 .jpg files in /content/oasis_data/Data/Very mild Dementia
Checking directory: /content/oasis_data/Data/Mild Dementia
Found 5002 .jpg files in /content/oasis_data/Data/Mild Dementia
Checking directory: /content/oasis_data/Data/Moderate Dementia
Found 488 .jpg files in /content/oasis_data/Data/Moderate Dementia
Loaded 1952 images: 488 CN, 488 MCI, 488 Mild, 488 Moderate
Patches tensor shape: torch.Size([20, 1, 32, 32]), Min: -1.0000, Max: 0.8980Patches tensor shape: torch.Size([20, 1, 32, 32]), Min: -1.0000, Max: 0.7412

Patches tensor shape: torch.Size([20, 1, 32, 32]), Min

#pmicl_model

In [4]:
%%writefile /content/pmicl_model.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv
import torch.nn.init as init
import scipy.sparse as sp

class GraphConstructor:
    def __init__(self, k=15):
        self.k = k

    def build_graph(self, features):
        from sklearn.neighbors import kneighbors_graph
        adj = kneighbors_graph(features, n_neighbors=self.k, mode='connectivity', include_self=False)
        adj = sp.csr_matrix(adj)
        num_nodes = adj.shape[0]
        edge_index = torch.tensor(adj.nonzero(), dtype=torch.long)
        edge_weight = torch.ones(edge_index.shape[1], dtype=torch.float)
        print(f"Build graph: num_nodes={num_nodes}, edge_index_shape={edge_index.shape}, edge_weight_min={edge_weight.min():.4f}, max={edge_weight.max():.4f}")
        return edge_index, edge_weight

def graph_loss(embeddings, edge_index, edge_weight):
    embeddings = F.normalize(embeddings, p=2, dim=-1)
    print(f"graph_loss: embeddings shape: {embeddings.shape}, edge_index max: {edge_index.max()}")
    source = embeddings[edge_index[0]]
    target = embeddings[edge_index[1]]
    similarity = F.cosine_similarity(source, target, dim=-1)
    similarity = torch.clamp(similarity, min=-0.99, max=0.99)
    print(f"Cosine similarity (sample): {similarity[:5]}")
    pos_loss = -torch.mean(edge_weight * similarity)

    # Contrastive regularization
    neg_samples = torch.randperm(embeddings.size(0))[:embeddings.size(0)//2]
    neg_sim = F.cosine_similarity(embeddings[neg_samples], embeddings[neg_samples.roll(1)], dim=-1)
    neg_loss = torch.mean(torch.clamp(neg_sim, min=-0.99, max=0.99))

    loss = pos_loss + 0.1 * neg_loss
    print(f"Graph loss: Pos={pos_loss.item():.4f}, Neg={neg_loss.item():.4f}, Total={loss.item():.4f}")
    return loss

class PMICL(nn.Module):
    def __init__(self, num_classes=4, embed_dim=128, num_prototypes=10, num_heads=4):
        super(PMICL, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 4 * 4, embed_dim)
        self.gat1 = GATConv(embed_dim, embed_dim // num_heads, heads=num_heads, dropout=0.3)
        self.gat2 = GATConv(embed_dim, embed_dim // num_heads, heads=num_heads, dropout=0.3)
        self.attention = nn.MultiheadAttention(embed_dim, num_heads, dropout=0.3)
        self.fc2 = nn.Linear(embed_dim, num_classes)
        self.prototypes = nn.Parameter(torch.randn(num_prototypes, embed_dim))
        self.dropout = nn.Dropout(0.5)
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, GATConv):
                nn.init.xavier_uniform_(m.lin.weight)
                # Initialize attention weights (att_src, att_dst)
                if hasattr(m, 'att_src') and m.att_src is not None:
                    nn.init.xavier_uniform_(m.att_src)
                    print("Initialized GAT att_src")
                if hasattr(m, 'att_dst') and m.att_dst is not None:
                    nn.init.xavier_uniform_(m.att_dst)
                    print("Initialized GAT att_dst")
        nn.init.normal_(self.prototypes, mean=0.0, std=0.01)

    def forward(self, patches, edge_index, batch, labels=None):
        B, N, C, H, W = patches.shape
        x = patches.view(B * N, C, H, W)
        x = self.bn1(F.leaky_relu(self.conv1(x), negative_slope=0.1))
        print(f"After conv1: Min={x.min().item():.4f}, Max={x.max().item():.4f}")
        x = self.pool(x)
        x = self.bn2(F.leaky_relu(self.conv2(x), negative_slope=0.1))
        print(f"After conv2: Min={x.min().item():.4f}, Max={x.max().item():.4f}")
        x = self.pool(x)
        x = self.bn3(F.leaky_relu(self.conv3(x), negative_slope=0.1))
        print(f"After conv3: Min={x.min().item():.4f}, Max={x.max().item():.4f}")
        x = self.pool(x)
        x = x.view(B * N, -1)
        x = self.dropout(F.leaky_relu(self.fc1(x), negative_slope=0.1))
        print(f"Before GAT - embeddings shape: {x.shape}, Min: {x.min():.4f}, Max: {x.max():.4f}")
        graph_constructor = GraphConstructor(k=15)
        edge_index, edge_weight = graph_constructor.build_graph(x.detach().cpu().numpy())
        edge_index, edge_weight = edge_index.to(x.device), edge_weight.to(x.device)
        x = F.leaky_relu(self.gat1(x, edge_index), negative_slope=0.1)
        print(f"GAT1 embeddings - Min: {x.min():.4f}, Max: {x.max():.4f}")
        x = self.dropout(F.leaky_relu(self.gat2(x, edge_index), negative_slope=0.1))
        print(f"GAT2 embeddings - Min: {x.min():.4f}, Max: {x.max():.4f}")
        x = x.unsqueeze(0)
        x, _ = self.attention(x, x, x)
        x = x.squeeze(0)
        print(f"After Attention - Min: {x.min():.4f}, Max: {x.max():.4f}")
        if not torch.isfinite(x).all():
            print(f"NaN detected in embeddings: {x}")
        graph_loss_val = graph_loss(x, edge_index, edge_weight)
        x = x.view(B, N, -1).mean(dim=1)
        logits = self.fc2(x)
        cls_loss = F.cross_entropy(logits, labels) if labels is not None else torch.tensor(0.0, device=x.device)
        proto_dist = torch.cdist(x, self.prototypes)
        proto_loss = proto_dist.mean()
        return logits, cls_loss, proto_loss, graph_loss_val, None

Overwriting /content/pmicl_model.py


#Train


In [5]:
%%writefile /content/train.py
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from torch.optim.lr_scheduler import CyclicLR
from oasis_dataset import get_dataloader
from pmicl_model import PMICL

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")
NUM_EPOCHS = 30  # Increased for more training time
BATCH_SIZE = 8
BASE_LR = 1e-4  # Lower base for warm-up
MAX_LR = 5e-4   # Max for cyclical LR
NUM_CLASSES = 4
PATIENCE = 8    # Increased patience
WARMUP_EPOCHS = 5

def train_epoch(model, dataloader, optimizer, scheduler, device, scaler):
    model.train()
    total_loss = 0
    cls_loss_total = 0
    proto_loss_total = 0
    graph_loss_total = 0
    correct = 0
    total = 0

    for batch_idx, (patches, labels, coords) in enumerate(dataloader):
        try:
            patches, labels = patches.to(device), labels.to(device)
            B, N, C, H, W = patches.shape
            print(f"Batch {batch_idx+1}/{len(dataloader)}: Patches shape: {patches.shape}, Min: {patches.min():.4f}, Max: {patches.max():.4f}")

            optimizer.zero_grad()
            with autocast(device_type='cuda' if torch.cuda.is_available() else 'cpu'):
                logits, cls_loss, proto_loss, graph_loss, _ = model(patches, None, None, labels)

            if not torch.isfinite(cls_loss) or not torch.isfinite(proto_loss) or not torch.isfinite(graph_loss):
                print(f"Batch {batch_idx+1} - Invalid loss: Cls={cls_loss.item()}, Proto={proto_loss.item()}, Graph={graph_loss.item()}")
                continue

            print(f"Batch {batch_idx+1} - Cls loss: {cls_loss.item():.4f}, Proto loss: {proto_loss.item():.4f}, Graph loss: {graph_loss.item():.4f}")

            loss = 0.3 * cls_loss + 0.3 * proto_loss + 0.4 * graph_loss  # Increased graph weight
            scaler.scale(loss).backward()

            print(f"Batch {batch_idx+1} - Gradient Check:")
            for name, param in model.named_parameters():
                if param.grad is None:
                    print(f"  {name}: No gradient")
                elif not torch.isfinite(param.grad).all():
                    print(f"  {name}: NaN/Inf gradient")
                else:
                    print(f"  {name}: Max gradient {param.grad.abs().max():.6f}")

            if (batch_idx + 1) % 10 == 0:
                total_grad_norm = sum(p.grad.norm().item() ** 2 for p in model.parameters() if p.grad is not None and torch.isfinite(p.grad).all()) ** 0.5
                conv1_grad = model.conv1.weight.grad.norm().item() if model.conv1.weight.grad is not None else 0.0
                conv2_grad = model.conv2.weight.grad.norm().item() if model.conv2.weight.grad is not None else 0.0
                conv3_grad = model.conv3.weight.grad.norm().item() if model.conv3.weight.grad is not None else 0.0
                gat1_grad = sum(p.grad.norm().item() for p in model.gat1.parameters() if p.grad is not None and torch.isfinite(p.grad).all())
                print(f"Batch {batch_idx+1} - Gradients: conv1={conv1_grad:.4f}, conv2={conv2_grad:.4f}, conv3={conv3_grad:.4f}, gat1={gat1_grad:.4f}")
                print(f"Batch {batch_idx+1} - Total gradient norm: {total_grad_norm:.6f}")

            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.1)  # Stricter clipping
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            cls_loss_total += cls_loss.item()
            proto_loss_total += proto_loss.item()
            graph_loss_total += graph_loss.item()

            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            print(f"Batch {batch_idx+1}/{len(dataloader)} processed")
        except Exception as e:
            print(f"Error in batch {batch_idx+1}: {e}")
            raise

        if isinstance(scheduler, CyclicLR):
            scheduler.step()

    print("Epoch batch processing completed")
    train_acc = correct / total
    print(f"Epoch Losses - Total: {total_loss / len(dataloader):.4f}, "
          f"Cls: {cls_loss_total / len(dataloader):.4f}, "
          f"Proto: {proto_loss_total / len(dataloader):.4f}, "
          f"Graph: {graph_loss_total / len(dataloader):.4f}")
    return (total_loss / len(dataloader), cls_loss_total / len(dataloader),
            proto_loss_total / len(dataloader), graph_loss_total / len(dataloader), train_acc)

def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    total_loss = 0
    with torch.no_grad():
        for patches, labels, coords in dataloader:
            patches, labels = patches.to(device), labels.to(device)
            logits, cls_loss, proto_loss, graph_loss, _ = model(patches, None, None, labels)
            loss = 0.3 * cls_loss + 0.3 * proto_loss + 0.4 * graph_loss
            total_loss += loss.item()
            preds = torch.argmax(logits, dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total, total_loss / len(dataloader)

def main():
    try:
        train_dataloader = get_dataloader('/content/oasis_data/', batch_size=BATCH_SIZE, augment=True)
        val_dataloader = get_dataloader('/content/oasis_data/', batch_size=BATCH_SIZE, augment=False)  # Separate validation
        print(f"Train dataset size: {len(train_dataloader.dataset)}, Batches: {len(train_dataloader)}")
        print(f"Val dataset size: {len(val_dataloader.dataset)}, Batches: {len(val_dataloader)}")
        model = PMICL(num_classes=NUM_CLASSES).to(DEVICE)
        optimizer = optim.AdamW(model.parameters(), lr=BASE_LR, weight_decay=1e-2)  # AdamW with L2
        scheduler = CyclicLR(optimizer, base_lr=BASE_LR, max_lr=MAX_LR, step_size_up=len(train_dataloader) * WARMUP_EPOCHS, mode='triangular')
        scaler = GradScaler()

        best_val_loss = float('inf')
        patience_counter = 0

        for epoch in range(NUM_EPOCHS):
            print(f"Starting Epoch {epoch+1}/{NUM_EPOCHS}")
            loss, cls_loss, proto_loss, graph_loss, train_acc = train_epoch(model, train_dataloader, optimizer, scheduler, device=DEVICE, scaler=scaler)
            print(f"Epoch {epoch+1} - Loss: {loss:.4f}, Cls: {cls_loss:.4f}, Proto: {proto_loss:.4f}, Graph: {graph_loss:.4f}, Train Acc: {train_acc:.4f}")

            val_acc, val_loss = evaluate(model, val_dataloader, DEVICE)
            print(f"Epoch {epoch+1} - Validation Accuracy: {val_acc:.4f}, Validation Loss: {val_loss:.4f}")

            print(f"Current learning rate: {optimizer.param_groups[0]['lr']}")

            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
            else:
                patience_counter += 1
                if patience_counter >= PATIENCE:
                    print("Early stopping triggered")
                    break
        print("Training completed!")
    except Exception as e:
        print(f"Error in main loop: {e}")
        raise

if __name__ == "__main__":
    main()

Overwriting /content/train.py


In [6]:
!rm -rf /content/__pycache__

In [None]:
%run /content/train.py

Using device: cuda
Looking for images in /content/oasis_data/Data
Available directories: ['Mild Dementia', 'Moderate Dementia', 'Very mild Dementia', 'Non Demented']
Checking directory: /content/oasis_data/Data/Non Demented
Found 67222 .jpg files in /content/oasis_data/Data/Non Demented
Checking directory: /content/oasis_data/Data/Very mild Dementia
Found 13725 .jpg files in /content/oasis_data/Data/Very mild Dementia
Checking directory: /content/oasis_data/Data/Mild Dementia
Found 5002 .jpg files in /content/oasis_data/Data/Mild Dementia
Checking directory: /content/oasis_data/Data/Moderate Dementia
Found 488 .jpg files in /content/oasis_data/Data/Moderate Dementia
Loaded 1952 images: 488 CN, 488 MCI, 488 Mild, 488 Moderate
Looking for images in /content/oasis_data/Data
Available directories: ['Mild Dementia', 'Moderate Dementia', 'Very mild Dementia', 'Non Demented']
Checking directory: /content/oasis_data/Data/Non Demented
Found 67222 .jpg files in /content/oasis_data/Data/Non Demen

  edge_index = torch.tensor(adj.nonzero(), dtype=torch.long)


Build graph: num_nodes=160, edge_index_shape=torch.Size([2, 2400]), edge_weight_min=1.0000, max=1.0000
GAT1 embeddings - Min: -9.0938, Max: 85.8036
GAT2 embeddings - Min: -19.8602, Max: 175.9937
After Attention - Min: -110.3125, Max: 98.2500
graph_loss: embeddings shape: torch.Size([160, 128]), edge_index max: 159
Cosine similarity (sample): tensor([ 0.2301,  0.1439,  0.0504,  0.2329, -0.2174], device='cuda:0',
       grad_fn=<SliceBackward0>)
Graph loss: Pos=-0.1353, Neg=0.0945, Total=-0.1258
Batch 1 - Cls loss: 30.8906, Proto loss: 76.4920, Graph loss: -0.1258
Patches tensor shape: torch.Size([20, 1, 32, 32]), Min: -1.0000, Max: 0.7412
Patches tensor shape: torch.Size([20, 1, 32, 32]), Min: -1.0000, Max: 0.7569
Patches tensor shape: torch.Size([20, 1, 32, 32]), Min: -1.0000, Max: 0.9451
Patches tensor shape: torch.Size([20, 1, 32, 32]), Min: -1.0000, Max: 0.8745
Patches tensor shape: torch.Size([20, 1, 32, 32]), Min: -1.0000, Max: 0.7647
Batch 1 - Gradient Check:
  prototypes: Max gr



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
GAT1 embeddings - Min: -0.4050, Max: 0.1576
GAT2 embeddings - Min: -0.0392, Max: 0.1756
After Attention - Min: -0.1829, Max: 0.1394
graph_loss: embeddings shape: torch.Size([160, 128]), edge_index max: 159
Cosine similarity (sample): tensor([0.9900, 0.9900, 0.9900, 0.9900, 0.9900], device='cuda:0',
       grad_fn=<SliceBackward0>)
Graph loss: Pos=-0.9900, Neg=0.9899, Total=-0.8910
Batch 198 - Cls loss: 1.3779, Proto loss: 0.0052, Graph loss: -0.8910
Batch 198 - Gradient Check:
  prototypes: Max gradient 0.879972
  conv1.weight: Max gradient 0.167847
  conv1.bias: Max gradient 0.106079
  bn1.weight: Max gradient 0.003550
  bn1.bias: Max gradient 0.003090
  conv2.weight: Max gradient 0.008217
  conv2.bias: Max gradient 0.002962
  bn2.weight: Max gradient 0.005906
  bn2.bias: Max gradient 0.003091
  conv3.weight: Max gradient 0.010185
  conv3.bias: Max gradient 0.002110
  bn3.weight: Max gradient 0.017617
  bn3.bias: Max gra

#Evaluation


In [None]:
%%writefile /content/evaluate_visualize.py
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix
from torch.utils.data import DataLoader
from oasis_dataset import get_dataloader
from pmicl_model import PMICL, GraphConstructor

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 8
NUM_CLASSES = 4

def evaluate(model, dataloader, device):
    model.eval()
    preds = []
    true_labels = []
    graph_constructor = GraphConstructor(k=5)

    with torch.no_grad():
        for patches, labels, coords in dataloader:
            patches, labels = patches.to(device), labels.to(device)
            B, N, C, H, W = patches.shape
            x = patches.view(B * N, C, H, W)
            x = model.bn1(F.relu(model.conv1(x)))
            x = model.pool(x)
            x = model.bn2(F.relu(model.conv2(x)))
            x = model.pool(x)
            x = model.bn3(F.relu(model.conv3(x)))
            x = model.pool(x)
            x = x.view(B * N, -1)
            x = F.relu(model.fc1(x))
            edge_index, edge_weight = graph_constructor.build_graph(x.detach().cpu().numpy())
            edge_index, edge_weight = edge_index.to(device), edge_weight.to(device)
            batch = torch.repeat_interleave(torch.arange(B, device=device), N)

            logits, _, _, _, _ = model(patches, edge_index, batch)
            preds.append(torch.softmax(logits, dim=1).cpu().numpy())
            true_labels.append(labels.cpu().numpy())

    preds = np.concatenate(preds)
    true_labels = np.concatenate(true_labels)

    acc = accuracy_score(true_labels, np.argmax(preds, axis=1))
    f1 = f1_score(true_labels, np.argmax(preds, axis=1), average='macro')
    auc = roc_auc_score(true_labels, preds, multi_class='ovr')
    return acc, f1, auc, true_labels, np.argmax(preds, axis=1)

def generate_attention_map(model, dataloader, device):
    model.eval()
    graph_constructor = GraphConstructor(k=5)

    with torch.no_grad():
        for patches, _, coords in dataloader:
            patches = patches.to(device)
            B, N, C, H, W = patches.shape
            x = patches.view(B * N, C, H, W)
            x = model.bn1(F.relu(model.conv1(x)))
            x = model.pool(x)
            x = model.bn2(F.relu(model.conv2(x)))
            x = model.pool(x)
            x = model.bn3(F.relu(model.conv3(x)))
            x = model.pool(x)
            x = x.view(B * N, -1)
            x = F.relu(model.fc1(x))
            edge_index, edge_weight = graph_constructor.build_graph(x.detach().cpu().numpy())
            edge_index, edge_weight = edge_index.to(device), edge_weight.to(device)
            batch = torch.repeat_interleave(torch.arange(B, device=device), N)

            _, _, _, _, attention_weights = model(patches, edge_index, batch)
            attention_weights = attention_weights.cpu().numpy()

            attention_per_image = attention_weights[:N, :N]
            plt.figure(figsize=(10, 10))
            plt.imshow(attention_per_image, cmap='hot')
            plt.colorbar()
            plt.title("Attention Map")
            plt.savefig('/content/attention_map.png')
            plt.close()
            break

def plot_confusion_matrix(true_labels, pred_labels, class_names):
    cm = confusion_matrix(true_labels, pred_labels)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.savefig('/content/confusion_matrix.png')
    plt.close()

def plot_metrics(train_accs, val_accs, val_f1s, val_aucs, losses):
    epochs = range(1, len(train_accs) + 1)

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, train_accs, label='Training Accuracy')
    plt.plot(epochs, val_accs, label='Validation Accuracy')
    plt.title('Accuracy Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig('/content/accuracy_plot.png')
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, val_accs, label='Validation Accuracy')
    plt.plot(epochs, val_f1s, label='Validation F1 Score')
    plt.plot(epochs, val_aucs, label='Validation AUC')
    plt.title('Validation Metrics Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    plt.grid(True)
    plt.savefig('/content/validation_metrics_plot.png')
    plt.close()

    plt.figure(figsize=(10, 5))
    plt.plot(epochs, losses, label='Training Loss')
    plt.title('Training Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('/content/loss_plot.png')
    plt.close()

def main():
    dataloader = get_dataloader('/content/oasis_data/', batch_size=BATCH_SIZE)
    model = PMICL(num_classes=NUM_CLASSES).to(DEVICE)
    class_names = ['Non Demented', 'Very mild Dementia', 'Mild Dementia', 'Moderate Dementia']

    acc, f1, auc, true_labels, pred_labels = evaluate(model, dataloader, DEVICE)
    print(f"Val: Acc={acc:.4f}, F1={f1:.4f}, AUC={auc:.4f}")

    plot_confusion_matrix(true_labels, pred_labels, class_names)
    generate_attention_map(model, dataloader, DEVICE)

if __name__ == "__main__":
    main()