In [6]:
# other imports
import sys
sys.path.append('..')
import numpy as np
from PIL import Image
import glob
from os.path import join

# PyTorch imports
import torch
from torch import cuda
from torch.autograd import Variable

# own scripts imports
from training.helpers import get_device
import training.ds_transformations as td

In [7]:
class Config():
    """
    Configuration Class in which all necessary parameters that will be used in the further process are defined.
    """
    DEVICE = get_device()
    DATASET_DIR = '../../ear_dataset'
    MODEL_DIR = '../../models/ve_g_margin_2,0.pt'
    is_small_resize = False
    DATABASE_FOLDER = '../../embeddings/'

In [8]:
# Load the model that will be used to create the embeddings.
model = torch.load(Config.MODEL_DIR, map_location=torch.device(Config.DEVICE))
# Specify a set of transformations to be applied to all captured images before creating embeddings.
transformation = td.get_transform('siamese_valid_and_test', Config.is_small_resize)

In [9]:
def pipeline(input_, preprocess):
    """
    This method performs a series of image processing procedures. It also checks whether one of the tensor in the
    following can be processed on the graphics card.
    1. convert the input to gray image
    2. perform preprocessing (in this case defined in the transformations
    3. sizes adjustment
    4. rearrange the tensor
    """
    input_ = input_.convert("L")
    input_ = preprocess(input_)
    input_ = input_.reshape(-1, td.get_resize(Config.is_small_resize)[0], td.get_resize(Config.is_small_resize)[1], 1)
    input_ = input_.permute(3, 0, 1, 2)   
    if cuda.is_available():
        return input_.type('torch.cuda.FloatTensor')
    else:
        return input_.type('torch.FloatTensor')

In [12]:
# Here, each image is now converted into an embedding. 
# First, the images are preprocessed, then processed through the network and converted into an embedding
# Finally the Embeddings are saved in our embeddings database.
for label in os.listdir(Config.DATASET_DIR):
    embeddings = []
    image_list = []
    for filename in glob.glob( join(Config.DATASET_DIR, label, '*') ):
        img = Image.open(filename)
        img_processed = pipeline(img,transformation)
        image_list.append(img_processed)
        
    embeddings = np.array([model(Variable(i)).cpu() for i in image_list])
        
    np.save( join(Config.DATABASE_FOLDER,label+'.npy'), embeddings)     

