In [1]:
import os
import uuid
import multiprocessing
from functools import reduce

import cv2

import numpy as np

import torch
import torchvision.transforms as T
from torch.utils.data import DataLoader, Dataset, sampler

from matplotlib import pyplot as plt

from pdb import set_trace

from ddh import *
from dataset import *
from logger import *

In [2]:
# reset the data

undo_create_set("val")
undo_create_set("test")
create_set("val")
create_set("test")

In [3]:
# switch to the appropriate device

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

In [4]:
# add data transformation here

transform = T.Compose([
    T.ToPILImage(),
    T.Resize((32, 32)),
    T.ToTensor()
])

In [5]:
# build the dataset 

data_train = FaceScrubDataset(type="hash_label",
                              mode="train",
                              transform=transform)
data_val = FaceScrubDataset(type="hash_label",
                            mode="val",
                            transform=transform)
data_test = FaceScrubDataset(type="hash_label",
                             mode="test",
                             transform=transform)

In [6]:
# setting up the data loader

loader_params = {
    "shuffle": True,
    # "num_workers": multiprocessing.cpu_count() - 2
    "num_workers": 1
}

loader_train = DataLoader(data_train,
                          batch_size=256,
                          **loader_params)
loader_val = DataLoader(data_val,
                          batch_size=256,
                          **loader_params)
loader_test = DataLoader(data_test,
                          batch_size=256,
                          **loader_params)

In [7]:
# hyperparameters

num_epochs = 100

optimizer_params = {
    "lr": 1e-4,
    "weight_decay": 0.0
}
custom_params = {
    "beta": 1.0
}

In [8]:
# components
model = DiscriminativeDeepHashing()
model.to(device=device)
optimizer = optim.Adam(model.parameters(), **optimizer_params)

run_id = uuid.uuid4().hex.upper()[0:6]
checkpoint_path = os.getcwd() + "/models/{}_best_weights.pt".format(run_id)

In [9]:
# training code

with Logger(write_to_file=False) as logger:
    logger.write("Starting run for {}".format(run_id))

    for epoch in range(num_epochs):
        epoch_train_losses = []
        epoch_train_acc = []

        for num_iter, (X, y) in enumerate(loader_train):
            # set model to train mode
            model.train()
            # forward pass
            batch = X.to(device).float()
            batch_size = len(X)
            label = y.to(device).float()
            outputs = model(X)
            # calculating loss
            loss = F.binary_cross_entropy_with_logits(outputs, label)
            # loss to encourage code to be closer to -1 and 1
            loss += custom_params["beta"]* \
                        torch.abs((torch.abs(outputs)-1).sum()) / batch_size
            # back propagate
            optimizer.zero_grad()
            epoch_train_losses.append(loss.item())
            loss.backward()
            optimizer.step()
            # calculate accuracy over training data
            y = y.byte()
            output_hash = outputs.data > 0
            num_correct = ((output_hash == y).sum(1) == 0).sum().item()
            num_total = len(y)
            num_digits_correct = (output_hash == y).sum()
            num_digits_total = reduce(lambda x, y: x*y, output_hash.shape)

            print(
                "Epoch {} ".format(epoch) +
                "iter {}: ".format(num_iter) +
                "loss - {:.10f} ".format(loss.item()) +
                "correct/total - {}/{} ({}%)"
                    .format(num_correct, num_total,
                            num_correct * 100.0 / num_total) +
                "digits/total - {}/{} ({}%)"
                    .format(num_digits_correct, num_digits_total,
                            num_digits_correct * 100.0 / num_digits_total)
            )

    logger.write("Completed run for {}".format(run_id))

Starting run for FB8964
Epoch 0 iter 0: loss - 43.5558433533 correct/total - 0/64 (0.0%)digits/total - 1570/3072 (51%)
Epoch 1 iter 0: loss - 42.8694572449 correct/total - 0/64 (0.0%)digits/total - 1577/3072 (51%)
Epoch 2 iter 0: loss - 42.1569099426 correct/total - 0/64 (0.0%)digits/total - 1571/3072 (51%)
Epoch 3 iter 0: loss - 41.4229507446 correct/total - 0/64 (0.0%)digits/total - 1573/3072 (51%)
Epoch 4 iter 0: loss - 40.6597213745 correct/total - 0/64 (0.0%)digits/total - 1581/3072 (51%)
Epoch 5 iter 0: loss - 39.8700141907 correct/total - 0/64 (0.0%)digits/total - 1571/3072 (51%)
Epoch 6 iter 0: loss - 39.0543556213 correct/total - 0/64 (0.0%)digits/total - 1575/3072 (51%)
Epoch 7 iter 0: loss - 38.2154121399 correct/total - 0/64 (0.0%)digits/total - 1573/3072 (51%)
Epoch 8 iter 0: loss - 37.3534164429 correct/total - 0/64 (0.0%)digits/total - 1563/3072 (50%)
Epoch 9 iter 0: loss - 36.4646606445 correct/total - 0/64 (0.0%)digits/total - 1578/3072 (51%)
Epoch 10 iter 0: loss - 35

Epoch 86 iter 0: loss - 1.3617494106 correct/total - 0/64 (0.0%)digits/total - 1585/3072 (51%)
Epoch 87 iter 0: loss - 1.2270123959 correct/total - 0/64 (0.0%)digits/total - 1583/3072 (51%)
Epoch 88 iter 0: loss - 0.9354799986 correct/total - 0/64 (0.0%)digits/total - 1578/3072 (51%)
Epoch 89 iter 0: loss - 1.1718661785 correct/total - 0/64 (0.0%)digits/total - 1582/3072 (51%)
Epoch 90 iter 0: loss - 1.3912043571 correct/total - 0/64 (0.0%)digits/total - 1578/3072 (51%)
Epoch 91 iter 0: loss - 1.4117470980 correct/total - 0/64 (0.0%)digits/total - 1584/3072 (51%)
Epoch 92 iter 0: loss - 1.2605714798 correct/total - 0/64 (0.0%)digits/total - 1580/3072 (51%)
Epoch 93 iter 0: loss - 0.9489184618 correct/total - 0/64 (0.0%)digits/total - 1568/3072 (51%)
Epoch 94 iter 0: loss - 1.1711614132 correct/total - 0/64 (0.0%)digits/total - 1565/3072 (50%)
Epoch 95 iter 0: loss - 1.4025975466 correct/total - 0/64 (0.0%)digits/total - 1586/3072 (51%)
Epoch 96 iter 0: loss - 1.4447345734 correct/total