<a href="https://colab.research.google.com/github/pxndey/dl-projects-2/blob/main/fewshot-mnist/fewshot2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install easyfsl



In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torchvision
from torchvision.transforms import Resize
import torch.nn.functional as F
from torchvision.models import resnet18
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
class ProtoNet(nn.Module):
    def __init__(self, backbone, output_channels, hidden_size=64):
        super(ProtoNet, self).__init__()
        self.backbone = backbone
        self.fc = nn.Linear(hidden_size * 7 * 7, output_channels)

    def forward(self, support_images, support_labels, query_images):
        # Extract the features of support and query images
        z_support = self.backbone(support_images)
        z_query = self.backbone(query_images)

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [
                z_support[torch.nonzero(support_labels == label)].mean(0)
                for label in range(n_way)
            ]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        # And here is the super complicated operation to transform those distances into classification scores!
        scores = -dists
        return scores


resnet18_one_channel = resnet18(pretrained=False)
resnet18_one_channel.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

model = ProtoNet(backbone=resnet18_one_channel, output_channels=10).to(device)



In [4]:

# Hyperparameters
num_epochs = 50
learning_rate = 0.001

# Load the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    Resize((224, 224))  # Add this line to resize the images to 224x224
])
mnist_train = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
mnist_test = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# DataLoaders for training and testing


convolutional_network = torchvision.models.resnet18(pretrained=True)
convolutional_network.fc = nn.Flatten()

# Instantiate the Prototypical Network

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training and testing loss lists for plotting
train_losses = []
test_losses = []

# Hyperparameter for early stopping
early_stop_patience = 10
best_test_loss = float('inf')
counter = 0




In [5]:
# import numpy as np
# for batch in train_loader:
#   support,labels = batch
#   support, labels = support.to(device), labels.to(device)
#   print(labels)
#   fig,axs = plt.subplots(2,5)
#   for i in range(2):
#     for j in range(5):
#       image = np.squeeze(support[i*5+j].cpu().numpy(), axis=0)
#       axs[i,j].imshow(image)
    # plt.imshow(image, cmap='gray')  # Assuming it's a grayscale image

#   break


In [14]:
num_classes = 10  # Number of classes in MNIST
num_shots = 5   # Number of shots per class
num_query = 1     # Number of query examples per class
train_loader = DataLoader(mnist_train, batch_size=num_shots * num_classes, shuffle=True)
test_loader = DataLoader(mnist_test, batch_size=num_query * num_classes, shuffle=True)

for batch in train_loader:
    # image = np.squeeze(batch[0][0].numpy(), axis=0)
    # print(batch[1][0])
    # plt.imshow(image, cmap='gray')  # Assuming it's a grayscale image
    # # break
    support,labels = batch
    support, labels = support.to(device), labels.to(device)
    # print(labels.shape)
    print(support.shape)
    support_images = support[:num_shots * (num_classes - 1)]
    print(support_images.shape)
    query_images = support[num_shots * (num_classes - 1):]
    print(query_images.shape)
    support_labels = labels[:num_shots * (num_classes - 1)]
    query_labels = labels[num_shots * (num_classes - 1):]
    # image = np.squeeze(query_images[0].cpu().numpy(), axis=0)
    # plt.imshow(image)
    print(support_labels.shape)
    print(query_labels.shape)
    break


torch.Size([50, 1, 224, 224])
torch.Size([45, 1, 224, 224])
torch.Size([5, 1, 224, 224])
torch.Size([45])
torch.Size([5])


In [15]:
 # Training and testing accuracy lists for plotting
train_accuracies = []
test_accuracies = []
# Training loop with early stopping and accuracy calculation
for epoch in range(num_epochs):
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    for batch in train_loader:
        images, labels = batch
        images, labels =images.to(device), labels.to(device)

        # Split the support set into support images and query images
        support_images = support[:num_shots * (num_classes - 1)]
        # print(support_images.shape)
        query_images = support[num_shots * (num_classes - 1):]
        # print(query_images.shape)
        support_labels = labels[:num_shots * (num_classes - 1)]
        query_labels = labels[num_shots * (num_classes - 1):]

        optimizer.zero_grad()

        # Call the modified forward method
        outputs = model.forward(support_images, support_labels, query_images)

        # Flatten the outputs and labels to match the cross-entropy function requirements
        outputs = outputs.view(-1)
        labels = labels.view(-1)
        labels = labels.to(torch.float32)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

        # Calculate training accuracy
        _, predicted_train = torch.max(outputs.data, 1)
        total_train += labels.size(0)
        correct_train += (predicted_train == labels).sum().item()

    # Calculate average training loss and accuracy
    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    train_accuracy = correct_train / total_train
    train_accuracies.append(train_accuracy)

    # Print and plot training loss and accuracy
    print(f'Epoch [{epoch+1}/{num_epochs}], Avg. Training Loss: {avg_train_loss:.4f}, Training Accuracy: {train_accuracy * 100:.2f}%')

    # Testing loop with early stopping and accuracy calculation
    model.eval()
    test_loss = 0.0
    correct_test = 0
    total_test = 0
    with torch.no_grad():
        for batch in test_loader:
            query_set, labels = batch
            query_set, labels = query_set.to(device), labels.to(device)

            # Call the modified forward method
            outputs = model.forward(support_images, support_labels, query_images)

            # Flatten the outputs and labels to match the cross-entropy function requirements
            outputs = outputs.view(-1, num_classes)
            labels = labels.view(-1)

            loss = criterion(outputs, labels)
            test_loss += loss.item()

            # Calculate testing accuracy
            _, predicted_test = torch.max(outputs.data, 1)
            total_test += labels.size(0)
            correct_test += (predicted_test == labels).sum().item()

    # Calculate average testing loss and accuracy
    avg_test_loss = test_loss / len(test_loader)
    test_losses.append(avg_test_loss)
    test_accuracy = correct_test / total_test
    test_accuracies.append(test_accuracy)

    # Print and plot testing loss and accuracy
    print(f'Avg. Test Loss: {avg_test_loss:.4f}, Test Accuracy: {test_accuracy * 100:.2f}%')

    # Early stopping check
    if avg_test_loss < best_test_loss:
        best_test_loss = avg_test_loss
        counter = 0
    else:
        counter += 1
        if counter >= early_stop_patience:
            print(f'Early stopping at epoch {epoch+1} as test loss has not improved for {early_stop_patience} consecutive epochs.')
            break

# ... (rest of the code)


IndexError: ignored

In [None]:
# Plot training and testing loss and accuracy graphs
plt.figure(figsize=(15, 5))

# Loss plots
plt.subplot(1, 2, 1)
plt.plot(range(1, epoch + 2), train_losses, label='Training Loss')
plt.plot(range(1, epoch + 2), [avg_test_loss] * (epoch + 1), label='Testing Loss')  # Repeat avg_test_loss for each epoch
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Testing Loss')
plt.legend()

# Accuracy plots
plt.subplot(1, 2, 2)
plt.plot(range(1, epoch + 2), train_accuracies, label='Training Accuracy')
plt.plot(range(1, epoch + 2), test_accuracies, label='Testing Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Testing Accuracy')
plt.legend()

plt.show()

In [None]:

# Function to randomly select and visualize an image from the test set
def visualize_random_image():
    model.eval()
    with torch.no_grad():
        # Select a random batch from the test loader
        batch = next(iter(test_loader))
        query_set, labels = batch
        query_set, labels = query_set.to(device), labels.to(device)

        # Make predictions
        outputs = model(query_set)
        _, predicted = torch.max(outputs.data, 1)

        # Select a random image from the batch
        index = np.random.choice(num_query * num_classes)
        image = query_set[index].cpu().numpy().squeeze()
        true_class = labels[index].item()
        predicted_class = predicted[index].item()

        # Visualize the image
        plt.imshow(image, cmap='gray')
        plt.title(f'True Class: {true_class}, Predicted Class: {predicted_class}')
        plt.show()



In [None]:
# Call the function to visualize a random image
visualize_random_image()