In [None]:
 !pip install kornia



In [None]:
%%writefile Retrieval.py

from utils import *
from tqdm import tqdm

def pqDist_one(C, N_books, g_x, q_x):
    l1, l2 = C.shape
    L_word = int(l2/N_books)
    D_C = T.zeros((l1, N_books), dtype=T.float32)

    q_x_split = T.split(q_x, L_word, 0)
    g_x_split = np.split(g_x.cpu().data.numpy(), N_books, 1)
    C_split = T.split(C, L_word, 1)
    D_C_split = T.split(D_C, 1, 1)

    for j in range(N_books):
        for k in range(l1):
            D_C_split[j][k] =T.norm(q_x_split[j]-C_split[j][k], 2)
            #D_C_split[j][k] = T.norm(q_x_split[j]-C_split[j][k], 2).detach() #for PyTorch version over 1.9
        if j == 0:
            dist = D_C_split[j][g_x_split[j]]
        else:
            dist = T.add(dist, D_C_split[j][g_x_split[j]])
    Dpq = T.squeeze(dist)
    return Dpq

def Indexing(C, N_books, X):              #stores minimum squared distance index of each subvector of X
    l1, l2 = C.shape
    L_word = int(l2/N_books)
    x = T.split(X, L_word, 1)
    y = T.split(C, L_word, 1)
    for i in range(N_books):
        diff = squared_distances(x[i], y[i])
        arg = T.argmin(diff, dim=1)
        min_idx = T.reshape(arg, [-1, 1])
        if i == 0:
            quant_idx = min_idx
        else:
            quant_idx = T.cat((quant_idx, min_idx), dim=1)
    return quant_idx

def Evaluate_mAP(C, N_books, gallery_codes, query_codes, gallery_labels, query_labels, device, TOP_K=None):
    num_query = query_labels.shape[0]
    mean_AP = 0.0

    with T.no_grad():
      with tqdm(total=num_query, desc="Evaluate mAP", bar_format='{desc:<15}{percentage:3.0f}%|{bar:10}{r_bar}') as pbar:
          for i in range(num_query):
            # Retrieve images from database
              retrieval = (query_labels[i, :] @ gallery_labels.t() > 0).float()

            # Arrange position according to hamming distance
              retrieval = retrieval[T.argsort(pqDist_one(C, N_books, gallery_codes, query_codes[i]))][:TOP_K]

            # Retrieval count
              retrieval_cnt = retrieval.sum().int().item()

            # Can not retrieve images
              if retrieval_cnt == 0:
                  continue

            # Generate score for every position
              score = T.linspace(1, retrieval_cnt, retrieval_cnt).to(device)

            # Acquire index
              index = (T.nonzero(retrieval == 1, as_tuple=False).squeeze() + 1.0).float().to(device)

              mean_AP += (score / index).mean()
              pbar.update(1)

          mean_AP = mean_AP / num_query
    return mean_AP

def DoRetrieval(device, args, net, C):
    print("Do Retrieval!")

    trainset = torchvision.datasets.CIFAR10(root=args.data_dir, train=True, download=args.if_download, transform=transforms.ToTensor())
    Gallery_loader = T.utils.data.DataLoader(trainset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    testset = torchvision.datasets.CIFAR10(root=args.data_dir, train=False, download=args.if_download, transform=transforms.ToTensor())
    Query_loader = T.utils.data.DataLoader(testset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

    net.eval()
    with T.no_grad():
        with tqdm(total=len(Gallery_loader), desc="Build Gallery", bar_format='{desc:<15}{percentage:3.0f}%|{bar:10}{r_bar}') as pbar:
            for i, data in enumerate(Gallery_loader, 0):
                gallery_x_batch, gallery_y_batch = data[0].to(device), data[1].to(device)
                outputs = net(gallery_x_batch)
                gallery_c_batch = Indexing(C, args.N_books, outputs[0])
                gallery_y_batch = T.eye(args.num_cls)[gallery_y_batch]
                if i == 0:
                    gallery_c = gallery_c_batch
                    gallery_y = gallery_y_batch
                else:
                    gallery_c = T.cat([gallery_c, gallery_c_batch], 0)
                    gallery_y = T.cat([gallery_y, gallery_y_batch], 0)
                pbar.update(1)

        with tqdm(total=len(Query_loader), desc="Compute Query", bar_format='{desc:<15}{percentage:3.0f}%|{bar:10}{r_bar}') as pbar:
            for i, data in enumerate(Query_loader, 0):
                query_x_batch, query_y_batch = data[0].to(device), data[1].to(device)
                outputs = net(query_x_batch)
                query_y_batch = T.eye(args.num_cls)[query_y_batch]
                if i == 0:
                    query_c = outputs[0]
                    query_y = query_y_batch
                else:
                    query_c = T.cat([query_c, outputs[0]], 0)
                    query_y = T.cat([query_y, query_y_batch], 0)
                pbar.update(1)

    mAP = Evaluate_mAP(C, args.N_books, gallery_c.type(T.int), query_c, gallery_y, query_y, device, args.Top_N)
    return mAP

Overwriting Retrieval.py


In [None]:
%%writefile config.py

#YouTube Faces (YTF), 1: FaceScrub (FS), 2: VGGFace2 (VGGF), 3: CASIA-WebFace (CW)

cifar = dict(
    NB_CLS=10,
    input_size=224)

cifar32 = dict(
    Gallery_img_dir='./data/cifar',
    Gallery_txt_dir='./data/cifar_Train.txt',
    Query_img_dir='./data/cifar',
    Query_txt_dir='./data/cifar_Query.txt',
    NB_CLS=10,
    input_size=32)

ImageNet32 = dict(
    Train_img_dir='./data/ImageNet_32',
    Train_txt_dir='./data/ImageNet32.txt',
    NB_CLS=1000,
    input_size=32)

Overwriting config.py


In [None]:
%%writefile models.py

import torch as T
import torch.nn as nn
import torch.nn.functional as F

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet_Baseline(nn.Module):
    def __init__(self, block, num_blocks):
        super(ResNet_Baseline, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.fc_out = nn.Linear(512*block.expansion, 512)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc_out(out))
        return out

Overwriting models.py


In [None]:
%%writefile utils.py
import os
import argparse
import torch as T
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torchvision
from torch.autograd import Variable
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image
import kornia.augmentation as Kg


#define soft quatization function. This is the exponential of the squared euclidean distance between the 
#feature sub vector and the codeword(centroid), 
#divided by the sum of exponentials of squared distances over the other code words The closest codeword contributes the most to this quantised sum.

def Soft_Quantization(X, C, N_books, tau_q):
    L_word = int(C.size()[1]/N_books)  #calculate length of word as total code size/number of books
    x = T.split(X, L_word, dim=1)      #split x into equal sized chunks
    c = T.split(C, L_word, dim=1)      # split c into equalt sized chunks
    for i in range(N_books):
        soft_c = F.softmax(squared_distances(x[i], c[i]) * (-tau_q), dim=-1)  #calculate softmax
        if i==0:
            Z = soft_c @ c[i]   # @ operator is used for matrix multiplication, thus result of softmax is multiplied with codeword
        else:
            Z = T.cat((Z, soft_c @ c[i]), dim=1)   #concatenates the result to Z vector
    return Z

#returns sum of squared differences

def squared_distances(x, y):
    '''
    Input: x is a Nxd matrix
           y is an optional Mxd matirx
    Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
            if y is not given then use 'y=x'.
    i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
    '''
    diff = x.unsqueeze(1) - y.unsqueeze(0) #unsqueeze is used to create a matrix of singleton list elments, the index is specified as argument.
    return T.sum(diff * diff, -1)



#create a Basic Building block neural network

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) #convolutional neural network of kernal size 3x3 https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
        self.bn1 = nn.BatchNorm2d(planes)                                       #batch normalisation
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,                  #2nd convolutional layer
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)    #batch normalisation          

        self.shortcut = nn.Sequential()  #used to add layers as they are passed to its constructor in a cascading matter(prev output is next input)
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))    #apply relu on all outputs and return, used for forward prop
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

#construct a baseline residual network

class ResNet_Baseline(nn.Module):
    def __init__(self, block, num_blocks):
        super(ResNet_Baseline, self).__init__()
        self.in_planes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.fc_out = nn.Linear(512*block.expansion, 512)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, out.size()[3])
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc_out(out))
        return out

Overwriting utils.py


In [None]:
%%writefile main_SPQ.py
from torch.optim.lr_scheduler import CosineAnnealingLR
from utils import *
from Retrieval import *


#argument parser
def get_args_parser():
    parser = argparse.ArgumentParser('SPQ', add_help=False)

    parser.add_argument('--gpu_id', default="0", type=str, help="""Define GPU id.""")
    parser.add_argument('--if_download', default=False, type=bool, help="""Whether to download the dataset or not.""")
    parser.add_argument('--data_dir', default="./data", type=str, help="""Path of the dataset to be installed.""")
    parser.add_argument('--batch_size', default=256, type=int, help="""Training mini-batch size.""")
    parser.add_argument('--num_workers', default=12, type=int, help="""Number of data loading workers per GPU.""")
    parser.add_argument('--input_size', default=32, type=int, help="""Input image size, default is set to CIFAR10.""")

    parser.add_argument('--N_books', default=8, type=int, help="""The number of the codebooks.""")
    parser.add_argument('--N_words', default=16, type=int, help="""The number of the codewords. It should be a power of two.""")
    parser.add_argument('--L_word', default=16, type=int, help="""Dimensionality of the codeword.""")
    parser.add_argument('--soft_quantization_scale', default=5.0, type=float, help="""Soft-quantization scaling parameter.""")
    parser.add_argument('--contrastive_temperature', default=0.5, type=float, help="""Contrastive learning Temperature scaling parameter.""")
    
    parser.add_argument('--num_cls', default="10", type=int, help="""The number of classes in the dataset for evaluation, default is set to CIFAR10""")
    parser.add_argument('--eval_epoch', default=100, type=int, help="""Compute mAP for Every N-th epoch.""")
    parser.add_argument('--output_dir', default=".", type=str, help="""Path to save logs and checkpoints.""")
    parser.add_argument('--Top_N', default=1000, type=int, help="""Top N number of images to be retrieved for evaluation.""")
    
    return parser


class CQCLoss(T.nn.Module):

    def __init__(self, device, batch_size, tau_cqc):
        super(CQCLoss, self).__init__()
        self.batch_size = batch_size
        self.tau_cqc = tau_cqc
        self.device = device
        self.COSSIM = T.nn.CosineSimilarity(dim=-1)
        self.CE = T.nn.CrossEntropyLoss(reduction="sum")
        self.get_corr_mask = self._get_correlated_mask().type(T.bool)

    def _get_correlated_mask(self):
        diag = np.eye(2 * self.batch_size)
        l1 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=-self.batch_size)
        l2 = np.eye((2 * self.batch_size), 2 * self.batch_size, k=self.batch_size)
        mask = T.from_numpy((diag + l1 + l2))
        mask = (1 - mask).type(T.bool)
        return mask.to(self.device)

    def forward(self, Xa, Xb, Za, Zb):

        XaZb = T.cat([Xa, Zb], dim=0)
        XbZa = T.cat([Xb, Za], dim=0)

        Cossim_ab = self.COSSIM(XaZb.unsqueeze(1), XaZb.unsqueeze(0))
        Rab = T.diag(Cossim_ab, self.batch_size)
        Lab = T.diag(Cossim_ab, -self.batch_size)
        Pos_ab = T.cat([Rab, Lab]).view(2 * self.batch_size, 1)
        Neg_ab = Cossim_ab[self.get_corr_mask].view(2 * self.batch_size, -1)

        Cossim_ba = self.COSSIM(XbZa.unsqueeze(1), XbZa.unsqueeze(0))
        Rba = T.diag(Cossim_ba, self.batch_size)
        Lba = T.diag(Cossim_ba, -self.batch_size)    
        Pos_ba = T.cat([Rba, Lba]).view(2 * self.batch_size, 1)
        Neg_ba = Cossim_ba[self.get_corr_mask].view(2 * self.batch_size, -1)


        logits_ab = T.cat((Pos_ab, Neg_ab), dim=1)
        logits_ab /= self.tau_cqc

        logits_ba = T.cat((Pos_ba, Neg_ba), dim=1)
        logits_ba /= self.tau_cqc

        labels = T.zeros(2 * self.batch_size).to(self.device).long()
        
        loss = self.CE(logits_ab, labels) + self.CE(logits_ba, labels)
        return loss / (2 * self.batch_size)

def train_SPQ(args):
    os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
    device = T.device('cuda')

    sz = args.input_size
    data_dir = args.data_dir
    batch_size = args.batch_size

    N_books = args.N_books
    N_words = args.N_words
    L_word = args.L_word
    tau_q = args.soft_quantization_scale
    tau_cqc = args.contrastive_temperature

    N_bits = int(N_books * np.sqrt(N_words))
    print('\033[91m' + '%d'%N_bits +  '-bit to retrieval' + '\033[0m')

    #Define the data augmentation following the setup of SimCLR
    Augmentation = nn.Sequential(
        Kg.RandomResizedCrop(size=(sz, sz)),
        Kg.RandomHorizontalFlip(p=0.5),
        Kg.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8),
        Kg.RandomGrayscale(p=0.2),
        Kg.RandomGaussianBlur((int(0.1 * sz), int(0.1 * sz)), (0.1, 2.0), p=0.5))

    transform = transforms.ToTensor()

    trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=args.if_download, transform=transform)
    trainloader = T.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)

    class Quantization_Head(nn.Module):
        def __init__(self, N_words, N_books, L_word, tau_q):
            super(Quantization_Head, self).__init__()
            self.fc = nn.Linear(512, N_books * L_word)
            nn.init.xavier_uniform_(self.fc.weight)

            # Codebooks
            self.C = T.nn.Parameter(Variable((T.randn(N_words, N_books * L_word)).type(T.float32), requires_grad=True))
            nn.init.xavier_uniform_(self.C)

            self.N_books = N_books
            self.L_word = L_word
            self.tau_q = tau_q

        def forward(self, input):
            X = self.fc(input)
            Z = Soft_Quantization(X, self.C, self.N_books, self.tau_q)
            return X, Z
        
    Q = Quantization_Head(N_words, N_books, L_word, tau_q)
    net = nn.Sequential(ResNet_Baseline(BasicBlock, [2, 2, 2, 2]), Q)

    net.cuda(device)
    checkpoint = T.load('32_0.7247_checkpoint.pth')
    
    criterion = CQCLoss(device, batch_size, tau_cqc)

    optimizer = T.optim.Adam(net.parameters(), lr=3e-4, weight_decay=10e-6)
    scheduler = CosineAnnealingLR(optimizer, T_max=len(trainloader), eta_min=0, last_epoch=-1)

    net.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    MAX_mAP = 0.0
    mAP = 0.0

    for epoch in range(5000):  # loop over the dataset multiple times

        print('Epoch: %d, Learning rate: %.4f' % (epoch, scheduler.get_last_lr()[0]))
        running_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            inputs = data[0].to(device)
            Ilist = list()
            Xlist = list()
            Zlist = list()

            n = 5

            for k in range(n):
              Ilist.append(Augmentation(inputs))
           # Ia = Augmentation(inputs)
           # Ib = Augmentation(inputs)

            optimizer.zero_grad()

            #Xa, Za = net(Ia)
            #Xb, Zb = net(Ib)

            for k in range(n):
              Xa, Za = net(Ilist[k])
              Xlist.append(Xa)
              Zlist.append(Za)
            
            loss = 0

            for k in range(n):
              for l in range(n):
                if not (k == l):
                  loss = loss + criterion(Xlist[k], Xlist[l], Zlist[k], Zlist[l])
                  
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            #loss = criterion(Xa, Xb, Za, Zb)
            #loss.backward()
            #optimizer.step()

            #running_loss += loss.item()

            if (i+1) % 10 == 0:    # print every 10 mini-batches
                print('[%3d] loss: %.4f, mAP: %.4f, MAX mAP: %.4f' %
                    (i+1, running_loss / (10*n), mAP, MAX_mAP))
                
                store_loss = running_loss / (10*n)
                running_loss = 0.0

        if epoch >= 10:
            scheduler.step()
        
        if (epoch+1) % args.eval_epoch == 0:
            mAP = DoRetrieval(device, args, net, Q.C)
            if mAP > MAX_mAP:
                Result_path = os.path.join(args.output_dir, "%d_%.4f_checkpoint.pth"%(N_bits, mAP))
                T.save({
                  'epoch': epoch,
                  'model_state_dict': net.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': store_loss,
                  }, Result_path)
                #T.save(net.state_dict(), Result_path)
                MAX_mAP = mAP

if __name__ == '__main__':
    parser = argparse.ArgumentParser('SPQ', parents=[get_args_parser()])
    args = parser.parse_args()
    train_SPQ(args)

Overwriting main_SPQ.py


In [None]:
!python main_SPQ.py --gpu_id=0 --batch_size=256 --N_books=8 --N_words=16 --eval_epoch=5 --if_download=True

[91m32-bit to retrieval[0m
Files already downloaded and verified
  cpuset_checked))
Epoch: 0, Learning rate: 0.0003
[ 10] loss: 38.3867, mAP: 0.0000, MAX mAP: 0.0000
[ 20] loss: 38.1224, mAP: 0.0000, MAX mAP: 0.0000
[ 30] loss: 37.8955, mAP: 0.0000, MAX mAP: 0.0000
[ 40] loss: 37.7422, mAP: 0.0000, MAX mAP: 0.0000
[ 50] loss: 37.7395, mAP: 0.0000, MAX mAP: 0.0000
[ 60] loss: 37.6646, mAP: 0.0000, MAX mAP: 0.0000
[ 70] loss: 37.6509, mAP: 0.0000, MAX mAP: 0.0000
[ 80] loss: 37.6385, mAP: 0.0000, MAX mAP: 0.0000
[ 90] loss: 37.6104, mAP: 0.0000, MAX mAP: 0.0000
[100] loss: 37.5648, mAP: 0.0000, MAX mAP: 0.0000
[110] loss: 37.5526, mAP: 0.0000, MAX mAP: 0.0000
[120] loss: 37.5766, mAP: 0.0000, MAX mAP: 0.0000
[130] loss: 37.5397, mAP: 0.0000, MAX mAP: 0.0000
[140] loss: 37.5865, mAP: 0.0000, MAX mAP: 0.0000
[150] loss: 37.5886, mAP: 0.0000, MAX mAP: 0.0000
[160] loss: 37.5284, mAP: 0.0000, MAX mAP: 0.0000
[170] loss: 37.6345, mAP: 0.0000, MAX mAP: 0.0000
[180] loss: 37.5035, mAP: 0.0000