In [1]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm

from datset import FacesDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision.models import resnet18
from torch import nn

import torch.optim as optim

In [2]:
def show_datapoint(image, label):
    l = label.reshape(4, 2)
    X = l[:, 0]
    Y = l[:, 1]
    plt.imshow(image.reshape(96,96), cmap='Greys')
    plt.scatter(X, Y, c='black')
    plt.show()

def custom_loss(output, label):
    loss = 0
    for i in range(0, 8, 2):
        loss += torch.linalg.norm(output[i:i+1] - label[i:i+1])**2
    return loss / 4

In [13]:
model = torch.load('./model.pth')

In [14]:
print(model)

ResNet(
  (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [16]:
train_df = pd.read_csv("training.csv")

conds = np.where((train_df.isnull().sum(axis=0) < 100))
train_df = train_df.iloc[:, conds[0]]
train_df = train_df.dropna(axis=0)


traindataset = FacesDataset(train_df.iloc[:500])
traindataloader = DataLoader(traindataset, batch_size=64, shuffle=False, num_workers=4)

In [None]:
criterion = custom_loss
optimizer = optim.SGD(model.parameters(), lr=0.0005, momentum=0.9)

for epoch in range(1000):
    epoch_loss = []

    for image, label in tqdm(traindataloader):

        model.float()
        output = model(image.float())

        loss = criterion(output, label.float())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_loss.append(loss.item())



    mean_loss = sum(epoch_loss)/len(epoch_loss)
    print(f'Epoch loss: {mean_loss}')

100%|██████████| 8/8 [00:04<00:00,  1.73it/s]


Epoch loss: 45.603787422180176


100%|██████████| 8/8 [00:04<00:00,  1.73it/s]


Epoch loss: 48.06929302215576


100%|██████████| 8/8 [00:04<00:00,  1.61it/s]


Epoch loss: 45.19218921661377


100%|██████████| 8/8 [00:05<00:00,  1.44it/s]


Epoch loss: 42.69088816642761


100%|██████████| 8/8 [00:06<00:00,  1.28it/s]


Epoch loss: 43.28927230834961


100%|██████████| 8/8 [00:06<00:00,  1.20it/s]


Epoch loss: 43.26063311100006


100%|██████████| 8/8 [00:05<00:00,  1.34it/s]


Epoch loss: 43.10670483112335


100%|██████████| 8/8 [00:06<00:00,  1.28it/s]


Epoch loss: 42.77736043930054


100%|██████████| 8/8 [00:06<00:00,  1.26it/s]


Epoch loss: 42.85496163368225


100%|██████████| 8/8 [00:05<00:00,  1.37it/s]


Epoch loss: 42.79262638092041


100%|██████████| 8/8 [00:05<00:00,  1.37it/s]


Epoch loss: 42.69676923751831


100%|██████████| 8/8 [00:05<00:00,  1.35it/s]


Epoch loss: 42.55170726776123


100%|██████████| 8/8 [00:05<00:00,  1.33it/s]


Epoch loss: 42.49425530433655


100%|██████████| 8/8 [00:05<00:00,  1.35it/s]


Epoch loss: 42.41399836540222


100%|██████████| 8/8 [00:05<00:00,  1.36it/s]


Epoch loss: 42.331578493118286


100%|██████████| 8/8 [00:05<00:00,  1.35it/s]


Epoch loss: 42.25087785720825


100%|██████████| 8/8 [00:05<00:00,  1.34it/s]


Epoch loss: 42.176053047180176


100%|██████████| 8/8 [00:05<00:00,  1.34it/s]


Epoch loss: 42.09950113296509


100%|██████████| 8/8 [00:05<00:00,  1.35it/s]


Epoch loss: 42.02175784111023


100%|██████████| 8/8 [00:05<00:00,  1.36it/s]


Epoch loss: 41.94133019447327


100%|██████████| 8/8 [00:05<00:00,  1.38it/s]


Epoch loss: 41.868797302246094


100%|██████████| 8/8 [00:05<00:00,  1.36it/s]


Epoch loss: 41.802106976509094


100%|██████████| 8/8 [00:06<00:00,  1.33it/s]


Epoch loss: 41.737876653671265


100%|██████████| 8/8 [00:05<00:00,  1.38it/s]


Epoch loss: 41.664140939712524


100%|██████████| 8/8 [00:06<00:00,  1.33it/s]


Epoch loss: 41.59256649017334


100%|██████████| 8/8 [00:05<00:00,  1.36it/s]


Epoch loss: 41.52998328208923


100%|██████████| 8/8 [00:05<00:00,  1.37it/s]


Epoch loss: 41.461145997047424


100%|██████████| 8/8 [00:05<00:00,  1.37it/s]


Epoch loss: 41.39738059043884


100%|██████████| 8/8 [00:05<00:00,  1.34it/s]


Epoch loss: 41.329360008239746


100%|██████████| 8/8 [00:06<00:00,  1.33it/s]


Epoch loss: 41.26472055912018


100%|██████████| 8/8 [00:06<00:00,  1.32it/s]


Epoch loss: 41.20114707946777


100%|██████████| 8/8 [00:05<00:00,  1.34it/s]


Epoch loss: 41.13448226451874


100%|██████████| 8/8 [00:06<00:00,  1.30it/s]


Epoch loss: 41.06927514076233


100%|██████████| 8/8 [00:07<00:00,  1.07it/s]


Epoch loss: 41.00735950469971


100%|██████████| 8/8 [00:07<00:00,  1.10it/s]


Epoch loss: 40.9475474357605


100%|██████████| 8/8 [00:06<00:00,  1.25it/s]


Epoch loss: 40.88540768623352


100%|██████████| 8/8 [00:06<00:00,  1.29it/s]


Epoch loss: 40.82126021385193


100%|██████████| 8/8 [00:06<00:00,  1.28it/s]


Epoch loss: 40.756348609924316


100%|██████████| 8/8 [00:05<00:00,  1.34it/s]


Epoch loss: 40.69901704788208


100%|██████████| 8/8 [00:06<00:00,  1.25it/s]


Epoch loss: 40.63607847690582


100%|██████████| 8/8 [00:05<00:00,  1.34it/s]


Epoch loss: 40.574687123298645


100%|██████████| 8/8 [00:05<00:00,  1.34it/s]


Epoch loss: 40.51020383834839


100%|██████████| 8/8 [00:05<00:00,  1.35it/s]


Epoch loss: 40.450173020362854


100%|██████████| 8/8 [00:05<00:00,  1.37it/s]


Epoch loss: 40.39639163017273


100%|██████████| 8/8 [00:06<00:00,  1.24it/s]


Epoch loss: 40.331371545791626


100%|██████████| 8/8 [00:06<00:00,  1.31it/s]


Epoch loss: 40.266706466674805


100%|██████████| 8/8 [00:06<00:00,  1.27it/s]


Epoch loss: 40.201621294021606


100%|██████████| 8/8 [00:06<00:00,  1.21it/s]


Epoch loss: 40.1381311416626


100%|██████████| 8/8 [00:06<00:00,  1.29it/s]


Epoch loss: 40.07254600524902


100%|██████████| 8/8 [00:06<00:00,  1.33it/s]


Epoch loss: 40.00542974472046


100%|██████████| 8/8 [00:05<00:00,  1.38it/s]


Epoch loss: 39.94177174568176


100%|██████████| 8/8 [00:07<00:00,  1.12it/s]


Epoch loss: 39.88302505016327


100%|██████████| 8/8 [00:06<00:00,  1.23it/s]


Epoch loss: 39.815815567970276


100%|██████████| 8/8 [00:06<00:00,  1.14it/s]


Epoch loss: 39.7592910528183


100%|██████████| 8/8 [00:06<00:00,  1.25it/s]


Epoch loss: 39.692258477211


100%|██████████| 8/8 [00:06<00:00,  1.17it/s]


Epoch loss: 39.63267183303833


100%|██████████| 8/8 [00:07<00:00,  1.04it/s]


Epoch loss: 39.57444500923157


100%|██████████| 8/8 [00:06<00:00,  1.18it/s]


Epoch loss: 39.5112841129303


100%|██████████| 8/8 [00:08<00:00,  1.03s/it]


Epoch loss: 39.4549617767334


100%|██████████| 8/8 [00:08<00:00,  1.06s/it]


Epoch loss: 39.396817088127136


100%|██████████| 8/8 [00:08<00:00,  1.02s/it]


Epoch loss: 39.33440673351288


100%|██████████| 8/8 [00:07<00:00,  1.04it/s]


Epoch loss: 39.271034479141235


100%|██████████| 8/8 [00:08<00:00,  1.00s/it]


Epoch loss: 39.20835018157959


100%|██████████| 8/8 [00:07<00:00,  1.06it/s]


Epoch loss: 39.132060050964355


100%|██████████| 8/8 [00:07<00:00,  1.08it/s]


Epoch loss: 39.068570375442505


100%|██████████| 8/8 [00:06<00:00,  1.18it/s]


Epoch loss: 39.00487732887268


100%|██████████| 8/8 [00:06<00:00,  1.18it/s]


Epoch loss: 38.946125507354736


100%|██████████| 8/8 [00:07<00:00,  1.07it/s]


Epoch loss: 38.88885474205017


100%|██████████| 8/8 [00:07<00:00,  1.02it/s]


Epoch loss: 38.82313799858093


100%|██████████| 8/8 [00:07<00:00,  1.06it/s]


Epoch loss: 38.75698900222778


100%|██████████| 8/8 [00:07<00:00,  1.03it/s]


Epoch loss: 38.68799710273743


100%|██████████| 8/8 [00:07<00:00,  1.12it/s]


Epoch loss: 38.625680446624756


100%|██████████| 8/8 [00:08<00:00,  1.05s/it]


Epoch loss: 38.566195011138916


100%|██████████| 8/8 [00:07<00:00,  1.03it/s]


Epoch loss: 38.50521731376648


100%|██████████| 8/8 [00:07<00:00,  1.12it/s]


Epoch loss: 38.44160771369934


100%|██████████| 8/8 [00:08<00:00,  1.01s/it]


Epoch loss: 38.3783198595047


100%|██████████| 8/8 [00:09<00:00,  1.17s/it]


Epoch loss: 38.313145875930786


 12%|█▎        | 1/8 [00:02<00:17,  2.43s/it]