In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Sampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torchxrayvision as xrv
import numpy as np
import torch.optim as optim
from tqdm import tqdm
import os

In [None]:
# --- 1. Model Definition (Hybrid w/ Transplanted Backbone) ---
class PrototypicalNet(nn.Module):
    def __init__(self, out_dim=256): 
        super(PrototypicalNet, self).__init__()
        
        # A. Load the default backbone structure
        model = xrv.models.DenseNet(weights="densenet121-res224-all")
        self.backbone = model.features
        
        # B. Load the "smarter" weights from our 86% F1-score model
        try:
            backbone_weights = torch.load('finetuned_backbone_from_imbalanced_model.pth')
            self.backbone.load_state_dict(backbone_weights)
            print("Successfully loaded fine-tuned backbone from imbalanced model.")
        except Exception as e:
            print(f"Warning: Could not load fine-tuned backbone. Using default weights. Error: {e}")

        # C. Freeze the (now smarter) backbone
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        
        # D. Add the trainable embedding head
        self.embedding_head = nn.Linear(1024, out_dim)

    def forward(self, x):
        with torch.no_grad(): 
            features = self.backbone(x)
        pooled = self.pooling(features).view(features.size(0), -1)
        embedding = self.embedding_head(pooled)
        return embedding

In [None]:
# --- 2. Data Transforms (Same as before) ---
def get_transforms():
    XRV_MEAN = [0.5081]
    XRV_STD = [0.0893]
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=1),
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=XRV_MEAN, std=XRV_STD)
    ])
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize(mean=XRV_MEAN, std=XRV_STD)
    ])
    return train_transform, test_transform

In [None]:
# --- 3. Episodic Sampler (Same as before) ---
class EpisodicBatchSampler(Sampler):
    def __init__(self, data_targets, n_way, n_shot, n_query, episodes_per_epoch):
        super().__init__(data_targets)
        self.data_targets = data_targets
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_query = n_query
        self.episodes_per_epoch = episodes_per_epoch
        self.class_indices = {}
        for idx, target in enumerate(self.data_targets):
            if target not in self.class_indices: self.class_indices[target] = []
            self.class_indices[target].append(idx)
        self.classes = list(self.class_indices.keys())
        if self.n_way > len(self.classes):
            raise ValueError(f"N_WAY ({self.n_way}) cannot be larger than the number of available classes ({len(self.classes)})")
    def __len__(self):
        return self.episodes_per_epoch
    def __iter__(self):
        for _ in range(self.episodes_per_epoch):
            episode_indices = []
            try:
                selected_classes = np.random.choice(self.classes, self.n_way, replace=False)
            except ValueError: continue
            for cls in selected_classes:
                class_idx = self.class_indices[cls]
                replace = len(class_idx) < (self.n_shot + self.n_query)
                try:
                    selected_idx = np.random.choice(class_idx, self.n_shot + self.n_query, replace=replace)
                    episode_indices.extend(selected_idx)
                except: continue
            if len(episode_indices) == self.n_way * (self.n_shot + self.n_query):
                yield episode_indices
            else: continue

In [None]:
# --- 4. Prototypical Loss (UPDATED FOR COSINE SIMILARITY) ---
def prototypical_loss(embeddings, labels, n_shot, n_query, n_way, device):
    embeddings = embeddings.reshape(n_way, n_shot + n_query, -1)
    support_embeddings = embeddings[:, :n_shot, :]
    query_embeddings = embeddings[:, n_shot:, :,]
    
    prototypes = support_embeddings.mean(dim=1)
    query_embeddings = query_embeddings.reshape(n_way * n_query, -1)
    
    # --- THIS IS THE KEY CHANGE ---
    # Normalize vectors to unit length
    query_embeddings_norm = F.normalize(query_embeddings, p=2, dim=1)
    prototypes_norm = F.normalize(prototypes, p=2, dim=1)
    
    # Calculate cosine similarity (matrix multiplication)
    # Higher similarity is better, so we use this directly as logits
    similarities = torch.mm(query_embeddings_norm, prototypes_norm.t())
    # --- END CHANGE ---

    query_labels = torch.arange(n_way, device=device).repeat_interleave(n_query)

    # Use similarities as logits for the loss
    loss = F.cross_entropy(similarities, query_labels)
    
    # Calculate accuracy based on highest similarity
    _, predicted_labels = torch.max(similarities, dim=1)
    accuracy = (predicted_labels == query_labels).float().mean()
    
    return loss, accuracy

In [None]:
# --- 5. Meta-Training Function (UPDATED) ---
def run_meta_training():
    print("--- Starting Meta-Training Phase (with Cosine Similarity) ---")
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    META_TRAIN_DIR = 'fsl_data/meta_train'
    MODEL_SAVE_PATH = 'fsl_backbone_cosine.pth' # New save path
    N_WAY = 2
    N_SHOT = 10
    N_QUERY = 10
    EPOCHS = 20
    EPISODES_PER_EPOCH = 500
    LEARNING_RATE = 0.0001

    train_transform, _ = get_transforms()
    train_dataset = ImageFolder(META_TRAIN_DIR, transform=train_transform)

    print(f"\nMeta-Train (Base) dataset loaded.")
    print(f"Found {len(train_dataset)} images in {len(train_dataset.classes)} classes.")
    print(f"Classes: {train_dataset.classes}")

    train_sampler = EpisodicBatchSampler(
        train_dataset.targets, N_WAY, N_SHOT, N_QUERY, EPISODES_PER_EPOCH
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_sampler=train_sampler,
        num_workers=2
    )

    model = PrototypicalNet().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

    print("\nStarting meta-training...")
    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0.0
        total_acc = 0.0
        
        with tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}") as pbar:
            for batch in pbar:
                images, labels = batch
                images = images.to(DEVICE)
                
                embeddings = model(images)
                loss, acc = prototypical_loss(
                    embeddings, labels, N_SHOT, N_QUERY, N_WAY, DEVICE
                )
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                
                total_loss += loss.item()
                total_acc += acc.item()
                
                pbar.set_postfix(
                    loss=f"{total_loss / (pbar.n + 1):.4f}", 
                    acc=f"{total_acc / (pbar.n + 1):.4f}"
                )

        avg_loss = total_loss / len(train_loader)
        avg_acc = total_acc / len(train_loader)
        print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f} | Avg Acc: {avg_acc:.4f}")

    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"\nMeta-training complete. Model saved to {MODEL_SAVE_PATH}")

In [None]:
# --- 6. Meta-Testing Function (UPDATED) ---
def run_meta_testing():
    print("\n--- Starting Meta-Testing Phase (with Cosine Similarity) ---")
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {DEVICE}")

    META_TEST_DIR = 'fsl_data/meta_test'
    MODEL_PATH = 'fsl_backbone_cosine.pth' # Load the new model
    N_WAY = 2  
    N_SHOT = 10 
    N_QUERY = 15
    TEST_EPISODES = 1000

    _, test_transform = get_transforms()
    test_dataset = ImageFolder(META_TEST_DIR, transform=test_transform)

    print(f"\nMeta-Test (Novel) dataset loaded.")
    print(f"Found {len(test_dataset)} images in {len(test_dataset.classes)} classes.")
    print(f"Classes: {test_dataset.classes}")

    test_sampler = EpisodicBatchSampler(
        data_targets=test_dataset.targets,
        n_way=N_WAY,
        n_shot=N_SHOT,
        n_query=N_QUERY,
        episodes_per_epoch=TEST_EPISODES
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_sampler=test_sampler,
        num_workers=2
    )

    model = PrototypicalNet(out_dim=256).to(DEVICE)
    try:
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        print(f"Successfully loaded pre-trained model from {MODEL_PATH}")
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    model.eval()
    total_loss = 0.0
    all_accuracies = []

    with torch.no_grad():
        for (batch_images, batch_labels) in tqdm(test_loader, desc="Running Meta-Test"):
            batch_images = batch_images.to(DEVICE)
            
            embeddings = model(batch_images)
            loss, accuracy = prototypical_loss(
                embeddings, batch_labels, N_SHOT, N_QUERY, N_WAY, DEVICE
            )
            
            total_loss += loss.item()
            all_accuracies.append(accuracy.item())

    avg_loss = total_loss / TEST_EPISODES
    avg_acc = np.mean(all_accuracies)
    std_dev = np.std(all_accuracies)
    confidence_interval = 1.96 * (std_dev / np.sqrt(len(all_accuracies)))

    print("\n" + "="*30)
    print("--- Meta-Test Results ---")
    print(f"Task: {N_WAY}-way, {N_SHOT}-shot (COVID vs TB)")
    print(f"Episodes Run: {len(all_accuracies)}")
    print(f"Average Loss: {avg_loss:.4f}")
    print(f"Average Accuracy: {avg_acc * 100:.2f}%")
    print(f"95% Confidence Interval: +/- {confidence_interval * 100:.2f}%")
    print("="*30)
    print(f"Final Reported Accuracy: {avg_acc * 100:.2f} Â± {confidence_interval * 100:.2f}%")
    print("="*30)

In [None]:
# --- 7. Main Execution ---
if __name__ == "__main__":
    run_meta_training()
    run_meta_testing()