In [None]:
import torch
print("PyTorch sees", torch.cuda.device_count(), "GPUs")


In [None]:
import os
print("CPU cores available:", os.cpu_count())


In [None]:
import torchvision
import torch
import os
import random
import torch.nn as nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [None]:
from torchvision.datasets.utils import download_and_extract_archive
from torchvision.datasets.folder import ImageFolder

In [None]:
data_transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

In [None]:
if not os.path.exists("/home/23dcs505/data/2750"):
    print("No dataset found")
fulldata=ImageFolder(root='/home/23dcs505/data/2750', transform=data_transform)


In [None]:
from torch.utils.data import random_split

train_len=int((0.8)*len(fulldata))
test_len=len(fulldata)-(train_len)

train_data_set,test_data_set= random_split(fulldata,[train_len, test_len])

In [None]:
all_list=[0,1,2,3,4,5,6,7,8,9]

In [None]:
train_class_len=5

In [None]:
train_list=random.sample(all_list,train_class_len)
test_list=list(range(0,10))
strict_test_list=list(set(all_list) - set(train_list))


In [None]:
print(train_list)
print(test_list)
print(strict_test_list)

In [None]:
ways=5
shots=5
queries=5
strict_ways=ways
gpu_num=1

In [None]:
from torch.utils.data import Subset

In [None]:
train_data_set.indices[0]

In [None]:
def class_sorting(dataset, class_list):
    targets = dataset.dataset.targets

    indices= [i for i in dataset.indices if targets[i] in class_list]
    return Subset(dataset.dataset, indices)
    

In [None]:
train_data=class_sorting(train_data_set,train_list)
test_data=class_sorting(test_data_set,test_list)
strict_test_data=class_sorting(test_data_set,strict_test_list)

In [None]:
train_data[0]

In [None]:

train_data.dataset

In [None]:
train_data.indices[0]

In [None]:
def mask_patches_tensor(image, patch_size=9, num_patches=124):
    masked = image.clone()
    _, H, W = masked.shape

    for _ in range(num_patches):
        top = torch.randint(0, H - patch_size + 1, (1,)).item()
        left = torch.randint(0, W - patch_size + 1, (1,)).item()
        masked[:, top:top + patch_size, left:left + patch_size] = 0.0 

    return masked

In [None]:
from torch.utils.data import Dataset

class create_dataset(Dataset):
    def __init__(self,data,way,shot,query,episode):
        super().__init__()
        self.data=data
        self.way=way
        self.shot=shot
        self.query=query
        self.episode=episode

        self.class_to_indices=self._build_class_index()
        self.classes=list(self.class_to_indices.keys())
    
    @staticmethod
    def block_mask(img, patch_size=8, mask_ratio=0.1):
        C, H, W = img.shape
        num_patches_h = H // patch_size
        num_patches_w = W // patch_size
        total_patches = num_patches_h * num_patches_w
        num_mask = int(mask_ratio * total_patches)

        # Choose random patch indices to mask
        patch_indices = [(i, j) for i in range(num_patches_h) for j in range(num_patches_w)]
        masked_indices = random.sample(patch_indices, num_mask)

        # Initialize full mask
        mask = torch.zeros((1, H, W))

        for i, j in masked_indices:
            h_start = i * patch_size
            w_start = j * patch_size
            mask[:, h_start:h_start+patch_size, w_start:w_start+patch_size] = 1.0

        masked_img = img.clone() * (1 - mask)
        return masked_img, mask    

    def _build_class_index(self):
        class_index={}

        targets=self.data.dataset.targets

        labels = [self.data.dataset.targets[i] for i in self.data.indices]
        


        for indexofsubset, indexoforiginal in enumerate(self.data.indices):
            label=targets[indexoforiginal]
            if label not in class_index:
                class_index[label]=[]
            class_index[label].append(indexofsubset)

        return class_index
        
    def __len__(self):
            return self.episode
        
    def __getitem__(self, idx):
        selected_class=random.sample(self.classes,self.way)

        reconstruct_images, support_images, support_labels=[],[],[]
        query_images, query_labels=[],[]
        mask=[]


        label_map={class_name: i for i, class_name in enumerate(selected_class)}

        for class_name in selected_class:
            all_indices_for_class=self.class_to_indices[class_name]

            selected_index=random.sample(all_indices_for_class,self.shot+self.query)

            support_index=selected_index[:self.shot]
            query_index=selected_index[self.shot:]

            for i in support_index:
                image,_=self.data[i]
                support_images.append(image)

                masked_image,masks = self.block_mask(image)
                reconstruct_images.append(masked_image)
                mask.append(masks)
                support_labels.append(torch.tensor(label_map[class_name]))
                
            for i in query_index:
                image,_=self.data[i]
                query_images.append(image)
                query_labels.append(torch.tensor(label_map[class_name]))
            
        return(
            torch.stack(reconstruct_images),
            torch.stack(mask),
            torch.stack(support_images),
            torch.stack(support_labels),
            torch.stack(query_images),
            torch.stack(query_labels)
        )

In [None]:
def compute_prototypes(support_embeddings,support_labels,way):
    embedding_dimensions=support_embeddings.size(-1)
    prototypes=torch.zeros(way,embedding_dimensions).to(support_embeddings.device)

    for c in range(way):
        class_mask=(support_labels==c)
        class_embeddings=support_embeddings[class_mask]
        prototypes[c]=class_embeddings.mean(dim=0)
    return prototypes

def classify_queries(prototypes,query_embeddings):
    n_query=query_embeddings.size(0)
    way=prototypes.size(0)

    query_exp=query_embeddings.unsqueeze(1).expand(n_query,way,-1)
    prototypes_exp=prototypes.unsqueeze(0).expand(n_query,way,-1)

    distances=torch.sum((query_exp-prototypes_exp)**2,dim=2)

    logits=-distances
    return logits


In [None]:
import torch.optim as optim

few_dataset=create_dataset(
    data=train_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [None]:
few_dataloader=DataLoader(
    few_dataset,
    #batch_size=1,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

In [None]:


# class VGGEmbedding(nn.Module):
    
#     def __init__(self, embedding_dim=256):
#         super().__init__()
        
#         vgg = models.vgg16(pretrained=True)
        
#         self.features = vgg.features
#         self.avgpool = vgg.avgpool
        
#         in_features = vgg.classifier[0].in_features
#         self.classifier = nn.Sequential(
#             nn.Linear(in_features, 4096),
#             nn.ReLU(True),
#             nn.Linear(4096, 4096),
#             nn.ReLU(True),
#             nn.Linear(4096, embedding_dim)
#         )

#     def forward(self, x):
#         x = self.features(x)
#         x = self.avgpool(x)
#         x = torch.flatten(x, 1)
#         x = self.classifier(x)
#         return x



# model= VGGEmbedding(embedding_dim=256)


# for param in model.parameters():
#     param.requires_grad = False


# for param in model.features[24:].parameters():
#     param.requires_grad = True
    

# for param in model.classifier.parameters():
#     param.requires_grad = True


# device = torch.device(f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu")
# model=model.to(device)

# trainable_params = filter(lambda p: p.requires_grad, model.parameters())
# optimizer = optim.Adam(trainable_params, lr=1e-4)
# loss_fn = nn.CrossEntropyLoss()

# epochs=20

In [None]:
test_dataset=create_dataset(
    data=test_data,
    way=ways,
    shot=shots,
    query=queries,
    episode=200
)

In [None]:
test_dataloader=DataLoader(
    test_dataset,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

In [None]:
print("testing on class :",test_list)

In [None]:
strict_test_dataset=create_dataset(
    data=strict_test_data,
    way=strict_ways,
    shot=shots,
    query=queries,
    episode=200
)

In [None]:
strict_test_dataloader=DataLoader(
    strict_test_dataset,
    #batch_size=1,
    shuffle=True,
    num_workers=8, 
    pin_memory=True
)

**Stable Protypical Network**

In [None]:
import torch
import gc

# Delete all unused objects
gc.collect()

# Empty PyTorch CUDA cache
torch.cuda.empty_cache()

In [None]:
from dropblock import DropBlock2D

In [None]:
from torch.cuda.amp import autocast, GradScaler

In [None]:
import torch
import torch.nn as nn
from torchvision import models

class DropBlock2D(nn.Module):
    def __init__(self, drop_prob=0.1, block_size=3):
        super().__init__()
        self.drop_prob = drop_prob
        self.block_size = block_size

    def forward(self, x):
        if not self.training or self.drop_prob == 0.:
            return x
        else:
            gamma = self._compute_gamma(x)
            mask = (torch.rand(x.shape[0], 1, x.shape[2], x.shape[3], device=x.device) < gamma).float()
            mask = self._compute_block_mask(mask)
            countM = mask.numel()
            count_ones = mask.sum()
            return mask * x * (countM / count_ones)

    def _compute_block_mask(self, mask):
        block_mask = nn.functional.max_pool2d(
            input=mask,
            kernel_size=(self.block_size, self.block_size),
            stride=(1, 1),
            padding=self.block_size // 2
        )
        return 1 - block_mask

    def _compute_gamma(self, x):
        return self.drop_prob / (self.block_size ** 2)

# Encoder based on ResNet-50
class Encoder(nn.Module):
    def __init__(self, drop_prob=0.3, block_size=3):
        super().__init__()
        resnet = models.resnet50(pretrained=True)

        # Freeze all parameters initially
        for param in resnet.parameters():
            param.requires_grad = False

        # Unfreeze parameters of the last block (layer4) for fine-tuning
        for param in resnet.layer4.parameters():
            param.requires_grad = True
        
        # Create a sequence of layers, inserting DropBlock after layer3 and layer4.
        # This is analogous to the original VGG implementation where DropBlock was
        # inserted after major feature extraction/downsampling stages.
        self.feature_extractor = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            DropBlock2D(drop_prob=drop_prob, block_size=block_size),
            resnet.layer4,
            DropBlock2D(drop_prob=drop_prob, block_size=block_size)
        )

        # The output of ResNet-50's layer4 has 2048 channels.
        # The Decoder and embedding head expect 512 channels.
        # This bottleneck layer reduces the channel dimension to ensure compatibility.
        self.bottleneck = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        # Ensure the new bottleneck layer is trainable
        for param in self.bottleneck.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.bottleneck(x)   # Output shape: [B, 512, 7, 7]
        return x


In [None]:
# class Encoder(nn.Module):
#     def __init__(self):
#         super().__init__()
#         self.vgg = models.vgg16(pretrained = True)

#         for param in self.vgg.features.parameters():
#             param.requires_grad = False

#         self.feature_extractor = nn.Sequential(*list(self.vgg.features.children())) # Use vgg.features

#     def forward(self,x):
#         x = self.feature_extractor(x) # Output shape: [B, 512, 7, 7]
#         return x

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            #formula to calcu;ate dim in convtranspose is used such that always double easier to deal ig
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),  # 7 → 14 (7(in)-1)*2 -2*1(pad) + 3(ker_size) + 1(out_pad)
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),  # 14 → 28
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),   # 28 → 56
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),    # 56 → 112
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),    # 112 → 224
            nn.ReLU(),
            nn.Conv2d(16, 3, kernel_size=3, padding=1),  # Keep output channels 3 (RGB)
            nn.Tanh()
        )

    def forward(self, x):
        return self.decoder(x)  # Output: [B, 3, 224, 224]


In [None]:
class MaskedAutoencoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

        self.embedding_head = nn.Sequential(
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),             
        nn.Linear(512, 256)
        )
    

    def forward(self, masked_img, mask=None):
        latent = self.encoder(masked_img)
        recon = self.decoder(latent)
        embedding=self.embedding_head(latent)
        return recon,embedding

In [None]:
import torch.nn.functional as F


In [None]:
device = torch.device(f"cuda:{gpu_num}" if torch.cuda.is_available() else "cpu")

In [None]:
encoder = Encoder().to(device)
decoder = Decoder().to(device)
model = MaskedAutoencoder(encoder, decoder).to(device)

# Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
import itertools


In [None]:
def compute_psnr(mse,max_val=1.0):
    psnr = 20 * torch.log10(max_val / torch.sqrt(mse))
    return psnr


In [None]:
def compute_rmse(img1, img2):
    return torch.sqrt(torch.mean((img1 - img2) ** 2))


In [None]:
import torch.nn.functional as F

def identity_loss_fn(embedding_a, embedding_b):
    
    return F.l1_loss(embedding_a, embedding_b)
    
    #return F.mse_loss(embedding_a, embedding_b)
    

In [None]:
def masked_loss(recon, target, mask):
    mask = mask.float()
    if mask.shape[1] == 1:
        mask = mask.expand_as(recon)  # Now shape is [B, 3, H, W]
    
    loss = F.mse_loss(recon * mask, target * mask, reduction='sum')
    norm = mask.sum() + 1e-8
    return loss / norm

In [None]:
images,mask,support_images,_,_,_=few_dataset[0]

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.imshow(images[2].permute(1,2,0))
plt.show()

In [None]:
loss_fn = nn.CrossEntropyLoss()
image_loss=nn.L1Loss()
triplet_loss_fn = nn.TripletMarginLoss(margin=1.5, p=2)
model = model.to(device)
epochs=20

#From code for SPN
recon_weight=5
n_times=15
alpha=0.01
best_accuracy = 0.0
for epoch in range(epochs):
    model.train()
    total_loss, total_correct, total_queries,total_final_psnr,total_final_recon_loss= 0,0,0,0,0

    from tqdm.notebook import tqdm
    progress_bar=tqdm(few_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)

    total_recon_loss = 0
    total_psnr = 0

    for episode in progress_bar:
        images,mask,support_images, support_labels, query_images, query_labels=episode
        images = images.squeeze(0).to(device, non_blocking=True) 
        mask = mask.squeeze(0).to(device, non_blocking=True)
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)
        
    
        total_combined_loss=0
        

        optimizer.zero_grad(set_to_none=True)
        # all_triplet_loss=[]
        all_ce_losses = []
        all_query_logits = []
        all_psnr=[]
        all_reconstruct_loss=[]

        for _ in range(n_times):
            _,support_embeddings=model(support_images)
            _,query_embeddings=model(query_images)
            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)
            ce_loss=loss_fn(logits,query_labels)

            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)

            # anchor_embeddings = query_embeddings
            # positive_embeddings = prototypes[query_labels] 

            # dists = torch.cdist(anchor_embeddings, prototypes) # Shape: [num_queries, n_way]
            # dists[torch.arange(len(query_labels)), query_labels] = float('inf')
            
            
            # hard_negative_indices = torch.argmin(dists, dim=1)
            # negative_embeddings = prototypes[hard_negative_indices].detach()

            
            #triplet_loss = triplet_loss_fn(
                #anchor_embeddings, 
                #positive_embeddings, 
                #negative_embeddings
            #)




            reconstructed_image,_=model(images,mask)    
            recon_loss= masked_loss(reconstructed_image, support_images,mask)
            img_loss=image_loss(reconstructed_image,support_images)
            recon_loss+=img_loss
            mse_loss=F.mse_loss(reconstructed_image,support_images)
            psnr=compute_psnr(mse_loss, max_val=1.0)
            #all_triplet_loss.append(triplet_loss)
            all_reconstruct_loss.append(recon_loss)
            all_psnr.append(psnr)

        #total_triplet_loss = torch.stack(all_triplet_loss).mean()
        total_ce_loss= torch.stack(all_ce_losses).mean()
        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=-1)

        true_class_probs = stacked_probs[
            torch.arange(n_times)[:, None],
            torch.arange(len(query_labels)),
            query_labels
        ]

        total_recon_loss=torch.stack(all_reconstruct_loss).mean()
        total_psnr=torch.stack(all_psnr).mean()

        variance_loss=torch.std(true_class_probs,dim=0).sum()
        total_combined_loss=(recon_weight * total_recon_loss)+(total_ce_loss)+(alpha*variance_loss)

        total_combined_loss.backward()
        optimizer.step()

        total_final_psnr+=total_psnr.item()
        total_final_recon_loss+=total_recon_loss.item()
        total_loss+=total_combined_loss.item()
        mean_logits=stacked_logits.mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Phase="Training",Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}",PSNR=f"{total_psnr.item():4f}&",Ce_Loss=f"{total_ce_loss.item():4f}",Recon_Loss=f"{total_recon_loss.item():4f}")

    
    avg_recon_loss = total_final_recon_loss / len(few_dataloader)
    avg_psnr = total_final_psnr / len(few_dataloader)
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Training ","Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy,"Recon Loss:",avg_recon_loss, "PSNR:",avg_psnr)



    model.eval()
    total_loss, total_correct, total_queries,total_final_psnr,total_final_recon_loss= 0,0,0,0,0
    from tqdm.notebook import tqdm
    progress_bar=tqdm(test_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)

    total_recon_loss = 0
    total_psnr = 0
    with torch.no_grad():
        for episode in progress_bar:
            images,mask,support_images, support_labels, query_images, query_labels=episode
            images = images.squeeze(0).to(device, non_blocking=True) 
            mask = mask.squeeze(0).to(device, non_blocking=True)
            support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
            query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
            support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
            query_labels=(query_labels.view(-1)).to(device, non_blocking=True)
            

            total_combined_loss=0
            

            optimizer.zero_grad(set_to_none=True)
            all_ce_losses = []
            all_query_logits = []
            all_psnr=[]
            all_reconstruct_loss=[]
            model.train()
            for _ in range(n_times):
                _,support_embeddings=model(support_images)
                _,query_embeddings=model(query_images)
                n_way=torch.unique(support_labels).size(0)
                prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
                logits=classify_queries(prototypes,query_embeddings)
                ce_loss=loss_fn(logits,query_labels)

                all_ce_losses.append(ce_loss)
                all_query_logits.append(logits)

                reconstructed_image,_=model(images,mask)    
                recon_loss= masked_loss(reconstructed_image, support_images,mask)
                mse_loss=F.mse_loss(reconstructed_image,support_images)
                psnr=compute_psnr(mse_loss, max_val=1.0)

                all_reconstruct_loss.append(recon_loss)
                all_psnr.append(psnr)
            model.eval()
            
            total_ce_loss= torch.stack(all_ce_losses).mean()
            stacked_logits=torch.stack(all_query_logits)
            stacked_probs=torch.softmax(stacked_logits,dim=-1)

            true_class_probs = stacked_probs[
                torch.arange(n_times)[:, None],
                torch.arange(len(query_labels)),
                query_labels
            ]

            total_recon_loss=torch.stack(all_reconstruct_loss).mean()
            total_psnr=torch.stack(all_psnr).mean()

            variance_loss=torch.std(true_class_probs,dim=0).sum()
            total_combined_loss=(recon_weight * total_recon_loss)+(total_ce_loss)+(alpha*variance_loss)

            #total_combined_loss.backward()
            #optimizer.step()

            total_final_psnr+=total_psnr.item()
            total_final_recon_loss+=total_recon_loss.item()
            total_loss+=total_combined_loss.item()
            mean_logits=stacked_logits.mean(dim=0)
            preds=torch.argmax(mean_logits,dim=1)
            total_correct+=(preds==query_labels).sum().item()
            total_queries+=query_labels.size(0)

            avg_acc_till=(total_correct/total_queries)*100
            progress_bar.set_postfix(Phase="Testing",Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}",PSNR=f"{psnr}&")


        avg_recon_loss = total_final_recon_loss / len(few_dataloader)
        avg_psnr = total_final_psnr / len(few_dataloader)
        avg_loss=total_loss/len(few_dataloader)
        accuracy=(total_correct/total_queries)*100
        print("Testing ","Epoch:",epoch+1,"-------------","Loss=",avg_loss,"Acccuracy=",accuracy,"Recon Loss:",avg_recon_loss, "PSNR:",avg_psnr)

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), "/home/23dcs505/model_recon/5w5s_resnet.pth")
    
    # # Optional: Save the latest model too
    # torch.save(model.state_dict(), "/home/23dcs505/model_recon/latest_model_w_5.pth")


In [None]:
recon_weight

In [None]:
#torch.save(model.state_dict(), 'model_ce_variance_spn_reconstruction.pth')

In [None]:
alpha

In [None]:
model.eval()
total_loss, total_correct, total_queries,total_final_psnr,total_final_recon_loss= 0,0,0,0,0
from tqdm.notebook import tqdm
progress_bar=tqdm(test_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)

total_recon_loss = 0
total_psnr = 0
with torch.no_grad():
    for episode in progress_bar:
        images,mask,support_images, support_labels, query_images, query_labels=episode
        images = images.squeeze(0).to(device, non_blocking=True) 
        mask = mask.squeeze(0).to(device, non_blocking=True)
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)
        

        total_combined_loss=0
        

        optimizer.zero_grad(set_to_none=True)
        all_ce_losses = []
        all_query_logits = []
        all_psnr=[]
        all_reconstruct_loss=[]
        model.train()
        for _ in range(n_times):
            _,support_embeddings=model(support_images)
            _,query_embeddings=model(query_images)
            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)
            ce_loss=loss_fn(logits,query_labels)

            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)

            reconstructed_image,_=model(images,mask)    
            recon_loss= masked_loss(reconstructed_image, support_images,mask)
            mse_loss=F.mse_loss(reconstructed_image,support_images)
            psnr=compute_psnr(mse_loss, max_val=1.0)

            all_reconstruct_loss.append(recon_loss)
            all_psnr.append(psnr)
        model.eval()
        
        total_ce_loss= torch.stack(all_ce_losses).mean()
        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=-1)

        true_class_probs = stacked_probs[
            torch.arange(n_times)[:, None],
            torch.arange(len(query_labels)),
            query_labels
        ]

        total_recon_loss=torch.stack(all_reconstruct_loss).mean()
        total_psnr=torch.stack(all_psnr).mean()

        variance_loss=torch.std(true_class_probs,dim=0).sum()
        total_combined_loss=(recon_weight * total_recon_loss)+(total_ce_loss)+(alpha*variance_loss)

        #total_combined_loss.backward()
        #optimizer.step()

        total_final_psnr+=total_psnr.item()
        total_final_recon_loss+=total_recon_loss.item()
        total_loss+=total_combined_loss.item()
        mean_logits=stacked_logits.mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}",PSNR=f"{psnr}&")


    avg_recon_loss = total_final_recon_loss / len(few_dataloader)
    avg_psnr = total_final_psnr / len(few_dataloader)
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Testing On 10 Classes(Seen + Unseen)","Loss=",avg_loss,"Acccuracy=",accuracy,"Recon Loss:",avg_recon_loss, "PSNR:",avg_psnr)

In [None]:
model.eval()
total_loss, total_correct, total_queries,total_final_psnr,total_final_recon_loss= 0,0,0,0,0
from tqdm.notebook import tqdm
progress_bar=tqdm(strict_test_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)

total_recon_loss = 0
total_psnr = 0
with torch.no_grad():
    for episode in progress_bar:
        images,mask,support_images, support_labels, query_images, query_labels=episode
        images = images.squeeze(0).to(device, non_blocking=True) 
        mask = mask.squeeze(0).to(device, non_blocking=True)
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)
        

        total_combined_loss=0
        

        optimizer.zero_grad(set_to_none=True)
        all_ce_losses = []
        all_query_logits = []
        all_psnr=[]
        all_reconstruct_loss=[]
        model.train()
        for _ in range(n_times):
            _,support_embeddings=model(support_images)
            _,query_embeddings=model(query_images)
            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)
            ce_loss=loss_fn(logits,query_labels)

            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)

            reconstructed_image,_=model(images,mask)    
            recon_loss= masked_loss(reconstructed_image, support_images,mask)
            mse_loss=F.mse_loss(reconstructed_image,support_images)
            psnr=compute_psnr(mse_loss, max_val=1.0)

            all_reconstruct_loss.append(recon_loss)
            all_psnr.append(psnr)
        model.eval()
        
        total_ce_loss= torch.stack(all_ce_losses).mean()
        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=-1)

        true_class_probs = stacked_probs[
            torch.arange(n_times)[:, None],
            torch.arange(len(query_labels)),
            query_labels
        ]

        total_recon_loss=torch.stack(all_reconstruct_loss).mean()
        total_psnr=torch.stack(all_psnr).mean()

        variance_loss=torch.std(true_class_probs,dim=0).sum()
        total_combined_loss=(recon_weight * total_recon_loss)+(total_ce_loss)+(alpha*variance_loss)

        #total_combined_loss.backward()
        optimizer.step()

        total_final_psnr+=total_psnr.item()
        total_final_recon_loss+=total_recon_loss.item()
        total_loss+=total_combined_loss.item()
        mean_logits=stacked_logits.mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}",PSNR=f"{psnr}&")


    avg_recon_loss = total_final_recon_loss / len(few_dataloader)
    avg_psnr = total_final_psnr / len(few_dataloader)
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Testing On 5 Classes(Unseen)","Loss=",avg_loss,"Acccuracy=",accuracy,"Recon Loss:",avg_recon_loss, "PSNR:",avg_psnr)

In [None]:
model.load_state_dict(torch.load("/home/23dcs505/model_recon/5w5s_resnet.pth"))

In [None]:
epoch=20

In [None]:
model.eval()
total_loss, total_correct, total_queries,total_final_psnr,total_final_recon_loss= 0,0,0,0,0
from tqdm.notebook import tqdm
progress_bar=tqdm(test_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)

total_recon_loss = 0
total_psnr = 0
with torch.no_grad():
    for episode in progress_bar:
        images,mask,support_images, support_labels, query_images, query_labels=episode
        images = images.squeeze(0).to(device, non_blocking=True) 
        mask = mask.squeeze(0).to(device, non_blocking=True)
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)
        

        total_combined_loss=0
        

        optimizer.zero_grad(set_to_none=True)
        all_ce_losses = []
        all_query_logits = []
        all_psnr=[]
        all_reconstruct_loss=[]
        model.train()
        for _ in range(n_times):
            _,support_embeddings=model(support_images)
            _,query_embeddings=model(query_images)
            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)
            ce_loss=loss_fn(logits,query_labels)

            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)

            reconstructed_image,_=model(images,mask)    
            recon_loss= masked_loss(reconstructed_image, support_images,mask)
            mse_loss=F.mse_loss(reconstructed_image,support_images)
            psnr=compute_psnr(mse_loss, max_val=1.0)

            all_reconstruct_loss.append(recon_loss)
            all_psnr.append(psnr)
        model.eval()
        
        total_ce_loss= torch.stack(all_ce_losses).mean()
        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=-1)

        true_class_probs = stacked_probs[
            torch.arange(n_times)[:, None],
            torch.arange(len(query_labels)),
            query_labels
        ]

        total_recon_loss=torch.stack(all_reconstruct_loss).mean()
        total_psnr=torch.stack(all_psnr).mean()

        variance_loss=torch.std(true_class_probs,dim=0).sum()
        total_combined_loss=(recon_weight * total_recon_loss)+(total_ce_loss)+(alpha*variance_loss)

        #total_combined_loss.backward()
        #optimizer.step()

        total_final_psnr+=total_psnr.item()
        total_final_recon_loss+=total_recon_loss.item()
        total_loss+=total_combined_loss.item()
        mean_logits=stacked_logits.mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}",PSNR=f"{psnr}&")


    avg_recon_loss = total_final_recon_loss / len(few_dataloader)
    avg_psnr = total_final_psnr / len(few_dataloader)
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Testing On 10 Classes(Seen + Unseen)","Loss=",avg_loss,"Acccuracy=",accuracy,"Recon Loss:",avg_recon_loss, "PSNR:",avg_psnr)

In [None]:
# entropy = -(torch.softmax(mean_logits, dim=1) * torch.log_softmax(mean_logits, dim=1)).sum(dim=1).mean()
# print("Mean Predictive Entropy =", entropy.item())

In [None]:
model.eval()
total_loss, total_correct, total_queries,total_final_psnr,total_final_recon_loss= 0,0,0,0,0
from tqdm.notebook import tqdm
progress_bar=tqdm(strict_test_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)

total_recon_loss = 0
total_psnr = 0
with torch.no_grad():
    for episode in progress_bar:
        images,mask,support_images, support_labels, query_images, query_labels=episode
        images = images.squeeze(0).to(device, non_blocking=True) 
        mask = mask.squeeze(0).to(device, non_blocking=True)
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)
        

        total_combined_loss=0
        

        optimizer.zero_grad(set_to_none=True)
        all_ce_losses = []
        all_query_logits = []
        all_psnr=[]
        all_reconstruct_loss=[]
        model.train()
        for _ in range(n_times):
            _,support_embeddings=model(support_images)
            _,query_embeddings=model(query_images)
            n_way=torch.unique(support_labels).size(0)
            prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
            logits=classify_queries(prototypes,query_embeddings)
            ce_loss=loss_fn(logits,query_labels)

            all_ce_losses.append(ce_loss)
            all_query_logits.append(logits)

            reconstructed_image,_=model(images,mask)    
            recon_loss= masked_loss(reconstructed_image, support_images,mask)
            mse_loss=F.mse_loss(reconstructed_image,support_images)
            psnr=compute_psnr(mse_loss, max_val=1.0)

            all_reconstruct_loss.append(recon_loss)
            all_psnr.append(psnr)
        model.eval()
        
        total_ce_loss= torch.stack(all_ce_losses).mean()
        stacked_logits=torch.stack(all_query_logits)
        stacked_probs=torch.softmax(stacked_logits,dim=-1)

        true_class_probs = stacked_probs[
            torch.arange(n_times)[:, None],
            torch.arange(len(query_labels)),
            query_labels
        ]

        total_recon_loss=torch.stack(all_reconstruct_loss).mean()
        total_psnr=torch.stack(all_psnr).mean()

        variance_loss=torch.std(true_class_probs,dim=0).sum()
        total_combined_loss=(recon_weight * total_recon_loss)+(total_ce_loss)+(alpha*variance_loss)

        #total_combined_loss.backward()
        #optimizer.step()

        total_final_psnr+=total_psnr.item()
        total_final_recon_loss+=total_recon_loss.item()
        total_loss+=total_combined_loss.item()
        mean_logits=stacked_logits.mean(dim=0)
        preds=torch.argmax(mean_logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}",PSNR=f"{psnr}&")


    avg_recon_loss = total_final_recon_loss / len(few_dataloader)
    avg_psnr = total_final_psnr / len(few_dataloader)
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    print("Testing On 5 Classes(Unseen)","Loss=",avg_loss,"Acccuracy=",accuracy,"Recon Loss:",avg_recon_loss, "PSNR:",avg_psnr)

In [None]:
import matplotlib.pyplot as plt

def show_images(masked, recon, original):
    masked = masked.cpu().permute(1, 2, 0).numpy()
    recon = recon.cpu().permute(1, 2, 0).numpy()
    original = original.cpu().permute(1, 2, 0).numpy()
    

    fig, axs = plt.subplots(1, 3, figsize=(12, 4))
    axs[0].imshow(masked); axs[0].set_title('Masked Input'); axs[0].axis('off')
    axs[1].imshow(recon); axs[1].set_title('Reconstruction'); axs[1].axis('off')
    axs[2].imshow(original); axs[2].set_title('Original'); axs[2].axis('off')
    plt.show()


In [None]:
def compute_psnr(rmse, max_val=1.0):
    psnr = 20 * torch.log10(max_val / rmse)
    return psnr

In [None]:
r=random.randint(1,5000)
model.eval()
with torch.no_grad():
    masked_img,mask,img,_,_,_ = few_dataset[10]  # Single sample
    masked_img = masked_img.to(device)  # [1, 3, 224, 224]
    mask = mask.to(device)              # [1, 1, 224, 224]
    img = img.to(device)

    recon,_ = model(masked_img, mask)                  # [1, 3, 224, 224]
    show_images(masked_img[0], recon[0], img[0])
    reconstruction_loss_fn = nn.L1Loss()
    recon_loss = reconstruction_loss_fn(recon[0],img[0])
    psnr=compute_psnr(recon_loss, max_val=1.0)
    print(psnr.item())       # Pass individual tensors to visualization


In [None]:
import torch
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF

# Define the same mean and std you used for normalization
mean = torch.tensor([0.485, 0.456, 0.406])
std = torch.tensor([0.229, 0.224, 0.225])

def unnormalize(tensor, mean, std):
    """Reverses the normalization on a tensor."""
    # Clone the tensor to avoid modifying it in-place
    tensor = tensor.clone()
    # The un-normalization formula is: pixel = (pixel * std) + mean
    # We need to reshape mean and std to broadcast correctly
    for t, m, s in zip(tensor, mean, std):
        t.mul_(s).add_(m)
    return tensor

# --- In your visualization code ---
def show_images(masked_input, reconstruction, original):
    # Make sure tensors are on the CPU
    masked_input = masked_input.cpu()
    reconstruction = reconstruction.cpu()
    original = original.cpu()

    # Un-normalize all images
    masked_input = unnormalize(masked_input, mean, std)
    reconstruction = unnormalize(reconstruction, mean, std)
    original = unnormalize(original, mean, std)

    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Clip values to [0, 1] just in case of small floating point errors
    axes[0].imshow(masked_input.permute(1, 2, 0).clamp(0, 1))
    axes[0].set_title("Masked Input")
    axes[0].axis('off')

    axes[1].imshow(reconstruction.permute(1, 2, 0).clamp(0, 1))
    axes[1].set_title("Reconstruction")
    axes[1].axis('off')

    axes[2].imshow(original.permute(1, 2, 0).clamp(0, 1))
    axes[2].set_title("Original")
    axes[2].axis('off')

    plt.show()

# Now call this function with your tensors
# show_images(masked_img[0], recon[0], img[0])

In [None]:
n=random.randint(1,24)
show_images(masked_img[n], recon[n], img[n])
reconstruction_loss_fn = nn.L1Loss()
recon_loss = reconstruction_loss_fn(recon[n],img[n])
psnr=compute_psnr(recon_loss, max_val=1.0)
print(psnr.item())

In [None]:
print("Testing........... Started.......................")
model.eval()
total_loss, total_correct, total_queries= 0,0,0

from tqdm.notebook import tqdm
progress_bar=tqdm(test_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)

total_recon_loss = 0
total_psnr = 0
with torch.no_grad():
    for episode in progress_bar:
        images,mask,support_images, support_labels, query_images, query_labels=episode
        images = images.squeeze(0).to(device, non_blocking=True) 
        mask = mask.squeeze(0).to(device, non_blocking=True)
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)
        
        all_psnr=[]
        all_reconstruct_loss=[]
        total_combined_loss=0

        _,support_embeddings=model(support_images)
        _,query_embeddings=model(query_images)
        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        ce_loss=loss_fn(logits,query_labels)
        probs=torch.softmax(logits,dim=-1)
        true_class_probs = probs[
            torch.arange(len(query_labels)),
            query_labels
        ]

        #optimizer.zero_grad(set_to_none=True)
        for _ in range(n_times):
            exit

        reconstructed_image,_=model(images,mask)    
        recon_loss= masked_loss(reconstructed_image, support_images,mask)
        mse_loss=F.mse_loss(reconstructed_image,support_images)
        psnr=compute_psnr(mse_loss, max_val=1.0)
        total_recon_loss += recon_loss.item()
        total_psnr += psnr

        #variance_loss=torch.std(true_class_probs,dim=0).sum()
        total_combined_loss=(recon_weight * recon_loss)+(ce_loss)+(alpha*variance_loss)

        

        
        #total_combined_loss.backward()
        #optimizer.step()


        total_loss+=total_combined_loss.item()
        # mean_logits=stacked_logits.mean(dim=0)
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}",PSNR=f"{psnr}&")


    avg_recon_loss = total_recon_loss / len(few_dataloader)
    avg_psnr = total_psnr / len(few_dataloader)
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    accuracy_spn_0=(total_correct/total_queries)*100
    print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Stable Prototypical Network on all", len(test_list),"Class =",accuracy_spn_0,"%","PSNR=",avg_psnr)


    entropy = -(torch.softmax(logits, dim=1) * torch.log_softmax(logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())


In [None]:
print("Testing........... Started.......................")
model.eval()
total_loss, total_correct, total_queries= 0,0,0

from tqdm.notebook import tqdm
progress_bar=tqdm(strict_test_dataloader, desc=f"Epoch {epoch+1}/{epochs}",leave=False)

total_recon_loss = 0
total_psnr = 0
with torch.no_grad():
    for episode in progress_bar:
        images,mask,support_images, support_labels, query_images, query_labels=episode
        images = images.squeeze(0).to(device, non_blocking=True) 
        mask = mask.squeeze(0).to(device, non_blocking=True)
        support_images=(support_images.squeeze(0)).to(device, non_blocking=True)
        query_images=(query_images.squeeze(0)).to(device, non_blocking=True)
        support_labels=(support_labels.view(-1)).to(device, non_blocking=True)
        query_labels=(query_labels.view(-1)).to(device, non_blocking=True)
        
        all_psnr=[]
        all_reconstruct_loss=[]
        total_combined_loss=0

        _,support_embeddings=model(support_images)
        _,query_embeddings=model(query_images)
        n_way=torch.unique(support_labels).size(0)
        prototypes=compute_prototypes(support_embeddings,support_labels,n_way)
        logits=classify_queries(prototypes,query_embeddings)
        ce_loss=loss_fn(logits,query_labels)
        probs=torch.softmax(logits,dim=-1)
        true_class_probs = probs[
            torch.arange(len(query_labels)),
            query_labels
        ]

        #optimizer.zero_grad(set_to_none=True)
        for _ in range(n_times):
            exit

        reconstructed_image,_=model(images,mask)    
        recon_loss= masked_loss(reconstructed_image, support_images,mask)
        mse_loss=F.mse_loss(reconstructed_image,support_images)
        psnr=compute_psnr(mse_loss, max_val=1.0)
        total_recon_loss += recon_loss.item()
        total_psnr += psnr

        #variance_loss=torch.std(true_class_probs,dim=0).sum()
        total_combined_loss=(recon_weight * recon_loss)+(ce_loss)

        

        
        #total_combined_loss.backward()
        #optimizer.step()


        total_loss+=total_combined_loss.item()
        # mean_logits=stacked_logits.mean(dim=0)
        preds=torch.argmax(logits,dim=1)
        total_correct+=(preds==query_labels).sum().item()
        total_queries+=query_labels.size(0)

        avg_acc_till=(total_correct/total_queries)*100
        progress_bar.set_postfix(Loss=f"{total_combined_loss.item():4f}",Acc=f"{avg_acc_till}",PSNR=f"{psnr}&")


    avg_recon_loss = total_recon_loss / len(few_dataloader)
    avg_psnr = total_psnr / len(few_dataloader)
    avg_loss=total_loss/len(few_dataloader)
    accuracy=(total_correct/total_queries)*100
    accuracy_spn_0=(total_correct/total_queries)*100
    print(ways,"way",shots,"shot:","Loss=",avg_loss,"Acccuracy of Stable Prototypical Network on all", len(test_list),"Class =",accuracy_spn_0,"%","PSNR=",avg_psnr)


    entropy = -(torch.softmax(logits, dim=1) * torch.log_softmax(logits, dim=1)).sum(dim=1).mean()
    print("Mean Predictive Entropy =", entropy.item())
