Skip to content

Commit

Permalink
fix gplus
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Oct 17, 2020
1 parent 0f5b422 commit 708e7e4
Showing 1 changed file with 46 additions and 30 deletions.
76 changes: 46 additions & 30 deletions torch_geometric/datasets/snap_dataset.py
Expand Up @@ -22,6 +22,10 @@ def __inc__(self, key, item):

def read_ego(files, name):
all_featnames = []
files = [
x for x in files if x.split('.')[-1] in
['circles', 'edges', 'egofeat', 'feat', 'featnames']
]
for i in range(4, len(files), 5):
featnames_file = files[i]
with open(featnames_file, 'r') as f:
Expand All @@ -39,57 +43,69 @@ def read_ego(files, name):
feat_file = files[i + 3]
featnames_file = files[i + 4]

x = pandas.read_csv(feat_file, sep=' ', header=None, dtype=np.int64)
x = torch.from_numpy(x.values)
# Catch the case of empty features.
try:
x_ego = pandas.read_csv(egofeat_file, sep=' ', header=None,
dtype=np.float32)
x_ego = torch.from_numpy(x_ego.values)
except: # noqa
x_ego = None

x = None
if x_ego is not None:
x = pandas.read_csv(feat_file, sep=' ', header=None,
dtype=np.float32)
x = torch.from_numpy(x.values)[:, 1:]

x_all = torch.cat([x, x_ego], dim=0)

# Reorder `x` according to `featnames` ordering.
x_all = torch.zeros(x.size(0), len(all_featnames))
with open(featnames_file, 'r') as f:
featnames = f.read().split('\n')[:-1]
featnames = [' '.join(x.split(' ')[1:]) for x in featnames]
indices = [all_featnames[featname] for featname in featnames]
x_all[:, torch.tensor(indices)] = x
x = x_all

idx = pandas.read_csv(feat_file, sep=' ', header=None, dtype=str,
usecols=[0], squeeze=True)

idx, x = x[:, 0].to(torch.long), x[:, 1:].to(torch.float)
idx_assoc = {}
for i, j in enumerate(idx.tolist()):
for i, j in enumerate(idx):
idx_assoc[j] = i

circles = []
circles_batch = []
with open(circles_file, 'r') as f:
for i, circle in enumerate(f.read().split('\n')[:-1]):
circle = [int(idx_assoc[int(c)]) for c in circle.split()[1:]]
circle = [idx_assoc[c] for c in circle.split()[1:]]
circles += circle
circles_batch += [i] * len(circle)
circle = torch.tensor(circles)
circle_batch = torch.tensor(circles_batch)

edge_index = pandas.read_csv(edges_file, sep=' ', header=None,
dtype=np.int64)
edge_index = torch.from_numpy(edge_index.values).t()
edge_index = edge_index.flatten()
for i, e in enumerate(edge_index.tolist()):
edge_index[i] = idx_assoc[e]
edge_index = edge_index.view(2, -1)
row, col = edge_index
row = pandas.read_csv(edges_file, sep=' ', header=None, dtype=str,
usecols=[0], squeeze=True)
col = pandas.read_csv(edges_file, sep=' ', header=None, dtype=str,
usecols=[1], squeeze=True)

x_ego = pandas.read_csv(egofeat_file, sep=' ', header=None,
dtype=np.float32)
x_ego = torch.from_numpy(x_ego.values)
row = torch.tensor([idx_assoc[i] for i in row])
col = torch.tensor([idx_assoc[i] for i in col])

N = max(int(row.max()), int(col.max())) + 2
N = x.size(0) if x is not None else N

row_ego = torch.full((x.size(0), ), x.size(0), dtype=torch.long)
col_ego = torch.arange(x.size(0))
row_ego = torch.full((N - 1, ), N - 1, dtype=torch.long)
col_ego = torch.arange(N - 1)

# Ego node should be connected to every other node.
row = torch.cat([row, row_ego, col_ego], dim=0)
col = torch.cat([col, col_ego, row_ego], dim=0)
edge_index = torch.stack([row, col], dim=0)

x = torch.cat([x, x_ego], dim=0)

# Reorder `x` according to `featnames` ordering.
x_all = torch.zeros(x.size(0), len(all_featnames))
with open(featnames_file, 'r') as f:
featnames = f.read().split('\n')[:-1]
featnames = [' '.join(x.split(' ')[1:]) for x in featnames]
indices = [all_featnames[featname] for featname in featnames]
x_all[:, torch.tensor(indices)] = x

edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))
data = Data(x=x_all, edge_index=edge_index, circle=circle,
edge_index, _ = coalesce(edge_index, None, N, N)
data = Data(x=x, edge_index=edge_index, circle=circle,
circle_batch=circle_batch)

data_list.append(data)
Expand Down

0 comments on commit 708e7e4

Please sign in to comment.