Skip to content

Commit

Permalink
Fixing bugs in model_parallel example (#9231)
Browse files Browse the repository at this point in the history
Two bugs were fixed in `examples/multi_gpu/model_parallel.py`.
The first one was the attribute error:
```
  File "/workspace/examples/multi_gpu/model_parallel.py", line 30, in forward
    x, edge_index = x.to(self.device2), edge_index.to(self.device2)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1704, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'GCN' object has no attribute 'device2'
```

After fixing the initial error, the next one appeared:
```
  File "/workspace/examples/multi_gpu/model_parallel.py", line 58, in test
    out = model(data.x, data.edge_index, data.edge_attr)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1527, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1536, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: GCN.forward() takes 3 positional arguments but 4 were given
```

By the way, `data.edge_attr` is equal to `None`. 

With these fixes, the example works fine.

---------

Co-authored-by: rusty1s <matthias.fey@tu-dortmund.de>
  • Loading branch information
drivanov and rusty1s committed Apr 26, 2024
1 parent f69f595 commit 212c4ce
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion examples/multi_gpu/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
class GCN(torch.nn.Module):
def __init__(self, in_channels, out_channels, device1, device2):
super().__init__()
self.device1 = device1
self.device2 = device2

self.conv1 = GCNConv(in_channels, 16).to(device1)
self.conv2 = GCNConv(16, out_channels).to(device2)

Expand Down Expand Up @@ -54,7 +57,7 @@ def train():
@torch.no_grad()
def test():
model.eval()
out = model(data.x, data.edge_index, data.edge_attr)
out = model(data.x, data.edge_index)
pred = out.argmax(dim=-1).to('cuda:0')

accs = []
Expand Down

0 comments on commit 212c4ce

Please sign in to comment.