## Setup imports

In [None]:
import os
import os.path as osp
import sys
from __future__ import absolute_import
def add_path(path):
    if path not in sys.path:
        sys.path.insert(0, path)
this_dir = osp.dirname('.')
lib_path = osp.join(this_dir, '..', 'lib')
add_path(lib_path)
add_path(osp.join(this_dir, '..'))

In [None]:
from models import build_model
from config import config
from config import _update_config_from_file
from dataset import build_dataloader
from core.function import only_forward
from metrics import global_metrics

In [None]:
import torch
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import copy

## Set the config
- Go to the required config file and set the DATASET.VAL_IMGS field to the path of the data folder you want embeddings for.

In [None]:
# for only global model use:
# model_file = '../models/global.pth'
# _update_config_from_file(config, '..\experiments\global-inference.yaml')

# for only local model use:
# model_file = '../models/local.pth'
# _update_config_from_file(config, '..\experiments\local-inference.yaml')

# for global + local model use:
model_file = '../models/global_plus_local.pth'
_update_config_from_file(config, '..\experiments\global_plus_local-inference.yaml')

## Setup the model

In [None]:
model = build_model(config)

In [None]:
state_dict = torch.load(model_file, map_location="cpu")
model.load_state_dict(state_dict, strict=True)
model.to(torch.device('cuda'))
model

## Setup the dataloader

In [None]:
valid_loader = build_dataloader(config, False)

## Forward pass

In [None]:
results = only_forward(config, valid_loader, model)

In [None]:
if 'global' in results.keys():
    print("Model generated Global embeddings with shape:", results['global'].shape)
if 'posori' in results.keys():
    print("Model performed Minutiae extraction with shape:", results['posori'].shape)
if 'embs' in results.keys():
    print("Model generated corresponding Minutiae embeddings with shape:", results['embs'].shape)

## Visualisations

In [None]:
if 'global' in results.keys():
    plt.matshow(results['global'][:min(50, results['global'].shape[0]), :10])
    plt.title('Global Embeddings')
    plt.colorbar()
    plt.show()

In [None]:
# run this only if the model is predicting local information
l = 15
for i in range(min(10, len(valid_loader.dataset.imgs))):
    path = valid_loader.dataset.imgs[i]
    img = Image.open(path).convert('RGB')
    img = valid_loader.dataset.transforms(img)[0]
    plt.imshow(img, cmap='gray')
    mnts = copy.deepcopy(results['posori'][i])
    mnts[:, 0] *= img.shape[0]
    mnts[:, 1] *= img.shape[1]
    mnts[:, 2] *= 2 * np.pi
    plt.scatter(mnts[:, 0], mnts[:, 1], color='red', s=30, alpha=0.75)
    for i in range(mnts.shape[0]):
        plt.plot((mnts[i, 0], mnts[i, 0] + l * np.cos(2*np.pi-mnts[i, 2])), (mnts[i, 1], mnts[i, 1] + l * np.sin(2*np.pi-mnts[i, 2])), color='red', alpha=0.75)
        plt.title(path)
    plt.show()