diff --git a/torch_geometric/data/collate.py b/torch_geometric/data/collate.py index 74aef446ec8e..7eff4c8b4814 100644 --- a/torch_geometric/data/collate.py +++ b/torch_geometric/data/collate.py @@ -58,6 +58,7 @@ def collate( # elements as attributes that got incremented need to be decremented # while separating to obtain original values. device = None + print('---------------') slice_dict, inc_dict = defaultdict(dict), defaultdict(dict) for out_store in out.stores: key = out_store._key @@ -84,7 +85,8 @@ def collate( value, slices, incs = _collate(attr, values, data_list, stores, increment) - device = value.device if isinstance(value, Tensor) else device + if isinstance(value, Tensor) and value.is_cuda: + device = value.device out_store[attr] = value if key is not None: