diff --git a/torchvision/utils.py b/torchvision/utils.py index e82752ab28b..afbc1332105 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -56,8 +56,13 @@ def make_grid( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(make_grid) - if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): - raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") + if not torch.is_tensor(tensor): + if isinstance(tensor, list): + for t in tensor: + if not torch.is_tensor(t): + raise TypeError(f"tensor or list of tensors expected, got a list containing {type(t)}") + else: + raise TypeError(f"tensor or list of tensors expected, got {type(tensor)}") if "range" in kwargs.keys(): warnings.warn(