In [30]:
import torch
print("PyTorch sees", torch.cuda.device_count(), "GPUs")


PyTorch sees 8 GPUs


In [31]:
import os
print("CPU cores available:", os.cpu_count())


CPU cores available: 80


In [32]:
import torchvision
import torch
import os
import random
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [33]:
from torchvision.datasets.utils import download_and_extract_archive
from torchvision.datasets.folder import ImageFolder

In [34]:
data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

In [35]:
if not os.path.exists("/home/23dcs505/data/2750"):
    print("No dataset found")
fulldata=ImageFolder(root='/home/23dcs505/data/2750', transform=data_transform)


In [36]:
from torch.utils.data import random_split

train_len=int((0.8)*len(fulldata))
test_len=len(fulldata)-(train_len)

train_data_set,test_data_set= random_split(fulldata,[train_len, test_len])

In [37]:
all_list=[0,1,2,3,4,5,6,7,8,9]

In [38]:
train_class_len=8

In [39]:
train_list=random.sample(all_list,train_class_len)
test_list=list(range(0,10))
strict_test_list=list(set(all_list) - set(train_list))


In [40]:
print(train_list)
print(test_list)
print(strict_test_list)

[9, 7, 3, 8, 5, 4, 6, 1]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[0, 2]


In [41]:
ways=5
shots=5
queries=5
strict_ways=len(strict_test_list)

In [42]:
from torch.utils.data import Subset

In [43]:
train_data_set.indices[0]

1225

In [44]:
def class_sorting(dataset, class_list):
    targets = dataset.dataset.targets

    indices= [i for i in dataset.indices if targets[i] in class_list]
    return Subset(dataset.dataset, indices)
    

In [45]:
train_data=class_sorting(train_data_set,train_list)
test_data=class_sorting(test_data_set,test_list)
strict_test_data=class_sorting(test_data,strict_test_list)

In [46]:
train_data[0]

(tensor([[[-0.5424, -0.5424, -0.5424,  ..., -1.1589, -1.1760, -1.1760],
          [-0.5424, -0.5424, -0.5424,  ..., -1.1589, -1.1760, -1.1760],
          [-0.5424, -0.5424, -0.5424,  ..., -1.1589, -1.1760, -1.1760],
          ...,
          [-0.1828, -0.1828, -0.1828,  ..., -0.7822, -0.8164, -0.8164],
          [-0.1828, -0.1828, -0.1828,  ..., -0.8507, -0.8849, -0.8849],
          [-0.1828, -0.1828, -0.1828,  ..., -0.8507, -0.8849, -0.8849]],
 
         [[-0.2850, -0.2850, -0.2850,  ..., -0.5651, -0.5651, -0.5651],
          [-0.2850, -0.2850, -0.2850,  ..., -0.5651, -0.5651, -0.5651],
          [-0.2850, -0.2850, -0.2850,  ..., -0.5651, -0.5651, -0.5651],
          ...,
          [-0.0049, -0.0049, -0.0049,  ..., -0.5126, -0.5301, -0.5301],
          [-0.0049, -0.0049, -0.0049,  ..., -0.5826, -0.6001, -0.6001],
          [-0.0049, -0.0049, -0.0049,  ..., -0.5826, -0.6001, -0.6001]],
 
         [[ 0.1302,  0.1302,  0.1302,  ..., -0.2358, -0.2358, -0.2358],
          [ 0.1302,  0.1302,

In [47]:

train_data.dataset

Dataset ImageFolder
    Number of datapoints: 27000
    Root location: /home/23dcs505/data/2750
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [48]:
train_data.indices[0]

19324

In [49]:
from torch.utils.data import Dataset

class create_dataset(Dataset):
    def __init__(self,data,way,shot,query,episode):
        super().__init__()
        self.data=data
        self.way=way
        self.shot=shot
        self.query=query
        self.episode=episode

        self.class_to_indices=self._build_class_index()
        self.classes=list(self.class_to_indices.keys())
        

    def _build_class_index(self):
        class_index={}

        targets=self.data.dataset.targets

        labels = [self.data.dataset.targets[i] for i in self.data.indices]
        


        for indexofsubset, indexoforiginal in enumerate(self.data.indices):
            label=targets[indexoforiginal]
            if label not in class_index:
                class_index[label]=[]
            class_index[label].append(indexofsubset)

        return class_index
        
    def __len__(self):
            return self.episode
        
    def __getitem__(self, idx):
        #print('hello')

        #print(f"Total available classes: {len(self.classes)}, requested way: {self.way}")


        selected_class=random.sample(self.classes,self.way)

        support_images, support_labels=[],[]
        query_images, query_labels=[],[]


        label_map={class_name: i for i, class_name in enumerate(selected_class)}

        for class_name in selected_class:
            all_indices_for_class=self.class_to_indices[class_name]

            selected_index=random.sample(all_indices_for_class,self.shot+self.query)

            support_index=selected_index[:self.shot]
            query_index=selected_index[self.shot:]

            for i in support_index:
                image,_=self.data[i]
                support_images.append(image)
                support_labels.append(torch.tensor(label_map[class_name]))
                
            for i in query_index:
                image,_=self.data[i]
                query_images.append(image)
                query_labels.append(torch.tensor(label_map[class_name]))
            
        return(
            torch.stack(support_images),
            torch.stack(support_labels),
            torch.stack(query_images),
            torch.stack(query_labels)
        )

In [50]:
def compute_prototypes(support_embeddings,support_labels,way):
    embedding_dimensions=support_embeddings.size(-1)
    prototypes=torch.zeros(way,embedding_dimensions).to(support_embeddings.device)

    for c in range(way):
        class_mask=(support_labels==c)
        class_embeddings=support_embeddings[class_mask]
        prototypes[c]=class_embeddings.mean(dim=0)
    return prototypes

def classify_queries(prototypes,query_embeddings):
    n_query=query_embeddings.size(0)
    way=prototypes.size(0)

    query_exp=query_embeddings.unsqueeze(1).expand(n_query,way,-1)
    prototypes_exp=prototypes.unsqueeze(0).expand(n_query,way,-1)

    distances=torch.sum((query_exp-prototypes_exp)**2,dim=2)

    logits=-distances
    return logits


In [51]:
import torch.optim as optim

few_dataset=create_dataset(
    data=train_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [52]:
few_dataloader=DataLoader(
    few_dataset,
    #batch_size=1,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [53]:
import torchvision.models as models
vgg=models.vgg16(pretrained=True)

In [54]:
class VGGEmbedding(nn.Module):
    def __init__(self):
        super().__init__()


        # ##Code from the paper
        # features_list = list(vgg.features.children())
        # # Insert DropBlock after MaxPool at index 16
        # features_list.insert(17, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # # Insert DropBlock after MaxPool at index 23 (now 24 due to previous insertion)
        # features_list.insert(24, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # ##END


        self.features=vgg.features
        self.avgpool=vgg.avgpool

        self.classifier=nn.Sequential(*list(vgg.classifier.children())[:-1])

    def forward(self,x):
        x=self.features(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x

In [55]:
model=VGGEmbedding()

for param in model.features.parameters():
    param.requires_grad=False

model.classifier[3]=nn.Linear(model.classifier[3].in_features,256)

device=torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model=model.to(device)

optimizer=optim.Adam(model.parameters(),lr=1e-4)
loss_fn=nn.CrossEntropyLoss()

epochs=20


In [56]:
print("training on class :",train_list)

training on class : [9, 7, 3, 8, 5, 4, 6, 1]


In [57]:
import torch
import gc

gc.collect()  # Python garbage collection
torch.cuda.empty_cache()  # Clear cache for current device


In [58]:

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    for episode in few_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        optimizer.zero_grad()
        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        loss=loss_fn(logits,query_labels)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)

Epoch: 1 ------------- Loss= 1.3288108696043492 Acccuracy= 64.42
Epoch: 2 ------------- Loss= 0.9727111332863569 Acccuracy= 69.14
Epoch: 3 ------------- Loss= 0.7913329622149468 Acccuracy= 74.32
Epoch: 4 ------------- Loss= 0.5078389469906688 Acccuracy= 83.72
Epoch: 5 ------------- Loss= 0.4828438464878127 Acccuracy= 85.06
Epoch: 6 ------------- Loss= 0.40637346596457063 Acccuracy= 86.7
Epoch: 7 ------------- Loss= 0.3734011799842119 Acccuracy= 88.42
Epoch: 8 ------------- Loss= 0.3254919455142226 Acccuracy= 89.92
Epoch: 9 ------------- Loss= 0.2800640887510963 Acccuracy= 91.28
Epoch: 10 ------------- Loss= 0.28580211638240144 Acccuracy= 91.22
Epoch: 11 ------------- Loss= 0.24228225934552028 Acccuracy= 92.22
Epoch: 12 ------------- Loss= 0.2796203244908247 Acccuracy= 91.42
Epoch: 13 ------------- Loss= 0.21232156127109192 Acccuracy= 93.2
Epoch: 14 ------------- Loss= 0.2211283789924346 Acccuracy= 93.60000000000001
Epoch: 15 ------------- Loss= 0.2396185733319726 Acccuracy= 93.14
Epoch

In [59]:
test_dataset=create_dataset(
    data=test_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [60]:
test_dataloader=DataLoader(
    test_dataset,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [61]:
print("testing on class :",test_list)

testing on class : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [62]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy on", len(test_list),"Class =",accuracy)

Loss= 0.2114605195436161 Acccuracy on 10 Class = 89.34


In [63]:
strict_test_dataset=create_dataset(
    data=strict_test_data,
    way=strict_ways,
    shot=shots,
    query=queries,
    episode=200
)

In [64]:
strict_test_dataloader=DataLoader(
    strict_test_dataset,
    #batch_size=1,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [65]:
print("testing on class :",strict_test_list)

testing on class : [0, 2]


In [66]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in strict_test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy on", len(strict_test_list),"Class =",accuracy)

Loss= 0.2114605195436161 Acccuracy on 2 Class = 79.9


**Stable Protypical Network**

In [67]:
import torch
import gc

# Delete all unused objects
gc.collect()

# Empty PyTorch CUDA cache
torch.cuda.empty_cache()

In [68]:
from dropblock import DropBlock2D

In [69]:
from torch.cuda.amp import autocast, GradScaler

In [70]:
import torchvision.models as models
vgg=models.vgg16(pretrained=True)
class VGGEmbedding(nn.Module):
    def __init__(self,drop_prob=0.3, block_size=5):
        super().__init__()


        ##Code from the paper
        features_list = list(vgg.features.children())
        # Insert DropBlock after MaxPool at index 16
        features_list.insert(17, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # Insert DropBlock after MaxPool at index 23 (now 24 due to previous insertion)
        features_list.insert(24, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        ##END


        self.features=nn.Sequential(*features_list)
        self.avgpool=vgg.avgpool

        self.classifier=nn.Sequential(*list(vgg.classifier.children())[:-1])

    def forward(self,x):
        x=self.features(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x

model=VGGEmbedding()

for name, param in model.features.named_parameters():
    if 'dropblock' not in name:
        param.requires_grad=False

model.classifier[3]=nn.Linear(model.classifier[3].in_features,256)

device=torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
model=model.to(device)

optimizer=optim.Adam(model.parameters(),lr=1e-4)
loss_fn=nn.CrossEntropyLoss()

epochs=20

#From code for SPN
n_times=5
alpha=1.0

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    from tqdm.notebook import tqdm
    progress_bar=tqdm(few_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)


    for episode in progress_bar:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)


        optimizer.zero_grad(set_to_none=True)
        #For montecarlopass
        
        all_ce_losses = []
        all_query_logits = []

        for _ in range(n_times):
            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            ce_loss=loss_fn(logits,query_labels)
            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)
            
        total_ce_loss= torch.stack(all_ce_losses).mean()

        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=-1)

        true_class_probs = stacked_probs[
            torch.arange(n_times)[:, None],
            torch.arange(len(query_labels)),
            query_labels
        ]
            
        variance_loss=torch.std(true_class_probs,dim=0).mean()
        total_combined_loss=total_ce_loss+alpha*variance_loss

        

        
        total_combined_loss.backward()
        optimizer.step()

        mean_logits=stacked_logits.mean(dim=0)
        total_loss+=total_combined_loss.item()
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}&")
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)

print("Testing........... Started.......................")

model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        stacked_logits=[]



        model.train()
        for _ in range(n_times):

            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            stacked_logits.append(logits)
        
        model.eval()
        mean_logits=torch.stack(stacked_logits).mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        torch.cuda.empty_cache()
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy on", len(test_list),"Class =",accuracy)


    entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())


Epoch 1/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 1 ------------- Loss= 1.3751522636413573 Acccuracy= 62.08


Epoch 2/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 2 ------------- Loss= 0.7726405822485686 Acccuracy= 80.60000000000001


Epoch 3/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 3 ------------- Loss= 0.5885326896235347 Acccuracy= 86.16


Epoch 4/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 4 ------------- Loss= 0.5136173798143864 Acccuracy= 88.68


Epoch 5/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 5 ------------- Loss= 0.4416167311742902 Acccuracy= 90.24


Epoch 6/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 6 ------------- Loss= 0.40828229662030935 Acccuracy= 90.28


Epoch 7/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 7 ------------- Loss= 0.3621763958223164 Acccuracy= 91.94


Epoch 8/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 8 ------------- Loss= 0.3311185437254608 Acccuracy= 92.80000000000001


Epoch 9/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 9 ------------- Loss= 0.3214541762880981 Acccuracy= 93.28


Epoch 10/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 10 ------------- Loss= 0.2960783926397562 Acccuracy= 94.19999999999999


Epoch 11/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 11 ------------- Loss= 0.30289030304178594 Acccuracy= 93.94


Epoch 12/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 12 ------------- Loss= 0.28507953859865665 Acccuracy= 94.16


Epoch 13/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 13 ------------- Loss= 0.2626989331096411 Acccuracy= 95.36


Epoch 14/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 14 ------------- Loss= 0.2493012420507148 Acccuracy= 94.92


Epoch 15/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 15 ------------- Loss= 0.246361143309623 Acccuracy= 95.26


Epoch 16/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 16 ------------- Loss= 0.22956619290169328 Acccuracy= 95.64


Epoch 17/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 17 ------------- Loss= 0.19867029588902368 Acccuracy= 96.78


Epoch 18/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 18 ------------- Loss= 0.20530921302270144 Acccuracy= 96.26


Epoch 19/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 19 ------------- Loss= 0.2003226502519101 Acccuracy= 96.17999999999999


Epoch 20/20:   0%|          | 0/200 [00:00<?, ?it/s]

Epoch: 20 ------------- Loss= 0.19385274865198882 Acccuracy= 96.6
Testing........... Started.......................
Loss= 0.19385274865198882 Acccuracy on 10 Class = 88.78
Mean Predictive Entropy = 0.12722106277942657


In [72]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in strict_test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)

        stacked_logits=[]



        model.train()
        for _ in range(n_times):

            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            stacked_logits.append(logits)
        
        model.eval()
        mean_logits=torch.stack(stacked_logits).mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        torch.cuda.empty_cache()
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy on", len(strict_test_list),"Class =",accuracy)


    entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())


Loss= 0.19385274865198882 Acccuracy on 2 Class = 78.14999999999999
Mean Predictive Entropy = 0.5437802076339722
