In [70]:
import torch
print("PyTorch sees", torch.cuda.device_count(), "GPUs")


PyTorch sees 8 GPUs


In [71]:
import os
print("CPU cores available:", os.cpu_count())


CPU cores available: 80


In [72]:
import torchvision
import torch
import os
import random
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [73]:
from torchvision.datasets.utils import download_and_extract_archive
from torchvision.datasets.folder import ImageFolder

In [74]:
data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

In [75]:
if not os.path.exists("/home/23dcs505/data/2750"):
    print("No dataset found")
fulldata=ImageFolder(root='/home/23dcs505/data/2750', transform=data_transform)


In [76]:
from torch.utils.data import random_split

train_len=int((0.8)*len(fulldata))
test_len=len(fulldata)-(train_len)

train_data_set,test_data_set= random_split(fulldata,[train_len, test_len])

In [77]:
all_list=[0,1,2,3,4,5,6,7,8,9]

In [78]:
train_class_len=5

In [79]:
train_list=random.sample(all_list,train_class_len)
test_list=list(range(0,10))
strict_test_list=list(set(all_list) - set(train_list))


In [80]:
print(train_list)
print(test_list)
print(strict_test_list)

[7, 0, 5, 6, 3]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[1, 2, 4, 8, 9]


In [81]:
ways=5
shots=5
queries=5
strict_ways=len(strict_test_list)
gpu_num=2

In [82]:
from torch.utils.data import Subset

In [83]:
train_data_set.indices[0]

25027

In [84]:
def class_sorting(dataset, class_list):
    targets = dataset.dataset.targets

    indices= [i for i in dataset.indices if targets[i] in class_list]
    return Subset(dataset.dataset, indices)
    

In [85]:
train_data=class_sorting(train_data_set,train_list)
test_data=class_sorting(test_data_set,test_list)
strict_test_data=class_sorting(test_data,strict_test_list)

In [86]:
train_data[0]

(tensor([[[ 0.7248,  0.7248,  0.7248,  ..., -0.2342, -0.2342, -0.2342],
          [ 0.7248,  0.7248,  0.7248,  ..., -0.2342, -0.2342, -0.2342],
          [ 0.7248,  0.7248,  0.7248,  ..., -0.2342, -0.2342, -0.2342],
          ...,
          [ 1.8550,  1.8550,  1.8379,  ...,  0.4337,  0.3994,  0.3994],
          [ 1.8722,  1.8722,  1.8550,  ...,  0.4851,  0.4508,  0.4508],
          [ 1.8722,  1.8722,  1.8550,  ...,  0.4851,  0.4508,  0.4508]],
 
         [[ 0.5903,  0.5903,  0.5903,  ..., -0.1450, -0.1450, -0.1450],
          [ 0.5903,  0.5903,  0.5903,  ..., -0.1450, -0.1450, -0.1450],
          [ 0.5903,  0.5903,  0.5903,  ..., -0.1450, -0.1450, -0.1450],
          ...,
          [ 1.4132,  1.4132,  1.4132,  ...,  0.2577,  0.2227,  0.2227],
          [ 1.4307,  1.4307,  1.4307,  ...,  0.2927,  0.2752,  0.2752],
          [ 1.4307,  1.4307,  1.4307,  ...,  0.2927,  0.2752,  0.2752]],
 
         [[ 0.5834,  0.5834,  0.5659,  ...,  0.1476,  0.1651,  0.1651],
          [ 0.5834,  0.5834,

In [87]:

train_data.dataset

Dataset ImageFolder
    Number of datapoints: 27000
    Root location: /home/23dcs505/data/2750
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [88]:
train_data.indices[0]

18075

In [89]:
from torch.utils.data import Dataset

class create_dataset(Dataset):
    def __init__(self,data,way,shot,query,episode):
        super().__init__()
        self.data=data
        self.way=way
        self.shot=shot
        self.query=query
        self.episode=episode

        self.class_to_indices=self._build_class_index()
        self.classes=list(self.class_to_indices.keys())
        

    def _build_class_index(self):
        class_index={}

        targets=self.data.dataset.targets

        labels = [self.data.dataset.targets[i] for i in self.data.indices]
        


        for indexofsubset, indexoforiginal in enumerate(self.data.indices):
            label=targets[indexoforiginal]
            if label not in class_index:
                class_index[label]=[]
            class_index[label].append(indexofsubset)

        return class_index
        
    def __len__(self):
            return self.episode
        
    def __getitem__(self, idx):
        #print('hello')

        #print(f"Total available classes: {len(self.classes)}, requested way: {self.way}")


        selected_class=random.sample(self.classes,self.way)

        support_images, support_labels=[],[]
        query_images, query_labels=[],[]


        label_map={class_name: i for i, class_name in enumerate(selected_class)}

        for class_name in selected_class:
            all_indices_for_class=self.class_to_indices[class_name]

            selected_index=random.sample(all_indices_for_class,self.shot+self.query)

            support_index=selected_index[:self.shot]
            query_index=selected_index[self.shot:]

            for i in support_index:
                image,_=self.data[i]
                support_images.append(image)
                support_labels.append(torch.tensor(label_map[class_name]))
                
            for i in query_index:
                image,_=self.data[i]
                query_images.append(image)
                query_labels.append(torch.tensor(label_map[class_name]))
            
        return(
            torch.stack(support_images),
            torch.stack(support_labels),
            torch.stack(query_images),
            torch.stack(query_labels)
        )

In [90]:
def compute_prototypes(support_embeddings,support_labels,way):
    embedding_dimensions=support_embeddings.size(-1)
    prototypes=torch.zeros(way,embedding_dimensions).to(support_embeddings.device)

    for c in range(way):
        class_mask=(support_labels==c)
        class_embeddings=support_embeddings[class_mask]
        prototypes[c]=class_embeddings.mean(dim=0)
    return prototypes

def classify_queries(prototypes,query_embeddings):
    n_query=query_embeddings.size(0)
    way=prototypes.size(0)

    query_exp=query_embeddings.unsqueeze(1).expand(n_query,way,-1)
    prototypes_exp=prototypes.unsqueeze(0).expand(n_query,way,-1)

    distances=torch.sum((query_exp-prototypes_exp)**2,dim=2)

    logits=-distances
    return logits


In [91]:
import torch.optim as optim

few_dataset=create_dataset(
    data=train_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [92]:
few_dataloader=DataLoader(
    few_dataset,
    #batch_size=1,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [93]:
import torch
import torch.nn as nn
import torchvision.models as models

class VGGEmbedding(nn.Module):
    
    def __init__(self, embedding_dim=256):
        super().__init__()
        
        vgg = models.vgg16(pretrained=True)
        
        self.features = vgg.features
        self.avgpool = vgg.avgpool
        
        in_features = vgg.classifier[0].in_features
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 4096),
            nn.ReLU(True),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Linear(4096, embedding_dim)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x



model= VGGEmbedding(embedding_dim=256)


for param in model.parameters():
    param.requires_grad = False


for param in model.features[24:].parameters():
    param.requires_grad = True
    

for param in model.classifier.parameters():
    param.requires_grad = True


device = torch.device(f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu")
model=model.to(device)

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(trainable_params, lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

epochs=20

In [94]:
print("training on class :",train_list)

training on class : [7, 0, 5, 6, 3]


In [95]:
import torch
import gc

gc.collect()  # Python garbage collection
torch.cuda.empty_cache()  # Clear cache for current device


In [96]:

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    for episode in few_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        optimizer.zero_grad()
        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        loss=loss_fn(logits,query_labels)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)

Epoch: 1 ------------- Loss= 0.21087220745786908 Acccuracy= 94.12
Epoch: 2 ------------- Loss= 0.11109117899221019 Acccuracy= 97.04
Epoch: 3 ------------- Loss= 0.07731493228719558 Acccuracy= 97.78
Epoch: 4 ------------- Loss= 0.06289379324025333 Acccuracy= 98.36
Epoch: 5 ------------- Loss= 0.05937879223136406 Acccuracy= 98.28
Epoch: 6 ------------- Loss= 0.034718381384473494 Acccuracy= 99.14
Epoch: 7 ------------- Loss= 0.04964110765995429 Acccuracy= 98.61999999999999
Epoch: 8 ------------- Loss= 0.05351178248384713 Acccuracy= 98.76


KeyboardInterrupt: 

In [None]:
test_dataset=create_dataset(
    data=test_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [None]:
test_dataloader=DataLoader(
    test_dataset,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [None]:
print("testing on class :",test_list)

testing on class : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [None]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy_0=(total_correct/total_queries)*100
    print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Prototypical Network on all", len(test_list),"Class =",accuracy_0,"%")


Loss= 0.02886864752079802 Acccuracy on 10 Class = 78.52


In [None]:
strict_test_dataset=create_dataset(
    data=strict_test_data,
    way=strict_ways,
    shot=shots,
    query=queries,
    episode=200
)

In [None]:
strict_test_dataloader=DataLoader(
    strict_test_dataset,
    #batch_size=1,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [None]:
print("testing on class :",strict_test_list)

testing on class : [1, 2, 3, 6, 9]


In [None]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in strict_test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy_1=(total_correct/total_queries)*100
    print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Prototypical Network on Unseen", len(strict_test_list),"Class =",accuracy_1,"%")

5 way 3 shot Loss= 0.02886864752079802 Acccuracy on 5 Class = 60.58


**Stable Protypical Network**

In [None]:
import torch
import gc

# Delete all unused objects
gc.collect()

# Empty PyTorch CUDA cache
torch.cuda.empty_cache()

In [None]:
from dropblock import DropBlock2D

In [None]:
from torch.cuda.amp import autocast, GradScaler

In [None]:

class VGGEmbedding(nn.Module):
    def __init__(self, embedding_dim=256, drop_prob=0.3, block_size=3):
        super().__init__()
        vgg = models.vgg16(pretrained=True)
        
        features_list = []
        for layer in vgg.features:
            features_list.append(layer)
            if isinstance(layer, nn.MaxPool2d):
                if len(features_list) == 17 or len(features_list) == 24:
                    features_list.append(DropBlock2D(block_size=block_size, drop_prob=drop_prob))

        self.features = nn.Sequential(*features_list)
        self.avgpool = vgg.avgpool
        
       
        in_features = vgg.classifier[0].in_features
        
        
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, embedding_dim) 
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

model = VGGEmbedding(embedding_dim=256, drop_prob=0.3, block_size=3)

for param in model.parameters():
    param.requires_grad = False

for param in model.features[26:].parameters():
    param.requires_grad = True
    
for param in model.classifier.parameters():
    param.requires_grad = True


trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(trainable_params, lr=1e-5) # Use a smaller LR for fine-tuning

loss_fn = nn.CrossEntropyLoss()

device = torch.device(f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu")
model = model.to(device)


epochs=20

#From code for SPN
n_times=15
alpha=0.01

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    from tqdm.notebook import tqdm
    progress_bar=tqdm(few_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)


    for episode in progress_bar:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)


        optimizer.zero_grad(set_to_none=True)
        #For montecarlopass
        
        all_ce_losses = []
        all_query_logits = []

        for _ in range(n_times):
            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            ce_loss=loss_fn(logits,query_labels)
            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)
            
        total_ce_loss= torch.stack(all_ce_losses).mean()

        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=-1)

        true_class_probs = stacked_probs[
            torch.arange(n_times)[:, None],
            torch.arange(len(query_labels)),
            query_labels
        ]
            
        variance_loss=torch.std(true_class_probs,dim=0).sum()
        total_combined_loss=total_ce_loss+alpha*variance_loss

        

        
        total_combined_loss.backward()
        optimizer.step()


        total_loss+=total_combined_loss.item()
        mean_logits=stacked_logits.mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}&")
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)





Epoch 1/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1 ------------- Loss= 0.46970345072448255 Acccuracy= 83.94


Epoch 2/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2 ------------- Loss= 0.11294508117251098 Acccuracy= 97.98


Epoch 3/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3 ------------- Loss= 0.07424557770602405 Acccuracy= 98.78


Epoch 4/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4 ------------- Loss= 0.055583062248770146 Acccuracy= 99.16


Epoch 5/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5 ------------- Loss= 0.038318783300346694 Acccuracy= 99.64


Epoch 6/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6 ------------- Loss= 0.030326761977048592 Acccuracy= 99.7


Epoch 7/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7 ------------- Loss= 0.019098240821185754 Acccuracy= 99.8


Epoch 8/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8 ------------- Loss= 0.01500507457232743 Acccuracy= 99.9


Epoch 9/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9 ------------- Loss= 0.012381813611900725 Acccuracy= 99.96000000000001


Epoch 10/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10 ------------- Loss= 0.009562951392472313 Acccuracy= 99.98


Epoch 11/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11 ------------- Loss= 0.006080969567829015 Acccuracy= 100.0


Epoch 12/20:   0%|          | 0/200 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
alpha

0.01

In [None]:
print("Testing........... Started.......................")
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        stacked_logits=[]



        model.train()
        for _ in range(n_times):

            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            stacked_logits.append(logits)
        
        model.eval()
        mean_logits=torch.stack(stacked_logits).mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        torch.cuda.empty_cache()
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy_spn_0=(total_correct/total_queries)*100
    print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Stable Prototypical Network on all", len(test_list),"Class =",accuracy_spn_0,"%")


    entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())


Testing........... Started.......................
Loss= 0.006080969567829015 Acccuracy on 10 Class = 83.52000000000001
Mean Predictive Entropy = 0.5312187671661377


In [None]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in strict_test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        stacked_logits=[]



        model.train()
        for _ in range(n_times):

            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            stacked_logits.append(logits)
        
        model.eval()
        mean_logits=torch.stack(stacked_logits).mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        torch.cuda.empty_cache()
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy_spn_1=(total_correct/total_queries)*100
    print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Stable Prototypical Network on Unseen", len(strict_test_list),"Class =",accuracy_spn_1,"%")


    entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())
    


Loss= 0.006080969567829015 Acccuracy on 5 Class = 67.7
Mean Predictive Entropy = 0.4826270043849945


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

class VGGEmbeddingEncoder(nn.Module):
    """
    A simple VGG16-based feature extractor with no dropout.
    This will serve as the base encoder for all subsequent methods.
    """
    def __init__(self, embedding_dim=256):
        super().__init__()
        vgg = models.vgg16(pretrained=True)
        self.features = vgg.features
        self.avgpool = vgg.avgpool
        
        in_features = vgg.classifier[0].in_features
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 4096),
            nn.ReLU(True),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Linear(4096, embedding_dim)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

print("Reusable VGGEmbeddingEncoder class defined.")

In [None]:
class ContrastiveLoss(nn.Module):
    
    def __init__(self, margin=2.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        euclidean_distance = nn.functional.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((label) * torch.pow(euclidean_distance, 2) + 
                                      (1-label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

print("Siamese Network loss function defined.")

In [None]:

import torch.optim as optim
from tqdm.notebook import tqdm

print("--- Starting Siamese Network Training ---")

# 1. Initialize Model, Loss, and Optimizer
model = VGGEmbeddingEncoder(embedding_dim=256).to(device)

# Fine-tuning setup: freeze early layers, unfreeze block 5 and classifier
for param in model.parameters():
    param.requires_grad = False
for param in model.features[24:].parameters():
    param.requires_grad = True
for param in model.classifier.parameters():
    param.requires_grad = True

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(trainable_params, lr=1e-4)
loss_fn = ContrastiveLoss().to(device)
epochs = 20

# 2. Training Loop
for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    
    progress_bar = tqdm(few_dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
    
    for episode in progress_bar:
        support_images, support_labels, query_images, query_labels = episode
        
        # Combine support and query to create a larger pool for pairing
        all_images = torch.cat([support_images.squeeze(0), query_images.squeeze(0)], dim=0)
        all_labels = torch.cat([support_labels.view(-1), query_labels.view(-1)], dim=0)

        optimizer.zero_grad()
        
        # --- On-the-fly pair generation ---
        # For simplicity, we'll create a small number of pairs per episode.
        # A more advanced implementation would have a dedicated pair-based loader.
        img1_list, img2_list, labels_list = [], [], []
        for i in range(all_images.size(0)):
            # Create one positive and one negative pair for each image
            anchor_img, anchor_label = all_images[i], all_labels[i]
            
            # Find positive sample
            positive_indices = (all_labels == anchor_label).nonzero(as_tuple=True)[0]
            positive_idx = positive_indices[positive_indices != i][0]
            
            # Find negative sample
            negative_indices = (all_labels != anchor_label).nonzero(as_tuple=True)[0]
            negative_idx = random.choice(negative_indices)

            # Add positive pair
            img1_list.append(anchor_img)
            img2_list.append(all_images[positive_idx])
            labels_list.append(torch.tensor(1.0))

            # Add negative pair
            img1_list.append(anchor_img)
            img2_list.append(all_images[negative_idx])
            labels_list.append(torch.tensor(0.0))

        img1 = torch.stack(img1_list).to(device)
        img2 = torch.stack(img2_list).to(device)
        pair_labels = torch.stack(labels_list).to(device)
        
        # Forward pass
        output1 = model(img1)
        output2 = model(img2)
        
        loss = loss_fn(output1, output2, pair_labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        progress_bar.set_postfix(Loss=f"{loss.item():.4f}")

    epoch_loss = running_loss / len(few_dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss:.4f}")

In [None]:
siamese_shots = 1


test_dataset = create_dataset(
    data=test_data,
    way=ways,
    shot=siamese_shots, 
    query=queries, 
    episode=200
)
test_dataloader = DataLoader(test_dataset, shuffle=True, num_workers=8, pin_memory=True)

model.eval()
total_correct, total_queries = 0, 0

with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels = episode

        support_images = support_images.squeeze(0).to(device)
        query_images = query_images.squeeze(0).to(device)
        support_labels = support_labels.view(-1).to(device)
        query_labels = query_labels.view(-1).to(device)

        # In a Siamese test, the support set acts as the class prototypes
        support_embeddings = model(support_images)
        query_embeddings = model(query_images)
        
        # Since it's 1-shot, the embeddings are the prototypes
        prototypes = support_embeddings

        # Classify queries based on distance to support images (prototypes)
        logits = classify_queries(prototypes, query_embeddings)
        preds = torch.argmax(logits, dim=1)
        
        total_correct += (preds == query_labels).sum().item()
        total_queries += query_labels.size(0)

accuracy_siamese_0 = (total_correct / total_queries) * 100
print(ways,"way",shots,"siamese_shots:","Loss=",avg_loss,"Acccuracy of Siamese Network on all", len(test_list),"Class =",accuracy_siamese_0,"%")

In [None]:
siamese_shots = 1


strict_test_dataset = create_dataset(
    data=strict_test_data,
    way=ways,
    shot=siamese_shots, 
    query=queries, 
    episode=200
)
strict_test_dataloader = DataLoader(strict_test_dataset, shuffle=True, num_workers=8, pin_memory=True)

model.eval()
total_correct, total_queries = 0, 0

with torch.no_grad():
    for episode in strict_test_dataloader:
        support_images, support_labels, query_images, query_labels = episode

        support_images = support_images.squeeze(0).to(device)
        query_images = query_images.squeeze(0).to(device)
        support_labels = support_labels.view(-1).to(device)
        query_labels = query_labels.view(-1).to(device)

        # In a Siamese test, the support set acts as the class prototypes
        support_embeddings = model(support_images)
        query_embeddings = model(query_images)
        
        # Since it's 1-shot, the embeddings are the prototypes
        prototypes = support_embeddings

        # Classify queries based on distance to support images (prototypes)
        logits = classify_queries(prototypes, query_embeddings)
        preds = torch.argmax(logits, dim=1)
        
        total_correct += (preds == query_labels).sum().item()
        total_queries += query_labels.size(0)

accuracy_siamese_1 = (total_correct / total_queries) * 100
print(ways,"way",shots,"siamese_shots:","Loss=",avg_loss,"Acccuracy of Siamese Network on Unseen", len(strict_test_list),"Class =",accuracy_siamese_1,"%")

In [None]:
import torch.optim as optim
from tqdm.notebook import tqdm
import copy

print("--- Starting MAML Training ---")


model = VGGEmbeddingEncoder(embedding_dim=ways).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5) # Outer loop optimizer

epochs = 20
inner_update_steps = 5 # Number of gradient steps in the inner loop
inner_lr = 0.01 # Learning rate for the inner loop adaptation
loss_fn = nn.CrossEntropyLoss()

# 2. MAML Training Loop
for epoch in range(epochs):
    model.train()
    total_meta_loss = 0.0
    total_query_acc = 0.0
    
    progress_bar = tqdm(few_dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
    
    for episode in progress_bar:
        support_images, support_labels, query_images, query_labels = episode
        support_images, support_labels = support_images.squeeze(0).to(device), support_labels.view(-1).to(device)
        query_images, query_labels = query_images.squeeze(0).to(device), query_labels.view(-1).to(device)
        
        optimizer.zero_grad()
        
        
        fast_model = copy.deepcopy(model)
        fast_optimizer = optim.SGD(fast_model.parameters(), lr=inner_lr)
        
        
        for _ in range(inner_update_steps):
            support_preds = fast_model(support_images)
            inner_loss = loss_fn(support_preds, support_labels)
            fast_optimizer.zero_grad()
            inner_loss.backward(retain_graph=True)
            fast_optimizer.step()
        
        
        query_preds = fast_model(query_images)
        meta_loss = loss_fn(query_preds, query_labels)
        
        
        meta_loss.backward()
        optimizer.step()
        
        total_loss += meta_loss.item()
        total_acc += (query_preds.argmax(dim=1) == query_labels).sum().item() / len(query_labels)
        progress_bar.set_postfix(MetaLoss=f"{meta_loss.item():.4f}")

    avg_meta_loss = total_meta_loss / len(few_dataloader)
    accuracy = (total_query_acc / len(few_dataloader)) * 100
    print(f"Epoch {epoch+1}/{epochs}, Meta-Loss: {avg_meta_loss:.4f}, Query Accuracy: {accuracy:.2f}%")

In [None]:
# --- MAML: Evaluation Loop ---
print("--- Evaluating MAML ---")


test_dataset = create_dataset(
    data=test_data,
    way=ways, 
    shot=shots, 
    query=queries,
    episode=200
)
test_dataloader = DataLoader(test_dataset, shuffle=True, num_workers=8, pin_memory=True)


model.eval()
total_correct, total_queries = 0, 0

#with torch.no_grad():
for episode in test_dataloader:
    support_images, support_labels, query_images, query_labels = episode
    support_images, support_labels = support_images.squeeze(0).to(device), support_labels.view(-1).to(device)
    query_images, query_labels = query_images.squeeze(0).to(device), query_labels.view(-1).to(device)

    # Create a temporary model and adapt it to the new, unseen task
    test_fast_model = copy.deepcopy(model)
    test_fast_model.train() # Enable gradients for adaptation
    test_fast_optimizer = optim.SGD(test_fast_model.parameters(), lr=inner_lr)

    # Adapt the model using the support set of the unseen task
    for _ in range(inner_update_steps * 2): # Use more update steps for testing
        support_preds = test_fast_model(support_images)
        inner_loss = loss_fn(support_preds, support_labels)
        test_fast_optimizer.zero_grad()
        inner_loss.backward()
        test_fast_optimizer.step()
    
    # Evaluate the adapted model on the query set
    test_fast_model.eval()
    query_preds = test_fast_model(query_images)
    preds = torch.argmax(query_preds, dim=1)
    
    total_correct += (preds == query_labels).sum().item()
    total_queries += query_labels.size(0)

accuracy_maml_0 = (total_correct / total_queries) * 100
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of MAML Network on all", len(test_list),"Class =",accuracy_maml_0,"%")

In [None]:
# --- MAML: Evaluation Loop ---
print("--- Evaluating MAML ---")


strict_test_dataset = create_dataset(
    data=strict_test_data,
    way=ways, 
    shot=shots, 
    query=queries,
    episode=200
)
strict_test_dataloader = DataLoader(strict_test_dataset, shuffle=True, num_workers=8, pin_memory=True)


model.eval()
total_correct, total_queries = 0, 0

#with torch.no_grad():
for episode in strict_test_dataloader:
    support_images, support_labels, query_images, query_labels = episode
    support_images, support_labels = support_images.squeeze(0).to(device), support_labels.view(-1).to(device)
    query_images, query_labels = query_images.squeeze(0).to(device), query_labels.view(-1).to(device)

    
    test_fast_model = copy.deepcopy(model)
    test_fast_model.train() 
    test_fast_optimizer = optim.SGD(test_fast_model.parameters(), lr=inner_lr)

    
    for _ in range(inner_update_steps * 2): 
        support_preds = test_fast_model(support_images)
        inner_loss = loss_fn(support_preds, support_labels)
        test_fast_optimizer.zero_grad()
        inner_loss.backward()
        test_fast_optimizer.step()
    
    # Evaluate the adapted model on the query set
    test_fast_model.eval()
    query_preds = test_fast_model(query_images)
    preds = torch.argmax(query_preds, dim=1)
    
    total_correct += (preds == query_labels).sum().item()
    total_queries += query_labels.size(0)

accuracy_maml_1 = (total_correct / total_queries) * 100
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of MAML Network on Unseen", len(strict_test_list),"Class =",accuracy_maml_1,"%")

In [None]:
# --- Reptile: Training Loop ---
import torch.optim as optim
from tqdm.notebook import tqdm
import copy

print("--- Starting Reptile Training ---")

model = VGGEmbeddingEncoder(embedding_dim=ways).to(device)

epochs = 20
inner_update_steps_reptile = 5 # Number of gradient steps on the task
inner_lr_reptile = 0.01
meta_step_size = 0.5 # How far to move the meta-model towards the task-specific weights
loss_fn_reptile = nn.CrossEntropyLoss()

# 2. Reptile Training Loop
for epoch in range(epochs):
    model.train()
    total_loss_reptile = 0.0
    
    progress_bar = tqdm(few_dataloader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
    
    # Store original weights before the episode batch
    original_weights = {name: param.clone() for name, param in model.named_parameters()}
    
    for episode in progress_bar:
        support_images, support_labels, query_images, query_labels = episode
        support_images, support_labels = support_images.squeeze(0).to(device), support_labels.view(-1).to(device)
        
        # Create a temporary model to train on the task
        task_model = copy.deepcopy(model)
        task_optimizer = optim.SGD(model.parameters(), lr=inner_lr_reptile)

        # Train the task_model on the support set
        for _ in range(inner_update_steps_reptile):
            support_preds = task_model(support_images)
            task_loss = loss_fn_reptile(support_preds, support_labels)
            task_optimizer.zero_grad()
            task_loss.backward()
            task_optimizer.step()
        
        total_loss_reptile += task_loss.item() # For logging purposes
        
        # --- Reptile Meta-Update ---
        # Interpolate the meta-model's weights towards the task-model's new weights
        with torch.no_grad():
            for name, param in model.named_parameters():
                task_param = task_model.state_dict()[name]
                param.data.lerp_(task_param, meta_step_size)
        
        progress_bar.set_postfix(TaskLoss=f"{task_loss.item():.4f}")

    avg_task_loss = total_loss_reptile / len(few_dataloader)
    print(f"Epoch {epoch+1}/{epochs}, Avg Task Loss: {avg_task_loss:.4f}")

In [None]:
model.eval()
total_correct, total_queries = 0, 0

#with torch.no_grad():
for episode in test_dataloader:
    support_images, support_labels, query_images, query_labels = episode
    support_images, support_labels = support_images.squeeze(0).to(device), support_labels.view(-1).to(device)
    query_images, query_labels = query_images.squeeze(0).to(device), query_labels.view(-1).to(device)

    # Create a temporary model and fine-tune it on the new, unseen task
    test_task_model = copy.deepcopy(model)
    test_task_model.train() # Enable gradients for fine-tuning
    test_task_optimizer = optim.SGD(test_task_model.parameters(), lr=inner_lr_reptile)

    # Fine-tune the model using the support set of the unseen task
    for _ in range(inner_update_steps_reptile * 2): # Use more update steps for testing
        support_preds = test_task_model(support_images)
        inner_loss = loss_fn_reptile(support_preds, support_labels)
        test_task_optimizer.zero_grad()
        inner_loss.backward()
        test_task_optimizer.step()
    
    # Evaluate the fine-tuned model on the query set
    test_task_model.eval()
    query_preds = test_task_model(query_images)
    preds = torch.argmax(query_preds, dim=1)
    
    total_correct += (preds == query_labels).sum().item()
    total_queries += query_labels.size(0)

accuracy_reptile_0 = (total_correct / total_queries) * 100
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Reptile Network on all", len(strict_test_list),"Class =",accuracy_reptile_0,"%")

In [None]:
model.eval()
total_correct, total_queries = 0, 0

#with torch.no_grad():
for episode in strict_test_dataloader:
    support_images, support_labels, query_images, query_labels = episode
    support_images, support_labels = support_images.squeeze(0).to(device), support_labels.view(-1).to(device)
    query_images, query_labels = query_images.squeeze(0).to(device), query_labels.view(-1).to(device)

    # Create a temporary model and fine-tune it on the new, unseen task
    test_task_model = copy.deepcopy(model)
    test_task_model.train() # Enable gradients for fine-tuning
    test_task_optimizer = optim.SGD(test_task_model.parameters(), lr=inner_lr_reptile)

    # Fine-tune the model using the support set of the unseen task
    for _ in range(inner_update_steps_reptile * 2): # Use more update steps for testing
        support_preds = test_task_model(support_images)
        inner_loss = loss_fn_reptile(support_preds, support_labels)
        test_task_optimizer.zero_grad()
        inner_loss.backward()
        test_task_optimizer.step()
    
    # Evaluate the fine-tuned model on the query set
    test_task_model.eval()
    query_preds = test_task_model(query_images)
    preds = torch.argmax(query_preds, dim=1)
    
    total_correct += (preds == query_labels).sum().item()
    total_queries += query_labels.size(0)

accuracy_reptile_1 = (total_correct / total_queries) * 100
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Reptile Network on Unseen", len(strict_test_list),"Class =",accuracy_reptile_1,"%")

In [None]:
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Prototypical Network on all", len(test_list),"Class =",accuracy_0,"%")
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Prototypical Network on Unseen", len(strict_test_list),"Class =",accuracy_1,"%")
print('\n\n\n')
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Stable Prototypical Network on all", len(test_list),"Class =",accuracy_spn_0,"%")
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Stable Prototypical Network on Unseen", len(strict_test_list),"Class =",accuracy_spn_1,"%")
print('\n\n\n')
print(ways,"way",shots,"siamese_shots:","Loss=",avg_loss,"Acccuracy of Siamese Network on all", len(test_list),"Class =",accuracy_siamese_0,"%")
print(ways,"way",shots,"siamese_shots:","Loss=",avg_loss,"Acccuracy of Siamese Network on Unseen", len(strict_test_list),"Class =",accuracy_siamese_1,"%")
print('\n\n\n')
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of MAML Network on all", len(test_list),"Class =",accuracy_maml_0,"%")
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of MAML Network on Unseen", len(strict_test_list),"Class =",accuracy_maml_1,"%")
print('\n\n\n')
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Reptile Network on all", len(strict_test_list),"Class =",accuracy_reptile_0,"%")
print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Reptile Network on Unseen", len(strict_test_list),"Class =",accuracy_reptile_1,"%")

In [None]:
print('\n\n\n')


/n/n/n/n
