Skip to content

Reproducibility breaks down with weighted Cross Entropy loss #46024

@arsenzaryan

Description

@arsenzaryan

Hello, the following code ceases to be reproducible when the weights in cross entropy are non-integers. Here’s the example:

import numpy as np
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader


h, w, in_ch, out_ch = 32, 32, 3, 5
class Dtst(Dataset):
    def __init__(self, N=20):
        self.X = [torch.randn([in_ch, h, w], dtype=torch.float32) for _ in range(N)]
        self.Y = [torch.randint(low=0, high=out_ch, size=(h,w), dtype=torch.int64) for _ in range(N)]
        
    def __getitem__(self, ix):
        return self.X[ix], self.Y[ix]
    
    def __len__(self):
        return len(self.Y)


class Network(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Conv2d(in_channels=in_ch, out_channels=10, kernel_size=3, padding=1)
        self.drop = nn.Dropout2d(p=0.1)
        self.layer2 = nn.Conv2d(in_channels=10, out_channels=out_ch, kernel_size=3, padding=1)

    def forward(self, x):
        out = self.layer2(self.drop(self.layer1(x)))
        return out

seed = 4
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = True
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)

dtst = Dtst()
model = Network()

device = 'cuda'
model.to(device)
class_weights = ((torch.arange(out_ch)+1).type(torch.FloatTensor)**0.5).to(device)

loss_fn = torch.nn.CrossEntropyLoss(weight=class_weights)
opt = torch.optim.Adam(model.parameters())

preds_dict = dict()
for e in range(1500):
    dtldr = DataLoader(dtst, batch_size=4)
    for x,y in dtldr:
        preds = model(x.to(device))
        loss = loss_fn(preds, y.to(device))
        loss.backward()
        opt.step()

        preds_argmax = preds.argmax(dim=1).flatten()
        preds_dict.update(Counter(preds_argmax.tolist()))

print(sorted(preds_dict.items(), key=lambda x: x[1]))
print(model.layer1.weight.data.norm(2).item())

It’s a simple network with a very basic Dataset, and a simple train loop.
This code is not reproducible. But when I remove the (**0.5) part from the class_weights it becomes reproducible. I.e., if the class weight values are actual floats, not integers cast to floats, then the code is not reproducible.

Also, the problem exists only on cuda. If the device is set to ‘cpu’, the code is reproducible again!
I run this on Ubuntu 18. My environment is the following:
pytorch 1.6.0
cudatoolkit 10.1.243
numpy 1.19.1

P.S. I've already opened the same issue in pytorch forum ((https://discuss.pytorch.org/t/reproducibility-breaks-down-with-weighted-cross-entropy-loss/96632)) but since there is no answer there for a long time, thought maybe in will get more attention here.

cc @ngimel

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: cudaRelated to torch.cuda, and CUDA support in generalmodule: determinismmodule: lossProblem is related to loss functionmodule: numerical-reproducibilitytriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions