In [1]:
from kgat import load_config_sg
import json

with open("./config/model/default.json", 'r') as fp:
    config = json.load(fp)
model = load_config_sg(config, clm=None)

In [2]:
model

GraphModule(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (graphpooler): GATAggregateGraphPooler(
    (conv1): GATConv(768, 1024, heads=3)
    (conv2): GATConv(1024, 1024, heads=3)
    (pooler): SAGPooling(GATConv, 1024, ratio=0.99999999999999

# Subgraph Generation

## Load Dataset

In [3]:
from kgat.data import load_id2map, SubgraphGenerationDataset

In [4]:
sg_data_path = "./data/subgraph-gen/atomic/proc/train.json"
sg_entity_path = "./data/subgraph-gen/atomic/proc/entities.txt"
sg_relations_path = "./data/subgraph-gen/atomic/proc/relations.txt"

sg_id2entity = load_id2map(sg_entity_path)
sg_id2rel = load_id2map(sg_relations_path)

In [5]:
sg_ds = SubgraphGenerationDataset(sg_data_path, sg_id2entity, sg_id2rel)

## Prepare Model

In [6]:
from kgat.model.graph import GATAggregateGraphPooler, SubgraphPooler
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name_or_path = "openai-community/gpt2"

clm = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")

In [7]:
clm.transformer.embed_dim

768

In [8]:
from torch.utils.data import DataLoader
from kgat.data import SubgraphGenerationCollator

tokenizer.pad_token = tokenizer.eos_token
n_process = 3
sg_collator = SubgraphGenerationCollator(tokenizer=tokenizer, n_process=n_process)
sg_dataloader = DataLoader(sg_ds, batch_size=5, shuffle=False, collate_fn=sg_collator)

In [9]:
batch = next(iter(sg_dataloader))

In [8]:
graphpooler = GATAggregateGraphPooler(in_channels=clm.transformer.embed_dim, 
                                      hidden_channels=1024, 
                                      out_channels=1024, heads=3, 
                                      edge_dim=clm.transformer.embed_dim, 
                                      n_mlp_layers=5)

subgraphpooler = SubgraphPooler(graph_emb_dim=1024,
                                text_emb_dim=clm.transformer.embed_dim,
                                hidden_dim=1024)

In [9]:
from kgat.model.graph import GraphModule

tokenizer.pad_token_id = tokenizer.eos_token_id
clm.transformer.config.pad_token_id = clm.transformer.config.eos_token_id

graph_module = GraphModule(transformer=clm.transformer, graphpooler=graphpooler, subgraphpooler=subgraphpooler,
                           prepare_inputs_method=clm.prepare_inputs_for_generation)

## Forward Pass

In [10]:
batch["graph_query_input_ids"]

tensor([[15439,    55,  4113,  7755,    56,   338, 46444,   379,   262, 18264],
        [50256, 50256, 50256, 15439,    55,  2058,   284,   651,  7755,    56],
        [50256, 50256, 50256, 15439,    55,  7584, 46444,   287,   262, 30967],
        [50256, 50256, 50256, 50256, 50256, 15439,    55,   318,  1972,  4697],
        [50256, 50256, 50256, 50256, 15439,    55,  4952,  7755,    56,  1995]])

In [11]:
"""graph_query_input_ids, graph_query_attention_mask,
entities_input_ids, entities_attention_mask,
relations_input_ids, relations_attention_mask,
x_coo, batch"""

'graph_query_input_ids, graph_query_attention_mask,\nentities_input_ids, entities_attention_mask,\nrelations_input_ids, relations_attention_mask,\nx_coo, batch'

In [12]:
# graph_query = tokenizer(batch[0], padding=True, truncation=True, max_length=64, return_tensors="pt")
# entities = tokenizer(batch[1], padding=True, truncation=True, max_length=16, return_tensors="pt")
# relations = tokenizer(batch[2], padding=True, truncation=True, max_length=16, return_tensors="pt")
# x_coo = batch[3]
# node_batch = batch[4]
# y_coo = batch[5]

In [13]:
# node_batch[x_coo[0]].shape

In [14]:
# node_batch[x_coo[2]].shape

In [15]:
# x_coo[:,-5:]

In [16]:
# node_batch

In [17]:
y_coo_cls = batch.pop("y_coo_cls")
mean_fused_score, subgraph_emb, edge_batch = graph_module(**batch)

In [18]:
mean_fused_score

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 

In [19]:
y_coo_cls

tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1,

In [20]:
batch["batch"]

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3,

In [21]:
batch["graph_query_input_ids"]

tensor([[15439,    55,  4113,  7755,    56,   338, 46444,   379,   262, 18264],
        [50256, 50256, 50256, 15439,    55,  2058,   284,   651,  7755,    56],
        [50256, 50256, 50256, 15439,    55,  7584, 46444,   287,   262, 30967],
        [50256, 50256, 50256, 50256, 50256, 15439,    55,   318,  1972,  4697],
        [50256, 50256, 50256, 50256, 15439,    55,  4952,  7755,    56,  1995]])

In [23]:
subgraph_emb.shape

torch.Size([5, 3, 1024])