In [None]:
## Experiments re. extension of HAMP as AHAMP

In [None]:
import os
##
import pickle

import argparse
import math
import numpy as np
import pandas as pd
import json
from tqdm import tqdm
import scipy.sparse as sp
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.nn.functional as F
from sklearn.metrics import f1_score
import networkx as nx
import dgl
import dgl.function as fn
import dgl.nn.pytorch as dglnn
from dgl.nn.pytorch import edge_softmax
# from models import HAMP as Model
from models import PosVect, InterModalAttention
# note: there may still be some variability
torch.manual_seed(8)
np.random.seed(8)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'
import warnings
warnings.filterwarnings('ignore')

In [None]:
mode = 'adaptive'
task = 'screen_genre_class' 
# task = 'element_comp_class'

# paths
home_dir = Path(os.getcwd())
version_dir = 'rico_n'
main_data_dir = home_dir/'data'
data_dir = home_dir/'data'/version_dir

In [None]:
# load data - features
app2ui_edgelist = pd.read_hdf(data_dir/'app2ui_edgelist.h5', key='edgelist')
ui2class_edgelist = pd.read_hdf(data_dir/'ui2class_edgelist.h5', key='edgelist')
class2element_edgelist = pd.read_hdf(data_dir/'class2element_edgelist.h5', key='edgelist')
element2element_edgelist = pd.read_hdf(data_dir/'element2element_edgelist.h5', key='edgelist')
app_description_features_df = pd.read_hdf(data_dir/'app_description_features.h5', key='features')
ui_image_features = pd.read_hdf(data_dir/'ui_image_features.h5', key='features')
ui_pos_features_df = pd.read_hdf(data_dir/'ui_position_features.h5', key='features')
class_name_features = pd.read_hdf(data_dir/'charngram_features.h5', key='features') 
element_spatial_features_df = pd.read_hdf(data_dir/'spatial_features.h5', key='features')
element_image_features_df = pd.read_hdf(data_dir/'element_image_features.h5', key='features')
# load labels
comp_labels = pd.read_hdf(data_dir/'comp_labels.h5', key='labels')
genre_labels = pd.read_hdf(data_dir/'genre_labels.h5', key='labels')

# process edgelists
e2e_num_edges = len(element2element_edgelist)
e2e_adj_row = list(element2element_edgelist.target_element_encoded)
e2e_adj_col = list(element2element_edgelist.source_element_encoded)
e2e_num = len(element2element_edgelist.target_element_encoded.unique())
e2e_adj = sp.csc_matrix((np.ones(e2e_num_edges), (e2e_adj_row, e2e_adj_col)), shape=(e2e_num, e2e_num))
e2c_num_edges = len(class2element_edgelist)
e2c_adj_row = list(class2element_edgelist.target_element_encoded)
e2c_adj_col = list(class2element_edgelist.class_name_encoded)
e2c_num_row = len(class2element_edgelist.target_element_encoded.unique())
e2c_num_col = len(class2element_edgelist.class_name_encoded.unique())
e2c_adj = sp.csc_matrix((np.ones(e2c_num_edges), (e2c_adj_row, e2c_adj_col)), shape=(e2c_num_row, e2c_num_col))
u2c_num_edges = len(ui2class_edgelist)
u2c_adj_row = list(ui2class_edgelist.ui_encoded)
u2c_adj_col = list(ui2class_edgelist.class_name_encoded)
u2c_num_row = len(ui2class_edgelist.ui_encoded.unique())
u2c_num_col = len(ui2class_edgelist.class_name_encoded.unique())
u2c_adj = sp.csc_matrix((np.ones(u2c_num_edges), (u2c_adj_row, u2c_adj_col)), shape=(u2c_num_row, u2c_num_col))
a2u_num_edges = len(app2ui_edgelist)
a2u_adj_row = list(app2ui_edgelist.app_encoded)
a2u_adj_col = list(app2ui_edgelist.ui_encoded)
a2u_num_row = len(app2ui_edgelist.app_encoded.unique())
a2u_num_col = len(app2ui_edgelist.ui_encoded.unique())
a2u_adj = sp.csc_matrix((np.ones(a2u_num_edges), (a2u_adj_row, a2u_adj_col)), shape=(a2u_num_row, a2u_num_col))
# process features
assert (app_description_features_df.app_encoded == app2ui_edgelist.app_encoded.unique()).all()
app_desc_vectors = app_description_features_df.iloc[:,3:].values
assert (ui_image_features.ui_encoded == app2ui_edgelist.ui_encoded.unique()).all()
assert (ui_image_features.ui_encoded == ui2class_edgelist.ui_encoded.unique()).all()
ui_image_vectors = ui_image_features.iloc[:,2:].values
assert (class_name_features.class_name_encoded == ui2class_edgelist.class_name_encoded.unique()).all()
assert (class_name_features.class_name_encoded == ui2class_edgelist.class_name_encoded.unique()).all()
class_name_vectors = class_name_features.iloc[:,4:].values
assert (element_spatial_features_df.target_encoded == class2element_edgelist.target_element_encoded.unique()).all()
assert (element_spatial_features_df.target_encoded == element2element_edgelist.target_element_encoded.unique()).all()
element_spatial_vectors = element_spatial_features_df.iloc[:,2:].values
assert (element_image_features_df.target_element_encoded == class2element_edgelist.target_element_encoded.unique()).all()
assert (element_image_features_df.target_element_encoded == element2element_edgelist.target_element_encoded.unique()).all()
element_image_vectors = element_image_features_df.iloc[:,2:].values
assert (ui_image_features.ui_encoded == ui_pos_features_df.ui_encoded.unique()).all()
assert (ui_image_features.ui_encoded == ui_pos_features_df.ui_encoded.unique()).all()

ui_pos_vectors = ui_pos_features_df.iloc[:,3:].values

G = dgl.heterograph({
        ('element', 'fwd', 'element') : e2e_adj.nonzero(), # nonzero is the edgelist
        ('element', 'bkwd', 'element') : e2e_adj.transpose().nonzero(),
        ('element', 'is', 'class') : e2c_adj.nonzero(),
        ('class', 'of', 'element') : e2c_adj.transpose().nonzero(),
        # ('class', 'selfloop', 'class') : c2c_adj.nonzero(), # two-hop if necc
        ('ui', 'composed-of', 'class') : u2c_adj.nonzero(),
        ('class', 'in', 'ui') : u2c_adj.transpose().nonzero(),
        ('app', 'inc', 'ui') : a2u_adj.nonzero(),
        ('ui', 'part-of', 'app') : a2u_adj.transpose().nonzero(),
    }, num_nodes_dict = {'element': e2e_adj.shape[1], 'class': e2c_adj.shape[1], 'ui': a2u_adj.shape[1], 'app':a2u_adj.shape[0]})

# look-up dicts
node_dict = {}
reverse_node_dict = {}
edge_dict = {}
for ntype in G.ntypes:
    idx = len(node_dict)
    node_dict[ntype] = idx # increment by 1 each iteration
    reverse_node_dict[idx] = ntype
for etype in G.etypes:
    edge_dict[etype] = len(edge_dict)
    # assign a list of same integer ids to the etype
    G.edges[etype].data['id'] = torch.ones(G.number_of_edges(etype), dtype=torch.long) * edge_dict[etype] # attribute of edge with the id
    
node_feat = {'app':app_desc_vectors.shape[1], 'class':class_name_vectors.shape[1], 'element':element_image_vectors.shape[1], 'ui': ui_image_vectors.shape[1]}

G.nodes['element'].data['comp_label'] = torch.tensor(comp_labels.comp_encoded.values)
G.nodes['ui'].data['genre_label'] = torch.tensor(genre_labels.genre_encoded.values)
G.nodes['element'].data['node_ft'] = torch.tensor(element_image_vectors)
G.nodes['class'].data['node_ft'] = torch.tensor(class_name_vectors)
G.nodes['ui'].data['node_ft'] = torch.tensor(ui_image_vectors)
G.nodes['app'].data['node_ft'] = torch.tensor(app_desc_vectors)
element_spatial_vectors[:,4] = element_spatial_vectors[:,4] + 3
G.nodes['element'].data['pos'] = torch.FloatTensor(np.zeros((G.nodes('element').size()[0], 1)))
G.nodes['class'].data['pos'] = torch.FloatTensor(np.zeros((G.nodes('class').size()[0], 1)))
G.nodes['ui'].data['pos'] = torch.FloatTensor(ui_pos_vectors)
G.nodes['app'].data['pos'] = torch.FloatTensor(np.zeros((G.nodes('app').size()[0], 1)))
G.nodes['element'].data['depth'] = torch.FloatTensor(element_spatial_vectors[:,4]).unsqueeze(1)
G.nodes['class'].data['depth'] = torch.FloatTensor(np.ones((G.nodes('class').size()[0], 1))*3)
G.nodes['ui'].data['depth'] = torch.FloatTensor(np.ones((G.nodes('ui').size()[0], 1))*2)
G.nodes['app'].data['depth'] = torch.FloatTensor(np.ones((G.nodes('app').size()[0], 1))*1)
G.nodes['element'].data['bound1'] = torch.FloatTensor(element_spatial_vectors[:,0]).unsqueeze(1)
G.nodes['class'].data['bound1'] = torch.FloatTensor(np.zeros((G.nodes('class').size()[0], 1)))
G.nodes['ui'].data['bound1'] = torch.FloatTensor(np.zeros((G.nodes('ui').size()[0], 1)))
G.nodes['app'].data['bound1'] = torch.FloatTensor(np.zeros((G.nodes('app').size()[0], 1)))
G.nodes['element'].data['bound2'] = torch.FloatTensor(element_spatial_vectors[:,1]).unsqueeze(1)
G.nodes['class'].data['bound2'] = torch.FloatTensor(np.zeros((G.nodes('class').size()[0], 1)))
G.nodes['ui'].data['bound2'] = torch.FloatTensor(np.zeros((G.nodes('ui').size()[0], 1)))
G.nodes['app'].data['bound2'] = torch.FloatTensor(np.zeros((G.nodes('app').size()[0], 1)))
G.nodes['element'].data['bound3'] = torch.FloatTensor(element_spatial_vectors[:,2]).unsqueeze(1)
G.nodes['class'].data['bound3'] = torch.FloatTensor(np.zeros((G.nodes('class').size()[0], 1)))
G.nodes['ui'].data['bound3'] = torch.FloatTensor(np.zeros((G.nodes('ui').size()[0], 1)))
G.nodes['app'].data['bound3'] = torch.FloatTensor(np.zeros((G.nodes('app').size()[0], 1)))
G.nodes['element'].data['bound4'] = torch.FloatTensor(element_spatial_vectors[:,3]).unsqueeze(1)
G.nodes['class'].data['bound4'] = torch.FloatTensor(np.zeros((G.nodes('class').size()[0], 1)))
G.nodes['ui'].data['bound4'] = torch.FloatTensor(np.zeros((G.nodes('ui').size()[0], 1)))
G.nodes['app'].data['bound4'] = torch.FloatTensor(np.zeros((G.nodes('app').size()[0], 1)))

In [None]:
class AHAMPLayer(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim,
                 node_dict, edge_dict,
                 num_heads, dropout = 0.3, use_layer_norm = True):
        super().__init__()
        
        self.input_dim = input_dim # input dimension after spatial-sequence-hierarichal vectorization, inter-modal and spatial-sequence-hierarichal attention and projection
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim # output dimension of each layer

        self.node_dict = node_dict # dictionaries contain the node-types and edge-tyeps
        self.edge_dict = edge_dict

        self.num_node_types = len(node_dict)
        self.num_edge_types = len(edge_dict)
        self.num_heads = num_heads 
        self.d_k = hidden_dim // num_heads 
        self.sqrt_dk = math.sqrt(self.d_k) 

        self.k_dense = nn.ModuleList()
        self.q_dense = nn.ModuleList()
        self.v_dense = nn.ModuleList()
        self.fc = nn.ModuleList()
        self.layer_norms = nn.ModuleList()
        self.dense = nn.ModuleList()
        self.use_layer_norm = use_layer_norm 
        self.canon_weights   = nn.Parameter(torch.ones(self.num_edge_types, self.num_heads))
        self.att_weights = nn.Parameter(torch.Tensor(self.num_edge_types, self.num_heads, self.d_k, self.d_k)) 
        self.value_weights = nn.Parameter(torch.Tensor(self.num_edge_types, self.num_heads, self.d_k, self.d_k)) 
        self.res = nn.Parameter(torch.ones(self.num_node_types))

        self.dropout = nn.Dropout(dropout)

        for t in range(self.num_node_types): 
            self.k_dense.append(nn.Linear(input_dim, hidden_dim))
            self.q_dense.append(nn.Linear(input_dim, hidden_dim))
            self.v_dense.append(nn.Linear(input_dim, hidden_dim))
            self.fc.append(nn.Linear(hidden_dim, hidden_dim))
            if use_layer_norm:
                self.layer_norms.append(nn.LayerNorm(hidden_dim))
            self.dense.append(nn.Linear(hidden_dim, output_dim))
        
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                # nn.init.xavier_normal_(p, gain=0.001) 
                nn.init.xavier_uniform_(p)

    def forward(self, G, h, edge_mask=None):

        with G.local_scope(): 
            for srctype, etype, dsttype in G.canonical_etypes:
                sub_graph = G[srctype, etype, dsttype]

                k_dense = self.k_dense[self.node_dict[srctype]] # this retrieves the models, node_dict number corresponds to dictionary position in the ModuleList
                v_dense = self.v_dense[self.node_dict[srctype]]
                q_dense = self.q_dense[self.node_dict[dsttype]]

                k = k_dense(h[srctype]).view(-1, self.num_heads, self.d_k) # source
                v = v_dense(h[srctype]).view(-1, self.num_heads, self.d_k) # source
                q = q_dense(h[dsttype]).view(-1, self.num_heads, self.d_k) # target

                # extract id for the edge
                e_id = self.edge_dict[etype]

                att_weights = self.att_weights[e_id] 
                canon_weights = self.canon_weights[e_id]
                value_weights = self.value_weights[e_id]

                k = torch.einsum("bij,ijk->bik", k, att_weights)
                v = torch.einsum("bij,ijk->bik", v, value_weights)

                sub_graph.srcdata['k'] = k # src and dst different in different subgraphs
                sub_graph.dstdata['q'] = q
                sub_graph.srcdata['v'] = v

                sub_graph.apply_edges(fn.v_dot_u('q', 'k', 't')) 
                attn_score = sub_graph.edata.pop('t').sum(-1) * canon_weights / self.sqrt_dk 
                attn_score = edge_softmax(sub_graph, attn_score, norm_by='dst') 

                num_edges = sub_graph.num_edges()
                
                if edge_mask is not None:
                    if num_edges > 0:
                        # print(edge_mask[srctype, etype, dsttype].shape)
                        edge_mask_ = edge_mask[srctype, etype, dsttype].sigmoid()
                        edge_mask_ = edge_mask_.unsqueeze(-1)
                        attn_score_masked = torch.mul(attn_score, edge_mask_)
                        attn_score = attn_score_masked

                sub_graph.edata['t'] = attn_score.unsqueeze(-1) 

                edge_scores = attn_score_masked

            G.multi_update_all({etype : (fn.u_mul_e('v', 't', 'm'), fn.sum('m', 't')) for etype in self.edge_dict}, cross_reducer = 'mean')
            final_h = {}
            for ntype in G.ntypes:
                n_id = self.node_dict[ntype]
                alpha = torch.sigmoid(self.res[n_id]) 
                t = G.nodes[ntype].data['t'].view(-1, self.hidden_dim) 
                h_prime = self.dropout(self.fc[n_id](t))
                h_prime = h_prime * alpha + h[ntype] * (1-alpha)
                if self.use_layer_norm:
                    final_h[ntype] = self.dense[n_id](self.layer_norms[n_id](h_prime))
                else:
                    final_h[ntype] = self.dense[n_id](h_prime)
                
            return final_h

class AHAMP(nn.Module):
    def __init__(self, node_dict, edge_dict, reverse_node_dict, node_feat, input_dim, hidden_dim, output_dim, num_layers, num_heads, use_layer_norm = True, act = F.gelu):
        super().__init__()
        self.node_dict = node_dict
        self.edge_dict = edge_dict
        self.layer = nn.ModuleList()
        self.input_dim = input_dim # not used for now in HAMP as it is directly computed from node_feat dict, but can be utilized if all feature dimensions same
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        self.act = act

        self.projection  = nn.ModuleList()

        self.pvect_pos  = nn.ModuleList()   
        self.pvect_depth  = nn.ModuleList()  
        self.pvect_bound1  = nn.ModuleList()  
        self.pvect_bound2  = nn.ModuleList()  
        self.pvect_bound3  = nn.ModuleList()
        self.pvect_bound4  = nn.ModuleList()  

        for t in range(len(node_dict)):
            self.pvect_pos.append(PosVect(1, hidden_dim))
            self.pvect_depth.append(PosVect(1, hidden_dim)) # depth corresponds to hierarchy information
            self.pvect_bound1.append(PosVect(1, hidden_dim))
            self.pvect_bound2.append(PosVect(1, hidden_dim))
            self.pvect_bound3.append(PosVect(1, hidden_dim))
            self.pvect_bound4.append(PosVect(1, hidden_dim))

            #project down from different feature dimensions
            in_dim = node_feat[reverse_node_dict[t]] # this is a lookup that allows us to assign different input dimension to each of the projection modules (depending on the modality of the feature)
            self.projection.append(nn.Linear(in_dim, hidden_dim))

        self.intermodal_attention = InterModalAttention(in_size=hidden_dim)

        for _ in range(num_layers):
            # input_dim here is hidden_dim as it has already been projected from the various node_feat to a common hidden_dim
            self.layer.append(AHAMPLayer(hidden_dim, hidden_dim, hidden_dim, node_dict, edge_dict, num_heads, use_layer_norm = use_layer_norm))
        
        self.out = nn.Linear(hidden_dim, output_dim) 
        
        self.reset_parameters()

    def reset_parameters(self):
        for p in self.parameters():
            if p.dim() > 1:
                # nn.init.xavier_normal_(p, gain=0.001) 
                nn.init.xavier_uniform_(p)

    def forward(self, G, sel_node_type, feat_mask=None, edge_mask=None):
        h = {}
        pos_h = {}
        depth_h = {}
        bound1_h = {}
        bound2_h = {}
        bound3_h = {}
        bound4_h = {}

        if feat_mask is not None:
            for ntype in G.ntypes:
                n_id = self.node_dict[ntype]

                masked_node_ft = torch.mul(G.nodes[ntype].data['node_ft'], feat_mask[ntype]['node_ft'].sigmoid())
                h[ntype] = self.act(self.projection[n_id](masked_node_ft))

                masked_pos = torch.mul(G.nodes[ntype].data['pos'], feat_mask[ntype]['pos'].sigmoid())
                pos_h[ntype] = self.act(self.pvect_pos[n_id](masked_pos))

                masked_depth = torch.mul(G.nodes[ntype].data['depth'], feat_mask[ntype]['depth'].sigmoid())
                depth_h[ntype] = self.act(self.pvect_depth[n_id](masked_depth))

                masked_bound1 = torch.mul(G.nodes[ntype].data['bound1'], feat_mask[ntype]['bound1'].sigmoid())
                bound1_h[ntype] = self.act(self.pvect_bound1[n_id](masked_bound1))

                masked_bound2 = torch.mul(G.nodes[ntype].data['bound2'], feat_mask[ntype]['bound2'].sigmoid())
                bound2_h[ntype] = self.act(self.pvect_bound1[n_id](masked_bound2))

                masked_bound3 = torch.mul(G.nodes[ntype].data['bound3'], feat_mask[ntype]['bound3'].sigmoid())
                bound3_h[ntype] = self.act(self.pvect_bound1[n_id](masked_bound3))

                masked_bound4 = torch.mul(G.nodes[ntype].data['bound4'], feat_mask[ntype]['bound4'].sigmoid())
                bound4_h[ntype] = self.act(self.pvect_bound1[n_id](masked_bound4))

                # attention
                all_h = []
                all_h.append(h[ntype])
                all_h.append(pos_h[ntype])
                all_h.append(depth_h[ntype])
                all_h.append(bound1_h[ntype])
                all_h.append(bound2_h[ntype])
                all_h.append(bound3_h[ntype])
                all_h.append(bound4_h[ntype])

                all_h = torch.stack(all_h, dim=1) 
                att_h = self.intermodal_attention(all_h)  
                h[ntype] = att_h

        else: 

            for ntype in G.ntypes:
                n_id = self.node_dict[ntype]
                h[ntype] = self.act(self.projection[n_id](G.nodes[ntype].data['node_ft'])) 

                pos_h[ntype] = self.act(self.pvect_pos[n_id](G.nodes[ntype].data['pos']))
                depth_h[ntype] = self.act(self.pvect_depth[n_id](G.nodes[ntype].data['depth']))
                bound1_h[ntype] = self.act(self.pvect_bound1[n_id](G.nodes[ntype].data['bound1']))
                bound2_h[ntype] = self.act(self.pvect_bound2[n_id](G.nodes[ntype].data['bound2']))
                bound3_h[ntype] = self.act(self.pvect_bound3[n_id](G.nodes[ntype].data['bound3']))
                bound4_h[ntype] = self.act(self.pvect_bound4[n_id](G.nodes[ntype].data['bound4']))

                # attention
                all_h = []
                all_h.append(h[ntype])
                all_h.append(pos_h[ntype])
                all_h.append(depth_h[ntype])
                all_h.append(bound1_h[ntype])
                all_h.append(bound2_h[ntype])
                all_h.append(bound3_h[ntype])
                all_h.append(bound4_h[ntype])

                all_h = torch.stack(all_h, dim=1) 
                att_h = self.intermodal_attention(all_h)  
                h[ntype] = att_h

        for i in range(self.num_layers):
            h = self.layer[i](G, h, edge_mask)

        return self.out(h[sel_node_type]) # select node type to get representation of

def init_masks(graph, device):
    feat_mask = {}
    feat_types = ['node_ft', 'pos', 'depth', 'bound1', 'bound2', 'bound3', 'bound4']
    for ntype in graph.ntypes:
        temp = {}
        for type in feat_types:
            num_nodes, feat_size = graph.nodes[ntype].data[type].size()
            std = 0.1
            temp[type] = nn.Parameter(torch.randn(1, feat_size, device=device) * std)
            feat_mask[ntype] = temp

    edge_mask = {}
    for srctype, etype, dsttype in graph.canonical_etypes:
        sub_graph = graph[srctype, etype, dsttype]
        num_edges = sub_graph.num_edges()
        num_nodes = sub_graph.num_nodes()
        if num_edges > 0:
            std = nn.init.calculate_gain('relu') * math.sqrt(2.0 / (2 * num_nodes))
            edge_mask[srctype, etype, dsttype] = nn.Parameter(torch.randn(num_edges, device=device) * std)

    return feat_mask, edge_mask

In [None]:
# initialize model
n_epoch = 3000
input_dim = None
hidden_dim = 32
clip = 1.0
max_lr=1e-3
task = task
eps = 1e-15
alpha1=0.005
alpha2=1.0
beta1=1.0
beta2=0.1

if task == 'screen_genre_class':
    selected_element = 'ui'
    num_classes = len(G.nodes['ui'].data['genre_label'].unique())
    labels = G.nodes['ui'].data['genre_label']
    pid = u2c_adj.tocoo().row
    
elif task == 'element_comp_class':
    selected_element = 'element'
    num_classes = len(G.nodes['element'].data['comp_label'].unique())
    labels = G.nodes['element'].data['comp_label']
    pid = e2e_adj.tocoo().row

genre_labels_ = G.nodes['ui'].data['genre_label']
comp_labels_ = G.nodes['element'].data['comp_label']

In [None]:
feat_mask, edge_mask = init_masks(G, device)

# generate train/val/test split
train = int(0.6*len(pid))
valid = int(0.8*len(pid))
shuffle = np.random.permutation(pid)
train_idx = torch.tensor(shuffle[0:train]).long()
val_idx = torch.tensor(shuffle[train:valid]).long()
test_idx = torch.tensor(shuffle[valid:]).long()

model = AHAMP(node_dict, edge_dict, reverse_node_dict, node_feat, 
            input_dim=input_dim,
            hidden_dim=hidden_dim,
            output_dim=num_classes,
            num_layers=2,
            num_heads=2,
            use_layer_norm=True).to(device)

# optimizer = torch.optim.AdamW(model.parameters())

params = list(model.parameters()) + [edge_mask[k] for k in edge_mask.keys()]
optimizer = torch.optim.AdamW(params)


scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, total_steps=n_epoch, max_lr = max_lr)
G = G.to(device) 

# train
best_val_acc = 0
best_test_acc = 0
best_micro_f1 = 0
best_macro_f1 = 0
train_step = 0
best_epoch = 0

In [None]:
# main training loop
for epoch in tqdm(np.arange(n_epoch)+1):
    model.train()
    logits = model(G, selected_element, edge_mask=edge_mask)
    # The loss is computed only for labeled nodes.
    loss = F.cross_entropy(logits[train_idx], labels[train_idx].to(device))

    for srctype, etype, dsttype in G.canonical_etypes:
        num_edges = G[srctype, etype, dsttype].num_edges()
        if num_edges > 0:
            e_mask_s = edge_mask[srctype, etype, dsttype].sigmoid()
            # edge_mask = edge_mask.sigmoid()
            # Edge mask sparsity regularization
            loss = loss + alpha1 * torch.sum(e_mask_s)
            # Edge mask entropy regularization
            ent = - e_mask_s * torch.log(e_mask_s + eps) - \
                (1 - e_mask_s) * torch.log(1 - e_mask_s + eps)
            loss = loss + alpha2 * ent.mean()

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()
    train_step += 1
    scheduler.step()
    
    if epoch % 5 == 0:
        model.eval()
        logits = model(G, selected_element, edge_mask=edge_mask)
        pred   = logits.argmax(1).cpu()
        train_acc = (pred[train_idx] == labels[train_idx]).float().mean()
        val_acc   = (pred[val_idx]   == labels[val_idx]).float().mean()
        test_acc  = (pred[test_idx]  == labels[test_idx]).float().mean()

        test_micro_f1 = f1_score(labels[test_idx].detach().cpu().numpy(), pred[test_idx].numpy(), average='micro')
        test_macro_f1 = f1_score(labels[test_idx].detach().cpu().numpy(), pred[test_idx].numpy(), average='macro')

        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc
            if (best_micro_f1 < test_micro_f1) & (best_macro_f1 < test_macro_f1):
                best_micro_f1 = test_micro_f1
                best_macro_f1 = test_macro_f1
                best_epoch = epoch

print('='*100)
print(f'Test - Best - Micro F1: {best_micro_f1} | Macro F1: {best_macro_f1} | Best epoch: {best_epoch}')
print('='*100)