In [1]:
import pandas as pd
import gzip
import array
import numpy as np
import networkx as nx

In [2]:
from tqdm.notebook import tqdm

In [3]:
from collections import deque
from scipy.sparse import coo, coo_matrix

# Data utils

In [4]:
def parse(path):
  g = gzip.open(path, 'rb')
  for l in g:
    yield eval(l)

def getDF(path):
  i = 0
  df = {}
  for d in parse(path):
    df[i] = d
    i += 1
  return pd.DataFrame.from_dict(df, orient='index')

def readImageFeatures(path):
  f = open(path, 'rb')
  while True:
    asin = f.read(10)
    if asin == '': break
    a = array.array('f')
    try:
      a.fromfile(f, 4096)
    except:
      break
    yield asin, a.tolist()

def related_convert_to_list(x):
  r = set([])
  if x.isna()['related']:
    return r
  else:
    x = x['related']
    for key in x:
      for item in x[key]:
        r.add(item)
    return r

# Utils

## General

In [54]:
def normalize(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv = np.power(rowsum, -1).flatten()
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)
    mx = r_mat_inv.dot(mx)
    return mx

def sparse_mx_to_torch_sparse_tensor(sparse_mx):
    """Convert a scipy sparse matrix to a torch sparse tensor."""
    sparse_mx = sparse_mx.tocoo().astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    return torch.sparse.FloatTensor(indices, values, shape)

def train_val_test_split(num_total, split_train=0.6, split_val=0.2):
  idx_list = [i for i in range(num_total)]
  np.random.shuffle(idx_list)

  num_train = int(split_train * num_total)
  num_val = int((split_train + split_val) * num_total)

  idx_train = idx_list[:num_train]
  idx_val = idx_list[num_train:num_val]
  idx_test = idx_list[num_val:]

  return idx_train, idx_val, idx_test

def accuracy(output, labels):
  correct_false = 0
  correct_true = 0
  total_false = 0
  total_true = 0
  for o, l in zip(output, labels):
    if l == 0:
      total_false += 1
      if o == 0:
        correct_false += 1
    elif l == 1:
      total_true += 1
      if o == 1:
        correct_true += 1
  if total_true == 0:
    total_true = 1
  if total_false == 0:
    total_false = 1
  return (correct_true + correct_false)/(total_true + total_false), correct_true/total_true, total_true, correct_false/total_false, total_false

In [81]:
category = 'Musical_Instruments'
min_reviews = 3
ratings_df = pd.read_csv('ratings_{}.csv'.format(category), header=None)
ratings_df = ratings_df[ratings_df.groupby(0)[1].transform('count')>=min_reviews]
ratings_df = ratings_df.reset_index(drop=True)
movies_ratings_df = pd.read_csv('ratings_Movies_and_TV.csv', header = None)
meta_df = getDF('meta_{}.json.gz'.format(category))
meta_df = meta_df[['asin', 'related']]
meta_df = meta_df.set_index('asin')
meta_df = meta_df.apply(lambda x: related_convert_to_list(x), axis=1)

image_features_index = {feature[0]:i for i,feature in enumerate(readImageFeatures('image_features_{}.b'.format(category)))}
image_features = []
for feature in readImageFeatures('image_features_{}.b'.format(category)):
    image_features.append(coo_matrix(feature[1]))

## Graph (data) loader

In [86]:
from numpy import zeros, ones
from numpy.random import shuffle
import networkx as nx
import pandas as pd

class GraphLoader():
    def __init__(self, category, batch_size=20, min_reviews=3):
        self.batch_size = batch_size

#         ratings_df = pd.read_csv('ratings_{}.csv'.format(category), header=None)
#         ratings_df = ratings_df[ratings_df.groupby(0)[1].transform('count')>=min_reviews]
#         ratings_df = ratings_df.reset_index(drop=True)

        self.ratings_df_reviewer = ratings_df.set_index(0)
        self.ratings_df_asin = ratings_df.set_index(1)
        self.ratings_df_index = {id:i for i,id in enumerate(self.ratings_df_reviewer.index)}
        self.ratings_df_index_i = self.ratings_df_reviewer.index

#         movies_ratings_df = pd.read_csv('ratings_Movies_and_TV.csv', header = None)

        # Generate labels
        self.labels = zeros(len(self.ratings_df_index_i))
        for user in self.ratings_df_index_i.intersection(movies_ratings_df[0]):
            self.labels[self.ratings_df_index[user]] = 1

#         meta_df = getDF('meta_{}.json.gz'.format(category))
#         meta_df = meta_df[['asin', 'related']]
#         meta_df = meta_df.set_index('asin')
#         meta_df = meta_df.apply(lambda x: related_convert_to_list(x), axis=1)
        self.meta_df = meta_df
        
#         self.image_features_index = {feature[0]:i for i,feature in enumerate(readImageFeatures('image_features_{}.b'.format(category)))}
#         self.image_features = []
#         for feature in readImageFeatures('image_features_{}.b'.format(category)):
#             self.image_features.append(coo_matrix(feature[1]))
        self.image_features_index = image_features_index
        self.image_features = image_features    
        
        self.idx_train, self.idx_val, self.idx_test = self.train_val_test_split(len(self.ratings_df_index_i))
            
        self.mode = 'train'
        self.idx = self.idx_train

    def train_val_test_split(self, num_total, split_train=0.6, split_val=0.2):
        idx_list = [i for i in range(num_total)]
        shuffle(idx_list)

        num_train = int(split_train * num_total)
        num_val = int((split_train + split_val) * num_total)

        idx_train = idx_list[:num_train]
        print(sum(self.labels[idx_train])/len(self.labels[idx_train]))
        idx_val = idx_list[num_train:num_val]
        print(sum(self.labels[idx_val])/len(self.labels[idx_val]))
        idx_test = idx_list[num_val:]
        print(sum(self.labels[idx_test])/len(self.labels[idx_test]))

        return idx_train, idx_val, idx_test

    def __iter__(self):
        if self.mode == 'train':
            shuffle(self.idx_train)
            self.idx = self.idx_train
        elif self.mode == 'val':
            self.idx = self.idx_val
        else:
            self.idx = self.idx_test
        i = 0
        for i in range(len(self.idx)):
          yield self.get_graph_around_user(self.idx[i])
        #     if i % self.batch_size == 0:
        #         adj_, feat_, label_ = self.get_graph_around_user(idx[i])
        #         adj = adj_.unsqueeze(0)
        #         feat = feat_.unsqueeze(0)
        #         label = label_.unsqueeze(0)
        #     else:
        #         adj_, feat_, label_ = self.get_graph_around_user(idx[i])
        #         adj = torch.vstack((adj, adj_.unsqueeze(0)))
        #         feat = torch.vstack((feat, feat_.unsqueeze(0)))
        #         label = torch.vstack((label, label_.unsqueeze(0)))
            
        #     if i % self.batch_size == self.batch_size - 1:
        #         yield adj, feat, label

    def __len__(self):
        # Number of users
        return len(self.idx)

    def get_graph_around_user(self, user_idx):
        '''
        Return a adjacency matrix and feature matrix of graph surrounding the user
        '''
        user = self.ratings_df_index_i[user_idx]
        features = zeros((50, 4096))
        G = nx.Graph()
        Q = deque()
        Q.append((user, 0)) # 0 for user
        idx = 0

        while Q and G.number_of_nodes() < 50:
            node_name, node_type = Q.popleft()
            if node_name not in G._node:
                if node_type == 1 and bytes(node_name, 'utf-8') in self.image_features_index.keys():
                    features[G.number_of_nodes(), :] = self.image_features[self.image_features_index[bytes(node_name, 'utf-8')]].todense()
                G.add_node(node_name)

            if node_type == 0: # User
                products = self.ratings_df_reviewer.loc[node_name][1]
                self.add_to_queue_or_graph(Q, G, node_name, products, 1)

            else:
                if node_name in self.meta_df.index:
                    products = self.meta_df.loc[node_name]
                    self.add_to_queue_or_graph(Q, G, node_name, products, 1)

                if node_name in self.ratings_df_asin.index:
                    users = self.ratings_df_asin.loc[node_name][0]
                    self.add_to_queue_or_graph(Q, G, node_name, users, 0)

        adj = nx.linalg.graphmatrix.adjacency_matrix(G).todense()
        adj_pad = np.zeros((50, 50))
        adj_pad[:adj.shape[0], :adj.shape[1]] = adj
        return torch.Tensor(adj_pad), torch.Tensor(features), torch.LongTensor([self.labels[user_idx]])

    def add_to_queue_or_graph(self, Q, G, source, to_add, indicator):
        if isinstance(to_add, str):
            if to_add not in G._node:
                Q.append((to_add, indicator))
            else:
                G.add_edge(source, to_add)
        else:
            for t in to_add:
                self.add_to_queue_or_graph(Q, G, source, t, indicator)

In [87]:
graph_loader = GraphLoader('Musical_Instruments')

0.07041272228611872
0.06509189834782032
0.06830999933647403


# Train GNN

## Layer

In [88]:
import math

import torch

from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module


class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.matmul(input, self.weight)
        output = torch.sparse.mm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

## Model

In [89]:
import torch.nn as nn
import torch.nn.functional as F

from tqdm import tqdm

class GCN(nn.Module):
    def __init__(self, nfeat, nclass, dropout):
        """ 3 layers of GCNs with output dimensions equal to 32, 48, 64 respectively and average all node features """
        """ Final classifier with 2 fully connected layers and hidden dimension set to 32 """
        """ Activation function - ReLu (Mutag) """

        super(GCN, self).__init__()

        self.dropout = dropout

        self.gc1 = GraphConvolution(nfeat, 1024)
        self.gc2 = GraphConvolution(1024, 512)
        self.gc3 = GraphConvolution(512, 256)
        self.gc4 = GraphConvolution(256, 128)
        self.gc5 = GraphConvolution(128, 64)
        self.fc1 = nn.Linear(64, 32)
        self.fc2 = nn.Linear(32, nclass)

    def forward(self, x, adj):

        x = F.relu(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc2(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc3(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc4(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu(self.gc5(x, adj))

        # prev = 0
        # y = []
        # for idx in idx_map:
        #   y.append(torch.mean(x[prev:idx_map[idx]], 0))
        #   prev = idx_map[idx]
        # y = torch.stack(y, 0)

        # x = x[:num_customers]

        # y = []
        # for X in x:
        #   X = F.relu(self.fc1(X))
        #   X = F.dropout(X, self.dropout, training=self.training)
        #   X = F.softmax(self.fc2(X), dim=0)
        #   y.append(X)
        # y = torch.stack(y, 0)

        y = x[0]

        y = F.relu(self.fc1(y))
        y = F.dropout(y, self.dropout, training=self.training)
        y = F.softmax(self.fc2(y), dim=0)

        return y

In [90]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

class FocalLoss(nn.Module):
    def __init__(self, gamma=0, alpha=None, size_average=True):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha
        if isinstance(alpha,(float,int)): self.alpha = torch.Tensor([alpha,1-alpha])
        if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
        self.size_average = size_average

    def forward(self, input, target):
        if input.dim()>2:
            input = input.view(input.size(0),input.size(1),-1)  # N,C,H,W => N,C,H*W
            input = input.transpose(1,2)    # N,C,H*W => N,H*W,C
            input = input.contiguous().view(-1,input.size(2))   # N,H*W,C => N*H*W,C
        target = target.view(-1,1)

        logpt = F.log_softmax(input)
        logpt = logpt.gather(1,target)
        logpt = logpt.view(-1)
        pt = Variable(logpt.data.exp())

        if self.alpha is not None:
            if self.alpha.type()!=input.data.type():
                self.alpha = self.alpha.type_as(input.data)
            at = self.alpha.gather(0,target.data.view(-1))
            logpt = logpt * Variable(at)

        loss = -1 * (1-pt)**self.gamma * logpt
        if self.size_average: return loss.mean()
        else: return loss.sum()

In [93]:
from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np

import torch
import torch.nn.functional as F
import torch.optim as optim
from sklearn.metrics import f1_score

# Parameters

class Object(object):
    pass

args = Object()
args.epochs = 80
args.seed = 100
args.cuda = torch.cuda.is_available()
args.lr = 0.0001
args.dropout = 0.1
args.weight_decay = 5e-4
args.batch_size = 100

np.random.seed(args.seed)

torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)

model = GCN(nfeat=4096,
            nclass=2,
            dropout=args.dropout)


optimizer = optim.Adam(model.parameters(),
                       lr=args.lr, weight_decay=args.weight_decay)

weights = torch.Tensor([1, 16])

if args.cuda:
    model.cuda()
    weights.cuda()

def train(epoch):
  t = time.time()
  model.train()
  optimizer.zero_grad()
  graph_loader.mode = 'train'
  total_loss_train = 0
  output_list = []
  labels_list = []

  for i, data in enumerate(tqdm(graph_loader, total=len(graph_loader), position=0, leave=True)):
      adj, feat, label = data
      if args.cuda:
          adj = adj.cuda()
          feat = feat.cuda()
          label = label.cuda()

      if i % args.batch_size == 0:
          current = model.forward(feat, adj)
          output = current.unsqueeze(0)
          labels = label
      else:
          current = model.forward(feat, adj)
          output = torch.vstack((output, current.unsqueeze(0)))
          labels = torch.vstack((labels, label))

      output_list.append(int(torch.argmax(current)))
      labels_list.append(int(label))

      if i % args.batch_size == args.batch_size - 1:
          labels = torch.flatten(labels)
          loss_train = F.cross_entropy(
              output, labels, weight=weights)
          total_loss_train += float(loss_train)
          # acc_train = accuracy(output, labels)
          loss_train.backward()
          optimizer.step()
        
  f1_train = f1_score(labels_list, output_list)
  acc_train = accuracy(output_list, labels_list)
  # model.eval()
  
  print('\nEpoch: {:04d}'.format(epoch+1),
      # 'loss_train: {:.4f}'.format(loss_train.item()),
      'total_loss_train: {:.4f}'.format(total_loss_train),
      'f1_train: {:.4f}'.format(f1_train),
      'acc_train: {:.4f} ({:.4f} of {} true|| {:.4f} of {} false)'.format(acc_train[0], acc_train[1], acc_train[2], acc_train[3], acc_train[4]),
      # 'loss_val: {:.4f}'.format(loss_val.item()),
      # 'avg_loss_val: {:.4f}'.format(total_loss_val/len(idx_val)),
      # 'f1_val: {:.4f}'.format(f1_val),
      # 'acc_val: {:.4f}'.format(acc_val.item()),
      'time: {:.4f}s'.format(time.time() - t))

class EarlyStopping():
    def __init__(self, patience = 10, min_loss = 0.5, hit_min_before_stopping = False):
        self.patience = patience
        self.counter = 0
        self.hit_min_before_stopping = hit_min_before_stopping
        if hit_min_before_stopping:
            self.min_loss = min_loss
        self.best_loss = None
        self.early_stop = False
        
    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
        elif loss > self.best_loss:
            self.counter += 1
            if self.counter > self.patience:
              if self.hit_min_before_stopping == True and loss > self.min_loss:
                print("Cannot hit mean loss, will continue")
                self.counter -= self.patience
              else:
                self.early_stop = True
        else:
            self.best_loss = loss
            counter = 0

In [None]:
for i in range(args.epochs):
  train(i)
  if i % 5 == 4:
    torch.save(model.state_dict(), 'ckpt_{}.pth'.format(i))

  9%|██████▊                                                                    | 8198/90424 [15:19<2:00:37, 11.36it/s]

In [None]:
from torch.nn.modules.module import Module
from torch.nn import Linear
from torch.nn import ReLU6
from torch.nn import Sequential
import random

import copy

MAX_NUM_NODES = 100 # for mutag
random.seed(200)

class Generator(Module):
    def __init__(self, 
                 C: list,
                 c=0,
                 hyp1=1, 
                 hyp2=2, 
                 start=None,
                 nfeat=7,
                 dropout=0.1):
        """ 
        :param C: Candidate set of nodes (list)
        :param start: Starting node (defaults to randomised node)
        """
        super(Generator, self).__init__()

        self.nfeat = nfeat
        self.dropout = dropout
        self.c = c

        self.fc = Linear(nfeat, 8)
        self.gc1 = GraphConvolution(8, 16)
        self.gc2 = GraphConvolution(16, 24)
        self.gc3 = GraphConvolution(24, 32)

        # MLP1
        # 2 FC layers with hidden dimension 16
        self.mlp1 = Sequential(Linear(32, 16),
                               Linear(16, 1))

        # MLP2
        # 2 FC layers with hidden dimension 24
        self.mlp2 = Sequential(Linear(64, 24),
                               Linear(24, 1))

        # Hyperparameters
        self.hyp1 = hyp1
        self.hyp2 = hyp2
        self.candidate_set = C
        
        # Default starting node (if any)
        if start is not None:
          self.start = start
          self.random_start = False
        else:
          self.start = random.choice(np.arange(0, len(self.candidate_set)))
          self.random_start = True

        # Load GCN for calculating reward
        self.model = GCN(nfeat=features_list[0].shape[1],
                         nclass=labels.max().item() + 1,
                         dropout=args.dropout)
        
        self.model.load_state_dict(torch.load(PATH))
        for param in self.model.parameters():
          param.requires_grad = False

        self.reset_graph()
        
    def reset_graph(self):
        """
        Reset g.G to default graph with only start node
        """
        if self.random_start == True:
            self.start = random.choice(np.arange(0, len(self.candidate_set)))

        mask_start = torch.BoolTensor([False if i == 0 else True for i in range(MAX_NUM_NODES + len(self.candidate_set))])
        
        adj = torch.zeros((MAX_NUM_NODES + len(self.candidate_set), MAX_NUM_NODES + len(self.candidate_set)), dtype=torch.float32)

        feat = torch.zeros((MAX_NUM_NODES + len(self.candidate_set), len(self.candidate_set)), dtype=torch.float32)
        feat[0, self.start] = 1
        feat[np.arange(-len(self.candidate_set), 0), np.arange(0, len(self.candidate_set))] = 1

        degrees = torch.zeros(MAX_NUM_NODES)

        self.G = {'adj': adj, 'feat': feat, 'degrees': degrees, 'num_nodes': 1, 'mask_start': mask_start}

    def calculate_loss(self, Rt, p_start, a_start, p_end, a_end, G_t_1):
        """
        Calculated from cross entropy loss (Lce) and reward function (Rt)
        where loss = -Rt*(Lce_start + Lce_end)
        """

        Lce_start = F.cross_entropy(torch.reshape(p_start, (1, 35)), a_start.unsqueeze(0))
        Lce_end = F.cross_entropy(torch.reshape(p_end, (1, 35)), a_end.unsqueeze(0))

        return -Rt*(Lce_start + Lce_end)

    def calculate_reward(self, G_t_1):
        """
        Rtr     Calculated from graph rules to encourage generated graphs to be valid
                1. Only one edge to be added between any two nodes
                2. Generated graph cannot contain more nodes than predefined maximum node number
                3. (For chemical) Degree cannot exceed valency
                If generated graph violates graph rule, Rtr = -1

        Rtf     Feedback from trained model
        """

        rtr = self.check_graph_rules(G_t_1)

        rtf = self.calculate_reward_feedback(G_t_1)
        rtf_sum = 0
        for m in range(rollout):
            p_start, a_start, p_end, a_end, G_t_1 = self.forward(G_t_1)
            rtf_sum += self.calculate_reward_feedback(G_t_1)
        rtf = rtf + rtf_sum * self.hyp1 / rollout

        return rtf + self.hyp2 * rtr

    def calculate_reward_feedback(self, G_t_1):
        """
        p(f(G_t_1) = c) - 1/l
        where l denotes number of possible classes for f
        """
        f = self.model(G_t_1['feat'], G_t_1['adj'], None)
        return f[self.c] - 1/len(f)

    def check_graph_rules(self, G_t_1):
        """
        For mutag, node degrees cannot exceed valency
        """
        idx = 0

        for d in G_t_1['degrees']:
          if d is not 0:
            node_id = torch.argmax(G_t_1['feat'][idx]) # Eg. [0, 1, 0, 0] -> 1
            node = self.candidate_set[node_id]  # Eg ['C.4', 'F.2', 'Br.7'][1] = 'F.2'
            max_valency = int(node.split('.')[1]) # Eg. C.4 -> ['C', '4'] -> 4

            # If any node degree exceeds its valency, return -1
            if max_valency < d:
                return -1

        return 0
        
    def forward(self, G_in):
        G = copy.deepcopy(G_in)

        x = G['feat'].detach().clone()
        adj = G['adj'].detach().clone()

        x = F.relu6(self.fc(x))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu6(self.gc1(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu6(self.gc2(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.relu6(self.gc3(x, adj))
        x = F.dropout(x, self.dropout, training=self.training)

        p_start = self.mlp1(x)
        p_start = p_start.masked_fill(G['mask_start'].unsqueeze(1), 0)
        p_start = F.softmax(p_start, dim=0)
        a_start_idx = torch.argmax(p_start.masked_fill(G['mask_start'].unsqueeze(1), -1))
        
        # broadcast
        x1, x2 = torch.broadcast_tensors(x, x[a_start_idx])
        x = torch.cat((x1, x2), 1) # cat increases dim from 32 to 64

        mask_end = torch.BoolTensor([True for i in range(MAX_NUM_NODES + len(self.candidate_set))])
        mask_end[MAX_NUM_NODES:] = False
        mask_end[:G['num_nodes']] = False
        mask_end[a_start_idx] = True

        p_end = self.mlp2(x)
        p_end = p_end.masked_fill(mask_end.unsqueeze(1), 0)
        p_end = F.softmax(p_end, dim=0)
        a_end_idx = torch.argmax(p_end.masked_fill(mask_end.unsqueeze(1), -1))

        # Return new G
        # If a_end_idx is not masked, node exists in graph, no new node added
        if G['mask_start'][a_end_idx] == False:
            G['adj'][a_end_idx][a_start_idx] += 1
            G['adj'][a_start_idx][a_end_idx] += 1
            
            # Update degrees
            G['degrees'][a_start_idx] += 1
            G['degrees'][G['num_nodes']] += 1
        else:
            # Add node
            G['feat'][G['num_nodes']] = G['feat'][a_end_idx]
            # Add edge
            G['adj'][G['num_nodes']][a_start_idx] += 1
            G['adj'][a_start_idx][G['num_nodes']] += 1
            # Update degrees
            G['degrees'][a_start_idx] += 1
            G['degrees'][G['num_nodes']] += 1

            # Update start mask
            G_mask_start_copy = G['mask_start'].detach().clone()
            G_mask_start_copy[G['num_nodes']] = False
            G['mask_start'] = G_mask_start_copy
            
            G['num_nodes'] += 1

        return p_start, a_start_idx, p_end, a_end_idx, G