Skip to content
Permalink
Browse files

clean up cleaned tu dataset

  • Loading branch information...
rusty1s committed Oct 28, 2019
1 parent 2b1fb68 commit d9d0152275abd4e0ea03d9cc57fe80fc1e698871
@@ -27,21 +27,21 @@
hiddens = [16, 32, 64, 128]
datasets = ['MUTAG', 'PROTEINS', 'IMDB-BINARY', 'REDDIT-BINARY'] # , 'COLLAB']
nets = [
# GCNWithJK,
# GraphSAGEWithJK,
# GIN0WithJK,
# GINWithJK,
# Graclus,
# TopK,
# SAGPool,
# DiffPool,
GCNWithJK,
GraphSAGEWithJK,
GIN0WithJK,
GINWithJK,
Graclus,
TopK,
SAGPool,
DiffPool,
GCN,
GraphSAGE,
GIN0,
GIN,
# GlobalAttentionNet,
# Set2SetNet,
# SortPool,
GlobalAttentionNet,
Set2SetNet,
SortPool,
]


@@ -57,18 +57,14 @@ def logger(info):
best_result = (float('inf'), 0, 0) # (loss, acc, std)
print('-----\n{} - {}'.format(dataset_name, Net.__name__))
for num_layers, hidden in product(layers, hiddens):
dataset = get_dataset(dataset_name, sparse=Net != DiffPool, cleaned=True)
dataset = get_dataset(dataset_name, sparse=Net != DiffPool,
cleaned=True)
model = Net(dataset, num_layers, hidden)
loss, acc, std = cross_validation_with_val_set(
dataset,
model,
folds=10,
epochs=args.epochs,
batch_size=args.batch_size,
lr=args.lr,
dataset, model, folds=10, epochs=args.epochs,
batch_size=args.batch_size, lr=args.lr,
lr_decay_factor=args.lr_decay_factor,
lr_decay_step_size=args.lr_decay_step_size,
weight_decay=0,
lr_decay_step_size=args.lr_decay_step_size, weight_decay=0,
logger=None)
if loss < best_result[0]:
best_result = (loss, acc, std)
@@ -14,7 +14,6 @@ def test_batch():
s2 = '2'

data = Batch.from_data_list([Data(x1, e1, s=s1), Data(x2, e2, s=s2)])
print(data)

assert data.__repr__() == (
'Batch(batch=[5], edge_index=[2, 6], s=[2], x=[5])')
@@ -69,3 +69,12 @@ def test_enzymes():
assert dataset.num_edge_features == 0

shutil.rmtree(root)


def test_cleaned_enzymes():
root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
dataset = TUDataset(root, 'ENZYMES', cleaned=True)

assert len(dataset) == 595

shutil.rmtree(root)

This file was deleted.

This file was deleted.

@@ -44,6 +44,7 @@ class Dataset(torch.utils.data.Dataset):
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""

@property
def raw_file_names(self):
r"""The name of the files to find in the :obj:`self.raw_dir` folder in
@@ -77,8 +78,6 @@ def __init__(self, root, transform=None, pre_transform=None,
super(Dataset, self).__init__()

self.root = osp.expanduser(osp.normpath(root))
self.raw_dir = osp.join(self.root, 'raw')
self.processed_dir = osp.join(self.root, 'processed')
self.transform = transform
self.pre_transform = pre_transform
self.pre_filter = pre_filter
@@ -102,6 +101,14 @@ def __init__(self, root, transform=None, pre_transform=None,
self._download()
self._process()

@property
def raw_dir(self):
return osp.join(self.root, 'raw')

@property
def processed_dir(self):
return osp.join(self.root, 'processed')

@property
def num_node_features(self):
r"""Returns the number of features per node in the dataset."""

This file was deleted.

0 comments on commit d9d0152

Please sign in to comment.
You can’t perform that action at this time.