In [1]:
from __future__ import division, print_function

import argparse
import copy
from copy import deepcopy
import datetime
import logging
import math
import os
import random
import sys
import time
from tqdm import tqdm

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from eval_metrics import ErrorRateAt95Recall
from models import HardNet, ResNet, SOSNet32x32
from dataset import TripletPhotoTour, BrownTest
from losses import loss_HardNet, loss_L2Net
from utils import cv2_scale, np_reshape, read_yaml

logging.basicConfig(level=logging.INFO)

In [2]:
logging.info('\n\n================ IMAGE MATCHING CHALLENGE 2020 ==================\n\n')
configs = read_yaml('configs.yml')
# models_output = os.path.join('models', configs['experiment_name'])
os.makedirs(configs['model_dir'], exist_ok=True)

INFO:root:





In [3]:
if configs['use_cuda']:
    os.environ['CUDA_VISIBLE_DEVICES'] = str(configs['gpu_id'])
    cudnn.benchmark = True
    torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

In [4]:
def train(train_loader, model, optimizer, epoch):
    model.train()
    for batch_id, data in tqdm(enumerate(train_loader)):
        if batch_id + 1 == len(train_loader):
            continue
        data_a, data_p = data

        if configs['use_cuda']:
            data_a, data_p  = data_a.cuda(), data_p.cuda()
        out_a = model(data_a)
        out_p = model(data_p)

        loss = loss_HardNet(out_a, out_p)
           
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_id%10 == 0:
            logging.info(f'{batch_id}/{len(train_loader)} - Loss: {loss.item()}')
        if batch_id%100 == 0:
            x = datetime.datetime.now()
            time = x.strftime("%y-%m-%d_%H:%M:%S")
            model_checkpoint = os.path.join(configs['model_dir'], f'checkpoint_{time}_{epoch}_{batch_id}.pth')
            torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, model_checkpoint)
            logging.info(model_checkpoint)
    logging.info(f'{len(train_loader)}/{len(train_loader)} - Loss: {loss.item()}')
    
    x = datetime.datetime.now()
    time = x.strftime("%y-%m-%d_%H:%M:%S")
    model_checkpoint = os.path.join(configs['model_dir'], f'checkpoint_{time}_{epoch}.pth')
    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, model_checkpoint)
    logging.info(model_checkpoint)


def test(test_loader, model, epoch):
    model.eval()

    labels, distances = [], [] 
    for (data_a, data_p, label) in tqdm(test_loader):

        if configs['use_cuda']:
            data_a, data_p = data_a.cuda(), data_p.cuda()

        with torch.no_grad():
            out_a = model(data_a)
            out_p = model(data_p)
        dists = torch.sqrt(torch.sum((out_a - out_p) ** 2, 1))  # euclidean distance
        distances.append(dists.data.cpu().numpy().reshape(-1,1))
        ll = label.data.cpu().numpy().reshape(-1, 1)
        labels.append(ll)

    num_tests = len(test_loader.dataset)
    labels = np.vstack(labels).reshape(num_tests)
    distances = np.vstack(distances).reshape(num_tests)

    fpr95 = ErrorRateAt95Recall(labels, 1.0 / (distances + 1e-8))
    logging.info('\33[91mTest set: Accuracy(FPR95): {:.8f}\n\33[0m'.format(fpr95))

    return


def create_transform():
    return transforms.Compose([#transforms.Lambda(cv2_scale),
                               #transforms.Lambda(np_reshape),
                               transforms.ToTensor(),
                               transforms.Normalize((configs['dataset']['mean'],), (configs['dataset']['std'],))])

def create_trainloader(root: str, train_scenes: list):   
    dataset = TripletPhotoTour(root=root,
                               transform=create_transform(),
                               train_scenes=train_scenes)
    return DataLoader(dataset, batch_size=configs['batch_size'], shuffle=True, num_workers=configs['n_workers'])

def create_testloader(root: str, test_scene: str, is_challenge_data: bool):
    if is_challenge_data:
        dataset = TripletPhotoTour(root=root,
                                   transform=create_transform(),
                                   test_scene=test_scene,
                                   train=False)
    else:
        dataset = BrownTest(root=root,
                            scene=test_scene,
                            transform=create_transform())
    return DataLoader(dataset, batch_size=configs['batch_size'], shuffle=False, num_workers=configs['n_workers'])

In [5]:
train_dataloader = create_trainloader(root=configs['dataset']['challenge_root'],
                                      train_scenes=configs['dataset']['train_scenes'])

INFO:root:Load brandenburg_gate dataset with n_points: 62398
INFO:root:num_3d_points: 62398
INFO:root:generate 62398 triplets (simple add a negative patch to a matched patch)
100%|██████████| 62398/62398 [00:07<00:00, 8893.91it/s] 


In [6]:
test_dataloaders = []
for scene in configs['dataset']['test_scenes']['challenge']:
    logging.info(f'load test set: {scene}')
    dataloader = create_testloader(root=configs['dataset']['challenge_root'],
                                   test_scene=scene,
                                   is_challenge_data=True)
    test_dataloaders.append((scene, dataloader))

INFO:root:load test set: sacre_coeur
INFO:root:load test set: st_peters_square
INFO:root:load test set: reichstag


In [7]:
anchor = torch.tensor([
    [1, 3, 1],
    [2.0, 1, 3],
    [1, -1, -5]], requires_grad=True)
positive = torch.tensor([
    [-1, 5, 2],
    [-1.0, 0, 1],
    [2, 2, 3]], requires_grad=True)
# print(distance_matrix_vector(s, t) * distance_matrix_vector(s, t))

In [8]:
# loss_sosnet(anchor, positive)

In [9]:
def distance_matrix_vector(anchor: torch.tensor, positive: torch.tensor, is_the_same=False) -> torch.tensor:
    """
    anchor, positive: batch_size x 512
    return: batch_size x batch_size
    """    
#     global n
#     print(f'......{n}.....')
    
    d1_sq = torch.sum(anchor * anchor, dim=1).unsqueeze(-1)
#     print(sum(sum(torch.isnan(d1_sq))))
    d2_sq = torch.sum(positive * positive, dim=1).unsqueeze(-1)
#     print(sum(sum(torch.isnan(d2_sq))))

    eps = 1e-6
    s1 = (d1_sq.repeat(1, positive.size(0)) + torch.t(d2_sq.repeat(1, anchor.size(0)))
                      - 2.0 * torch.bmm(anchor.unsqueeze(0), torch.t(positive).unsqueeze(0)).squeeze(0))
    if is_the_same:
        eye = torch.eye(s1.size(1), requires_grad=True)
        s1 = s1 + eye * 1
    dist_matrix = torch.sqrt(s1 + eps)
    


    return dist_matrix

def get_distance_matrix_without_min_on_diag(dist_matrix: torch.tensor) -> torch.tensor:
    """
    dist_matrix: batch_size x batch_size
    return: batch_size x batch_size
    """
    eye = torch.eye(dist_matrix.size(1), requires_grad=True)
    
    # steps to filter out same patches that occur in distance matrix as negatives
    dist_without_min_on_diag = dist_matrix + eye * 10
    mask = (dist_without_min_on_diag.ge(0.008).float()-1.0)*(-1)
    mask = mask.type_as(dist_without_min_on_diag)*10
    dist_without_min_on_diag = dist_without_min_on_diag + mask
    return dist_without_min_on_diag
    
    
def loss_sosnet(anchor, positive):
    """HardNet margin loss - calculates loss based on distance matrix based on positive distance and closest negative distance.
    """

    assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal."
    assert anchor.dim() == 2, "Inputd must be a 2D matrix."
    n = anchor.shape[0]
    eps = 1e-8
        
    dist_matrix_a = distance_matrix_vector(anchor, anchor, True) + eps
    dist_without_min_on_diag_a = get_distance_matrix_without_min_on_diag(dist_matrix_a)

    dist_matrix = distance_matrix_vector(anchor, positive) + eps
    pos1 = torch.diag(dist_matrix)
    dist_without_min_on_diag = get_distance_matrix_without_min_on_diag(dist_matrix)

    dist_matrix_p = distance_matrix_vector(positive, positive, True) + eps
    dist_without_min_on_diag_p = get_distance_matrix_without_min_on_diag(dist_matrix_p)
    
    min_neg_a = torch.min(dist_without_min_on_diag_a,1)[0]
    min_neg1 = torch.min(dist_without_min_on_diag,1)[0]
    min_neg2 = torch.min(dist_without_min_on_diag,0)[0]
    min_neg_p = torch.min(dist_without_min_on_diag_p, 1)[0]
    
    min_neg = torch.min(torch.min(min_neg1, min_neg2), torch.min(min_neg_a, min_neg_p))
    pos = pos1

    fos_loss = torch.clamp(1 + pos - min_neg, min=0.0)
    fos_loss = torch.mean(fos_loss)
    
    
    with torch.no_grad():
        _, indices_1 = torch.topk(dist_without_min_on_diag_a, k=8, dim=1, largest=False)
        _, indices_2 = torch.topk(dist_without_min_on_diag_p, k=8, dim=1, largest=False)
    mask = torch.zeros(n, n)
    for i in range(mask.shape[0]):
        mask[i][indices_1[i]] = 1
        mask[i][indices_2[i]] = 1
    
    mask.requires_grad_(True)
    s = (dist_without_min_on_diag_a - dist_without_min_on_diag_p) * (dist_without_min_on_diag_a - dist_without_min_on_diag_p)
    s = mask * s
    s = torch.sum(s, dim=1)
    s = torch.sqrt(s + eps)
    sos_loss = torch.mean(s)
    return fos_loss + sos_loss

In [None]:
model = HardNet()
optimizer = torch.optim.Adam(model.parameters(),
                             lr=configs['lr'],
                             betas=(0.9, 0.999),
                             eps=1e-08,
                             weight_decay=0,
                             amsgrad=False)
for batch_id, data in enumerate(train_dataloader):
    if batch_id + 1 == len(train_dataloader):
        continue
    print(batch_id)
    data_a, data_p = data
#     with torch.no_grad():
#         cache[batch_id] = (data_a, data_p, model.state_dict())
    out_a = model(data_a)
    out_p = model(data_p)
#     print(out_a.shape, out_a[0][0:5])
#     print(out_p.shape, out_p[0][0:5])
#     print('*************')
    loss = loss_sosnet(out_a, out_p)
#     print('-------------')
    optimizer.zero_grad()
    loss.backward()
#     print(loss)
#     print(model.features[0].weight.grad[0][0][0])
#     print(model.features[0].weight[0][0][0])
    optimizer.step()
    with torch.no_grad():
        print(loss.item())
#     print(model.features[0].weight[0][0][0])
#     print('==============')

0
1.4206409454345703
1
1.417121171951294
2
1.3818728923797607
3
1.3628431558609009
4
1.3414639234542847
5
1.3379077911376953
6
1.3174974918365479
7
1.3137112855911255
8
1.3036553859710693
9
1.2810842990875244
10
1.286557912826538
11
1.286529302597046
12
1.2682174444198608
13


In [11]:
model = SOSNet32x32()

In [12]:
s1 = iter(train_dataloader).next()

In [13]:
a, b = s1

In [14]:
out_a = model(a)

In [16]:
torch.sum(out_a * out_a, 1).shape

torch.Size([1024])