In [60]:
import torch
from torch.utils.data import Dataset
import numpy as np

# Randomly split training and testing datasets
np.random.seed(67)
torch.manual_seed(67)

<torch._C.Generator at 0x7f1adc089950>

In [61]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [62]:
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = r-a  # free inside reserved

t, r, f

(11554717696, 2969567232, 2794906112)

In [63]:
from tqdm import tqdm
import math
from glob import glob
from pathlib import Path

In [64]:
from concurrent.futures import ProcessPoolExecutor
# from loky import ProcessPoolExecutor  # for Windows users

def parallel(func, iterable):
    e = ProcessPoolExecutor()
    return e.map(func, iterable)

In [65]:
from PIL import Image

def verify_image(fn):
    "Confirm that `fn` can be opened"
    try:
        im = Image.open(fn)
        im.draft(im.mode, (32,32))
        im.load()
        return True
    except: return False

In [66]:
sample_paths = [Path(g) for g in glob("./data/new_image_crops/*")]
input_paths = np.array([(Path(path) / "0.jpg", Path(path) / "1.jpg") for path in sample_paths]).flatten()

In [67]:
num_of_samples = len(input_paths)

In [68]:
is_valid = parallel(verify_image, input_paths)

In [69]:
np.array([valid for valid in is_valid]).sum()

1190

In [70]:
from PIL import Image
import numpy as np

mean_rgb = (131.0912, 103.8827, 91.4953)

def load_image_for_feature_extraction(path='', shape=None):
    '''
    Referenced from VGGFace2 Paper:
    Q. Cao, L. Shen, W. Xie, O. M. Parkhi, and A. Zisserman, “VGGFace2: A dataset for recognising faces across pose and age,” arXiv:1710.08092 [cs], May 2018
    '''
    short_size = 224.0
    crop_size = shape
    img = Image.open(path)
    im_shape = np.array(img.size)    # in the format of (width, height, *)
    img = img.convert('RGB')

    ratio = float(short_size) / np.min(im_shape)
    img = img.resize(size=(int(np.ceil(im_shape[0] * ratio)),   # width
                           int(np.ceil(im_shape[1] * ratio))),  # height
                     resample=Image.BILINEAR)

    x = np.array(img)  # image has been transposed into (height, width)
    newshape = x.shape[:2]
    h_start = (newshape[0] - crop_size[0])//2
    w_start = (newshape[1] - crop_size[1])//2
    x = x[h_start:h_start+crop_size[0], w_start:w_start+crop_size[1]]
    
    # normalize colors to prevent overfitting on color differences 
    x = x - mean_rgb
    
    # returns transformed image, and original image
    return x

In [71]:
import warnings
import os
image_size = (224,224,3)

np.random.seed(67)

def generate_batch(batch_size=16, shuffle=False):
    total_samples = len(sample_paths)
    
    if shuffle:
        idx = np.random.permutation(total_samples)
    else:
        idx = np.arange(total_samples)
        
    
    for ndx in range(0, total_samples, batch_size):
        batch_start = ndx
        batch_end = np.min([ndx + batch_size, total_samples])
        batch_idx = idx[batch_start: batch_end]
        
        batch_paths = np.array(sample_paths)[batch_idx]
        
        batch_images = []
        batch_image2idx = []
               
        for i, (nid, path) in enumerate(zip(batch_idx, batch_paths)):
            sub_image_paths = os.listdir(path)
            
            if(len(sub_image_paths) != 2):
                warnings.warn(f"{path} has {len(sub_image_paths)} files")
            else:
                
                batch_images.append(load_image_for_feature_extraction(path / sub_image_paths[0], image_size))
                batch_images.append(load_image_for_feature_extraction(path / sub_image_paths[1], image_size))
                batch_image2idx.append(nid)
                batch_image2idx.append(nid)
        
        yield np.stack(batch_images), np.stack(batch_image2idx), (batch_start, batch_end)

In [72]:
from saved_model.prepare_resnet50 import prepare_resnet_model

resnet_model = prepare_resnet_model("./saved_model/resnet50_ft_weight.pkl")

In [73]:
num_of_features = 2048

features = torch.empty((len(input_paths), num_of_features)).cpu()

In [74]:
batch_size = 16
num_of_batches = math.ceil(len(sample_paths) / batch_size)

for batch_images, batch_image2idx, batch_num in tqdm(generate_batch(batch_size=batch_size), total=num_of_batches):
    batch_start = batch_num[0] * 2
    batch_end = np.min([batch_num[1] * 2, len(input_paths)])
    
    x = torch.Tensor(batch_images.transpose(0, 3, 1, 2))  # nx3x224x224
    x = x.to(device)
    feat = resnet_model(x).cpu().detach()
    
    features[batch_start:batch_end, :] = feat

100%|███████████████████████████████████████████| 38/38 [00:05<00:00,  7.43it/s]


In [75]:
features[600]

tensor([2.9841, 0.3229, 0.1657,  ..., 0.0000, 1.5662, 0.0187])

In [76]:
class CustomDataset(Dataset):
    def __init__(self, indexes):
        self.indexes = indexes
        
    def __len__(self):
        return len(self.indexes)
    
    def __getitem__(self, i):
        index = self.indexes[i]
        return features[index], labels[index]

In [77]:
# Randomly split training and testing datasets
np.random.seed(67)

num_of_samples = len(input_paths)
print(f"Total number of samples: {num_of_samples}")

idx = np.random.permutation(range(num_of_samples))
cut = int(0.8 * num_of_samples)
train_idx = idx[:cut]
valid_idx = idx[cut:]

Total number of samples: 1190


In [78]:
from torch.utils.data import DataLoader

labels = np.resize(np.array([0, 1]), len(input_paths))

batch_size = 16

train_ds = CustomDataset(train_idx)
valid_ds = CustomDataset(valid_idx)

train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
valid_dl = DataLoader(valid_ds, batch_size=batch_size)

In [79]:
batch_features, batch_labels = next(iter(train_dl))

In [81]:
batch_features.shape

torch.Size([16, 2048])

In [82]:
import torch.nn as nn

# Define a simple binary classifier that takes a 2048 feature long tensor as input
class BinaryClassifier(nn.Module):
    def __init__(self):
        super(BinaryClassifier, self).__init__()        
        
        # Number of input features is 2048
        self.layer_1 = nn.Linear(2048, 2048)
        self.layer_2 = nn.Linear(2048, 1)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        
    def forward(self, inputs):
        x = self.relu(self.layer_1(inputs))
        x = self.dropout(x)
        x = self.layer_2(x)
        
        return x

In [83]:
import torch.optim as optim

LEARNING_RATE = 0.0001

model = BinaryClassifier()
loss_function = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [84]:
def binary_acc(y_pred, y_test):
    # Transform outputs to 0 and 1
    y_pred_tag = torch.round(torch.sigmoid(y_pred))

    # Calculate percentage of correct predictions
    correct_results_sum = (y_pred_tag == y_test).sum().float()
    acc = correct_results_sum/y_test.shape[0]
    
    return acc

In [94]:
import traceback

EPOCHS = 30

losses = []
val_losses = []
accuracies = []
val_accuracies = []

# Move model to GPU if possible
model = model.to(device)
# Tells PyTorch we are in training mode
model.train()


for e in range(EPOCHS):

    # Set loss and accuracy to zero at start of each epoch
    epoch_training_loss = 0
    epoch_training_accuracy = 0
    epoch_valid_loss = 0
    epoch_valid_accuracy = 0

    with tqdm(train_dl, unit="batch") as tepoch:
        for x_batch, y_batch in tepoch:
            tepoch.set_description(f"Epoch {e}")
            # Transfer the tensors to the GPU if possible
            x_batch = x_batch.to(device, dtype=torch.float)
            y_batch = y_batch.to(device, dtype=torch.float)

            # Zero out gradients before backpropagation (PyTorch cumulates the gradient otherwise)
            optimizer.zero_grad()

            # Predict a minibatch of outputs
            y_pred = model(x_batch)

            # Calculate the loss (unsqueeze adds a dimension to y)



            loss = loss_function(y_pred, y_batch.unsqueeze(1))
            training_acc = binary_acc(y_pred, y_batch.unsqueeze(1))

            # Backpropagation. Gradients are calculated
            loss.backward()
            optimizer.step()

            batch_loss = loss.item()
            batch_acc = training_acc.item()
            epoch_training_loss += batch_loss
            epoch_training_accuracy += batch_acc
            losses.append(batch_loss)
            accuracies.append(batch_acc)

            # tepoch.set_postfix(loss=loss.item(), accuracy=100. * training_acc.item())

    for x_batch, y_batch in valid_dl:
        x_batch = x_batch.to(device, dtype=torch.float)
        y_batch = y_batch.to(device, dtype=torch.float)

        valid_y_pred = model(x_batch)
        valid_loss = loss_function(valid_y_pred, y_batch.unsqueeze(1))
        valid_acc = binary_acc(valid_y_pred, y_batch.unsqueeze(1))

        batch_valid_loss = valid_loss.item()
        batch_valid_accuracy = valid_acc.item()
        epoch_valid_loss += batch_valid_loss
        epoch_valid_accuracy += batch_valid_accuracy
        val_losses.append(batch_valid_loss)
        val_accuracies.append(batch_valid_accuracy)

    avg_train_loss = epoch_training_loss/len(train_dl)
    avg_valid_loss = epoch_training_loss/len(valid_dl)

    avg_train_accuracy = epoch_training_accuracy/len(train_dl)
    avg_valid_accuracy = epoch_valid_accuracy/len(valid_dl)

    print(f'End of Epoch {e}: | Training Loss: {avg_train_loss:.5f} | Training accuracy: {avg_train_accuracy} | Validation Loss: {avg_valid_loss} | Validation Accuracy: {avg_valid_accuracy}')

    

Epoch 0: 100%|██████████████████████████████| 60/60 [00:00<00:00, 385.33batch/s]


End of Epoch 0: | Training Loss: 0.02948 | Training accuracy: 0.9989583333333333 | Validation Loss: 0.11792726715405782 | Validation Accuracy: 0.4607142865657806


Epoch 1: 100%|██████████████████████████████| 60/60 [00:00<00:00, 438.32batch/s]


End of Epoch 1: | Training Loss: 0.02501 | Training accuracy: 0.9979166666666667 | Validation Loss: 0.10005000693102678 | Validation Accuracy: 0.47321428656578063


Epoch 2: 100%|██████████████████████████████| 60/60 [00:00<00:00, 440.19batch/s]


End of Epoch 2: | Training Loss: 0.01834 | Training accuracy: 1.0 | Validation Loss: 0.07336405261109273 | Validation Accuracy: 0.4648809532324473


Epoch 3: 100%|██████████████████████████████| 60/60 [00:00<00:00, 411.09batch/s]


End of Epoch 3: | Training Loss: 0.01516 | Training accuracy: 1.0 | Validation Loss: 0.060644253715872766 | Validation Accuracy: 0.4648809532324473


Epoch 4: 100%|██████████████████████████████| 60/60 [00:00<00:00, 448.05batch/s]


End of Epoch 4: | Training Loss: 0.01291 | Training accuracy: 1.0 | Validation Loss: 0.051637458894401786 | Validation Accuracy: 0.45654761989911397


Epoch 5: 100%|██████████████████████████████| 60/60 [00:00<00:00, 424.60batch/s]


End of Epoch 5: | Training Loss: 0.01047 | Training accuracy: 1.0 | Validation Loss: 0.04187389245877663 | Validation Accuracy: 0.4648809532324473


Epoch 6: 100%|██████████████████████████████| 60/60 [00:00<00:00, 443.31batch/s]


End of Epoch 6: | Training Loss: 0.00954 | Training accuracy: 1.0 | Validation Loss: 0.03815561728551984 | Validation Accuracy: 0.4523809532324473


Epoch 7: 100%|██████████████████████████████| 60/60 [00:00<00:00, 441.99batch/s]


End of Epoch 7: | Training Loss: 0.00783 | Training accuracy: 1.0 | Validation Loss: 0.03132592967400948 | Validation Accuracy: 0.4648809532324473


Epoch 8: 100%|██████████████████████████████| 60/60 [00:00<00:00, 443.44batch/s]


End of Epoch 8: | Training Loss: 0.00718 | Training accuracy: 1.0 | Validation Loss: 0.02870678637797634 | Validation Accuracy: 0.4648809532324473


Epoch 9: 100%|██████████████████████████████| 60/60 [00:00<00:00, 441.62batch/s]


End of Epoch 9: | Training Loss: 0.00639 | Training accuracy: 1.0 | Validation Loss: 0.025556146011998255 | Validation Accuracy: 0.4648809532324473


Epoch 10: 100%|█████████████████████████████| 60/60 [00:00<00:00, 442.14batch/s]


End of Epoch 10: | Training Loss: 0.00570 | Training accuracy: 1.0 | Validation Loss: 0.022786633297801017 | Validation Accuracy: 0.46964285771052044


Epoch 11: 100%|█████████████████████████████| 60/60 [00:00<00:00, 446.42batch/s]


End of Epoch 11: | Training Loss: 0.00516 | Training accuracy: 1.0 | Validation Loss: 0.02064001321171721 | Validation Accuracy: 0.4607142865657806


Epoch 12: 100%|█████████████████████████████| 60/60 [00:00<00:00, 449.76batch/s]


End of Epoch 12: | Training Loss: 0.00464 | Training accuracy: 1.0 | Validation Loss: 0.018567743059247733 | Validation Accuracy: 0.48571428656578064


Epoch 13: 100%|█████████████████████████████| 60/60 [00:00<00:00, 448.22batch/s]


End of Epoch 13: | Training Loss: 0.00412 | Training accuracy: 1.0 | Validation Loss: 0.016484972182661296 | Validation Accuracy: 0.4648809532324473


Epoch 14: 100%|█████████████████████████████| 60/60 [00:00<00:00, 446.77batch/s]


End of Epoch 14: | Training Loss: 0.00364 | Training accuracy: 1.0 | Validation Loss: 0.014551196433603763 | Validation Accuracy: 0.469047619899114


Epoch 15: 100%|█████████████████████████████| 60/60 [00:00<00:00, 449.06batch/s]


End of Epoch 15: | Training Loss: 0.00333 | Training accuracy: 1.0 | Validation Loss: 0.013332185770074527 | Validation Accuracy: 0.47321428656578063


Epoch 16: 100%|█████████████████████████████| 60/60 [00:00<00:00, 448.16batch/s]


End of Epoch 16: | Training Loss: 0.00302 | Training accuracy: 1.0 | Validation Loss: 0.012072455851982038 | Validation Accuracy: 0.45714285771052043


Epoch 17: 100%|█████████████████████████████| 60/60 [00:00<00:00, 444.63batch/s]


End of Epoch 17: | Training Loss: 0.00278 | Training accuracy: 1.0 | Validation Loss: 0.011124632973223924 | Validation Accuracy: 0.469047619899114


Epoch 18: 100%|█████████████████████████████| 60/60 [00:00<00:00, 444.86batch/s]


End of Epoch 18: | Training Loss: 0.00255 | Training accuracy: 1.0 | Validation Loss: 0.010217606912677486 | Validation Accuracy: 0.47321428656578063


Epoch 19: 100%|█████████████████████████████| 60/60 [00:00<00:00, 448.61batch/s]


End of Epoch 19: | Training Loss: 0.00243 | Training accuracy: 1.0 | Validation Loss: 0.009735634985069434 | Validation Accuracy: 0.4821428577105204


Epoch 20: 100%|█████████████████████████████| 60/60 [00:00<00:00, 448.39batch/s]


End of Epoch 20: | Training Loss: 0.00221 | Training accuracy: 1.0 | Validation Loss: 0.00883532288329055 | Validation Accuracy: 0.45654761989911397


Epoch 21: 100%|█████████████████████████████| 60/60 [00:00<00:00, 448.14batch/s]


End of Epoch 21: | Training Loss: 0.00199 | Training accuracy: 1.0 | Validation Loss: 0.007943761100371679 | Validation Accuracy: 0.4773809532324473


Epoch 22: 100%|█████████████████████████████| 60/60 [00:00<00:00, 443.71batch/s]


End of Epoch 22: | Training Loss: 0.00189 | Training accuracy: 1.0 | Validation Loss: 0.007565606959785024 | Validation Accuracy: 0.47797619104385375


Epoch 23: 100%|█████████████████████████████| 60/60 [00:00<00:00, 440.89batch/s]


End of Epoch 23: | Training Loss: 0.00176 | Training accuracy: 1.0 | Validation Loss: 0.007044565522422394 | Validation Accuracy: 0.469047619899114


Epoch 24: 100%|█████████████████████████████| 60/60 [00:00<00:00, 444.96batch/s]


End of Epoch 24: | Training Loss: 0.00163 | Training accuracy: 1.0 | Validation Loss: 0.006538837992896636 | Validation Accuracy: 0.469047619899114


Epoch 25: 100%|█████████████████████████████| 60/60 [00:00<00:00, 444.21batch/s]


End of Epoch 25: | Training Loss: 0.00151 | Training accuracy: 1.0 | Validation Loss: 0.006049358802071462 | Validation Accuracy: 0.4863095243771871


Epoch 26: 100%|█████████████████████████████| 60/60 [00:00<00:00, 444.71batch/s]


End of Epoch 26: | Training Loss: 0.00140 | Training accuracy: 1.0 | Validation Loss: 0.005593762034550309 | Validation Accuracy: 0.46547619104385374


Epoch 27: 100%|█████████████████████████████| 60/60 [00:00<00:00, 444.32batch/s]


End of Epoch 27: | Training Loss: 0.00133 | Training accuracy: 1.0 | Validation Loss: 0.005329897149931639 | Validation Accuracy: 0.4773809532324473


Epoch 28: 100%|█████████████████████████████| 60/60 [00:00<00:00, 443.47batch/s]


End of Epoch 28: | Training Loss: 0.00123 | Training accuracy: 1.0 | Validation Loss: 0.004904818313661963 | Validation Accuracy: 0.4773809532324473


Epoch 29: 100%|█████████████████████████████| 60/60 [00:00<00:00, 443.64batch/s]


End of Epoch 29: | Training Loss: 0.00114 | Training accuracy: 1.0 | Validation Loss: 0.004550012670612584 | Validation Accuracy: 0.4821428577105204
