diff --git a/tntorch/round.py b/tntorch/round.py index 8b273ac..aa5de68 100644 --- a/tntorch/round.py +++ b/tntorch/round.py @@ -131,10 +131,10 @@ def truncated_svd( # NOTE: Special case: M = zero -> rank is 1 if batch: if svd[1].max() < 1e-13: - return torch.zeros([batch_size, M.shape[1], 1]), torch.zeros([batch_size, 1, M.shape[2]]) + return torch.zeros([batch_size, M.shape[1], 1]).to(M.device), torch.zeros([batch_size, 1, M.shape[2]]).to(M.device) else: if svd[1][0] < 1e-13: - return torch.zeros([M.shape[0], 1]), torch.zeros([1, M.shape[1]]) + return torch.zeros([M.shape[0], 1]).to(M.device), torch.zeros([1, M.shape[1]]).to(M.device) S = svd[1]**2