Skip to content

BatchNorm may impact reproducibiltiy with CUDA #53691

@ZyUestc

Description

@ZyUestc

🐛 Bug

Either BatchNorm1d or BatchNorm2d may impact reproducibility with CUDA.

To Reproduce

I wrote a script to prove.
In this script, both BatchNorm1d and BatchNorm2d are applied.
One can delete all either BN1d or BN2d to get similar results.

import os
#Enable deterministic about CUBLAS.
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import random
import argparse


def set_seed(seed=1): 
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    
    if torch.__version__ >= '1.8':
        torch.use_deterministic_algorithms(True)
    else:
        torch.set_deterministic(True)
    
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.cuda.manual_seed_all(seed)
    
    
def hash_module(module: nn.Module) -> float:
    """Return the hash value of the given module.
    Noet: this operation will change the random seed of `torch`.

    Args:
        module (nn.Module): Module.

    Returns:
        float: Hash value.
    """
    hash_value = 0
    with torch.no_grad():
        for param in module.parameters():
            hash_value += (param.norm(p=2))
    return hash_value
    
    
class NoOP(nn.Module): 
    def __init__(self):
        super().__init__()
        
    def forward(self, x):
        return x
    
def run(args):
    set_seed(1)
    
    X = torch.rand(10000, 1, 28, 28, dtype=torch.float32)
    print("Norm of X: ", X.norm(p=2))
    
    dataset = data.TensorDataset(
        X, 
        torch.randint(0, 1, (10000,))
    )
    
    dataLoader = data.DataLoader(dataset, batch_size=64, shuffle=False)
    
    model = nn.Sequential(
        nn.Conv2d(1, 16, 3, 1, 1),
        nn.BatchNorm2d(16) if args.bn else NoOP(),
        nn.ReLU(inplace=True),
        nn.Conv2d(16, 16, 7, 3, 0),
        nn.BatchNorm2d(16) if args.bn else NoOP(),
        nn.ReLU(inplace=True),
        nn.Flatten(1),
        nn.Linear(16 * 8 * 8, 16 * 4 * 4),
        nn.BatchNorm1d(16 * 4 * 4) if args.bn else NoOP(),
        nn.ReLU(inplace=True),
        nn.Linear(16 * 4 * 4, 8 * 2 * 2),
        nn.BatchNorm1d(8 * 2 * 2) if args.bn else NoOP(),
        nn.ReLU(inplace=True),
        nn.Linear(32, 10)
    )
    if args.cuda:
        model = model.cuda()
    
    print("Norm of model before training: ", hash_module(model))
    
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    
    for x, y in dataLoader:
        if args.cuda:
            x = x.cuda()
            y = y.cuda()
        
        optimizer.zero_grad()
        y_hat = model(x)
        loss = F.cross_entropy(y_hat, y)
        loss.backward()
        optimizer.step()
        
    print("Norm of model after training: ", hash_module(model))
    

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--bn", action="store_true")
    parser.add_argument("--cuda", action="store_true")
    args = parser.parse_args()
    
    for i in range(4):
        print("{}-th RUN:".format(i))
        run(args)
    

Expected behavior

Without BN and disable CUDA

python run.py

0-th RUN:
[W Context.cpp:69] Warning: torch.set_deterministic is in beta, and its design and  functionality may change in the future. (function operator())
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(20.4299)
Norm of model after training:  tensor(21.0595)
1-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(20.4299)
Norm of model after training:  tensor(21.0595)
2-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(20.4299)
Norm of model after training:  tensor(21.0595)
3-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(20.4299)
Norm of model after training:  tensor(21.0595)

With BN and disable CUDA.

python run.py --bn

0-th RUN:
[W Context.cpp:69] Warning: torch.set_deterministic is in beta, and its design and  functionality may change in the future. (function operator())
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(50.0867)
Norm of model after training:  tensor(50.2153)
1-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(50.0867)
Norm of model after training:  tensor(50.2153)
2-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(50.0867)
Norm of model after training:  tensor(50.2153)
3-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(50.0867)
Norm of model after training:  tensor(50.2153)

One can see that BN does not matter by the comparison of above 2 commands.

Without BN and enable CUDA

python run.py --cuda

0-th RUN:
[W Context.cpp:69] Warning: torch.set_deterministic is in beta, and its design and  functionality may change in the future. (function operator())
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(20.4299, device='cuda:0')
Norm of model after training:  tensor(21.0595, device='cuda:0')
1-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(20.4299, device='cuda:0')
Norm of model after training:  tensor(21.0595, device='cuda:0')
2-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(20.4299, device='cuda:0')
Norm of model after training:  tensor(21.0595, device='cuda:0')
3-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(20.4299, device='cuda:0')
Norm of model after training:  tensor(21.0595, device='cuda:0')

With BN and enable CUDA

python run.py --bn --cuda

0-th RUN:
[W Context.cpp:69] Warning: torch.set_deterministic is in beta, and its design and  functionality may change in the future. (function operator())
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(50.0867, device='cuda:0')
Norm of model after training:  tensor(50.2151, device='cuda:0')
1-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(50.0867, device='cuda:0')
Norm of model after training:  tensor(50.2153, device='cuda:0')
2-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(50.0867, device='cuda:0')
Norm of model after training:  tensor(50.2154, device='cuda:0')
3-th RUN:
Norm of X:  tensor(1616.1311)
Norm of model before training:  tensor(50.0867, device='cuda:0')
Norm of model after training:  tensor(50.2151, device='cuda:0')

There are three kinds of model norm: 50.2151, 50.2153 and 50.2154.

Environment

conda list:

  • Linux
  • python 3.8.5
  • pytorch 1.7.1 py3.8_cuda10.2.89_cudnn7.6.5_0

cc @ngimel @aocsa @mruberry @kurtamohler @albanD @jbschlosser

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: determinismmodule: nnRelated to torch.nnmodule: norms and normalizationtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions