In [1]:
import os
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
import torchvision.models as models

In [2]:
def gen_A(num_classes, t, adj_file):
    _adj = np.identity(num_classes, np.int32)
    return _adj

In [3]:
def gen_adj(A):
    D = torch.pow(A.sum(1).float(), -0.5)
    D = torch.diag(D)
    adj = torch.matmul(torch.matmul(A, D).t(), D)
    return adj

In [4]:
class GraphConvolution(nn.Module):
    def __init__(self, in_features, out_features, bias=False):
        super(GraphConvolution, self).__init__()
        self.in_feature = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.Tensor(in_features, out_features))
        if bias:
            self.bias = nn.Parameter(torch.Tensor(1, 1, out_features))
        else:
            self.register_parameter("bias", None)
    
    def forward(self, x, adj):
        support = torch.matmul(x, self.weight)
        output = torch.matmul(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

In [5]:
class resnet18_gcn(nn.Module):
    def __init__(self, model, num_classes, in_channels=300, t=0, adj_file=None):
        super(resnet18_gcn, self).__init__()
        self.cnn = model
        self.num_classes = num_classes
        self.t = t
        self.adj_file = adj_file
        self.pooling = nn.MaxPool2d(7, 7)
        self.gc1 = GraphConvolution(in_channels, 1024)
        self.gc2 = GraphConvolution(1024, 2048)
        self.relu = nn.LeakyReLU(0.2)
        
        _adj = gen_A(num_classes, t, adj_file)
        self.A = nn.Parameter(torch.from_numpy(_adj).float())
        
        self.image_normalization_mean = [0.485, 0.456, 0.406]
        self.image_normalization_std = [0.229, 0.224, 0.225]
        
    def forward(self, feature, inp):
        features = self.cnn(feature)
        features = self.pooling(features)
        features = features.view(features.size(0), -1)
        
        adj = gen_adj(self.A).detach()
        x = self.gc1(inp, adj)
        x = self.relu(x)
        x = self.gc2(x, adj)
        
        x = x.transpose(0, 1)
        x = torch.matmul(features, x)
        return x
    
    def get_config_optim(self, lr, lrp):
        return [
                {"params": self.features.parameters(), "lr": lr * lrp},
                {"params": self.gc1.parameters(),      "lr": lr},
                {"params": self.gc2.parameters(),      "lr": lr}
        ]

In [6]:
def get_resnet18_sp(num_classes, t, pretrained=False, adj_file=None, in_channels=300):
    model = models.resnet18(pretrained=pretrained)
    new_model = torch.nn.Sequential(*(list(model.children())[:-2]))
    return resnet18_gcn(new_model, num_classes, t=t, adj_file=adj_file, in_channels=in_channels)

In [7]:
model = models.resnet18(pretrained=True)
new_model = torch.nn.Sequential(*(list(model.children())[:-2]))

In [8]:
cnn = new_model
num_classes = 7
t = 7
adj_file = "/home/viper/Documents/GitHub/Notes/IRP/Code/wb/RAF_DB_glove_word2vec.npy"
pooling = nn.MaxPool2d(7, 7)
gc1 = GraphConvolution(300, 1024)
gc2 = GraphConvolution(1024, 512)
relu = nn.LeakyReLU(0.2)
_adj = gen_A(num_classes, t, adj_file)
A = nn.Parameter(torch.from_numpy(_adj).float())

In [9]:
img = torch.rand(1, 3, 224, 224)

In [10]:
features = cnn(img)
features.shape

torch.Size([1, 512, 7, 7])

In [11]:
features = pooling(features)
features.shape

torch.Size([1, 512, 1, 1])

In [12]:
features = features.view(features.size(0), -1)
features.shape

torch.Size([1, 512])

In [13]:
adj = gen_adj(A).detach()
adj.shape

torch.Size([7, 7])

In [14]:
inp = torch.rand(7,300)

In [15]:
x  = gc1(inp, adj)
x.shape

torch.Size([7, 1024])

In [16]:
x = relu(x)
x.shape

torch.Size([7, 1024])

In [17]:
x = gc2(x, adj)
x.shape

torch.Size([7, 512])

In [18]:
x = x.transpose(0, 1)
x.shape

torch.Size([512, 7])

In [19]:
x = torch.matmul(features, x)
x.shape

torch.Size([1, 7])

In [21]:
v = np.load("/home/viper/Documents/GitHub/EXP/ML-GCN/data/coco/coco_adj.pkl", allow_pickle=True)

In [22]:
v

{'nums': array([ 2243,  1171,  3924,  1618,  1804,  1884,   668,  2539,  3844,
         2287,  2241,  2098,  3734,  5968,  5028,  1340,  2791,  2080,
         8606,  1186,  2818,  3322,  8950,  3159,  3170,  1389,  6518,
         8378,  3041,  1062,  1518,  1205,  2537,  1511,  1798,   128,
         4861,  2068,   821,  1471,  1625,  3097,  2475,  1089,  2442,
         1290,  1216,  2003,   481, 45174,  2202,  3084,  1671,  2180,
         1645,   673,  1105,  3291,  2511,  2209,  1170,  2493,  2986,
         1214,  1631,  2343,  1510,  2368,  2667,   151,  2317,   700,
         2893,  2464,  4321,  3191,  2749,  2530,  1771,  1324]),
 'adj': array([[  0,   0,  46, ...,   0,   0,   0],
        [  0,   0,  20, ...,  49,  45,   0],
        [ 46,  20,   0, ...,  42,  35,   1],
        ...,
        [  0,  49,  42, ...,   0, 197,   1],
        [  0,  45,  35, ..., 197,   0,   0],
        [  0,   0,   1, ...,   1,   0,   0]])}