In [1]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch
import random
import numpy as np
import scipy.io as scp
import torch.optim as optim
import torchvision.models as models
from dataset import train_dataset, test_dataset, val_dataset

In [2]:
def image_preprocessing(pil_image):    
    # -------- Resize with Aspect Ratio maintained--------- #
    # First fixing the short axes
    if pil_image.size[0] > pil_image.size[1]:
        pil_image.thumbnail((10000000, 256))
    else:
        pil_image.thumbnail((256, 100000000))
    
    # ---------Crop----------- #
    left_margin = (pil_image.width - 224) / 2
    bottom_margin = (pil_image.height - 224) / 2
    right_margin = left_margin + 224
    top_margin = bottom_margin + 224
    
    pil_image = pil_image.crop((left_margin, bottom_margin, right_margin, top_margin))
    
    # --------- Convert to np then Normalize ----------- #
    np_image = np.array(pil_image) / 255
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    np_image = (np_image -mean) / std
    
    # --------- Transpose to fit PyTorch Axes ----------#
    np_image = np_image.transpose([2, 0, 1])
    
    return np_image

def imshow(pt_image, ax = None, title = None):
    '''
    Takes in a PyTorch-compatible image with [Ch, H, W],
    Convert it back to [H, W, Ch], 
    Undo the preprocessing,
    then display it on a grid
    '''
    if ax is None:
        fig, ax = plt.subplots()
    
    # --------- Transpose ----------- #
    plt_image = pt_image.transpose((1, 2, 0))
    
    # --------- Undo the preprocessing --------- #
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    plt_image = plt_image * std + mean
    
    if title is not None:
        ax.set_title(title)
        
    # Image need to be clipped between 0 and 1 or it looks noisy
    plt_image = np.clip(plt_image, 0, 1)
    
    # this imshow is a function defined in the plt module
    ax.imshow(plt_image)
    
    return ax

In [3]:
label_path = './data/flowers-102/imagelabels.mat'
label_arr = scp.loadmat(label_path)['labels']
label_arr

array([[77, 77, 77, ..., 62, 62, 62]], dtype=uint8)

In [4]:
split_path = './data/flowers-102/setid.mat'
data_splits = scp.loadmat(split_path)
train_split = data_splits['trnid']
print(train_split.shape)
val_split = data_splits['valid']
print(val_split.shape)
test_split = data_splits['tstid']
print(test_split.shape)

(1, 1020)
(1, 1020)
(1, 6149)


In [5]:
from train import eval
from model import vgg, resnet, mobilenet

### Testing Model

In [7]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
PATH = './saved_models/mobilenet_model_weights.pth'
BATCH_SIZE = 128

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
model,optimizer,criterion = mobilenet()
model.to(DEVICE)
model.load_state_dict(torch.load(PATH))

test_accuracy, _ = eval(test_loader, model, criterion, DEVICE)
print(f'Test Accuracy: {test_accuracy}')

Test Accuracy: (0.827939502358107, 1.3834412675731036)


### Initialising Few Shot Training Data

In [12]:
def train_model_for_episode(model, criterion, optimizer, support_loader, query_loader, device):
    # Set the model in training mode
    model.train()
    
    # Iterate through support set
    for support_batch in support_loader:
        support_inputs, support_targets = support_batch
        support_inputs, support_targets = support_inputs.to(device), support_targets.to(device)

        # Forward pass
        support_outputs = model(support_inputs)
        
        # Compute loss
        loss = criterion(support_outputs, support_targets)
        
        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Evaluate on the query set
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for query_batch in query_loader:
            query_inputs, query_targets = query_batch
            query_inputs, query_targets = query_inputs.to(device), query_targets.to(device)

            query_outputs = model(query_inputs)
            _, predicted = torch.max(query_outputs, 1)
            total_samples += query_targets.size(0)
            total_correct += (predicted == query_targets).sum().item()
    
    accuracy = total_correct / total_samples
    
    return accuracy

In [9]:
# def evaluate_model_on_test_set(model, test_loader, device):
#     model.eval()  # Set the model to evaluation mode
#     total_correct = 0
#     total_samples = 0
    
#     with torch.no_grad():
#         for batch in test_loader:
#             inputs, targets = batch
#             inputs, targets = inputs.to(device), targets.to(device)
            
#             # Forward pass
#             outputs = model(inputs)
#             _, predicted = torch.max(outputs, 1)
            
#             # Calculate accuracy for the batch
#             correct = (predicted == targets).sum().item()
#             total_correct += correct
#             total_samples += targets.size(0)
    
#     accuracy = total_correct / total_samples
#     return accuracy

In [13]:
# Define hyperparameters
N = 5  # We initialise 5 Classes to be processed per episode
K = 1  # Number of support-set images per class
Q = 2  # Number of query images per class
total_episodes = 102  # Total number of episodes for training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_EPOCH = 100
model,optimizer,criterion = mobilenet()

# Extracting Class Labels

class_labels = set()
for _, label in train_dataset:
    class_labels.add(label)

class_labels = list(class_labels) # Converting to list for easier manipulation

class_to_idx = {class_label: idx for idx, class_label in enumerate(class_labels)}

# Creating the training loop
for episode in range(total_episodes):
    
    sampled_classes = random.sample(class_labels, N)

    # Step 2: Sampling support-set and query-set images
    support_set = []
    query_set = []
    for class_name in sampled_classes:
        class_indices = [i for i, (_, label) in enumerate(train_dataset) if label == class_to_idx[class_name]]
        support_indices = random.sample(class_indices, K)
        query_indices = random.sample(class_indices, Q)
        
        # Organize support-set and query-set images
        support_set.extend([train_dataset[i] for i in support_indices])
        query_set.extend([train_dataset[i] for i in query_indices])
    
    support_loader = DataLoader(support_set, batch_size = 256, shuffle=True)
    query_loader = DataLoader(query_set, batch_size = 256, shuffle=False)

    # Step 3: Training the model for each episode
    accuracy = train_model_for_episode(model, criterion, optimizer, support_loader, query_loader, DEVICE) # Using Model, Criterion and Optimiser from nn_setup()
    print(f'Accuracy for episode {episode}: {accuracy}')

# Step 4: Evaluating the model on the validation dataset
# evaluate_model_on_test_set(model, val_loader, DEVICE)
val_accuracy, _ = eval(val_loader, model, criterion, DEVICE)
print(f'Val Accuracy: {val_accuracy}')

Accuracy for episode 0: 0.0
Accuracy for episode 1: 0.0
Accuracy for episode 2: 0.0
Accuracy for episode 3: 0.0
Accuracy for episode 4: 0.0
Accuracy for episode 5: 0.2
Accuracy for episode 6: 0.0
Accuracy for episode 7: 0.0
Accuracy for episode 8: 0.0
Accuracy for episode 9: 0.0
Accuracy for episode 10: 0.0
Accuracy for episode 11: 0.0
Accuracy for episode 12: 0.0
Accuracy for episode 13: 0.0
Accuracy for episode 14: 0.0
Accuracy for episode 15: 0.0
Accuracy for episode 16: 0.0
Accuracy for episode 17: 0.0
Accuracy for episode 18: 0.0
Accuracy for episode 19: 0.0
Accuracy for episode 20: 0.2
Accuracy for episode 21: 0.0
Accuracy for episode 22: 0.0
Accuracy for episode 23: 0.0
Accuracy for episode 24: 0.0
Accuracy for episode 25: 0.0
Accuracy for episode 26: 0.0
Accuracy for episode 27: 0.1
Accuracy for episode 28: 0.2
Accuracy for episode 29: 0.0
Accuracy for episode 30: 0.0
Accuracy for episode 31: 0.2
Accuracy for episode 32: 0.0
Accuracy for episode 33: 0.0
Accuracy for episode 34: