In [1]:
import os
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
import torchvision.models as models

In [2]:
def gen_A(num_classes, t, adj_file):
    _adj = np.identity(num_classes, np.int32)
    return _adj

In [10]:
def gen_adj(A):
    D = torch.pow(A.sum(1).float(), -0.5)
    D = torch.diag(D)
    adj = torch.matmul(torch.matmul(A, D).t(), D)
    return adj

In [4]:
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=False):
        super(GraphConvolution, self).__init__()
        self.in_feature = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(1, 1, out_features))
        else:
            self.register_parameter("bias", None)
    
    def forward(self, x, adj):
        support = torch.matmul(x, self.weight)
        output = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

In [5]:
class resnet18_gcn(nn.Module):
    def __init__(self, model, num_classes, in_channels=300, t=0, adj_file=None):
        super(resnet18_gcn, self).__init__()
        self.cnn = model
        self.num_classes = num_classes
        self.t = t
        self.adj_file = adj_file
        self.pooling = nn.MaxPool2d(7, 7)
        self.gc1 = GraphConvolution(in_channels, 1024)
        self.gc2 = GraphConvolution(1024, 2048)
        self.relu = nn.LeakyReLU(0.2)
        
        _adj = gen_A(num_classes, t, adj_file)
        self.A = nn.Parameter(torch.from_numpy(_adj).float())
        
        self.image_normalization_mean = [0.485, 0.456, 0.406]
        self.image_normalization_std = [0.229, 0.224, 0.225]
        
    def forward(self, feature, inp):
        features = self.cnn(feature)
        features = self.pooling(features)
        features = features.view(features.size(0), -1)
        
        
        inp = inp[0]
        adj = gen_adj(self.A).detach()
        x = self.gc1(inp, adj)
        x = self.relu(x)
        x = self.gc2(x, adj)
        
        x = x.tranpose(0, 1)
        x = torch.matmul(feature, x)
        return x
    
    def get_config_optim(self, lr, lrp):
        return [
                {"params": self.features.parameters(), "lr": lr * lrp},
                {"params": self.gc1.parameters(),      "lr": lr},
                {"params": self.gc2.parameters(),      "lr": lr}
        ]

In [6]:
def get_resnet18_sp(num_classes, t, pretrained=False, adj_file=None, in_channels=300):
    model = models.resnet18(pretrained=pretrained)
    new_model = torch.nn.Sequential(*(list(model.children())[:-2]))
    return resnet18_gcn(new_model, num_classes, t=t, adj_file=adj_file, in_channels=in_channels)

In [7]:
model = get_resnet18_sp(num_classes=7, t=0.4, pretrained=True, adj_file=None, in_channels=300)

In [8]:
feature = torch.rand(1, 3, 224, 224)
inp = np.load("/home/viper/Documents/Notes/IRP/Code/wb/RAF_DB_glove_word2vec.npy")

In [12]:
inp = torch.Tensor(inp)

In [13]:
res = model(feature, inp)

RuntimeError: size mismatch, got 7, 7x7,1024