In [1]:
import json
import os

import torch
import numpy as np
from PIL import Image
from sklearn.model_selection import train_test_split
from torch import nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms

from Datasets.Morph2.DataParser import DataParser
#from Datasets.Morph2.Morph2RecognitionDataset import Morph2RecognitionDataset
from Datasets.Morph2.Morph2RecognitionIdxDataset import Morph2RecognitionIdxDataset 
from Models.ArcMarginClassifier import ArcMarginClassifier
from Optimizers.RangerLars import RangerLars
from Training.train_recognition_model import train_recognition_model

BATCH_SIZE = 32
NUM_EPOCHS = 50

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.empty_cache()

data_parser = DataParser('./Datasets/Morph2/aligned_data/aligned_dataset_with_metadata_uint8.hdf5')
data_parser.initialize_data()

ids_train = np.unique([json.loads(m)['id_num'] for m in data_parser.y_train])
ids_test = np.unique([json.loads(m)['id_num'] for m in data_parser.y_test])

#X_train, X_test, y_train, y_test = train_test_split(data_parser.x_train, data_parser.y_train, test_size=0.33, random_state=42)
x_train, y_train, x_test, y_test = data_parser.x_train,	data_parser.y_train, data_parser.x_test, data_parser.y_test,

train_ds = Morph2RecognitionIdxDataset(
    x_train,
    y_train,
    ids_train,
    transforms.Compose([
#         transforms.RandomResizedCrop(224, (0.9, 1.0)),
#         transforms.RandomHorizontalFlip(),
#         transforms.ColorJitter(
#             brightness=0.125,
#             contrast=0.125,
#             saturation=0.125,
#             hue=0.125
#         ),
#         transforms.RandomAffine(
#             degrees=15,
#             translate=(0.15, 0.15),
#             scale=(0.85, 1.15),
#             shear=15,
#             resample=Image.BICUBIC
#         ),
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        # transforms.RandomErasing(p=0.5, scale=(0.02, 0.25))
    ])
)

test_ds = Morph2RecognitionIdxDataset(
    x_test,
    y_test,
    ids_test,
    transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
)

image_datasets = {
    'train': train_ds,
    'val': test_ds
}

data_loaders = {
    x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    for x in ['train', 'val']
}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

# Create model and parameters
model = ArcMarginClassifier(len(ids_train))
model.to(device)


pretrained_model_path = 'F:/age_estimation_with_error_estimator/weights/Morph2_recognition/vgg16/RangerLars_unfreeze_at_15_lr_1e2_steplr_01_batchsize_64'
pretrained_model_file = os.path.join(pretrained_model_path, "weights.pt")
model.load_state_dict(torch.load(pretrained_model_file), strict=True)



cuda:0


<All keys matched successfully>

In [2]:
model.base_net.features

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (17): Conv2d(256, 512, kernel_si

In [3]:
import torchextractor as tx

model_ext = tx.Extractor(model, ["base_net.classifier.5"])

In [4]:
#model_ext.to(device)

In [5]:
from tqdm import tqdm

In [6]:
face2emb_arr_trn = []
for i, batch in enumerate(tqdm(data_loaders['train'])):
    faces = batch['image'].to(device)
    with torch.no_grad():
        output, features = model_ext(faces)
    
    for j in range(len(output)):
        face2emb_arr_trn.append(features['base_net.classifier.5'][j].cpu().numpy().reshape(1, 4096))
        
#     if i == 5:
#         break


100%|██████████████████████████████████████████████████████████████████████████████| 1384/1384 [03:09<00:00,  7.30it/s]


In [7]:
np.save('face2emb_arr_trn_recog.npy', np.array(face2emb_arr_trn))

In [8]:
face2emb_arr_vld = []
for i, batch in enumerate(tqdm(data_loaders['val'])):
    faces = batch['image'].to(device)
    with torch.no_grad():
        output, features = model_ext(faces)
    
    for j in range(len(output)):
        face2emb_arr_vld.append(features['base_net.classifier.5'][j].cpu().numpy().reshape(1, 4096))
        
#     if i == 5:
#         break


100%|████████████████████████████████████████████████████████████████████████████████| 332/332 [00:44<00:00,  7.41it/s]


In [9]:
np.save('face2emb_arr_vld_recog.npy', np.array(face2emb_arr_vld))

In [10]:
face2emb_arr_trn_r = np.load('face2emb_arr_trn_recog.npy', allow_pickle=True)

In [11]:
eq = True
for i in range(len(face2emb_arr_trn_r)):
    if not np.array_equal(face2emb_arr_trn_r[i], face2emb_arr_trn[i]):
        eq = False
        break
        
if len(face2emb_arr_trn_r) != len(face2emb_arr_trn):
    eq = False
    
if eq:
    print("trn equal")
else:
    print("trn not eq")

trn equal


In [12]:
face2emb_arr_vld_r = np.load('face2emb_arr_vld_recog.npy', allow_pickle=True)

In [13]:
eq = True
for i in range(len(face2emb_arr_vld_r)):
    if not np.array_equal(face2emb_arr_vld_r[i], face2emb_arr_vld[i]):
        print("trn not eq")
        eq = False
        break

        
if len(face2emb_arr_vld_r) != len(face2emb_arr_vld):
    eq = False
    
if eq:
    print("vld equal")
else:
    print("vld not eq")

vld equal
