Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non Deterministic results although did everything in order for it to be #39849

Closed
Esaada opened this issue Jun 11, 2020 · 7 comments
Closed

Non Deterministic results although did everything in order for it to be #39849

Esaada opened this issue Jun 11, 2020 · 7 comments
Labels
module: cudnn Related to torch.backends.cudnn, and CuDNN support module: determinism module: numerical-reproducibility triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@Esaada
Copy link

Esaada commented Jun 11, 2020

Non Deterministic outputs is a long discussion, I thought I found the formula for deterministic outputs:

    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(args.seed)

Apparently it's not enough.
The issue, of course, is that the results are different although training with the same seed and the above settings.
Reproducing is easy , I'm attaching the code under "additional context", put it undet main.py, download the data from:https://www.kaggle.com/nicapotato/womens-ecommerce-clothing-reviews , put them in the same dir and just run:
python3 main.py

(If I had the ability to upload the py and csv files it was much easier)

Environment

Collecting environment information

...
PyTorch version: 1.5.0
Is debug build: No
CUDA used to build PyTorch: 10.2

OS: Ubuntu 18.04.2 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.10.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: TITAN V
GPU 1: TITAN V
GPU 2: TITAN V
GPU 3: TITAN V

Nvidia driver version: 440.33.01
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.4.2

Versions of relevant libraries:
[pip3] numpy==1.16.2
[pip3] torch==1.5.0
[pip3] torchfile==0.1.0
[pip3] torchnet==0.0.4
[pip3] torchvision==0.2.2
[conda] Could not collect

  • PyTorch Version (e.g., 1.0):
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, source):
  • Build command you used (if compiling from source):
  • Python version:
  • CUDA/cuDNN version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

The full code:

 import torch
 import torch.nn as nn
 import pandas as pd
 import numpy as np
 import argparse
 import os
 import re
 import spacy

 from collections import Counter
 from torch.utils.data import Dataset, DataLoader
 import torch.nn.functional as F
 import string
 from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
 from sklearn.metrics import mean_squared_error
 from sklearn.model_selection import train_test_split
 import random
 import numpy as np 
 parser = argparse.ArgumentParser(description='PyTorch CINIC10 Training')
parser.add_argument('--seed',type=int,  default=60)
parser.add_argument('--epochs',type=int,  default=90)

parser.add_argument('--save_path',type=str,  default="save_path/")
parser.add_argument('--batch_size',type=int,  default=1024)

best_acc = 0

class ReviewsDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.y = Y
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        return torch.from_numpy(self.X[idx][0].astype(np.int32)), self.y[idx], self.X[idx][1]

def tokenize(text):
    text = re.sub(r"[^\x00-\x7F]+", " ", text)
    regex = re.compile('[' + re.escape(string.punctuation) + '0-9\\r\\t\\n]')  # remove punctuation and numbers
    nopunct = regex.sub(" ", text.lower())
    return [token.text for token in tok.tokenizer(nopunct)]
def encode_sentence(text, vocab2index, N=70):
    tokenized = tokenize(text)
    encoded = np.zeros(N, dtype=int)
    enc1 = np.array([vocab2index.get(word, vocab2index["UNK"]) for word in tokenized])
    length = min(N, len(enc1))
    encoded[:length] = enc1[:length]
    return encoded, length

def train_model(model, epochs, lr=0.001):
    global best_acc
    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adam(parameters, lr=lr)
    for i in range(epochs):
        model.train()
        sum_loss = 0.0
        total = 0
        for batch_idx,(x, y, l) in enumerate(train_dl):
            x = x.long().cuda()
            y = y.long().cuda()
            l = l.cuda()
            y_pred = model(x,l)
            optimizer.zero_grad()
            loss = F.cross_entropy(y_pred, y)
            loss.backward()
            optimizer.step()
            sum_loss += loss.item() * y.shape[0]
            total += y.shape[0]
        val_loss, val_acc  = validation_metrics(model, val_dl)
        if val_acc > best_acc:
           best_acc = val_acc
        if i % 5 == 1:
            print("train loss %.3f, val loss %.3f, val accuracy %.3f" % (
            sum_loss / total, val_loss, val_acc))


def validation_metrics(model, valid_dl):
    model.eval()
    correct = 0
    total = 0
    sum_loss = 0.0
    sum_rmse = 0.0
    for x, y, l in valid_dl:
        x = x.long().cuda()
        y = y.long().cuda()
        l = l.cuda()
        y_hat = model(x, l)
        loss = F.cross_entropy(y_hat, y)
        pred = torch.max(y_hat, 1)[1]
        correct += (pred == y).float().sum()
        total += y.shape[0]
        sum_loss += loss.item() * y.shape[0]
        

    return sum_loss / total, correct / total


class LSTM_fixed_len(torch.nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, embedding_dim, padding_idx=0)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
        self.linear = nn.Linear(hidden_dim, 5)

    def forward(self, x, l):
        x = self.embeddings(x)
        lstm_out, (ht, ct) = self.lstm(x)
        return self.linear(ht[-1])
if __name__ == "__main__":
    args = parser.parse_args()
    
    if not os.path.exists(args.save_path):
      os.mkdir(args.save_path)
    #####
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    np.random.seed(args.seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    random.seed(args.seed)
    #####

    reviews = pd.read_csv("Reviews.csv")
    #print(reviews.shape)
    #Replacing Nan values
    reviews['Title'] = reviews['Title'].fillna('')
    reviews['Review Text'] = reviews['Review Text'].fillna('')

    reviews['review'] = reviews['Title'] + ' ' + reviews['Review Text']
    reviews = reviews[['review', 'Rating']]
    reviews.columns = ['review', 'rating']
    reviews['review_length'] = reviews['review'].apply(lambda x: len(x.split()))
    # changing ratings to 0-numbering
    zero_numbering = {1: 0, 2: 1, 3: 2, 4: 3, 5: 4}
    reviews['rating'] = reviews['rating'].apply(lambda x: zero_numbering[x])
    np.mean(reviews['review_length'])
    tok = spacy.load('en_core_web_sm')



    counts = Counter()
    for index, row in reviews.iterrows():
        counts.update(tokenize(row['review']))
    for word in list(counts):
        if counts[word] < 2:
            del counts[word]
    vocab2index = {"":0, "UNK":1}
    words = ["", "UNK"]
    for word in counts:
        vocab2index[word] = len(words)
        words.append(word)

    reviews['encoded'] = reviews['review'].apply(lambda x: np.array(encode_sentence(x,vocab2index )))
    reviews.head()
    Counter(reviews['rating'])
    X = list(reviews['encoded'])
    y = list(reviews['rating'])
    #haluka = int(len(X)*0.8)
    X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=args.seed)#X[:haluka],X[haluka:],y[:haluka],y[haluka:]#
    train_ds = ReviewsDataset(X_train, y_train)
    valid_ds = ReviewsDataset(X_valid, y_valid)

     
    vocab_size = len(words)
    train_dl = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
    val_dl = DataLoader(valid_ds, batch_size=args.batch_size)
    model_fixed = LSTM_fixed_len(vocab_size, 50, 50)

    train_model(model_fixed.cuda(), epochs=args.epochs, lr=0.01)
    print("Best accuracy:",best_acc)

cc @csarofeen @ptrblck

@cpchen
Copy link

cpchen commented Jun 11, 2020

if you use torch 1.4 is the result still nondeterministic?

@xwang233
Copy link
Collaborator

I see that you are using cudnn 7.4.2. Can you try with the latest cudnn 7.6.5 or v8? https://developer.nvidia.com/cudnn

This may or may not be related to LSTM. There is a known LSTM non-deterministic issue #35661, but that one only happens with non-zero dropout.

@mruberry mruberry added module: determinism module: numerical-reproducibility triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: cudnn Related to torch.backends.cudnn, and CuDNN support labels Jun 12, 2020
@Esaada
Copy link
Author

Esaada commented Jun 15, 2020

Using pytorch 1.4.0 solved the problem (at least in the current situation).
But it's not always possible to downgrade to 1.4.0 , since sometimes there is a need to use torchvision==0.6.0 which demand torch 1.5.0.
Thanks for your help.

@Esaada Esaada closed this as completed Jun 15, 2020
@cpchen
Copy link

cpchen commented Jun 15, 2020

it might be worth reopening this bug because I have had the same issue with torch 1.5 and cuda 7.6.5

@xwang233
Copy link
Collaborator

This might be a regression since @Esaada mentioned that 1.4 result is deterministic, but @cpchen mentioned 1.5 is not.

I'll check with cudnn libraries and pytorch versions and see if problem still exists.

@xwang233 xwang233 reopened this Jun 15, 2020
@xwang233
Copy link
Collaborator

After some tests, it seems like this issue only exists in Volta (Titan V), but not Turing (2070). I tested with cuda 10.2 and cudnn 7.6.5. Pytorch version master/1.5/1.4 doesn't matter.

This is very likely a LSTM issue, and I will check with cudnn team for a fix. As a temporary workaround, you can use CUDA_LAUNCH_BLOCKING=1 python script.py to get deterministic training results.

@xwang233
Copy link
Collaborator

xwang233 commented Jun 17, 2020

This is a known issue to cuDNN 7.6.5 and v8, and will be fixed in the next release. Here is the explanations and a workaround https://docs.nvidia.com/deeplearning/sdk/cudnn-release-notes/rel_8.html#rel-800-Preview__section_qhc_jc1_5kb , in the Limitations section.

Specifically, you can use
CUBLAS_WORKSPACE_CONFIG=:16:8 python script.py
or
CUBLAS_WORKSPACE_CONFIG=:4096:2 python script.py
to make the LSTM result deterministic.

I'll close this issue. Please feel free to reopen it if there is anything new.

Edit: please track this issue at #35661.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: cudnn Related to torch.backends.cudnn, and CuDNN support module: determinism module: numerical-reproducibility triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants