In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
import matplotlib.pyplot as plt
from torchsummary import summary
import pickle
import gzip

In [2]:
class ClusterBlock(nn.Module):
    def __init__(self, n_input, n_output, v ,T, k_cluster):
        super(ClusterBlock, self).__init__( )
        self.n_input = n_input
        self.n_output = n_output
        self.v = v  # number of node
        self.T = T
        self.k_cluster = k_cluster
        self.linear1 = nn.Linear(n_input , 1)  # Dense
        self.linear2 = nn.Linear(T ,k_cluster ) # clustering
        self.softmax = nn.Softmax(dim =-1)
        # support = 3 , each support has k GAT
        self.attentions1 = [GraphAttentionLayer(T, n_input, n_output) for _ in range(k_cluster)]
        self.attentions2 = [GraphAttentionLayer(T, n_input, n_output) for _ in range(k_cluster)]
        self.attentions3 = [GraphAttentionLayer(T, n_input, n_output) for _ in range(k_cluster)]
        self.dropout = nn.Dropout(p=0.5)
        
    def forward(self, x, graph_list):  # x.shape = (b,v,T,f)
        b, *_ = x.size()
        # squeeze f
        out = self.linear1(x)    #  out.shape = (b , v ,T,1)
        out = out.view(-1, self.v, self.T)  #  out.shape = (b , v ,T)

        # clustering
        soft_cluster = self.softmax(self.linear2(out))   # out.shape = (b,v,k)
        
        # soft to hard        
#         m = torch.transpose(torch.max(soft_cluster, -1)[0].repeat(2,1),0,1)
        max_mat = torch.max(soft_cluster, -1)[0].unsqueeze(-1).repeat(1, 1, self.k_cluster)
        one_mat = torch.ones(b, self.v, self.k_cluster)
        zero_mat = torch.zeros(b, self.v, self.k_cluster)
        hard_cluster = torch.where(soft_cluster-max_mat>=0, one_mat, zero_mat)
        
        #graph attention
        gout = 0
        for graph in graph_list:   # graph_list : (3,3 ,n , n )
            graph1, graph2, graph3 = graph[0], graph[1], graph[2]
#             cluster_mask_out = torch.zeros(3, b, self.v, self.T*self.output)
            for i in range(self.k_cluster):
                out1 = torch.mul(self.attentions1[i](x, graph1), hard_cluster[:, :, i].unsqueeze(-1))
                out2 = torch.mul(self.attentions2[i](x, graph1), hard_cluster[:, :, i].unsqueeze(-1))
                out3 = torch.mul(self.attentions2[i](x, graph1), hard_cluster[:, :, i].unsqueeze(-1))
                out = F.relu((out1 + out2  + out3)/3 )
            gout += out
#                 cluster_mask_out[0] += out1
#                 cluster_mask_out[1] += out2
#                 cluster_mask_out[2] += out3
#             gout.append(torch.sum(cluster_mask_out, dim=0))

        return gout/3 , hard_cluster 

In [26]:
class GraphAttentionLayer(nn.Module):
    """
    Simple GAT layer, similar to https://arxiv.org/abs/1710.10903
    """

    def __init__(self, T , in_features, out_features, concat=False):
        super(GraphAttentionLayer, self).__init__()
        self.dropout = nn.Dropout(p=0.5)
        self.in_features = in_features
        self.out_features = out_features
        self.concat = concat
        self.W = nn.Parameter(torch.zeros(size=(T*in_features, T*out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)
        self.a = nn.Parameter(torch.zeros(size=(2*T*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU()

    def forward(self, input, adj): # input.shape = ( b,v, t ,f1)
        B , N , T = input.size()[0], input.size()[1], input.size()[2]     
        
        # h = xw , (b,N,t*f1)  * ( t*f1 ,  t*f2) -> (b ,N , t*f2)
        input = input.view(B*N, -1)
        h = torch.matmul(input, self.W).view(B , N , -1)
        
        # h.repeat(1,1,N) : (b, N, N*t*f2)  
        # h.repeat(1,N,1) : (b ,N*N , t*f2)  
        # output : (b , N , N, 2*t*f2)
        a_input = torch.cat([h.repeat(1, 1, N).view(B, N * N, -1), h.repeat(1, N, 1)], dim=1).view(B, N, N, 2*T*self.out_features)
        # e = a_input* a ,  (b,N,N,2*t*f2) * ( 2*t*f2 ,1) -> ( b ,N ,N )
        e = self.leakyrelu(torch.matmul(a_input, self.a).squeeze(-1))
        zero_vec = -9e15*torch.ones_like(e)
        # adj : (n, n) -> ( b , n ,n )
        b_adj = adj.unsqueeze(0).repeat(B, 1, 1)
        
        attention = torch.where(b_adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=-1)
#         attention = self.dropout(attention, self.dropout, training=self.training)
        
        # (b,n,n) * (b,n ,t*f2)  -> (b, n , t*f2)
        h_prime = torch.bmm(attention, h)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

    def __repr__(self):
        return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'


In [27]:
test_data = torch.rand(10 , 20 , 48 ,1)  # b,v,t, f
test_label = torch.rand(10, 20 , 48 ,1)
graph = torch.rand(20,20)
graph_list = torch.rand(3,3,20,20)

In [28]:
T = 48
in_features = 1 
out_features = 1
model_layer = GraphAttentionLayer(T , in_features, out_features)
model_layer(test_data,graph)

tensor([[[ 0.1944,  0.5102, -0.3534,  ..., -0.5934, -0.3720, -1.7551],
         [ 0.2738,  0.5211, -0.3906,  ..., -0.6415, -0.4070, -1.7689],
         [ 0.2683,  0.5204, -0.3881,  ..., -0.6382, -0.4046, -1.7679],
         ...,
         [ 0.2265,  0.5083, -0.3236,  ..., -0.6340, -0.4281, -1.8062],
         [ 0.2265,  0.5083, -0.3236,  ..., -0.6340, -0.4281, -1.8062],
         [ 0.2265,  0.5083, -0.3236,  ..., -0.6340, -0.4281, -1.8062]],

        [[ 0.1704,  0.6645, -0.6333,  ..., -0.5211, -0.6062, -1.8961],
         [ 0.3105,  0.7658, -0.6013,  ..., -0.6110, -0.6656, -1.7635],
         [ 0.1354,  0.6392, -0.6413,  ..., -0.4986, -0.5913, -1.9292],
         ...,
         [ 0.3333,  0.7678, -0.5493,  ..., -0.5518, -0.6182, -1.8569],
         [ 0.3333,  0.7678, -0.5493,  ..., -0.5518, -0.6182, -1.8569],
         [ 0.3333,  0.7678, -0.5493,  ..., -0.5518, -0.6182, -1.8569]],

        [[ 0.2952,  0.5778, -0.4222,  ..., -0.6484, -0.5660, -1.7679],
         [ 0.3128,  0.5388, -0.3102,  ..., -0

In [29]:
T  = 48
n_input = 1 
n_output = 1 
v = 20 
k_cluster = 3

In [30]:
model = ClusterBlock(n_input, n_output, v ,T, k_cluster)

In [32]:
model(test_data, graph_list)[0].shape 

torch.Size([10, 20, 48])