-
Notifications
You must be signed in to change notification settings - Fork 25.2k
Closed
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: determinismmodule: nnRelated to torch.nnRelated to torch.nnmodule: norms and normalizationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module
Description
🐛 Bug
Either BatchNorm1d or BatchNorm2d may impact reproducibility with CUDA.
To Reproduce
I wrote a script to prove.
In this script, both BatchNorm1d and BatchNorm2d are applied.
One can delete all either BN1d or BN2d to get similar results.
import os
#Enable deterministic about CUBLAS.
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
import random
import argparse
def set_seed(seed=1):
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
if torch.__version__ >= '1.8':
torch.use_deterministic_algorithms(True)
else:
torch.set_deterministic(True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.cuda.manual_seed_all(seed)
def hash_module(module: nn.Module) -> float:
"""Return the hash value of the given module.
Noet: this operation will change the random seed of `torch`.
Args:
module (nn.Module): Module.
Returns:
float: Hash value.
"""
hash_value = 0
with torch.no_grad():
for param in module.parameters():
hash_value += (param.norm(p=2))
return hash_value
class NoOP(nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x
def run(args):
set_seed(1)
X = torch.rand(10000, 1, 28, 28, dtype=torch.float32)
print("Norm of X: ", X.norm(p=2))
dataset = data.TensorDataset(
X,
torch.randint(0, 1, (10000,))
)
dataLoader = data.DataLoader(dataset, batch_size=64, shuffle=False)
model = nn.Sequential(
nn.Conv2d(1, 16, 3, 1, 1),
nn.BatchNorm2d(16) if args.bn else NoOP(),
nn.ReLU(inplace=True),
nn.Conv2d(16, 16, 7, 3, 0),
nn.BatchNorm2d(16) if args.bn else NoOP(),
nn.ReLU(inplace=True),
nn.Flatten(1),
nn.Linear(16 * 8 * 8, 16 * 4 * 4),
nn.BatchNorm1d(16 * 4 * 4) if args.bn else NoOP(),
nn.ReLU(inplace=True),
nn.Linear(16 * 4 * 4, 8 * 2 * 2),
nn.BatchNorm1d(8 * 2 * 2) if args.bn else NoOP(),
nn.ReLU(inplace=True),
nn.Linear(32, 10)
)
if args.cuda:
model = model.cuda()
print("Norm of model before training: ", hash_module(model))
optimizer = optim.Adam(model.parameters(), lr=1e-4)
for x, y in dataLoader:
if args.cuda:
x = x.cuda()
y = y.cuda()
optimizer.zero_grad()
y_hat = model(x)
loss = F.cross_entropy(y_hat, y)
loss.backward()
optimizer.step()
print("Norm of model after training: ", hash_module(model))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--bn", action="store_true")
parser.add_argument("--cuda", action="store_true")
args = parser.parse_args()
for i in range(4):
print("{}-th RUN:".format(i))
run(args)
Expected behavior
Without BN and disable CUDA
python run.py
0-th RUN:
[W Context.cpp:69] Warning: torch.set_deterministic is in beta, and its design and functionality may change in the future. (function operator())
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(20.4299)
Norm of model after training: tensor(21.0595)
1-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(20.4299)
Norm of model after training: tensor(21.0595)
2-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(20.4299)
Norm of model after training: tensor(21.0595)
3-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(20.4299)
Norm of model after training: tensor(21.0595)
With BN and disable CUDA.
python run.py --bn
0-th RUN:
[W Context.cpp:69] Warning: torch.set_deterministic is in beta, and its design and functionality may change in the future. (function operator())
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(50.0867)
Norm of model after training: tensor(50.2153)
1-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(50.0867)
Norm of model after training: tensor(50.2153)
2-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(50.0867)
Norm of model after training: tensor(50.2153)
3-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(50.0867)
Norm of model after training: tensor(50.2153)
One can see that BN does not matter by the comparison of above 2 commands.
Without BN and enable CUDA
python run.py --cuda
0-th RUN:
[W Context.cpp:69] Warning: torch.set_deterministic is in beta, and its design and functionality may change in the future. (function operator())
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(20.4299, device='cuda:0')
Norm of model after training: tensor(21.0595, device='cuda:0')
1-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(20.4299, device='cuda:0')
Norm of model after training: tensor(21.0595, device='cuda:0')
2-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(20.4299, device='cuda:0')
Norm of model after training: tensor(21.0595, device='cuda:0')
3-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(20.4299, device='cuda:0')
Norm of model after training: tensor(21.0595, device='cuda:0')
With BN and enable CUDA
python run.py --bn --cuda
0-th RUN:
[W Context.cpp:69] Warning: torch.set_deterministic is in beta, and its design and functionality may change in the future. (function operator())
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(50.0867, device='cuda:0')
Norm of model after training: tensor(50.2151, device='cuda:0')
1-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(50.0867, device='cuda:0')
Norm of model after training: tensor(50.2153, device='cuda:0')
2-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(50.0867, device='cuda:0')
Norm of model after training: tensor(50.2154, device='cuda:0')
3-th RUN:
Norm of X: tensor(1616.1311)
Norm of model before training: tensor(50.0867, device='cuda:0')
Norm of model after training: tensor(50.2151, device='cuda:0')
There are three kinds of model norm: 50.2151, 50.2153 and 50.2154.
Environment
conda list:
- Linux
- python 3.8.5
- pytorch 1.7.1 py3.8_cuda10.2.89_cudnn7.6.5_0
cc @ngimel @aocsa @mruberry @kurtamohler @albanD @jbschlosser
bridgeqiqi and jasperzhong
Metadata
Metadata
Assignees
Labels
module: cudaRelated to torch.cuda, and CUDA support in generalRelated to torch.cuda, and CUDA support in generalmodule: determinismmodule: nnRelated to torch.nnRelated to torch.nnmodule: norms and normalizationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate moduleThis issue has been looked at a team member, and triaged and prioritized into an appropriate module