In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

from PIL import Image
import matplotlib.pyplot as plt

import random

In [2]:
from sklearn.model_selection import train_test_split

X = np.load("/home/viper/Downloads/X_basic.npy", allow_pickle=True)
Y = np.load("/home/viper/Downloads/Y_basic.npy", allow_pickle=True)

X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=42)
X_test, X_valid, Y_test, Y_valid = train_test_split(X_test, Y_test, test_size=0.5, random_state=42)

In [166]:
class RAN(nn.Module):
    def __init__(self, net):
        super(RAN, self).__init__()
        self.feature = net
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.alpha = nn.Sequential(nn.Linear(512, 1),
                                   nn.Sigmoid())
        self.beta = nn.Sequential(nn.Linear(1024, 1),
                                  nn.Sigmoid())
        self.fc = nn.Linear(1024, 7)
        
    def forward(self, x):
        x = x.view(-1, 3, 224, 224)
        # Extract feature module
        ## X: [B, 6, 512]
        res1 = self.feature(x).squeeze(3).squeeze(2).view(-1, 6, 512)
        print("res1: ", res1.shape)
        # Self attention module
        ## mu
        mu = self.alpha(res1.view(-1, 512)).view(-1, 6, 1)
        mu = F.softmax(mu, dim=1)
        print("mu: ", mu.shape)
        mu_max = mu[:, 0:5, :].max(dim=1)[0]
        mu_org = mu[:, 5, :]
        ## Fm
        Fm = res1.mul(mu).sum(1).div(mu.sum(1)).unsqueeze(dim=1)
        print("Fm: ", Fm.shape)
        
        # Relation-attention module
        ## concat Fi:Fm
        res2 = torch.cat((res1, Fm.repeat(1, 6, 1)), dim=2).view(-1, 6, 1024)
        print("res2: ", res2.shape)
        ## vi
        vi = self.beta(res2).view(-1, 6)
        vi = F.softmax(vi, dim=1)
        print("vi: ", vi.shape)
        
        ## PRAN
        PRAN = res2.mul((mu.squeeze() * vi).unsqueeze(2)).sum(dim=1).div((mu.squeeze() * vi).sum(dim=1).unsqueeze(dim=1))
        res3 = self.fc(PRAN)
        
        # res
        return res3
        
        

In [167]:
backbone = nn.Sequential(*list(models.resnet18(pretrained=True).children())[:-1])
model = RAN(backbone)

In [209]:
class CaffeCrop(object):

    def __init__(self, phase):
        assert(phase=='train' or phase=='test')
        self.phase = phase

    def __call__(self, img):
        # pre determined parameters
        final_size = 224
        final_width = final_height = final_size
        crop_size = 110
        crop_height = crop_width = crop_size
        crop_center_y_offset = 15
        crop_center_x_offset = 0
        if self.phase == 'train':
            scale_aug = 0.02
            trans_aug = 0.01
        else:
            scale_aug = 0.0
            trans_aug = 0.0
        
        # computed parameters
        randint = random.randint
        scale_height_diff = (randint(0,1000)/500-1)*scale_aug
        crop_height_aug = crop_height*(1+scale_height_diff)
        scale_width_diff = (randint(0,1000)/500-1)*scale_aug
        crop_width_aug = crop_width*(1+scale_width_diff)


        trans_diff_x = (randint(0,1000)/500-1)*trans_aug
        trans_diff_y = (randint(0,1000)/500-1)*trans_aug


        center = ((img.width/2 + crop_center_x_offset)*(1+trans_diff_x),
                 (img.height/2 + crop_center_y_offset)*(1+trans_diff_y))

        
        if center[0] < crop_width_aug/2:
            crop_width_aug = center[0]*2-0.5
        if center[1] < crop_height_aug/2:
            crop_height_aug = center[1]*2-0.5
        if (center[0]+crop_width_aug/2) >= img.width:
            crop_width_aug = (img.width-center[0])*2-0.5
        if (center[1]+crop_height_aug/2) >= img.height:
            crop_height_aug = (img.height-center[1])*2-0.5

        crop_box = (center[0]-crop_width_aug/2, center[1]-crop_height_aug/2,
                    center[0]+crop_width_aug/2, center[1]+crop_width_aug/2)

        mid_img = img.crop(crop_box)
        res_img = img.resize( (final_width, final_height) )
        return res_img

In [210]:
img =  Image.open("/home/viper/Downloads/face_f.jpg")