## Recognition 
1. Open recognition.ipynb  
2. Execute cells in sequence  
3. Choosing the pretrained model with the dataset and model setting shows in the model name  

> If the pretrained model name contains a prefix "rotation", execting the cell "rotation-based oversampling". Otherwise, do not execute it.  

## Matching algorithm
1. Single: Executing "multi-transform matching" cell without any additional transformations.  
2. Mirror-concatenated: Executing "mirror-concatenated matching" cell.  
3. Multi-transform: Executing "multi-transform matching" cell with any additional transformations you want.  


In [1]:
!nvidia-smi

Thu Sep 29 21:22:22 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:01:00.0  On |                  N/A |
|  0%   44C    P8    42W / 320W |    684MiB / 10240MiB |     38%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import os
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.utils.data as data
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import timm
import cv2
import math
import torch.nn.functional as F
from torch.autograd import Variable
import random
plt.style.use("default")
    
class AddGaussianNoise(object):
    def __init__(self, mean=0., std=1., amplitude=(0,25)):
        self.std = std
        self.mean = mean
        self.amplitude = amplitude
        
    def __call__(self, img):
        img_arr = np.asarray(img)
        gauss = np.random.normal(self.mean,self.std,img_arr.shape) * random.randint(*self.amplitude)
        img = img_arr + gauss
        img[img>255] = 255
        return Image.fromarray(img.astype("uint8")).convert("RGB")
    
    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

batch_size_train = 32
batch_size_test = 32
image_size = 224

# baseline augmentation
# trainingTransform = transforms.Compose([
#     transforms.ToPILImage(),
#     transforms.RandomApply([transforms.ColorJitter(brightness=0.5)], p=0.6),
#     transforms.RandomApply([transforms.ColorJitter(contrast=0.5)], p=0.6),
#     # transforms.RandomApply([transforms.ColorJitter(saturation=0.25)], p=0.6),
#     transforms.RandomApply([transforms.ColorJitter(hue=0.25)], p=0.6),
#     transforms.RandomApply([transforms.RandomChoice([transforms.GaussianBlur(1), transforms.GaussianBlur(3), transforms.GaussianBlur(5)])], p=0.6),
#     transforms.RandomAdjustSharpness(1.5, p=0.6),
#     transforms.RandomApply([transforms.RandomResizedCrop(size=image_size, scale=(1.07, 1.14))], p=0.6),
#     transforms.RandomApply([transforms.RandomRotation([-5,5], expand=False)], p=0.6),
#     # transforms.RandomApply([transforms.RandomAffine([-5,5], translate=(0.01, 0.01))], p=0.6),
#     transforms.Resize(image_size),
#     transforms.ToTensor(),
# ])
# testingTransform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Resize(image_size)
# ])


    
# Define space search for training settings
from operator import itemgetter
(brightness,
brightness_p,
contrast,
contrast_p,
# saturation,
# saturation_p,
hue,
hue_p,
# blur_sigma,
# blur_p,
noise_amp,
noise_p,
# rotation_degree,
# rotation_p,
# translatation,
# translatation_p
) = itemgetter(
    'brightness',
    'brightness_probability',
    'contrast',
    'contrast_probability',
    # 'saturation',
    # 'saturation_probability',
    'hue',
    'hue_probability',
    # 'blur_sigma',
    # 'blur_probability',
    'noise_amplitude',
    'noise_probability',
    # 'rotation_degree',
    # 'rotation_probability',
    # 'translatation',
    # 'translatation_probability',
)(torch.load("optuna/best_params-reorder-third-50.pt"))

trainingTransform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomApply([transforms.ColorJitter(brightness=brightness)], p=brightness_p),
    transforms.RandomApply([transforms.ColorJitter(contrast=contrast)], p=contrast_p),
    # transforms.RandomApply([transforms.ColorJitter(saturation=saturation)], p=saturation_p),
    transforms.RandomApply([transforms.ColorJitter(hue=hue)], p=hue_p),
    # transforms.RandomApply([transforms.RandomChoice([transforms.GaussianBlur(3, sigma=(0.1, blur_sigma)), 
    #                                                  transforms.GaussianBlur(5, sigma=(0.1, blur_sigma))])], p=blur_p),
    transforms.RandomApply([AddGaussianNoise(0 , 1, (0, noise_amp))], p=noise_p),
    # transforms.RandomApply([transforms.RandomAffine(degrees=rotation_degree)], p=rotation_p),
    # transforms.RandomApply([transforms.RandomAffine(0, translate=(translatation, translatation))], p=translatation_p),
    transforms.Resize(image_size),
    transforms.ToTensor(),
])
testingTransform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize(image_size),
    transforms.ToTensor(),
])

In [3]:
## Tongji dataset
# contain both session of specific indices
class TongjiTrainingDataset(Dataset):
    '''
    all images of selected indices 
    '''
    def __init__(self, root, indices, transforms):
        # 圖片所在的資料夾
        self.root = root
        # 需要的類別編號
        self.indices = indices
        self.transforms = transforms

        self.fnames = []
        self.labels = []
        for c in self.indices:
            for i in range(c*10, c*10+10):
                self.fnames.append(os.path.join(self.root, 'session1/{:05d}.tiff'.format(i+1)))
                self.fnames.append(os.path.join(self.root, 'session2/{:05d}.tiff'.format(i+1)))
                # 左右手視為不同的類別
                self.labels.append(c)
                self.labels.append(c)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        # 利用路徑讀取圖片
        img = Image.open(self.fnames[idx])
        # 將輸入的圖片轉換成符合預訓練模型的形式
        img = self.transforms(img)
        # 補足3個channel
        # img = img.repeat(3,1,1)
        # 圖片相對應的 label
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.fnames)
# contain first session of all indices, and second session of not selected
class TongjiTuningDataset(Dataset):
    '''
    testing set include half of the select indices, and this is the remain(useless)
    '''
    def __init__(self, root, indices, transforms):
        self.root = root
        # 註冊的類別編號
        self.indices = indices
        self.transforms = transforms

        self.fnames = []
        self.labels = []
        for i in range(6000):
            c = int(i/10)
            self.fnames.append(os.path.join(self.root, 'session1/{:05d}.tiff'.format(i+1)))
            self.labels.append(c)
            if c not in self.indices:
                self.fnames.append(os.path.join(self.root, 'session2/{:05d}.tiff'.format(i+1)))
                self.labels.append(c)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        img = Image.open(self.fnames[idx])
        img = self.transforms(img)
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.fnames)

# only contain one session
class TongjiTestingDataset(Dataset):
    '''
    half of the selected indices
    '''
    def __init__(self, root, indices, mode, transforms):
        # 圖片所在的資料夾
        if mode == "probe":
            self.root = os.path.join(root, "session2")
        else:
            self.root = os.path.join(root, "session1")
        # self.root = root
        # 需要的類別編號
        self.indices = indices
        self.transforms = transforms

        self.fnames = []
        self.labels = []
        for c in self.indices:
            for i in range(c*10, c*10+10):
                self.fnames.append(os.path.join(self.root, '{:05d}.tiff'.format(i+1)))
                # 左右手視為不同的類別
                self.labels.append(c)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        # 利用路徑讀取圖片
        img = Image.open(self.fnames[idx])
        # 將輸入的圖片轉換成符合預訓練模型的形式
        img = self.transforms(img)
        # 補足3個channel
        # img = img.repeat(3,1,1)
        # 圖片相對應的 label
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.fnames)
    
# flexibly spliting support set and query set
class TongjiFewShotDataset(Dataset):
    '''
    mode == gallery, get num_samples start from 1 to 20 of each class
    mode == probe, get num_samples start from 20 to 1 of each class
    '''
    def __init__(self, root, indices, num_samples, mode, transforms):
        # 圖片所在的資料夾
        self.root = root
        # 需要的類別編號
        self.indices = indices
        self.transforms = transforms
        if num_samples > 20:
            raise BaseException("Number of samples larger than the limit")
        else:
            self.session1 = 10 if num_samples > 10 else num_samples
            self.session2 = num_samples - self.session1
            if mode == "probe":
                self.session1, self.session2 = self.session2, self.session1
                self.session1 = range(10-self.session1, 10)
                self.session2 = range(10-self.session2, 10)
            else:
                self.session1 = range(self.session1)
                self.session2 = range(self.session2)

        self.fnames = []
        self.labels = []
        for c in self.indices:
            # get images from session1
            for i in self.session1: # [0,1,2,3,4,5,6,7,8,9]
                self.fnames.append(os.path.join(self.root, 'session1/{:05d}.tiff'.format(c*10+i+1)))
                self.labels.append(c)
            # get images from session2
            for i in self.session2: # [0,1,2,3,4,5,6,7,8,9]
                self.fnames.append(os.path.join(self.root, 'session2/{:05d}.tiff'.format(c*10+i+1)))
                self.labels.append(c)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        # 利用路徑讀取圖片
        img = Image.open(self.fnames[idx])
        img = np.asarray(img)
        # 將輸入的圖片轉換成符合預訓練模型的形式
        img = self.transforms(img)
        # 補足3個channel
        # img = img.repeat(3,1,1)
        # 圖片相對應的 label
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.fnames)
    
# flexibly spliting support set and query set
class TongjiRotationCopyDataset(Dataset):
    '''
    mode == gallery, get num_samples start from 1 to 20 of each class
    mode == probe, get num_samples start from 20 to 1 of each class
    '''
    def __init__(self, root, indices, num_samples, mode, transforms):
        # 圖片所在的資料夾
        self.root = root
        # 需要的類別編號
        self.indices = indices
        self.transforms = transforms
        self.num_samples = num_samples
        if num_samples > 20:
            raise BaseException("Number of samples larger than the limit")
        else:
            self.session1 = 10 if num_samples > 10 else num_samples
            self.session2 = num_samples - self.session1
            if mode == "probe":
                self.session1, self.session2 = self.session2, self.session1
                self.session1 = range(10-self.session1, 10)
                self.session2 = range(10-self.session2, 10)
            else:
                self.session1 = range(self.session1)
                self.session2 = range(self.session2)

        self.fnames = []
        self.labels = []
        for c in self.indices:
            # get images from session1
            for i in self.session1: # [0,1,2,3,4,5,6,7,8,9]
                self.fnames.append(os.path.join(self.root, 'session1/{:05d}.tiff'.format(c*10+i+1)))
                self.labels.append(c)
            # get images from session2
            for i in self.session2: # [0,1,2,3,4,5,6,7,8,9]
                self.fnames.append(os.path.join(self.root, 'session2/{:05d}.tiff'.format(c*10+i+1)))
                self.labels.append(c)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        i = idx % (len(self.indices) * self.num_samples)
        quotient = int(idx / (len(self.indices) * self.num_samples))
        img = Image.open(self.fnames[i])
        img = np.asarray(img)
        if quotient > 0:
            img = np.rot90(img, quotient, (0,1)) # will rotate 1,2,3 times
        img = self.transforms(img)
        label = self.labels[i]
        return img, label + 600 * quotient
    
    def __len__(self):
        return len(self.fnames)*4

In [4]:
## PolyU dataset
# contain both session of specific indices
class PolyUTrainingDataset(Dataset):
    '''
    all images of selected indices 
    '''
    def __init__(self, root, indices, transforms):
        # 圖片所在的資料夾
        self.root = root
        # 需要的類別編號
        self.indices = indices
        self.transforms = transforms

        self.fnames = [[],[],[]]  # R,G,B
        self.labels = []
        for c in self.indices:
            for i in range(6):
                self.fnames[0].append(os.path.join(self.root, 'Multispectral_R/{:03d}/1_{:02d}_s.bmp'.format(c+1, i+1)))
                self.fnames[1].append(os.path.join(self.root, 'Multispectral_G/{:03d}/1_{:02d}_s.bmp'.format(c+1, i+1)))
                self.fnames[2].append(os.path.join(self.root, 'Multispectral_B/{:03d}/1_{:02d}_s.bmp'.format(c+1, i+1)))
                self.fnames[0].append(os.path.join(self.root, 'Multispectral_R/{:03d}/2_{:02d}_s.bmp'.format(c+1, i+1)))
                self.fnames[1].append(os.path.join(self.root, 'Multispectral_G/{:03d}/2_{:02d}_s.bmp'.format(c+1, i+1)))
                self.fnames[2].append(os.path.join(self.root, 'Multispectral_B/{:03d}/2_{:02d}_s.bmp'.format(c+1, i+1)))
                # 2 sessions
                self.labels.append(c)
                self.labels.append(c)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        img_R = Image.open(self.fnames[0][idx])
        img_G = Image.open(self.fnames[1][idx])
        img_B = Image.open(self.fnames[2][idx])
        img = np.dstack((img_R,img_G,img_B))
        img = self.transforms(img)
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.labels)
    
# only contain one session
class PolyUTestingDataset(Dataset):
    '''
    half of the selected indices
    '''
    def __init__(self, root, indices, mode, transforms):
        self.root = root
        # 需要的類別編號
        self.indices = indices
        self.transforms = transforms
        # 決定session
        if mode == "probe":
            self.session = 2
        else:
            self.session = 1

        self.fnames = [[],[],[]]  # R,G,B
        self.labels = []
        for c in self.indices:
            for i in range(6):
                self.fnames[0].append(os.path.join(self.root, 'Multispectral_R/{:03d}/{}_{:02d}_s.bmp'.format(c+1, self.session, i+1)))
                self.fnames[1].append(os.path.join(self.root, 'Multispectral_G/{:03d}/{}_{:02d}_s.bmp'.format(c+1, self.session, i+1)))
                self.fnames[2].append(os.path.join(self.root, 'Multispectral_B/{:03d}/{}_{:02d}_s.bmp'.format(c+1, self.session, i+1)))
                self.labels.append(c)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        img_R = Image.open(self.fnames[0][idx])
        img_G = Image.open(self.fnames[1][idx])
        img_B = Image.open(self.fnames[2][idx])
        img = np.dstack((img_R,img_G,img_B))
        img = self.transforms(img)
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.labels)
    
# flexibly spliting support set and query set
class PolyUFewShotDataset(Dataset):
    '''
    mode == gallery, get num_samples start from 1 to 20 of each class
    mode == probe, get num_samples start from 20 to 1 of each class
    '''
    def __init__(self, root, indices, num_samples, mode, transforms):
        self.root = root
        self.indices = indices
        self.transforms = transforms
        if num_samples > 12:
            raise BaseException("Number of samples larger than the limit")
        else:
            session1 = 6 if num_samples > 6 else num_samples
            session2 = num_samples - session1
            if mode == "probe":
                session1, session2 = session2, session1
                self.session1 = range(6-session1, 6)
                self.session2 = range(6-session2, 6)
            else:
                self.session1 = range(session1)
                self.session2 = range(session2)

        self.fnames = []
        self.labels = []
        for c in self.indices:
            # get images from session1
            for i in self.session1:
                fname = []
                for channel in "RGB":
                    fname.append(os.path.join(self.root, 'Multispectral_{}/{:03d}/{}_{:02d}_s.bmp'.format(channel, c+1, 1, i+1)))
                self.fnames.append(fname)
                self.labels.append(c)
            # get images from session2
            for i in self.session2:
                fname = []
                for channel in "RGB":
                    fname.append(os.path.join(self.root, 'Multispectral_{}/{:03d}/{}_{:02d}_s.bmp'.format(channel, c+1, 2, i+1)))
                self.fnames.append(fname)
                self.labels.append(c)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        imgs = []
        for i in self.fnames[idx]:
            imgs.append(Image.open(i))
        img = np.dstack(imgs)
        img = self.transforms(img)
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.fnames)
    
# flexibly spliting support set and query set
class PolyURotationCopyDataset(Dataset):
    '''
    mode == gallery, get num_samples start from 1 to 20 of each class
    mode == probe, get num_samples start from 20 to 1 of each class
    '''
    def __init__(self, root, indices, num_samples, mode, transforms):
        self.root = root
        self.indices = indices
        self.transforms = transforms
        self.num_samples = num_samples
        if num_samples > 12:
            raise BaseException("Number of samples larger than the limit")
        else:
            session1 = 6 if num_samples > 6 else num_samples
            session2 = num_samples - session1
            if mode == "probe":
                session1, session2 = session2, session1
                self.session1 = range(6-session1, 6)
                self.session2 = range(6-session2, 6)
            else:
                self.session1 = range(session1)
                self.session2 = range(session2)

        self.fnames = []
        self.labels = []
        for c in self.indices:
            # get images from session1
            for i in self.session1:
                fname = []
                for channel in "RGB":
                    fname.append(os.path.join(self.root, 'Multispectral_{}/{:03d}/{}_{:02d}_s.bmp'.format(channel, c+1, 1, i+1)))
                self.fnames.append(fname)
                self.labels.append(c)
            # get images from session2
            for i in self.session2:
                fname = []
                for channel in "RGB":
                    fname.append(os.path.join(self.root, 'Multispectral_{}/{:03d}/{}_{:02d}_s.bmp'.format(channel, c+1, 2, i+1)))
                self.fnames.append(fname)
                self.labels.append(c)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        imgs = []
        i = idx % (len(self.indices) * self.num_samples)
        quotient = int(idx / (len(self.indices) * self.num_samples))
        for path in self.fnames[i]:
            imgs.append(Image.open(path))
        img = np.dstack(imgs)
        if quotient > 0:
            img = np.rot90(img, quotient, (0,1)) # will rotate 1,2,3 times
        img = self.transforms(img)
        label = self.labels[i]
        return img, label + 500 * quotient
    
    def __len__(self):
        return len(self.fnames)*4

In [5]:
## MPD dataset
# contain all data of specific phone
class MPDTrainingDataset(Dataset):
    '''
    all images of selected indices 
    '''
    def __init__(self, root, indices, phone, transforms):
        self.root = root
        self.indices = indices
        self.phone = phone
        self.transforms = transforms

        self.fnames = []
        self.labels = []
        for c in self.indices:
            for i in range(10):
                for p in self.phone:
                    self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 1, p, "l", i+1))) # 左手session1
                    self.labels.append(2*c)
                    self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 2, p, "l", i+1))) # 左手session2
                    self.labels.append(2*c)
                    self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 1, p, "r", i+1))) # 右手session1
                    self.labels.append(2*c+1)
                    self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 2, p, "r", i+1))) # 右手session2
                    self.labels.append(2*c+1)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        img = Image.open(self.fnames[idx])
        img = self.transforms(img)
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.fnames)

# flexibly spliting support set and query set
class MPDFewShotDataset(Dataset):
    '''
    can only use for single phone
    mode == gallery, get num_samples start from 1 to 20 of each class
    mode == probe, get num_samples start from 20 to 1 of each class
    '''
    def __init__(self, root, indices, num_samples, mode, transforms):
        self.root = root
        self.indices = indices
        self.transforms = transforms
        self.num_samples = num_samples
        if num_samples > 40:
            raise BaseException("Number of samples larger than the limit")
        else:
            if mode == "probe": # count backward
                self.session = [2,1]
                self.phone = "mh"
                self.samples = range(9, -1, -1)
            else:
                self.session = [1,2]
                self.phone = "hm"
                self.samples = range(10)

        
        self.fnames = []
        self.labels = []
        for c in self.indices:
            count = 0
            for s in self.session:
                for i in self.samples:
                    for p in self.phone:
                        if count >= self.num_samples:
                            break
                        self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, s, p, "l", i+1))) # 左手session1
                        self.labels.append(2*c)
                        self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, s, p, "r", i+1))) # 右手session1
                        self.labels.append(2*c+1)
                        count += 1
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        img = Image.open(self.fnames[idx])
        img = self.transforms(img)
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.fnames)
    
class MPDFewShotSingleDataset(Dataset):
    '''
    can only use for single phone
    mode == gallery, get num_samples start from 1 to 20 of each class
    mode == probe, get num_samples start from 20 to 1 of each class
    '''
    def __init__(self, root, indices, phone, num_samples, mode, transforms):
        self.root = root
        self.indices = indices
        self.phone = phone
        self.transforms = transforms
        self.num_samples = num_samples
        if num_samples > 20:
            raise BaseException("Number of samples larger than the limit")
        else:
            session1 = 10 if num_samples > 10 else num_samples
            session2 = num_samples - session1
            if mode == "probe": # count backward
                session1, session2 = session2, session1
                self.session1 = range(10-session1, 10)
                self.session2 = range(10-session2, 10)
            else:
                self.session1 = range(session1)
                self.session2 = range(session2)

        
        self.fnames = []
        self.labels = []
        for c in self.indices:
            # get images from session1
            for i in self.session1: # [0,1,2,3,4,5,6,7,8,9]
                self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 1, self.phone, "l", i+1))) # 左手session1
                self.labels.append(2*c)
                self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 1, self.phone, "r", i+1))) # 右手session1
                self.labels.append(2*c+1)
            # get images from session2
            for i in self.session2: # [0,1,2,3,4,5,6,7,8,9]
                self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 2, self.phone, "l", i+1))) # 左手session2
                self.labels.append(2*c)
                self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 2, self.phone, "r", i+1))) # 右手session2
                self.labels.append(2*c+1)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):
        img = Image.open(self.fnames[idx])
        img = np.asarray(img)
        img = self.transforms(img)
        label = self.labels[idx]
        return img, label
    
    def __len__(self):
        return len(self.fnames)
    
class MPDFewShotSingleCopyDataset(Dataset):
    '''
    can only use for single phone
    mode == gallery, get num_samples start from 1 to 20 of each class
    mode == probe, get num_samples start from 20 to 1 of each class
    '''
    def __init__(self, root, indices, phone, num_samples, mode, transforms):
        self.root = root
        self.indices = indices
        self.phone = phone
        self.transforms = transforms
        self.num_samples = num_samples
        if num_samples > 20:
            raise BaseException("Number of samples larger than the limit")
        else:
            session1 = 10 if num_samples > 10 else num_samples
            session2 = num_samples - session1
            if mode == "probe": # count backward
                session1, session2 = session2, session1
                self.session1 = range(10-session1, 10)
                self.session2 = range(10-session2, 10)
            else:
                self.session1 = range(session1)
                self.session2 = range(session2)

        
        self.fnames = []
        self.labels = []
        for c in self.indices:
            # get images from session1
            for i in self.session1: # [0,1,2,3,4,5,6,7,8,9]
                self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 1, self.phone, "l", i+1))) # 左手session1
                self.labels.append(2*c)
                self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 1, self.phone, "r", i+1))) # 右手session1
                self.labels.append(2*c+1)
            # get images from session2
            for i in self.session2: # [0,1,2,3,4,5,6,7,8,9]
                self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 2, self.phone, "l", i+1))) # 左手session2
                self.labels.append(2*c)
                self.fnames.append(os.path.join(self.root, '{:03d}_{}_{}_{}_{:02d}_ROI.jpeg'.format(c+1, 2, self.phone, "r", i+1))) # 右手session2
                self.labels.append(2*c+1)
        self.labels = torch.Tensor(self.labels).long()

    def __getitem__(self, idx):    
        i = idx % (len(self.indices)*2 * self.num_samples)
        quotient = int(idx / (len(self.indices)*2 * self.num_samples))
        img = Image.open(self.fnames[i])
        img = np.asarray(img)
        if quotient > 0:
            img = np.rot90(img, quotient, (0,1)) # will rotate 1,2,3 times
        img = self.transforms(img)
        label = self.labels[i]
        return img, label + 400 * quotient
    
    def __len__(self):
        return len(self.fnames)*4

In [6]:
## ResNet 
class BasicConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0, bias=False):
        super(BasicConv2d, self).__init__()
        self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)
        self.bn = nn.BatchNorm2d(out_planes, eps=0.001)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x
    
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super().__init__()
        self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu(out)

        return out

# add batch norm
class ResNet20_basic(nn.Module):

    def __init__(self, block, layers, num_classes=1000):
        super().__init__()
        self.conv1 = BasicConv2d(3, 64, kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=1)
        self.conv2 = BasicConv2d(64, 128, kernel_size=3, stride=2, padding=1)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=1)
        self.conv3 = BasicConv2d(128, 256, kernel_size=3, stride=2, padding=1)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=1)
        self.conv4 = BasicConv2d(256, 512, kernel_size=3, stride=2, padding=1)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=1)
        
        self.fc = nn.Linear(512*14*14 , num_classes)


    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None  
        layers = []
        layers.append(block(planes, planes, stride, downsample))
        for _ in range(1, blocks):
            layers.append(block(planes, planes))

        return nn.Sequential(*layers)
    
    
    def forward(self, x):
        x = self.conv1(x)           # 112x112
        x = self.layer1(x)          # 
        x = self.conv2(x)           # 56x56
        x = self.layer2(x)          # 
        x = self.conv3(x)           # 28x28
        x = self.layer3(x)          # 
        x = self.conv4(x)           # 14x14
        x = self.layer4(x)          # 

        x = torch.flatten(x, 1)     # remove 1 X 1 grid and make vector of tensor shape 
        x = self.fc(x)

        return x
    
def resnet20_basic(num_classes):
    layers=[1, 2, 4, 1]
    model = ResNet20_basic(BasicBlock, layers, num_classes)
    return model

In [7]:
## Loss functions
def l2_norm(input, axis = 1):
    norm = torch.norm(input, 2, axis, True)
    output = torch.div(input, norm)
    return output

class CurricularFace(nn.Module):
    def __init__(self, in_features, out_features, s = 64., m = 0.5, centers=False):
        super(CurricularFace, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.m = m
        self.s = s
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.threshold = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        self.register_buffer('t', torch.zeros(1))
        nn.init.normal_(self.weight, std=0.01)

    def classify(self, x, centers=None):   
        weights = self.weight
        if torch.is_tensor(centers):
            weights = centers     
        logits = F.linear(F.normalize(x), F.normalize(weights))
        return self.s * logits

    def forward(self, embbedings, label, centers=None):
        weights = self.weight
        if torch.is_tensor(centers) and centers.shape == self.weight.shape:
            weights = centers
        
        embbedings = l2_norm(embbedings, axis = 1)
        kernel_norm = l2_norm(weights, axis = 1)
        cos_theta = torch.mm(embbedings, torch.transpose(kernel_norm, 0, 1))
        cos_theta = cos_theta.clamp(-1, 1)  # for numerical stability
        with torch.no_grad():
            origin_cos = cos_theta.clone()
        target_logit = cos_theta[torch.arange(0, embbedings.size(0)), label].view(-1, 1)

        sin_theta = torch.sqrt(1.0 - torch.pow(target_logit, 2))
        cos_theta_m = target_logit * self.cos_m - sin_theta * self.sin_m #cos(target+margin)
        mask = cos_theta > cos_theta_m
        final_target_logit = torch.where(target_logit > self.threshold, cos_theta_m, target_logit - self.mm)

        hard_example = cos_theta[mask]
        with torch.no_grad():
            self.t = target_logit.mean() * 0.01 + (1 - 0.01) * self.t
        cos_theta[mask] = hard_example * (self.t + hard_example)
        cos_theta.scatter_(1, label.view(-1, 1).long(), final_target_logit)
        margin_output = cos_theta * self.s
        original_logits = origin_cos * self.s
        return margin_output, original_logits

class ArcFace(torch.nn.Module):
    """ ArcFace (https://arxiv.org/pdf/1801.07698v1.pdf):
    """
    def __init__(self, in_features=128, out_features=10575, s=32.0, m=0.50, easy_margin=False, centers=False):
        super(ArcFace, self).__init__()
        self.in_feature = in_features
        self.out_feature = out_features
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.easy_margin = easy_margin
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)

        # make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, x, label):
        # cos(theta)
        cosine = F.linear(F.normalize(x), F.normalize(self.weight))
        with torch.no_grad():
            origin_cos = cosine.clone()
        # cos(theta + m)
        sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
        phi = cosine * self.cos_m - sine * self.sin_m

        if self.easy_margin:
            phi = torch.where(cosine > 0, phi, cosine)
        else:
            phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)

        #one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1)
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        output = output * self.s
        margin_output = origin_cos * self.s
        return margin_output, output
    
class LMCL(nn.Module):
    def __init__(self, in_features=128, out_features=600, s=30.0, m=0.65, centers=False):
        super(LMCL, self).__init__()
        self.in_feature = in_features
        self.out_feature = out_features
        self.s = s
        self.m = m
        if torch.is_tensor(centers) and centers.shape == (out_features, in_features):
            self.weight = nn.Parameter(centers)
        else:
            self.weight = nn.Parameter(torch.randn(out_features, in_features))
            # nn.init.xavier_uniform_(self.weight)
            # nn.init.kaiming_uniform_(self.weight)
            # nn.init.normal_(self.weight, std=0.01)

    def classify(self, x, centers=None):   
        weights = self.weight
        if torch.is_tensor(centers):
            weights = centers     
        logits = F.linear(F.normalize(x), F.normalize(weights))
        return self.s * logits
    
    def forward(self, x, label, centers=None):
        weights = self.weight
        if torch.is_tensor(centers) and centers.shape == self.weight.shape:
            weights = centers
        # else:
        #     weights = self.centers()
        cosine = F.linear(F.normalize(x), F.normalize(weights))
        with torch.no_grad():
            origin_cos = cosine.clone()
            
        # one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
        one_hot = torch.zeros_like(cosine)
        one_hot.scatter_(1, label.view(-1, 1), 1.0)

        margin_output = self.s * (cosine - one_hot * self.m)
        original_logits = self.s * origin_cos
        return margin_output, original_logits
    
class CenterLoss(nn.Module):
    def __init__(self):
        super(CenterLoss, self).__init__()

    def forward(self, feats, label, centers):
        center = centers[label]
        dist = (feats-center).pow(2).sum(dim=-1) / 2
        loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)
        return loss

class CenterHuberLoss(nn.Module):
    def __init__(self, delta=1.0):
        super(CenterHuberLoss, self).__init__()
        self.HuberLoss = nn.HuberLoss(delta=delta)
        self.delta = delta

    def forward(self, feats, label, centers):
        center = centers[label]
        # dist = (feats-center).pow(2).sum(dim=-1) / 2
        dist = self.HuberLoss(feats, center)
        loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)
        return loss
    
class FocalLoss(nn.Module):
    def __init__(self, gamma = 2, eps = 1e-7):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.eps = eps
        self.ce = nn.CrossEntropyLoss()

    def forward(self, input, target):
        logp = self.ce(input, target)
        p = torch.exp(-logp)
        loss = (1 - p) ** self.gamma * logp
        return loss.mean()

In [9]:
dataset_params = {}
dataset_params["Tongji"] = ["../data/TongJi/palmprint/ROI_RGB", 600]
dataset_params["PolyU"] = ["../data/PolyU", 500]
dataset_params["MPD"] = ["../data/MPD/ROI", 200]
dataset_params["MPD_h"] = dataset_params["MPD_m"] = dataset_params["MPD"]

def buildDatasets(dataset, shot=5, val_ratio=0.2, test_ratio=0.2, indices=False):
    if indices:
        training_class_indices, testing_class_indices = indices
    elif test_ratio == 0:
        training_class_indices = range(dataset_params[dataset][1])
        testing_class_indices = range(dataset_params[dataset][1])
    else:
        training_class_indices, testing_class_indices = data.random_split(
            range(dataset_params[dataset][1]), [int(dataset_params[dataset][1]*(1-test_ratio)), int(dataset_params[dataset][1]*test_ratio)])
    
    if(dataset == "Tongji"):
        num_val_samples = int(20*val_ratio)
        trainingDataset = TongjiFewShotDataset(dataset_params[dataset][0], training_class_indices, 20, "gallery", trainingTransform)
        validationDataset = TongjiFewShotDataset(dataset_params[dataset][0], training_class_indices, 20, "probe", testingTransform)
        galleryDataset = TongjiFewShotDataset(dataset_params[dataset][0], testing_class_indices, shot, "gallery", testingTransform)
        probeDataset = TongjiFewShotDataset(dataset_params[dataset][0], testing_class_indices, 20-shot, "probe", testingTransform)
    elif(dataset == "PolyU"):
        num_val_samples = int(12*val_ratio)
        trainingDataset = PolyUFewShotDataset(dataset_params[dataset][0], training_class_indices, 12, "gallery", trainingTransform)
        validationDataset = PolyUFewShotDataset(dataset_params[dataset][0], training_class_indices, 12, "probe", testingTransform)
        galleryDataset = PolyUFewShotDataset(dataset_params[dataset][0], testing_class_indices, shot, "gallery", testingTransform)
        probeDataset = PolyUFewShotDataset(dataset_params[dataset][0], testing_class_indices, 12-shot, "probe", testingTransform)
    elif(dataset == "MPD_h"):
        num_val_samples = int(20*val_ratio)
        trainingDataset = MPDFewShotSingleDataset(dataset_params[dataset][0], training_class_indices, "h", 20, "gallery", trainingTransform)
        validationDataset = MPDFewShotSingleDataset(dataset_params[dataset][0], training_class_indices, "h", 20, "probe", testingTransform)
        galleryDataset = MPDFewShotSingleDataset(dataset_params[dataset][0], testing_class_indices, "h", shot, "gallery", testingTransform)
        probeDataset = MPDFewShotSingleDataset(dataset_params[dataset][0], testing_class_indices, "h", 20-shot, "probe", testingTransform)
    elif(dataset == "MPD_m"):
        num_val_samples = int(20*val_ratio)
        trainingDataset = MPDFewShotSingleDataset(dataset_params[dataset][0], training_class_indices, "m", 20, "gallery", trainingTransform)
        validationDataset = MPDFewShotSingleDataset(dataset_params[dataset][0], training_class_indices, "m", 20, "probe", testingTransform)
        galleryDataset = MPDFewShotSingleDataset(dataset_params[dataset][0], testing_class_indices, "m", shot, "gallery", testingTransform)
        probeDataset = MPDFewShotSingleDataset(dataset_params[dataset][0], testing_class_indices, "m", 20-shot, "probe", testingTransform)
    # elif(dataset == "MPD"):
    #     trainingDataset = MPDTrainingDataset(dataset_params[dataset][0], training_class_indices, "hm", trainingTransform)
    #     galleryDataset = MPDFewShotDataset(dataset_params[dataset][0], testing_class_indices, shot, "gallery", testingTransform)
    #     probeDataset = MPDFewShotDataset(dataset_params[dataset][0], testing_class_indices, 40-shot, "probe", testingTransform)
    else:
        print(dataset)
        
    return trainingDataset, validationDataset, galleryDataset, probeDataset, training_class_indices, testing_class_indices

def buildDataloaders(trainingDataset, validationDataset, galleryDataset, probeDataset, batch_size_train = 55, batch_size_test = 128):
    trainingDataloader = DataLoader(trainingDataset, batch_size=batch_size_train, shuffle=True)
    validationDataloader = DataLoader(validationDataset, batch_size=batch_size_test, shuffle=False)
    galleryDataloader = DataLoader(galleryDataset, batch_size=batch_size_test, shuffle=False)
    probeDataloader = DataLoader(probeDataset, batch_size=batch_size_test, shuffle=False)
    return trainingDataloader, validationDataloader, galleryDataloader, probeDataloader

In [10]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()
        
    def forward(self, x):
        return x
    
def initModel(model_type, head_type, num_classes, feature_dim, loss_func, lamb, lr=0.01, l2=0, mm=0.9, feature_norm=False):
    if model_type == "ResNet20_basic":
        model = resnet20_basic(feature_dim)
    elif model_type == "ResNet18":
        model = models.resnet18(pretrained = False)
        model.avgpool = Identity()
        model.fc = nn.Linear(model.fc.in_features*7*7, feature_dim) 
    elif model_type == "ResNet18_default":
        model = models.resnet18(pretrained = False)
        model.fc = nn.Linear(model.fc.in_features, feature_dim) 
    elif model_type == "pretrained_ResNet18":
        model = models.resnet18(pretrained = True)
        model.fc = nn.Linear(model.fc.in_features, feature_dim) 
    elif model_type == "reduced_ResNet18":
        model = models.resnet18(pretrained = True)
        model.layer4 = Identity()
        model.fc = nn.Linear(256, feature_dim) 
    elif model_type == "SE_ResNeXt26d_pretrained":
        model = timm.create_model('seresnext26d_32x4d', pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, feature_dim) 
    elif model_type == "SE_ResNeXt26d":
        model = timm.create_model('seresnext26d_32x4d')
        model.fc = nn.Linear(model.fc.in_features, feature_dim) 
    elif model_type == "ResNeSt26d_pretrained":
        model = timm.create_model('resnest26d', pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, feature_dim) 
    elif model_type == "Reduced_ResNeSt26d_pretrained":
        model = timm.create_model('resnest26d', pretrained=True)
        model.layer4 = Identity()
        model.fc = nn.Linear(1024, feature_dim) 
    elif model_type == "Double_Reduced_ResNeSt26d_pretrained":
        model = timm.create_model('resnest26d', pretrained=True)
        model.layer4 = Identity()
        model.layer3 = Identity()
        model.fc = nn.Linear(512, feature_dim) 
    elif model_type == "ResNeSt26d":
        model = timm.create_model('resnest26d')
        model.fc = nn.Linear(model.fc.in_features, feature_dim) 
    elif model_type == "ResNeSt14d_pretrained":
        model = timm.create_model('resnest14d', pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, feature_dim) 
    elif model_type == "ResNeSt50d_pretrained":
        model = timm.create_model('resnest50d', pretrained=True)
        model.fc = nn.Linear(model.fc.in_features, feature_dim) 
    elif model_type == "Reduced_ResNeSt50d_pretrained":
        model = timm.create_model('resnest50d', pretrained=True)
        model.layer4 = Identity()
        model.fc = nn.Linear(1024, feature_dim)
    elif model_type == "Double_Reduced_ResNeSt50d_pretrained":
        model = timm.create_model('resnest50d', pretrained=True)
        model.layer3 = Identity()
        model.layer4 = Identity()
        model.fc = nn.Linear(512, feature_dim) 
    else:
        raise BaseException("Invalid model type")

    if head_type == "LMCL":
        head = LMCL(in_features=feature_dim, out_features=num_classes, s=30.0, m=0.65, centers=False)
        # head = LMCL_loss(num_classes, feature_dim, s=30.00, m=0.65)
    elif head_type == "CurricularFace":
        head = CurricularFace(in_features=feature_dim, out_features=num_classes, s=30.0, m=0.65, centers=False)
    elif head_type == "ArcFace":
        head = ArcFace(in_features=feature_dim, out_features=num_classes, s=30.0, m=0.65, centers=False)
    else:
        raise BaseException("Invalid loss function type")
        
    if loss_func == "CE+Center":
        criterion = {"Softmax": nn.CrossEntropyLoss(), "CenterLoss": CenterLoss(), "Lambda": lamb}
    elif loss_func == "Focal+Center":
        criterion = {"Softmax": FocalLoss(gamma=2), "CenterLoss": CenterLoss(), "Lambda": lamb}
    elif loss_func == "Focal+Huber":
        criterion = {"Softmax": FocalLoss(gamma=2), "CenterLoss": CenterHuberLoss(delta=12), "Lambda": lamb}
    elif loss_func == "CE+Huber":
        criterion = {"Softmax": nn.CrossEntropyLoss(), "CenterLoss": CenterHuberLoss(delta=10), "Lambda": lamb}
    else:
        raise BaseException("Invalid loss function type")
    

    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=mm)
    optimizer4center = torch.optim.Adam(head.parameters(), lr=0.1, weight_decay=l2)
    return model, head, criterion, optimizer, optimizer4center

In [11]:
## initial dataset parameters
import torchvision.models as models
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau

## dataset parameters
dataset_choices = ["PolyU","Tongji","MPD_h","MPD_m"]
dataset = dataset_choices[0]
val_ratio = 0
test_ratio = 0.1
shot = 5
num_classes_source = dataset_params[dataset][1]*2 if dataset.startswith("MPD") else dataset_params[dataset][1]
dataset_type = f'{dataset}' if test_ratio else f'{dataset}_full'

print("------ dataset parameters ------")
print("dataset: ", dataset)
print("val ratio: ", val_ratio)
print("test ratio: ", test_ratio)
print("number of shots: ", shot)

(trainingDataset_source, 
validationDataset_source, 
galleryDataset_source, 
probeDataset_source, 
training_class_indices_source, 
testing_class_indices_source) = buildDatasets(dataset, shot, val_ratio, test_ratio)
trainingDataloader_source, validationDataloader_source, galleryDataloader_source, probeDataloader_source = buildDataloaders(
    trainingDataset_source, validationDataset_source, galleryDataset_source, probeDataset_source, batch_size_train = 32, batch_size_test = 128)
print("------ source dataset parameters ------")
print("training classes: ", len(training_class_indices_source))
print("testing classes: ", len(testing_class_indices_source))
print("training samples: ", len(trainingDataset_source))
print("validation samples: ", len(validationDataset_source))
print("gallery samples: ", len(galleryDataset_source))
print("probe samples: ", len(probeDataset_source))

print("------ end dataset parameters ------")

------ dataset parameters ------
dataset:  PolyU
val ratio:  0
test ratio:  0.1
number of shots:  5
------ source dataset parameters ------
training classes:  450
testing classes:  50
training samples:  5400
validation samples:  5400
gallery samples:  250
probe samples:  350
------ end dataset parameters ------


In [14]:
dataset_type += "-optuna-third"

In [688]:
## rotation-based oversampling
trainingDataset_source = PolyURotationCopyDataset("../data/PolyU", training_class_indices_source, 12, "gallery", trainingTransform)
# trainingDataset_source = TongjiRotationCopyDataset("../data/TongJi/palmprint/ROI_RGB", training_class_indices_source, 20, "gallery", trainingTransform)
# trainingDataset_source = MPDFewShotSingleCopyDataset("../data/MPD/ROI", training_class_indices_source, "h", 20, "gallery", trainingTransform)
# trainingDataset_source = MPDFewShotSingleCopyDataset("../data/MPD/ROI", training_class_indices_source, "m", 20, "gallery", trainingTransform)
dataset_type += "-rotation"

trainingDataloader_source = DataLoader(trainingDataset_source, batch_size=batch_size_train, shuffle=True)
num_classes_source *= 4
print("training classes: ", num_classes_source)

training classes:  2000


In [12]:
## initial model parameters
epoch = 0
log_train_acc = []
log_train_loss = []
log_val_acc = []
log_val_loss = []
log_test_acc = []
log_test_loss = []
log_gradient_norm = []
log_max_feature_norm = []

feature_dim = 128
## model type
# model_type = "ResNet20_basic"
# model_type = "ResNet18"
# model_type = "ResNet18_default"
# model_type = "pretrained_ResNet18"
# model_type = "reduced_ResNet18"
# model_type = "SE_ResNeXt26d_pretrained"
# model_type = "SE_ResNeXt26d"
# model_type = "ResNeSt26d"
# model_type = "ResNeSt26d_pretrained"
# model_type = "Reduced_ResNeSt26d_pretrained"
# model_type = "Double_Reduced_ResNeSt26d_pretrained"
# model_type = "ResNeSt50d_pretrained"
model_type = "Reduced_ResNeSt50d_pretrained"
# model_type = "Double_Reduced_ResNeSt50d_pretrained"

## loss functinos
head_type = "LMCL"
# head_type = "CurricularFace"
# head_type = "ArcFace"

# loss_func = "CE+Center"
# loss_func = "Focal+Center"
loss_func = "Focal+Huber"
# loss_func = "CE+Huber"
# loss_func = "CE+Circle"
lamb = 1

# optimzer
lr = 0.001
mm = 0.9
l2 = 0

model, head, criterion, optimizer, optimzer4center = initModel(
    model_type, head_type, num_classes_source, feature_dim, loss_func, lamb, lr, l2, mm)
prefix = "{}-{}s-{}-{}emb-{}-{}-{}l-{}lr-{}mm-{}l2".format(dataset_type, shot, model_type, feature_dim, head_type, loss_func, lamb, lr, mm, l2)

total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("------ model parameters ------")
print("model type: ", model_type)
print("loss function: ", loss_func, ", lambda: ", lamb)
print("learning rate: ", lr, ", L2 Norm(weight decay): ", l2, ", momentum: ", mm)
print("sameple model name: {}-{}e.pt".format(prefix, epoch))
print('parameter total:{}, trainable:{}'.format(total, trainable))
print("------ end model parameters ------")

------ model parameters ------
model type:  Reduced_ResNeSt50d_pretrained
loss function:  Focal+Huber , lambda:  1
learning rate:  0.001 , L2 Norm(weight decay):  0 , momentum:  0.9
sameple model name: PolyU-5s-Reduced_ResNeSt50d_pretrained-128emb-LMCL-Focal+Huber-1l-0.001lr-0.9mm-0l2-0e.pt
parameter total:9412608, trainable:9412608
------ end model parameters ------


In [13]:
## load model
checkpoint = torch.load("model/PolyU-optuna-third-5s-Reduced_ResNeSt50d_pretrained-128emb-LMCL-Focal+Huber-1l-0.001lr-0.9mm-0l2-20e.pt")
epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
head.load_state_dict(checkpoint['head_state_dict'])
criterion = checkpoint['criterion']
training_class_indices = checkpoint['training_class_indices']
testing_class_indices = checkpoint['testing_class_indices']
trainingTransform = checkpoint['trainingTransform']
testingTransform = checkpoint['testingTransform']


In [14]:
val_ratio = 0
(trainingDataset_source, 
 validationDataset_source, 
 galleryDataset_source, 
 probeDataset_source, 
 training_class_indices_source, 
 testing_class_indices_source) = buildDatasets(
    dataset, shot, val_ratio, test_ratio, (training_class_indices, testing_class_indices))
trainingDataloader_source, validationDataloader_source, galleryDataloader_source, probeDataloader_source = buildDataloaders(
    trainingDataset_source, validationDataset_source, galleryDataset_source, probeDataset_source, batch_size_train = 32, batch_size_test = 128)
print("------ source dataset parameters ------")
print("training classes: ", len(training_class_indices_source))
print("testing classes: ", len(testing_class_indices_source))
print("training samples: ", len(trainingDataset_source))
print("validation samples: ", len(validationDataset_source))
print("gallery samples: ", len(galleryDataset_source))
print("probe samples: ", len(probeDataset_source))

------ source dataset parameters ------
training classes:  450
testing classes:  50
training samples:  5400
validation samples:  5400
gallery samples:  250
probe samples:  350


In [15]:
## target dataset parameters
dataset_choices = ["PolyU","Tongji","MPD_h","MPD_m"]
dataset_target = dataset_choices[2]
val_ratio = 0
test_ratio = 0
shot = 5
num_classes_target = dataset_params[dataset_target][1]*2 if dataset_target.startswith("MPD") else dataset_params[dataset_target][1]

print("------ dataset parameters ------")
print("dataset: ", dataset_target)
print("val ratio: ", val_ratio)
print("test ratio: ", test_ratio)
print("number of shots: ", shot)

(trainingDataset_target, 
validationDataset_target, 
galleryDataset_target, 
probeDataset_target, 
training_class_indices_target, 
testing_class_indices_target) = buildDatasets(dataset_target, shot, val_ratio, test_ratio)
trainingDataloader_target, validationDataloader_target, galleryDataloader_target, probeDataloader_target = buildDataloaders(
    trainingDataset_target, validationDataset_target, galleryDataset_target, probeDataset_target, batch_size_train = 32, batch_size_test = 128)
print("------ target dataset parameters ------")
print("training classes: ", len(training_class_indices_target))
print("testing classes: ", len(testing_class_indices_target))
print("training samples: ", len(trainingDataset_target))
print("validation samples: ", len(validationDataset_target))
print("gallery samples: ", len(galleryDataset_target))
print("probe samples: ", len(probeDataset_target))

print("------ end dataset parameters ------")

------ dataset parameters ------
dataset:  MPD_h
val ratio:  0
test ratio:  0
number of shots:  5
------ target dataset parameters ------
training classes:  200
testing classes:  200
training samples:  8000
validation samples:  8000
gallery samples:  2000
probe samples:  6000
------ end dataset parameters ------


In [16]:
# Mirror-concatenated matching
model= model.cuda()
model.eval()
gallery_feature_loader = torch.Tensor().cuda()
gallery_label_loader = torch.Tensor().long()
probe_feature_loader = torch.Tensor().cuda()
probe_label_loader = torch.Tensor().long()
with torch.no_grad():
    for i, (img, labels) in enumerate(galleryDataloader_target):
        inputs = img.cuda()
        feats = model(inputs)
        
        # inputs_mirror = torch.fliplr(img).cuda()
        inputs_mirror = torch.flip(img, (-2,)).cuda()
        feats_mirror = model(inputs_mirror)
        feats = torch.cat([feats, feats_mirror], 1)
        
        gallery_feature_loader = torch.cat([gallery_feature_loader, feats], 0)
        gallery_label_loader = torch.cat([gallery_label_loader, labels], 0)
    

    for i, (img, labels) in enumerate(probeDataloader_target):
        inputs = img.cuda()
        feats = model(inputs)
        
        inputs_mirror = torch.flip(img, (-2,)).cuda()
        feats_mirror = model(inputs_mirror)
        # feats =  feats_mirror
        feats = torch.cat([feats, feats_mirror], 1)
        
        probe_feature_loader = torch.cat([probe_feature_loader, feats], 0)
        probe_label_loader = torch.cat([probe_label_loader, labels], 0)
   

match_scores = torch.Tensor()
mathces = torch.Tensor()
probe_label_loader = probe_label_loader.cpu()
cos = nn.CosineSimilarity()
test_acc = 0.0
fail = 0
fail_loader = []
for i,p in enumerate(probe_feature_loader):
    cosine = cos(p, gallery_feature_loader)
    test_pred = torch.max(cosine, 0).indices.item()
    test_acc += gallery_label_loader[test_pred] == probe_label_loader[i]
    if gallery_label_loader[test_pred] != probe_label_loader[i]:
        true_indexes = np.where(gallery_label_loader == probe_label_loader[i])
        fail_loader.append([test_pred, i, cosine[test_pred].item(), cosine[true_indexes].cpu()])

    # open-set verification: match scores for every gallery entry to every probe input
    # match_scores = torch.cat([match_scores, cosine.cpu()], 0)
    # mathces = torch.cat([mathces, (probe_label_loader[i] == gallery_label_loader)], 0)
    
test_acc /= len(probe_feature_loader)
print("Test Acc: %3.6f" % (test_acc))
print(len(fail_loader))

Test Acc: 0.900500
597


In [19]:
# multi-transform matching
model= model.cuda()
model.eval()
gallery_feature_loader = torch.Tensor().cuda()
gallery_label_loader = torch.Tensor()
probe_feature_loader = torch.Tensor().cuda()
probe_label_loader = torch.Tensor()
with torch.no_grad():
    for i, (img, labels) in enumerate(galleryDataloader_target):
        images = []
        images.append(img.cuda()) # original
        # images.append(torch.flip(img, (-1,)).cuda()) # horizantal
        # images.append(torch.flip(img, (-2,)).cuda()) # vertical
        # images.append(torch.flip(img, (-1,-2)).cuda()) # h+v = 180 degree
        # images.append(torch.rot90(img, 1, (-1,-2)).cuda()) # 90 degree
        # images.append(torch.rot90(img, 3, (-1,-2)).cuda()) # 270 degree

        feats = []
        for image in images:
            feats.append(model(image))
        feats = torch.stack(feats, dim=1)
        
        gallery_feature_loader = torch.cat([gallery_feature_loader, feats], 0)
        gallery_label_loader = torch.cat([gallery_label_loader, labels], 0)
    for i, (img, labels) in enumerate(probeDataloader_target):
        images = []
        images.append(img.cuda()) # original
        # images.append(torch.flip(img, (-1,)).cuda()) # horizantal
        # images.append(torch.flip(img, (-2,)).cuda()) # vertical
        # images.append(torch.flip(img, (-1,-2)).cuda()) # h+v = 180 degree
        # images.append(torch.rot90(img, 1, (-1,-2)).cuda()) # 90 degree
        # images.append(torch.rot90(img, 3, (-1,-2)).cuda()) # 270 degree

        feats = []
        for image in images:
            feats.append(model(image))
        feats = torch.stack(feats, dim=1)
        
        probe_feature_loader = torch.cat([probe_feature_loader, feats], 0)
        probe_label_loader = torch.cat([probe_label_loader, labels], 0)

match_scores = torch.Tensor()
mathces = torch.Tensor()
cos = nn.CosineSimilarity(dim=-1)
test_acc = 0.0
fail = 0
fail_loader = []
for i,p in enumerate(probe_feature_loader):
    cosine = cos(p, gallery_feature_loader)
    
    # average all transformation type
    similarity_mean = cosine.sum(dim=-1) / (cosine!=0).sum(1)
    
    test_pred = torch.max(similarity_mean, 0).indices.item()
    test_acc += gallery_label_loader[test_pred] == probe_label_loader[i]
    
    if gallery_label_loader[test_pred] != probe_label_loader[i]:
        true_indexes = np.where(gallery_label_loader == probe_label_loader[i])
        fail_loader.append([test_pred, i, cosine[test_pred].cpu(), cosine[true_indexes].cpu()])
        
    # open-set verification: match scores for every gallery entry to every probe input
    # match_scores = torch.cat([match_scores, similarity_mean.cpu()], 0)
    # mathces = torch.cat([mathces, (probe_label_loader[i] == gallery_label_loader)], 0)
    
test_acc /= len(probe_feature_loader)
print("Test Acc: %3.6f" % (test_acc))
print("Number of disabled gallery mirror template: %d" % ((gallery_feature_loader.norm(dim=-1) == 0).sum()))
print("Number of disabled probe mirror template: %d" % ((probe_feature_loader.norm(dim=-1) == 0).sum()))
print("Number of failed: %d" % (len(fail_loader)))

Test Acc: 0.852667
Number of disabled gallery mirror template: 0
Number of disabled probe mirror template: 0
Number of failed: 884


In [472]:
print(fail_loader[1])
print(gallery_label_loader[fail_loader[0][0]], probe_label_loader[fail_loader[0][1]])


[tensor(406, device='cuda:0'), 22, tensor(0.6359, device='cuda:0'), tensor([0.5847, 0.6338, 0.6024, 0.6092, 0.6068], device='cuda:0')]
tensor(169.) tensor(1.)
