In [34]:
import torch
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image, write_jpeg
import torchvision.transforms as transforms
from torchvision.transforms.functional import crop

import scipy.io as scio

In [35]:
emotions = ['Affection', 'Anger', 'Annoyance', 'Anticipation', 'Aversion', 'Confidence', 'Disapproval', 'Disconnection', 'Disquietment',
            'Doubt/Confusion', 'Embarrassment', 'Engagement', 'Esteem', 'Excitement', 'Fatigue', 'Fear', 'Happiness', 'Pain', 'Peace',
            'Pleasure', 'Sadness', 'Sensitivity', 'Suffering', 'Surprise', 'Sympathy', 'Yearning']

emotion_encode = {e: i for i, e in enumerate(emotions)}
emotion_decode = {i: e for e, i in emotion_encode.items()}

In [36]:
class EmoticDataset(Dataset):
    def __init__(self, subject_size, context_size, anns_dir, img_dir):
        anns = scio.loadmat(anns_dir)["train"]
        self.anns = np.fromiter(filter(lambda x: x["folder"].item() != "framesdb/images", iter(anns[0])), dtype=anns.dtype)
        self.img_dir = img_dir
        self.subject_transform = transforms.Resize(subject_size)
        self.context_transform = transforms.Resize(context_size)

    def __len__(self):
        return self.anns.size

    def __getitem__(self, idx):
        ann = self.anns[idx]

        img_loc = "../data/cvpr_emotic/" + ann["folder"].item() + '/' + ann["filename"].item()
        context_img = read_image(img_loc)
        
        bbox =  ann["person"]["body_bbox"][0][0][0].astype(int)
        subject_img = crop(context_img, bbox[1], bbox[0], bbox[3] - bbox[1], bbox[2] - bbox[0])

        label = np.zeros(len(emotions))
        ems = [i.item() for i in ann["person"]["annotations_categories"][0][0].item()[0][0]]
        for e in ems:
            label[emotion_encode[e]] = 1.

        subject_img = self.subject_transform(subject_img.float())
        context_img = self.context_transform(context_img.float())

        return subject_img, context_img, label

In [5]:
subject_size = (50, 50)
context_size = (200, 200)

In [38]:
train_data = EmoticDataset(subject_size, context_size, "../data/Annotations/Annotations.mat", "../data/cvpr_emotic/")
train_dataloader = DataLoader(train_data, batch_size=16)

In [113]:
def net_branch():
    return nn.Sequential(
               nn.Conv2d(3, 96, (11, 1), stride=(4, 1)),
               nn.ReLU(),
               nn.BatchNorm2d(96),
               nn.Conv2d(96, 96, (1, 11), stride=(1, 4)),
               nn.ReLU(),
               nn.BatchNorm2d(96),
               nn.MaxPool2d(3, stride=2),
               nn.Conv2d(96, 256, (1, 5), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(256),
               nn.Conv2d(256, 256, (5, 1), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(256),
               nn.MaxPool2d(3, stride=2),
    )

    """
               nn.Conv2d(256, 384, (1, 3), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(384),
               nn.Conv2d(384, 384, (3, 1), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(384),
               
               nn.Conv2d(384, 384, (1, 3), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(384),
               nn.Conv2d(384, 384, (3, 1), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(384),
    
               nn.Conv2d(384, 256, (1, 3), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(256),
               nn.Conv2d(256, 256, (3, 1), padding="same"),
               nn.ReLU(),
               nn.BatchNorm2d(256),
               nn.MaxPool2d(3, stride=2)
    )
    """

In [118]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.subject = net_branch()
        self.context = net_branch()
        
        self.fusion = nn.Sequential(
        )

    def forward(self, s, c):
        s = self.subject(s)
        s = torch.flatten(s, start_dim=1)
        
        c = self.context(c)
        c = torch.flatten(c, start_dim=1)

        x = torch.cat((s, c), dim=1)
        x = self.fusion(x)
        return x

In [119]:
n = Net()

In [120]:
first = next(iter(train_dataloader))
save = n(first[0], first[1])

In [121]:
first[0].shape

torch.Size([16, 3, 50, 50])

In [122]:
save.shape

torch.Size([16, 31232])