In [1]:
import torch
print("PyTorch sees", torch.cuda.device_count(), "GPUs")


PyTorch sees 8 GPUs


In [2]:
import os
print("CPU cores available:", os.cpu_count())


CPU cores available: 80


In [3]:
import torchvision
import torch
import os
import random
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [4]:
from torchvision.datasets.utils import download_and_extract_archive
from torchvision.datasets.folder import ImageFolder

In [5]:
data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

In [6]:
if not os.path.exists("/home/23dcs505/data/2750"):
    print("No dataset found")
fulldata=ImageFolder(root='/home/23dcs505/data/2750', transform=data_transform)


In [7]:
from torch.utils.data import random_split

train_len=int((0.8)*len(fulldata))
test_len=len(fulldata)-(train_len)

train_data_set,test_data_set= random_split(fulldata,[train_len, test_len])

In [8]:
all_list=[0,1,2,3,4,5,6,7,8,9]

In [9]:
train_class_len=3

In [10]:
train_list=random.sample(all_list,train_class_len)
test_list=list(range(0,10))
strict_test_list=list(set(all_list) - set(train_list))


In [11]:
print(train_list)
print(test_list)
print(strict_test_list)

[5, 9, 0]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[1, 2, 3, 4, 6, 7, 8]


In [12]:
ways=3
shots=5
queries=5
strict_ways=len(strict_test_list)

In [13]:
from torch.utils.data import Subset

In [14]:
train_data_set.indices[0]

16085

In [15]:
def class_sorting(dataset, class_list):
    targets = dataset.dataset.targets

    indices= [i for i in dataset.indices if targets[i] in class_list]
    return Subset(dataset.dataset, indices)
    

In [16]:
train_data=class_sorting(train_data_set,train_list)
test_data=class_sorting(test_data_set,test_list)
strict_test_data=class_sorting(test_data,strict_test_list)

In [17]:
train_data[0]

(tensor([[[-1.6213, -1.6213, -1.6213,  ..., -1.6384, -1.6384, -1.6384],
          [-1.6213, -1.6213, -1.6213,  ..., -1.6384, -1.6384, -1.6384],
          [-1.6213, -1.6213, -1.6213,  ..., -1.6384, -1.6384, -1.6384],
          ...,
          [-1.6213, -1.6213, -1.6213,  ..., -1.6042, -1.6042, -1.6042],
          [-1.6213, -1.6213, -1.6213,  ..., -1.6042, -1.6042, -1.6042],
          [-1.6213, -1.6213, -1.6213,  ..., -1.6042, -1.6042, -1.6042]],
 
         [[-1.1604, -1.1604, -1.1604,  ..., -1.1604, -1.1604, -1.1604],
          [-1.1604, -1.1604, -1.1604,  ..., -1.1604, -1.1604, -1.1604],
          [-1.1604, -1.1604, -1.1604,  ..., -1.1604, -1.1604, -1.1604],
          ...,
          [-1.1429, -1.1429, -1.1429,  ..., -1.1604, -1.1604, -1.1604],
          [-1.1429, -1.1429, -1.1429,  ..., -1.1604, -1.1604, -1.1604],
          [-1.1429, -1.1429, -1.1429,  ..., -1.1604, -1.1604, -1.1604]],
 
         [[-0.4624, -0.4624, -0.4624,  ..., -0.5147, -0.5147, -0.5147],
          [-0.4624, -0.4624,

In [18]:

train_data.dataset

Dataset ImageFolder
    Number of datapoints: 27000
    Root location: /home/23dcs505/data/2750
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [19]:
train_data.indices[0]

25019

In [20]:
from torch.utils.data import Dataset

class create_dataset(Dataset):
    def __init__(self,data,way,shot,query,episode):
        super().__init__()
        self.data=data
        self.way=way
        self.shot=shot
        self.query=query
        self.episode=episode

        self.class_to_indices=self._build_class_index()
        self.classes=list(self.class_to_indices.keys())
        

    def _build_class_index(self):
        class_index={}

        targets=self.data.dataset.targets

        labels = [self.data.dataset.targets[i] for i in self.data.indices]
        


        for indexofsubset, indexoforiginal in enumerate(self.data.indices):
            label=targets[indexoforiginal]
            if label not in class_index:
                class_index[label]=[]
            class_index[label].append(indexofsubset)

        return class_index
        
    def __len__(self):
            return self.episode
        
    def __getitem__(self, idx):
        #print('hello')

        #print(f"Total available classes: {len(self.classes)}, requested way: {self.way}")


        selected_class=random.sample(self.classes,self.way)

        support_images, support_labels=[],[]
        query_images, query_labels=[],[]


        label_map={class_name: i for i, class_name in enumerate(selected_class)}

        for class_name in selected_class:
            all_indices_for_class=self.class_to_indices[class_name]

            selected_index=random.sample(all_indices_for_class,self.shot+self.query)

            support_index=selected_index[:self.shot]
            query_index=selected_index[self.shot:]

            for i in support_index:
                image,_=self.data[i]
                support_images.append(image)
                support_labels.append(torch.tensor(label_map[class_name]))
                
            for i in query_index:
                image,_=self.data[i]
                query_images.append(image)
                query_labels.append(torch.tensor(label_map[class_name]))
            
        return(
            torch.stack(support_images),
            torch.stack(support_labels),
            torch.stack(query_images),
            torch.stack(query_labels)
        )

In [21]:
def compute_prototypes(support_embeddings,support_labels,way):
    embedding_dimensions=support_embeddings.size(-1)
    prototypes=torch.zeros(way,embedding_dimensions).to(support_embeddings.device)

    for c in range(way):
        class_mask=(support_labels==c)
        class_embeddings=support_embeddings[class_mask]
        prototypes[c]=class_embeddings.mean(dim=0)
    return prototypes

def classify_queries(prototypes,query_embeddings):
    n_query=query_embeddings.size(0)
    way=prototypes.size(0)

    query_exp=query_embeddings.unsqueeze(1).expand(n_query,way,-1)
    prototypes_exp=prototypes.unsqueeze(0).expand(n_query,way,-1)

    distances=torch.sum((query_exp-prototypes_exp)**2,dim=2)

    logits=-distances
    return logits


In [22]:
import torch.optim as optim

few_dataset=create_dataset(
    data=train_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [23]:
few_dataloader=DataLoader(
    few_dataset,
    #batch_size=1,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [24]:
import torchvision.models as models
vgg=models.vgg16(pretrained=True)

In [25]:
class VGGEmbedding(nn.Module):
    def __init__(self):
        super().__init__()


        # ##Code from the paper
        # features_list = list(vgg.features.children())
        # # Insert DropBlock after MaxPool at index 16
        # features_list.insert(17, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # # Insert DropBlock after MaxPool at index 23 (now 24 due to previous insertion)
        # features_list.insert(24, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # ##END


        self.features=vgg.features
        self.avgpool=vgg.avgpool

        self.classifier=nn.Sequential(*list(vgg.classifier.children())[:-1])

    def forward(self,x):
        x=self.features(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x

In [26]:
model=VGGEmbedding()


# 1. Freeze all layers first
for param in model.parameters():
    param.requires_grad = False

# 2. Unfreeze the last convolutional block (block 5)
# The VGG16 features list has 31 layers. Block 5 starts at index 24.
for param in model.features[24:].parameters():
    param.requires_grad = True

# 3. Unfreeze the classifier and replace the last layer
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 256)

for param in model.classifier.parameters():
    param.requires_grad = True

device=torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model=model.to(device)

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(trainable_params, lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

epochs=20


In [27]:
print("training on class :",train_list)

training on class : [5, 9, 0]


In [28]:
import torch
import gc

gc.collect()  # Python garbage collection
torch.cuda.empty_cache()  # Clear cache for current device


In [29]:

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    for episode in few_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        optimizer.zero_grad()
        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        loss=loss_fn(logits,query_labels)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)

Epoch: 1 ------------- Loss= 0.22790205305151176 Acccuracy= 93.76666666666667
Epoch: 2 ------------- Loss= 0.07753565326440366 Acccuracy= 97.36666666666667
Epoch: 3 ------------- Loss= 0.09333320378897952 Acccuracy= 97.16666666666667
Epoch: 4 ------------- Loss= 0.10415965393103484 Acccuracy= 97.66666666666667
Epoch: 5 ------------- Loss= 0.04421461799503391 Acccuracy= 98.73333333333333
Epoch: 6 ------------- Loss= 0.04410038689420844 Acccuracy= 98.9
Epoch: 7 ------------- Loss= 0.02829293905401755 Acccuracy= 99.33333333333333
Epoch: 8 ------------- Loss= 0.015706479383538863 Acccuracy= 99.66666666666667
Epoch: 9 ------------- Loss= 0.054929374042822235 Acccuracy= 98.86666666666667
Epoch: 10 ------------- Loss= 0.030878585080907895 Acccuracy= 99.03333333333333
Epoch: 11 ------------- Loss= 0.11672872548387801 Acccuracy= 97.8
Epoch: 12 ------------- Loss= 0.04096086663189631 Acccuracy= 98.96666666666667
Epoch: 13 ------------- Loss= 0.013386597071989853 Acccuracy= 99.63333333333333
Epoc

In [30]:
test_dataset=create_dataset(
    data=test_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [31]:
test_dataloader=DataLoader(
    test_dataset,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [32]:
print("testing on class :",test_list)

testing on class : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [33]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy on", len(test_list),"Class =",accuracy)

Loss= 0.028705596570492452 Acccuracy on 10 Class = 66.43333333333334


In [34]:
strict_test_dataset=create_dataset(
    data=strict_test_data,
    way=strict_ways,
    shot=shots,
    query=queries,
    episode=200
)

In [35]:
strict_test_dataloader=DataLoader(
    strict_test_dataset,
    #batch_size=1,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [36]:
print("testing on class :",strict_test_list)

testing on class : [1, 2, 3, 4, 6, 7, 8]


In [37]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in strict_test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy on", len(strict_test_list),"Class =",accuracy)

Loss= 0.028705596570492452 Acccuracy on 7 Class = 29.271428571428572


**Stable Protypical Network**

In [38]:
import torch
import gc

# Delete all unused objects
gc.collect()

# Empty PyTorch CUDA cache
torch.cuda.empty_cache()

In [39]:
from dropblock import DropBlock2D

In [40]:
from torch.cuda.amp import autocast, GradScaler

In [41]:
import torchvision.models as models
vgg=models.vgg16(pretrained=True)
class VGGEmbedding(nn.Module):
    def __init__(self,drop_prob=0.3, block_size=5):
        super().__init__()


        ##Code from the paper
        features_list = list(vgg.features.children())
        # Insert DropBlock after MaxPool at index 16
        features_list.insert(17, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # Insert DropBlock after MaxPool at index 23 (now 24 due to previous insertion)
        features_list.insert(24, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        ##END


        self.features=nn.Sequential(*features_list)
        self.avgpool=vgg.avgpool

        self.classifier=nn.Sequential(*list(vgg.classifier.children())[:-1])

    def forward(self,x):
        x=self.features(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x

model=VGGEmbedding()

for param in model.parameters():
    param.requires_grad = False

for layer in model.features[26:]: # Layers from the original Block 5 onwards
    # Check if the layer has parameters (e.g., Conv2d does, MaxPool2d doesn't)
    if hasattr(layer, 'parameters'):
        for param in layer.parameters():
            param.requires_grad = True

# 3. Unfreeze the classifier and ensure the final layer is replaced
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 256)
for param in model.classifier.parameters():
    param.requires_grad = True

# 4. Set up optimizer with ONLY the trainable parameters
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model = model.to(device)

trainable_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(trainable_params, lr=1e-5) # Use a smaller LR for fine-tuning
loss_fn = nn.CrossEntropyLoss()


epochs=20

#From code for SPN
n_times=5
alpha=1.0

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    from tqdm.notebook import tqdm
    progress_bar=tqdm(few_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)


    for episode in progress_bar:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)


        optimizer.zero_grad(set_to_none=True)
        #For montecarlopass
        
        all_ce_losses = []
        all_query_logits = []

        for _ in range(n_times):
            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            ce_loss=loss_fn(logits,query_labels)
            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)
            
        total_ce_loss= torch.stack(all_ce_losses).sum()

        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=-1)

        true_class_probs = stacked_probs[
            torch.arange(n_times)[:, None],
            torch.arange(len(query_labels)),
            query_labels
        ]
            
        variance_loss=torch.std(true_class_probs,dim=0).mean()
        total_combined_loss=total_ce_loss+alpha*variance_loss

        

        
        total_combined_loss.backward()
        optimizer.step()

        mean_logits=stacked_logits.mean(dim=0)
        total_loss+=total_combined_loss.item()
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}&")
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)

print("Testing........... Started.......................")

model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        stacked_logits=[]



        model.train()
        for _ in range(n_times):

            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            stacked_logits.append(logits)
        
        model.eval()
        mean_logits=torch.stack(stacked_logits).mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        torch.cuda.empty_cache()
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy on", len(test_list),"Class =",accuracy)


    entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())


Epoch 1/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1 ------------- Loss= 2.538295323550701 Acccuracy= 84.26666666666667


Epoch 2/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2 ------------- Loss= 0.8072832859493793 Acccuracy= 94.6


Epoch 3/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3 ------------- Loss= 0.5867871993314475 Acccuracy= 96.83333333333334


Epoch 4/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4 ------------- Loss= 0.3597478022542782 Acccuracy= 98.13333333333333


Epoch 5/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5 ------------- Loss= 0.3092515488702338 Acccuracy= 98.0


Epoch 6/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6 ------------- Loss= 0.23891570526640862 Acccuracy= 98.46666666666667


Epoch 7/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7 ------------- Loss= 0.21240038631862262 Acccuracy= 98.73333333333333


Epoch 8/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8 ------------- Loss= 0.14517389148371876 Acccuracy= 99.13333333333333


Epoch 9/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9 ------------- Loss= 0.19114242820593064 Acccuracy= 99.23333333333333


Epoch 10/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10 ------------- Loss= 0.10323280909447931 Acccuracy= 99.5


Epoch 11/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11 ------------- Loss= 0.09992190586743163 Acccuracy= 99.23333333333333


Epoch 12/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12 ------------- Loss= 0.1424202796656391 Acccuracy= 99.26666666666667


Epoch 13/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13 ------------- Loss= 0.10664852139554569 Acccuracy= 99.1


Epoch 14/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14 ------------- Loss= 0.068071213621879 Acccuracy= 99.76666666666667


Epoch 15/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15 ------------- Loss= 0.06351958364786697 Acccuracy= 99.7


Epoch 16/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16 ------------- Loss= 0.05090420080083277 Acccuracy= 99.76666666666667


Epoch 17/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17 ------------- Loss= 0.044024755224787666 Acccuracy= 99.83333333333333


Epoch 18/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18 ------------- Loss= 0.060671954464414736 Acccuracy= 99.7


Epoch 19/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19 ------------- Loss= 0.04098398191972592 Acccuracy= 99.76666666666667


Epoch 20/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 20 ------------- Loss= 0.03762057566924341 Acccuracy= 99.66666666666667
Testing........... Started.......................
Loss= 0.03762057566924341 Acccuracy on 10 Class = 71.43333333333334
Mean Predictive Entropy = 0.6568796038627625


In [42]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in strict_test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        stacked_logits=[]



        model.train()
        for _ in range(n_times):

            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            stacked_logits.append(logits)
        
        model.eval()
        mean_logits=torch.stack(stacked_logits).mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        torch.cuda.empty_cache()
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy on", len(strict_test_list),"Class =",accuracy)


    entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())


Loss= 0.03762057566924341 Acccuracy on 7 Class = 33.68571428571428
Mean Predictive Entropy = 1.2036287784576416
