In [17]:
import os
from os.path import join as opj
import pickle
import numpy as np
import torch.nn.functional as F
import torch

### Example code of processing Webqsp

#### 1. Load raw, embedding, and scores

In [2]:
raw_data = pickle.load(open("./samples/webqsp_val_raw.pkl", "rb"))

In [4]:
scored_data = torch.load("./samples/webqsp_241028_val.pth")

In [25]:
scored_data['WebQTrn-11'].keys()

dict_keys(['question', 'scored_triples', 'q_entity_in_graph', 'a_entity_in_graph', 'max_path_length', 'target_relevant_triples'])

In [5]:
embedding1 = torch.load("./samples/0.pth")
embedding2 = torch.load("./samples/1.pth")
embeddings = {**embedding1, **embedding2}

In [10]:
for idx, (key1, key2) in enumerate(zip(embeddings.keys(), scored_data.keys())):
    assert key1 == raw_data[idx]['id'] == key2

#### 2. Filtering

In [11]:
processed_data = []

In [15]:
filter_K = 300

In [None]:
for sample in raw_data:
    processed_sample = {}
    sample_id = sample['id']
    scored_sample = scored_data[sample_id]
    sample_embd = embeddings[sample_id]
    if scored_sample['max_path_length'] in [None, 0]:
        continue

    fh_id_list, fr_id_list, ft_id_list = [], [], []

    scored_triplets = scored_sample['scored_triples']
    assert len(scored_triplets) == len(sample["h_id_list"])
    filtered_triplets = [(t[0], t[1], t[2]) for idx, t in enumerate(scored_triplets) if idx < filter_K]

    # Find preserved (h,r,t) id.
    entity_list = sample['text_entity_list'] + sample['non_text_entity_list']
    relation_list = sample['relation_list']

    for h_id, r_id, t_id in zip(sample['h_id_list'], sample['r_id_list'], sample['t_id_list']):
        h, r, t = entity_list[h_id], relation_list[r_id], entity_list[t_id]
        if (h, r, t) in filtered_triplets:
            fh_id_list.append(h_id)
            fr_id_list.append(r_id)
            ft_id_list.append(t_id)
    assert len(fh_id_list) == len(fr_id_list) == len(ft_id_list) == filter_K
    # Get filtered node & edge embeddings, and edge-index
    # 1. edge attribute
    edge_attr = sample_embd['relation_embs'][fr_id_list]
    # 2. node embeddings
    entity_embeddings = torch.cat([
                sample_embd['entity_embs'],
                torch.zeros(len(sample['non_text_entity_list']), sample_embd['entity_embs'].size(1))
                ], dim=0)
    selected_nodes = torch.unique(fh_id_list + ft_id_list)
    new_entity_embeddings = entity_embeddings[selected_nodes]
    # 3. edge index, from old idx to new idx
    idx_mapping = {old_idx: new_idx for new_idx, old_idx in enumerate(selected_nodes)}
    edge_index_org = np.stack([fh_id_list, ft_id_list], dim=0)
    edge_index = np.vectorize(idx_mapping.get)(edge_index_org)

    processed_sample = {
        "edge_index": edge_index,
        "x": new_entity_embeddings,
        "edge_attr": edge_attr,
        "triplets": filtered_triplets,
        'relevant_triples': scored_data['target_relevant_triples'],
        "id": sample_id,
        "q": sample['question'],
        "q_embd": sample_embd['q_emb'],
    }
    break


In [24]:
scored_data['WebQTrn-126']['target_relevant_triples']

[('Bangkok National Museum',
  'travel.tourist_attraction.near_travel_destination',
  'Bangkok'),
 ('Thonburi', 'travel.tourist_attraction.near_travel_destination', 'Bangkok'),
 ('Bangkok',
  'travel.travel_destination.tourist_attractions',
  'Bangkok Aquarium'),
 ('Bangkok', 'travel.travel_destination.tourist_attractions', 'Khaosan Road'),
 ('Bangkok',
  'travel.travel_destination.tourist_attractions',
  'Rajamangala Stadium'),
 ('Bangkok',
  'travel.travel_destination.tourist_attractions',
  'Democracy Monument'),
 ('Samutprakarn Crocodile Farm and Zoo',
  'travel.tourist_attraction.near_travel_destination',
  'Bangkok'),
 ('Bangkok',
  'travel.travel_destination.tourist_attractions',
  'Chatuchak Park'),
 ('Siam Park City',
  'travel.tourist_attraction.near_travel_destination',
  'Bangkok'),
 ('Bangkok', 'travel.travel_destination.tourist_attractions', 'Golden Buddha'),
 ('Bangkok', 'travel.travel_destination.tourist_attractions', 'Thonburi'),
 ('Bangkok International Trade and Exhi

In [26]:
data_process = torch.load("./samples/processed_webqsp_val_300.pth")

In [27]:
data_process[0]['relevant_idx']

[]

In [28]:
import torch

# 假设 edge_index 是你的边索引张量
edge_index = torch.tensor([[0, 1, 2],  # 第一行
                           [3, 4, 5]]) # 第二行

# 颠倒两行
flipped_edge_index = edge_index.flip(0)

print(flipped_edge_index)

tensor([[3, 4, 5],
        [0, 1, 2]])
