Skip to content

Linear transformations of batch vs. single row do not satisfy torch.allclose() #29727

@shawntan

Description

@shawntan

🐛 Bug

Performing a linear transform to a large, higher dimension results in output that do not satisfy torch.allclose() when a single row vs. a minibatch, post-transformation row are compared.

To Reproduce

>>> import torch
>>> from torch import nn
>>> f = nn.Linear(20, 10000)
>>> x = torch.randn(16, 20)
>>> y1 = f(x)[:1]
>>> y2 = f(x[:1])
>>> y1 - y2
tensor([[ 0.0000e+00,  0.0000e+00, -2.9802e-08,  ..., -5.9605e-08,
          5.9605e-08, -5.9605e-08]], grad_fn=<SubBackward0>)
>>> torch.allclose(y1, y2)
False

Expected behavior

I tried this with numpy using a large matrix transform, and no issues were encountered:

>>> W = np.random.randn(20, 10000)
>>> x_ = x.numpy()
>>> np.dot(x_,W)
array([[  2.04262373,  -1.83566229,   5.37089564, ...,  -1.36152633,
          2.09995545,   9.50700872],
       [ -1.41234774,   1.11562106,   8.02134875, ...,   0.51592682,
          3.92124522,   1.50186604],
       [  1.81572648, -10.87273162,   0.1820477 , ...,   0.82404416,
          2.97601253,  -3.18268902],
       ...,
       [ -3.18884018,  -6.77464924,   2.3925408 , ...,  -1.01267284,
          1.78089243,  -3.11052016],
       [ -1.05808344,  -0.72415951,  -1.94846654, ...,   0.15422047,
         -2.3243351 ,   7.35207271],
       [ -2.29791049,   2.42070905,   5.46771606, ...,  -5.39978544,
          6.22757513,  -8.78495285]])
>>> z1 = np.dot(x_,W)
>>> z1 = np.dot(x_,W)[:1]
>>> z2 = np.dot(x_[:1],W)
>>> np.allclose(z1, z2)
True
>>> z1 - z2
array([[0., 0., 0., ..., 0., 0., 0.]])

Environment

Collecting environment information...
PyTorch version: 1.3.1
Is debug build: No
CUDA used to build PyTorch: 10.1.243

OS: Arch Linux
GCC version: (GCC) 9.2.0
CMake version: version 3.15.5

Python version: 3.7
Is CUDA available: No
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA

Versions of relevant libraries:
[pip3] numpy==1.16.2
[pip3] torch==1.3.1
[pip3] torchtext==0.4.0
[pip3] torchvision==0.2.2.post3
[conda] Could not collect

Additional context

Uses CPU.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions