In [53]:
import torchvision
import torch
import os
import random
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [54]:
from torchvision.datasets.utils import download_and_extract_archive
from torchvision.datasets.folder import ImageFolder

In [55]:
data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

In [56]:
if not os.path.exists("/home/23dcs505/data/2750"):
    print("No dataset found")
fulldata=ImageFolder(root='/home/23dcs505/data/2750', transform=data_transform)


In [57]:
from torch.utils.data import random_split

train_len=int((0.8)*len(fulldata))
test_len=len(fulldata)-(train_len)

train_data_set,test_data_set= random_split(fulldata,[train_len, test_len])

In [58]:
all_list=[0,1,2,3,4,5,6,7,8,9]

In [59]:
train_class_len=8

In [60]:
train_list=random.sample(all_list,train_class_len)
test_list=list(range(0,10))
strict_test_list=list(set(all_list) - set(train_list))


In [61]:
print(train_list)
print(test_list)
print(strict_test_list)

[7, 3, 1, 0, 8, 5, 4, 2]
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
[9, 6]


In [84]:
ways=5
shots=5
queries=5
strict_ways=len(strict_test_list)

In [63]:
from torch.utils.data import Subset

In [64]:
train_data_set.indices[0]

13256

In [65]:
def class_sorting(dataset, class_list):
    targets = dataset.dataset.targets

    indices= [i for i in dataset.indices if targets[i] in class_list]
    return Subset(dataset.dataset, indices)
    

In [66]:
train_data=class_sorting(train_data_set,train_list)
test_data=class_sorting(test_data_set,test_list)
strict_test_data=class_sorting(test_data,strict_test_list)

In [67]:
train_data[0]

(tensor([[[-0.5767, -0.5767, -0.5767,  ..., -0.7308, -0.7137, -0.7137],
          [-0.5767, -0.5767, -0.5767,  ..., -0.7308, -0.7137, -0.7137],
          [-0.5938, -0.5938, -0.5938,  ..., -0.7308, -0.7308, -0.7308],
          ...,
          [-0.2513, -0.2513, -0.2513,  ..., -0.4054, -0.3369, -0.3369],
          [-0.2342, -0.2342, -0.2342,  ..., -0.4226, -0.3541, -0.3541],
          [-0.2342, -0.2342, -0.2342,  ..., -0.4226, -0.3541, -0.3541]],
 
         [[-0.2850, -0.2850, -0.2850,  ..., -0.4601, -0.4426, -0.4426],
          [-0.2850, -0.2850, -0.2850,  ..., -0.4601, -0.4426, -0.4426],
          [-0.2850, -0.2850, -0.2850,  ..., -0.4601, -0.4426, -0.4426],
          ...,
          [-0.1099, -0.1099, -0.1099,  ..., -0.2150, -0.1450, -0.1450],
          [-0.0924, -0.0924, -0.0924,  ..., -0.2325, -0.1625, -0.1625],
          [-0.0924, -0.0924, -0.0924,  ..., -0.2325, -0.1625, -0.1625]],
 
         [[ 0.0953,  0.0953,  0.0953,  ..., -0.1487, -0.1312, -0.1312],
          [ 0.0953,  0.0953,

In [68]:

train_data.dataset

Dataset ImageFolder
    Number of datapoints: 27000
    Root location: /home/23dcs505/data/2750
    StandardTransform
Transform: Compose(
               Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
               ToTensor()
               Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
           )

In [69]:
train_data.indices[0]

13256

In [70]:
from torch.utils.data import Dataset

class create_dataset(Dataset):
    def __init__(self,data,way,shot,query,episode):
        super().__init__()
        self.data=data
        self.way=way
        self.shot=shot
        self.query=query
        self.episode=episode

        self.class_to_indices=self._build_class_index()
        self.classes=list(self.class_to_indices.keys())
        

    def _build_class_index(self):
        class_index={}

        targets=self.data.dataset.targets

        labels = [self.data.dataset.targets[i] for i in self.data.indices]
        


        for indexofsubset, indexoforiginal in enumerate(self.data.indices):
            label=targets[indexoforiginal]
            if label not in class_index:
                class_index[label]=[]
            class_index[label].append(indexofsubset)

        return class_index
        
    def __len__(self):
            return self.episode
        
    def __getitem__(self, idx):
        #print('hello')

        #print(f"Total available classes: {len(self.classes)}, requested way: {self.way}")


        selected_class=random.sample(self.classes,self.way)

        support_images, support_labels=[],[]
        query_images, query_labels=[],[]


        label_map={class_name: i for i, class_name in enumerate(selected_class)}

        for class_name in selected_class:
            all_indices_for_class=self.class_to_indices[class_name]

            selected_index=random.sample(all_indices_for_class,self.shot+self.query)

            support_index=selected_index[:self.shot]
            query_index=selected_index[self.shot:]

            for i in support_index:
                image,_=self.data[i]
                support_images.append(image)
                support_labels.append(torch.tensor(label_map[class_name]))
                
            for i in query_index:
                image,_=self.data[i]
                query_images.append(image)
                query_labels.append(torch.tensor(label_map[class_name]))
            
        return(
            torch.stack(support_images),
            torch.stack(support_labels),
            torch.stack(query_images),
            torch.stack(query_labels)
        )

In [71]:
def compute_prototypes(support_embeddings,support_labels,way):
    embedding_dimensions=support_embeddings.size(-1)
    prototypes=torch.zeros(way,embedding_dimensions).to(support_embeddings.device)

    for c in range(way):
        class_mask=(support_labels==c)
        class_embeddings=support_embeddings[class_mask]
        prototypes[c]=class_embeddings.mean(dim=0)
    return prototypes

def classify_queries(prototypes,query_embeddings):
    n_query=query_embeddings.size(0)
    way=prototypes.size(0)

    query_exp=query_embeddings.unsqueeze(1).expand(n_query,way,-1)
    prototypes_exp=prototypes.unsqueeze(0).expand(n_query,way,-1)

    distances=torch.sum((query_exp-prototypes_exp)**2,dim=2)

    logits=-distances
    return logits


In [72]:
import torch.optim as optim

few_dataset=create_dataset(
    data=train_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [73]:
few_dataloader=DataLoader(
    few_dataset,
    #batch_size=1,
    shuffle=True
)

In [74]:
import torchvision.models as models
vgg=models.vgg16(pretrained=True)

In [75]:
class VGGEmbedding(nn.Module):
    def __init__(self):
        super().__init__()


        # ##Code from the paper
        # features_list = list(vgg.features.children())
        # # Insert DropBlock after MaxPool at index 16
        # features_list.insert(17, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # # Insert DropBlock after MaxPool at index 23 (now 24 due to previous insertion)
        # features_list.insert(24, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # ##END


        self.features=vgg.features
        self.avgpool=vgg.avgpool

        self.classifier=nn.Sequential(*list(vgg.classifier.children())[:-1])

    def forward(self,x):
        x=self.features(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x

In [76]:
model=VGGEmbedding()

for param in model.features.parameters():
    param.requires_grad=False

model.classifier[3]=nn.Linear(model.classifier[3].in_features,256)

device=torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model=model.to(device)

optimizer=optim.Adam(model.parameters(),lr=1e-4)
loss_fn=nn.CrossEntropyLoss()

epochs=20


In [77]:
print("training on class :",train_list)

training on class : [7, 3, 1, 0, 8, 5, 4, 2]


In [78]:

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    for episode in few_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)

        optimizer.zero_grad()
        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        loss=loss_fn(logits,query_labels)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)

Epoch: 1 ------------- Loss= 1.3608693324029446 Acccuracy= 63.04
Epoch: 2 ------------- Loss= 0.9174705976247788 Acccuracy= 71.6
Epoch: 3 ------------- Loss= 0.7172941901534796 Acccuracy= 79.22
Epoch: 4 ------------- Loss= 0.5152536971867084 Acccuracy= 83.62
Epoch: 5 ------------- Loss= 0.4201317657344043 Acccuracy= 86.53999999999999
Epoch: 6 ------------- Loss= 0.3699075583915692 Acccuracy= 87.9
Epoch: 7 ------------- Loss= 0.3313070959947072 Acccuracy= 90.0
Epoch: 8 ------------- Loss= 0.3443203535955399 Acccuracy= 89.2
Epoch: 9 ------------- Loss= 0.3121330278739333 Acccuracy= 89.86
Epoch: 10 ------------- Loss= 0.2823126020957716 Acccuracy= 90.72
Epoch: 11 ------------- Loss= 0.28628839645069093 Acccuracy= 91.60000000000001
Epoch: 12 ------------- Loss= 0.29973291508620603 Acccuracy= 91.5
Epoch: 13 ------------- Loss= 0.2775481683155522 Acccuracy= 91.16
Epoch: 14 ------------- Loss= 0.2600843299902044 Acccuracy= 92.24
Epoch: 15 ------------- Loss= 0.2648161242174683 Acccuracy= 92.2

In [79]:
test_dataset=create_dataset(
    data=test_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [80]:
test_dataloader=DataLoader(
    test_dataset,
    shuffle=True
)

In [81]:
print("testing on class :",test_list)

testing on class : [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]


In [82]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)

        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy=",accuracy)

Loss= 0.19731398384610657 Acccuracy= 91.84


In [85]:
strict_test_dataset=create_dataset(
    data=strict_test_data,
    way=strict_ways,
    shot=shots,
    query=queries,
    episode=200
)

In [86]:
strict_test_dataloader=DataLoader(
    strict_test_dataset,
    #batch_size=1,
    shuffle=True
)

In [87]:
print("testing on class :",strict_test_list)

testing on class : [9, 6]


In [88]:
model.eval()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in strict_test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)

        support_embeddings=model(support_images)
        query_embeddings=model(query_images)

        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy=",accuracy)

Loss= 0.19731398384610657 Acccuracy= 92.75


**Stable Protypical Network**

In [92]:
import torch
import gc

# Delete all unused objects
gc.collect()

# Empty PyTorch CUDA cache
torch.cuda.empty_cache()

In [93]:
from dropblock import DropBlock2D

In [94]:
import torchvision.models as models
vgg=models.vgg16(pretrained=True)
class VGGEmbedding(nn.Module):
    def __init__(self,drop_prob=0.3, block_size=5):
        super().__init__()


        ##Code from the paper
        features_list = list(vgg.features.children())
        # Insert DropBlock after MaxPool at index 16
        features_list.insert(17, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        # Insert DropBlock after MaxPool at index 23 (now 24 due to previous insertion)
        features_list.insert(24, DropBlock2D(block_size=block_size, drop_prob=drop_prob))
        ##END


        self.features=vgg.features
        self.avgpool=vgg.avgpool

        self.classifier=nn.Sequential(*list(vgg.classifier.children())[:-1])

    def forward(self,x):
        x=self.features(x)
        x=self.avgpool(x)
        x=torch.flatten(x,1)
        x=self.classifier(x)
        return x

model=VGGEmbedding()

for param in model.features.parameters():
    param.requires_grad=False

model.classifier[3]=nn.Linear(model.classifier[3].in_features,256)

device=torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model=model.to(device)

optimizer=optim.Adam(model.parameters(),lr=1e-4)
loss_fn=nn.CrossEntropyLoss()

epochs=20

#From code for SPN
n_times=5
alpha=0.1

for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries= 0,0,0

    for episode in few_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)


        #For montecarlopass
        all_ce_losses = []
        all_query_logits = []

        for _ in range(n_times):

            #optimizer.zero_grad()
            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            ce_loss=loss_fn(logits,query_labels)
            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)
        
        total_ce_loss= torch.stack(all_ce_losses).mean()

        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=1)

        true_class_probs=stacked_probs[:, torch.arange(len(query_labels)),query_labels]
        variance_loss=torch.std(true_class_probs,dim=0).mean()

        total_combined_loss=total_ce_loss+alpha*variance_loss

        optimizer.zero_grad()
        total_combined_loss.backward()
        optimizer.step()

        mean_logits=stacked_logits.mean(dim=0)
        total_loss+=total_combined_loss.item()
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)
    
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy)




# test_dataset=create_dataset(
#     data=test_data,
#     way=ways,
#     shot=shots,
#     query=queries,
#     episode=200
# )
# test_dataloader=DataLoader(
#     test_dataset,
#     shuffle=True
# )

model.train()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)

        stacked_logits=[]

        for _ in range(n_times):

            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            stacked_logits.append(logits)
        
        mean_logits=torch.stack(stacked_logits).mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        torch.cuda.empty_cache()
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy=",accuracy)


    entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())


Epoch: 1 ------------- Loss= 0.8299506848305463 Acccuracy= 74.3
Epoch: 2 ------------- Loss= 0.3870365349203348 Acccuracy= 88.12
Epoch: 3 ------------- Loss= 0.3132565667852759 Acccuracy= 90.53999999999999
Epoch: 4 ------------- Loss= 0.25237558664754034 Acccuracy= 92.88
Epoch: 5 ------------- Loss= 0.2185698545537889 Acccuracy= 93.46
Epoch: 6 ------------- Loss= 0.17827789523638785 Acccuracy= 94.66
Epoch: 7 ------------- Loss= 0.1711210664920509 Acccuracy= 95.26
Epoch: 8 ------------- Loss= 0.1508194752689451 Acccuracy= 95.78
Epoch: 9 ------------- Loss= 0.12393793179653585 Acccuracy= 97.06
Epoch: 10 ------------- Loss= 0.12603611681610347 Acccuracy= 96.44
Epoch: 11 ------------- Loss= 0.11919277957640588 Acccuracy= 96.98
Epoch: 12 ------------- Loss= 0.11120697053149342 Acccuracy= 97.24000000000001
Epoch: 13 ------------- Loss= 0.12448192247189581 Acccuracy= 97.2
Epoch: 14 ------------- Loss= 0.09962486624717712 Acccuracy= 97.5
Epoch: 15 ------------- Loss= 0.10473587723448873 Acccur

In [95]:
model.train()
total_correct, total_queries= 0,0
with torch.no_grad():
    for episode in strict_test_dataloader:
        support_images, support_labels, query_images, query_labels=episode
        support_images=(support_images.squeeze(0)).to(device)
        query_images=(query_images.squeeze(0)).to(device)
        support_labels=(support_labels.view(-1)).to(device)
        query_labels=(query_labels.view(-1)).to(device)

        stacked_logits=[]

        for _ in range(n_times):

            support_embeddings=model(support_images)
            query_embeddings=model(query_images)

            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)

            stacked_logits.append(logits)
        
        mean_logits=torch.stack(stacked_logits).mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        torch.cuda.empty_cache()
    
    #avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Loss=",avg_loss,"Acccuracy=",accuracy)


    entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())


Loss= 0.0848819395992905 Acccuracy= 88.1
Mean Predictive Entropy = 0.2893742322921753
