In [1]:
import torchvision
import torch
import os
import random
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [2]:
from torchvision.datasets.utils import download_and_extract_archive
from torchvision.datasets.folder import ImageFolder

In [3]:
data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

In [4]:
if not os.path.exists("/home/23dcs505/data/2750"):
    print("No dataset found")
fulldata=ImageFolder(root='/home/23dcs505/data/2750', transform=data_transform)


In [5]:
from torch.utils.data import random_split

train_len=int((0.8)*len(fulldata))
test_len=len(fulldata)-(train_len)

train_data_set,test_data_set= random_split(fulldata,[train_len, test_len])

In [6]:
train_class=[0,1,2,3,4,5,6,7]
test_class=[0,1,2,3,4,5,6,7,8,9]

In [7]:
from torch.utils.data import Subset

In [8]:
train_data_set.indices[0]

1463

In [9]:
def class_sorting(dataset, class_list):
    targets = dataset.dataset.targets

    indices= [i for i in dataset.indices if targets[i] in class_list]
    return Subset(dataset.dataset, indices)
    

In [10]:
train_data=class_sorting(train_data_set,train_class)
test_data=class_sorting(test_data_set,test_class)

In [11]:
train_data[0]

(tensor([[[ 2.1290,  2.1290,  2.1290,  ..., -0.4739, -0.4568, -0.4568],
          [ 2.1290,  2.1290,  2.1290,  ..., -0.4739, -0.4568, -0.4568],
          [ 2.1290,  2.1290,  2.1290,  ..., -0.4739, -0.4568, -0.4568],
          ...,
          [ 2.2318,  2.2318,  2.2318,  ...,  1.8379,  1.8379,  1.8379],
          [ 2.2318,  2.2318,  2.2318,  ...,  1.8379,  1.8379,  1.8379],
          [ 2.2318,  2.2318,  2.2318,  ...,  1.8379,  1.8379,  1.8379]],
 
         [[ 1.1331,  1.1331,  1.1331,  ..., -0.1450, -0.1275, -0.1275],
          [ 1.1331,  1.1331,  1.1331,  ..., -0.1450, -0.1275, -0.1275],
          [ 1.1331,  1.1331,  1.1331,  ..., -0.1450, -0.1275, -0.1275],
          ...,
          [ 1.3431,  1.3431,  1.3431,  ...,  0.9930,  0.9930,  0.9930],
          [ 1.3431,  1.3431,  1.3431,  ...,  0.9930,  0.9930,  0.9930],
          [ 1.3431,  1.3431,  1.3431,  ...,  0.9930,  0.9930,  0.9930]],
 
         [[ 0.8797,  0.8797,  0.8797,  ...,  0.0082,  0.0256,  0.0256],
          [ 0.8797,  0.8797,

In [12]:

train_data.dataset

Dataset ImageFolder
    Number of datapoints: 27000
    Root location: /home/23dcs505/data/2750
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [13]:
train_data.indices[0]

1463

In [14]:
from torch.utils.data import Dataset

class create_dataset(Dataset):
    def __init__(self,data,way,shot,query,episode):
        super().__init__()
        self.data=data
        self.way=way
        self.shot=shot
        self.query=query
        self.episode=episode

        self.class_to_indices=self._build_class_index()
        self.classes=list(self.class_to_indices.keys())
        

    def _build_class_index(self):
        class_index={}

        targets=self.data.dataset.targets

        labels = [self.data.dataset.targets[i] for i in self.data.indices]
        


        for indexofsubset, indexoforiginal in enumerate(self.data.indices):
            label=targets[indexoforiginal]
            if label not in class_index:
                class_index[label]=[]
            class_index[label].append(indexofsubset)

        return class_index
        
    def __len__(self):
            return self.episode
        
    def __getitem__(self, idx):
        #print('hello')

        #print(f"Total available classes: {len(self.classes)}, requested way: {self.way}")


        selected_class=random.sample(self.classes,self.way)

        support_images, support_labels=[],[]
        query_images, query_labels=[],[]


        label_map={class_name: i for i, class_name in enumerate(selected_class)}

        for class_name in selected_class:
            all_indices_for_class=self.class_to_indices[class_name]

            selected_index=random.sample(all_indices_for_class,self.shot+self.query)

            support_index=selected_index[:self.shot]
            query_index=selected_index[self.shot:]

            for i in support_index:
                image,_=self.data[i]
                support_images.append(image)
                support_labels.append(torch.tensor(label_map[class_name]))
                
            for i in query_index:
                image,_=self.data[i]
                query_images.append(image)
                query_labels.append(torch.tensor(label_map[class_name]))
            
        return(
            torch.stack(support_images),
            torch.stack(support_labels),
            torch.stack(query_images),
            torch.stack(query_labels)
        )

In [15]:
def compute_prototypes(support_embeddings,support_labels,way):
    embedding_dimensions=support_embeddings.size(-1)
    prototypes=torch.zeros(way,embedding_dimensions).to(support_embeddings.device)

    for c in range(way):
        class_mask=(support_labels==c)
        class_embeddings=support_embeddings[class_mask]
        prototypes[c]=class_embeddings.mean(dim=0)
    return prototypes

def classify_queries(prototypes,query_embeddings):
    n_query=query_embeddings.size(0)
    way=prototypes.size(0)

    query_exp=query_embeddings.unsqueeze(1).expand(n_query,way,-1)
    prototypes_exp=prototypes.unsqueeze(0).expand(n_query,way,-1)

    distances=torch.sum((query_exp-prototypes_exp)**2,dim=2)

    logits=-distances
    return logits


In [16]:
import torch.optim as optim

few_dataset=create_dataset(
    data=train_data,
    way=5,
    shot=3,
    query=5,
    episode=200
)

In [17]:
few_dataloader=DataLoader(
    few_dataset,
    #batch_size=1,
    shuffle=True
)

In [18]:
import torchvision.models as models
vgg=models.vgg16(pretrained=True)

In [19]:
class VGGEmbedding(nn.Module):
    def __init__(self):
        super().__init__()


        # ##Code from the paper
        # features_list = list(vgg.features.children())
        # # Insert DropBlock after MaxPool at index 16
        # features_list.insert(17, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # # Insert DropBlock after MaxPool at index 23 (now 24 due to previous insertion)
        # features_list.insert(24, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # ##END


        self.features=vgg.features
        self.avgpool=vgg.avgpool

        self.classifier=nn.Sequential(*list(vgg.classifier.children())[:-1])

    def forward(self,x):
        x=self.features(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x

In [21]:
model=VGGEmbedding()

for param in model.features.parameters():
    param.requires_grad=False

model.classifier[3]=nn.Linear(model.classifier[3].in_features,256)

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=model.to(device)

optimizer=optim.Adam(model.parameters(),lr=1e-4)
loss_fn=nn.CrossEntropyLoss()

epochs=20


In [None]:

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    for episode in few_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)

        optimizer.zero_grad()
        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        loss=loss_fn(logits,query_labels)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)

Epoch: 1 ------------- Loss= 1.8508246785402298 Acccuracy= 50.28
Epoch: 2 ------------- Loss= 1.449593073129654 Acccuracy= 50.160000000000004
Epoch: 3 ------------- Loss= 1.3273231509327887 Acccuracy= 54.900000000000006
Epoch: 4 ------------- Loss= 1.164746471941471 Acccuracy= 60.699999999999996
Epoch: 5 ------------- Loss= 0.9329160003364086 Acccuracy= 69.86
Epoch: 6 ------------- Loss= 0.785576047860086 Acccuracy= 75.08
Epoch: 7 ------------- Loss= 0.6679948568344116 Acccuracy= 78.74
Epoch: 8 ------------- Loss= 0.7251726684719324 Acccuracy= 78.44
Epoch: 9 ------------- Loss= 0.6248822312429547 Acccuracy= 80.82000000000001
Epoch: 10 ------------- Loss= 0.6131166943255812 Acccuracy= 80.76
Epoch: 11 ------------- Loss= 0.5130507605150342 Acccuracy= 84.82


In [None]:
test_dataset=create_dataset(
    data=test_data,
    way=5,
    shot=3,
    query=5,
    episode=200
)

In [None]:
test_dataloader=DataLoader(
    few_dataset,
    shuffle=True
)

In [None]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)

        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy=",accuracy)

Epoch: 20 ------------- Loss= 0.11725419672846328 Acccuracy= 98.42


**Stable Protypical Network**

In [None]:
import torch
import gc

# Delete all unused objects
gc.collect()

# Empty PyTorch CUDA cache
torch.cuda.empty_cache()

In [None]:
from dropblock import DropBlock2D

In [None]:
import torchvision.models as models
vgg=models.vgg16(pretrained=True)
class VGGEmbedding(nn.Module):
    def __init__(self,drop_prob=0.3, block_size=5):
        super().__init__()


        ##Code from the paper
        features_list = list(vgg.features.children())
        # Insert DropBlock after MaxPool at index 16
        features_list.insert(17, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # Insert DropBlock after MaxPool at index 23 (now 24 due to previous insertion)
        features_list.insert(24, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        ##END


        self.features=vgg.features
        self.avgpool=vgg.avgpool

        self.classifier=nn.Sequential(*list(vgg.classifier.children())[:-1])

    def forward(self,x):
        x=self.features(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x

model=VGGEmbedding()

for param in model.features.parameters():
    param.requires_grad=False

model.classifier[3]=nn.Linear(model.classifier[3].in_features,256)

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
model=model.to(device)

optimizer=optim.Adam(model.parameters(),lr=1e-4)
loss_fn=nn.CrossEntropyLoss()

epochs=20

#From code for SPN
n_times=5
alpha=0.1

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    for episode in few_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)


        #For montecarlopass
        all_ce_losses = []
        all_query_logits = []

        for _ in range(n_times):

            #optimizer.zero_grad()
            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            ce_loss=loss_fn(logits,query_labels)
            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)
        
        total_ce_loss= torch.stack(all_ce_losses).mean()

        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=1)

        true_class_probs=stacked_probs[:, torch.arange(len(query_labels)),query_labels]
        variance_loss=torch.std(true_class_probs,dim=0).mean()

        total_combined_loss=total_ce_loss+alpha*variance_loss

        optimizer.zero_grad()
        total_combined_loss.backward()
        optimizer.step()

        mean_logits=stacked_logits.mean(dim=0)
        total_loss+=total_combined_loss.item()
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)




test_dataset=create_dataset(
    data=test_data,
    way=5,
    shot=3,
    query=5,
    episode=200
)
test_dataloader=DataLoader(
    few_dataset,
    shuffle=True
)

model.train()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)

        stacked_logits=[]

        for _ in range(n_times):

            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            stacked_logits.append(logits)
        
        mean_logits=torch.stack(stacked_logits).mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        torch.cuda.empty_cache()
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy=",accuracy)


    entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())


RuntimeError: CUDA out of memory. Tried to allocate 392.00 MiB (GPU 0; 31.73 GiB total capacity; 27.89 GiB already allocated; 2.44 MiB free; 27.94 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
!nvidia-smi

Sat Jun 21 04:00:05 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.144.03             Driver Version: 550.144.03     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla V100-SXM2-32GB           On  |   00000000:06:00.0 Off |                    0 |
| N/A   48C    P0            179W /  300W |    2480MiB /  32768MiB |     90%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  Tesla V100-SXM2-32GB           On  |   00