In [1]:
import torch_geometric

import numpy as np
import pandas as pd
import torch
from torch.nn import Linear, LayerNorm, ReLU, Dropout
import torch.nn.functional as F
from torch_geometric.nn import ChebConv, NNConv, DeepGCNLayer, GATConv, DenseGCNConv, GCNConv, GraphConv
from torch_geometric.data import Data, DataLoader
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, accuracy_score
import scipy.sparse as sp

import warnings
warnings.filterwarnings("ignore")

# ref: https://medium.com/stanford-cs224w/fraud-detection-with-gat-edac49bda1a0

#### Import dataset

In [2]:
# import data 
df_features = pd.read_csv('../data/elliptic_txs_features.csv', header=None)
df_edges = pd.read_csv("../data/elliptic_txs_edgelist.csv")
df_classes =  pd.read_csv("../data/elliptic_txs_classes.csv")

df_classes['class'] = df_classes['class'].map({'unknown': 2, '1':1, '2':0})

# merging dataframes
df_merge = df_features.merge(df_classes, how='left', right_on="txId", left_on=0)
df_merge.drop(0, axis=1, inplace=True)

# check if there are duplicate txId
print("Number of duplicate txId: ", df_merge.duplicated(subset=['txId']).sum())


Number of duplicate txId:  0


In [3]:
# rename column 0 to time_step
df_merge.rename(columns={1: 'time_step'}, inplace=True)
display(df_merge.head())
display(df_edges.shape)

Unnamed: 0,time_step,2,3,4,5,6,7,8,9,10,...,159,160,161,162,163,164,165,166,txId,class
0,1,-0.171469,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162097,-0.167933,...,1.46133,1.461369,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,230425980,2
1,1,-0.171484,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162112,-0.167948,...,-0.979074,-0.978556,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,5530458,2
2,1,-0.172107,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162749,-0.168576,...,-0.979074,-0.978556,-0.098889,-0.106715,-0.131155,-0.183671,-0.120613,-0.119792,232022460,2
3,1,0.163054,1.96379,-0.646376,12.409294,-0.063725,9.782742,12.414558,-0.163645,-0.115831,...,0.241128,0.241406,1.072793,0.08553,-0.131155,0.677799,-0.120613,-0.119792,232438397,0
4,1,1.011523,-0.081127,-1.201369,1.153668,0.333276,1.312656,-0.061584,-0.163523,0.041399,...,0.517257,0.579382,0.018279,0.277775,0.326394,1.29375,0.178136,0.179117,230460314,2


(234355, 2)

In [None]:
df_merge.shape

(203769, 168)

#### Split dataset masks

In [None]:
edges = df_edges.copy()

# Setup trans ID to node ID mapping
nodes = df_merge['txId'].values
map_id = {j:i for i,j in enumerate(nodes)} # mapping nodes to indexes

# Map transction IDs to node Ids
edges.txId1 = edges.txId1.map(map_id) #get nodes idx1 from edges list and filtered data
edges.txId2 = edges.txId2.map(map_id)
edges = edges.astype(int)

# Reformat and convert to tensor
edge_index = np.array(edges.values).T 
edge_index = torch.tensor(edge_index, dtype=torch.long).contiguous()

print("shape of edge index is {}".format(edge_index.shape))

shape of edge index is torch.Size([2, 234355])


In [None]:
node_features = df_merge.drop(['txId'], axis=1).copy()
print("unique=",node_features["class"].unique())

# Retain known vs unknown IDs
all_classified_idx = node_features['class'].loc[node_features['class']!=2].index # filter on known labels
all_unclassified_idx = node_features['class'].loc[node_features['class']==2].index
all_classified_illicit_idx = node_features['class'].loc[node_features['class']==1].index # filter on illicit labels
all_classified_licit_idx = node_features['class'].loc[node_features['class']==0].index # filter on licit labels

# node_features = node_features.drop(columns=[0, 1, 'class'])
display(node_features.head())

unique= [2 0 1]


Unnamed: 0,time_step,2,3,4,5,6,7,8,9,10,...,158,159,160,161,162,163,164,165,166,class
0,1,-0.171469,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162097,-0.167933,...,-0.600999,1.46133,1.461369,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,2
1,1,-0.171484,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162112,-0.167948,...,0.673103,-0.979074,-0.978556,0.018279,-0.08749,-0.131155,-0.097524,-0.120613,-0.119792,2
2,1,-0.172107,-0.184668,-1.201369,-0.12197,-0.043875,-0.113002,-0.061584,-0.162749,-0.168576,...,0.439728,-0.979074,-0.978556,-0.098889,-0.106715,-0.131155,-0.183671,-0.120613,-0.119792,2
3,1,0.163054,1.96379,-0.646376,12.409294,-0.063725,9.782742,12.414558,-0.163645,-0.115831,...,-0.613614,0.241128,0.241406,1.072793,0.08553,-0.131155,0.677799,-0.120613,-0.119792,0
4,1,1.011523,-0.081127,-1.201369,1.153668,0.333276,1.312656,-0.061584,-0.163523,0.041399,...,-0.400422,0.517257,0.579382,0.018279,0.277775,0.326394,1.29375,0.178136,0.179117,2


In [None]:
df_merge.shape

(203769, 168)

In [None]:
train_classified_idx = node_features.loc[(node_features['time_step'] <= 34) & (node_features['class'] != 2)].index
test_classified_idx = node_features.loc[(node_features['time_step'] > 34) & (node_features['class'] != 2)].index
print("train_classified_idx.shape=",train_classified_idx.shape)
print("test_classified_idx.shape=",test_classified_idx.shape)

train_classified_idx.shape= (29894,)
test_classified_idx.shape= (16670,)


In [None]:
# node_features.drop(columns=['time_step'], inplace=True)
node_features.drop(columns=['class'], inplace=True)

# Convert to tensor
node_features_t = torch.tensor(np.array(node_features.values, dtype=np.double), dtype=torch.double)

In [None]:
# Define labels
labels = df_merge['class'].values

#create weights tensor with same shape of edge_index
weights = torch.tensor([1]* edge_index.shape[1] , dtype=torch.double) 

# Do train test split on classified_ids
train_idx = train_classified_idx
test_idx = test_classified_idx

# Create pyG dataset
data_graph = Data(x=node_features_t.float(), edge_index=edge_index, edge_attr=weights, 
                               y=torch.tensor(labels, dtype=torch.long))

# Add in the train and valid idx
data_graph.train_idx = train_idx
data_graph.test_idx = test_idx
data_graph

Data(x=[203769, 166], edge_index=[2, 234355], edge_attr=[234355], y=[203769], train_idx=Int64Index([     3,      9,     10,     11,     16,     17,     25,     27,
                29,     30,
            ...
            136232, 136233, 136234, 136236, 136239, 136241, 136243, 136249,
            136250, 136258],
           dtype='int64', length=29894), test_idx=Int64Index([136276, 136277, 136278, 136279, 136280, 136282, 136285, 136287,
            136288, 136291,
            ...
            203727, 203730, 203736, 203740, 203750, 203752, 203754, 203759,
            203763, 203766],
           dtype='int64', length=16670))

#### STAGN Model and Training

In [None]:
import torch
from torch_geometric.nn import GCNConv


class TGCN(torch.nn.Module):
    r"""An implementation of the Temporal Graph Convolutional Gated Recurrent Cell.
    For details see this paper: `"T-GCN: A Temporal Graph ConvolutionalNetwork for
    Traffic Prediction." <https://arxiv.org/abs/1811.05320>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        improved (bool): Stronger self loops. Default is False.
        cached (bool): Caching the message weights. Default is False.
        add_self_loops (bool): Adding self-loops for smoothing. Default is True.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        improved: bool = False,
        cached: bool = False,
        add_self_loops: bool = True,
    ):
        super(TGCN, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops

        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):

        self.conv_z = GCNConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )

        self.linear_z = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_reset_gate_parameters_and_layers(self):

        self.conv_r = GCNConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )

        self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_candidate_state_parameters_and_layers(self):

        self.conv_h = GCNConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            improved=self.improved,
            cached=self.cached,
            add_self_loops=self.add_self_loops,
        )

        self.linear_h = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            H = torch.zeros(X.shape[0], self.out_channels).to(X.device)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H):
        Z = torch.cat([self.conv_z(X, edge_index, edge_weight), H], axis=1)
        Z = self.linear_z(Z)
        Z = torch.sigmoid(Z)
        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
        R = torch.cat([self.conv_r(X, edge_index, edge_weight), H], axis=1)
        R = self.linear_r(R)
        R = torch.sigmoid(R)
        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
        H_tilde = torch.cat([self.conv_h(X, edge_index, edge_weight), H * R], axis=1)
        H_tilde = self.linear_h(H_tilde)
        H_tilde = torch.tanh(H_tilde)
        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde
        return H

    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
        H: torch.FloatTensor = None,
    ) -> torch.FloatTensor:
        """
        Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** *(PyTorch Float Tensor)* - Node features.
            * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
            * **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
            * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.

        Return types:
            * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
        H = self._calculate_hidden_state(Z, H, H_tilde)
        return H


class TGCN2(torch.nn.Module):
    r"""An implementation THAT SUPPORTS BATCHES of the Temporal Graph Convolutional Gated Recurrent Cell.
    For details see this paper: `"T-GCN: A Temporal Graph ConvolutionalNetwork for
    Traffic Prediction." <https://arxiv.org/abs/1811.05320>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        batch_size (int): Size of the batch.
        improved (bool): Stronger self loops. Default is False.
        cached (bool): Caching the message weights. Default is False.
        add_self_loops (bool): Adding self-loops for smoothing. Default is True.
    """

    def __init__(self, in_channels: int, out_channels: int, 
                 batch_size: int,  # this entry is unnecessary, kept only for backward compatibility
                 improved: bool = False, cached: bool = False, 
                 add_self_loops: bool = True):
        super(TGCN2, self).__init__()

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.improved = improved
        self.cached = cached
        self.add_self_loops = add_self_loops
        self.batch_size = batch_size  # not needed
        self._create_parameters_and_layers()

    def _create_update_gate_parameters_and_layers(self):
        self.conv_z = GCNConv(in_channels=self.in_channels,  out_channels=self.out_channels, improved=self.improved,
                              cached=self.cached, add_self_loops=self.add_self_loops )
        self.linear_z = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_reset_gate_parameters_and_layers(self):
        self.conv_r = GCNConv(in_channels=self.in_channels, out_channels=self.out_channels, improved=self.improved,
                              cached=self.cached, add_self_loops=self.add_self_loops )
        self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_candidate_state_parameters_and_layers(self):
        self.conv_h = GCNConv(in_channels=self.in_channels, out_channels=self.out_channels, improved=self.improved,
                              cached=self.cached, add_self_loops=self.add_self_loops )
        self.linear_h = torch.nn.Linear(2 * self.out_channels, self.out_channels)

    def _create_parameters_and_layers(self):
        self._create_update_gate_parameters_and_layers()
        self._create_reset_gate_parameters_and_layers()
        self._create_candidate_state_parameters_and_layers()

    def _set_hidden_state(self, X, H):
        if H is None:
            # can infer batch_size from X.shape, because X is [B, N, F]
            H = torch.zeros(X.shape[0], X.shape[1], self.out_channels).to(X.device) #(b, 207, 32)
        return H

    def _calculate_update_gate(self, X, edge_index, edge_weight, H):
        Z = torch.cat([self.conv_z(X, edge_index, edge_weight), H], axis=2) # (b, 207, 64)
        Z = self.linear_z(Z) # (b, 207, 32)
        Z = torch.sigmoid(Z)

        return Z

    def _calculate_reset_gate(self, X, edge_index, edge_weight, H):
        R = torch.cat([self.conv_r(X, edge_index, edge_weight), H], axis=2) # (b, 207, 64)
        R = self.linear_r(R) # (b, 207, 32)
        R = torch.sigmoid(R)

        return R

    def _calculate_candidate_state(self, X, edge_index, edge_weight, H, R):
        H_tilde = torch.cat([self.conv_h(X, edge_index, edge_weight), H * R], axis=2) # (b, 207, 64)
        H_tilde = self.linear_h(H_tilde) # (b, 207, 32)
        H_tilde = torch.tanh(H_tilde)

        return H_tilde

    def _calculate_hidden_state(self, Z, H, H_tilde):
        H = Z * H + (1 - Z) * H_tilde   # # (b, 207, 32)
        return H

    def forward(self,X: torch.FloatTensor, edge_index: torch.LongTensor, edge_weight: torch.FloatTensor = None,
                H: torch.FloatTensor = None ) -> torch.FloatTensor:
        """
        Making a forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph. If the hidden state matrix is not present
        when the forward pass is called it is initialized with zeros.

        Arg types:
            * **X** *(PyTorch Float Tensor)* - Node features.
            * **edge_index** *(PyTorch Long Tensor)* - Graph edge indices.
            * **edge_weight** *(PyTorch Long Tensor, optional)* - Edge weight vector.
            * **H** *(PyTorch Float Tensor, optional)* - Hidden state matrix for all nodes.

        Return types:
            * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
        """
        H = self._set_hidden_state(X, H)
        Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
        R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
        H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
        H = self._calculate_hidden_state(Z, H, H_tilde) # (b, 207, 32)
        return H

In [None]:
num_features = data_graph.num_node_features
device = torch.device('cpu')

In [None]:
def train(model, data, optimizer):
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    # out = out.reshape((data.x.shape[0]))
    # TODO :use weighted cross entropy loss
    loss = F.cross_entropy(out[data.train_idx], data.y[data.train_idx])
    loss.backward()
    optimizer.step()
    return loss.item()

def test(model, data):
    model.eval()
    with torch.no_grad():
        data = data.to(device)
        out = model(data.x, data.edge_index)
        pred_scores = out[data.test_idx]
        pred = torch.argmax(pred_scores, dim=1)
        # print(sum([1 for i in pred.tolist() if i == 1]))
        y = data.y[data.test_idx]
        acc = accuracy_score(y.cpu(), pred.cpu())
        f1 = f1_score(y.cpu(), pred.cpu())
        precision = precision_score(y.cpu(), pred.cpu())
        recall = recall_score(y.cpu(), pred.cpu())
        roc = roc_auc_score(y.cpu(), pred.cpu())
        return acc, f1, precision, recall, roc

In [None]:
model_tgcn = TGCN(in_channels=num_features, out_channels=2)
num_epochs = 200
lr = 0.001
optimizer = torch.optim.Adam(model_tgcn.parameters(), lr=lr)

for epoch in range(num_epochs+1):
    loss = train(model_tgcn, data_graph, optimizer)
    acc, f1, precision, recall, roc = test(model_tgcn, data_graph)
    if epoch % 10 == 0:
        print(f'Epoch: {epoch}, Loss: {loss:.4f}, Acc: {acc:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, recall: {recall:.4f}, roc: {roc:.4f}')
    

In [None]:
# test 
acc, f1, precision, recall, roc = test(model_tgcn, data_graph)
print(f"Accuracy: {acc:.4f}, F1: {f1:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, ROC: {roc:.4f}")

Accuracy: 0.9344, F1: 0.0000, Precision: 0.0000, Recall: 0.0000, ROC: 0.4996


In [None]:
class TemporalConv(torch.nn.Module):
    r"""Temporal convolution block applied to nodes in the STGCN Layer
    For details see: `"Spatio-Temporal Graph Convolutional Networks:
    A Deep Learning Framework for Traffic Forecasting."
    <https://arxiv.org/abs/1709.04875>`_ Based off the temporal convolution
     introduced in "Convolutional Sequence to Sequence Learning"  <https://arxiv.org/abs/1709.04875>`_

    Args:
        in_channels (int): Number of input features.
        out_channels (int): Number of output features.
        kernel_size (int): Convolutional kernel size.
    """

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int = 3):
        super(TemporalConv, self).__init__()
        self.conv_1 = torch.nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv_2 = torch.nn.Conv2d(in_channels, out_channels, (1, kernel_size))
        self.conv_3 = torch.nn.Conv2d(in_channels, out_channels, (1, kernel_size))

    def forward(self, X: torch.FloatTensor) -> torch.FloatTensor:
        """Forward pass through temporal convolution block.

        Arg types:
            * **X** (torch.FloatTensor) -  Input data of shape
                (batch_size, input_time_steps, num_nodes, in_channels).

        Return types:
            * **H** (torch.FloatTensor) - Output data of shape
                (batch_size, in_channels, num_nodes, input_time_steps).
        """
        X = X.permute(0, 3, 2, 1)
        P = self.conv_1(X)
        Q = torch.sigmoid(self.conv_2(X))
        PQ = P * Q
        H = F.relu(PQ + self.conv_3(X))
        H = H.permute(0, 3, 2, 1)
        return H

class STConv(torch.nn.Module):
    r"""Spatio-temporal convolution block using ChebConv Graph Convolutions.
    For details see: `"Spatio-Temporal Graph Convolutional Networks:
    A Deep Learning Framework for Traffic Forecasting"
    <https://arxiv.org/abs/1709.04875>`_

    NB. The ST-Conv block contains two temporal convolutions (TemporalConv)
    with kernel size k. Hence for an input sequence of length m,
    the output sequence will be length m-2*(k-1).

    Args:
        in_channels (int): Number of input features.
        hidden_channels (int): Number of hidden units output by graph convolution block
        out_channels (int): Number of output features.
        kernel_size (int): Size of the kernel considered.
        K (int): Chebyshev filter size :math:`K`.
        normalization (str, optional): The normalization scheme for the graph
            Laplacian (default: :obj:`"sym"`):

            1. :obj:`None`: No normalization
            :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}`

            2. :obj:`"sym"`: Symmetric normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A}
            \mathbf{D}^{-1/2}`

            3. :obj:`"rw"`: Random-walk normalization
            :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}`

            You need to pass :obj:`lambda_max` to the :meth:`forward` method of
            this operator in case the normalization is non-symmetric.
            :obj:`\lambda_max` should be a :class:`torch.Tensor` of size
            :obj:`[num_graphs]` in a mini-batch scenario and a
            scalar/zero-dimensional tensor when operating on single graphs.
            You can pre-compute :obj:`lambda_max` via the
            :class:`torch_geometric.transforms.LaplacianLambdaMax` transform.
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)

    """

    def __init__(
        self,
        num_nodes: int,
        in_channels: int,
        hidden_channels: int,
        out_channels: int,
        kernel_size: int,
        K: int,
        normalization: str = "sym",
        bias: bool = True,
    ):
        super(STConv, self).__init__()
        self.num_nodes = num_nodes
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.K = K
        self.normalization = normalization
        self.bias = bias

        self._temporal_conv1 = TemporalConv(
            in_channels=in_channels,
            out_channels=hidden_channels,
            kernel_size=kernel_size,
        )

        self._graph_conv = ChebConv(
            in_channels=hidden_channels,
            out_channels=hidden_channels,
            K=K,
            normalization=normalization,
            bias=bias,
        )

        self._temporal_conv2 = TemporalConv(
            in_channels=hidden_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
        )

        self._batch_norm = torch.nn.BatchNorm2d(num_nodes)

        
    def forward(
        self,
        X: torch.FloatTensor,
        edge_index: torch.LongTensor,
        edge_weight: torch.FloatTensor = None,
    ) -> torch.FloatTensor:

        r"""Forward pass. If edge weights are not present the forward pass
        defaults to an unweighted graph.

        Arg types:
            * **X** (PyTorch FloatTensor) - Sequence of node features of shape (Batch size X Input time steps X Num nodes X In channels).
            * **edge_index** (PyTorch LongTensor) - Graph edge indices.
            * **edge_weight** (PyTorch LongTensor, optional)- Edge weight vector.

        Return types:
            * **T** (PyTorch FloatTensor) - Sequence of node features.
        """
        T_0 = self._temporal_conv1(X)
        T = torch.zeros_like(T_0).to(T_0.device)
        for b in range(T_0.size(0)):
            for t in range(T_0.size(1)):
                T[b][t] = self._graph_conv(T_0[b][t], edge_index, edge_weight)

        T = F.relu(T)
        T = self._temporal_conv2(T)
        T = T.permute(0, 2, 1, 3)
        T = self._batch_norm(T)
        T = T.permute(0, 2, 1, 3)
        return T

In [None]:
def train_stconv(model, data, optimizer):
    model.train()
    data = data.to(device)
    optimizer.zero_grad()
    X = data_graph.x.unsqueeze(axis=0).unsqueeze(axis=0)
    out = model(X, data.edge_index)
    # out = out.reshape((data.x.shape[0]))
    out = out.squeeze()
    # TODO :use weighted cross entropy loss
    loss = F.cross_entropy(out[data.train_idx], data.y[data.train_idx])
    loss.backward()
    optimizer.step()
    return loss.item()

def test_stconv(model, data):
    model.eval()
    with torch.no_grad():
        data = data.to(device)
        X = data_graph.x.unsqueeze(axis=0).unsqueeze(axis=0)
        out = model(X, data.edge_index)
        out = out.squeeze()
        pred_scores = out[data.test_idx]
        pred = torch.argmax(pred_scores, dim=1)
        # print(sum([1 for i in pred.tolist() if i == 1]))
        y = data.y[data.test_idx]
        acc = accuracy_score(y.cpu(), pred.cpu())
        f1 = f1_score(y.cpu(), pred.cpu())
        precision = precision_score(y.cpu(), pred.cpu())
        recall = recall_score(y.cpu(), pred.cpu())
        roc = roc_auc_score(y.cpu(), pred.cpu())
        return acc, f1, precision, recall, roc

In [103]:
num_nodes = data_graph.x.shape[0]
model_stconv = STConv(num_nodes=num_nodes, in_channels=num_features, hidden_channels=128, out_channels=2, 
                      kernel_size=1,K=2)
optimizer_stconv = torch.optim.Adam(model_stconv.parameters(), lr=lr)

for epoch in range(20+1):
    loss = train_stconv(model_stconv, data_graph, optimizer_stconv)
    acc, f1, precision, recall, roc = test_stconv(model_stconv, data_graph)
    print(f'epoch: {epoch}, loss: {loss:.4f}, acc: {acc:.4f}, f1: {f1:.4f}, precision: {precision:.4f}, recall: {recall:.4f}, roc: {roc:.4f}')
    

epoch: 0, loss: 1.5432, acc: 0.6653, f1: 0.0894, precision: 0.0543, recall: 0.2530, roc: 0.4735
epoch: 1, loss: 0.8417, acc: 0.8831, f1: 0.0279, precision: 0.0304, recall: 0.0259, roc: 0.4843
epoch: 2, loss: 0.8019, acc: 0.7385, f1: 0.1181, precision: 0.0756, recall: 0.2696, roc: 0.5204
epoch: 3, loss: 1.0054, acc: 0.5893, f1: 0.1081, precision: 0.0629, recall: 0.3832, roc: 0.4934
epoch: 4, loss: 1.0929, acc: 0.6166, f1: 0.1050, precision: 0.0619, recall: 0.3463, roc: 0.4908
epoch: 5, loss: 1.0620, acc: 0.6160, f1: 0.1021, precision: 0.0602, recall: 0.3361, roc: 0.4858
epoch: 6, loss: 1.0573, acc: 0.5957, f1: 0.0968, precision: 0.0566, recall: 0.3333, roc: 0.4737
epoch: 7, loss: 1.0859, acc: 0.6056, f1: 0.0930, precision: 0.0547, recall: 0.3112, roc: 0.4686
epoch: 8, loss: 1.0771, acc: 0.6041, f1: 0.0894, precision: 0.0526, recall: 0.2992, roc: 0.4622
epoch: 9, loss: 1.0898, acc: 0.6065, f1: 0.0894, precision: 0.0526, recall: 0.2973, roc: 0.4626
epoch: 10, loss: 1.1016, acc: 0.6001, f1