Skip to content

Commit

Permalink
update event datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Apr 15, 2019
1 parent 281ecf3 commit 8bd1bb5
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 63 deletions.
4 changes: 2 additions & 2 deletions torch_geometric/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .pcpnet_dataset import PCPNetDataset
from .geometry import GeometricShapes
from .bitcoin_otc import BitcoinOTC
from .icews import ICEWS
from .icews import ICEWS18
from .gdelt import GDELT

__all__ = [
Expand All @@ -38,6 +38,6 @@
'PCPNetDataset',
'GeometricShapes',
'BitcoinOTC',
'ICEWS',
'ICEWS18',
'GDELT',
]
51 changes: 46 additions & 5 deletions torch_geometric/datasets/gdelt.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from .icews import TemporalDataset
import torch
from torch_geometric.data import download_url
from torch_geometric.read import read_txt_array

from .icews import EventDataset

class GDELT(TemporalDataset):

class GDELT(EventDataset):
r"""The Global Database of Events, Language, and Tone (GDELT) dataset used
in the, *e.g.*, `"Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs" <https://arxiv.org/abs/1904.05530>`_ paper, consisting of
Expand All @@ -17,8 +21,45 @@ class GDELT(TemporalDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""

def __init__(self, root, transform=None, pre_transform=None):
super(GDELT, self).__init__(root, 'GDELT', 15, transform,
pre_transform)
url = 'https://github.com/INK-USC/RENet/raw/master/data/GDELT'

def __init__(self,
root,
transform=None,
pre_transform=None,
pre_filter=None):
super(GDELT, self).__init__(root, transform, pre_transform, pre_filter)

@property
def num_nodes(self):
return 7691

@property
def num_rels(self):
return 240

@property
def splits(self):
return [0, 1734399, 1973164, 2278405] # Train/Val/Test splits.

@property
def raw_file_names(self):
return ['{}.txt'.format(name) for name in ['train', 'valid', 'test']]

def download(self):
for filename in self.raw_file_names:
download_url('{}/{}'.format(self.url, filename), self.raw_dir)

def process_events(self):
events = []
for path in self.raw_paths:
data = read_txt_array(path, sep='\t', end=4, dtype=torch.long)
data[:, 3] = data[:, 3] / 15
events += [data]
return torch.cat(events, dim=0)
122 changes: 66 additions & 56 deletions torch_geometric/datasets/icews.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,83 +3,55 @@
from torch_geometric.read import read_txt_array


class TemporalDataset(InMemoryDataset):
url = 'https://github.com/INK-USC/RENet/raw/master/data'
names = ['ICEWS18', 'GDELT']

class EventDataset(InMemoryDataset):
def __init__(self,
root,
name,
granularity,
transform=None,
pre_transform=None):
assert name in self.names
self.name = name
self.granularity = granularity
super(TemporalDataset, self).__init__(root, transform, pre_transform)
pre_transform=None,
pre_filter=None):
super(EventDataset, self).__init__(root, transform, pre_transform,
pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
self.splits = torch.load(self.processed_paths[1])

@property
def num_nodes(self):
return self.data.num_nodes

@property
def num_rels(self):
return self.data.edge_type.max().item() + 1
def processed_file_names(self):
return 'data.pt'

@property
def raw_file_names(self):
return ['{}.txt'.format(name) for name in ['train', 'valid', 'test']]
def num_nodes(self):
raise NotImplementedError

@property
def processed_file_names(self):
return ['data.pt', 'splits.pt']
def num_rels(self):
raise NotImplementedError

def download(self):
for filename in self.raw_file_names:
url = '{}/{}/{}'.format(self.url, self.name, filename)
download_url(url, self.raw_dir)
def process_events(self):
raise NotImplementedError

def process(self):
events = self.process_events()
events = events - events.min(dim=0, keepdim=True)[0]

data_list = []
splits = [0]
for raw_path in self.raw_paths:
srot = read_txt_array(raw_path, sep='\t', end=4, dtype=torch.long)
row, rel, col, time = srot.t().contiguous()
time = time / self.granularity

count = time.bincount()
split_sections = count[count > 0].tolist()

rows = row.split(split_sections)
cols = col.split(split_sections)
rels = rel.split(split_sections)
times = time.split(split_sections)
splits.append(splits[-1] + len(rows))

for row, col, rel, time in zip(rows, cols, rels, times):
edge_index = torch.stack([row, col], dim=0)
data = Data(edge_index=edge_index, edge_type=rel, time=time)
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
for (sub, rel, obj, t) in events.tolist():
data = Data(sub=sub, rel=rel, obj=obj, t=t)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)

torch.save(self.collate(data_list), self.processed_paths[0])
torch.save(splits, self.processed_paths[1])


class ICEWS(TemporalDataset):
class ICEWS18(EventDataset):
r"""The Integrated Crisis Early Warning System (ICEWS) dataset used in
the, *e.g.*, `"Recurrent Event Network for Reasoning over Temporal
Knowledge Graphs" <https://arxiv.org/abs/1904.05530>`_ paper, consisting of
events collected from 1/1/2014 to 12/31/2014 (24 hours time granularity).
events collected from 1/1/2018 to 10/31/2018 (24 hours time granularity).
Args:
root (string): Root directory where the dataset should be saved.
split (string): If :obj:`"train"`, loads the training dataset.
If :obj:`"val"`, loads the validation dataset.
If :obj:`"test"`, loads the test dataset. (default: :obj:`"train"`)
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
Expand All @@ -88,8 +60,46 @@ class ICEWS(TemporalDataset):
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""

def __init__(self, root, transform=None, pre_transform=None):
super(ICEWS, self).__init__(root, 'ICEWS18', 24, transform,
pre_transform)
url = 'https://github.com/INK-USC/RENet/raw/master/data/ICEWS18'

def __init__(self,
root,
transform=None,
pre_transform=None,
pre_filter=None):
super(ICEWS18, self).__init__(root, transform, pre_transform,
pre_filter)

@property
def num_nodes(self):
return 23033

@property
def num_rels(self):
return 256

@property
def splits(self):
return [0, 373018, 419013, 468558] # Train/Val/Test splits.

@property
def raw_file_names(self):
return ['{}.txt'.format(name) for name in ['train', 'valid', 'test']]

def download(self):
for filename in self.raw_file_names:
download_url('{}/{}'.format(self.url, filename), self.raw_dir)

def process_events(self):
events = []
for path in self.raw_paths:
data = read_txt_array(path, sep='\t', end=4, dtype=torch.long)
data[:, 3] = data[:, 3] / 24
events += [data]
return torch.cat(events, dim=0)

0 comments on commit 8bd1bb5

Please sign in to comment.