Skip to content

Commit

Permalink
node, edge features to vary across labels
Browse files Browse the repository at this point in the history
  • Loading branch information
arunppsg committed Feb 11, 2022
1 parent 2db5eff commit 76946ad
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions torch_geometric/datasets/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,23 +72,34 @@ def generate_data(self) -> Data:

data = Data()

if self._num_classes > 0 and self.task == 'node':
data.y = torch.randint(self._num_classes, (num_nodes, ))
elif self._num_classes > 0 and self.task == 'graph':
data.y = torch.tensor([random.randint(0, self._num_classes - 1)])

data.edge_index = get_edge_index(num_nodes, num_nodes, self.avg_degree,
self.is_undirected, remove_loops=True)

if self.num_channels > 0:
data.x = torch.randn(num_nodes, self.num_channels)
if self.num_channels > 0 and self.task == 'graph':
data.x = torch.randn(num_nodes, self.num_channels) + data.y
elif self.num_channels > 0 and self.task == 'node':
data.x = torch.randn(num_nodes,
self.num_channels) + data.y.unsqueeze(1)
else:
data.num_nodes = num_nodes

if self.edge_dim > 1:
data.edge_attr = torch.rand(data.num_edges, self.edge_dim)
if self.task == 'graph':
data.edge_attr = torch.rand(data.num_edges,
self.edge_dim) + data.y
elif self.task == 'node':
# no need to consider variance in edge distribution
data.edge_attr = torch.rand(data.num_edges, self.edge_dim)
elif self.edge_dim == 1:
data.edge_weight = torch.rand(data.num_edges)

if self._num_classes > 0 and self.task == 'node':
data.y = torch.randint(self._num_classes, (num_nodes, ))
elif self._num_classes > 0 and self.task == 'graph':
data.y = torch.tensor([random.randint(0, self._num_classes - 1)])
if self.task == 'graph':
data.edge_weight = torch.rand(data.num_edges) + data.y
elif self.task == 'node':
data.edge_weight = torch.rand(data.num_edges)

return data

Expand Down

0 comments on commit 76946ad

Please sign in to comment.