Skip to content

rand_like() function with same manual_seed give transposed results given contiguous and non-contiguous argument. #71090

@tjyuyao

Description

@tjyuyao

🐛 I get different rand_like() results with same seed.

Following is a minimal reproduction code.

The background is that some of the torch function returns non contiguous tensor:

import torch
a = torch.randn([1,9,2], device="cuda")
b = torch.randn([1,2,9,2], device="cuda")
c = torch.einsum("bsl,bcsl->bcl", a, b).reshape(1, 2, 1, 2)
assert c.is_contiguous() == False
d = c.contiguous()
assert d.is_contiguous() == True

Now the problem is a rand_like(c) call will return different to rand_like(d).

torch.manual_seed(0)
print(torch.randn_like(c))

torch.manual_seed(0)
print(torch.randn_like(d))

which gives transposed results:

tensor([[[[-0.9247, -2.6438]],

         [[-0.4253,  0.1452]]]], device='cuda:0')
tensor([[[[-0.9247, -0.4253]],

         [[-2.6438,  0.1452]]]], device='cuda:0')

But shouldn't we generally expect the same result from those statements as they are basically the same tensor with the same seeds, only differs with the contiguous states. Should the current behavior be expected?

I met this problem when I tried to compare two algorithms' gradient by doing y.backward(torch.rand_like(y)), which cause different results during the backward pass given the same forward output (in terms of the value but not the contiguous state) of the two different algorithms.

Versions

PyTorch version: 1.10.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: Could not collect
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.8.12 | packaged by conda-forge | (default, Oct 12 2021, 21:59:51) [GCC 9.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-91-generic-x86_64-with-glibc2.10
Is CUDA available: True
CUDA runtime version: 11.5.119
GPU models and configuration:
GPU 0: NVIDIA GeForce RTX 3090
GPU 1: NVIDIA GeForce RTX 3090
GPU 2: NVIDIA GeForce RTX 3090
GPU 3: NVIDIA GeForce RTX 3090

Nvidia driver version: 495.29.05
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy==0.910
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.19.5
[pip3] torch==1.10.0
[pip3] torchvision==0.11.1
[pip3] torchviz==0.0.2
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 h2bc3f7f_2
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] mypy 0.910 pypi_0 pypi
[conda] mypy-extensions 0.4.3 pypi_0 pypi
[conda] numpy 1.19.5 pypi_0 pypi
[conda] pytorch 1.10.0 py3.8_cuda11.3_cudnn8.2.0_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchvision 0.11.1 py38_cu113 pytorch
[conda] torchviz 0.0.2 pypi_0 pypi

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions