In [1]:
import os, random, tarfile
import sys, argparse
import matplotlib.pyplot as plt
import h5py as h5

import numpy as np
from PIL import Image
from skimage import io, transform

from collections import OrderedDict

import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F

import torch.utils.data as data
from torchvision.datasets import ImageFolder
from torchvision.datasets import CIFAR10
from torchvision.datasets.utils import check_integrity, download_url
import torchvision.transforms as transforms
from torch.optim import lr_scheduler

import time

from scipy.io import loadmat # for loading .mat files
#from scipy.io import savemat # for saving .mat files

from sklearn.metrics.cluster import normalized_mutual_info_score#, homogeneity_score
import sklearn
from sklearn.cluster import KMeans

#######

import pymanopt
from pymanopt.manifolds import Product, Euclidean,  Grassmann
from pymanopt import Problem
from pymanopt.solvers import ConjugateGradient#, SteepestDescent

import math


In [2]:
class Args:
    meaningless_dummy_variable=999

In [3]:
dataset_name='CUB'#CUB,Cars,SOP
loss_name='gumlLoss'

# ### Uncomment the below for command-line arguments
# parser = argparse.ArgumentParser()
# LookupChoices = type('', (argparse.Action, ), dict(__call__ = lambda a, p, n, v, o: setattr(n, a.dest, a.choices[v])))
# parser.add_argument('--emb_dim', default = 128, type = int)
# parser.add_argument('--epochs', default = 5, type = int)
# parser.add_argument('--batch', default = 42, type = int)
# parser.add_argument('--lr', default = 0.01, type = float)
# parser.add_argument('--margin', default = 0.5, type = float)
# parser.add_argument('--step-size', default = 50, type = int)
# parser.add_argument('--gamma', default = 0.1, type = float)
# #parser.add_argument('--gpu-id', default='0', type=str)
# opts = parser.parse_args()


opts=Args()

# Loss term related hyperparameters
opts.margin = 0.5 # for triplet loss, and semi-hard mining
opts.alpha = 45 # for guml loss

# Embedding size
opts.emb_dim = 128

# Optimization related hyperparameters
opts.epochs = 20
opts.batch = 120

opts.lr = 0.01
opts.step_size = 50
opts.gamma = 0.1

print('Following parameters are used=> emb_dim: ',opts.emb_dim,' n_epochs: ',opts.epochs,' batch_sz: ',opts.batch)

emb_dim = opts.emb_dim

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
#print(device)


Following parameters are used=> emb_dim:  128  n_epochs:  20  batch_sz:  120


In [4]:
def _check_integrity(img_folder, integrity_test_list):
    for fentry in (integrity_test_list):
        filename, md5 = fentry[0], fentry[1]
        fpath = img_folder + filename
        if not check_integrity(fpath, md5):
            return False
    return True

In [5]:
def pil_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')

In [6]:
def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
                                 md5=None, remove_finished=False):
    download_root = os.path.expanduser(download_root)
    if extract_root is None:
        extract_root = download_root
    if not filename:
        filename = os.path.basename(url)

    download_url(url, download_root, filename, md5)

    archive = os.path.join(download_root, filename)
    print("Extracting {} to {}".format(archive, extract_root))
    
    with tarfile.open(archive, 'r:gz') as tar: tar.extractall(path=extract_root)


In [7]:
# =====================================================================================
class CUB2011():
    root = '../data/'
    url = 'http://www.vision.caltech.edu/visipedia-data/CUB-200-2011/CUB_200_2011.tgz'
    filename = 'CUB_200_2011.tgz'
    tgz_md5 = '97eceeb196236b17998738112f37df78'

    num_training_classes = 100
    name = 'CUB_200_2011'
    triplet_mode = False
    mined_data = None

    integrity_test_list = [
    ['001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg', '4c84da568f89519f84640c54b7fba7c2'],
    ['002.Laysan_Albatross/Laysan_Albatross_0001_545.jpg', 'e7db63424d0e384dba02aacaf298cdc0'],
    ['198.Rock_Wren/Rock_Wren_0001_189289.jpg', '487d082f1fbd58faa7b08aa5ede3cc00'],
    ['200.Common_Yellowthroat/Common_Yellowthroat_0003_190521.jpg', '96fd60ce4b4805e64368efc32bf5c6fe']
    ]


    def __init__(self, transform=None, download=False, train = True, **kwargs):
        if download and not _check_integrity(self.root+'/CUB_200_2011/images/', self.integrity_test_list):
            download_and_extract_archive(self.url, self.root, filename=self.filename, md5=self.tgz_md5)

        if not _check_integrity(self.root+'/CUB_200_2011/images/', self.integrity_test_list):
            raise RuntimeError('Dataset not found or corrupted. You can use download=True to download it')
        else:
            print('Dataset found, and is proper!') 

        self.transform = transform
        self.train = train

        self.classes = [x.split()[-1] for x in open(self.root+self.name+"/classes.txt", "r").readlines()]
        self.class_to_idx = {_class: i for i, _class in enumerate(self.classes)}
        self.classes = self.classes[:self.num_training_classes] if train else self.classes[self.num_training_classes:]

        self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
        
        
        i = open(self.root+self.name+"/images.txt", "r").readlines() 
        # <image_id,image_file_name>: image_id (example number in dataset) #image_id:1,2,...N
        l = open(self.root+self.name+"/image_class_labels.txt", "r").readlines()
        # <image_id,class_label> #print('l:\n',l) #class_label:1,2,...C
        self.imgs = [(_[0].split()[1], int(_[1].split()[1])-1) for _ in zip(i,l)]
        # <image_file_name,class_label>: class_label:0,1,... due to -1, for python based indexing

        
        self.imgs = [(self.root+self.name+"/images/"+image_file_path, class_label_ind) for image_file_path, class_label_ind in self.imgs if ((class_label_ind-self.num_training_classes)<0) == self.train ]
        #<exact_train_image_path,class_label> // only for the training data        
        
        self.loader = pil_loader

    def __getitem__(self, index):

        # triplet (anchor, pos, neg) mode
        if self.triplet_mode:
            perm = random.sample(self.mined_data['anchors'].keys(), 1)[0] # pick a random anchor

            # sort negatives based on Euclidean distance to anchor
            q_x = self.embeddings[perm,:]
            n_x = self.embeddings[self.mined_data['negpool'][perm],:] # get the embeddings of the list of negatives for the selected anchor
            dists = ((q_x - n_x)**2).sum(axis=1) # broadcasting
            n_idx = np.argsort(dists) 

            rand_pos = np.random.randint(0, len(self.mined_data['pospool'][perm])) # random pos from the pool
            rand_neg = n_idx[np.random.randint(0, np.min((10, len(self.mined_data['negpool'][perm]))))]  # random pick from 10 Euclidean-NN in the pool

            a_path, a_target = self.imgs[int(self.mined_data['anchors'][perm])] # now select the actual anchor image
            p_path, p_target = self.imgs[int(self.mined_data['pospool'][perm][rand_pos])]
            n_path, n_target = self.imgs[int(self.mined_data['negpool'][perm][rand_neg])]

#             p_w = self.mined_data['posweight'][perm][rand_pos] #wts not needed to be given as input
#             n_w = self.mined_data['negweight'][perm][rand_neg]

            a_img, p_img, n_img = self.loader(a_path), self.loader(p_path), self.loader(n_path) 

            if self.transform is not None:
                a_img, p_img, n_img = self.transform(a_img), self.transform(p_img), self.transform(n_img)

            return a_img, p_img, n_img#, p_w, n_w

        # single image mode
        else:
            path, target = self.imgs[index] # if index is given, one get directly get the (image,label)-tuple
            img = self.loader(path)
            if self.transform is not None:
                img = self.transform(img)

            return img, target


    def __len__(self):
        return len(self.imgs)


In [8]:
class inception_v1_googlenet(nn.Sequential):
    output_size = 1024
    input_side = 227
    rescale = 255.0
    rgb_mean = [122.7717, 115.9465, 102.9801]
    rgb_std = [1, 1, 1]

    def __init__(self):
        super(inception_v1_googlenet, self).__init__(OrderedDict([
            ('conv1', nn.Sequential(OrderedDict([
            ('7x7_s2', nn.Conv2d(3, 64, (7, 7), (2, 2), (3, 3))),
            ('relu1', nn.ReLU(True)),
            ('pool1', nn.MaxPool2d((3, 3), (2, 2), ceil_mode = True)),
            ('lrn1', nn.CrossMapLRN2d(5, 0.0001, 0.75, 1))
            ]))),

            ('conv2', nn.Sequential(OrderedDict([
            ('3x3_reduce', nn.Conv2d(64, 64, (1, 1), (1, 1), (0, 0))),
            ('relu1', nn.ReLU(True)),
            ('3x3', nn.Conv2d(64, 192, (3, 3), (1, 1), (1, 1))),
            ('relu2', nn.ReLU(True)),
            ('lrn2', nn.CrossMapLRN2d(5, 0.0001, 0.75, 1)),
            ('pool2', nn.MaxPool2d((3, 3), (2, 2), ceil_mode = True))
            ]))),

            ('inception_3a', InceptionModule(192, 64, 96, 128, 16, 32, 32)),
            ('inception_3b', InceptionModule(256, 128, 128, 192, 32, 96, 64)),

            ('pool3', nn.MaxPool2d((3, 3), (2, 2), ceil_mode = True)),

            ('inception_4a', InceptionModule(480, 192, 96, 208, 16, 48, 64)),
            ('inception_4b', InceptionModule(512, 160, 112, 224, 24, 64, 64)),
            ('inception_4c', InceptionModule(512, 128, 128, 256, 24, 64, 64)),
            ('inception_4d', InceptionModule(512, 112, 144, 288, 32, 64, 64)),
            ('inception_4e', InceptionModule(528, 256, 160, 320, 32, 128, 128)),

            ('pool4', nn.MaxPool2d((3, 3), (2, 2), ceil_mode = True)),

            ('inception_5a', InceptionModule(832, 256, 160, 320, 32, 128, 128)),
            ('inception_5b', InceptionModule(832, 384, 192, 384, 48, 128, 128)),

            ('pool5', nn.AvgPool2d((7, 7), (1, 1), ceil_mode = True)),

            #('drop5', nn.Dropout(0.4))
            ]))

class InceptionModule(nn.Module):
    def __init__(self, inplane, outplane_a1x1, outplane_b3x3_reduce, outplane_b3x3, outplane_c5x5_reduce, outplane_c5x5, outplane_pool_proj):
        super(InceptionModule, self).__init__()
        a = nn.Sequential(OrderedDict([
            ('1x1', nn.Conv2d(inplane, outplane_a1x1, (1, 1), (1, 1), (0, 0))),
            ('1x1_relu', nn.ReLU(True))
            ]))

        b = nn.Sequential(OrderedDict([
            ('3x3_reduce', nn.Conv2d(inplane, outplane_b3x3_reduce, (1, 1), (1, 1), (0, 0))),
            ('3x3_relu1', nn.ReLU(True)),
            ('3x3', nn.Conv2d(outplane_b3x3_reduce, outplane_b3x3, (3, 3), (1, 1), (1, 1))),
            ('3x3_relu2', nn.ReLU(True))
            ]))

        c = nn.Sequential(OrderedDict([
            ('5x5_reduce', nn.Conv2d(inplane, outplane_c5x5_reduce, (1, 1), (1, 1), (0, 0))),
            ('5x5_relu1', nn.ReLU(True)),
            ('5x5', nn.Conv2d(outplane_c5x5_reduce, outplane_c5x5, (5, 5), (1, 1), (2, 2))),
            ('5x5_relu2', nn.ReLU(True))
            ]))

        d = nn.Sequential(OrderedDict([
            ('pool_pool', nn.MaxPool2d((3, 3), (1, 1), (1, 1))),
            ('pool_proj', nn.Conv2d(inplane, outplane_pool_proj, (1, 1), (1, 1), (0, 0))),
            ('pool_relu', nn.ReLU(True))
            ]))

        for container in [a, b, c, d]:
            for name, module in container.named_children():
                self.add_module(name, module)

        self.branches = [a, b, c, d]

    def forward(self, input):
        return torch.cat([branch(input) for branch in self.branches], 1)

In [9]:
# def distance_vectors_pairwise(anchor, positive, negative , squared = True):
#     """Given batch of anchor descriptors and positive descriptors calculate distance matrix"""
#     eps = 1e-8

#     a_sq = torch.sum(anchor * anchor, dim=1)
#     p_sq = torch.sum(positive * positive, dim=1)
#     n_sq = torch.sum(negative * negative, dim=1)

#     d_a_p = a_sq + p_sq - 2*torch.sum(anchor * positive, dim = 1)
#     d_a_n = a_sq + n_sq - 2*torch.sum(anchor * negative, dim = 1)
#     d_p_n = p_sq + n_sq - 2*torch.sum(positive * negative, dim = 1)
    
#     if not squared:
#         d_a_p = torch.sqrt(d_a_p + eps)
#         d_a_n = torch.sqrt(d_a_n + eps)
#         d_p_n = torch.sqrt(d_p_n + eps)
        
#     return d_a_p, d_a_n, d_p_n



In [10]:
class CNNModel(nn.Module):
    def __init__(self, base_model, num_classes):#, embedding_size = 128, lr = 0.001):
        super(CNNModel, self).__init__()
        self.base_model = base_model
        self.num_classes = num_classes
        #self.embedder = nn.Linear(base_model.output_size, embedding_size)
        #self.lr = lr
        
    def forward(self, input):
        #return self.embedder(F.relu(self.base_model(input).view(len(input), -1)))
        return F.normalize(F.relu(self.base_model(input).view(len(input), -1)))
    
    #criterion = None
    


In [11]:
class TripletLoss(torch.nn.Module):
    """
    Triplet loss function.
    Based on: FaceNet: A unified embedding for face recognition and clustering

    """

    def __init__(self, margin=0.5):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, output3):
        anc_pos_distance = F.pairwise_distance(output1, output2, keepdim = True) # output1:(b,d),output2:(b,d)
        anc_neg_distance = F.pairwise_distance(output1, output3, keepdim = True) # output1:(b,d),output2:(b,d)
        
        #print(euclidean_distance.size()) # (b,1) pairwise distances tensor
#         temp=((1-label) * torch.pow(euclidean_distance, 2) +
#                                       (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        
        #print('temp:',temp.size()) ## (b,1) tensor
        loss_triplet = torch.mean(torch.pow(torch.clamp(anc_pos_distance+self.margin - anc_neg_distance, min=0.0), 2))
        #print('loss_triplet:',loss_triplet) # tensor of no size, but an item containing the scalar loss 

        return loss_triplet

############################################################################################################
# class WeightedTriplet(Model):
#     def forward(self, input):
#         return F.normalize(Model.forward(self, input))
    
#     def criterion(self, a_emb, p_emb, n_emb, margin = 1.0):
#     #def criterion(self, a_emb, p_emb, n_emb, p_w, n_w, margin = 1.0):
#         (d_a_p, d_a_n, _) = distance_vectors_pairwise(a_emb, p_emb, n_emb)
#         loss = torch.clamp(margin + d_a_p - d_a_n, min=0.0)
#         #loss = loss * p_w.float()# no need to give weights as i/p, can be computed
#         loss = torch.mean(loss)

#         return loss

In [12]:
class gumlLoss(torch.nn.Module):
    """
    GUML loss function.
    Based on: 

    """

    def __init__(self, alpha=45):
        super(gumlLoss, self).__init__()
        self.alpha = alpha

    def forward(self, embed1, embed2, embed3, R_tensor,L_tensor):

        xi_ancs=embed1.transpose(0,1)
        xi_poss=embed2.transpose(0,1)
        xi_negs=embed3.transpose(0,1) #(d,b) format, b:#triplets in mini-batch, i.e., batch size

        #print(xi_ancs.size(),xi_poss.size(),xi_negs.size())
        #print(xi_ancs.requires_grad,xi_poss.requires_grad,xi_negs.requires_grad)


        #print(R_tensor.size(),L_tensor.size())
        #print(R_tensor.requires_grad,L_tensor.requires_grad)

        xi_avgs=0.5*(xi_ancs+xi_poss) #dxT, T:#triplets
        #print(xi_avgs.size())
        #print(xi_avgs.requires_grad)

        num_triplets=xi_avgs.size(1)
        #print(num_triplets)


        RRT=R_tensor@(R_tensor.transpose(0,1))
        #print(RRT.size(),RRT.dtype)


        listA=xi_ancs.transpose(0,1)#.tolist()
        listB=xi_poss.transpose(0,1)#.tolist()



        #print(listA.dtype,type(listA))

        exps_plus=list(map(lambda elA, elB:
                   torch.exp(-(elA@RRT@elB)), 
                   listA.double(), listB.double()))
        #exps_plus=list(map(lambda elt: np.asscalar(elt), exps_plus))
        exps_plus=torch.tensor(exps_plus).reshape(1,len(exps_plus))

        #print(exps_plus,exps_plus.size())

        listA=xi_avgs.transpose(0,1)
        listB=xi_negs.transpose(0,1)

        exps_minus=list(map(lambda elA, elB:
                    torch.exp(-(elA@RRT@elB)),
                    listA.double(), listB.double()))
        #exps_minus=list(map(lambda elt: np.asscalar(elt), exps_minus))
        exps_minus=torch.tensor(exps_minus).reshape(1,len(exps_minus))

        w_is_plus = 1/(1+ exps_plus ) # weights
        w_is_minus = 1-1/(1+ exps_minus ) # weights

        w_is=0.5*(w_is_plus+w_is_minus) 
        #print(w_is,w_is.size())    

        LT=L_tensor.transpose(0,1)

        delZaZp=xi_ancs-xi_poss
        MhZaZp = LT @ delZaZp.double()
        d_aps=torch.sum(MhZaZp*MhZaZp,0)
        d_aps=d_aps.reshape(1,len(d_aps)) # 1xT
        #print(delZaZp.shape,MhZaZp.shape,d_aps.shape)

        delZnZm=xi_negs-xi_avgs
        MhZnZm = LT @ delZnZm.double()
        d_nms=torch.sum(MhZnZm*MhZnZm,0)
        d_nms=d_nms.reshape(1,len(d_nms)) # 1xT

        tan_sq_alpha=(math.tan(math.radians(self.alpha)))**2 # for angular 

        z_is=d_aps-4*tan_sq_alpha*d_nms # metric losses

        #print(z_is.shape)

        m_is=torch.log(1+torch.exp(z_is))
        
        w_is=w_is.to(device)
        m_is=m_is.to(device)

        f_is=-w_is*m_is
        #print(f_is.shape)

        loss_guml= (1/num_triplets)*torch.sum(torch.log(1+torch.exp(-f_is)))

        #print(loss_guml)
        return loss_guml

In [13]:
def get_dataset_embeddings(model, dataset, threads = 8): 
    embeddings_all, labels_all = [], []
    loader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=False, num_workers = threads)
    for batch_idx, batch in enumerate(loader):
        with torch.no_grad():
            images, labels = [torch.autograd.Variable(tensor.to(device)) for tensor in batch]
        embeddings_all.extend(model(images).data.cpu().numpy())
        labels_all.extend(labels.data.cpu().numpy())
    return np.asarray(embeddings_all), np.asarray(labels_all)


In [14]:
# Nearest Neighbor (NN) search and recall computation
def recall(embeddings, labels, K = 1): # embeddings: Nxd
    prod = torch.mm(embeddings, embeddings.t())
    norm = prod.diag().unsqueeze(1).expand_as(prod)
    D = norm + norm.t() - 2 * prod
    knn_inds = D.topk(1 + K, dim = 1, largest = False)[1][:, 1:]
    
#     print(type(embeddings),type(labels),type(prod),type(norm),type(D),type(knn_inds))
#     print(embeddings.size(),labels.size(),prod.size(),norm.size(),D.size(),knn_inds.size())
    
    return (labels.unsqueeze(-1).expand_as(knn_inds) == labels[knn_inds.contiguous().view(-1)].view_as(knn_inds)).max(1)[0].float().mean()


In [15]:
def eval_nmi(embedding, label,  normed_flag = False, fast_kmeans = False): # provide Nxd data, (N,) labels
    unique_id = np.unique(label)
    num_category = len(unique_id)
    if normed_flag:
        for i in range(embedding.shape[0]):
            embedding[i,:] = embedding[i,:]/np.sqrt(np.sum(embedding[i,:] ** 2)+1e-4)
    if fast_kmeans:
        kmeans = KMeans(n_clusters=num_category, n_init = 1)
    else:
        kmeans = KMeans(n_clusters=num_category)
    kmeans.fit(embedding)
    y_kmeans_pred = kmeans.predict(embedding)
    nmi = normalized_mutual_info_score(label, y_kmeans_pred)
    return nmi

In [16]:
# load triplet information from text files
def read_mined_data(anchors_fn, pos_fn, neg_fn):#, posw_fn, negw_fn):

    anchors, pos, neg = dict(), dict(), dict()
    #anchors, pos, neg, posw, negw = dict(), dict(), dict(), dict(), dict()
    with open(anchors_fn) as f:
        for idx,line in enumerate(f):
            anchors[idx] = int(line.strip())-1 #stripping to omit \n, and storing the anc id. -1 for python indexing

    with open(pos_fn) as posf, open(neg_fn) as negf:#, open(posw_fn) as poswf, open(negw_fn) as negwf:
        #for idx, (pos_line, neg_line, posw_line, negw_line) in enumerate(zip(posf, negf, poswf, negwf)):
        for idx, (pos_line, neg_line) in enumerate(zip(posf, negf)):
            pos[idx] = [x-1 for x in map(int,pos_line.strip().split(','))]
            neg[idx] = [x-1 for x in map(int,neg_line.strip().split(','))]
#             posw[idx] = [x for x in map(float,posw_line.strip().split(','))]
#             negw[idx] = [x for x in map(float,negw_line.strip().split(','))]

    return {'anchors':anchors,'pospool':pos,'negpool':neg}#,'posweight':posw,'negweight':negw}

In [17]:
class SemiHardMiner():
#     def __init__(self, opt):
#         self.par          = opt
#         self.name         = 'semihard'
#         self.margin       = vars(opt)['loss_'+opt.loss+'_margin']
    def __init__(self, opt):
        self.margin       = opt.margin

    def __call__(self, batch, labels, return_distances=False):
        if isinstance(labels, torch.Tensor): labels = labels.detach()#.numpy()
        bs = batch.size(0)
        #Return distance matrix for all elements in batch (BSxBS)
        distances = self.pdist(batch.detach()).detach().cpu().numpy()

        positives, negatives = [], []
        anchors = []
        for i in range(bs):
            l, d = labels[i], distances[i]
            neg = labels!=l; pos = labels==l

            if sum(pos)==1: # if there is only one example in the mini-batch from a class, drop that class
                continue
            
            anchors.append(i)
            pos[i] = 0
            p      = np.random.choice(np.where(pos)[0])
            positives.append(p)

            #Find negatives that violate tripet constraint semi-negatives
            neg_mask = np.logical_and(neg,d>d[p])
            neg_mask = np.logical_and(neg_mask,d<self.margin+d[p])
            if neg_mask.sum()>0:
                negatives.append(np.random.choice(np.where(neg_mask)[0]))
            else:
                negatives.append(np.random.choice(np.where(neg)[0]))

        sampled_triplets = [[a, p, n] for a, p, n in zip(anchors, positives, negatives)]

        if return_distances:
            return sampled_triplets, distances
        else:
            return sampled_triplets


    def pdist(self, A):
        prod = torch.mm(A, A.t())
        norm = prod.diag().unsqueeze(1).expand_as(prod)
        res = (norm + norm.t() - 2 * prod).clamp(min = 0)
        return res.clamp(min = 0).sqrt()

In [18]:
# main code starts here

main_root = '../'
data_dir=main_root+'data/'

### Keep the following lines commented, unless triplet ids are provided in text files

# # load training data picked by MoM
# mined_data = read_mined_data(data_dir+'anchors.txt',data_dir+'pos.txt',data_dir+'neg.txt')#,data_dir+'posw.txt',data_dir+'negw.txt')

# # print(mined_data.keys()) #dict_keys(['anchors', 'pospool', 'negpool', 'posweight', 'negweight'])
# # #anchors is a dictionary with a single value for a key. Rest are dictionaries with a list of values for a key.
# # for key in mined_data.keys():
# #     print(key,type(mined_data[key]))
# #     print(len(mined_data[key]),mined_data[key])


In [19]:
results_dir = main_root+'results/' 

log_filename = results_dir+dataset_name+'_'+loss_name+'_d_'+str(opts.emb_dim)+'_a_'+str(opts.alpha)+'_'+'log.txt'
print('logs saved in => ',log_filename)

log = open(log_filename, 'a')

# set random seeds
for set_random_seed in [random.seed, np.random.seed, torch.manual_seed]: set_random_seed(8)


logs saved in =>  ../results/CUB_gumlLoss_d_128_a_45_log.txt


In [20]:
# load base model
base_model = inception_v1_googlenet()
base_model_weights_path = os.path.join(data_dir+'inception_v1_googlenet.h5') # 'googlenet.h5', 'inception_v1_googlenet.h5'

temp=[]
f = h5.File(base_model_weights_path, "r")
gp_name=list(f.items())[0][0] #'data_0'
G1=f.get(gp_name) # G1: Group 1 // is a <HDF5 group "/data_0" (114 members)>
# G1_items=list(G1.items())

with open('inception_v1_params_names.txt') as fp: 
    for line in fp:
        #print(line.strip())
        param_name=line.strip()
        G1j=G1.get('/'+gp_name+'/'+param_name) 
        #G1j_items=list(G1j.items())
        param=np.array(G1j.get('data_0'))
        temp.append((param_name,param))
#print(type(temp),temp)

base_model.load_state_dict({k : torch.from_numpy(v) for k, v in temp})

print(base_model)

inception_v1_googlenet(
  (conv1): Sequential(
    (7x7_s2): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
    (relu1): ReLU(inplace=True)
    (pool1): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=0, dilation=1, ceil_mode=True)
    (lrn1): CrossMapLRN2d(5, alpha=0.0001, beta=0.75, k=1)
  )
  (conv2): Sequential(
    (3x3_reduce): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
    (relu1): ReLU(inplace=True)
    (3x3): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (relu2): ReLU(inplace=True)
    (lrn2): CrossMapLRN2d(5, alpha=0.0001, beta=0.75, k=1)
    (pool2): MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=0, dilation=1, ceil_mode=True)
  )
  (inception_3a): InceptionModule(
    (1x1): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1))
    (1x1_relu): ReLU(inplace=True)
    (3x3_reduce): Conv2d(192, 96, kernel_size=(1, 1), stride=(1, 1))
    (3x3_relu1): ReLU(inplace=True)
    (3x3): Conv2d(96, 128, kernel_size=(3, 3), strid

In [21]:
normalize = transforms.Compose([
 transforms.ToTensor(),
 transforms.Lambda(lambda x: x * base_model.rescale),
 transforms.Normalize(mean = base_model.rgb_mean, std = base_model.rgb_std),
 transforms.Lambda(lambda x: x[[2, 1, 0], ...])
])

dataset_train = CUB2011(train = True, transform = transforms.Compose([
 transforms.RandomResizedCrop(base_model.input_side),
 transforms.RandomHorizontalFlip(),
 normalize
]), download = False) # Train data: Classes 1-100

dataset_eval = CUB2011(train = False, transform = transforms.Compose([
 transforms.Resize(256),
 transforms.CenterCrop(base_model.input_side),
 normalize
]), download = False) # Test data: Classes 101-200

#print(dataset_train.num_training_classes,dataset_train.idx_to_class,dataset_train.imgs)
#print(dataset_eval.num_training_classes,dataset_eval.idx_to_class,dataset_eval.imgs)
#print(dataset_train.imgs)

#print(type(dataset_train.imgs[0]),dataset_train.imgs[0])

Dataset found, and is proper!
Dataset found, and is proper!


In [22]:
main_root = '../'
aas_labels_dir=main_root+'aas_labels_fgvc/'

if dataset_name=='CUB':
    aas_dataset='full_CUB_googlenet'
elif dataset_name=='Cars':
    aas_dataset='full_Cars_googlenet'
else:
    aas_dataset='SOP_googlenet'

op_filename=aas_labels_dir+'aas_labels_'+aas_dataset+'.mat'
annots = loadmat(op_filename)
#print(annots.keys())

pred_labels=annots['aas_labels']
#pred_labels=pred_labels.reshape(len(pred_labels),)

#print(len(pred_labels),len(np.unique(pred_labels)))

for i,pseudolabel in enumerate(pred_labels):
    #print(i,pseudolabel)
    dataset_train.imgs[i]=(dataset_train.imgs[i][0], int(pseudolabel))

# print(dataset_train.imgs)
# print(type(dataset_train.imgs[0]),dataset_train.imgs[0])

print('Pseudo-labels assignment done.')

Pseudo-labels assignment done.


In [23]:
# loss, optimizer, scheduler
#model = WeightedTriplet(base_model, dataset_train.num_training_classes, lr =opts.lr, embedding_size = emb_dim).to(device)

model=CNNModel(base_model, dataset_train.num_training_classes).to(device)

#criterion = TripletLoss(opts.margin)
criterion = gumlLoss(opts.alpha)

# if only those parameters need to be updated for which requires_grad=True
optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), weight_decay = 5e-4, lr = opts.lr, momentum = 0.9, dampening = 0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, **dict(step_size = opts.step_size, gamma = opts.gamma))

print('Model, optimizer and scheduler initialized.')

Model, optimizer and scheduler initialized.


In [24]:
# evaluate on test set for initial network

t1=time.time()
model.eval()
embeddings_all, labels_all = get_dataset_embeddings(model,dataset_eval)
print('Time elapsed for getting test embeddings:', time.time()-t1)


Time elapsed for getting test embeddings: 10.979093551635742


In [25]:
print(type(embeddings_all),embeddings_all.shape,type(labels_all),labels_all.shape)

<class 'numpy.ndarray'> (5924, 1024) <class 'numpy.ndarray'> (5924,)


In [34]:
if dataset_name=='SOP':
    k_arr=[1,10,100]
else:
    k_arr=[1,2,4,8]




In [27]:
# #model.eval() # set to true if embeddings are required
# dataset_train.triplet_mode = False

# loader_train = torch.utils.data.DataLoader(dataset_train, shuffle=True,
#                                            num_workers = 8, batch_size = opts.batch, drop_last = True)
# model.train()

# #print(loader_train)

# example_batch = next(iter(loader_train))

# # model.train() tells your model that you are training the model. 
# # So effectively layers like dropout, batchnorm etc. which behave different on the train 
# # and test procedures know what is going on and hence can behave accordingly.
# # You can call either model.eval() or model.train(mode=False) to tell that you are testing

# images, pseudolabels = example_batch

# # print(images.size(),pseudolabels.size(),pseudolabels)
# # img = images[10,:,:,:].numpy().transpose((1, 2, 0)) # visualize a random image in the batch
# # img = np.clip(img, 0, 1)
# # print('After transforms:')
# # plt.imshow(img)

# emb_ims=model(images.to(device))
# # print(emb_ims.size(), torch.sum(emb_ims*emb_ims, 1), pseudolabels)
# print(emb_ims.size())

# triplet_miner= SemiHardMiner(opts) 
# sampled_triplets = triplet_miner(emb_ims,pseudolabels) #contains indices of examples from a batch

# print('#triplets sampled:', len(sampled_triplets) )

In [28]:
# anc_ids=[]
# pos_ids=[]
# neg_ids=[]
# for trip in sampled_triplets:
#     anc_ids.append(trip[0])
#     pos_ids.append(trip[1])
#     neg_ids.append(trip[2])
# # print(sampled_triplets)
# # print(anc_ids)
# # print(pos_ids)
# # print(neg_ids)

# embed_a=emb_ims[anc_ids,:]
# embed_p=emb_ims[pos_ids,:]
# embed_n=emb_ims[neg_ids,:]

# #print(embed_a.size())

In [29]:
# margin=0.2

# batch,labels=emb_ims,pseudolabels
# if isinstance(labels, torch.Tensor): labels = labels.detach().numpy()
# bs = batch.size(0)

# print(bs,labels)
# #Return distance matrix for all elements in batch (BSxBS)
# distances = pdist(batch.detach()).detach().cpu().numpy()        
    
# print(distances.shape,distances) # bxb np array

# positives, negatives = [], []
# anchors = []
# for i in range(bs):
#     #print(i)
#     l, d = labels[i], distances[i]
#     neg = labels!=l; pos = labels==l
    
#     if sum(pos)==1: # if there is no more example in the mini-batch from the same class
#         continue
    
#     anchors.append(i)
#     pos[i] = 0
#     p      = np.random.choice(np.where(pos)[0])
#     positives.append(p)

#     #Find negatives that violate tripet constraint semi-negatives
#     neg_mask = np.logical_and(neg,d>d[p])
#     neg_mask = np.logical_and(neg_mask,d<margin+d[p])
#     if neg_mask.sum()>0:
#         negatives.append(np.random.choice(np.where(neg_mask)[0]))
#     else:
#         negatives.append(np.random.choice(np.where(neg)[0]))

# sampled_triplets = [[a, p, n] for a, p, n in zip(anchors, positives, negatives)]

# print(len(sampled_triplets))
# for trip in sampled_triplets:
#     #print(trip)
#     print(labels[trip[0]],labels[trip[1]],labels[trip[2]])

In [30]:
orig_dim=base_model.output_size
latent_dim=opts.emb_dim

manifold = Product([Euclidean(orig_dim, latent_dim), Grassmann(orig_dim, latent_dim)]) #list or tuples
#manifold = Product([Grassmann(orig_dim, latent_dim)]) #list or tuples

In [31]:
@pymanopt.function.PyTorch
def cost(R_tensor,L_tensor):

    xi_avgs=0.5*(xi_ancs+xi_poss) #dxT, T:#triplets
    #print('xi_avgs.size',xi_avgs.size())
    #print('xi_avgs.requires_grad',xi_avgs.requires_grad)

    num_triplets=xi_avgs.size(1)
    #print(num_triplets)

    
    R_tensor=R_tensor.to(device)
    RRT=R_tensor@(R_tensor.transpose(0,1))
    #print(RRT.size(),RRT.dtype)
    #print('type_R_tensor',type(R_tensor))


    listA=xi_ancs.transpose(0,1)#.tolist()
    listB=xi_poss.transpose(0,1)#.tolist()



    #print(listA.dtype,type(listA))

    exps_plus=list(map(lambda elA, elB:
           torch.exp(-(elA@RRT@elB)), 
           listA.double(), listB.double()))
    #exps_plus=list(map(lambda elt: np.asscalar(elt), exps_plus))
    exps_plus=torch.tensor(exps_plus).reshape(1,len(exps_plus))

    #print(exps_plus,exps_plus.size())

    listA=xi_avgs.transpose(0,1)
    listB=xi_negs.transpose(0,1)

    exps_minus=list(map(lambda elA, elB:
            torch.exp(-(elA@RRT@elB)),
            listA.double(), listB.double()))
    #exps_minus=list(map(lambda elt: np.asscalar(elt), exps_minus))
    exps_minus=torch.tensor(exps_minus).reshape(1,len(exps_minus))

    w_is_plus = 1/(1+ exps_plus ) # weights
    w_is_minus = 1-1/(1+ exps_minus ) # weights

    w_is=0.5*(w_is_plus+w_is_minus) 
    #print(w_is,w_is.size())    

    L_tensor=L_tensor.to(device)
    LT=L_tensor.transpose(0,1)

    delZaZp=xi_ancs-xi_poss
    MhZaZp = LT @ delZaZp.double()
    d_aps=torch.sum(MhZaZp*MhZaZp,0)
    d_aps=d_aps.reshape(1,len(d_aps)) # 1xT
    #print(delZaZp.shape,MhZaZp.shape,d_aps.shape)

    delZnZm=xi_negs-xi_avgs
    MhZnZm = LT @ delZnZm.double()
    d_nms=torch.sum(MhZnZm*MhZnZm,0)
    d_nms=d_nms.reshape(1,len(d_nms)) # 1xT

    tan_sq_alpha=(math.tan(math.radians(alpha)))**2 # for angular 

    z_is=d_aps-4*tan_sq_alpha*d_nms # metric losses

    #print(z_is.shape)

    m_is=torch.log(1+torch.exp(z_is))
    w_is=w_is.to(device)
    m_is=m_is.to(device)

    f_is=-w_is*m_is
    #print(f_is.shape)

    loss_guml= (1/num_triplets)*torch.sum(torch.log(1+torch.exp(-f_is)))

    #print(loss_guml)
    return loss_guml.cpu()

In [32]:
alpha=opts.alpha
problem = Problem(manifold=manifold, cost=cost, verbosity=0)#, egrad=egrad) # verbosity=0 for no o/p, 2 for most
# (3) Instantiate a Pymanopt solver
solver = ConjugateGradient( maxiter=5 , logverbosity=2) # logverbosity controls how much info is stored.

In [33]:
for epoch in range(opts.epochs):
    
    t1=time.time()
    
    dataset_train.triplet_mode = False
    loader_train = torch.utils.data.DataLoader(dataset_train, shuffle=True,
                                           num_workers = 8, batch_size = opts.batch, drop_last = True)
    model.train()
    # batch train
    scheduler.step()
    loss_all = []
    for batch_idx, batch in enumerate(loader_train):
    #for batch_idx, batch in enumerate(loader_train if model.criterion is not None else []):
        #a_images, p_images, n_images, p_w, n_w  = [torch.autograd.Variable(tensor.to(device)) for tensor in batch]
        #images, labels  = [torch.autograd.Variable(tensor.to(device)) for tensor in batch]
        images, labels  = batch
        
        # // Obtain embeddings for images, and perform semi-hard mining on the batch
        emb_ims=model(images.to(device))
        triplet_miner= SemiHardMiner(opts) 
        sampled_triplets = triplet_miner(emb_ims,labels) #contains indices of examples from a batch
        #print('#triplets sampled:', len(sampled_triplets) )
        
        if len(sampled_triplets)<1: # to ensure that the code does not break in case of no triplets
            continue # proceed to the next mini-batch
        
        #embed1, embed2, embed3 : anc, pos, neg tensors of shape (T,d). T:#triplets, d:embedding size
        anc_ids, pos_ids, neg_ids = [], [], []
        for trip in sampled_triplets:
            anc_ids.append(trip[0])
            pos_ids.append(trip[1])
            neg_ids.append(trip[2])

        embed1=emb_ims[anc_ids,:].to(device) # anchor tensors (T,d): T: #triplets, d:orig_dim
        embed2=emb_ims[pos_ids,:].to(device) # positive tensors
        embed3=emb_ims[neg_ids,:].to(device) # negative tensors
    
        ## Above embedding tensors are already l2 normalized
        
        xi_ancs=embed1.transpose(0,1).detach() # dxb tensors, b=T i.e. #triplets in mini-batch
        xi_poss=embed2.transpose(0,1).detach() # dxb
        xi_negs=embed3.transpose(0,1).detach() # dxb        
        
        ##>>>>>>>>>>>>>>>> Fix the remaining network, and learn (R,L) <<<<<<<<<<<<<<<<#
        ##>>>>>>>>>>>>>>>> learn (R,L) using Riemannian optimization (eg, RCGD) <<<<<<<<<<<<<<<<#
        if (epoch==0 and batch_idx==0):#for the very first mini-batch
            # let Pymanopt do the rest
            xopt, optlog = solver.solve(problem) # start with a random point on manifold
            R=xopt[0]
            L=xopt[1] #dxl
            #proj_matrix=L.T
        else:
            # let Pymanopt do the rest
            xopt,optlog = solver.solve(problem, x=(R_old,L_old))
            R=xopt[0]
            L=xopt[1]
            #proj_matrix=L.T
            #print('L_type_shape',type(L),L.shape)

        ## for the next mini-batch, start Riemannian optimization from the points where you leave now.   
        R_old=R
        L_old=L
            
        ##print(type(optlog['iterations']['f(x)']))        
        ## img = torch.from_numpy(img).float().to(device) # transform numpy arrays to PyTorch tensors
        
        ##>>>>>>>>>>>>>>>> Compute Loss with forward pass using current estimates <<<<<<<<<<<<<<<<#
        ## requires_grad=False by default. Therefore, won't get affected in backward pass.
        R_tensor=torch.from_numpy(R).to(device)
        L_tensor=torch.from_numpy(L).to(device)
        loss=criterion(embed1, embed2, embed3, R_tensor,L_tensor)

        ##>>>>>>>>>>>>>>>> Backpropagation <<<<<<<<<<<<<<<<#
        ## Parameters learned using Riemannian optimization won't get affected
        ## Rest of the network parameters will be updated (for eg, the last linear layer prior to L)
#         print('P Before backward:', net.fc1[2].weight.data) # uncomment to see changes in updation
        ## above remains same as obtained from previous mini-batch
#         print('L Before backward:', L.T) # uncomment to see changes in updation
        ## above gets changed as obtained from previous mini-batch, because of Riemmannian optimization
        ## for the current mini-batch
        
        loss_all.append(loss.data.item())
        optimizer.zero_grad()
        
        loss.backward()
        
        ##>>>>>>>>>>>>>>>> Update parameters using learned gradients <<<<<<<<<<<<<<<<#
        optimizer.step()


    print('loss epoch {}: {:.04f}'.format(epoch, np.mean(loss_all)))
    log.write('loss epoch {}: {:.04f}\n'.format(epoch, np.mean(loss_all)))
    
    print('Time elapsed for training the epoch:', time.time()-t1)
    
    
    # evaluate on test set 
    if epoch < 10 or (epoch + 1) % 5 == 0 or (epoch + 1) == opts.epochs:
        model.eval()
        embeddings_all, labels_all = get_dataset_embeddings(model,dataset_eval)
        
        
        rec = [recall(torch.Tensor(embeddings_all), torch.Tensor(list(labels_all)), x).item() for x in k_arr]

        nmi = eval_nmi(torch.Tensor(embeddings_all), torch.Tensor(labels_all.reshape(len(labels_all),))) # provide Nxd data, (N,) labels

        print('NMI recall@1,2,4,8 epoch {}: {:.06f} {:.06f} {:.06f} {:.06f} {:.06f}'.format(epoch, nmi*1e2, rec[0]*1e2, rec[1]*1e2, rec[2]*1e2, rec[3]*1e2))
        log.write('NMI recall@1,2,4,8 epoch {}: {:.06f} {:.06f} {:.06f} {:.06f} {:.06f}'.format(epoch, nmi*1e2, rec[0]*1e2, rec[1]*1e2, rec[2]*1e2, rec[3]*1e2))

# best_epoch=epoch
# torch.save({'epoch': best_epoch + 1, 'state_dict': model.state_dict()},'{}/checkpoint_{}_{}.pth'.format(data_dir, best_epoch, dataset_name))        



loss epoch 0: 0.7077
Time elapsed for training the epoch: 117.07206726074219
NMI recall@1,2,4,8 epoch 0: 55.547269 46.758947 59.503716 72.062796 82.849425
loss epoch 1: 0.7056
Time elapsed for training the epoch: 115.61474084854126
NMI recall@1,2,4,8 epoch 1: 55.141807 46.758947 59.773803 72.299123 82.950711
loss epoch 2: 0.7045
Time elapsed for training the epoch: 115.02356719970703
NMI recall@1,2,4,8 epoch 2: 55.633242 47.029033 59.773803 72.535449 82.815665
loss epoch 3: 0.7040
Time elapsed for training the epoch: 114.21500301361084
NMI recall@1,2,4,8 epoch 3: 55.978688 47.079676 59.925723 71.995276 82.815665
loss epoch 4: 0.7033
Time elapsed for training the epoch: 115.88131785392761
NMI recall@1,2,4,8 epoch 4: 55.590070 47.079676 59.740043 72.079676 82.849425
loss epoch 5: 0.7027
Time elapsed for training the epoch: 114.33155798912048
NMI recall@1,2,4,8 epoch 5: 55.762441 46.944633 60.077649 72.045916 82.461172
loss epoch 6: 0.7024
Time elapsed for training the epoch: 116.77263760