In [1]:
import json
import torch
import torch.nn.functional as F

import sys
import argparse
import logging
import os

from utils.log_helper import logger_init
from modeling.modeling_llm import LLMGNP
from utils.data_util import LLMGNP_DataLoader
from utils import utils

2024-01-27 16:08:23.098032: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
parser = argparse.ArgumentParser(description="LLMGNP")

### Dataset
parser.add_argument('--dataset', default='medqa', help='dataset name')
parser.add_argument('--data_dir', default='data', type=str, help='Path to the data directory')
parser.add_argument('--ent_emb_paths', default='umls/ent_emb_blbertL.npy', help='sources for entity embeddings')
parser.add_argument('--pad_node', default=-1, type=int)
#propmt
parser.add_argument('--context_node', default=True, type=utils.bool_flag, nargs='?', const=True)
parser.add_argument('--context_pad', default=-2, type=int)
parser.add_argument('--add_edge_attr', default='no', choices=['no', 'solo_label', 'semantic'])
parser.add_argument('--use_char_options_format', default=False, type=utils.bool_flag, nargs='?', const=True)
parser.add_argument('--prompt_context', default=True, type=utils.bool_flag, nargs='?', const=True)
parser.add_argument('--prompt_Lanswer', default=False, type=utils.bool_flag, nargs='?', const=True)
#subgraph
parser.add_argument('--num_subgraphs', default=4, type=int)
parser.add_argument('--sub_graphs_choice', default='all', choices=['all', 'solo_question', 'solo_option'])
parser.add_argument('--num_nodes', default=200, type=int)

###Tune-P
parser.add_argument('--learning_rate', default=1e-4, type=float, help='learning rate')
parser.add_argument('--epochs', default=50, type=int, help='total number of training epochs to perform.')
parser.add_argument('--batch_size', default=8, type=int)
#Opt
parser.add_argument('--optim', default='adamw', type=str)
parser.add_argument('--lr_schedule', default='fixed', choices=['fixed', 'warmup_linear', 'warmup_constant'], help='learning rate scheduler')
parser.add_argument('--warmup_ratio', type=float, default=0.1)
parser.add_argument('--max_grad_norm', default=1.0, type=float, help='max grad norm (0 to disable)')
parser.add_argument('--weight_decay', default=1e-2, type=float, help='l2 weight decay strength')

#Task_weight
parser.add_argument('--llm_task', type=float, default=1.0, help='Task weight for the LLM')
parser.add_argument('--lp_task', type=float, default=0.1, help='Task weight for the LinkPred task')

#-------Model-------
###LLM Model
parser.add_argument('--model_name',  default='google/flan-t5-small', help='encoder type')
parser.add_argument('--max_seq_len', default=512, type=int)
parser.add_argument('--max_tag_len', default=128, type=int)
parser.add_argument('--llm_frozen', default=True, type=utils.bool_flag, nargs='?', const=True)
#LLM decoder modity
parser.add_argument('--xattn_heads', default=8, type=int)
parser.add_argument('--xattn_after', default=True, type=utils.bool_flag, nargs='?', const=True)
###GNP Model
parser.add_argument('--add_gnp', default=True, type=utils.bool_flag, nargs='?', const=True)
parser.add_argument('--cross_gnp', default=False, type=utils.bool_flag, nargs='?', const=True)
parser.add_argument('--cross_gnp_num', default=5, type=int, help='attn_heads of the GNN layers')
#initial_embed
parser.add_argument('--random_ent_emb', default=False, type=utils.bool_flag, nargs='?', const=True, help='Whether to use randomly initialized learnable entity embeddings or not.')
parser.add_argument('--freeze_ent_emb', default=True, type=utils.bool_flag, nargs='?', const=True, help='Whether to freeze the entity embedding layer.')
#FFN
parser.add_argument('--ffn_mult', default=4, type=int)
#Gnn
parser.add_argument('--gnn_layers', default=4, type=int, help='numbers of the GNN layers')
parser.add_argument('--gnn_dim', default=1024, type=int, help='dimension of the GNN layers')
parser.add_argument('--gnn_heads', default=2, type=int, help='attn_heads of the GNN layers')
parser.add_argument('--gnn_norm', default=False, type=utils.bool_flag, nargs='?', const=True, help='encoder type')
parser.add_argument('--gnn_residual', default='no', choices=['no', 'simple', 'linear'])
parser.add_argument('--cross_gnp_choice', default=1, type=int, help='attn_heads of the GNN layers')
#Self-Attention
parser.add_argument('--self_attn_layers', default=1, type=int, help='numbers of the self-attention layers')
parser.add_argument('--self_attn_heads', default=2, type=int, help='attn_heads of the self-attention layers')
#Cross-Attention
parser.add_argument('--cross_attn_layers', default=1, type=int, help='numbers of the CMP layers')
parser.add_argument('--cross_attn_heads', default=2, type=int, help='attn_heads of the GNN layers')

###Link Prediction
parser.add_argument('--is_lp', default=True, type=utils.bool_flag, nargs='?', const=True)
parser.add_argument('--link_drop_probability', type=float, default=0.1, help='To specify #target positive triples for LinkPred')
parser.add_argument('--link_negative_sample_size', type=int, default=64, help='')
parser.add_argument('--link_negative_adversarial_sampling', type=utils.bool_flag, default=True, help='')
parser.add_argument('--link_negative_adversarial_sampling_temperature', type=float, default=1, help='')
parser.add_argument('--link_regularizer_weight', type=float, default=0.01, help='')
parser.add_argument('--scaled_distmult', type=utils.bool_flag, default=False, help='')
#-------Model end-------

#Activation
parser.add_argument('--activation', default='gelu', type=str)
parser.add_argument('--pre_norm', default=True, type=utils.bool_flag, nargs='?', const=True)
parser.add_argument('--init_range', default=0.02, type=float, help='stddev when initializing with normal distribution')
parser.add_argument('--dropout_emb', type=float, default=0.2, help='dropout for GNN layers')
parser.add_argument('--dropout_ffn', type=float, default=0.2, help='dropout for GNN layers')
parser.add_argument('--dropout_gnn', type=float, default=0.2, help='dropout for GNN layers')
parser.add_argument('--dropout_self_attn', type=float, default=0.1, help='dropout for GNN layers')
parser.add_argument('--dropout_cross_attn', type=float, default=0.1, help='dropout for GNN layers')

#save and log
parser.add_argument('--save_dir', default=f'./saved_models/', help='model output directory')
parser.add_argument('--save_model', default=1, type=float, help="0: do not save model checkpoints. 1: save if best dev. 2: save always")
parser.add_argument('--load_model_path', default=None, help="The model checkpoint to load in the evaluation mode.")

#Test-Debug use
parser.add_argument('--log_interval', default=10, type=int)
parser.add_argument('--seed', default=1, type=int, help='random seed')
parser.add_argument("--resume_checkpoint", default=None, type=str,
                    help="The checkpoint to resume training from.")
parser.add_argument('--mode', default='train', choices=['train', 'test'], help='run training or evaluation')

args=parser.parse_args([])


In [3]:
from train import ModelConfig
import numpy as np

args.dataset = 'medqa'
args.sub_graphs_choice = 'all'

args.mode = 'test'
args.model_name="google/flan-t5-base"
args.llm_frozen=True

args.add_gnp=True
args.cross_gnp=False

args.load_model_path = None
args.context_node = False

args.cross_gnp_choice = 1
args.gnn_residual = 'linear'
args.add_edge_attr = 'no'


 
config = ModelConfig(vars(args))

def get_device_and_set_seed(seed):
    """ Set all seeds to make results reproducible """
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    np.random.seed(seed)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda:0" if use_cuda else "cpu")
    return device

config.device = get_device_and_set_seed(0)

| num_concepts: 297927 |
[2024-01-27 16:08:49] - INFO:  ### Right now config 
[2024-01-27 16:08:49] - INFO: ### dataset = medqa
[2024-01-27 16:08:49] - INFO: ### data_dir = data
[2024-01-27 16:08:49] - INFO: ### ent_emb_paths = umls/ent_emb_blbertL.npy
[2024-01-27 16:08:49] - INFO: ### pad_node = -1
[2024-01-27 16:08:49] - INFO: ### context_node = False
[2024-01-27 16:08:49] - INFO: ### context_pad = -2
[2024-01-27 16:08:49] - INFO: ### add_edge_attr = no
[2024-01-27 16:08:49] - INFO: ### use_char_options_format = False
[2024-01-27 16:08:49] - INFO: ### prompt_context = True
[2024-01-27 16:08:49] - INFO: ### prompt_Lanswer = False
[2024-01-27 16:08:49] - INFO: ### num_subgraphs = 5
[2024-01-27 16:08:49] - INFO: ### sub_graphs_choice = all
[2024-01-27 16:08:49] - INFO: ### num_nodes = 200
[2024-01-27 16:08:49] - INFO: ### learning_rate = 0.0001
[2024-01-27 16:08:49] - INFO: ### epochs = 50
[2024-01-27 16:08:49] - INFO: ### batch_size = 8
[2024-01-27 16:08:49] - INFO: ### optim = adamw
[

In [4]:
#model
model = LLMGNP(config)
tokenizer = model.llm_tokenizer

Initializing LLM...
LLM Initialized
Initializing GNP module...
 ### Initializing embedding for CustomizedEmbedding...
 ### CustomizedEmbedding Initialized
 ### Initializing w_relation for DistMultDecoder...
 ### DistMultDecoder Initialized
GNP Initialized


In [6]:
from torchinfo import summary

summary(model)

Layer (type:depth-idx)                                                 Param #
LLMGNP                                                                 --
├─T5ForConditionalGeneration: 1-1                                      --
│    └─ModuleDict: 2-1                                                 --
│    └─Embedding: 2-2                                                  (24,674,304)
│    └─T5Stack: 2-3                                                    24,674,304
│    │    └─ModuleDict: 3-1                                            --
│    │    └─Embedding: 3-2                                             (recursive)
│    │    └─ModuleList: 3-3                                            (84,953,472)
│    │    └─T5LayerNorm: 3-4                                           (768)
│    │    └─Dropout: 3-5                                               --
│    └─T5Stack: 2-4                                                    24,674,304
│    │    └─ModuleDict: 3-6                                

In [None]:
#data
dataset = LLMGNP_DataLoader(config)

In [42]:
#subgraph
data_g = []
with open('./data/medqa/subgraphed/train.jsonl', 'r', encoding='utf-8') as file:
    for line in file:
        json_obj = json.loads(line)
        data_g.append(json_obj)

In [43]:
from itertools import chain

for data in data_g:
    data[:] = list(chain(*data))

In [44]:
res = {'nodes':[], 'edges': [], 'edge_types':[]}

for data in data_g:
    for each in data:
        if not each['edges']:
            each.update(res)

In [45]:
with open('train.jsonl', 'w') as fout:
    for dic in data_g:
        fout.write(json.dumps(dic) + '\n')

In [13]:
#debug each batch
batch = dataset.train_dataset[0]
lminputs, lmlabels, gnndata = batch

In [6]:
#LLM test
labels = lmlabels['input_ids']
labels[labels == model.llm_tokenizer.pad_token_id] = -100
labels_masks = lmlabels['attention_mask']

output = model.llm_model(input_ids = lminputs['input_ids'],
                                    attention_mask = lminputs['attention_mask'],
                                    labels =labels,
                                    decoder_attention_mask=labels_masks)
preds = F.softmax(output.logits, dim=-1).argmax(dim=-1)
print(tokenizer.batch_decode(preds, skip_special_tokens=True))
print(output.loss)

['yes', 'yes', 'no', 'no', 'yes', 'yes', 'no', 'yes']
tensor(1.3692)


In [6]:
#Model test
logits, lm_loss, lp_loss = model(batch, 'train')
preds = F.softmax(logits, dim=-1).argmax(dim=-1)
print(tokenizer.batch_decode(preds, skip_special_tokens=True))
print(lm_loss)
print(lp_loss)

['no', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes']
tensor(1.4679, grad_fn=<NllLossBackward0>)
(tensor(0.6969, grad_fn=<AddBackward0>), tensor(0.6934, grad_fn=<NegBackward0>), tensor(0.7004, grad_fn=<NegBackward0>))


In [2]:
for i in range(5):
    # print(i)

    if i != 5 - 1 :
        print(i)

0
1
2
3


### data debug

In [82]:
data_g = []
with open('/Users/soulofshadow/Downloads/UMLS/data/bioasq/subgraphed/test.jsonl', 'r', encoding='utf-8') as file:
    for line in file:
        json_obj = json.loads(line)
        data_g.append(json_obj)

In [None]:
E = len(H.edge_index[0])
positions = torch.arange(E)

In [92]:
sample = None
for i in range(len(data_g)):
    one_sample_graphs = []
    for j in range(len(data_g[0])):
        sample = data_g[i][j]
    break

In [None]:
sample

In [145]:
nodes = sample['nodes'][:200 - 1]
edge_list = []
edge_attr = []

In [147]:
for node in nodes[1:]:
    edge_list.append([nodes[0], node])
    edge_attr.append(-1)
for index, [a, b] in enumerate(sample['edges']):
    if a in nodes and b in nodes:
        edge_list.append([a, b])
        edge_attr.append(sample['edge_types'][index])


In [148]:
source_nodes = [s for s, t in edge_list]
target_nodes = [t for s, t in edge_list]
node_id_to_index = {node_id: idx for idx, node_id in enumerate(nodes)}
edge_index = torch.tensor([
    [node_id_to_index[s] for s in source_nodes],
    [node_id_to_index[t] for t in target_nodes]
], dtype=torch.long)

In [149]:
nodes = torch.tensor(nodes, dtype=torch.long)
if nodes.size(0) < 200:
    pad_length = 200 - nodes.size(0)
    nodes = torch.cat([nodes, torch.tensor([-1] * pad_length, dtype=torch.long)], dim=0)


In [152]:
edge_index.size()

torch.Size([2, 75])

In [150]:
edge_attr = [x-98 if x > (98 - 1) else x for x in edge_attr]
edge_attr = [new_dict[x] for x in edge_attr]
edge_attr = tokenizer(edge_attr, return_tensors='pt', padding=True, add_special_tokens=False)

In [151]:
edge_attr['input_ids'].size()

torch.Size([75, 12])

### GNP debug

In [13]:
T, all_hidden_states = model.get_embedding(lminputs, gnndata)
T_mask = lminputs['attention_mask']

In [14]:
batch_num = T.size(0) * config.num_subgraphs #batch_size * num_subgraphs
        
#init H + GNN
H = model.GNP.batch_choice_of_graph(gnndata)
H.x = model.GNP.ent_embed_init(H.x)
H.x = model.GNP.dropout_emb(H.x)
H_mask = H.mask
H_mask = H_mask.view(batch_num, config.num_nodes)

In [19]:
E = len(H.edge_index[0])
positions = torch.arange(E)

In [45]:
effective_nodes = set()
for item in H.edge_index.flatten():
    effective_nodes.add(item.item()) 
effective_nodes = list(effective_nodes)

In [None]:
[[effective_nodes[idx] for idx in row] for row in neg_nodes_indices]

In [46]:
neg_nodes_indices = torch.randint(0, len(effective_nodes), (10, 10)) 
print(neg_nodes_indices)

tensor([[ 60, 287, 101, 422, 651, 308, 116, 293, 159, 213],
        [ 85, 151, 192, 402, 715, 733, 753, 532, 292, 691],
        [262, 153, 720, 868, 902, 610, 496, 592, 461, 535],
        [ 58, 917, 650, 172, 862, 467, 910, 879, 545, 746],
        [409, 894,   9, 517, 310, 318, 175, 775, 313, 687],
        [283, 184, 506, 325, 801, 263, 801, 547, 216,  28],
        [274, 802, 280, 745, 932,  15, 189, 867, 785, 144],
        [730, 639, 260, 413, 517, 706, 847, 255, 307, 356],
        [890, 693, 663, 789, 918, 746, 133,  84, 507, 902],
        [471, 457, 275, 634, 376, 327, 378, 848, 265, 627]])


In [47]:
[[effective_nodes[idx] for idx in row] for row in neg_nodes_indices]

[[60, 287, 101, 485, 810, 308, 116, 293, 159, 213],
 [85, 151, 192, 465, 874, 1005, 1025, 595, 292, 850],
 [262, 153, 879, 1140, 1225, 673, 559, 655, 524, 598],
 [58, 1240, 809, 172, 1134, 530, 1233, 1202, 608, 1018],
 [472, 1217, 9, 580, 310, 318, 175, 1047, 313, 846],
 [283, 184, 569, 325, 1073, 263, 1073, 610, 216, 28],
 [274, 1074, 280, 1017, 1411, 15, 189, 1139, 1057, 144],
 [1002, 702, 260, 476, 580, 865, 1119, 255, 307, 419],
 [1213, 852, 822, 1061, 1241, 1018, 133, 84, 570, 1225],
 [534, 520, 275, 697, 439, 327, 441, 1120, 265, 690]]

In [37]:
H.edge_index

tensor([[   0,    0,    0,  ..., 1412, 1412, 1414],
        [   1,    2,    3,  ..., 1414, 1418, 1417]])

In [44]:
H.x.size()

torch.Size([1600, 1024])

In [42]:
H.x[14]

tensor([0., 0., 0.,  ..., 0., 0., 0.])

In [27]:
is_special = H.edge_lp[H.edge_lp == config.num_relations]

In [None]:
is_special

In [24]:
is_special = H.edge_lp == config.num_relations

In [17]:
positions = positions[H.edge_lp == config.num_relations]

1419

In [31]:
nodes_input = H.x.view(batch_num, config.num_nodes, -1)

for gnn_block, lm_hidden_states in zip(model.GNP.gnn_blocks, all_hidden_states[-len(model.GNP.gnn_blocks):]):
    H.x = gnn_block(H, lm_hidden_states, T_mask)

In [33]:
nodes_after_cross = H.x.view(batch_num, config.num_nodes, -1)
nodes_after_cross = model.GNP.activation_gat(nodes_after_cross)
nodes_after_cross = model.GNP.activation_residual(model.GNP.Vh(nodes_input) + model.GNP.Vx(nodes_after_cross))
#
H_node_final = nodes_after_cross

In [27]:
i = 0 #[0,1,2,3,4]
gnn_block = model.GNP.gnn_blocks[i]
lm_hidden_states = all_hidden_states[-(5-i)]

origin_size = H.x.size(0)

In [26]:
H.x

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  1.1256,  0.3548,  ..., -0.8073, -1.0626,  1.2089],
        [-0.0791,  0.0000, -0.0707,  ..., -0.0000, -0.0000,  1.1810],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]])

In [28]:
TT = torch.repeat_interleave(lm_hidden_states, repeats=config.num_subgraphs, dim=0)
TT_mask = torch.repeat_interleave(T_mask, repeats=config.num_subgraphs, dim=0)

HH = gnn_block.GNN(H)
HH = HH.view(TT.size(0), config.num_nodes, -1)
print(HH[0][0])
print(HH[0][1])
print(torch.max(HH))


tensor([ 0.0170,  0.0224,  0.0295,  ...,  0.0125, -0.0368,  0.0369],
       grad_fn=<SelectBackward0>)
tensor([0.6052, 0.5988, 0.8457,  ..., 0.5642, 1.1095, 0.1331],
       grad_fn=<SelectBackward0>)
tensor(2.9801, grad_fn=<MaxBackward1>)


In [29]:
context_node_gnn_feats = HH[:, 0, :].clone() # [bs, node_dim]
context_node_gnn_feats = gnn_block.norm(context_node_gnn_feats)
context_node_lm_feats = gnn_block.cross_attn(context_node_gnn_feats.unsqueeze(1), 
                                        TT, 
                                        TT_mask)
context_node_lm_feats = context_node_lm_feats.squeeze(1) # [bs, node_dim]
context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
context_node_feats = gnn_block.merge_ffn(context_node_feats)

# residual link
context_node_feats = context_node_feats + context_node_gnn_feats
HH[:, 0, :] = context_node_feats

print(HH[0][0])
print(HH[0][1])
print(torch.max(HH))

tensor([  67.6272, -126.7188,   61.7799,  ...,  424.8896, -303.9418,
         573.8397], grad_fn=<SelectBackward0>)
tensor([0.6052, 0.5988, 0.8457,  ..., 0.5642, 1.1095, 0.1331],
       grad_fn=<SelectBackward0>)
tensor(731.6631, grad_fn=<MaxBackward1>)


In [30]:
H.x = HH.view(origin_size, -1)

In [31]:
i = 1 #[0,1,2,3,4]
gnn_block = model.GNP.gnn_blocks[i]
lm_hidden_states = all_hidden_states[-(5-i)]

origin_size = H.x.size(0)

In [32]:
TT = torch.repeat_interleave(lm_hidden_states, repeats=config.num_subgraphs, dim=0)
TT_mask = torch.repeat_interleave(T_mask, repeats=config.num_subgraphs, dim=0)

HH = gnn_block.GNN(H)
HH = HH.view(TT.size(0), config.num_nodes, -1)
print(HH[0][0])
print(HH[0][2])
print(torch.max(HH))

tensor([-328.4462, -144.5006, -417.6876,  ..., -333.5429,  113.1736,
        -143.5552], grad_fn=<SelectBackward0>)
tensor([-0.1583, -0.1150, -0.0586,  ..., -0.2677,  0.2841, -0.4118],
       grad_fn=<SelectBackward0>)
tensor(928.3207, grad_fn=<MaxBackward1>)


In [17]:
print(HH[0][2])

tensor([ 0.0454, -0.0207, -0.1212,  ..., -0.2516, -0.7245,  1.0667],
       grad_fn=<SelectBackward0>)


In [21]:
H.edge_index

tensor([[   0,    0,    0,  ..., 1419, 1419, 1419],
        [   1,    2,    3,  ..., 1420, 1421, 1426]])

In [77]:
context_node_gnn_feats = HH[:, 0, :].clone() # [bs, node_dim]
context_node_gnn_feats = gnn_block.norm(context_node_gnn_feats)
context_node_lm_feats = gnn_block.cross_attn(context_node_gnn_feats.unsqueeze(1), 
                                        TT, 
                                        TT_mask)
context_node_lm_feats = context_node_lm_feats.squeeze(1) # [bs, node_dim]
context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
context_node_feats = gnn_block.merge_ffn(context_node_feats)

# residual link
context_node_feats = context_node_feats + context_node_gnn_feats
HH[:, 0, :] = context_node_feats

print(HH[0][0])
print(HH[0][1])
print(torch.max(HH))

tensor([  4.5510,   8.2308, -18.0366,  ...,  -4.0156,  -2.6247,  12.5429],
       grad_fn=<SelectBackward0>)
tensor([-0.3618, -0.2496,  0.4299,  ..., -0.4742, -0.0483, -0.2053],
       grad_fn=<SelectBackward0>)
tensor(863.4166, grad_fn=<MaxBackward1>)


In [78]:
H.x = HH.view(origin_size, -1)

In [79]:
i = 3 #[0,1,2,3,4]
gnn_block = model.GNP.gnn_blocks[i]
lm_hidden_states = all_hidden_states[-(5-i)]

origin_size = H.x.size(0)

TT = torch.repeat_interleave(lm_hidden_states, repeats=config.num_subgraphs, dim=0)
TT_mask = torch.repeat_interleave(T_mask, repeats=config.num_subgraphs, dim=0)

HH = gnn_block.GNN(H)
HH = HH.view(TT.size(0), config.num_nodes, -1)
print(HH[0][0])
print(HH[0][1])
print(torch.max(HH))

context_node_gnn_feats = HH[:, 0, :].clone() # [bs, node_dim]
context_node_gnn_feats = gnn_block.norm(context_node_gnn_feats)
context_node_lm_feats = gnn_block.cross_attn(context_node_gnn_feats.unsqueeze(1), 
                                        TT, 
                                        TT_mask)
context_node_lm_feats = context_node_lm_feats.squeeze(1) # [bs, node_dim]
context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
context_node_feats = gnn_block.merge_ffn(context_node_feats)

# residual link
context_node_feats = context_node_feats + context_node_gnn_feats
HH[:, 0, :] = context_node_feats

print(HH[0][0])
print(HH[0][1])
print(torch.max(HH))

H.x = HH.view(origin_size, -1)

tensor([-14.5090, -10.2454,  -5.1613,  ...,  35.2490, -18.4673,   9.3669],
       grad_fn=<SelectBackward0>)
tensor([ 0.0727, -0.2229,  0.0707,  ...,  0.0817, -0.4731,  0.5370],
       grad_fn=<SelectBackward0>)
tensor(1096.0328, grad_fn=<MaxBackward1>)
tensor([-584.8746,  217.7348,  771.8218,  ...,  262.2571, -127.7205,
         456.6053], grad_fn=<SelectBackward0>)
tensor([ 0.0727, -0.2229,  0.0707,  ...,  0.0817, -0.4731,  0.5370],
       grad_fn=<SelectBackward0>)
tensor(1702.2756, grad_fn=<MaxBackward1>)


In [80]:
i = 3 #[0,1,2,3,4]
gnn_block = model.GNP.gnn_blocks[i]
lm_hidden_states = all_hidden_states[-(5-i)]

origin_size = H.x.size(0)

TT = torch.repeat_interleave(lm_hidden_states, repeats=config.num_subgraphs, dim=0)
TT_mask = torch.repeat_interleave(T_mask, repeats=config.num_subgraphs, dim=0)

HH = gnn_block.GNN(H)
HH = HH.view(TT.size(0), config.num_nodes, -1)
print(HH[0][0])
print(HH[0][1])
print(torch.max(HH))

context_node_gnn_feats = HH[:, 0, :].clone() # [bs, node_dim]
context_node_gnn_feats = gnn_block.norm(context_node_gnn_feats)
context_node_lm_feats = gnn_block.cross_attn(context_node_gnn_feats.unsqueeze(1), 
                                        TT, 
                                        TT_mask)
context_node_lm_feats = context_node_lm_feats.squeeze(1) # [bs, node_dim]
context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
context_node_feats = gnn_block.merge_ffn(context_node_feats)

# residual link
context_node_feats = context_node_feats + context_node_gnn_feats
HH[:, 0, :] = context_node_feats

print(HH[0][0])
print(HH[0][1])
print(torch.max(HH))

H.x = HH.view(origin_size, -1)

tensor([ -472.9886,  1065.9122,    -2.3353,  ...,  1392.9152,   614.6643,
        -1501.3160], grad_fn=<SelectBackward0>)
tensor([ 0.4715, -1.1434,  1.1304,  ..., -2.5475,  0.1591, -0.5176],
       grad_fn=<SelectBackward0>)
tensor(2666.8748, grad_fn=<MaxBackward1>)
tensor([ 383.9848,  640.9866,  133.9565,  ..., -633.9208, -509.8645,
           3.7313], grad_fn=<SelectBackward0>)
tensor([ 0.4715, -1.1434,  1.1304,  ..., -2.5475,  0.1591, -0.5176],
       grad_fn=<SelectBackward0>)
tensor(1685.3871, grad_fn=<MaxBackward1>)


In [None]:
i = 4 #[0,1,2,3,4]
gnn_block = model.GNP.gnn_blocks[i]
lm_hidden_states = all_hidden_states[-(5-i)]

origin_size = H.x.size(0)

TT = torch.repeat_interleave(lm_hidden_states, repeats=config.num_subgraphs, dim=0)
TT_mask = torch.repeat_interleave(T_mask, repeats=config.num_subgraphs, dim=0)

HH = gnn_block.GNN(H)
HH = HH.view(TT.size(0), config.num_nodes, -1)
print(HH[0][0])
print(HH[0][1])
print(torch.max(HH))

context_node_gnn_feats = HH[:, 0, :].clone() # [bs, node_dim]
context_node_gnn_feats = gnn_block.norm(context_node_gnn_feats)
context_node_lm_feats = gnn_block.cross_attn(context_node_gnn_feats.unsqueeze(1), 
                                        TT, 
                                        TT_mask)
context_node_lm_feats = context_node_lm_feats.squeeze(1) # [bs, node_dim]
context_node_feats = torch.cat([context_node_lm_feats, context_node_gnn_feats], dim=1)
context_node_feats = gnn_block.merge_ffn(context_node_feats)

# residual link
context_node_feats = context_node_feats + context_node_gnn_feats
HH[:, 0, :] = context_node_feats

print(HH[0][0])
print(HH[0][1])
print(torch.max(HH))

H.x = HH.view(origin_size, -1)

: 