Skip to content

Commit

Permalink
node, edge features to vary across labels
Browse files Browse the repository at this point in the history
  • Loading branch information
arunppsg committed Feb 11, 2022
1 parent 2db5eff commit 2840bc1
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions torch_geometric/datasets/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,24 +72,24 @@ def generate_data(self) -> Data:

data = Data()

if self._num_classes > 0 and self.task == 'node':
data.y = torch.randint(self._num_classes, (num_nodes, ))
elif self._num_classes > 0 and self.task == 'graph':
data.y = torch.tensor([random.randint(0, self._num_classes - 1)])

data.edge_index = get_edge_index(num_nodes, num_nodes, self.avg_degree,
self.is_undirected, remove_loops=True)

if self.num_channels > 0:
data.x = torch.randn(num_nodes, self.num_channels)
data.x = torch.randn(num_nodes, self.num_channels) + data.y
else:
data.num_nodes = num_nodes

if self.edge_dim > 1:
data.edge_attr = torch.rand(data.num_edges, self.edge_dim)
data.edge_attr = torch.rand(data.num_edges, self.edge_dim) + data.y
elif self.edge_dim == 1:
data.edge_weight = torch.rand(data.num_edges)

if self._num_classes > 0 and self.task == 'node':
data.y = torch.randint(self._num_classes, (num_nodes, ))
elif self._num_classes > 0 and self.task == 'graph':
data.y = torch.tensor([random.randint(0, self._num_classes - 1)])

return data


Expand Down

0 comments on commit 2840bc1

Please sign in to comment.