In [None]:
train_data_path = '/kaggle/input/omniglot/images_background/'
test_data_path = '/kaggle/input/omniglot/images_evaluation/'
n_way = 5
# n_way = 20
k_shot = 1
q_query = 5
outer_lr = 0.001
inner_lr = 0.04
meta_batch_size = 32
train_inner_step = 1
eval_inner_step = 3
num_iterations = 1000
# num_iterations = 500
num_workers = 0
valid_size = 0.2
random_seed = 42
display_gap = 50

In [None]:
import os
import glob
import torch
import random
import collections
import numpy as np
import torch.optim as optim
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm
from torch import nn
import torchvision
from torch.utils.data import DataLoader, Subset
from torch.utils.data import random_split
from torch.utils.data.dataset import Dataset
from torchvision.transforms import transforms
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
%matplotlib inline

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
os.environ['PYTHONHASHSEED'] = str(random_seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [None]:
class MAMLDataset(Dataset):
    # here 'task_num' is the batch size of each meta learning iteration
    # "n_way" is the number of classes to be classified
    # "k_shot" is the number of data samples of each class used for training (support)
    # "q_query" is the number of data samples of each class used for testing (query)
    # "k_shot" is set to be equal to "q_query" by default
    def __init__(self, data_path, n_way=5, k_shot=1, q_query=1):

        self.file_list = self.get_file_list(data_path)
        self.n_way = n_way
        self.k_shot = k_shot
        self.q_query = q_query

    def get_file_list(self, data_path):
        raise NotImplementedError('get_file_list function not implemented!')

    def get_one_task_data(self):
        raise NotImplementedError('get_one_task_data function not implemented!')

    def __len__(self):
        # This number does not influence the results,
        # since we will randomly sample task 
        # from the entire class dataset each time,
        # which means that the size of it can be seen as Infinity.
        # And as the setting below, any number here wont influence the result.
        return len(self.file_list)

    def __getitem__(self, index):
        return self.get_one_task_data()

class OmniglotDataset(MAMLDataset):
    def get_file_list(self, data_path):
        """
        Get a list of all classes.
        Args:
            data_path: Omniglot Data path

        Returns: list of all classes

        """
        return [f for f in glob.glob(data_path + "**/character*", recursive=True)]

    def get_one_task_data(self):
        """
        Get ones task maml data, include one batch support images and labels, one batch query images and labels.
        Returns: support_data, query_data

        """
        img_dirs = random.sample(self.file_list, self.n_way)
        support_data = []
        query_data = []

        support_image = []
        support_label = []
        query_image = []
        query_label = []

        for label, img_dir in enumerate(img_dirs):
            img_list = [f for f in glob.glob(img_dir + "**/*.png", recursive=True)]
            images = random.sample(img_list, self.k_shot + self.q_query)

            # Read support set
            for img_path in images[:self.k_shot]:
                image = transforms.Resize(size=28)(Image.open(img_path))
                image = np.array(image)
                image = np.expand_dims(image / 255., axis=0)
                support_data.append((image, label))

            # Read query set
            for img_path in images[self.k_shot:]:
                image = transforms.Resize(size=28)(Image.open(img_path))
                image = np.array(image)
                image = np.expand_dims(image / 255., axis=0)
                query_data.append((image, label))

        # shuffle support set
        random.shuffle(support_data)
        for data in support_data:
            support_image.append(data[0])
            support_label.append(data[1])

        # shuffle query set
        random.shuffle(query_data)
        for data in query_data:
            query_image.append(data[0])
            query_label.append(data[1])

        # query_label = np.array(query_label)

        # # Now apply pseudo-labeling for unlabeled data (query set)
        # query_tensor = torch.tensor(query_image).float()
        
        # # Generate pseudo-labels for unlabeled data in the query set
        # pseudo_labels, high_confidence = generate_pseudo_labels(self.classifier, query_tensor)

        # # Update labels for high-confidence pseudo-labeled samples
        # query_label[high_confidence] = pseudo_labels[high_confidence]

        return np.array(support_image), np.array(support_label), np.array(query_image), np.array(query_label)


class OmniglotDatasetTrain(MAMLDataset):
    def __init__(self, data_path, n_way=5, k_shot=1, q_query=1, classifier=None):
        super().__init__(data_path, n_way, k_shot, q_query)
        self.classifier = classifier
    def get_file_list(self, data_path):
        """
        Get a list of all classes.
        Args:
            data_path: Omniglot Data path

        Returns: list of all classes

        """
        return [f for f in glob.glob(data_path + "**/character*", recursive=True)]
        
    def get_one_task_data(self):
        """
        Generate one MAML task, include one batch of support images and labels, one batch of query images and labels.
        """
        img_dirs = random.sample(self.file_list, self.n_way)
        support_data = []
        query_data = []

        support_images = []
        support_labels = []
        query_images = []
        query_labels = []

        for label, img_dir in enumerate(img_dirs):
            img_list = [f for f in glob.glob(img_dir + "**/*.png", recursive=True)]
            images = random.sample(img_list, self.k_shot + self.q_query)

            # Read support set
            for img_path in images[:self.k_shot]:
                image = transforms.Resize(size=28)(Image.open(img_path))
                image = np.array(image)
                image = np.expand_dims(image / 255., axis=0)
                support_data.append((image, label))

            # Read query set
            for img_path in images[self.k_shot:]:
                image = transforms.Resize(size=28)(Image.open(img_path))
                image = np.array(image)
                image = np.expand_dims(image / 255., axis=0)
                query_data.append((image, label))

        # shuffle support set
        random.shuffle(support_data)
        for data in support_data:
            support_images.append(data[0])
            support_labels.append(data[1])

        # shuffle query set
        random.shuffle(query_data)
        for data in query_data:
            query_images.append(data[0])
            query_labels.append(data[1])

        # Convert to Tensor (before passing to model)
        support_images = torch.tensor(support_images).float()
        support_labels = torch.tensor(support_labels).long()

        query_images = torch.tensor(query_images).float()
        # query_labels = torch.tensor(query_labels).long()

        # Convert grayscale images (1 channel) to RGB (3 channels)
        query_images_rgb = query_images.repeat(1, 3, 1, 1)  # Replicating grayscale to RGB

        # Now query_images_rgb has the right shape for ResNet (3 channels)
        logits = self.classifier(query_images_rgb)  # Generate pseudo-labels using ResNet

        # Generate pseudo-labels based on the model's predictions
        pseudo_labels, high_confidence = generate_pseudo_labels(self.classifier, query_images_rgb)

        query_labels = np.array(query_labels)
        pseudo_labels = np.array(pseudo_labels)

        # Use the high-confidence pseudo-labels to update query_labels
        query_labels[high_confidence] = pseudo_labels[high_confidence]

        return np.array(support_images), np.array(support_labels), np.array(query_images), np.array(query_labels)

    # def get_one_task_data(self):
    #     """
    #     Get ones task maml data, include one batch support images and labels, one batch query images and labels.
    #     Returns: support_data, query_data

    #     """
    #     img_dirs = random.sample(self.file_list, self.n_way)
    #     support_data = []
    #     query_data = []

    #     support_image = []
    #     support_label = []
    #     query_image = []
    #     query_label = []

    #     for label, img_dir in enumerate(img_dirs):
    #         img_list = [f for f in glob.glob(img_dir + "**/*.png", recursive=True)]
    #         images = random.sample(img_list, self.k_shot + self.q_query)

    #         # Read support set
    #         for img_path in images[:self.k_shot]:
    #             image = transforms.Resize(size=28)(Image.open(img_path))
    #             image = np.array(image)
    #             image = np.expand_dims(image / 255., axis=0)
    #             support_data.append((image, label))

    #         # Read query set
    #         for img_path in images[self.k_shot:]:
    #             image = transforms.Resize(size=28)(Image.open(img_path))
    #             image = np.array(image)
    #             image = np.expand_dims(image / 255., axis=0)
    #             query_data.append((image, label))

    #     # shuffle support set
    #     random.shuffle(support_data)
    #     for data in support_data:
    #         support_image.append(data[0])
    #         support_label.append(data[1])

    #     # shuffle query set
    #     random.shuffle(query_data)
    #     for data in query_data:
    #         query_image.append(data[0])
    #         query_label.append(data[1])

    #     query_label = np.array(query_label)

    #     # Now apply pseudo-labeling for unlabeled data (query set)
    #     query_tensor = torch.tensor(query_image).float()
        
    #     # Generate pseudo-labels for unlabeled data in the query set
    #     pseudo_labels, high_confidence = generate_pseudo_labels(self.classifier, query_tensor)

    #     # Update labels for high-confidence pseudo-labeled samples
    #     query_label[high_confidence] = pseudo_labels[high_confidence]

    #     return np.array(support_image), np.array(support_label), np.array(query_image), query_label

In [None]:
def get_dataset(
        train_data_path,
        test_data_path,
        n_way,
        k_shot,
        q_query,
        model
):
    """
    Get maml dataset.
    Args:
        args: ArgumentParser

    Returns: dataset
    """
    train_dataset = OmniglotDatasetTrain(train_data_path, 
                                    n_way, 
                                    k_shot, 
                                    q_query, 
                                    classifier=model)
    
    valid_dataset = OmniglotDatasetTrain(train_data_path, 
                                    n_way, 
                                    k_shot, 
                                    q_query, 
                                    classifier=model)

    test_dataset = OmniglotDataset(test_data_path, 
                                   n_way, 
                                   k_shot, 
                                   q_query, 
                                    )
    
    train_dataset, valid_dataset = spilt_train_valid(train_dataset, 
                                                     valid_dataset, 
                                                     valid_size)

    return train_dataset, valid_dataset, test_dataset


def spilt_train_valid(train_dataset, valid_dataset, valid_set_size):
    """
    Spilt train dataset into train and valid dataset according to the given size.
    Args:
        train_dataset: original train dataset
        valid_dataset: spilted valid dataset to put into
        valid_set_size: given size in terms of proportion
    
    Returns: spilted train and valid datasets
    """
    valid_set_size = int(valid_set_size * len(train_dataset))
    train_set_size = len(train_dataset) - valid_set_size

    file_list = train_dataset.file_list
    random.shuffle(file_list)
    
    train_dataset.file_list = file_list[:train_set_size]
    valid_dataset.file_list = file_list[train_set_size:]

    return train_dataset, valid_dataset

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(ConvBlock, self).__init__()
        self.conv2d = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.bn = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self, x):
        x = self.conv2d(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.max_pool(x)
        return x


def ConvBlockFunction(input, w, b, w_bn, b_bn):
    x = F.conv2d(input, w, b, padding=1)
    x = F.batch_norm(x, running_mean=None, running_var=None, weight=w_bn, bias=b_bn, training=True)
    x = F.relu(x)
    output = F.max_pool2d(x, kernel_size=2, stride=2)

    return output


class Classifier(nn.Module):
    def __init__(self, in_ch, n_way):
        super(Classifier, self).__init__()
        self.conv1 = ConvBlock(in_ch, 64)
        self.conv2 = ConvBlock(64, 64)
        self.conv3 = ConvBlock(64, 64)
        self.conv4 = ConvBlock(64, 64)
        self.logits = nn.Linear(64, n_way)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = x.view(x.shape[0], -1)
        x = self.logits(x)
        return x

    def functional_forward(self, x, params):
        for block in [1, 2, 3, 4]:
            x = ConvBlockFunction(
                x,
                params[f"conv{block}.conv2d.weight"],
                params[f"conv{block}.conv2d.bias"],
                params.get(f"conv{block}.bn.weight"),
                params.get(f"conv{block}.bn.bias"),
            )
        x = x.view(x.shape[0], -1)
        x = F.linear(x, params["logits.weight"], params["logits.bias"])
        return x

In [None]:
def maml_train(model, 
               support_images,
               support_labels,
               query_images,
               query_labels, 
               inner_step, 
               inner_lr,
               optimizer, 
               loss_fn,
               is_train=True):
    """
    Train the model using MAML method.
    Args:
        model: Any model
        support_images: several task support images
        support_labels: several  support labels
        query_images: several query images
        query_labels: several query labels
        inner_step: support data training step
        inner_lr: inner
        optimizer: optimizer
        is_train: whether train

    Returns: meta loss, meta accuracy
    """
    meta_loss = []
    meta_acc = []

    # Get support set and query set data for one train task
    for support_image, support_label, query_image, query_label \
        in zip(support_images, support_labels, query_images, query_labels):

        fast_weights = collections.OrderedDict(model.named_parameters())
        for _ in range(inner_step):
            # Update weight
            # logit: batch_num * n_way * 1
            support_logit = model.functional_forward(support_image, fast_weights)
            support_loss = loss_fn(support_logit, support_label)
            grads = torch.autograd.grad(support_loss, 
                                        fast_weights.values(), 
                                        create_graph=True)
            fast_weights = collections.OrderedDict((name, param - inner_lr * grads)
                                                   for ((name, param), grads) 
                                                   in zip(fast_weights.items(), grads))

        # Use trained weight to get query loss
        query_logit = model.functional_forward(query_image, fast_weights)
        query_prediction = torch.max(query_logit, dim=1)[1]

        query_loss = loss_fn(query_logit, query_label)
        query_acc = torch.eq(query_label, query_prediction).sum() / len(query_label)

        meta_loss.append(query_loss)
        meta_acc.append(query_acc.data.cpu().numpy())

    meta_loss = torch.stack(meta_loss).mean()
    meta_acc = np.mean(meta_acc)

    if is_train:
        optimizer.zero_grad()
        meta_loss.backward()
        optimizer.step()

    return meta_loss, meta_acc

In [None]:
# train_tasks, valid_tasks, test_tasks = get_dataset(train_data_path,
#                                                    test_data_path,
#                                                    n_way,
#                                                    k_shot,
#                                                    q_query,
#                                                    model=ssl_model)

# train_loader = DataLoader(train_tasks, batch_size=meta_batch_size, 
#                             shuffle=True, drop_last=True,  num_workers=num_workers)

# valid_loader = DataLoader(valid_tasks, batch_size=meta_batch_size, 
#                             shuffle=True, drop_last=True, num_workers=num_workers)

# test_loader = DataLoader(test_tasks, batch_size=meta_batch_size, 
#                             shuffle=False, drop_last=True, num_workers=num_workers)

In [None]:
maml_model = Classifier(in_ch=1, n_way=n_way)
maml_model.to(device)
optimizer = optim.Adam(maml_model.parameters(), outer_lr)
loss_fn = nn.CrossEntropyLoss().to(device)

In [None]:
checkpoint = torch.load('/kaggle/input/maml-5way-1shot-model/maml-para.pt') 
maml_model.load_state_dict(checkpoint)

In [None]:
# # ====================== evaluate model ====================
# model.load_state_dict(torch.load('maml-para.pt'))
# test_acc = []
# test_loss = []

# test_bar = tqdm(test_loader)
# model.eval()
# for support_images, support_labels, query_images, query_labels in test_bar:
#     test_bar.set_description("Testing")

#     support_images = support_images.float().to(device)
#     support_labels = support_labels.long().to(device)
#     query_images = query_images.float().to(device)
#     query_labels = query_labels.long().to(device)

#     loss, acc = maml_train(model, 
#                     support_images, 
#                     support_labels, 
#                     query_images, 
#                     query_labels,
#                     eval_inner_step, 
#                     inner_lr,
#                     optimizer, 
#                     loss_fn, 
#                     is_train=False)
#     test_loss.append(loss.item())
#     test_acc.append(acc)

# test_loss = np.mean(test_loss)
# test_acc = np.mean(test_acc)
# print('Meta Test Loss: {:.3f}, Meta Test Acc: {:.2f}%'.format(test_loss, 100 * test_acc))

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import cv2

# Dictionary to hold the gradients and activations
gradients = []
activations = []

def save_gradients(grad):
    global gradients
    gradients = grad

def save_activations(module, input, output):
    global activations
    activations = output

def register_hooks(model):
    # Register hooks for the last convolutional layer
    last_conv_layer = model.conv4  # ConvBlock of the last layer (you can change this depending on your model structure)
    
    last_conv_layer.conv2d.register_forward_hook(save_activations)
    last_conv_layer.conv2d.register_backward_hook(save_gradients)

In [None]:
def grad_cam(model, image, target_class=None):
    model.eval()
    image = image.unsqueeze(0).to(device)  # Add batch dimension and move to device (GPU/CPU)

    # Forward pass through the model
    output = model(image)

    if target_class is None:
        target_class = output.argmax(dim=1).item()  # Use the class with the highest score

    # Zero gradients to prepare for backward pass
    model.zero_grad()

    # Backward pass: calculate gradients of the target class score w.r.t. the last convolutional layer
    target_score = output[0, target_class]
    target_score.backward()

    # Get the gradients and activations
    global gradients, activations

    # Get the weights of the convolutional layer for the target class
    weights = torch.mean(gradients, dim=[0, 2, 3])  # Average over all spatial locations

    # Get the activations of the convolutional layer
    cams = torch.zeros(activations.shape[2:], dtype=torch.float32).to(device)

    # Compute the weighted sum of the activations
    for i in range(weights.shape[0]):
        cams += weights[i] * activations[0, i, :, :]

    # Apply ReLU to the weighted sum (Grad-CAM is non-negative)
    cams = F.relu(cams)

    # Normalize the cam
    cams = cams - cams.min()
    cams = cams / cams.max()

    # Convert to numpy and resize to match input image size
    cams = cams.cpu().detach().numpy()
    cams = cv2.resize(cams, (image.shape[2], image.shape[3]))  # Resize to image size
    return cams


In [None]:
def visualize_grad_cam(image, cams, target_class, label=None, save_path=None):
    # Convert the input image from tensor to numpy (remove batch dimension)
    image = image.squeeze().cpu().detach().numpy()

    # Normalize the image to [0, 1]
    image = (image - image.min()) / (image.max() - image.min())

    # Convert the heatmap to a 2D array
    heatmap = np.uint8(255 * cams)  # Scale between 0 and 255
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)  # Apply color map to heatmap

    # Resize heatmap to match the original image size
    heatmap = cv2.resize(heatmap, (image.shape[1], image.shape[0]))

    # Superimpose heatmap on the image
    superimposed_img = heatmap * 0.4 + np.repeat(image[:, :, np.newaxis], 3, axis=2) * 255

    # Convert to uint8 for visualization
    superimposed_img = np.uint8(np.clip(superimposed_img, 0, 255))

    # Display the image with heatmap overlay
    plt.figure(figsize=(10, 10))
    plt.imshow(superimposed_img)
    if label:
        plt.title(f"True label: {label}, Predicted class: {target_class}")
    else:
        plt.title(f"Predicted class: {target_class}")
    plt.axis('off')
    plt.show()
    
    if save_path:
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0)
        print(f"Saved Grad-CAM image to {save_path}")

In [None]:
# maml_model = Classifier(in_ch=1, n_way=n_way)
# maml_model.load_state_dict(torch.load('maml_model.pth'))
# maml_model.to(device)

# # Register the hooks on the model
# register_hooks(maml_model)

# # Sample a few test images (from the test set)
# test_loader = DataLoader(test_tasks, batch_size=5, shuffle=False, drop_last=True)

# for support_images, support_labels, query_images, query_labels in test_loader:
#     # Choose a batch of images (for Grad-CAM, we need a single image, so select one image from the batch)
#     query_images = query_images.to(device)

#     for i in range(5):  # Show Grad-CAM for 5 images
#         image = query_images[i]
#         label = query_labels[i].item()

#         # Compute the Grad-CAM heatmap for the image
#         cams = grad_cam(maml_model, image, target_class=label)

#         # Visualize the heatmap overlaid on the input image
#         visualize_grad_cam(image, cams, target_class=label, label=label)

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

# Make sure the model is in evaluation mode and the hooks are registered
maml_model.eval()  # Ensure the model is in evaluation mode (important for BatchNorm, Dropout)

# Register the hooks (we do this after loading the model)
register_hooks(maml_model)

# Load the test data (you already have the test_loader from the code)
test_loader = DataLoader(test_tasks, batch_size=5, shuffle=False, drop_last=True)

# Now let's visualize Grad-CAM for a few test images
for support_images, support_labels, query_images, query_labels in test_loader:
    # For demonstration, we'll take the first 5 query images from the test batch
    query_images = query_images.to(device)  # Move images to the correct device (GPU/CPU)

    for i in range(5):  # We can visualize Grad-CAM for 5 images
        image = query_images[i]  # Select the i-th image
        label = query_labels[i].item()  # Get the label for the image

        # Compute Grad-CAM for the image
        cams = grad_cam(maml_model, image, target_class=label)
        save_path = os.path.join(output_dir, f"grad_cam_image_{i}_label_{label}.png")
        # Visualize the Grad-CAM output
        visualize_grad_cam(image, cams, target_class=label, label=label, save_path=save_path)

    break  # We only need to process one batch for now (remove this line to process the entire test set)
