In [None]:
import numpy as np, pandas as pd, matplotlib.pyplot as plt, cv2, os
from fastai.vision.all import *
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
from sklearn.metrics import f1_score
import albumentations as A
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device:',device)

In [None]:
def get_m_symbols(classes_dir, m=1, to_run = False):
    
    print(f"Getting {m} symbols randomly per class")
    subfolder_pths, subfolder_names, fnames = zip(*[(dirpath, dirnames, filenames) for dirpath, dirnames, filenames in os.walk(classes_dir)])
    subfolder_pths = subfolder_pths[1:]; subfolder_names = subfolder_names[0]; fnames = fnames[1:]
    symbols = []
    targets = []
    for sub_pth, sub_name, sub_files in zip(subfolder_pths,subfolder_names,fnames):
        rand_m_symbols = random.sample(sub_files, m)
        rand_m_symbols_fpth = [os.path.join(sub_pth, o) for o in rand_m_symbols]
        symbols.append(rand_m_symbols_fpth)
        targets.append([int(sub_name)]*len(rand_m_symbols))

    symbols = [element for ls in symbols for element in ls]# flattent the list
    targets = [element for ls in targets for element in ls]# flattent the list
    print(f"Selected {len(symbols)} symbols and len of targets is also {len(targets)}")

    targ_names = list(Counter(targets).keys())
    targ_counts = list(Counter(targets).values())    
    print('******** DONE *************')
    return symbols, targets


def get_support_set_dict(support_symbols, support_targets):
    support_set_dict = {}
    for pth,label in zip(support_symbols, support_targets):
        if label not in support_set_dict:
            support_set_dict[label] = []
        support_set_dict[label].append(pth)
    return support_set_dict

def augment_data(support_set_dict, num_per_class, augmentation_scheme):
    """ works for 1 shot or 2 shot or 5 shot support set"""
    augmented_images = []
    augmented_labels = []

    for label in list(support_set_dict.keys()):
        for i in range(num_per_class):
            rand_pth = random.choice(support_set_dict[label])
            im = cv2.imread(rand_pth)
            transformed_im = augmentation_scheme(image = im)
            augmented_images.append(transformed_im['image'])
            augmented_labels.append(label)
    return augmented_images, augmented_labels

class One_shot_Dataset_st_list(Dataset):
    ''' PID dataset from symbols array list & targets list'''
    def __init__(self, symbols, targets, transform=None):
        self.symbols = symbols
        self.targets = targets 
        self.transform = transform

    def __len__(self):
        return len(self.symbols)

    def __getitem__(self, idx):
        symbol_image = Image.fromarray(self.symbols[idx])
        if self.transform:
            symbol_image = self.transform(symbol_image)
        label = self.targets[idx]
        return symbol_image, label 
    
class One_shot_Dataset_test(Dataset):
    ''' PID dataset from symbols paths list & targets list'''
    def __init__(self, symbols, targets, transform=None):
        self.symbols = symbols
        self.targets = targets 
        self.transform = transform

    def __len__(self):
        return len(self.symbols)

    def __getitem__(self, idx):

        symbol_image = cv2.imread(str(self.symbols[idx]))
        symbol_image = Image.fromarray(symbol_image)
        if self.transform:
            symbol_image = self.transform(symbol_image)

        label = self.targets[idx]

        return symbol_image, label
    


#  accuracy
def calculate_accuracy(output, target):
    _, predictions = torch.max(output, 1)
    corrects = (predictions == target).float().sum()
    return corrects / output.size(0)


### Inferencing #####

def extract_embeddings(model, data_loader):
    model = model.to(device)
    model = nn.Sequential(*list(model.children())[:-1]) 
    for p in model.parameters():
        p.requires_grad = False
    model.eval()
    embeddings = []
    labels = []
    with torch.no_grad():
        for batch_ims, batch_labels, _, _ in data_loader:
            outputs = model(batch_ims.to(device))
            embeddings.append(outputs)
            labels.append(batch_labels)
    embeddings = torch.cat(embeddings)
    embeddings = embeddings.squeeze()
    labels = torch.cat(labels)
    return embeddings, labels

class One_shot_Dataset_hw(Dataset):
    
    def __init__(self, symbols, targets, transform=None):
        self.symbols = symbols
        self.targets = targets 
        self.transform = transform

    def __len__(self):
        return len(self.symbols)

    def __getitem__(self, idx):

        symbol_image = cv2.imread(self.symbols[idx])
        h, w = symbol_image.shape[:2]
        symbol_image = Image.fromarray(symbol_image)
        if self.transform:
            symbol_image = self.transform(symbol_image)

        label = self.targets[idx]

        return symbol_image, label, h, w


In [None]:
classes_dir = '../Datasets/asupid/'

In [None]:
m=1

# Support set
support_symbols, support_targets = get_m_symbols(classes_dir, m=m, to_run=True)
support_set_dict = get_support_set_dict(support_symbols, support_targets)

# test set
test_symbols, test_targets = get_image_files(classes_dir), [int(o.parent.name) for o in get_image_files(classes_dir)]

# dataloader

to_tensor = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # helps in removing the affect of color
    ])


batch_size = 64

train_dataset = One_shot_Dataset_test(symbols = support_symbols, targets= support_targets, transform=to_tensor)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = One_shot_Dataset_test(test_symbols, test_targets, transform=to_tensor)
test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)



for ims,targs in train_loader:
    #print(ims)
    print(ims.shape, targs.shape)
    break

show_image_batch((ims,targs), items=32, cols=8, figsize=(8,4)) 


for ims,targs in test_loader:
    print(ims.shape, targs.shape)
    break


show_image_batch((ims,targs), items=32, cols=8, figsize=(8,4)) 

In [None]:
# Load the pre-trained ResNet50 model
model = torchvision.models.resnet50(pretrained=True)
model = nn.Sequential(*list(model.children())[:-1])
model = model.to(device)
for p in model.parameters():
    p.requires_grad = False
model.eval()
print('ready')

In [None]:
prototypes = []
proto_labels = []
with torch.no_grad():
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        prototypes.append(outputs.squeeze())
        proto_labels.append(labels)
prototypes = torch.cat(prototypes)
proto_labels = torch.cat(proto_labels)

In [None]:
query_embeddings = []
query_labels = []
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs = inputs.to(device)
        outputs = model(inputs)
        query_embeddings.append(outputs.squeeze())
        query_labels.append(labels)
query_embeddings = torch.cat(query_embeddings)
query_labels = torch.cat(query_labels)

In [None]:
prototypes = F.normalize(prototypes, p=2, dim=1)  # Shape [44, 2048]
query_embeddings = F.normalize(query_embeddings, p=2, dim=1)  # Shape [2000, 2048]

# Compute the cosine similarity
cosine_similarity = torch.mm(query_embeddings, prototypes.T)  # Shape [2000, 44]

# Find the index of the closest prototype for each query image
closest_prototype_indices = torch.argmax(cosine_similarity, dim=1)  # Shape [2000]

# If you need the actual maximum similarity values as well:
max_similarities, closest_prototype_indices = torch.max(cosine_similarity, dim=1)

In [None]:
pred_labels = []
for idx in closest_prototype_indices:
    pred_labels.append(proto_labels[idx].item())

In [None]:
from sklearn.metrics import f1_score

In [None]:
f1_score(query_labels, pred_labels, average='macro')

## Without Augmentations

In [None]:
for n_time in range(5):
    m=1

    # Support set
    support_symbols, support_targets = get_m_symbols(classes_dir, m=m, to_run=True)
    support_set_dict = get_support_set_dict(support_symbols, support_targets)

    # test set
    test_symbols, test_targets = get_image_files(classes_dir), [int(o.parent.name) for o in get_image_files(classes_dir)]

    # dataloader

    to_tensor = transforms.Compose([
            transforms.Resize((224,224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # helps in removing the affect of color
        ])


    batch_size = 64

    train_dataset = One_shot_Dataset_test(symbols = support_symbols, targets= support_targets, transform=to_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = One_shot_Dataset_test(test_symbols, test_targets, transform=to_tensor)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


    # Load the pre-trained ResNet50 model
    model = torchvision.models.resnet50(pretrained=True)
    model = nn.Sequential(*list(model.children())[:-1])
    model = model.to(device)
    for p in model.parameters():
        p.requires_grad = False
    model.eval()
    print('ready')

    prototypes = []
    proto_labels = []
    with torch.no_grad():
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            prototypes.append(outputs.squeeze())
            proto_labels.append(labels)
    prototypes = torch.cat(prototypes)
    proto_labels = torch.cat(proto_labels)
    print(prototypes.shape, len(proto_labels))


    query_embeddings = []
    query_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            query_embeddings.append(outputs.squeeze())
            query_labels.append(labels)
    query_embeddings = torch.cat(query_embeddings)
    query_labels = torch.cat(query_labels)

    print(query_embeddings.shape, len(query_labels))

    prototypes = F.normalize(prototypes, p=2, dim=1)  # Shape [44, 2048]
    query_embeddings = F.normalize(query_embeddings, p=2, dim=1)  # Shape [2000, 2048]

    # Compute the cosine similarity
    cosine_similarity = torch.mm(query_embeddings, prototypes.T)  # Shape [2000, 44]

    # Find the index of the closest prototype for each query image
    closest_prototype_indices = torch.argmax(cosine_similarity, dim=1)  # Shape [2000]


    pred_labels = []
    for idx in closest_prototype_indices:
        pred_labels.append(proto_labels[idx].item())

    print(n_time, ' F1 score: ', f1_score(query_labels, pred_labels, average='macro'))

## With Augmentations

In [None]:
for n_time in range(5):
    regular_augs = A.Compose([
        A.Resize(height = 224, width = 224),
        #A.D4(p=1), ### present in old
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.GridDistortion(distort_limit=(-0.1, 0.1), p=0.5),
        A.ShiftScaleRotate(rotate_limit = (-10,10),interpolation = 0, p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.AdvancedBlur(p=0.2),
        A.Defocus(radius=(3, 5), alias_blur=(0.05, 0.15), p=0.15),
        #A.Equalize(p=1), ########## Binarise #####
        A.GaussNoise(p=0.15),
        A.GaussianBlur(p=0.2),
        A.GlassBlur(p=0.1), #### simulate scanned dwgs
        A.Morphological(scale=(2, 3), operation='dilation', p=0.2),
        A.Morphological(scale=(2, 3), operation='erosion', p=0.2)
    ])


    albs = {'REGULAR': regular_augs}
    alb_key = 'REGULAR'
    # 1.1 Get m=1 symbol per class
    m=1
    support_symbols, support_targets = get_m_symbols(classes_dir, m=m, to_run=True)
    support_set_dict = get_support_set_dict(support_symbols, support_targets)

    alb_transform = albs[alb_key]

    num_per_class = 100
    batch_size = 64

    ## Run and save for later reloading
    train_syms, train_targs = augment_data(support_set_dict, num_per_class, alb_transform)
    print('Total Train images and labels generated are: ',len(train_syms), len(train_targs))


    train_dataset = One_shot_Dataset_st_list(symbols = train_syms, targets= train_targs, transform=to_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    test_dataset = One_shot_Dataset_test(test_symbols, test_targets, transform=to_tensor)
    test_loader  = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)


    # Load the pre-trained ResNet50 model
    model = torchvision.models.resnet50(pretrained=True)
    model = nn.Sequential(*list(model.children())[:-1])
    model = model.to(device)
    for p in model.parameters():
        p.requires_grad = False
    model.eval()
    print('ready')

    prototypes = []
    proto_labels = []
    with torch.no_grad():
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            prototypes.append(outputs.squeeze())
            proto_labels.append(labels)
    prototypes = torch.cat(prototypes)
    proto_labels = torch.cat(proto_labels)
    print(prototypes.shape, len(proto_labels))


    query_embeddings = []
    query_labels = []
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            query_embeddings.append(outputs.squeeze())
            query_labels.append(labels)
    query_embeddings = torch.cat(query_embeddings)
    query_labels = torch.cat(query_labels)

    print(query_embeddings.shape, len(query_labels))

    prototypes = F.normalize(prototypes, p=2, dim=1)  # Shape [44, 2048]
    query_embeddings = F.normalize(query_embeddings, p=2, dim=1)  # Shape [2000, 2048]

    # Compute the cosine similarity
    cosine_similarity = torch.mm(query_embeddings, prototypes.T)  # Shape [2000, 44]

    # Find the index of the closest prototype for each query image
    closest_prototype_indices = torch.argmax(cosine_similarity, dim=1)  # Shape [2000]


    pred_labels = []
    for idx in closest_prototype_indices:
        pred_labels.append(proto_labels[idx].item())

    print(n_time, ' F1 score: ', f1_score(query_labels, pred_labels, average='macro'))