In [2]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torchvision.models as models


class VA(nn.Module):
    """The layer for transforming the skeleton to the observed viewpoints"""
    def __init__(self,num_classes = 60):
        super(VA, self).__init__()
        self.num_classes = num_classes
        self.conv1 = nn.Conv2d(3, 128, kernel_size=5, stride=2, bias=False)
        self.bn1 = nn.BatchNorm2d(128)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(128, 128, kernel_size=5, stride=2, bias=False)
        self.bn2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU(inplace=True)
        self.avepool = nn.MaxPool2d(7)
        self.fc = nn.Linear(6272, 6)
        self.classifier = models.resnet50(pretrained=True)
        self.init_weight()

    def forward(self, x1, maxmin):

        x = self.conv1(x1)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.avepool(x)

        x = x.view(x.size(0), -1)
        trans = self.fc(x)

        temp1 = trans.cpu()
        x = _transform(x1, trans, maxmin)

        temp = x.cpu()
        x = self.classifier(x)
        return x, temp.data.numpy(), temp1.data.numpy()

    def init_weight(self):
        for layer in [self.conv1, self.conv2]:
            for name, param in layer.named_parameters():
                if 'weight' in name:
                    nn.init.xavier_uniform_(param)
                if 'bias' in name:
                    param.data.zero_()
        for layer in [self.bn1, self.bn2]:
            layer.weight.data.fill_(1)
            layer.bias.data.fill_(0)
            layer.momentum = 0.99
            layer.eps = 1e-3

        self.fc.bias.data.zero_()
        self.fc.weight.data.zero_()

        num_ftrs = self.classifier.fc.in_features
        self.classifier.fc = nn.Linear(num_ftrs, self.num_classes)

