In [10]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import cv2
import os
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch import optim
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
from torchvision import transforms
from utils import hand_crop


In [48]:
class DatasetPlus(Dataset):
    def __init__(self, root_img, root_data, width, hight, transform=None):
        self.root_img = root_img
        self.root_data = root_data
        self.width = width
        self.hight = hight
        self.transform = transform
        # labels are stored in a csv file
        self.labels = pd.read_csv(self.root_data)
        self.imgs = [image for image in sorted(
            os.listdir(self.root_img)) if image[-4:] == '.jpg']
        self.len = len(self.imgs)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        img_name = self.imgs[idx]
        img_path = os.path.join(self.root_img, img_name)
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB).astype(np.float32)
        img = cv2.resize(img, (self.width, self.hight), cv2.INTER_AREA)
        img = np.array(img) / 255.0

        if self.transform is not None:
            img = self.transform(img)

        img_id = int(img_name[6:-4])
        label = self.labels.where(self.labels['ID'] == img_id)['Label'].dropna().to_numpy()[0]

        label = torch.tensor(label, dtype=torch.float32)


        return img, label

In [64]:
class Net(nn.Module):
    def __init__(self, h, w):
        super().__init__()
        nw = (((w - 4) // 2) -4) // 2
        nh = (((h - 4) // 2) -4) // 2
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * nh * nw, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

In [65]:
transform = transforms.Compose(
    [transforms.ToTensor(),])
root_img = 'data/images/'
root_label = 'data/metadata/PSL_dataset.csv'
ds = DatasetPlus(root_img, root_label, 224, 224, transform=transform)

In [74]:
model = Net(224, 224)

trainloader = DataLoader(ds, batch_size=4, shuffle=True)

criterion = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [75]:
def train_model(epochs):
    for epoch in range(epochs): 
        losses = 0.0 
        for i, data in enumerate(trainloader, 0):
            optimizer.zero_grad()
            img, label = data
            yhat = model(img)
            yhat = yhat.view(-1)
            loss = criterion(yhat, label)
            loss.backward()
            optimizer.step()
            losses += loss.item()
            # if i % 5 == 99:
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {losses:.3f}')
            losses = 0.0
        

In [76]:
train_model(5)

[1,     1] loss: 1.249
[1,     2] loss: 1.190
[1,     3] loss: 1.129
[1,     4] loss: 1.062
[1,     5] loss: 0.579
[1,     6] loss: 0.528
[1,     7] loss: 0.497
[1,     8] loss: 0.836
[1,     9] loss: 1.141
[1,    10] loss: 0.391
[1,    11] loss: 0.028
[1,    12] loss: 0.671
[1,    13] loss: 0.323
[1,    14] loss: 0.925
[1,    15] loss: 0.578
[1,    16] loss: 0.851
[1,    17] loss: 0.835
[1,    18] loss: 0.542
[1,    19] loss: 0.803
[1,    20] loss: 0.533
[1,    21] loss: 0.263
[1,    22] loss: 0.517
[1,    23] loss: 0.768
[1,    24] loss: 0.767
[1,    25] loss: 0.254
[1,    26] loss: 0.762
[1,    27] loss: 0.758
[1,    28] loss: 0.506
[1,    29] loss: 0.252
[1,    30] loss: 1.007
[2,     1] loss: 0.251
[2,     2] loss: 0.251
[2,     3] loss: 0.251
[2,     4] loss: 0.502
[2,     5] loss: 0.752
[2,     6] loss: 0.501
[2,     7] loss: 0.752
[2,     8] loss: 0.751
[2,     9] loss: 0.501
[2,    10] loss: 0.751
[2,    11] loss: 0.250
[2,    12] loss: 0.501
[2,    13] loss: 0.250
[2,    14] 