From c549b3b33df5583894687744cf8ca00a0f2e5638 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Fri, 25 Feb 2022 11:30:05 +0000 Subject: [PATCH] batch to cuda (#3611) --- torch_geometric/data/collate.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torch_geometric/data/collate.py b/torch_geometric/data/collate.py index 74aef446ec8e..7eff4c8b4814 100644 --- a/torch_geometric/data/collate.py +++ b/torch_geometric/data/collate.py @@ -58,6 +58,7 @@ def collate( # elements as attributes that got incremented need to be decremented # while separating to obtain original values. device = None + print('---------------') slice_dict, inc_dict = defaultdict(dict), defaultdict(dict) for out_store in out.stores: key = out_store._key @@ -84,7 +85,8 @@ def collate( value, slices, incs = _collate(attr, values, data_list, stores, increment) - device = value.device if isinstance(value, Tensor) else device + if isinstance(value, Tensor) and value.is_cuda: + device = value.device out_store[attr] = value if key is not None: