In [1]:
import os
import numpy as np

# from torch.utils.data import DataLoader, Dataset

from preprocess import resize_input, train_test_split, read_raw
from ear_dataset import EarDataset

import cv2
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
n_classes = len(os.listdir("../UERC/"))
n_classes

1310

In [3]:
classes = np.arange(1, n_classes + 1)

In [4]:
# Randomly choose 30% of the classes
test_mask = np.random.choice([True, False], size=n_classes, p=[0.3, 0.7])

test = classes[test_mask]
train = classes[~test_mask]

len(train), len(test)

(910, 400)

In [5]:
# train_subjects = np.loadtxt("../train_subjects_mask.txt")
# test_subjects = np.loadtxt("../test_subjects_mask.txt")

In [6]:
train_subjects = train
test_subjects = test

In [7]:
np.savetxt("train_subjects_mask.txt", train)
np.savetxt("test_subjects_mask.txt", train)

In [8]:
from PIL import Image

data_path = "../UERC"

In [9]:
ear_data = os.listdir(data_path)

ear_imgs = {}
for person in ear_data:
    if int(person) not in train_subjects:
        continue

    imgs = os.listdir("%s/%s" % (data_path, person))
    try:
        ear_imgs[person] = [
            cv2.cvtColor(
                np.asarray(Image.open(f"{data_path}/{person}/{img}")), cv2.COLOR_BGR2RGB
            )
            for img in imgs
        ]
    except Exception as e:
        print(e)

In [10]:
X_train, X_eval, y_train, y_eval = train_test_split(ear_imgs)

In [11]:
len(X_train), len(y_train), len(X_eval), len(y_eval)

(120809, 120809, 52270, 52270)

In [12]:
len(X_train), len(X_eval)

(120809, 52270)

In [13]:
len(set(y_train)), len(set(y_eval))

(910, 910)

### Preprocess data


In [14]:
X_train = resize_input(X_train, tgt_size=64, mode="train")

train_dataset = EarDataset(X_train, y_train)

len(train_dataset.data), len(train_dataset.labels)



(120809, 120809)

In [15]:
ear_imgs.keys()

dict_keys(['0001', '0002', '0004', '0005', '0006', '0007', '0008', '0009', '0011', '0013', '0014', '0015', '0016', '0017', '0019', '0020', '0022', '0023', '0024', '0025', '0026', '0030', '0031', '0032', '0033', '0034', '0035', '0037', '0038', '0040', '0042', '0043', '0044', '0045', '0046', '0048', '0049', '0050', '0052', '0053', '0054', '0056', '0057', '0058', '0059', '0060', '0061', '0062', '0063', '0064', '0065', '0066', '0068', '0069', '0070', '0071', '0074', '0075', '0078', '0080', '0082', '0083', '0084', '0085', '0086', '0088', '0090', '0091', '0092', '0094', '0095', '0096', '0098', '0101', '0103', '0104', '0105', '0106', '0107', '0109', '0110', '0111', '0112', '0113', '0114', '0115', '0116', '0117', '0118', '0121', '0122', '0123', '0124', '0126', '0127', '0128', '0129', '0130', '0131', '0133', '0134', '0135', '0136', '0137', '0138', '0140', '0141', '0142', '0143', '0144', '0145', '0146', '0147', '0150', '0151', '0152', '0154', '0155', '0157', '0158', '0164', '0165', '0167', '0168

In [16]:
torch.save(train_dataset, "data/train_dataset.pt")

In [17]:
X_eval = resize_input(X_eval, tgt_size=64, mode="test")

eval_dataset = EarDataset(X_eval, y_eval)

In [18]:
len(eval_dataset.data), len(eval_dataset.labels)

(52270, 52270)

In [20]:
torch.save(eval_dataset, "data/eval_dataset.pt")