# This note book implements a simple classifier. The goal was to investigate the reason(s) for overfitting in our models.

In [None]:
import torch
import torchvision
import torch.nn as nn
from torch import optim
from torchvision import transforms, models, datasets
from torch.utils.data import DataLoader, Dataset

import cv2, glob, numpy as np, pandas as pd

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
def parent_dir(item):
    return [item.rsplit('/')[-2]]

label_dict = {"real":0, "fake":1}

class RebagKhakiDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.folder = folder
        self.items = glob.glob(f'{self.folder}/*/*') 
        self.transform = transform
    def __getitem__(self, ix):
        item = self.items[ix]
        img = cv2.imread(item)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(img)
        target = label_dict[parent_dir(item)[0]]
        return img.float().to(device), torch.tensor([target]).float().to(device)
    
    def __len__(self):
        return len(self.items)

In [None]:
img_size = 256
trn_tfms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(),
    transforms.Resize((img_size,img_size)),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float),
    #transforms.Normalize(mean=[0.199, 0.201, 0.205], std=[0.495, 0.493, 0.493])
    transforms.Normalize((0.5), (0.5))
])

val_tfms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(),
    transforms.Resize((img_size,img_size)),
    transforms.ToTensor(),
    transforms.ConvertImageDtype(torch.float),
    #transforms.Normalize(mean=[0.199, 0.201, 0.205], std=[0.495, 0.493, 0.493])
    transforms.Normalize((0.5), (0.5))
])

In [None]:
folder = "/content/drive/MyDrive/khaki/"
trn_ds = RebagKhakiDataset(folder=folder+"train", transform=trn_tfms)
val_ds = RebagKhakiDataset(folder=folder+"val", transform=val_tfms)

trn_dl = DataLoader(trn_ds, shuffle=True, batch_size=64)
val_dl = DataLoader(val_ds, shuffle=False, batch_size=64)

In [None]:
def convBlock(ni, no):
    return nn.Sequential(
        nn.Dropout(0.2),
        nn.Conv2d(ni, no, kernel_size=3, padding=1, padding_mode='reflect'),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(no),
    )

In [None]:
class RebagClassifier(nn.Module):
    def __init__(self):
        super(RebagClassifier, self).__init__()
        self.features = nn.Sequential(
            convBlock(1,4),
            convBlock(4,8),
            convBlock(8,8),
            nn.Flatten(),
            nn.Linear(8*256*256, 500), nn.ReLU(inplace=True),
            nn.Linear(500, 500), nn.ReLU(inplace=True),
            nn.Linear(500, 1), nn.Sigmoid()
        )

    def forward(self, x):
        out = self.features(x)
        return out

In [None]:
def train_batch(x, y, model, opt, loss_fn):
    model.train()
    prediction = model(x)
    batch_loss = loss_fn(prediction, y)
    batch_loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return batch_loss.item()

In [None]:
@torch.no_grad()
def accuracy(x, y, model):
    model.eval()
    prediction = model(x)
    is_correct = (prediction > 0.3) == y
    return is_correct.cpu().numpy().tolist()

In [None]:
model = RebagClassifier().to(device)
loss_fn = nn.BCELoss()
optimizer = optim.Adam(model.parameters(),lr = 0.001)

In [None]:
from torch_snippets import *
train_losses, train_accuracies = [], []
val_accuracies = []
n_epochs = 100
log = Report(n_epochs)
N = len(trn_dl)

for epoch in range(n_epochs):
    #print(f" epoch {epoch + 1}/5")
    train_epoch_losses, train_epoch_accuracies = [], []
    val_epoch_accuracies = []

    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        batch_loss = train_batch(x, y, model, optimizer, loss_fn)
        train_epoch_losses.append(batch_loss) 
    train_epoch_loss = np.array(train_epoch_losses).mean()
    log.record(epoch+(1+ix)/N, trn_loss=train_epoch_loss, end='\r')

    for ix, batch in enumerate(iter(trn_dl)):
        x, y = batch
        is_correct = accuracy(x, y, model)
        train_epoch_accuracies.extend(is_correct) 
    train_epoch_accuracy = np.mean(train_epoch_accuracies)
    log.record(epoch+(1+ix)/N, trn_acc=train_epoch_accuracy, end='\r')

    for ix, batch in enumerate(iter(val_dl)):
        x, y = batch
        val_is_correct = accuracy(x, y, model)
        val_epoch_accuracies.extend(val_is_correct)
    val_epoch_accuracy = np.mean(val_epoch_accuracies)
    log.record(epoch+(1+ix)/N, val_acc=val_epoch_accuracy, end='\r')

    train_losses.append(train_epoch_loss)
    train_accuracies.append(train_epoch_accuracy)
    val_accuracies.append(val_epoch_accuracy)

EPOCH: 99.176	val_acc: 0.697	(4931.01s - 40.95s remaining)

In [None]:
test_ds = RebagKhakiDataset(folder=folder+"test", transform=trn_tfms)

test_dl = DataLoader(test_ds, shuffle=False, batch_size=1)

In [None]:
dataiter = iter(test_dl)
labels = []
preds = []

for i in range(len(test_dl)):
    img, label = next(dataiter)
    img, label = img.to(device), label.to(device)
    pred = model(img)
    is_correct = ((pred > 0.5) == label).float()
    preds.append(is_correct.item()), labels.append(label.item())

In [None]:
def confusion(prediction, truth):
    """ Returns the confusion matrix for the values in the `prediction` and `truth`
    tensors, i.e. the amount of positions where the values of `prediction`
    and `truth` are
    - 1 and 1 (True Positive)
    - 1 and 0 (False Positive)
    - 0 and 0 (True Negative)
    - 0 and 1 (False Negative)
    """

    confusion_vector = prediction / truth
    # Element-wise division of the 2 tensors returns a new tensor which holds a
    # unique value for each case:
    #   1     where prediction and truth are 1 (True Positive)
    #   inf   where prediction is 1 and truth is 0 (False Positive)
    #   nan   where prediction and truth are 0 (True Negative)
    #   0     where prediction is 0 and truth is 1 (False Negative)

    true_positives = torch.sum(confusion_vector == 1).item()
    false_positives = torch.sum(confusion_vector == float('inf')).item()
    true_negatives = torch.sum(torch.isnan(confusion_vector)).item()
    false_negatives = torch.sum(confusion_vector == 0).item()

    return true_positives, false_positives, true_negatives, false_negatives

In [None]:
confusion(torch.Tensor(preds), torch.Tensor(labels))

(1, 132, 1, 59)