In [2]:
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
import torch
import random
import numpy as np
import scipy.io as scp
import torch.optim as optim
import torchvision.models as models

In [3]:
train_transform = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                        [0.229, 0.224, 0.225])])

testval_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

train_dataset = torchvision.datasets.Flowers102(root='./data', split='train', download=True, transform=train_transform)
val_dataset = torchvision.datasets.Flowers102(root='./data', split='val', download=True, transform=testval_transform)
test_dataset = torchvision.datasets.Flowers102(root='./data', split='test', download=True, transform=testval_transform)

Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/102flowers.tgz to data/flowers-102/102flowers.tgz


100.0%


Extracting data/flowers-102/102flowers.tgz to data/flowers-102
Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/imagelabels.mat to data/flowers-102/imagelabels.mat


100.0%


Downloading https://thor.robots.ox.ac.uk/datasets/flowers-102/setid.mat to data/flowers-102/setid.mat


100.0%


In [4]:
def image_preprocessing(pil_image):    
    # -------- Resize with Aspect Ratio maintained--------- #
    # First fixing the short axes
    if pil_image.size[0] > pil_image.size[1]:
        pil_image.thumbnail((10000000, 256))
    else:
        pil_image.thumbnail((256, 100000000))
    
    # ---------Crop----------- #
    left_margin = (pil_image.width - 224) / 2
    bottom_margin = (pil_image.height - 224) / 2
    right_margin = left_margin + 224
    top_margin = bottom_margin + 224
    
    pil_image = pil_image.crop((left_margin, bottom_margin, right_margin, top_margin))
    
    # --------- Convert to np then Normalize ----------- #
    np_image = np.array(pil_image) / 255
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    np_image = (np_image -mean) / std
    
    # --------- Transpose to fit PyTorch Axes ----------#
    np_image = np_image.transpose([2, 0, 1])
    
    return np_image

def imshow(pt_image, ax = None, title = None):
    '''
    Takes in a PyTorch-compatible image with [Ch, H, W],
    Convert it back to [H, W, Ch], 
    Undo the preprocessing,
    then display it on a grid
    '''
    if ax is None:
        fig, ax = plt.subplots()
    
    # --------- Transpose ----------- #
    plt_image = pt_image.transpose((1, 2, 0))
    
    # --------- Undo the preprocessing --------- #
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    plt_image = plt_image * std + mean
    
    if title is not None:
        ax.set_title(title)
        
    # Image need to be clipped between 0 and 1 or it looks noisy
    plt_image = np.clip(plt_image, 0, 1)
    
    # this imshow is a function defined in the plt module
    ax.imshow(plt_image)
    
    return ax

In [5]:
label_path = './data/flowers-102/imagelabels.mat'
label_arr = scp.loadmat(label_path)['labels']
label_arr

array([[77, 77, 77, ..., 62, 62, 62]], dtype=uint8)

In [6]:
split_path = './data/flowers-102/setid.mat'
data_splits = scp.loadmat(split_path)
train_split = data_splits['trnid']
print(train_split.shape)
val_split = data_splits['valid']
print(val_split.shape)
test_split = data_splits['tstid']
print(test_split.shape)

(1, 1020)
(1, 1020)
(1, 6149)


In [7]:
def train(dataloader, model, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for idx, (data, target) in enumerate(dataloader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        total_loss += loss.item()
        loss.backward()
        optimizer.step()
    
    return total_loss/len(dataloader)

In [8]:
def eval(dataloader, model, device):
    model.eval()
    correct = 0
    with torch.no_grad():
        for idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)

            output = model(data)
            pred = output.argmax(dim=1)

            correct += pred.eq(target.view_as(pred)).sum().item() # compare predicted label to actual label
    return correct / len(dataloader.dataset)

In [9]:
def nn_setup(dropout=0.5, hidden_layer1 = 120,lr = 0.001):
    
    model = models.vgg16(pretrained=True)  
        
    for param in model.parameters():
        param.requires_grad = False

        from collections import OrderedDict
        classifier = nn.Sequential(OrderedDict([
                          ('fc1', nn.Linear(25088, 500)),
                          ('relu', nn.ReLU()),
                          ('dropout1', nn.Dropout(dropout)),
                          ('fc2', nn.Linear(500, 102)),
                          ('output', nn.LogSoftmax(dim=1))
                          ]))
        
        model.classifier = classifier
        criterion = nn.NLLLoss()
        optimizer = optim.Adam(model.classifier.parameters(), lr )
        
        return model , optimizer ,criterion

### Training Normal Model Below ( Haven't Run Yet )

In [60]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_EPOCH = 100
NUM_CLASSES = 5

# HYPERPARAMS TO TUNE
NUM_HIDDEN = 128
NUM_LAYERS = 1
BATCH_SIZE = 128
EARLY_STOP_THRESHOLD = 3
LR = 0.001
loss_list = []
accuracy_list = []
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
model,optimizer,criterion = nn_setup()
model.to(DEVICE)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
best_acc = 0
early_stop_count = 0

In [64]:
for epoch in range(1, NUM_EPOCH+1):
    train_loss = train(train_loader, model, criterion, optimizer, DEVICE)
    accuracy = eval(val_loader, model, DEVICE)
    print(f'Epoch {epoch}, Train Loss: {train_loss}, Val Accuracy: {accuracy}')
    if accuracy > best_acc:
        best_acc = accuracy
        early_stop_count = 0
    else:
        early_stop_count += 1
    if early_stop_count >= EARLY_STOP_THRESHOLD:
        print("Early Stopping...")        
        break
    scheduler.step()
test_accuracy = eval(test_loader, model, DEVICE)
print(f'Test Accuracy: {test_accuracy}')

Epoch 1, Train Loss: 4.567100822925568, Val Accuracy: 0.2803921568627451
Epoch 2, Train Loss: 3.4106183648109436, Val Accuracy: 0.47352941176470587
Epoch 3, Train Loss: 2.7541665732860565, Val Accuracy: 0.6127450980392157
Epoch 4, Train Loss: 2.277555912733078, Val Accuracy: 0.6764705882352942


KeyboardInterrupt: 

### Initialising Few Shot Training Data

In [10]:
def train_model_for_episode(model, criterion, optimizer, support_loader, query_loader, device):
    # Set the model in training mode
    model.train()
    
    # Iterate through support set
    for support_batch in support_loader:
        support_inputs, support_targets = support_batch
        support_inputs, support_targets = support_inputs.to(device), support_targets.to(device)

        # Forward pass
        support_outputs = model(support_inputs)
        
        # Compute loss
        loss = criterion(support_outputs, support_targets)
        
        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    # Evaluate on the query set
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for query_batch in query_loader:
            query_inputs, query_targets = query_batch
            query_inputs, query_targets = query_inputs.to(device), query_targets.to(device)

            query_outputs = model(query_inputs)
            _, predicted = torch.max(query_outputs, 1)
            total_samples += query_targets.size(0)
            total_correct += (predicted == query_targets).sum().item()
    
    accuracy = total_correct / total_samples
    print(f'Accuracy for ')
    return accuracy

In [11]:
def evaluate_model_on_test_set(model, test_loader, device):
    model.eval()  # Set the model to evaluation mode
    total_correct = 0
    total_samples = 0
    
    with torch.no_grad():
        for batch in test_loader:
            inputs, targets = batch
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            
            # Calculate accuracy for the batch
            correct = (predicted == targets).sum().item()
            total_correct += correct
            total_samples += targets.size(0)
    
    accuracy = total_correct / total_samples
    return accuracy

In [None]:
# Define hyperparameters
N = 5  # We initialise 5 Classes to be processed per episode
K = 1  # Number of support-set images per class
Q = 2  # Number of query images per class
total_episodes = 102  # Total number of episodes for training
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
NUM_EPOCH = 100
model,optimizer,criterion = nn_setup()

# Extracting Class Labels

class_labels = set()
for _, label in train_dataset:
    class_labels.add(label)

class_labels = list(class_labels) # Converting to list for easier manipulation

class_to_idx = {class_label: idx for idx, class_label in enumerate(class_labels)}

# Creating the training loop
for episode in range(total_episodes):
    
    sampled_classes = random.sample(class_labels, N)

    # Step 2: Sampling support-set and query-set images
    support_set = []
    query_set = []
    for class_name in sampled_classes:
        class_indices = [i for i, (_, label) in enumerate(train_dataset) if label == class_to_idx[class_name]]
        support_indices = random.sample(class_indices, K)
        query_indices = random.sample(class_indices, Q)
        
        # Organize support-set and query-set images
        support_set.extend([train_dataset[i] for i in support_indices])
        query_set.extend([train_dataset[i] for i in query_indices])
    
    support_loader = DataLoader(support_set, batch_size = 256, shuffle=True)
    query_loader = DataLoader(query_set, batch_size = 256, shuffle=False)

    # Step 3: Training the model for each episode
    train_model_for_episode(model, criterion, optimizer, support_loader, query_loader, DEVICE) # Using Model, Criterion and Optimiser from nn_setup()

# Step 4: Evaluating the model on the validation dataset
evaluate_model_on_test_set(model, val_loader, DEVICE)  