In [None]:
import torch
import json
import numpy as np
np.set_printoptions(linewidth=10000)
import uuid
import random
import string

json_path = '/home/pgajo/pyg/data/yamakata/efrc_ud.json'

with open(json_path, 'r', encoding='utf8') as f:
    data = json.load(f)

# get node and edge labels
pos_tags = []

for line in data:
    pos_tags.extend(line['pos_tags'])

head_tags = []

for line in data:
    head_tags.extend(line['head_tags'])
pos_tags = sorted(set(pos_tags))
head_tags = sorted(set(head_tags))

# let's work on a single sample
sample = data[0]


head_id = np.array(sample['head_indices'])
node_type = np.array(sample['pos_tags'])
edge_type = np.array(sample['head_tags'])

head_id = np.concatenate([[0], head_id])
node_type = np.concatenate([['O'], node_type])
edge_type = np.concatenate([['root'], edge_type])

words = np.array(['root'] + sample['words'])
heads = words[head_id]

# make uuid_ids 
uuid_ids = np.array([str(uuid.uuid4()).split('-')[0] for _ in range(words.shape[0])])

new_uuid_ids = uuid_ids[head_id]

word_id = np.arange(words.shape[0])

has_in_edge = [0 if id not in head_id else 1 for id in word_id]

graph = np.stack([word_id, has_in_edge, words, heads, node_type, edge_type, head_id], axis = 1)

head_id_mask = graph[:, 6].astype(int) != 0
has_in_edge_mask = graph[:, 1].astype(int) == 1
and_mask = np.logical_or(head_id_mask, has_in_edge_mask)
graph = graph[and_mask]

reset_word_id = np.arange(graph.shape[0]).reshape(-1, 1)

new_word_ids = []
for line in graph:
    head_id = line[6]
    target = np.where(graph[:, 0] == head_id)
    new_word_ids.append(reset_word_id[target][0])

new_word_ids = np.array(new_word_ids)
edge_index = np.stack([reset_word_id.squeeze(), new_word_ids.squeeze()])

graph = np.hstack([graph, new_word_ids])
graph

array([['0', '1', 'root', 'root', 'O', 'root', '0', '0'],
       ['1', '1', 'Prick', 'potatoes', 'Ac-B', 'f-eq', '12', '5'],
       ['3', '0', 'potatoes', 'Prick', 'F-B', 't', '1', '1'],
       ['6', '0', 'fork', 'Prick', 'T-B', 't-comp', '1', '1'],
       ['11', '1', 'rub', 'sprinkle', 'Ac-B', 'd', '17', '7'],
       ['12', '1', 'potatoes', 'rub', 'F-B', 't', '11', '4'],
       ['14', '0', 'olive', 'rub', 'F-B', 'f-comp', '11', '4'],
       ['17', '1', 'sprinkle', 'wrap', 'Ac-B', 't', '21', '9'],
       ['19', '0', 'salt', 'sprinkle', 'F-B', 'f-comp', '17', '7'],
       ['21', '1', 'wrap', 'potatoes', 'Ac-B', 'f-eq', '28', '12'],
       ['24', '0', 'foil', 'wrap', 'T-B', 't-comp', '21', '9'],
       ['26', '1', 'Place', 'cover', 'Ac-B', 'd', '34', '14'],
       ['28', '1', 'potatoes', 'Place', 'F-B', 't', '26', '11'],
       ['31', '0', 'slow', 'Place', 'T-B', 'd', '26', '11'],
       ['34', '1', 'cover', 'cook', 'Ac-B', 't', '37', '15'],
       ['37', '1', 'cook', 'root', 'Ac-B', 'ro

In [11]:
edge_index, edge_index.shape

(array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20],
        [ 0,  5,  1,  1,  7,  4,  4,  9,  7, 12,  9, 14, 11, 11, 15,  0, 15, 15, 16, 17, 15]]),
 (2, 21))

In [12]:
graph.shape

(21, 8)

In [13]:
gd = {}

gd['word_id'] = graph[:, 0]
gd['has_in_edge'] = graph[:, 1]
gd['word'] = graph[:, 2]
gd['head'] = graph[:, 3]
gd['node_type'] = graph[:, 4]
gd['edge_type'] = graph[:, 5]
gd['head_id'] = [int(el) for el in graph[:, 6]]
# gd['node_uuid'] = graph[:, 7]
# gd['head_uuid'] = graph[:, 8]
gd['new_id'] = graph[:, 7]

# id_dict = {id: n for n, id in enumerate(gd['node_uuid'])}
gd

{'word_id': array(['0', '1', '3', '6', '11', '12', '14', '17', '19', '21', '24', '26', '28', '31', '34', '37', '39', '41', '49', '51', '57'], dtype='<U21'),
 'has_in_edge': array(['1', '1', '0', '0', '1', '1', '0', '1', '0', '1', '0', '1', '1', '0', '1', '1', '1', '1', '0', '0', '0'], dtype='<U21'),
 'word': array(['root', 'Prick', 'potatoes', 'fork', 'rub', 'potatoes', 'olive', 'sprinkle', 'salt', 'wrap', 'foil', 'Place', 'potatoes', 'slow', 'cover', 'cook', 'High', '4', 'Low', '7', 'tender'], dtype='<U21'),
 'head': array(['root', 'potatoes', 'Prick', 'Prick', 'sprinkle', 'rub', 'rub', 'wrap', 'sprinkle', 'potatoes', 'wrap', 'cover', 'Place', 'Place', 'cook', 'root', 'cook', 'cook', 'High', '4', 'cook'], dtype='<U21'),
 'node_type': array(['O', 'Ac-B', 'F-B', 'T-B', 'Ac-B', 'F-B', 'F-B', 'Ac-B', 'F-B', 'Ac-B', 'T-B', 'Ac-B', 'F-B', 'T-B', 'Ac-B', 'Ac-B', 'T-B', 'D-B', 'T-B', 'D-B', 'Sf-B'], dtype='<U21'),
 'edge_type': array(['root', 'f-eq', 't', 't-comp', 'd', 't', 'f-comp', 't', 'f

In [15]:
import pandas as pd

df_graph = pd.DataFrame(gd,
                        columns=[
                            'word_id',
                            'has_in_edge',
                            'word',
                            'node_type',
                            'edge_type',
                            'head_id',
                            'head',
                            # 'node_uuid',
                            # 'head_uuid',
                            'new_id',
                            ])

# df_graph = df_graph[(df_graph['head_id'] != 0) & (df_graph['has_in_edge'] != 1)]

df_graph

Unnamed: 0,word_id,has_in_edge,word,node_type,edge_type,head_id,head,new_id
0,0,1,root,O,root,0,root,0
1,1,1,Prick,Ac-B,f-eq,12,potatoes,5
2,3,0,potatoes,F-B,t,1,Prick,1
3,6,0,fork,T-B,t-comp,1,Prick,1
4,11,1,rub,Ac-B,d,17,sprinkle,7
5,12,1,potatoes,F-B,t,11,rub,4
6,14,0,olive,F-B,f-comp,11,rub,4
7,17,1,sprinkle,Ac-B,t,21,wrap,9
8,19,0,salt,F-B,f-comp,17,sprinkle,7
9,21,1,wrap,Ac-B,f-eq,28,potatoes,12


In [8]:
from torch_geometric.data import HeteroData
import torch_geometric.transforms as T

data = HeteroData()