In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch import Tensor
from torch.nn import init
from torch.nn.parameter import Parameter
from torch.nn.functional import softmax

In [31]:
from math import ceil

In [None]:
class multi_shallow_embedding(nn.Module):
    
    def __init__(self, num_nodes, k_neighs, num_graphs):
        super().__init__()
        
        self.num_nodes = num_nodes
        self.k = k_neighs
        self.num_graphs = num_graphs

        self.emb_s = Parameter(Tensor(num_graphs, num_nodes, 1))
        self.emb_t = Parameter(Tensor(num_graphs, 1, num_nodes))
        
    def reset_parameters(self):
        init.xavier_uniform_(self.emb_s)
        init.xavier_uniform_(self.emb_t)
        
        
    def forward(self, device):
        
        # adj: [G, N, N]
        adj = torch.matmul(self.emb_s, self.emb_t).to(device)
        
        # remove self-loop
        adj = adj.clone()
        idx = torch.arange(self.num_nodes, dtype=torch.long, device=device)
        adj[:, idx, idx] = float('-inf')
        
        # top-k-edge adj
        adj_flat = adj.reshape(self.num_graphs, -1)
        indices = adj_flat.topk(k=self.k)[1].reshape(-1)
        
        idx = torch.tensor([ i//self.k for i in range(indices.size(0)) ], device=device)
        
        adj_flat = torch.zeros_like(adj_flat).clone()
        adj_flat[idx, indices] = 1.
        adj = adj_flat.reshape_as(adj)
        
        return adj

class Group_Linear(nn.Module):
    
    def __init__(self, in_channels, out_channels, groups=1, bias=False):
        super().__init__()
                
        self.out_channels = out_channels
        self.groups = groups
        
        self.group_mlp = nn.Conv2d(in_channels * groups, out_channels * groups, kernel_size=(1, 1), groups=groups, bias=bias)
        
        self.reset_parameters()
        
    def reset_parameters(self):
        self.group_mlp.reset_parameters()
        
        
    def forward(self, x: Tensor, is_reshape: False):
        """
        Args:
            x (Tensor): [B, C, N, F] (if not is_reshape), [B, C, G, N, F//G] (if is_reshape)
        """
        B = x.size(0)
        C = x.size(1)
        N = x.size(-2)
        G = self.groups
        
        if not is_reshape:
            # x: [B, C_in, G, N, F//G]
            x = x.reshape(B, C, N, G, -1).transpose(2, 3)
        # x: [B, G*C_in, N, F//G]
        x = x.transpose(1, 2).reshape(B, G*C, N, -1)
        
        out = self.group_mlp(x)
        out = out.reshape(B, G, self.out_channels, N, -1).transpose(1, 2)
        
        # out: [B, C_out, G, N, F//G]
        return out


class DenseGCNConv2d(nn.Module):
    
    def __init__(self, in_channels, out_channels, groups=1, bias=True):
        super().__init__()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        
        self.lin = Group_Linear(in_channels, out_channels, groups, bias=False)
        
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()
        
    def reset_parameters(self):
        self.lin.reset_parameters()
        init.zeros_(self.bias)
        
    def norm(self, adj: Tensor, add_loop):
        if add_loop:
            adj = adj.clone()
            idx = torch.arange(adj.size(-1), dtype=torch.long, device=adj.device)
            adj[:, idx, idx] += 1
        
        deg_inv_sqrt = adj.sum(-1).clamp(min=1).pow(-0.5)
        
        adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
        
        return adj
        
        
    def forward(self, x: Tensor, adj: Tensor, add_loop=True):
        """
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [B, G, N, N]
        """
        adj = self.norm(adj, add_loop).unsqueeze(1)

        # x: [B, C, G, N, F//G]
        x = self.lin(x, False)
        
        out = torch.matmul(adj, x)
        
        # out: [B, C, N, F]
        B, C, _, N, _ = out.size()
        out = out.transpose(2, 3).reshape(B, C, N, -1)
        
        if self.bias is not None:
            out = out.transpose(1, -1) + self.bias
            out = out.transpose(1, -1)
        
        return out


class DenseGINConv2d(nn.Module):
    
    def __init__(self, in_channels, out_channels, groups=1, eps=0, train_eps=True):
        super().__init__()
        
        # TODO: Multi-layer model
        self.mlp = Group_Linear(in_channels, out_channels, groups, bias=False)
        
        self.init_eps = eps
        if train_eps:
            self.eps = Parameter(Tensor([eps]))
        else:
            self.register_buffer('eps', Tensor([eps]))
            
        self.reset_parameters()
            
    def reset_parameters(self):
        self.mlp.reset_parameters()
        self.eps.data.fill_(self.init_eps)
        
    def norm(self, adj: Tensor, add_loop):
        if add_loop:
            adj = adj.clone()
            idx = torch.arange(adj.size(-1), dtype=torch.long, device=adj.device)
            adj[..., idx, idx] += 1
        
        deg_inv_sqrt = adj.sum(-1).clamp(min=1).pow(-0.5)
        
        adj = deg_inv_sqrt.unsqueeze(-1) * adj * deg_inv_sqrt.unsqueeze(-2)
        
        return adj
        
        
    def forward(self, x: Tensor, adj: Tensor, add_loop=True):
        """
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [G, N, N]
        """
        B, C, N, _ = x.size()
        G = adj.size(0)
        
        # adj-norm
        adj = self.norm(adj, add_loop=False)
        
        # x: [B, C, G, N, F//G]
        x = x.reshape(B, C, N, G, -1).transpose(2, 3)
        
        out = torch.matmul(adj, x)
        
        # DYNAMIC
        x_pre = x[:, :, :-1, ...]
        
        # out = x[:, :, 1:, ...] + x_pre
        out[:, :, 1:, ...] = out[:, :, 1:, ...] + x_pre
        # out = torch.cat( [x[:, :, 0, ...].unsqueeze(2), out], dim=2 )
        
        if add_loop:
            out = (1 + self.eps) * x + out
        
        # out: [B, C, G, N, F//G]
        out = self.mlp(out, True)
        
        # out: [B, C, N, F]
        C = out.size(1)
        out = out.transpose(2, 3).reshape(B, C, N, -1)
        
        return out


# class Dense_TimeDiffPool2d(nn.Module):
    
#     def __init__(self, pre_nodes, pooled_nodes, kern_size, padding):
#         super().__init__()
        
#         # TODO: add Normalization
#         # self.time_conv = nn.Conv2d(pre_nodes, pooled_nodes, (1, kern_size), padding=(0, padding))
#         self.time_conv = nn.Conv2d(pre_nodes, pooled_nodes, (1, kern_size), padding=(0, padding),stride=1)

        
#         self.re_param = Parameter(Tensor(kern_size, 1))
        
#     def reset_parameters(self):
#         self.time_conv.reset_parameters()
#         init.kaiming_uniform_(self.re_param, nonlinearity='relu')
        
        
#     def forward(self, x: Tensor, adj: Tensor):
#         """
#         Args:
#             x (Tensor): [B, C, N, F]
#             adj (Tensor): [G, N, N]
#         """
#         B, G, N, F = x.size(0), adj.size(0), adj.size(1), x.size(-1)
#         xlater = x.reshape(B, C, N, G, -1).transpose(2, 3) 
#         # print('x shape at start in diffpool', x.shape)
#         x = x.transpose(1, 2)
#         out = self.time_conv(x)
#         out = out.transpose(1, 2)
#         # print('x shape at end of time pool in diffpool', out.shape)
        
#         # Expand the adjacency matrix to include the batch dimension.
#         adj_expanded = adj.unsqueeze(0).repeat(B, 1, 1, 1)  # [B, G, N, N]
#         print('adj_expanded shape ', adj_expanded.shape)




        
#         # s: [ N^(l+1), N^l, 1, K ]     
#         s = torch.matmul(self.time_conv.weight, self.re_param).view(out.size(-2), -1)
#         # TODO: fully-connect, how to decrease time complexity
#         out_adj = torch.matmul(torch.matmul(s, adj), s.transpose(0, 1))
#         print('out adj shape in diffpool', out_adj.shape)

       
#         return out, out_adj

In [290]:
class Dense_TimeDiffPool2d(nn.Module):
    
    def __init__(self, pre_nodes, pooled_nodes, kern_size, padding):
        super().__init__()
        
        # TODO: add Normalization
        self.time_conv = nn.Conv2d(pre_nodes, pooled_nodes, (1, kern_size), padding=(0, padding))
        self.re_param = Parameter(Tensor(kern_size, 1))

        ##added for attention in this module
        # Adjusted attention layers for query, key 
        # Todo (and optionally value) -does it make sense
        self.query = nn.Linear(18, 18)
        self.key = nn.Linear(18, 18)


        self.reset_parameters()
        
    def reset_parameters(self):
        self.time_conv.reset_parameters()
        init.kaiming_uniform_(self.re_param, nonlinearity='relu')

        ######################################## for the attention
        init.xavier_uniform_(self.query.weight)
        init.xavier_uniform_(self.key.weight)
        
        
    def forward(self, x: Tensor, adj: Tensor):
        """
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [G, N, N]
        """
        B, C, N, F = x.shape ## added lines  # Assuming x is [B, C, N, F]
        #################################################################
        x = x.transpose(1, 2)
        out = self.time_conv(x)
        out = out.transpose(1, 2)

        out_features = out
        ##################################################################

        # s: [ N^(l+1), N^l, 1, K ]
        s = torch.matmul(self.time_conv.weight, self.re_param).view(out.size(-2), -1)

        # TODO: fully-connect, how to decrease time complexity
        out_adj = torch.matmul(torch.matmul(s, adj), s.transpose(0, 1))
        ###################################################################
        
        G = adj.size(0)
        # Reshape and transpose to include the graph dimension G
        # xlater = out_features.reshape(B, C, self.G, N, -1).transpose(2, 3)  # [B, C, N, G, F']
        out_features = out_features.view(B, C, N, G, -1).permute(0, 1, 3, 2, 4)  # Reshape to [B, C, G, N, F']
        out_features_flat = out_features.reshape(B, C, G, N, -1)
        # Apply query and key transformations
        queries = self.query(out_features_flat)  # Shape: [B, C, G, N, pooled_nodes]
        keys = self.key(out_features_flat)  # Shape: [B, C, G, N, pooled_nodes]

        # Calculate attention scores
        attention_scores = torch.einsum('bcgnl,bcgml->bcgnm', (queries, keys))  # [B, C, G, N, N]
        print('att score shape before sfmax', attention_scores.shape)
        attention_scores = softmax(attention_scores, dim=-1)  # Normalize over the last dimension
        print('att score shape', attention_scores.shape)
        attention_scores = torch.mean(attention_scores, dim=1)

        # Combine attention_scores with the adjacency matrix
        # Assuming adj is already [B, G, N, N]
        adj_expanded = out_adj.unsqueeze(0).repeat(B, 1, 1, 1)
        print('adj_expanded shape', adj_expanded.shape)
        weighted_adj = attention_scores * adj_expanded  # Element-wise multiplication
     
        return out, out_adj, weighted_adj

In [291]:

class GNNStack(nn.Module):
    """ The stack layers of GNN.

    """

    def __init__(self, gnn_model_type, num_layers, groups, pool_ratio, kern_size, 
                 in_dim, hidden_dim, out_dim, 
                 seq_len, num_nodes, num_classes, dropout=0.5, 
                 # activation=nn.ReLU()
                 activation=nn.SELU()
                
                ):

        super().__init__()
        
        # TODO: Sparsity Analysis
        k_neighs = self.num_nodes = num_nodes
        
        self.num_graphs = groups
        
        self.num_feats = seq_len
        if seq_len % groups:
            self.num_feats += ( groups - seq_len % groups )
        self.g_constr = multi_shallow_embedding(num_nodes, k_neighs, self.num_graphs)
        
        gnn_model, heads = self.build_gnn_model(gnn_model_type)
        
        assert num_layers >= 1, 'Error: Number of layers is invalid.'
        assert num_layers == len(kern_size), 'Error: Number of kernel_size should equal to number of layers.'
        paddings = [ (k - 1) // 2 for k in kern_size ]
        
        self.tconvs = nn.ModuleList(
            [nn.Conv2d(1, in_dim, (1, kern_size[0]), padding=(0, paddings[0]))] + 
            [nn.Conv2d(heads * in_dim, hidden_dim, (1, kern_size[layer+1]), padding=(0, paddings[layer+1])) for layer in range(num_layers - 2)] + 
            [nn.Conv2d(heads * hidden_dim, out_dim, (1, kern_size[-1]), padding=(0, paddings[-1]))]
        )
        
        self.gconvs = nn.ModuleList(
            [gnn_model(in_dim, heads * in_dim, groups)] + 
            [gnn_model(hidden_dim, heads * hidden_dim, groups) for _ in range(num_layers - 2)] + 
            [gnn_model(out_dim, heads * out_dim, groups)]
        )
        
        self.bns = nn.ModuleList(
            [nn.BatchNorm2d(heads * in_dim)] + 
            [nn.BatchNorm2d(heads * hidden_dim) for _ in range(num_layers - 2)] + 
            [nn.BatchNorm2d(heads * out_dim)]
        )
        
        self.left_num_nodes = []
        for layer in range(num_layers + 1):
            left_node = round( num_nodes * (1 - (pool_ratio*layer)) )
            if left_node > 0:
                self.left_num_nodes.append(left_node)
            else:
                self.left_num_nodes.append(1)
        self.diffpool = nn.ModuleList(
            [Dense_TimeDiffPool2d(self.left_num_nodes[layer], self.left_num_nodes[layer+1], kern_size[layer], paddings[layer]) for layer in range(num_layers - 1)] + 
            [Dense_TimeDiffPool2d(self.left_num_nodes[-2], self.left_num_nodes[-1], kern_size[-1], paddings[-1])]
        )
        
        self.num_layers = num_layers
        self.dropout = dropout
        self.activation = activation
        
        self.softmax = nn.Softmax(dim=-1)
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        # self.global_pool = nn.AdaptiveAvgPool2d((1,None))
        
        self.linear = nn.Linear(heads * out_dim, num_classes)
        
        self.reset_parameters()
        
        
    def reset_parameters(self):
        for tconv, gconv, bn, pool in zip(self.tconvs, self.gconvs, self.bns, self.diffpool):
            tconv.reset_parameters()
            gconv.reset_parameters()
            bn.reset_parameters()
            pool.reset_parameters()
        
        self.linear.reset_parameters()
        
        
    def build_gnn_model(self, model_type):
        if model_type == 'dyGCN2d':
            return DenseGCNConv2d, 1
        if model_type == 'dyGIN2d':
            return DenseGINConv2d, 1
        if model_type == 'dyGAT2d':
            return DenseGATConv2d, 1
        

    # def forward(self, inputs: Tensor):
        
    #     if inputs.size(-1) % self.num_graphs:
    #         pad_size = (self.num_graphs - inputs.size(-1) % self.num_graphs) / 2
    #         x = F.pad(inputs, (int(pad_size), ceil(pad_size)), mode='constant', value=0.0)
    #     else:
    #         x = inputs
            
    #     adj = self.g_constr(x.device)
        
    #     for tconv, gconv, bn, pool in zip(self.tconvs, self.gconvs, self.bns, self.diffpool):
            
    #         x, adj = pool( gconv( tconv(x), adj ), adj )
            
    #         x = self.activation( bn(x) )
            
    #         x = F.dropout(x, p=self.dropout, training=self.training)
        
    #     # print('before pool',x.shape)
    #     out = self.global_pool(x)
    #     # print('after pool',out.shape)
    #     # out = out.view(out.size(0),out.size(-1),-1)
    #     out = out.view(out.size(0),-1)
    #     # print('after reshape',out.shape)
    #     out = self.linear(out)
    #     # print('logits shape',out.shape)
    #     # break
    #     return out
    def forward(self, inputs: Tensor):
        feature_time_maps = []
        adj_matrices = []
    
        if inputs.size(-1) % self.num_graphs:
            pad_size = (self.num_graphs - inputs.size(-1) % self.num_graphs) / 2
            x = F.pad(inputs, (int(pad_size), ceil(pad_size)), mode='constant', value=0.0)
        else:
            x = inputs
        # print('x before anything',x.shape)
        
    
        adj = self.g_constr(x.device)
        # print('adj adjacency shape',adj.shape)
        
        layer_index = 0
        for tconv, gconv, bn, pool in zip(self.tconvs, self.gconvs, self.bns, self.diffpool):
            # print('x before time convolution', x.shape, 'layer_index',layer_index)
            x = tconv(x)  # Apply time convolution
            # print('x after time convolution', x.shape, 'layer_index',layer_index)
            # sys.exit()
            # print('x before graph convolution', x.shape, 'layer_index',layer_index)
            x = gconv(x, adj)  # Apply graph convolution
            # print('x after graph convolution', x.shape, 'layer_index',layer_index)
            # break
            # feature_time_maps.append(x.detach())  # Save feature representation
            # adj_matrices.append(adj.detach())  # Save adjacency matrix after pooling
            x, adj, awadj = pool(x, adj)  # Apply pooling, which might modify x and adj
            # print('x after time diff pool', x.shape, 'layer_index',layer_index)
            # print('adj after time diff pool', adj.shape, 'layer_index',layer_index)
            # feature_time_maps.append(x.detach())  # Save feature representation
            # adj_matrices.append(adj.detach())  # Save adjacency matrix after pooling
    
            x = self.activation(bn(x))
            x = F.dropout(x, p=self.dropout, training=self.training)

            layer_index += 1
    
        out = self.global_pool(x)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        # print(out.shape)
        # sys.exit()
    
        return out, feature_time_maps, adj_matrices


In [292]:
args = {
    'arch': 'dyGIN2d', #what other models I can put here?? dyGCN2d, dyGIN2d
    'dataset': 'Mortality', # "AtrialFibrillation" # 'Mortality', # 'MIMIC3'
    'num_layers': 2,  # the number of GNN layers  3
    'groups': 16,  # the number of time series groups (num_graphs)
    'pool_ratio': 0,  # the ratio of pooling for nodes
    'kern_size': [3,3],  # list of time conv kernel size for each layer [9,5,3]
    'in_dim': 4,  # input dimensions of GNN stacks
    'hidden_dim': 4,  # hidden dimensions of GNN stacks
    'out_dim': 4,  # output dimensions of GNN stacks
    'workers': 4,  # number of data loading workers
    'epochs': 30,  # number of total epochs to run
    'batch_size': 4,  # mini-batch size, this is the total batch size of all GPUs
    'val_batch_size': 4,  # validation batch size
    'lr': 0.0001,  # initial learning rate
    'weight_decay': 1e-4,  # weight decay
    'evaluate': False,  # evaluate model on validation set
    'seed': 2,  # seed for initializing training
    'gpu': 0,  # GPU id to use
    'use_benchmark': True,  # use benchmark
    'tag': 'date',  # the tag for identifying the log and model files
    'loss':'bce'
}

In [293]:
model = GNNStack(gnn_model_type=args['arch'], num_layers=args['num_layers'], 
                     groups=args['groups'], pool_ratio=args['pool_ratio'], kern_size=args['kern_size'], 
                     in_dim=args['in_dim'], hidden_dim=args['hidden_dim'], out_dim=args['out_dim'], 
                     seq_len=288, num_nodes=231, num_classes=2)

In [294]:
# model
model.eval()

GNNStack(
  (g_constr): multi_shallow_embedding()
  (tconvs): ModuleList(
    (0): Conv2d(1, 4, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
    (1): Conv2d(4, 4, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
  )
  (gconvs): ModuleList(
    (0-1): 2 x DenseGINConv2d(
      (mlp): Group_Linear(
        (group_mlp): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), groups=16, bias=False)
      )
    )
  )
  (bns): ModuleList(
    (0-1): 2 x BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (diffpool): ModuleList(
    (0-1): 2 x Dense_TimeDiffPool2d(
      (time_conv): Conv2d(231, 231, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1))
      (query): Linear(in_features=18, out_features=18, bias=True)
      (key): Linear(in_features=18, out_features=18, bias=True)
    )
  )
  (activation): SELU()
  (softmax): Softmax(dim=-1)
  (global_pool): AdaptiveAvgPool2d(output_size=1)
  (linear): Linear(in_features=4, out_features=2, bias=True)
)

In [295]:
### make random pytorch tensor of (32, 1, 231, 288) (B, C, N, F)
rt = torch.rand(32, 1, 231, 288)

In [296]:
out, feature_time_maps, adj_matrices = model(rt)

att score shape before sfmax torch.Size([32, 4, 16, 231, 231])
att score shape torch.Size([32, 4, 16, 231, 231])
adj_expanded shape torch.Size([32, 16, 231, 231])
att score shape before sfmax torch.Size([32, 4, 16, 231, 231])
att score shape torch.Size([32, 4, 16, 231, 231])
adj_expanded shape torch.Size([32, 16, 231, 231])


In [116]:
# out.shape, len(feature_time_maps), len(adj_matrices)
# (torch.Size([32, 2]), 2, 2)

In [52]:
out.shape, feature_time_maps[0].shape, feature_time_maps[1].shape, adj_matrices[0].shape,  adj_matrices[1].shape

(torch.Size([32, 2]),
 torch.Size([32, 16, 231, 288]),
 torch.Size([32, 16, 208, 288]),
 torch.Size([16, 231, 231]),
 torch.Size([16, 208, 208]))

In [None]:
231-3+2

In [193]:
temp_adj = torch.rand(16, 231, 231)
temp_batched_adj = torch.rand(32, 16, 231, 231)
temp_s = torch.rand(231, 231)

In [195]:
out_adj = torch.matmul(torch.matmul(temp_s, temp_adj), temp_s.transpose(0, 1))

In [196]:
out_adj.shape

torch.Size([16, 231, 231])

In [201]:
# out_adj

In [197]:
out_adj = torch.matmul(torch.matmul(temp_s, temp_batched_adj), temp_s.transpose(0, 1))

In [199]:
out_adj.shape

torch.Size([32, 16, 231, 231])

In [None]:
        Args:
            x (Tensor): [B, C, N, F]
            adj (Tensor): [B, G, N, N]
        """
        adj = self.norm(adj, add_loop).unsqueeze(1)

        # x: [B, C, G, N, F//G]
        x = self.lin(x, False)
        
        out = torch.matmul(adj, x)

In [None]:
        # x: [B, C, G, N, F//G]
        x = x.reshape(B, C, N, G, -1).transpose(2, 3)

In [212]:
xtem= torch.rand(32, 16, 231, 288)

In [215]:
xtem = xtem.reshape(32, 16, 231, 16, -1)
# .transpose(2, 3)

In [214]:
xtem.shape

torch.Size([32, 16, 231, 16, 18])