In [None]:
from __future__ import division, print_function

import argparse
import copy
from copy import deepcopy
import datetime
import logging
import math
import os
import random
import sys
import time
from tqdm import tqdm

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torch.utils.data import DataLoader

from eval_metrics import ErrorRateAt95Recall
from models import HardNet, ResNet, SOSNet32x32
from dataset import TripletPhotoTour, BrownTest
from losses import loss_HardNet, loss_L2Net
from utils import cv2_scale, np_reshape, read_yaml

logging.basicConfig(level=logging.INFO)

In [None]:
logging.info('\n\n================ IMAGE MATCHING CHALLENGE 2020 ==================\n\n')
configs = read_yaml('configs.yml')
# models_output = os.path.join('models', configs['experiment_name'])
os.makedirs(configs['model_dir'], exist_ok=True)

In [None]:
if configs['use_cuda']:
    os.environ['CUDA_VISIBLE_DEVICES'] = str(configs['gpu_id'])
    cudnn.benchmark = True
    torch.cuda.manual_seed_all(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)

In [None]:
def train(train_loader, model, optimizer, epoch):
    model.train()
    for batch_id, data in tqdm(enumerate(train_loader)):
        if batch_id + 1 == len(train_loader):
            continue
        data_a, data_p = data

        if configs['use_cuda']:
            data_a, data_p  = data_a.cuda(), data_p.cuda()
        out_a = model(data_a)
        out_p = model(data_p)

        loss = loss_HardNet(out_a, out_p)
           
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_id%10 == 0:
            logging.info(f'{batch_id}/{len(train_loader)} - Loss: {loss.item()}')
        if batch_id%100 == 0:
            x = datetime.datetime.now()
            time = x.strftime("%y-%m-%d_%H:%M:%S")
            model_checkpoint = os.path.join(configs['model_dir'], f'checkpoint_{time}_{epoch}_{batch_id}.pth')
            torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, model_checkpoint)
            logging.info(model_checkpoint)
    logging.info(f'{len(train_loader)}/{len(train_loader)} - Loss: {loss.item()}')
    
    x = datetime.datetime.now()
    time = x.strftime("%y-%m-%d_%H:%M:%S")
    model_checkpoint = os.path.join(configs['model_dir'], f'checkpoint_{time}_{epoch}.pth')
    torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, model_checkpoint)
    logging.info(model_checkpoint)


def test(test_loader, model, epoch):
    model.eval()

    labels, distances = [], [] 
    for (data_a, data_p, label) in tqdm(test_loader):

        if configs['use_cuda']:
            data_a, data_p = data_a.cuda(), data_p.cuda()

        with torch.no_grad():
            out_a = model(data_a)
            out_p = model(data_p)
        dists = torch.sqrt(torch.sum((out_a - out_p) ** 2, 1))  # euclidean distance
        distances.append(dists.data.cpu().numpy().reshape(-1,1))
        ll = label.data.cpu().numpy().reshape(-1, 1)
        labels.append(ll)

    num_tests = len(test_loader.dataset)
    labels = np.vstack(labels).reshape(num_tests)
    distances = np.vstack(distances).reshape(num_tests)

    fpr95 = ErrorRateAt95Recall(labels, 1.0 / (distances + 1e-8))
    logging.info('\33[91mTest set: Accuracy(FPR95): {:.8f}\n\33[0m'.format(fpr95))

    return


def create_transform():
    return transforms.Compose([#transforms.Lambda(cv2_scale),
                               #transforms.Lambda(np_reshape),
                               transforms.ToTensor(),
                               transforms.Normalize((configs['dataset']['mean'],), (configs['dataset']['std'],))])

def create_trainloader(root: str, train_scenes: list):   
    dataset = TripletPhotoTour(root=root,
                               transform=create_transform(),
                               train_scenes=train_scenes)
    return DataLoader(dataset, batch_size=configs['batch_size'], shuffle=True, num_workers=configs['n_workers'])

def create_testloader(root: str, test_scene: str, is_challenge_data: bool):
    if is_challenge_data:
        dataset = TripletPhotoTour(root=root,
                                   transform=create_transform(),
                                   test_scene=test_scene,
                                   train=False)
    else:
        dataset = BrownTest(root=root,
                            scene=test_scene,
                            transform=create_transform())
    return DataLoader(dataset, batch_size=configs['batch_size'], shuffle=False, num_workers=configs['n_workers'])

In [None]:
train_dataloader = create_trainloader(root=configs['dataset']['challenge_root'],
                                      train_scenes=configs['dataset']['train_scenes'])

In [None]:
test_dataloaders = []
for scene in configs['dataset']['test_scenes']['challenge']:
    logging.info(f'load test set: {scene}')
    dataloader = create_testloader(root=configs['dataset']['challenge_root'],
                                   test_scene=scene,
                                   is_challenge_data=True)
    test_dataloaders.append((scene, dataloader))

In [None]:
optimizer = torch.optim.Adam(model.parameters(),
                             lr=configs['lr'],
                             betas=(0.9, 0.999),
                             eps=1e-08,
                             weight_decay=0,
                             amsgrad=False)

In [None]:
anchor = torch.tensor([
    [1, 3, 1],
    [2.0, 1, 3],
    [1, -1, -5]], requires_grad=True)
positive = torch.tensor([
    [-1, 5, 2],
    [-1.0, 0, 1],
    [2, 2, 3]], requires_grad=True)
# print(distance_matrix_vector(s, t) * distance_matrix_vector(s, t))

In [None]:
def distance_matrix_vector(anchor: torch.tensor, positive: torch.tensor) -> torch.tensor:
    """
    anchor, positive: batch_size x 512
    return: batch_size x batch_size
    """    
    d1_sq = torch.sum(anchor * anchor, dim=1).unsqueeze(-1)
    d2_sq = torch.sum(positive * positive, dim=1).unsqueeze(-1)

    eps = 1e-6
    return torch.sqrt((d1_sq.repeat(1, positive.size(0)) + torch.t(d2_sq.repeat(1, anchor.size(0)))
                      - 2.0 * torch.bmm(anchor.unsqueeze(0), torch.t(positive).unsqueeze(0)).squeeze(0))+eps)


def get_distance_matrix_without_min_on_diag(dist_matrix: torch.tensor) -> torch.tensor:
    """
    dist_matrix: batch_size x batch_size
    return: batch_size x batch_size
    """
    eye = torch.eye(dist_matrix.size(1), requires_grad=True)
    
    # steps to filter out same patches that occur in distance matrix as negatives
    dist_without_min_on_diag = dist_matrix + eye * 10
    mask = (dist_without_min_on_diag.ge(0.008).float()-1.0)*(-1)
    mask = mask.type_as(dist_without_min_on_diag)*10
    dist_without_min_on_diag = dist_without_min_on_diag + mask
    return dist_without_min_on_diag
    
    
def loss_sosnet(anchor, positive):
    """HardNet margin loss - calculates loss based on distance matrix based on positive distance and closest negative distance.
    """

    assert anchor.size() == positive.size(), "Input sizes between positive and negative must be equal."
    assert anchor.dim() == 2, "Inputd must be a 2D matrix."
    eps = 1e-8
    
    dist_matrix_a = distance_matrix_vector(anchor, anchor) + eps
#     print(dist_matrix_a * dist_matrix_a)
    dist_without_min_on_diag_a = get_distance_matrix_without_min_on_diag(dist_matrix_a)
#     print(dist_without_min_on_diag_a * dist_without_min_on_diag_a)

    dist_matrix = distance_matrix_vector(anchor, positive) + eps
#     print(dist_matrix * dist_matrix)
    pos1 = torch.diag(dist_matrix)
    dist_without_min_on_diag = get_distance_matrix_without_min_on_diag(dist_matrix)
#     print(dist_without_min_on_diag * dist_without_min_on_diag)

    dist_matrix_p = distance_matrix_vector(positive, positive) + eps
#     print(dist_matrix_p * dist_matrix_p)
    dist_without_min_on_diag_p = get_distance_matrix_without_min_on_diag(dist_matrix_p)
#     print(dist_without_min_on_diag_p * dist_without_min_on_diag_p)

    
    min_neg_a = torch.min(dist_without_min_on_diag_a,1)[0]
    min_neg1 = torch.min(dist_without_min_on_diag,1)[0]
    min_neg2 = torch.min(dist_without_min_on_diag,0)[0]
    min_neg_p = torch.min(dist_without_min_on_diag_p, 1)[0]
    
    min_neg = torch.min(torch.min(min_neg1, min_neg2), torch.min(min_neg_a, min_neg_p))

    pos = pos1
#     print(min_neg_a *min_neg_a)
#     print(min_neg1 * min_neg1)
#     print(min_neg2 * min_neg2)
#     print(min_neg_p * min_neg_p)
#     print(pos * pos)
    loss = torch.clamp(1 + pos - min_neg, min=0.0)
    loss = torch.mean(loss)

    return loss

In [None]:
model = HardNet()
cache = {}
for batch_id, data in enumerate(train_dataloader):
    if batch_id + 1 == len(train_dataloader):
        continue
    print(batch_id)
    data_a, data_p = data
    with torch.no_grad():
        cache[batch_id] = (data_a, data_p, model.state_dict)
    out_a = model(data_a)
    out_p = model(data_p)
    print(out_a.shape, out_a[0][0:5])
    print(out_p.shape, out_p[0][0:5])
    print('*************')
    loss = loss_sosnet(out_a, out_p)
    optimizer.zero_grad()
    loss.backward()
    print(loss)
    print(model.features[0].weight.grad[0][0][0])
    print(model.features[0].weight[0][0][0])
    optimizer.step()
    print(model.features[0].weight[0][0][0])
    print('==============')

In [None]:
data_a, data_p, M = cache[3]
model_tmp = copy.deepcopy(M)

In [None]:
out_a = model_tmp(data_a)
out_p = model_tmp(data_p)