In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
import torchvision.models as models


class VA(nn.Module):
    """The layer for transforming the skeleton to the observed viewpoints"""
    def __init__(self,num_classes = 60):
        super(VA, self).__init__()
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(3, 128, kernel_size=5, stride=2, bias=False)
        self.bn1 = nn.BatchNorm2d(128)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(128, 128, kernel_size=5, stride=2, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU(inplace=True)
        self.avepool = nn.MaxPool2d(7)
        self.fc = nn.Linear(6272, 6)
        self.classifier = models.resnet50(pretrained=True)
        self.init_weight()

    def forward(self, inp, maxmin):

        x = self.conv1(inp)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.avepool(x)
        x = x.view(x.size(0), -1)
        params = self.fc(x)
        #Obtained the parameters for translation and rotation
        #Three rotation angles: alpha, beta, gamma , and three translation parameters : x, y, z

        temp1 = params.cpu()

        #transform the input skeleton map with output parameters
        x = self.transform_view(inp, params, maxmin)

        temp = x.cpu()
        x = self.classifier(x)
        
        #return x
        return x, temp.data.numpy(), temp1.data.numpy()

    def init_weight(self):
        for layer in [self.conv1, self.conv2]:
            for name, param in layer.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                if 'bias' in name:
                    param.data.zero_()
        for layer in [self.bn1, self.bn2]:
            layer.weight.data.fill_(1)
            layer.bias.data.fill_(0)
            layer.momentum = 0.99
            layer.eps = 1e-3

        self.fc.bias.data.zero_()
        self.fc.weight.data.zero_()

        num_ftrs = self.classifier.fc.in_features
        self.classifier.fc = nn.Linear(num_ftrs, self.num_classes)

    def get_rotation_matrix(self, rot):
        cos_r, sin_r = rot.cos(), rot.sin()
        zeros = Variable(rot.data.new(rot.size()[:1] + (1,)).zero_())
        ones = Variable(rot.data.new(rot.size()[:1] + (1,)).fill_(1))

        rx1 = torch.stack((ones, zeros, zeros),dim=-1)
        rx2 = torch.stack((zeros, cos_r[:,0:1], sin_r[:,0:1]), dim = -1)
        rx3 = torch.stack((zeros, -sin_r[:,0:1], cos_r[:,0:1]), dim = -1)
        rx = torch.cat((rx1, rx2, rx3), dim = 1)

        ry1 = torch.stack((cos_r[:,1:2], zeros, -sin_r[:,1:2]), dim =-1)
        ry2 = torch.stack((zeros, ones, zeros),dim=-1)
        ry3 = torch.stack((sin_r[:,1:2], zeros, cos_r[:,1:2]), dim =-1)
        ry = torch.cat((ry1, ry2, ry3), dim = 1)

        rz1 = torch.stack((cos_r[:,2:3], sin_r[:,2:3], zeros), dim =-1)
        rz3 = torch.stack((zeros, zeros, ones),dim=-1)
        rz2 = torch.stack((-sin_r[:,2:3], cos_r[:,2:3],zeros), dim =-1)
        rz = torch.cat((rz1, rz2, rz3), dim = 1)

        rot = rz.matmul(ry).matmul(rx)
        return rot

   
    def transform_view(self, x, params, maxmin):
        rot = params[:,0:3]
        trans = params[:,3:6]

        rot_matrix = self.get_rotation_matrix(rot)
        trans_vector = trans

        x = x.contiguous().view(-1, x.size(1), x.size(2)*x.size(3))
        maxi, mini = maxmin[:, 0], maxmin[:, 1]
        maxi, mini = maxi.contiguous().view(-1,1).repeat(1, 3), mini.contiguous().view(-1,1).repeat(1, 3)

        temp1 = torch.add(mini, trans)
        temp1 = temp1.unsqueeze(-1)

        term1 = torch.matmul(rot_matrix, x)

        term2_num = torch.add(torch.matmul(rot_matrix, temp1).squeeze(-1), -mini)
        term2 = torch.div(term2_num ,torch.add(maxi, -mini))
        term2 = term2.mul_(255)

        x = torch.add(term1, term2.unsqueeze(-1))
        x = x.view(-1,3,224,224)
        return x

