In [1]:
import os
import shutil

# --- Configuration ---
# This is the path to your current dataset (e.g., 'processed_data/')
# It should contain 4 folders: 'COVID', 'Pneumonia', 'TB', 'Normal'
SOURCE_DIR = r"C:\Users\SUBRAT\MAFSL PROJECT\processed_data\processed_data"

# These are the new directories we will create
BASE_DIR = 'fsl_data'
TRAIN_DIR = os.path.join(BASE_DIR, 'meta_train')
TEST_DIR = os.path.join(BASE_DIR, 'meta_test')

# Define our class split
# Base classes for meta-training
BASE_CLASSES = ['Normal', 'Pneumonia']

# Novel classes for meta-testing
NOVEL_CLASSES = ['COVID', 'Tuberculosis']
# ---------------------

def setup_directories():
    """
    This function creates a new 'fsl_data' directory with a
    'meta_train' and 'meta_test' split.
    
    WARNING: It will delete and replace 'fsl_data' if it already exists.
    """
    
    if os.path.exists(BASE_DIR):
        print(f"Removing existing directory: {BASE_DIR}")
        shutil.rmtree(BASE_DIR)
        
    print(f"Creating new directory structure at: {BASE_DIR}")
    os.makedirs(TRAIN_DIR, exist_ok=True)
    os.makedirs(TEST_DIR, exist_ok=True)

    # 1. Copy Base Classes to meta_train
    print("Copying base classes (Normal, Pneumonia) to meta_train...")
    for cls in BASE_CLASSES:
        src = os.path.join(SOURCE_DIR, cls)
        dst = os.path.join(TRAIN_DIR, cls)
        if os.path.exists(src):
            shutil.copytree(src, dst)
        else:
            print(f"  Warning: Source folder not found at {src}")

    # 2. Copy Novel Classes to meta_test
    print("Copying novel classes (COVID, TB) to meta_test...")
    for cls in NOVEL_CLASSES:
        src = os.path.join(SOURCE_DIR, cls)
        dst = os.path.join(TEST_DIR, cls)
        if os.path.exists(src):
            shutil.copytree(src, dst)
        else:
            print(f"  Warning: Source folder not found at {src}")

    print("\nData setup complete.")
    print(f"Meta-Train (Base) classes in: {TRAIN_DIR}")
    print(f"Meta-Test (Novel) classes in: {TEST_DIR}")

if __name__ == "__main__":
    if not os.path.exists(SOURCE_DIR):
        print(f"Error: Source data directory not found at '{SOURCE_DIR}'")
        print("Please make sure your 'processed_data' folder is at that location.")
    else:
        setup_directories()

Removing existing directory: fsl_data
Creating new directory structure at: fsl_data
Copying base classes (Normal, Pneumonia) to meta_train...
Copying novel classes (COVID, TB) to meta_test...

Data setup complete.
Meta-Train (Base) classes in: fsl_data\meta_train
Meta-Test (Novel) classes in: fsl_data\meta_test


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Sampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
import torchxrayvision as xrv
import numpy as np

# --- 1. Model Definition ---
class PrototypicalNet(nn.Module):
    # --- THIS IS THE NEW CONSTRUCTOR ---
    # It now accepts an 'out_dim' (output dimension)
    def __init__(self, out_dim=256): 
        super(PrototypicalNet, self).__init__()
        
        full_model = xrv.models.DenseNet(weights="densenet121-res224-all")
        self.backbone = full_model.features
        
        # --- SOLUTION 1: FREEZE THE BACKBONE ---
        # We will trust its pre-trained weights as a powerful feature extractor
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        self.pooling = nn.AdaptiveAvgPool2d((1, 1))
        
        # --- SOLUTION 1: ADD A SMALL, TRAINABLE EMBEDDING HEAD ---
        # The optimizer will now *only* train these parameters (1024 -> 256)
        # DenseNet-121 output is 1024 features
        self.embedding_head = nn.Linear(1024, out_dim)

    # --- UPDATED FORWARD METHOD ---
    def forward(self, x):
        
        # --- 1. Pass input through the *frozen* backbone ---
        # We use torch.no_grad() to ensure no gradients are computed
        with torch.no_grad(): 
            features = self.backbone(x)
            
        pooled = self.pooling(features).view(features.size(0), -1)
        
        # --- 2. Pass features through the *trainable* head ---
        # Gradients will be computed for this part
        embedding = self.embedding_head(pooled)
        return embedding

# --- 2. Data Transforms ---
def get_transforms():
    """
    Returns the correct transforms for training and evaluation.
    This includes the specific normalization for torchxrayvision models.
    """
    # This is the normalization specified by torchxrayvision
    XRV_MEAN = [0.5081]
    XRV_STD = [0.0893]

    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=1), # Ensure 1-channel
        transforms.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=XRV_MEAN, std=XRV_STD)
    ])
    
    # Test transform does not have data augmentation
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize(mean=XRV_MEAN, std=XRV_STD)
    ])

    return train_transform, test_transform

# --- 3. Episodic Sampler (The "Loader Fix") ---
class EpisodicBatchSampler(Sampler):
    """
    A custom PyTorch Sampler to create episodic batches.
    This is the correct way to fix your broken data loading loop.
    """
    def __init__(self, data_targets, n_way, n_shot, n_query, episodes_per_epoch):
        super().__init__(data_targets)
        self.data_targets = data_targets
        self.n_way = n_way
        self.n_shot = n_shot
        self.n_query = n_query
        self.episodes_per_epoch = episodes_per_epoch

        # Create a dict of {class: [list of image indices]}
        self.class_indices = {}
        for idx, target in enumerate(self.data_targets):
            if target not in self.class_indices:
                self.class_indices[target] = []
            self.class_indices[target].append(idx)
            
        self.classes = list(self.class_indices.keys())
        
        if self.n_way > len(self.classes):
            raise ValueError(f"N_WAY ({self.n_way}) cannot be larger than the number of available classes ({len(self.classes)})")

    def __len__(self):
        return self.episodes_per_epoch

    def __iter__(self):
        for _ in range(self.episodes_per_epoch):
            episode_indices = []
            
            # 1. Select N_WAY classes at random
            try:
                selected_classes = np.random.choice(self.classes, self.n_way, replace=False)
            except ValueError:
                print("Warning: Not enough classes to sample from. Check N_WAY.")
                continue

            # 2. For each class, sample N_SHOT + N_QUERY images
            for cls in selected_classes:
                class_idx = self.class_indices[cls]
                
                # Check if we have enough samples, sample with replacement if not
                replace = len(class_idx) < (self.n_shot + self.n_query)
                
                try:
                    selected_idx = np.random.choice(class_idx, self.n_shot + self.n_query, replace=replace)
                    episode_indices.extend(selected_idx)
                except:
                    # This handles a rare case where a class might have 0 images
                    continue
                    
            if len(episode_indices) == self.n_way * (self.n_shot + self.n_query):
                yield episode_indices
            else:
                # This can happen if a class failed to sample
                # print("Skipping episode due to sample mismatch")
                continue


# --- 4. Prototypical Loss and Accuracy ---
def prototypical_loss(embeddings, labels, n_shot, n_query, n_way, device):
    """
    A robust prototypical loss function.
    Assumes batch is structured as [N_WAY * (N_SHOT + N_QUERY), EMBEDDING_DIM]
    and labels are the original ImageFolder labels.
    """
    
    # Reshape embeddings: [N_WAY, N_SHOT + N_QUERY, EMBEDDING_DIM]
    embeddings = embeddings.reshape(n_way, n_shot + n_query, -1)

    # Split into support and query sets
    support_embeddings = embeddings[:, :n_shot, :]
    query_embeddings = embeddings[:, n_shot:, :,]
    
    # Calculate prototypes by averaging support embeddings
    # [N_WAY, EMBEDDING_DIM]
    prototypes = support_embeddings.mean(dim=1)

    # Flatten query embeddings for cdist
    # [N_WAY * N_QUERY, EMBEDDING_DIM]
    query_embeddings = query_embeddings.reshape(n_way * n_query, -1)
    
    # Calculate euclidean distances
    # [N_WAY * N_QUERY, N_WAY]
    distances = torch.cdist(query_embeddings, prototypes)
    
    # Create target labels for the query set
    # [0, 0, ..., 1, 1, ..., N_WAY-1, ...]
    query_labels = torch.arange(n_way, device=device).repeat_interleave(n_query)

    # Calculate cross-entropy loss on negative distances
    # (using -distances makes it a "similarity" score)
    loss = F.cross_entropy(-distances, query_labels)
    
    # Calculate accuracy
    # Find the prototype with the minimum distance (closest match)
    _, predicted_labels = torch.min(distances, dim=1)
    accuracy = (predicted_labels == query_labels).float().mean()
    
    return loss, accuracy

In [3]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Sampler
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import torchxrayvision as xrv
import numpy as np

# --- Setup Hyperparameters ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Data and Model
META_TRAIN_DIR = 'fsl_data/meta_train'
MODEL_SAVE_PATH = 'fsl_backbone.pth'

# FSL parameters
N_WAY = 2  # 2 base classes (Normal, Pneumonia)
N_SHOT = 10 # 10 support images per class
N_QUERY = 10 # 10 query images per class

# Training parameters
EPOCHS = 20
EPISODES_PER_EPOCH = 500
LEARNING_RATE = 0.0001 # Use a small LR for fine-tuning

# --- 3. Load Data ---
train_transform, _ = get_transforms()
train_dataset = ImageFolder(META_TRAIN_DIR, transform=train_transform)

print(f"\nMeta-Train (Base) dataset loaded.")
print(f"Found {len(train_dataset)} images in {len(train_dataset.classes)} classes.")
print(f"Classes: {train_dataset.classes}")

# Create the Episodic Sampler
train_sampler = EpisodicBatchSampler(
    train_dataset.targets, 
    N_WAY, N_SHOT, N_QUERY, 
    EPISODES_PER_EPOCH
)

# Create the DataLoader
# IMPORTANT: batch_size=None tells the loader to use our custom batch sampler
train_loader = DataLoader(
    train_dataset,
    batch_sampler=train_sampler,
    num_workers=2
)

# --- 4. Initialize Model and Optimizer ---
model = PrototypicalNet().to(DEVICE)
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

# --- 5. Main Meta-Training Loop ---
print("\nStarting meta-training...")
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    total_acc = 0.0
    
    # Use tqdm for a progress bar
    with tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}") as pbar:
        for batch in pbar:
            images, labels = batch
            images = images.to(DEVICE)
            
            # 1. Get embeddings from the model
            embeddings = model(images)
            
            # 2. Calculate loss and accuracy
            loss, acc = prototypical_loss(
                embeddings, labels, 
                N_SHOT, N_QUERY, N_WAY, 
                DEVICE
            )
            
            # 3. Backpropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            total_acc += acc.item()
            
            # Update progress bar
            pbar.set_postfix(
                loss=f"{total_loss / (pbar.n + 1):.4f}", 
                acc=f"{total_acc / (pbar.n + 1):.4f}"
            )

    avg_loss = total_loss / len(train_loader)
    avg_acc = total_acc / len(train_loader)
    print(f"Epoch {epoch+1} Avg Loss: {avg_loss:.4f} | Avg Acc: {avg_acc:.4f}")

# --- 6. Save the Trained Model ---
# We save the state_dict of the meta-trained model
torch.save(model.state_dict(), MODEL_SAVE_PATH)
print(f"\nMeta-training complete. Model saved to {MODEL_SAVE_PATH}")

Using device: cpu

Meta-Train (Base) dataset loaded.
Found 100013 images in 2 classes.
Classes: ['Normal', 'Pneumonia']

Starting meta-training...


Epoch 1/20: 100%|██████████| 500/500 [30:49<00:00,  3.70s/it, acc=0.8206, loss=0.4323]


Epoch 1 Avg Loss: 0.4323 | Avg Acc: 0.8206


Epoch 2/20: 100%|██████████| 500/500 [30:21<00:00,  3.64s/it, acc=0.8194, loss=0.4269]


Epoch 2 Avg Loss: 0.4269 | Avg Acc: 0.8194


Epoch 3/20: 100%|██████████| 500/500 [29:23<00:00,  3.53s/it, acc=0.8229, loss=0.4175]


Epoch 3 Avg Loss: 0.4175 | Avg Acc: 0.8229


Epoch 4/20: 100%|██████████| 500/500 [29:48<00:00,  3.58s/it, acc=0.8221, loss=0.4156]


Epoch 4 Avg Loss: 0.4156 | Avg Acc: 0.8221


Epoch 5/20: 100%|██████████| 500/500 [29:40<00:00,  3.56s/it, acc=0.8295, loss=0.4106]


Epoch 5 Avg Loss: 0.4106 | Avg Acc: 0.8295


Epoch 6/20: 100%|██████████| 500/500 [29:44<00:00,  3.57s/it, acc=0.8212, loss=0.4164]


Epoch 6 Avg Loss: 0.4164 | Avg Acc: 0.8212


Epoch 7/20: 100%|██████████| 500/500 [30:15<00:00,  3.63s/it, acc=0.8251, loss=0.4126]


Epoch 7 Avg Loss: 0.4126 | Avg Acc: 0.8251


Epoch 8/20: 100%|██████████| 500/500 [31:33<00:00,  3.79s/it, acc=0.8244, loss=0.4126]


Epoch 8 Avg Loss: 0.4126 | Avg Acc: 0.8244


Epoch 9/20: 100%|██████████| 500/500 [31:24<00:00,  3.77s/it, acc=0.8210, loss=0.4147]


Epoch 9 Avg Loss: 0.4147 | Avg Acc: 0.8210


Epoch 10/20: 100%|██████████| 500/500 [31:08<00:00,  3.74s/it, acc=0.8227, loss=0.4238]


Epoch 10 Avg Loss: 0.4238 | Avg Acc: 0.8227


Epoch 11/20: 100%|██████████| 500/500 [31:02<00:00,  3.72s/it, acc=0.8267, loss=0.4076]


Epoch 11 Avg Loss: 0.4076 | Avg Acc: 0.8267


Epoch 12/20: 100%|██████████| 500/500 [30:05<00:00,  3.61s/it, acc=0.8245, loss=0.4128]


Epoch 12 Avg Loss: 0.4128 | Avg Acc: 0.8245


Epoch 13/20: 100%|██████████| 500/500 [29:57<00:00,  3.59s/it, acc=0.8256, loss=0.4145]


Epoch 13 Avg Loss: 0.4145 | Avg Acc: 0.8256


Epoch 14/20: 100%|██████████| 500/500 [29:38<00:00,  3.56s/it, acc=0.8250, loss=0.4108]


Epoch 14 Avg Loss: 0.4108 | Avg Acc: 0.8250


Epoch 15/20: 100%|██████████| 500/500 [29:39<00:00,  3.56s/it, acc=0.8246, loss=0.4140]


Epoch 15 Avg Loss: 0.4140 | Avg Acc: 0.8246


Epoch 16/20: 100%|██████████| 500/500 [29:48<00:00,  3.58s/it, acc=0.8208, loss=0.4154]


Epoch 16 Avg Loss: 0.4154 | Avg Acc: 0.8208


Epoch 17/20: 100%|██████████| 500/500 [29:33<00:00,  3.55s/it, acc=0.8221, loss=0.4201]


Epoch 17 Avg Loss: 0.4201 | Avg Acc: 0.8221


Epoch 18/20: 100%|██████████| 500/500 [29:23<00:00,  3.53s/it, acc=0.8308, loss=0.4119]


Epoch 18 Avg Loss: 0.4119 | Avg Acc: 0.8308


Epoch 19/20: 100%|██████████| 500/500 [29:25<00:00,  3.53s/it, acc=0.8247, loss=0.4103]


Epoch 19 Avg Loss: 0.4103 | Avg Acc: 0.8247


Epoch 20/20: 100%|██████████| 500/500 [29:48<00:00,  3.58s/it, acc=0.8284, loss=0.4116]


Epoch 20 Avg Loss: 0.4116 | Avg Acc: 0.8284

Meta-training complete. Model saved to fsl_backbone.pth


In [6]:
# --- TESTING ON NOVEL CLASSES ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

# Data and Model
META_TEST_DIR = 'fsl_data/meta_test'
MODEL_PATH = r'C:\Users\SUBRAT\MAFSL PROJECT\fsl_backbone.pth' # Load the model we just trained

# FSL parameters for testing
# We are testing on 2 *novel* classes (COVID, TB)
N_WAY = 2  
# We will test on 5-shot
N_SHOT = 10 
N_QUERY = 15 # Use more query images for a stable evaluation

# Evaluation parameters
TEST_EPISODES = 1000 # Run 1000 test episodes for a good average

# --- 3. Load Data ---
_, test_transform = get_transforms() # Use the test transform (no augmentation)
test_dataset = ImageFolder(META_TEST_DIR, transform=test_transform)

print(f"\nMeta-Test (Novel) dataset loaded.")
print(f"Found {len(test_dataset)} images in {len(test_dataset.classes)} classes.")
print(f"Classes: {test_dataset.classes}")

# Create the Episodic Sampler
test_sampler = EpisodicBatchSampler(
    data_targets=test_dataset.targets,
    n_way=N_WAY,
    n_shot=N_SHOT,
    n_query=N_QUERY,
    episodes_per_epoch=TEST_EPISODES
)

test_loader = DataLoader(
    test_dataset,
    batch_sampler=test_sampler,
    num_workers=2
)

# --- 4. Load Meta-Trained Model ---
model = PrototypicalNet(out_dim=256).to(DEVICE)
try:
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    print(f"Successfully loaded pre-trained model from {MODEL_PATH}")
except Exception as e:
    print(f"Error loading model: {e}")
    print("Please make sure 'meta_train.py' ran successfully and saved the model.")
    exit()


model.eval()
total_loss = 0.0
all_accuracies = [] # Store all 1000 accuracies

with torch.no_grad():
    for (batch_images, batch_labels) in tqdm(test_loader, desc="Running Meta-Test"):
        batch_images = batch_images.to(DEVICE)
        
        embeddings = model(batch_images)
        loss, accuracy = prototypical_loss(
        embeddings, batch_labels, N_SHOT, N_QUERY, N_WAY, DEVICE )

        
        total_loss += loss.item()
        all_accuracies.append(accuracy.item()) # Store each accuracy

# --- Report Final Results (MODIFIED FOR CONFIDENCE INTERVAL) ---
avg_loss = total_loss / TEST_EPISODES

# Calculate final statistics
avg_acc = np.mean(all_accuracies)
std_dev = np.std(all_accuracies)

# 95% confidence interval = 1.96 * (std / sqrt(n))
# This is the "+/-" value
confidence_interval = 1.96 * (std_dev / np.sqrt(len(all_accuracies)))

print("\n" + "="*30)
print("--- Meta-Test Results ---")
print(f"Task: {N_WAY}-way, {N_SHOT}-shot (COVID vs TB)")
print(f"Episodes Run: {len(all_accuracies)}")
print(f"Average Loss: {avg_loss:.4f}")
print(f"Average Accuracy: {avg_acc * 100:.2f}%")
print(f"95% Confidence Interval: +/- {confidence_interval * 100:.2f}%")
print("="*30)
print(f"Final Reported Accuracy: {avg_acc * 100:.2f} ± {confidence_interval * 100:.2f}%")
print("="*30)

Using device: cpu

Meta-Test (Novel) dataset loaded.
Found 4316 images in 2 classes.
Classes: ['COVID', 'Tuberculosis']
Successfully loaded pre-trained model from C:\Users\SUBRAT\MAFSL PROJECT\fsl_backbone.pth


Running Meta-Test: 100%|██████████| 1000/1000 [1:02:46<00:00,  3.77s/it]


--- Meta-Test Results ---
Task: 2-way, 10-shot (COVID vs TB)
Episodes Run: 1000
Average Loss: 0.6689
Average Accuracy: 61.43%
95% Confidence Interval: +/- 0.65%
Final Reported Accuracy: 61.43 ± 0.65%



