In [1]:
# %%capture
# !pip install gdown --force-reinstall
# !gdown --id 1TEOjzSBN6UZc1WQwb_huEhdFl1XyZel4
# !unzip KDEF_CROPPED_ALIGNED.zip

In [2]:
# %%capture
# !gdown --id 1Bd87admxOZvbIOAyTkGEntsEz3fyMt7H

In [3]:
# %%capture
# !pip install mlxtend
# !pip install dlib
# !pip install scikit-image

In [4]:
import torch
torch.multiprocessing.set_start_method('spawn')

In [5]:
import imageio
import matplotlib.pyplot as plt
from mlxtend.image import extract_face_landmarks
import cv2

# MagFace

In [6]:
import torch
from torch import nn
# from torchvision.models.utils import load_state_dict_from_url

__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100']


def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class IBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                 base_width=64, dilation=1):
        super(IBasicBlock, self).__init__()
        if groups != 1 or base_width != 64:
            raise ValueError(
                'BasicBlock only supports groups=1 and base_width=64')
        if dilation > 1:
            raise NotImplementedError(
                "Dilation > 1 not supported in BasicBlock")
        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
        self.bn1 = nn.BatchNorm2d(inplanes, eps=2e-05, momentum=0.9)
        self.conv1 = conv3x3(inplanes, planes)
        self.bn2 = nn.BatchNorm2d(planes, eps=2e-05, momentum=0.9)
        self.prelu = nn.PReLU(planes)
        self.conv2 = conv3x3(planes, planes, stride)
        self.bn3 = nn.BatchNorm2d(planes, eps=2e-05, momentum=0.9)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x
        out = self.bn1(x)
        out = self.conv1(out)
        out = self.bn2(out)
        out = self.prelu(out)
        out = self.conv2(out)
        out = self.bn3(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity

        return out


class IResNet(nn.Module):
    fc_scale = 7 * 7

    def __init__(self, block, layers, num_classes=512, zero_init_residual=False,
                 groups=1, width_per_group=64, replace_stride_with_dilation=None):
        super(IResNet, self).__init__()

        self.inplanes = 64
        self.dilation = 1
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]
        if len(replace_stride_with_dilation) != 3:
            raise ValueError("replace_stride_with_dilation should be None "
                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
        self.groups = groups
        self.base_width = width_per_group
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1,
                               bias=False)
        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=2e-05, momentum=0.9)
        self.prelu = nn.PReLU(self.inplanes)
        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                       dilate=replace_stride_with_dilation[0])
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                       dilate=replace_stride_with_dilation[1])
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                       dilate=replace_stride_with_dilation[2])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))

        self.bn2 = nn.BatchNorm2d(
            512 * block.expansion, eps=2e-05, momentum=0.9)
        self.dropout = nn.Dropout2d(p=0.4, inplace=True)
        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_classes)
        self.features = nn.BatchNorm1d(num_classes, eps=2e-05, momentum=0.9)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, IBasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                nn.BatchNorm2d(planes * block.expansion,
                               eps=2e-05, momentum=0.9),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                            self.base_width, previous_dilation))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=self.groups,
                                base_width=self.base_width, dilation=self.dilation))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.bn2(x)
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        x = self.features(x)

        return x


def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
    model = IResNet(block, layers, **kwargs)
    # if pretrained:
    # state_dict = load_state_dict_from_url(model_urls[arch],
    #                                        progress=progress)
    # model.load_state_dict(state_dict)
    return model


# def iresnet18(pretrained=False, progress=True, **kwargs):
#     return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained, progress,
#                     **kwargs)


# def iresnet34(pretrained=False, progress=True, **kwargs):
#     return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained, progress,
#                     **kwargs)


# def iresnet50(pretrained=False, progress=True, **kwargs):
#     return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained, progress,
#                     **kwargs)


def iresnet100(pretrained=False, progress=True, **kwargs):
    return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained, progress, **kwargs)

In [7]:
from collections import OrderedDict
import sys

def clean_dict_inf(model, state_dict):
    _state_dict = OrderedDict()
    for k, v in state_dict.items():
        # # assert k[0:1] == 'features.module.'
        # new_k = 'features.'+'.'.join(k.split('.')[2:])
        new_k = '.'.join(k.split('.')[2:])
        if new_k in model.state_dict().keys() and v.size() == model.state_dict()[new_k].size():
            _state_dict[new_k] = v
        
        new_kk = '.'.join(k.split('.')[1:])
        if new_kk in model.state_dict().keys() and v.size() == model.state_dict()[new_kk].size():
            _state_dict[new_kk] = v
            
    num_model = len(model.state_dict().keys())
    num_ckpt = len(_state_dict.keys())
    if num_model != num_ckpt:
        sys.exit("=> Not all weights loaded, model params: {}, loaded params: {}".format(
            num_model, num_ckpt))
    return _state_dict
    
    
def load_dict_inf(model, resume, cpu_mode=False):
    if os.path.isfile(resume):
        print(f'=> loading pth from {resume} ...')
        if cpu_mode:
            checkpoint = torch.load(resume, map_location=torch.device("cpu"))
        else:
            checkpoint = torch.load(resume)
        _state_dict = clean_dict_inf(model, checkpoint['state_dict'])
        model_dict = model.state_dict()
        model_dict.update(_state_dict)
        model.load_state_dict(model_dict)
        # delete to release more space
        del checkpoint
        del _state_dict
    else:
        sys.exit("=> No checkpoint found at '{}'".format(resume))
    return model


# Dataset Preparation

In [8]:
# !rm -rf /media/soroushh/Storage2/matrices
# !mkdir -p /media/soroushh/Storage2/matrices/training
# !mkdir -p /media/soroushh/Storage2/matrices/evaluation

# magface = iresnet100(pretrained=False, num_classes=512)
# magface = load_dict_inf(magface, "./magface_epoch_00025.pth")
# magface = magface.to("cuda")
# magface.eval()

# def get_landmarks(image_paths, image_size):
#     database = {}
#     for image_path in tqdm.tqdm(image_paths):
#         filename = image_path.split("/")[-1].split(".")[0]
#         if not bool(re.match(r"\w{2}\d{2}\w{2}\w+", filename)):
#             continue

#         name = re.findall(r"(\w{2}\d{2})\w{2}\w+", filename)[0]
#         pose_label = re.findall(r"\w{2}\d{2}\w{2}(\w+)", filename)[0]
#         emotion_label = re.findall(r"\w{2}\d{2}(\w{2})\w+", filename)[0]

#         # img = cv2.imread(image_path)
#         img = cv2.imread(image_path.replace("CROPPED_ALIGNED", "KDEF")) # get real images from KDEF for landmarks
#         img = cv2.resize(img, (image_size, image_size))
#         img = cv2.copyMakeBorder(img, 10, 10, 10, 10, cv2.BORDER_CONSTANT)
#         pose_img = get_pose_image(img, image_size)

#         if pose_img is None:
#             continue

#         person_db = database.get(name, [])
#         person_db.append((emotion_label, pose_label, pose_img, image_path))
#         database[name] = person_db
        
#     return database

# def get_pose_image(img, image_size):
#     landmarks = extract_face_landmarks(img)

#     if np.all(landmarks == 0) or landmarks is None:
#         return None

#     landmarks[:, 0][landmarks[:, 0] >= image_size] = image_size - 1
#     landmarks[:, 0][landmarks[:, 0] < 0] = 0
#     landmarks[:, 1][landmarks[:, 1] >= image_size] = image_size - 1
#     landmarks[:, 1][landmarks[:, 1] < 0] = 0

#     # landmarks = landmarks[[30, 40, 46, 48, 56]]
#     # print(landmarks.shape)

#     pose_img = np.zeros((image_size, image_size))
#     pose_img[landmarks[:, 1], landmarks[:, 0]] = 1
#     pose_img = pose_img[:, :, np.newaxis]
#     pose_img = pose_img[landmarks[:, 1].min():landmarks[:, 1].max(), landmarks[:, 0].min():landmarks[:, 0].max()]
#     pose_img = cv2.resize(pose_img, (image_size, image_size))
#     pose_img = np.where(pose_img > 0.5, 1, 0)

#     # plt.imshow(pose_img)
#     # plt.show()

#     pose_img = pose_img[:, :, np.newaxis]

#     return pose_img

# def get_magface_embedding(image_path):
#     img = cv2.imread(image_path)

#     if img.shape[:2] != [112, 112]:
#         img = cv2.resize(img, (112, 112))

#     img = img[np.newaxis] / 255.
#     img = torch.Tensor(img).to("cuda").to(torch.float32)
#     img = img.permute(0, 3, 1, 2)
#     embedding = magface(img)
#     embedding = embedding.detach().cpu().numpy()
#     embedding = np.squeeze(embedding)

#     return embedding

# def prepare_final_database(database):
#     db_keys = list(database.keys())

#     final_database = []
#     for person1_name in db_keys:
#         person1 = database[person1_name]
#         for triple_to_reconstruct_index, triple_to_reconstruct in enumerate(person1):
#             pose_img_to_reconstruct = triple_to_reconstruct[2]

#             available_poses_input = []
#             available_poses_output = []
#             for person2_name in db_keys:
#                 person2 = database[person2_name]
#                 available_poses_input = list(filter(lambda index: person2[index][0] == triple_to_reconstruct[0], range(len(person2))))
#                 available_poses_output = list(filter(lambda index: person2[index][0] == triple_to_reconstruct[0] and person2[index][1] == triple_to_reconstruct[1], range(len(person2))))

#                 for triple_to_change_input_index in available_poses_input:
#                     for triple_to_change_output_index in available_poses_output:
#                         final_database.append((person1_name, triple_to_reconstruct_index, person2_name, triple_to_change_input_index, triple_to_change_output_index))

#                 if len(final_database) >= 30000:
#                     break

#             if len(final_database) >= 30000:
#                 break

#         if len(final_database) >= 30000:
#             break
            
#     return final_database

# def save_files(database, final_database, type_):
#     for idx in tqdm.tqdm(range(len(final_database))):
#         tup = final_database[idx]

#         if os.path.isfile(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_pose_img_to_reconstruct.npy') and os.path.isfile(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_input.npy') and os.path.isfile(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_output.npy'):
#             continue

#         pose_img_to_reconstruct = database[tup[0]][tup[1]][2]
#         image_path = database[tup[2]][tup[3]][3]
#         embedding_input = get_magface_embedding(image_path)
#         image_path = database[tup[2]][tup[4]][3]
#         embedding_output = get_magface_embedding(image_path)

#         with open(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_pose_img_to_reconstruct.npy', 'wb') as f:
#             np.save(f, pose_img_to_reconstruct)

#         with open(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_input.npy', 'wb') as f:
#             np.save(f, embedding_input)

#         with open(f'/media/soroushh/Storage2/matrices/{type_}/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_output.npy', 'wb') as f:
#             np.save(f, embedding_output)

#     with open(f'/media/soroushh/Storage2/database_{type_}.pickle', 'wb') as handle:
#         pickle.dump(database, handle, protocol=pickle.HIGHEST_PROTOCOL)


# people_paths = glob.glob(os.path.join("./CROPPED_ALIGNED/*"))
# people_paths = np.array(people_paths)
# np.random.shuffle(people_paths)
# people_paths = people_paths.tolist()

# index = int(0.8 * len(people_paths))
# trainset = people_paths[:index]
# evalset = people_paths[index:]

# image_paths = [img_path for person_path in trainset for img_path in glob.glob(person_path + "/*.JPG")]
# database = get_landmarks(image_paths, 112)
# final_database = prepare_final_database(database)
# save_files(database, final_database, "training")

# image_paths = [img_path for person_path in evalset for img_path in glob.glob(person_path + "/*.JPG")]
# database = get_landmarks(image_paths, 112)
# final_database = prepare_final_database(database)
# save_files(database, final_database, "evaluation")

In [9]:
import glob
import numpy as np
import torch
import os
import tqdm
import re
import pickle
import random


class CustomDataSet(torch.utils.data.Dataset):

    def __init__(self, type_="train"):
        self.type_ = type_
        
        all_files = list(set([re.findall(r"(.+)_(\d+)_(.+)_(\d+)_(\d+)_.+", filename)[0] for filename in os.listdir(f"/media/soroushh/Storage2/matrices/training")]))
        if self.type_ == "train":
            all_files = all_files[:int((2/3) * len(all_files))]
        else:
            all_files = all_files[int((2/3) * len(all_files)):]
        
        self.final_database = []
        for person1_name, triple_to_reconstruct_index, person2_name, triple_to_change_input_index, triple_to_change_output_index in all_files:
            triple_to_reconstruct_index = int(triple_to_reconstruct_index)
            triple_to_change_input_index = int(triple_to_change_input_index)
            triple_to_change_output_index = int(triple_to_change_output_index)
            self.final_database.append((person1_name, triple_to_reconstruct_index, person2_name, triple_to_change_input_index, triple_to_change_output_index))
            
        with open(f'/media/soroushh/Storage2/database_training.pickle', 'rb') as handle:
            self.database = pickle.load(handle)
        
        ids = list(self.database.keys())
        self.ids_dict = {id:idx for idx, id in enumerate(list(self.database.keys()))}
        self.emotions_dict = {emo:idx for idx, emo in enumerate(np.unique([triplet[0] for id in ids for triplet in self.database[id]]).tolist())}
        self.poses_dict = {pose:idx for idx, pose in enumerate(np.unique([triplet[1] for id in ids for triplet in self.database[id]]).tolist())}

    def __len__(self):
        return len(self.final_database)
        
    def __getitem__(self, idx):
        tup = self.final_database[idx]
        
        with open(f'/media/soroushh/Storage2/matrices/training/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_pose_img_to_reconstruct.npy', 'rb') as f:
            pose_img_to_reconstruct = np.load(f)
            
        with open(f'/media/soroushh/Storage2/matrices/training/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_input.npy', 'rb') as f:
            embedding_input = np.load(f)
            
        with open(f'/media/soroushh/Storage2/matrices/training/{tup[0]}_{tup[1]}_{tup[2]}_{tup[3]}_{tup[4]}_embedding_output.npy', 'rb') as f:
            embedding_output = np.load(f)
            
        expected_pose_label = self.database[tup[2]][tup[4]][1]
        input_id = tup[2]
        input_emo = self.database[tup[2]][tup[3]][0]
        
        while True:
            available_negative_choices = list(filter(lambda item: item[2] != input_id, self.final_database))
            if len(available_negative_choices) == 0:
                continue
                
            neg_tup = random.choice(available_negative_choices)
            break
            
        with open(f'/media/soroushh/Storage2/matrices/training/{neg_tup[0]}_{neg_tup[1]}_{neg_tup[2]}_{neg_tup[3]}_{neg_tup[4]}_embedding_input.npy', 'rb') as f:
            negative_embedding_input = np.load(f)
            
        return pose_img_to_reconstruct, embedding_input, embedding_output, negative_embedding_input

In [10]:
train_dataset = CustomDataSet(type_="train")
val_dataset = CustomDataSet(type_="val")

In [11]:
len(train_dataset), len(val_dataset)

(20001, 10001)

In [12]:
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=256, shuffle=True)

# Network Arch.

In [13]:
def img_to_patch(x, patch_size, flatten_channels=True):
    """
    Inputs:
        x - torch.Tensor representing the image of shape [B, C, H, W]
        patch_size - Number of pixels per dimension of the patches (integer)
        flatten_channels - If True, the patches will be returned in a flattened format
                           as a feature vector instead of a image grid.
    """
    B, C, H, W = x.shape
    x = x.reshape(B, C, H//patch_size, patch_size, W//patch_size, patch_size)
    x = x.permute(0, 2, 4, 1, 3, 5) # [B, H', W', C, p_H, p_W]
    x = x.flatten(1,2)              # [B, H'*W', C, p_H, p_W]
    if flatten_channels:
        x = x.flatten(2,4)          # [B, H'*W', C*p_H*p_W]
    return x


class AttentionBlock(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_heads, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of input and attention feature vectors
            hidden_dim - Dimensionality of hidden layer in feed-forward network
                         (usually 2-4x larger than embed_dim)
            num_heads - Number of heads to use in the Multi-Head Attention block
            dropout - Amount of dropout to apply in the feed-forward network
        """
        super().__init__()

        self.layer_norm_1 = nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads)
        self.layer_norm_2 = nn.LayerNorm(embed_dim)
        self.linear = nn.Sequential(
            nn.Linear(embed_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, embed_dim),
            nn.Dropout(dropout)
        )


    def forward(self, x):
        inp_x = self.layer_norm_1(x)
        x = x + self.attn(inp_x, inp_x, inp_x)[0]
        x = x + self.linear(self.layer_norm_2(x))
        return x


class VisionTransformer(nn.Module):

    def __init__(self, embed_dim, hidden_dim, num_channels, num_heads, num_layers, num_classes, patch_size, num_patches, dropout=0.0):
        """
        Inputs:
            embed_dim - Dimensionality of the input feature vectors to the Transformer
            hidden_dim - Dimensionality of the hidden layer in the feed-forward networks
                         within the Transformer
            num_channels - Number of channels of the input (3 for RGB)
            num_heads - Number of heads to use in the Multi-Head Attention block
            num_layers - Number of layers to use in the Transformer
            num_classes - Number of classes to predict
            patch_size - Number of pixels that the patches have per dimension
            num_patches - Maximum number of patches an image can have
            dropout - Amount of dropout to apply in the feed-forward network and
                      on the input encoding
        """
        super().__init__()

        self.patch_size = patch_size

        # Layers/Networks
        self.input_layer = nn.Linear(num_channels*(patch_size**2), embed_dim)
        self.transformer = nn.Sequential(*[AttentionBlock(embed_dim, hidden_dim, num_heads, dropout=dropout) for _ in range(num_layers)])
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )
        self.dropout = nn.Dropout(dropout)

        # Parameters/Embeddings
        self.cls_token = nn.Parameter(torch.randn(1,1,embed_dim))
        self.pos_embedding = nn.Parameter(torch.randn(1,1+num_patches,embed_dim))


    def forward(self, x):
        # Preprocess input
        x = img_to_patch(x, self.patch_size)
        B, T, _ = x.shape
        x = self.input_layer(x)

        # Add CLS token and positional encoding
        cls_token = self.cls_token.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.pos_embedding[:,:T+1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.transformer(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlp_head(cls)
        # out = cls
        
        return out

In [14]:
import math
import torch.nn.functional as F


# class Reshape(torch.nn.Module):
    
#     def __init__(self, *args):
#         super(Reshape, self).__init__()
#         self.shape = args

#     def forward(self, x):
#         return x.view(self.shape)

    
# def scaled_dot_product(q, k, v, mask=None):
#     d_k = q.size()[-1]
#     attn_logits = torch.matmul(q, k.transpose(-2, -1))
#     attn_logits = attn_logits / math.sqrt(d_k)
#     if mask is not None:
#         attn_logits = attn_logits.masked_fill(mask == 0, -9e15)
        
#     attention = F.softmax(attn_logits, dim=-1)
#     values = torch.matmul(attention, v)
    
#     return values, attention


# class MultiheadAttention(torch.nn.Module):

#     def __init__(self, input_dim, embed_dim, num_heads):
#         super().__init__()
#         assert embed_dim % num_heads == 0, "Embedding dimension must be 0 modulo number of heads."

#         self.embed_dim = embed_dim
#         self.num_heads = num_heads
#         self.head_dim = embed_dim // num_heads

#         # Stack all weight matrices 1...h together for efficiency
#         # Note that in many implementations you see "bias=False" which is optional
#         self.qkv_proj = torch.nn.Linear(input_dim, 3*embed_dim)
#         self.o_proj = torch.nn.Linear(embed_dim, embed_dim)

#         self._reset_parameters()

#     def _reset_parameters(self):
#         # Original Transformer initialization, see PyTorch documentation
#         torch.nn.init.xavier_uniform_(self.qkv_proj.weight)
#         self.qkv_proj.bias.data.fill_(0)
#         torch.nn.init.xavier_uniform_(self.o_proj.weight)
#         self.o_proj.bias.data.fill_(0)

#     def forward(self, x, mask=None, return_attention=False):
#         batch_size, seq_length, embed_dim = x.size()
#         qkv = self.qkv_proj(x)

#         # Separate Q, K, V from linear output
#         qkv = qkv.reshape(batch_size, seq_length, self.num_heads, 3*self.head_dim)
#         qkv = qkv.permute(0, 2, 1, 3) # [Batch, Head, SeqLen, Dims]
#         q, k, v = qkv.chunk(3, dim=-1)

#         # Determine value outputs
#         values, attention = scaled_dot_product(q, k, v, mask=mask)
#         values = values.permute(0, 2, 1, 3) # [Batch, SeqLen, Head, Dims]
#         values = values.reshape(batch_size, seq_length, embed_dim)
#         o = self.o_proj(values)

#         if return_attention:
#             return o, attention
#         else:
#             return o


class EmbeddingGeneratorDecoder(torch.nn.Module):
    
    def __init__(self):
        super(EmbeddingGeneratorDecoder, self).__init__()

        # self.downsample1 = torch.nn.Conv1d(2, 8, 3, stride=1, padding=1)
        # self.downsample2 = torch.nn.Conv1d(8, 64, 3, stride=1, padding=1)
        
        # self.multihead_attention = MultiheadAttention(input_dim=128, embed_dim=128, num_heads=4)
        
        self.vit = VisionTransformer(**{
                                        'embed_dim': 128,
                                        'hidden_dim': 256,
                                        'num_heads': 4,
                                        'num_layers': 4,
                                        'patch_size': 8,
                                        'num_channels': 1,
                                        'num_patches': 32,
                                        'num_classes': 512,
                                        'dropout': 0.4
                                    })
        
        self.normalizer = nn.LayerNorm(1024)
        
#         self.upsample1 = torch.nn.Conv1d(64, 8, 3, stride=1, padding=1)
#         self.upsample2 = torch.nn.Conv1d(8, 1, 3, stride=1, padding=1)
        
#         self.upsample_block = torch.nn.Upsample(scale_factor=2)
        
        # self.batchnorm1 = torch.nn.BatchNorm1d(1)
        # self.batchnorm8 = torch.nn.BatchNorm1d(8)
        # self.batchnorm64 = torch.nn.BatchNorm1d(64)
        
        # self.features = torch.nn.Sequential(
        #     torch.nn.Linear(512, 512), 
        #     torch.nn.BatchNorm1d(512),
        # )
        
#     def feature_extraction_downsample(self, x): # n, 2, 512
#         skip_connections = []
        
#         x = self.downsample1(x) # n, 8, 512
#         # x = self.batchnorm8(x)
#         x = F.dropout(x, p=0.4)
#         x = F.relu(x)
#         skip_connections.append(x)
#         x = F.max_pool1d(x, kernel_size=2) # n, 8, 256 
#         # x = F.avg_pool1d(x, kernel_size=2) # n, 8, 256 
        
#         x = self.downsample2(x) # n, 64, 256
#         # x = self.batchnorm64(x)
#         x = F.dropout(x, p=0.4)
#         x = F.relu(x)
#         skip_connections.append(x)
#         x = F.max_pool1d(x, kernel_size=2) # n, 64, 128
#         # x = F.avg_pool1d(x, kernel_size=2) # n, 8, 256
        
#         return x, skip_connections
    
#     def feature_extraction_upsample(self, x, skip_connections): # n, 64, 128
#         skip_connections = list(reversed(skip_connections))
        
#         x = self.upsample_block(x) # n, 64, 256
#         x = x + skip_connections[0]
#         x = self.upsample1(x) # n, 8, 256
#         # x = self.batchnorm8(x)
#         x = F.dropout(x, p=0.4)
#         x = F.relu(x)
        
#         x = self.upsample_block(x) # n, 8, 512
#         x = x + skip_connections[1]
#         x = self.upsample2(x) # n, 1, 512
#         # x = self.batchnorm1(x)
#         x = F.dropout(x, p=0.4)
#         x = F.relu(x)
        
#         return x
    
    def forward(self, emb1, emb2):
#         emb1 = emb1.unsqueeze(1)
#         emb2 = emb2.unsqueeze(1)
        
#         comb = torch.cat([emb1, emb2], 1)
#         # comb = comb.unsqueeze(1)
#         comb, skip_connections = self.feature_extraction_downsample(comb)
#         # att_comb = self.multihead_attention(comb)
#         att_comb = comb
#         att_comb = self.feature_extraction_upsample(att_comb, skip_connections)
#         att_comb = att_comb.squeeze(1)
#         embedding = self.features(att_comb)
        
        comb = torch.cat([emb1, emb2], 1)
        comb = self.normalizer(comb)
        comb = torch.reshape(comb, (-1, 32, 32))
        comb = comb.unsqueeze(1)
        embedding = self.vit(comb)

        # comb = comb.unsqueeze(1)
        # comb, skip_connections = self.feature_extraction_downsample(comb)
        # # att_comb = self.multihead_attention(comb)
        # att_comb = comb
        # att_comb = self.feature_extraction_upsample(att_comb, skip_connections)
        # att_comb = att_comb.squeeze(1)
        
        # embedding = self.features(att_comb)
        
        return embedding


class ConvAutoencoder(torch.nn.Module):
    
    def __init__(self):
        super(ConvAutoencoder, self).__init__()
        
        self.decoder_emb_generator = EmbeddingGeneratorDecoder()
        
        self.encoder_layer1 = torch.nn.Conv2d(1, 4, 3, stride=1, padding=1)
        self.encoder_layer2 = torch.nn.Conv2d(4, 16, 3, stride=1, padding=1)
        self.encoder_layer3 = torch.nn.Conv2d(16, 64, 3, stride=1, padding=1)
        self.encoder_layer4 = torch.nn.Conv2d(64, 128, 3, stride=1, padding=1)
        self.encoder_layer5 = torch.nn.Conv2d(128, 256, 3, stride=1, padding=1)
        self.encoder_layer6 = torch.nn.Conv2d(256, 512, 3, stride=1, padding=1)
        
        self.upsample2 = torch.nn.Upsample(scale_factor=2, mode='nearest')
        self.upsample3 = torch.nn.Upsample(scale_factor=3, mode='nearest')
        self.upsample4 = torch.nn.Upsample(scale_factor=4, mode='nearest')
        
        self.decoder_layer1 = torch.nn.Conv2d(512, 256, 2, stride=3, padding=2)
        self.decoder_layer2 = torch.nn.Conv2d(256, 128, 3, stride=1, padding=0)
        self.decoder_layer3 = torch.nn.Conv2d(128, 64, 3, stride=1, padding=1)
        self.decoder_layer4 = torch.nn.Conv2d(64, 16, 3, stride=1, padding=1)
        self.decoder_layer5 = torch.nn.Conv2d(16, 4, 3, stride=1, padding=1)
        self.decoder_layer6 = torch.nn.Conv2d(4, 1, 3, stride=1, padding=1)
        
        # self.bn8 = torch.nn.BatchNorm2d(8)
        # self.bn64 = torch.nn.BatchNorm2d(64)
        # self.bn256 = torch.nn.BatchNorm2d(256)
        # self.bn512 = torch.nn.BatchNorm2d(512)
        
    def encoder(self, x): # 1, 112, 112
        x = self.encoder_layer1(x) # 4, 112, 112
        x = F.dropout(x, p=0.4)
        # x = self.bn8(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2) # 4, 56, 56
        
        x = self.encoder_layer2(x) # 16, 56, 56
        x = F.dropout(x, p=0.4)
        # x = self.bn64(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2) # 16, 28, 28

        x = self.encoder_layer3(x) # 64, 28, 28
        x = F.dropout(x, p=0.4)
        # x = self.bn256(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2) # 64, 14, 14
        
        x = self.encoder_layer4(x) #128, 14, 14
        x = F.dropout(x, p=0.4)
        # x = self.bn512(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2) # 128, 7, 7
        
        x = self.encoder_layer5(x) #256, 7, 7
        x = F.dropout(x, p=0.4)
        # x = self.bn512(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2) # 256, 3, 3
        
        x = self.encoder_layer6(x) #512, 3, 3
        x = F.dropout(x, p=0.4)
        # x = self.bn512(x)
        x = F.relu(x)
        x = F.max_pool2d(x, kernel_size=2) # 512, 1, 1
        
        x = torch.reshape(x, (-1, 512))
        
        return x
        
    def decoder_pose_reconstructor(self, x):
        x = torch.reshape(x, (-1, 512, 1, 1))
        
        x = self.upsample4(x) # 512, 4, 4
        x = self.decoder_layer1(x) # 256, 3, 3
        x = F.dropout(x, p=0.4)
        # x = self.bn256(x)
        x = F.relu(x)
        
        x = self.upsample3(x) # 256, 9, 9
        x = self.decoder_layer2(x) # 128, 7, 7
        x = F.dropout(x, p=0.4)
        # x = self.bn256(x)
        x = F.relu(x)
        
        x = self.upsample2(x) # 128, 14, 14
        x = self.decoder_layer3(x) # 64, 14, 14
        x = F.dropout(x, p=0.4)
        # x = self.bn256(x)
        x = F.relu(x)
        
        x = self.upsample2(x) # 64, 28, 28
        x = self.decoder_layer4(x) # 16, 28, 28
        x = F.dropout(x, p=0.4)
        # x = self.bn64(x)
        x = F.relu(x)
        
        x = self.upsample2(x) # 16, 56, 56
        x = self.decoder_layer5(x) # 4, 56, 56
        x = F.dropout(x, p=0.4)
        # x = self.bn8(x)
        x = F.relu(x)
        
        x = self.upsample2(x) # 4, 112, 112
        x = self.decoder_layer6(x) # 1, 112, 112
        x = torch.sigmoid(x)
        
        return x
        
    def forward(self, pose_img, magface_embedding):
        coded = self.encoder(pose_img)
        reconstructed_pose = self.decoder_pose_reconstructor(coded)
        
        # Feature fusion
        generated_embedding = self.decoder_emb_generator(coded, magface_embedding)
        
        # fc_out = self.fc_head(generated_embedding)
        
#         id_predictions = self.id_head(fc_out)
#         id_predictions = F.softmax(id_predictions, dim=1)
        id_predictions = None
        
#         emo_predictions = self.emotion_head(fc_out)
#         emo_predictions = F.softmax(emo_predictions, dim=1)
        emo_predictions = None
        
        return reconstructed_pose, generated_embedding, (id_predictions, emo_predictions)

# Train phase

In [15]:
net = ConvAutoencoder()
net = net.to("cuda")

state_dict = torch.load("./augmentor.pt")
net.load_state_dict(state_dict)

criterion_bce = torch.nn.BCELoss(reduction='mean')
criterion_mse = torch.nn.MSELoss()
# criterion_huber = torch.nn.HuberLoss()
# criterion_mae = torch.nn.L1Loss()
# criterion_cross = torch.nn.CrossEntropyLoss()
criterion_triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)

optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5, min_lr=1e-6)

In [None]:
import datetime


NUM_EPOCHS = 300

historical_loss = []
historical_val_loss = []

best_loss = np.inf
for epoch in range(NUM_EPOCHS):
    # training
    losses1 = []
    losses2 = []
    # losses3 = []
    # losses4 = []
    losses = []
    for data in tqdm.tqdm(train_dataloader):
        pose_img = data[0]
        pose_img = pose_img.permute(0, 3, 1, 2)
        pose_img = pose_img.to("cuda").to(torch.float32)
        
        input_emb = data[1]
        input_emb = input_emb.to("cuda").to(torch.float32)
        
        output_emb = data[2]
        output_emb = output_emb.to("cuda").to(torch.float32)
        
        negative_embedding = data[3]
        negative_embedding = negative_embedding.to("cuda").to(torch.float32)
        
        # expected_pose_label, input_id, input_emo = data[3:]
        # expected_pose_label = expected_pose_label.to("cuda")
        # input_id = input_id.to("cuda")
        # input_emo = input_emo.to("cuda")
        
        optimizer.zero_grad()
        
        reconstructed_pose, generated_emb, (id_predictions, emo_predictions) = net(pose_img, input_emb)
        
        loss1 = criterion_bce(reconstructed_pose, pose_img)
        losses1.append(loss1.item())
        
        # loss2 = criterion_mse(generated_emb, output_emb)
        loss2 = criterion_triplet_loss(generated_emb, output_emb, negative_embedding)
        losses2.append(loss2.item())
        
#         loss3 = criterion_cross(id_predictions, input_id)
#         losses3.append(loss3.item())
        
#         loss4 = criterion_cross(emo_predictions, input_emo)
#         losses4.append(loss4.item())
        
        loss = loss1 + loss2 #+ loss3 + loss4
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        
    # validation
    with torch.no_grad():
        val_losses1 = []
        val_losses2 = []
        # val_losses3 = []
        # val_losses4 = []
        val_losses = []
        for val_data in tqdm.tqdm(val_dataloader):
            pose_img = val_data[0]
            pose_img = pose_img.permute(0, 3, 1, 2)
            pose_img = pose_img.to("cuda").to(torch.float32)
            
            input_emb = val_data[1]
            input_emb = input_emb.to("cuda").to(torch.float32)
            
            output_emb = val_data[2]
            output_emb = output_emb.to("cuda").to(torch.float32)
            
            negative_embedding = val_data[3]
            negative_embedding = negative_embedding.to("cuda").to(torch.float32)
            
            # expected_pose_label, input_id, input_emo = val_data[3:]
            # expected_pose_label = expected_pose_label.to("cuda")
            # input_id = input_id.to("cuda")
            # input_emo = input_emo.to("cuda")
            
            reconstructed_pose, generated_emb, (id_predictions, emo_predictions) = net(pose_img, input_emb)
            
            val_loss1 = criterion_bce(reconstructed_pose, pose_img)
            val_losses1.append(val_loss1.item())
            
            # val_loss2 = criterion_mse(generated_emb, output_emb)
            val_loss2 = criterion_triplet_loss(generated_emb, output_emb, negative_embedding)
            val_losses2.append(val_loss2.item())
            
#             val_loss3 = criterion_cross(id_predictions, input_id)
#             val_losses3.append(val_loss3.item())
            
#             val_loss4 = criterion_cross(emo_predictions, input_emo)
#             val_losses4.append(val_loss4.item())
            
            val_loss = val_loss1 + val_loss2 #+ val_loss3 + val_loss4
            val_losses.append(val_loss.item())
            
    scheduler.step(val_loss)
    
    current_lr = optimizer.param_groups[0]['lr']
    
    print("Epoch", epoch+1, "/", NUM_EPOCHS, 
          "Loss", round(np.mean(losses), 4), f"({round(np.mean(losses1), 4)}, {round(np.mean(losses2), 4)})", 
          "Val_Loss", round(np.mean(val_losses), 4), f"({round(np.mean(val_losses1), 4)}, {round(np.mean(val_losses2), 4)})", 
          "LR", current_lr)
    
    # print("Loss", np.mean(losses), f"({np.mean(losses1)}, {np.mean(losses2)}, {np.mean(losses3)}, {np.mean(losses4)})")
    # print("Val_Loss", np.mean(val_losses), f"({np.mean(val_losses1)}, {np.mean(val_losses2)}, {np.mean(val_losses3)}, {np.mean(val_losses4)})")
    # historical_loss.append((np.mean(losses), np.mean(losses1), np.mean(losses2), np.mean(losses3), np.mean(losses4)))
    # historical_val_loss.append((np.mean(val_losses), np.mean(val_losses1), np.mean(val_losses2), np.mean(val_losses3), np.mean(val_losses4)))
    historical_loss.append((np.mean(losses), np.mean(losses1), np.mean(losses2)))
    historical_val_loss.append((np.mean(val_losses), np.mean(val_losses1), np.mean(val_losses2)))
    
    if historical_val_loss[-1][0] < best_loss:
        best_loss = historical_val_loss[-1][0]
        torch.save(net.state_dict(), "./augmentor.pt")
        print(f"Model saved @ {datetime.datetime.now()}, Loss = {round(historical_loss[-1][0], 4)}, Val_Loss = {round(historical_val_loss[-1][0], 4)}")

100%|██████████| 79/79 [34:47<00:00, 26.42s/it]
100%|██████████| 40/40 [13:34<00:00, 20.37s/it]


Epoch 1 / 300 Loss 0.7688 (0.2335, 0.5353) Val_Loss 0.1472 (0.1208, 0.0264) LR 0.001
Model saved @ 2022-02-03 12:19:26.902091, Loss = 0.7688, Val_Loss = 0.1472


100%|██████████| 79/79 [25:46<00:00, 19.58s/it]
100%|██████████| 40/40 [12:46<00:00, 19.15s/it]


Epoch 2 / 300 Loss 0.1261 (0.1065, 0.0196) Val_Loss 0.1122 (0.0979, 0.0143) LR 0.001
Model saved @ 2022-02-03 12:57:59.835312, Loss = 0.1261, Val_Loss = 0.1122


100%|██████████| 79/79 [26:13<00:00, 19.92s/it]
100%|██████████| 40/40 [11:45<00:00, 17.64s/it]


Epoch 3 / 300 Loss 0.1081 (0.0941, 0.0141) Val_Loss 0.1031 (0.091, 0.0121) LR 0.001
Model saved @ 2022-02-03 13:35:58.807871, Loss = 0.1081, Val_Loss = 0.1031


100%|██████████| 79/79 [23:50<00:00, 18.11s/it]
100%|██████████| 40/40 [10:55<00:00, 16.39s/it]


Epoch 4 / 300 Loss 0.0981 (0.0883, 0.0097) Val_Loss 0.0955 (0.0863, 0.0092) LR 0.001
Model saved @ 2022-02-03 14:10:44.981844, Loss = 0.0981, Val_Loss = 0.0955


100%|██████████| 79/79 [23:27<00:00, 17.82s/it]
100%|██████████| 40/40 [10:53<00:00, 16.35s/it]


Epoch 5 / 300 Loss 0.0925 (0.0846, 0.0079) Val_Loss 0.0927 (0.0831, 0.0096) LR 0.001
Model saved @ 2022-02-03 14:45:06.556934, Loss = 0.0925, Val_Loss = 0.0927


100%|██████████| 79/79 [22:20<00:00, 16.97s/it]
100%|██████████| 40/40 [10:43<00:00, 16.09s/it]


Epoch 6 / 300 Loss 0.0911 (0.0819, 0.0092) Val_Loss 0.09 (0.0813, 0.0086) LR 0.001
Model saved @ 2022-02-03 15:18:11.090157, Loss = 0.0911, Val_Loss = 0.09


100%|██████████| 79/79 [22:14<00:00, 16.89s/it]
100%|██████████| 40/40 [10:05<00:00, 15.13s/it]


Epoch 7 / 300 Loss 0.0856 (0.0782, 0.0074) Val_Loss 0.0841 (0.0738, 0.0103) LR 0.001
Model saved @ 2022-02-03 15:50:30.322123, Loss = 0.0856, Val_Loss = 0.0841


100%|██████████| 79/79 [21:16<00:00, 16.16s/it]
100%|██████████| 40/40 [10:20<00:00, 15.52s/it]


Epoch 8 / 300 Loss 0.0799 (0.0728, 0.0071) Val_Loss 0.0801 (0.0725, 0.0077) LR 0.001
Model saved @ 2022-02-03 16:22:07.536163, Loss = 0.0799, Val_Loss = 0.0801


100%|██████████| 79/79 [21:24<00:00, 16.26s/it]
100%|██████████| 40/40 [10:10<00:00, 15.27s/it]


Epoch 9 / 300 Loss 0.0776 (0.0717, 0.006) Val_Loss 0.0799 (0.0714, 0.0086) LR 0.001
Model saved @ 2022-02-03 16:53:43.024240, Loss = 0.0776, Val_Loss = 0.0799


100%|██████████| 79/79 [21:28<00:00, 16.31s/it]
100%|██████████| 40/40 [10:09<00:00, 15.23s/it]


Epoch 10 / 300 Loss 0.0783 (0.0709, 0.0073) Val_Loss 0.0792 (0.0707, 0.0085) LR 0.001
Model saved @ 2022-02-03 17:25:20.513782, Loss = 0.0783, Val_Loss = 0.0792


100%|██████████| 79/79 [22:14<00:00, 16.90s/it]
100%|██████████| 40/40 [10:10<00:00, 15.27s/it]


Epoch 11 / 300 Loss 0.0774 (0.0702, 0.0072) Val_Loss 0.0761 (0.0702, 0.0058) LR 0.001
Model saved @ 2022-02-03 17:57:46.179825, Loss = 0.0774, Val_Loss = 0.0761


100%|██████████| 79/79 [21:34<00:00, 16.38s/it]
100%|██████████| 40/40 [10:01<00:00, 15.04s/it]


Epoch 12 / 300 Loss 0.0768 (0.0697, 0.0071) Val_Loss 0.0743 (0.0695, 0.0048) LR 0.001
Model saved @ 2022-02-03 18:29:21.978877, Loss = 0.0768, Val_Loss = 0.0743


100%|██████████| 79/79 [21:43<00:00, 16.50s/it]
100%|██████████| 40/40 [10:32<00:00, 15.82s/it]


Epoch 13 / 300 Loss 0.0762 (0.0691, 0.0071) Val_Loss 0.0759 (0.0691, 0.0068) LR 0.001


100%|██████████| 79/79 [22:18<00:00, 16.94s/it]
100%|██████████| 40/40 [10:20<00:00, 15.50s/it]


Epoch 14 / 300 Loss 0.078 (0.0688, 0.0092) Val_Loss 0.0757 (0.0687, 0.007) LR 0.001


100%|██████████| 79/79 [22:09<00:00, 16.83s/it]
100%|██████████| 40/40 [10:28<00:00, 15.70s/it]


Epoch 15 / 300 Loss 0.0741 (0.0681, 0.0059) Val_Loss 0.0763 (0.0687, 0.0076) LR 0.001


100%|██████████| 79/79 [22:33<00:00, 17.13s/it]
100%|██████████| 40/40 [10:29<00:00, 15.73s/it]


Epoch 16 / 300 Loss 0.0751 (0.0678, 0.0073) Val_Loss 0.0733 (0.0678, 0.0055) LR 0.001
Model saved @ 2022-02-03 20:39:57.214447, Loss = 0.0751, Val_Loss = 0.0733


100%|██████████| 79/79 [22:38<00:00, 17.19s/it]
100%|██████████| 40/40 [10:33<00:00, 15.84s/it]


Epoch 17 / 300 Loss 0.0738 (0.0676, 0.0063) Val_Loss 0.0724 (0.0673, 0.0051) LR 0.001
Model saved @ 2022-02-03 21:13:08.955905, Loss = 0.0738, Val_Loss = 0.0724


100%|██████████| 79/79 [22:08<00:00, 16.81s/it]
100%|██████████| 40/40 [10:08<00:00, 15.22s/it]


Epoch 18 / 300 Loss 0.0725 (0.0671, 0.0054) Val_Loss 0.0753 (0.0672, 0.0081) LR 0.001


100%|██████████| 79/79 [21:48<00:00, 16.56s/it]
100%|██████████| 40/40 [10:31<00:00, 15.78s/it]


Epoch 19 / 300 Loss 0.0727 (0.0669, 0.0058) Val_Loss 0.0743 (0.0668, 0.0075) LR 0.001


100%|██████████| 79/79 [22:41<00:00, 17.23s/it]
100%|██████████| 40/40 [10:37<00:00, 15.94s/it]


Epoch 20 / 300 Loss 0.0725 (0.0664, 0.0061) Val_Loss 0.0723 (0.0662, 0.0061) LR 0.001
Model saved @ 2022-02-03 22:51:04.389689, Loss = 0.0725, Val_Loss = 0.0723


100%|██████████| 79/79 [23:17<00:00, 17.70s/it]
100%|██████████| 40/40 [11:52<00:00, 17.82s/it]


Epoch 21 / 300 Loss 0.0745 (0.0662, 0.0083) Val_Loss 0.0721 (0.0662, 0.0059) LR 0.001
Model saved @ 2022-02-03 23:26:15.266503, Loss = 0.0745, Val_Loss = 0.0721


100%|██████████| 79/79 [23:50<00:00, 18.11s/it]
100%|██████████| 40/40 [11:08<00:00, 16.71s/it]


Epoch 22 / 300 Loss 0.072 (0.066, 0.006) Val_Loss 0.0706 (0.0663, 0.0043) LR 0.001
Model saved @ 2022-02-04 00:01:14.316333, Loss = 0.072, Val_Loss = 0.0706


100%|██████████| 79/79 [25:05<00:00, 19.06s/it]
100%|██████████| 40/40 [11:36<00:00, 17.41s/it]


Epoch 23 / 300 Loss 0.0714 (0.0658, 0.0056) Val_Loss 0.0731 (0.066, 0.0071) LR 0.001


100%|██████████| 79/79 [23:28<00:00, 17.83s/it]
100%|██████████| 40/40 [10:44<00:00, 16.12s/it]


Epoch 24 / 300 Loss 0.0719 (0.0655, 0.0064) Val_Loss 0.0698 (0.0652, 0.0046) LR 0.001
Model saved @ 2022-02-04 01:12:09.872223, Loss = 0.0719, Val_Loss = 0.0698


100%|██████████| 79/79 [23:24<00:00, 17.78s/it]
100%|██████████| 40/40 [11:17<00:00, 16.93s/it]


Epoch 25 / 300 Loss 0.0704 (0.0653, 0.0051) Val_Loss 0.0729 (0.065, 0.0078) LR 0.001


100%|██████████| 79/79 [23:24<00:00, 17.78s/it]
100%|██████████| 40/40 [11:38<00:00, 17.46s/it]


Epoch 26 / 300 Loss 0.0714 (0.065, 0.0064) Val_Loss 0.076 (0.0717, 0.0043) LR 0.001


100%|██████████| 79/79 [25:04<00:00, 19.04s/it]
100%|██████████| 40/40 [11:58<00:00, 17.95s/it]


Epoch 27 / 300 Loss 0.0735 (0.0661, 0.0075) Val_Loss 0.0721 (0.0649, 0.0072) LR 0.0005


100%|██████████| 79/79 [24:00<00:00, 18.24s/it]
100%|██████████| 40/40 [10:29<00:00, 15.73s/it]


Epoch 28 / 300 Loss 0.0711 (0.0644, 0.0067) Val_Loss 0.0692 (0.0644, 0.0048) LR 0.0005
Model saved @ 2022-02-04 03:33:26.714579, Loss = 0.0711, Val_Loss = 0.0692


100%|██████████| 79/79 [21:54<00:00, 16.64s/it]
100%|██████████| 40/40 [10:14<00:00, 15.37s/it]


Epoch 29 / 300 Loss 0.0686 (0.0641, 0.0045) Val_Loss 0.0685 (0.0641, 0.0044) LR 0.0005
Model saved @ 2022-02-04 04:05:36.410545, Loss = 0.0686, Val_Loss = 0.0685


100%|██████████| 79/79 [21:40<00:00, 16.46s/it]
100%|██████████| 40/40 [10:27<00:00, 15.69s/it]


Epoch 30 / 300 Loss 0.0695 (0.0639, 0.0056) Val_Loss 0.0687 (0.0639, 0.0048) LR 0.0005


100%|██████████| 79/79 [22:22<00:00, 16.99s/it]
100%|██████████| 40/40 [10:10<00:00, 15.26s/it]


Epoch 31 / 300 Loss 0.0684 (0.0638, 0.0046) Val_Loss 0.0696 (0.0637, 0.0059) LR 0.0005


100%|██████████| 79/79 [21:28<00:00, 16.31s/it]
100%|██████████| 40/40 [10:11<00:00, 15.30s/it]


Epoch 32 / 300 Loss 0.068 (0.0636, 0.0043) Val_Loss 0.0684 (0.0636, 0.0048) LR 0.0005
Model saved @ 2022-02-04 05:41:57.964864, Loss = 0.068, Val_Loss = 0.0684


100%|██████████| 79/79 [21:23<00:00, 16.25s/it]
100%|██████████| 40/40 [10:05<00:00, 15.13s/it]


Epoch 33 / 300 Loss 0.0685 (0.0634, 0.0051) Val_Loss 0.0678 (0.0635, 0.0044) LR 0.0005
Model saved @ 2022-02-04 06:13:26.739934, Loss = 0.0685, Val_Loss = 0.0678


100%|██████████| 79/79 [22:05<00:00, 16.78s/it]
100%|██████████| 40/40 [10:23<00:00, 15.59s/it]


Epoch 34 / 300 Loss 0.0684 (0.0632, 0.0052) Val_Loss 0.0669 (0.0633, 0.0037) LR 0.0005
Model saved @ 2022-02-04 06:45:55.669696, Loss = 0.0684, Val_Loss = 0.0669


100%|██████████| 79/79 [22:08<00:00, 16.81s/it]
100%|██████████| 40/40 [10:16<00:00, 15.42s/it]


Epoch 35 / 300 Loss 0.0682 (0.0631, 0.0051) Val_Loss 0.0693 (0.0631, 0.0061) LR 0.00025


100%|██████████| 79/79 [22:12<00:00, 16.87s/it]
100%|██████████| 40/40 [10:21<00:00, 15.53s/it]


Epoch 36 / 300 Loss 0.0672 (0.0627, 0.0045) Val_Loss 0.0683 (0.0626, 0.0057) LR 0.00025


100%|██████████| 79/79 [22:32<00:00, 17.13s/it]
100%|██████████| 40/40 [10:29<00:00, 15.73s/it]


Epoch 37 / 300 Loss 0.0669 (0.0624, 0.0045) Val_Loss 0.0669 (0.0626, 0.0043) LR 0.00025
Model saved @ 2022-02-04 08:23:56.361638, Loss = 0.0669, Val_Loss = 0.0669


100%|██████████| 79/79 [22:09<00:00, 16.83s/it]
100%|██████████| 40/40 [10:29<00:00, 15.73s/it]


Epoch 38 / 300 Loss 0.0662 (0.0624, 0.0038) Val_Loss 0.067 (0.0625, 0.0045) LR 0.00025


100%|██████████| 79/79 [21:48<00:00, 16.56s/it]
100%|██████████| 40/40 [10:05<00:00, 15.13s/it]


Epoch 39 / 300 Loss 0.0666 (0.0623, 0.0043) Val_Loss 0.0662 (0.0625, 0.0037) LR 0.00025
Model saved @ 2022-02-04 09:28:28.679849, Loss = 0.0666, Val_Loss = 0.0662


100%|██████████| 79/79 [21:56<00:00, 16.66s/it]
100%|██████████| 40/40 [10:15<00:00, 15.38s/it]


Epoch 40 / 300 Loss 0.0677 (0.0622, 0.0054) Val_Loss 0.0658 (0.0623, 0.0036) LR 0.00025
Model saved @ 2022-02-04 10:00:40.411237, Loss = 0.0677, Val_Loss = 0.0658


100%|██████████| 79/79 [22:26<00:00, 17.04s/it]
100%|██████████| 40/40 [10:29<00:00, 15.74s/it]


Epoch 41 / 300 Loss 0.0662 (0.0621, 0.004) Val_Loss 0.0676 (0.0622, 0.0054) LR 0.000125


 18%|█▊        | 14/79 [04:20<19:09, 17.69s/it]

In [None]:
print("Loss", round(loss.item(), 4), "Val_Loss", round(val_loss.item(), 4), "LR", current_lr)

In [None]:
# torch.save(net.state_dict(), "./augmentor.pt")

# Evaluation phase

In [None]:
net = ConvAutoencoder()
net = net.to("cuda")
state_dict = torch.load("./augmentor.pt")
net.load_state_dict(state_dict)
net.eval()

In [None]:
with torch.no_grad():
    for data in train_dataloader:
        pose_img = data[0]
        pose_img = pose_img.permute(0, 3, 1, 2)
        pose_img = pose_img.to("cuda").to(torch.float32)
        input_emb = data[1]
        input_emb = input_emb.to("cuda").to(torch.float32)
        output_emb = data[2]
        output_emb = output_emb.to("cuda").to(torch.float32)
        
        features = net.encoder(pose_img)
        
        reconstructed_data = net.decoder_pose_reconstructor(features)
        
        reconstructed_data = reconstructed_data.permute(0, 2, 3, 1).cpu().numpy()
        orig_data = pose_img.permute(0, 2, 3, 1).cpu().numpy()
        
        for i in range(reconstructed_data.shape[0]):
            plt.imshow(orig_data[i, :, :, 0])
            plt.show()
            plt.imshow(reconstructed_data[i, :, :, 0])
            plt.show()
            
        break

In [None]:
with torch.no_grad():
    for data in val_dataloader:
        pose_img = data[0]
        pose_img = pose_img.permute(0, 3, 1, 2)
        pose_img = pose_img.to("cuda").to(torch.float32)
        input_emb = data[1]
        input_emb = input_emb.to("cuda").to(torch.float32)
        output_emb = data[2]
        output_emb = output_emb.to("cuda").to(torch.float32)
        
        features, _ = net.encoder(pose_img)
        
        reconstructed_data = net.decoder_pose_reconstructor(features)
        
        reconstructed_data = reconstructed_data.permute(0, 2, 3, 1).cpu().numpy()
        orig_data = pose_img.permute(0, 2, 3, 1).cpu().numpy()
        
        for i in range(reconstructed_data.shape[0]):
            plt.imshow(orig_data[i, :, :, 0])
            plt.show()
            plt.imshow(reconstructed_data[i, :, :, 0])
            plt.show()
            
        break