In [1]:
import os

import torch
import torch.nn.functional as F
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from torch import nn
from torch.utils.data import DataLoader, random_split
from torchmetrics.functional import accuracy
from pytorch_lightning.loggers import TensorBoardLogger
from torchvision import transforms, datasets
import numpy as np
import timm
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
timm.list_models('resnet*', pretrained=True)

['resnet18',
 'resnet18d',
 'resnet26',
 'resnet26d',
 'resnet26t',
 'resnet32ts',
 'resnet33ts',
 'resnet34',
 'resnet34d',
 'resnet50',
 'resnet50_gn',
 'resnet50d',
 'resnet51q',
 'resnet61q',
 'resnet101',
 'resnet101d',
 'resnet152',
 'resnet152d',
 'resnet200d',
 'resnetblur50',
 'resnetrs50',
 'resnetrs101',
 'resnetrs152',
 'resnetrs200',
 'resnetrs270',
 'resnetrs350',
 'resnetrs420',
 'resnetv2_50',
 'resnetv2_50x1_bit_distilled',
 'resnetv2_50x1_bitm',
 'resnetv2_50x1_bitm_in21k',
 'resnetv2_50x3_bitm',
 'resnetv2_50x3_bitm_in21k',
 'resnetv2_101',
 'resnetv2_101x1_bitm',
 'resnetv2_101x1_bitm_in21k',
 'resnetv2_101x3_bitm',
 'resnetv2_101x3_bitm_in21k',
 'resnetv2_152x2_bit_teacher',
 'resnetv2_152x2_bit_teacher_384',
 'resnetv2_152x2_bitm',
 'resnetv2_152x2_bitm_in21k',
 'resnetv2_152x4_bitm',
 'resnetv2_152x4_bitm_in21k']

In [3]:
assert torch.cuda.device_count() >= 1

In [4]:
class HelenDataset(datasets.VisionDataset):
    def __init__(self, root, loader = datasets.folder.default_loader, transform = None, target_transform = None):
        super().__init__(root, transform=transform, target_transform=target_transform)
        annotation_files = os.listdir(os.path.join(root, 'annotation'))
        self.annotations = []
        for fname in annotation_files:
            with open(os.path.join(root, 'annotation', fname), "r") as file:
                lines = file.readlines()
                id = lines[0].strip()
                cords = []
                for i in lines[1:]:
                    cords.append([ float(j.strip()) for j in i.split(',') ])
                self.annotations.append({
                    'id': id,
                    'cords': cords
                })
        self.images = {}
        for fname in os.listdir(os.path.join(root, 'img')):
            id, _ = os.path.splitext(fname)
            self.images[id] = os.path.join(root, 'img', fname)
    
    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, index):
        annotation = self.annotations[index]
        img_name = self.images[annotation['id']]
        img = Image.open(img_name)
        old_size = np.array(img.size)
        new_size = np.array((512,512))
        scale = torch.Tensor(new_size / old_size)
        img = img.resize(new_size)
        cords = torch.Tensor(annotation['cords'])
        #print(cords)
        #print(scale.repeat((cords.shape[0], 1)))
        cords *= scale.repeat((cords.shape[0], 1))
        return self.transform(img), torch.flatten(cords)

In [5]:
class HelenDataModule(LightningDataModule):
    def __init__(self, data_dir: str = '../data/helen/', batch_size = 1, num_workers=12):
        super().__init__()
        self.batch_size = batch_size
        self.data_dir = data_dir
        self.num_workers = num_workers
        self.transforms = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.13), (0.3081)) # TODO: just being lazy, I took them as they where from tutorial, calculate the real ones later
        ])
        self.target_transforms = None
        self.train = None
        self.val = None
        self.test = None
    def prepare_data(self):
        pass

    def _split(self, dataset, prop):
        a = int(len(dataset) * prop)
        b = len(dataset) - a
        return random_split(dataset, (a, b))
    
    def setup(self, stage):
        if self.train:
            return
        
        self.dataset = HelenDataset(self.data_dir, transform=self.transforms, target_transform=self.target_transforms)
        
        self.train, rest = self._split(self.dataset, 0.8)
        self.test, self.val = self._split(rest, 0.5)
    
    def train_dataloader(self):
        return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def val_dataloader(self):
        return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def test_dataloader(self):
        return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers)

In [6]:
class FaceMorphingModel(LightningModule):
    def __init__(self, number_of_cordinates: int = 388, base_model: str = 'resnet50'):
        super().__init__()
        self.model = timm.create_model(base_model, pretrained=True)
        in_channels = self.model.get_classifier().in_features
        self.model.fc = nn.Linear(in_features=in_channels, out_features=number_of_cordinates, bias=True)
    
    def forward(self, x):
        return self.model(x)
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        self.trainer.logger.experiment.add_scalar('loss', loss.item(), batch_idx)
        #self.log('loss', loss.item())
        return loss
        
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.mse_loss(y_hat, y)
        return loss
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [7]:
data = HelenDataModule(batch_size=6)

In [8]:
logger = TensorBoardLogger('tb_log', name='facemorphing')

In [9]:
trainer = Trainer(gpus=1, max_epochs=10, logger=logger)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs


In [10]:
model = FaceMorphingModel()

In [11]:
trainer.fit(model, data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | model | ResNet | 24.3 M
---------------------------------
24.3 M    Trainable params
0         Non-trainable params
24.3 M    Total params
97.212    Total estimated model params size (MB)


Epoch 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                  | 311/350 [05:01<00:37,  1.03it/s, loss=3.85e+03, v_num=2]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                                                                                                                                                                                       | 0/39 [00:00<?, ?it/s][A
Epoch 0:  89%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                 | 313/350 [05:02<00:35,  1.03it/s, loss=3.85e+03, v_num=2][A
Validating:   5%|█████████▊                                                                                                                                                                                     | 2/39 [00:01<00:18,  2.06it/s]

In [12]:
trainer.test(model, data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:11<00:00,  3.57it/s]--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{}
--------------------------------------------------------------------------------
Testing: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:11<00:00,  3.30it/s]


[{}]

In [13]:
trainer.validate(model, data)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:11<00:00,  3.57it/s]--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{}
--------------------------------------------------------------------------------
Validating: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 39/39 [00:11<00:00,  3.38it/s]


[{}]