In [1]:
import os
import numpy as np
# from torch.utils.data import DataLoader, Dataset

from preprocess import resize_input, train_test_split, read_raw
from ear_triplet import EarTriplet

import cv2
import torch


In [2]:
train_subjects = np.loadtxt("./train_500.txt")

In [3]:
from PIL import Image
data_path = "../UERC"

In [4]:
ear_data = os.listdir(data_path)

ear_imgs = {}
for person in ear_data:
    if int(person) not in train_subjects:
        continue
    
    imgs = os.listdir("%s/%s" % (data_path, person))
    try:
        ear_imgs[person] = [
            cv2.cvtColor(
                np.asarray(Image.open(f"{data_path}/{person}/{img}")), cv2.COLOR_BGR2RGB
            )
            for img in imgs
        ]
    except Exception as e:
        print(e)

In [5]:
def split_triplets(X, y):
    # Constraint: anchor and positive must be the same person
    X, y = np.array(X), np.array(y)
    print("Classes: ", len(np.unique(y)))

    # Key has a list of indices of instances of that label
    # label_to_indices[label] returns indices of all the instances of class=label
    label_to_indices = {label: np.where(y == label)[0] for label in np.unique(y)}
    
    for key in label_to_indices.keys():
        np.random.shuffle(label_to_indices[key])

    an, pos, neg = [], [], []

    # For each class
    for i, label in enumerate(label_to_indices.keys()):
        subarray_size = len(label_to_indices[label]) // 2
        
        # Get random negative classes; negs = array of classes
        negs = np.random.choice(list(label_to_indices.keys()), size=subarray_size, replace=False)
        while label in negs:
            negs = np.random.choice(list(label_to_indices.keys()), size=subarray_size, replace=False)

        # anchor and positive have label = label and negative is random
        an.extend(label_to_indices[label][:subarray_size])
        pos.extend(label_to_indices[label][subarray_size:][:subarray_size])
        
        # Negative = random class od vseh ostalih, vzami random instanco iz vsakega
        for neg_label in negs:
            neg.extend(np.random.choice(label_to_indices[neg_label], size=1))

    print(len(an), len(pos), len(neg))

    # Use the resulting arrays of indices to get the split arrays
    anchor_data, anchor_labels = X[an], y[an]
    positive_data, positive_labels = X[pos], y[pos]
    negative_data, negative_labels = X[neg], y[neg]
        
    return (anchor_data, anchor_labels), (positive_data, positive_labels), (negative_data, negative_labels)


In [6]:
X_train, X_eval, y_train, y_eval = train_test_split(ear_imgs)

In [7]:
len(X_train)

5174

In [9]:
X_train = resize_input(X_train, mode="train")

train_triplets = split_triplets(X_train, y_train)

train_anc, train_pos, train_neg = train_triplets

train_dataset = EarTriplet(train_anc[0], train_anc[1], train_pos[0], train_pos[1], train_neg[0], train_neg[1])

len(train_dataset.anchor_data), len(train_dataset.anchor_labels), len(train_dataset.positive_data), len(train_dataset.positive_labels), len(train_dataset.negative_data), len(train_dataset.negative_labels)




Classes:  362
2474 2474 2474


(2474, 2474, 2474, 2474, 2474, 2474)

In [None]:
torch.save(train_dataset, "data/train_dataset_500.pt")

: 

: 

In [None]:
X_eval = resize_input(X_eval, mode="test")

eval_triplets = split_triplets(X_eval, y_eval)

eval_anc, eval_pos, eval_neg = eval_triplets

eval_dataset = EarTriplet(eval_anc[0], eval_anc[1], eval_pos[0], eval_pos[1], eval_neg[0], eval_neg[1])


In [None]:
torch.save(eval_dataset, "data/eval_dataset_500.pt")
