# Subgraph Generation

## Load Dataset

In [1]:
from kgat.data import load_id2map, SubgraphGenerationDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
sg_data_path = "./data/subgraph-gen/atomic/proc/train.json"
sg_entity_path = "./data/subgraph-gen/atomic/proc/entities.txt"
sg_relations_path = "./data/subgraph-gen/atomic/proc/relations.txt"

sg_id2entity = load_id2map(sg_entity_path)
sg_id2rel = load_id2map(sg_relations_path)

In [3]:
sg_ds = SubgraphGenerationDataset(sg_data_path, sg_id2entity, sg_id2rel)

In [4]:
from torch.utils.data import DataLoader
from kgat.data import subgraphgen_collate_fn

sg_dataloader = DataLoader(sg_ds, batch_size=5, shuffle=False, collate_fn=subgraphgen_collate_fn)

In [5]:
batch = next(iter(sg_dataloader))

## Prepare Model

In [6]:
from kgat.model.graph import GATAggregateGraphPooler, SubgraphPooler
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name_or_path = "openai-community/gpt2"

clm = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left")

In [7]:
clm.transformer.embed_dim

768

In [8]:
graphpooler = GATAggregateGraphPooler(in_channels=clm.transformer.embed_dim, 
                                      hidden_channels=1024, 
                                      out_channels=1024, heads=3, 
                                      edge_dim=clm.transformer.embed_dim, 
                                      n_mlp_layers=5)

subgraphpooler = SubgraphPooler(graph_emb_dim=1024,
                                text_emb_dim=clm.transformer.embed_dim,
                                hidden_dim=1024)

In [10]:
from kgat.model.graph import GraphModule

tokenizer.pad_token_id = tokenizer.eos_token_id
clm.transformer.config.pad_token_id = clm.transformer.config.eos_token_id

graph_module = GraphModule(transformer=clm.transformer, graphpooler=graphpooler, subgraphpooler=subgraphpooler,
                           prepare_inputs_method=clm.prepare_inputs_for_generation)

## Forward Pass

In [11]:
"""graph_query_input_ids, graph_query_attention_mask,
entities_input_ids, entities_attention_mask,
relations_input_ids, relations_attention_mask,
x_coo, batch"""

'graph_query_input_ids, graph_query_attention_mask,\nentities_input_ids, entities_attention_mask,\nrelations_input_ids, relations_attention_mask,\nx_coo, batch'

In [12]:
graph_query = tokenizer(batch[0], padding=True, truncation=True, max_length=64, return_tensors="pt")
entities = tokenizer(batch[1], padding=True, truncation=True, max_length=16, return_tensors="pt")
relations = tokenizer(batch[2], padding=True, truncation=True, max_length=16, return_tensors="pt")
x_coo = batch[3]
node_batch = batch[4]
y_coo = batch[5]

In [13]:
node_batch[x_coo[0]].shape

torch.Size([467])

In [14]:
node_batch[x_coo[2]].shape

torch.Size([467])

In [15]:
x_coo[:,-5:]

tensor([[442, 391, 443, 391, 372],
        [ 35,  43,  35,  37,  37],
        [393, 389, 412, 418, 449]])

In [16]:
node_batch

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3,

In [17]:
mean_fused_score, subgraph_emb, edge_batch = graph_module(graph_query_input_ids=graph_query["input_ids"], graph_query_attention_mask=graph_query["attention_mask"],
entities_input_ids=entities["input_ids"], entities_attention_mask=entities["attention_mask"],
relations_input_ids=relations["input_ids"], relations_attention_mask=relations["attention_mask"], 
x_coo=x_coo, batch=node_batch)

In [18]:
mean_fused_score[:10]

tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000], grad_fn=<SliceBackward0>)

In [19]:
y_coo

tensor([0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
        0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0,
        0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1,
        0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0,
        0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0,
        0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0,
        1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0,
        1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0,
        0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0,
        0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1,

# LM KBC

## Load Dataset

In [20]:
from kgat.data import load_json, LMKBCDataset
lmkbc_data_path = "./data/lm-kbc/conceptnet/proc/val.json"
lmkbc_entity_path = "./data/lm-kbc/conceptnet/proc/entities.txt"
lmkbc_relations_path = "./data/lm-kbc/conceptnet/proc/relations.txt"
lmkbc_triples_path = "./data/lm-kbc/conceptnet/proc/triples.json"

lmkbc_id2entity = load_id2map(lmkbc_entity_path)
lmkbc_id2rel = load_id2map(lmkbc_relations_path)
lmkbc_triples = load_json(lmkbc_triples_path)

In [21]:
from kgat.utils import SUBJECT_MASK, RELATION_MASK, OBJECT_MASK, KG_MASK

prompt_template = f"Based on {KG_MASK} complete this -> S : {SUBJECT_MASK} , R : {RELATION_MASK} , O : {OBJECT_MASK}"
graph_query_template = f"S : {SUBJECT_MASK} , R : {RELATION_MASK}"

In [22]:
lmkbc_ds = LMKBCDataset(lmkbc_data_path, lmkbc_id2entity, lmkbc_id2rel, lmkbc_triples, prompt_template=prompt_template, graph_query_template=graph_query_template)

In [23]:
tokenizer.add_tokens(KG_MASK)

1

In [24]:
kg_id = tokenizer.get_added_vocab()[KG_MASK]

## Prepare Model

In [25]:
from kgat.model.text import TextModule
from kgat.model.text import SubgraphVTTransformation
from kgat.model import ModelForLMKBC

vt_transformer = SubgraphVTTransformation(subgraph_emb_dim=1024, hidden_dim=1024, word_emb_dim=768)
text_module = TextModule(vt_transformer=vt_transformer,
                         clm_embedding=clm.transformer.wte,
                         clm_model=clm,
                         kg_id=kg_id)
lmkbc_model = ModelForLMKBC(graph_module=graph_module, text_module=text_module)

## Forward Pass

In [26]:
from kgat.data import lmkbc_collate_fn

lmkbc_dataloader = DataLoader(lmkbc_ds, batch_size=2, shuffle=False, collate_fn=lmkbc_collate_fn)

In [27]:
import time

start = time.time()
batch = next(iter(lmkbc_dataloader))

print(time.time() - start)

5.475267648696899


In [28]:
# """graph_query_input_ids, graph_query_attention_mask,
# prompt_input_ids, prompt_attention_mask,
# entities_input_ids, entities_attention_mask,
# relations_input_ids, relations_attention_mask,
# x_coo, batch"""

In [29]:
# text_in, graph_query, entities, relations, x_coo, batch, text_out

In [30]:
prompt = tokenizer(batch[0], padding=True, truncation=True, max_length=64, return_tensors="pt")
graph_query = tokenizer(batch[1], padding=True, truncation=True, max_length=64, return_tensors="pt")
entities = tokenizer(batch[2], padding=True, truncation=True, max_length=16, return_tensors="pt")
relations = tokenizer(batch[3], padding=True, truncation=True, max_length=16, return_tensors="pt")
x_coo = batch[4]
node_batch = batch[5]
labels = tokenizer(batch[6], padding=True, truncation=True, max_length=64, return_tensors="pt")

In [31]:
mean_fused_score, subgraph_emb, edge_batch = lmkbc_model.graph_module(
                graph_query["input_ids"], graph_query["attention_mask"],
                entities["input_ids"], entities["attention_mask"],
                relations["input_ids"], relations["attention_mask"],
                x_coo, node_batch)

In [32]:
out = lmkbc_model.text_module(prompt["input_ids"], prompt["attention_mask"], subgraph_emb)

In [33]:
out

CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[-31.3032, -30.2778, -32.2992,  ..., -40.4724, -39.6234, -30.6677],
         [-61.9246, -61.7303, -62.8354,  ..., -68.9247, -69.4393, -62.8554],
         [-67.1411, -67.2283, -70.2463,  ..., -75.0992, -75.7782, -68.2980],
         ...,
         [-57.2997, -56.2022, -55.2405,  ..., -65.2396, -64.3749, -55.5000],
         [-84.1450, -83.7201, -84.8075,  ..., -87.3241, -87.7441, -81.8412],
         [-48.9581, -50.0280, -50.1860,  ..., -55.1180, -55.6311, -49.8897]],

        [[-31.3032, -30.2778, -32.2992,  ..., -40.4724, -39.6234, -30.6677],
         [-25.7829, -26.1424, -27.8339,  ..., -31.2297, -31.0444, -26.6844],
         [-81.4803, -82.6399, -83.9219,  ..., -91.2887, -90.9388, -82.3843],
         ...,
         [-52.7651, -51.7204, -51.9096,  ..., -63.2706, -61.4108, -51.0049],
         [-84.2648, -84.8246, -85.7637,  ..., -90.9844, -90.2013, -82.8538],
         [-65.1761, -66.9361, -67.2005,  ..., -74.6778, -73.4716, -66.9