In [None]:
import torch

print("Current version", torch.__version__)
print("Problem was detected on 1.10.0+cu111")

Current version 1.10.0+cu111
Problem was detected on 1.10.0+cu111


In [None]:
class FastConv(object):

  @staticmethod
  def forward(x, w, b, conv_param):
    # print(x.shape, w.shape, b.shape, conv_param)
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape
    stride, pad = conv_param['stride'], conv_param['pad']
    layer = torch.nn.Conv2d(C, F, (HH, WW), stride=stride, padding=pad)
    layer.weight = torch.nn.Parameter(w)
    layer.bias = torch.nn.Parameter(b)
    tx = x.detach()
    tx.requires_grad = True
    out = layer(tx)
    cache = (x, w, b, conv_param, tx, out, layer)
    return out, cache

  @staticmethod
  def backward(dout, cache):
    try:
      x, _, _, _, tx, out, layer = cache
      out.backward(dout)
      dx = tx.grad.detach()
      dw = layer.weight.grad.detach()
      db = layer.bias.grad.detach()
      layer.weight.grad = layer.bias.grad = None
    except RuntimeError:
      dx, dw, db = torch.zeros_like(tx), torch.zeros_like(layer.weight), torch.zeros_like(layer.bias)
    return dx, dw, db

In [None]:
device = 'cpu'
num_inputs = 2
input_dims = (3, 16, 16)
next_filt = 16

batchnorm = True
dtype = torch.float32

kernel_size = 3
bn_param = {'mode': 'train'}
# stride and padding preserve output spatial size
conv_param = {'stride': 1, 'pad': (kernel_size - 1) // 2}


x = torch.randn(num_inputs, *input_dims, dtype=dtype, device=device)

gamma = torch.ones(input_dims[0], device=device, dtype=dtype)
beta = torch.zeros(input_dims[0], device=device, dtype=dtype)

Weight = torch.randn(next_filt, input_dims[0], kernel_size, kernel_size, dtype=dtype, device=device)
b = torch.zeros(next_filt, dtype=dtype, device=device)

N, C, H, W = x.shape

## PyTorch BN2d
try:
  out = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=device, dtype=dtype).forward(x)

  out_bn_2d, _ = FastConv.forward(out, Weight, b, conv_param)
except Exception as e:
  print(e)

## Pytorch BN1d
try:
  ch_view = x.transpose(1,2).transpose(2,3).reshape(N * H * W, C) 
  out = torch.nn.BatchNorm1d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=device, dtype=dtype).forward(ch_view)
  out = out.reshape(N, H, W, C).transpose(2,3).transpose(1,2)

  out_bn_1d, _ = FastConv.forward(out, Weight, b, conv_param)
except Exception as e:
  print(e)

print("BN1d vs BN2d", torch.norm(out_bn_1d - out_bn_2d).item())

BN1d vs BN2d 8.046893344726413e-05


In [3]:
!pip3 install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html -U

import torch
print(torch.__version__)

Looking in links: https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
Collecting torch
  Downloading https://download.pytorch.org/whl/nightly/cu111/torch-1.12.0.dev20220209%2Bcu111-cp37-cp37m-linux_x86_64.whl (1922.9 MB)
[K     |█████████████▉                  | 834.1 MB 1.5 MB/s eta 0:12:06tcmalloc: large alloc 1147494400 bytes == 0x562593bca000 @  0x7f517dab3615 0x56255969b3bc 0x56255977c18a 0x56255969e1cd 0x562559790b3d 0x562559712458 0x56255970d02f 0x56255969faba 0x5625597122c0 0x56255970d02f 0x56255969faba 0x56255970ecd4 0x562559791986 0x56255970e350 0x562559791986 0x56255970e350 0x562559791986 0x56255970e350 0x56255969ff19 0x5625596e3a79 0x56255969eb32 0x5625597121dd 0x56255970d02f 0x56255969faba 0x56255970ecd4 0x56255970d02f 0x56255969faba 0x56255970deae 0x56255969f9da 0x56255970e108 0x56255970d02f
[K     |█████████████████▋              | 1055.7 MB 1.6 MB/s eta 0:09:11tcmalloc: large alloc 1434370048 bytes == 0x5625d8220000 @  0x7f517dab3615 0x56255969b3bc 0x562

1.10.2+cu111


In [1]:
import torch
print(torch.__version__)

class FastConv(object):
  @staticmethod
  def forward(x, w, b, conv_param):
    # print(x.shape, w.shape, b.shape, conv_param)
    N, C, H, W = x.shape
    F, _, HH, WW = w.shape
    stride, pad = conv_param['stride'], conv_param['pad']
    layer = torch.nn.Conv2d(C, F, (HH, WW), stride=stride, padding=pad)
    layer.weight = torch.nn.Parameter(w)
    layer.bias = torch.nn.Parameter(b)
    tx = x.detach()
    tx.requires_grad = True
    out = layer(tx)
    cache = (x, w, b, conv_param, tx, out, layer)
    return out, cache

  @staticmethod
  def backward(dout, cache):
    try:
      x, _, _, _, tx, out, layer = cache
      out.backward(dout)
      dx = tx.grad.detach()
      dw = layer.weight.grad.detach()
      db = layer.bias.grad.detach()
      layer.weight.grad = layer.bias.grad = None
    except RuntimeError:
      dx, dw, db = torch.zeros_like(tx), torch.zeros_like(layer.weight), torch.zeros_like(layer.bias)
    return dx, dw, db


device = 'cuda'
num_inputs = 2
input_dims = (3, 16, 16)
next_filt = 16

batchnorm = True
dtype = torch.float32

kernel_size = 3
bn_param = {'mode': 'train'}
# stride and padding preserve output spatial size
conv_param = {'stride': 1, 'pad': (kernel_size - 1) // 2}

x = torch.randn(num_inputs, *input_dims, dtype=dtype, device=device)

gamma = torch.ones(input_dims[0], device=device, dtype=dtype)
beta = torch.zeros(input_dims[0], device=device, dtype=dtype)

Weight = torch.randn(next_filt, input_dims[0], kernel_size, kernel_size, dtype=dtype, device=device)
b = torch.zeros(next_filt, dtype=dtype, device=device)

N, C, H, W = x.shape

## PyTorch BN2d
out = torch.nn.BatchNorm2d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=device, dtype=dtype).forward(x)
out_bn_2d, _ = FastConv.forward(out, Weight, b, conv_param)
print(out_bn_2d.shape)
# > torch.Size([2, 16, 16, 16])

## Pytorch BN1d
ch_view = x.transpose(1,2).transpose(2,3).reshape(N * H * W, C) 
out = torch.nn.BatchNorm1d(C, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=device, dtype=dtype).forward(ch_view)
out = out.reshape(N, H, W, C).transpose(2,3).transpose(1,2)
out_bn_1d, _ = FastConv.forward(out, Weight, b, conv_param)
print(out_bn_1d.shape)
# > torch.Size([2, 16, 16, 16])

1.12.0.dev20220209+cu111
torch.Size([2, 16, 16, 16])
torch.Size([2, 16, 16, 16])
