In [1]:
import torchvision as tv
import phototour
import torch
from tqdm import tqdm 
import numpy as np
import torch.nn as nn
import math 

lib_train = phototour.PhotoTour('.','liberty', download=True, train=True, mode = 'triplets', augment = True, nsamples=409600)
yos_train = phototour.PhotoTour('.','yosemite', download=True, train=True, mode = 'triplets', augment = True)
nd_train = phototour.PhotoTour('.','notredame', download=True, train=True, mode = 'triplets', augment = True)

eval_db = phototour.PhotoTour('.','yosemite', download=True, train=False)
# train_db = torch.utils.data.ConcatDataset((lib_train, yos_train))
train_db = nd_train
train_name = 'notredame'

# Found cached data ./liberty.pt
# Found cached data ./yosemite.pt
# Found cached data ./notredame.pt
# Found cached data ./yosemite.pt


In [2]:
import tfeat_model
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

tfeat = tfeat_model.TNet()
tfeat = tfeat.cuda()

# this kind of works
optimizer = optim.SGD(tfeat.parameters(), lr=0.1, weight_decay=0.0001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.8)

seed=42
torch.manual_seed(seed)
np.random.seed(seed)
# cv2.setRNGSeed(seed)

train_loader = torch.utils.data.DataLoader(train_db,
                                             batch_size=300, shuffle=False,
                                             num_workers=30)

eval_loader = torch.utils.data.DataLoader(eval_db,
                                             batch_size=1024, shuffle=False,
                                             num_workers=32)

In [None]:
fpr_per_epoch = []

for e in range(300):
    tfeat.train()
    for batch_idx, (data_a, data_p, data_n) in tqdm(enumerate(train_loader)):
        data_a = data_a.unsqueeze(1).float().cuda()
        data_p = data_p.unsqueeze(1).float().cuda()
        data_n = data_n.unsqueeze(1).float().cuda()
        out_a, out_p, out_n = tfeat(data_a), tfeat(data_p), tfeat(data_n)
        loss = F.triplet_margin_loss(out_a, out_p, out_n, margin=2, swap=True) 
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
            
    tfeat.eval()
    l = np.empty((0,))
    d = np.empty((0,))
    #evaluate the network after each epoch
    for batch_idx, (data_l, data_r, lbls) in enumerate(eval_loader):
        data_l = data_l.unsqueeze(1).float().cuda()
        data_r = data_r.unsqueeze(1).float().cuda()
        out_l, out_r = tfeat(data_l), tfeat(data_r)
        dists = torch.norm(out_l - out_r, 2, 1).detach().cpu().numpy()
        l = np.hstack((l,lbls.numpy()))
        d = np.hstack((d,dists))
        
    # FPR95 code from Yurun Tian
    d = torch.from_numpy(d)
    l = torch.from_numpy(l)
    dist_pos = d[l==1]
    dist_neg = d[l!=1]
    dist_pos,indice = torch.sort(dist_pos)
    loc_thr = int(np.ceil(dist_pos.numel() * 0.95))
    thr = dist_pos[loc_thr]
    fpr95 = float(dist_neg.le(thr).sum())/dist_neg.numel()
    print(e,fpr95)
    fpr_per_epoch.append([e,fpr95])
    scheduler.step()
    np.savetxt('fpr.txt', np.array(fpr_per_epoch), delimiter=',') 
        


3334it [06:21,  8.75it/s]


0 0.1657


3334it [06:25,  8.65it/s]


1 0.12372


3334it [06:20,  8.76it/s]


2 0.1015


3334it [06:26,  8.64it/s]


3 0.1019


3334it [06:23,  8.69it/s]


4 0.10572


3334it [06:22,  8.72it/s]


5 0.09724


3334it [06:21,  8.74it/s]


6 0.10236


3334it [06:21,  8.73it/s]


7 0.08238


3334it [06:19,  8.78it/s]


8 0.08986


3334it [06:21,  8.75it/s]


9 0.09476


3334it [06:20,  8.77it/s]


10 0.08732


3334it [06:20,  8.77it/s]


11 0.07934


3334it [06:20,  8.77it/s]


12 0.08422


3334it [06:20,  8.76it/s]


13 0.07902


3334it [06:19,  8.78it/s]


14 0.0802


3334it [06:23,  8.69it/s]


15 0.07846


3334it [06:19,  8.79it/s]


16 0.07706


3334it [06:18,  8.80it/s]


17 0.0855


3334it [06:20,  8.77it/s]


18 0.07988


3334it [06:21,  8.73it/s]


19 0.079


3334it [06:22,  8.71it/s]


20 0.08532


3334it [06:21,  8.73it/s]


21 0.07216


3334it [06:23,  8.70it/s]


22 0.08076


3334it [06:22,  8.71it/s]


23 0.07728


3334it [06:21,  8.75it/s]


24 0.07908


3334it [06:22,  8.72it/s]


25 0.07762


3334it [06:21,  8.73it/s]


26 0.07134


3334it [06:34,  8.44it/s]


27 0.07572


3334it [06:21,  8.74it/s]


28 0.07106


3334it [06:23,  8.69it/s]


29 0.07926


3334it [06:20,  8.76it/s]


30 0.07034


3334it [06:42,  8.29it/s]


31 0.07322


3334it [06:20,  8.76it/s]


32 0.07258


3334it [06:32,  8.49it/s]


33 0.07342


3334it [06:24,  8.67it/s]


34 0.07228


3334it [06:21,  8.74it/s]


35 0.07366


3334it [06:20,  8.76it/s]


36 0.07084


3334it [06:19,  8.78it/s]


37 0.07102


3334it [06:22,  8.71it/s]


38 0.0727


3334it [06:23,  8.69it/s]


39 0.06934


3334it [06:19,  8.78it/s]


40 0.07024


3334it [06:18,  8.81it/s]


41 0.07024


3334it [06:24,  8.67it/s]


42 0.0699


3334it [06:19,  8.78it/s]


43 0.07


3334it [06:23,  8.70it/s]


44 0.06696


3334it [06:21,  8.73it/s]


45 0.07304


3334it [06:20,  8.76it/s]


46 0.0667


1919it [03:43,  8.60it/s]

In [None]:
torch.save(tfeat.state_dict(), train_name+'-tfeat.params')