In [2]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from datasets import load_dataset, concatenate_datasets
from torch_geometric.data import Data
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict, load_from_disk
from torch import nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from torch.utils.data import DataLoader
import torch
import numpy as np
from pcst_fast import pcst_fast
from torch_geometric.data.data import Data

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
dataset = load_dataset("rmanluo/RoG-cwq")

Downloading data: 100%|██████████| 18/18 [29:09<00:00, 97.22s/files] 
Generating train split: 100%|██████████| 27639/27639 [00:38<00:00, 709.46 examples/s]
Generating validation split: 100%|██████████| 3519/3519 [00:05<00:00, 683.30 examples/s]
Generating test split: 100%|██████████| 3531/3531 [00:05<00:00, 668.53 examples/s]


In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
        num_rows: 27639
    })
    validation: Dataset({
        features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
        num_rows: 3519
    })
    test: Dataset({
        features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
        num_rows: 3531
    })
})

In [5]:
dataset = concatenate_datasets([dataset['train'], dataset['validation'], dataset['test']])
dataset

Dataset({
    features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
    num_rows: 34689
})

In [6]:
seed = 0
percent_data = 0.01

train_test_split = dataset.train_test_split(test_size=percent_data, seed=seed, shuffle=True)

X_train = train_test_split['train']
dataset = train_test_split['test']

In [7]:

dataset

Dataset({
    features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
    num_rows: 347
})

In [8]:

train_size = 147
val_size = 100
test_size = 100

train_data = dataset.select(range(train_size))
val_data = dataset.select(range(train_size, train_size + val_size))
test_data = dataset.select(range(train_size + val_size, train_size + val_size + test_size))


# Combine the splits into a DatasetDict
sample_dataset = DatasetDict({
    "train": train_data,
    "validation": val_data,
    "test":test_data
})


In [9]:
sample_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
        num_rows: 147
    })
    validation: Dataset({
        features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
        num_rows: 100
    })
    test: Dataset({
        features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
        num_rows: 100
    })
})

In [10]:
model_name = 'sbert'
path = '/home/ahmadi/sadaf/GraphNeighborLM/G-retriever/dataset/ComplexWebQuestions'
path_nodes = f'{path}/nodes'
path_edges = f'{path}/edges'
path_graphs = f'{path}/graphs'

In [11]:
from datasets import DatasetDict

sample_dataset.save_to_disk(f'{path}/processed_dataset')


Saving the dataset (0/1 shards):   0%|          | 0/147 [00:00<?, ? examples/s]

Saving the dataset (1/1 shards): 100%|██████████| 147/147 [00:00<00:00, 1173.73 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 1394.16 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 1749.66 examples/s]


In [4]:
sample_dataset = load_from_disk('/home/ahmadi/sadaf/GraphNeighborLM/G-retriever/dataset/ComplexWebQuestions/processed_dataset')

In [5]:
sample_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
        num_rows: 147
    })
    validation: Dataset({
        features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
        num_rows: 100
    })
    test: Dataset({
        features: ['id', 'question', 'answer', 'q_entity', 'a_entity', 'graph', 'choices'],
        num_rows: 100
    })
})

In [12]:
sample_dataset["train"]["question"][0:5]

['What is the state slogan of the state whic held the 2008 Kentucky state Senator election?',
 'The the country that contains Kaduna State has what type of government?',
 'In the country where the Chihuahua originated, which currency would you use?',
 'What country hosted the Euro 2012 tourney and had a TV show called Kryminalni recorded there?',
 'In what country is perpignan located that has the legislative session of the Ayrault Government?']

In [14]:
sample_dataset["train"]["answer"][0:5]

[['United we stand, divided we fall'],
 ['Presidential system', 'Federal republic'],
 ['Mexican peso'],
 ['Warsaw'],
 ['France']]

In [15]:
sample_dataset["train"]["q_entity"][0:5]

[['United States Senate election in Kentucky, 2008', 'Motto'],
 ['Kaduna State'],
 ['Chihuahua'],
 ['Kryminalni', 'UEFA Euro 2012'],
 ['Perpignan', 'Ayrault Government']]

In [16]:
sample_dataset["train"]["a_entity"][0:5]

[['United we stand, divided we fall'],
 ['Presidential system', 'Federal republic'],
 ['Mexican peso'],
 ['Warsaw'],
 ['France']]

In [17]:
sample_dataset["train"]["graph"][0]

[['Big Sandy River', 'geography.river.origin', 'Tug Fork'],
 ['m.09jx_rp',
  'government.government_position_held.jurisdiction_of_office',
  'United States of America'],
 ['Illinois', 'location.location.time_zones', 'Central Time Zone'],
 ['Kentucky',
  'military.military_unit_place_of_origin.military_units',
  '18th Regiment Kentucky Volunteer Infantry'],
 ['Levisa Fork', 'location.location.containedby', 'United States of America'],
 ['Kentucky',
  'government.governmental_jurisdiction.official_symbols',
  'm.04stp_k'],
 ['Michigan', 'location.location.time_zones', 'Eastern Time Zone'],
 ['Secretary of State of Kentucky',
  'common.topic.notable_types',
  'Government Office or Title'],
 ['Ohio', 'common.topic.notable_types', 'US State'],
 ['United States of America',
  'base.aareas.schema.administrative_area.administrative_children',
  'Minnesota'],
 ['Texas',
  'meteorology.cyclone_affected_area.cyclones',
  'Tropical Storm Chris'],
 ['Tennessee', 'common.topic.notable_types', 'US St

In [18]:
len(sample_dataset["train"]["graph"][0])

3588

In [19]:
len(sample_dataset['train'])

147