In [1]:
import torch
import numpy
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose
import matplotlib.pyplot as plt
import cv2
import json
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
device

device(type='cuda')

In [3]:
from spiga.models.spiga import SPIGA
from spiga.inference.config import ModelConfig
from spiga.data.loaders.alignments import AlignmentsDataset, get_dataset
from spiga.data.loaders.dl_config import AlignConfig
from spiga.models.cnn.cnn_multitask import MultitaskCNN
from spiga.models.spiga import SPIGA
from spiga.data.loaders.dataloader import get_dataloader

# Dataset

In [4]:
data_cfg = AlignConfig('wflw')
test_cfg = AlignConfig('wflw', mode='test')

In [5]:
train_set = get_dataset(data_cfg)
test_set = get_dataset(test_cfg)

In [6]:
train_loader = get_dataloader(64, data_cfg)
test_loader = get_dataloader(1000, test_cfg)

# Training

In [7]:
class AdaptiveWingLoss(nn.Module):
    def __init__(self, omega=14, theta=0.5, epsilon=1, alpha=2.1):
        super(AdaptiveWingLoss, self).__init__()
        self.omega = omega
        self.theta = theta
        self.epsilon = epsilon
        self.alpha = alpha

    def forward(self, pred, target):
        '''
        :param pred: BxNxHxH
        :param target: BxNxHxH
        :return:
        '''

        y = target
        y_hat = pred
        delta_y = (y - y_hat).abs()
        delta_y1 = delta_y[delta_y < self.theta]
        delta_y2 = delta_y[delta_y >= self.theta]
        y1 = y[delta_y < self.theta]
        y2 = y[delta_y >= self.theta]
        loss1 = self.omega * torch.log(1 + torch.pow(delta_y1 / self.omega, self.alpha - y1))
        A = self.omega * (1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))) * (self.alpha - y2) * (
            torch.pow(self.theta / self.epsilon, self.alpha - y2 - 1)) * (1 / self.epsilon)
        C = self.theta * A - self.omega * torch.log(1 + torch.pow(self.theta / self.epsilon, self.alpha - y2))
        loss2 = A * delta_y2 - C
        return (loss1.sum() + loss2.sum()) / (len(loss1) + len(loss2))

In [8]:
def train(epoch, model, optimizer, train_loader, train_losses, train_counter):
  model.train()
  for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    data, target = data.to(device), target.to(device)
    output = model(data)
    total_loss = 0
    for i in range(4):
        loss_coord = nn.SmoothL1Loss(output['VisualField'][i], target)
        loss_edge = AdaptiveWingLoss(output['Heatmaps'][i][1], target)
        loss_points = AdaptiveWingLoss(output['VisualField'][0], target)
        total_loss+= (loss_coord + (loss_points + loss_edge)) * pow(2, i)
    total_loss.backward()
    optimizer.step()

    if batch_idx % 10 == 0:
      print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
        epoch, batch_idx * len(data), len(train_loader.dataset),
        100. * batch_idx / len(train_loader), total_loss.item()))
      train_losses.append(total_loss.item())
      train_counter.append(
        (batch_idx*64) + ((epoch-1)*len(train_loader.dataset)))

def test(epoch, model, test_loader, test_losses):
  model.eval()
  test_loss = 0
  correct = 0
  with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        test_loss += nn.CrossEntropyLoss()(output, target).item()*data.size(0)
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    test_losses.append(test_loss)

      # Print out accuracy of the trained model on the test set after each epoch.
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [15]:
for (image, label) in list(enumerate(train_loader))[3:4]:
    print(image)

In [17]:
4 * [1, 2, 3]

[1, 2, 3, 1, 2, 3, 1, 2, 3, 1, 2, 3]

In [97]:
n_epochs = 5
model = MultitaskCNN(pose_req = False).to(device)
optimizer = optim.AdamW(model.parameters(),lr=0.01)

In [98]:
train_losses = []
train_counter = []
test_losses = []
test_counter = []

In [None]:
for i in (1, n_epochs + 1):
    train(i, model, optimizer, train_loader, train_losses, train_counter)

In [83]:
full_model = SPIGA()
full_model.load_state_dict(torch.load('spiga_wflw.pt'))

<All keys matched successfully>