Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TorchScript pack_padded_sequence and pad_packed_sequence run time error #41869

Open
nlpconf opened this issue Jul 22, 2020 · 13 comments
Open

TorchScript pack_padded_sequence and pad_packed_sequence run time error #41869

nlpconf opened this issue Jul 22, 2020 · 13 comments
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects

Comments

@nlpconf
Copy link

nlpconf commented Jul 22, 2020

❓ Questions and Help

Hi, I am facing this problem and have been searching for answers for a day. Anyone can help?

RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript, serialized code (most recent call last):
  File "code/__torch__.py", line 24, in forward
    output = torch.empty_like(sorted_indices0, dtype=4, layout=0, device=torch.device("cpu"), pin_memory=False, memory_format=0)
    _4 = torch.arange(0, 1, 1, dtype=None, layout=0, device=torch.device("cpu"), pin_memory=False)
    tensor1 = torch.scatter_(output, 0, sorted_indices0, _4)
              ~~~~~~~~~~~~~~ <--- HERE
    padded_output, lengths2 = torch._pad_packed_sequence(tensor, tensor0, True, 0., 5)
    tensor2 = torch.index_select(padded_output, 0, tensor1)
...
...
... 

RuntimeError: Expected index [4] to be smaller than self [4] apart from dimension 0 and to be smaller size than src [1]

The part of the code causing the problem is here. I was trying to use an RNN. I removed the rnn layer and the problem persists.

    def forward(self, input1, input2, mask_len):
        embedded1 = self.embedding(input1)
        packed_embeds_1 = pack_padded_sequence(embedded1,
                                               mask_len,
                                               batch_first=True,
                                               enforce_sorted=False)
        encoder_output_1 = pad_packed_sequence(packed_embeds_1,
                                               batch_first=True,
                                               total_length=self.max_length)

My code runs without a problem before I am trying to convert it to a TorchScript. And TorchScript works fine until I add the rnn layer (using pack_padded_sequence and pad_packed_sequence ).

cc @suo @gmagogsfm

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label Jul 22, 2020
@github-actions github-actions bot added this to Need triage in JIT Triage Jul 22, 2020
@wconstab wconstab moved this from Need triage to In discussion in JIT Triage Jul 30, 2020
@wconstab
Copy link
Contributor

wconstab commented Jul 30, 2020

@nlpconf Please provide more details about how to reproduce this issue, such as which input arguments to supply to the function. Ideally, provide a whole self-contained repro .py script.

@nlpconf
Copy link
Author

nlpconf commented Jul 30, 2020

Hi @wconstab , Thanks for replying. The input1 and input2 are just 2d encoded padded text sequences with size [batch_size, max_length]. For example, [[1,2,3,4,5],[2,3,5,0,0], ...], mask_len Is the real text length, for example, [5,3,...]. The model is like this:

class Model(nn.Module):
    def __init__(self, vocab_size, max_length,embed_dim, num_class):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.similarity_function = torch.nn.CosineSimilarity(dim =-1,eps=1e-6)
        self.max_length = max_length
        self.embed_dim = embed_dim
        self.fc = nn.Linear(self.max_length, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, input1, input2, mask_len):
        embedded1 = self.embedding(input1)
        packed_embeds_1 = pack_padded_sequence(embedded1,
                                               mask_len,
                                               batch_first=True,
                                               enforce_sorted=False)
        encoder_output_1 = pad_packed_sequence(packed_embeds_1,
                                               batch_first=True,
                                               total_length=self.max_length)
        embedded2 = self.embedding(input2)
        embedded1_tiled = torch.transpose(torch.reshape(encoder_output_1[0].repeat((1,self.max_length,1)),(-1,self.max_length,self.max_length,self.embed_dim)),1,2)
        embedded2_tiled = torch.reshape(embedded2.repeat((1,self.max_length,1)),(-1,self.max_length,self.max_length,self.embed_dim))
        similarity = self.similarity_function(embedded1_tiled,embedded2_tiled)
        similarity = torch.mean(similarity,dim = -1)
        return self.fc(similarity[0])

Without the steps using packed_padded_sequence and pad_packed_sequence, it works fine.
This is how I saved the model

device = torch.device("cpu")
model.to(device)
text1, text2, label,x_length = generate_batch(valid_data[:1])
text1, text2, x_length = text1.to(device), text2.to(device),x_length.to(device)
traced_script_module = torch.jit.trace(model, (text1, text2,x_length))
traced_script_module.save(saved_model_path)

@wanchaol
Copy link
Contributor

wanchaol commented Aug 5, 2020

@nlpconf could you provide a small self contained script that I can run locally? It's hard for me to reproduce as many functions are missing, and if you could provide a smaller self contained script, it will help me root cause the issue faster :)

@nlpconf
Copy link
Author

nlpconf commented Aug 6, 2020

@wanchaol

Here I prepared a small piece of code with fake data. Please let me know if this works.

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch.nn.functional as F

import time
import random


def generate_fake_data(n_data, vocab_size, n_class, max_length):
    data = []
    for i in range(n_data):
        length = random.randint(1, max_length)
        d = [random.randint(0, vocab_size - 1) for i in range(length)]
        d = d + [0] * (max_length - len(d))
        lab = random.randint(0, n_class - 1)
        data.append((d, lab, length))
    return data


class Model(nn.Module):
    def __init__(self, vocab_size, max_length, embed_dim, num_class):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.similarity_function = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
        self.max_length = max_length
        self.embed_dim = embed_dim
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, input1, mask_len):
        embedded1 = self.embedding(input1)
        packed_embeds_1 = pack_padded_sequence(embedded1,
                                               mask_len,
                                               batch_first=True,
                                               enforce_sorted=False)
        encoder_output_1, _ = pad_packed_sequence(packed_embeds_1,
                                               batch_first=True,
                                               total_length=self.max_length)
        encoder_output_1 = torch.mean(encoder_output_1, dim=1)

        output = self.fc(encoder_output_1)
        return output


def generate_batch(batch):
    label = torch.tensor([entry[1] for entry in batch])
    text = torch.tensor([entry[0] for entry in batch])
    x_length = torch.tensor([entry[2] for entry in batch])

    return text, label, x_length


def train_func(train_data, model):
    # Train the model
    torch.autograd.set_detect_anomaly(True)
    train_loss = 0
    train_acc = 0
    data = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,
                      collate_fn=generate_batch)
    for i, (text1, label, x_length) in enumerate(data):
        optimizer.zero_grad()
        text1, label, x_length = text1.to(device), label.to(device), x_length.to(device)
        output = model(text1, x_length)
        loss = criterion(output, label)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()
        train_acc += (output.argmax(1) == label).sum().item()
    scheduler.step()

    return train_loss / len(train_data), train_acc / len(train_data)


def test(data_, model):
    model.eval()
    loss = 0
    acc = 0
    data = DataLoader(data_, batch_size=BATCH_SIZE, collate_fn=generate_batch)
    for text1, label, x_length in data:
        text1, label, x_length = text1.to(device), label.to(device), x_length.to(device)
        with torch.no_grad():
            output = model(text1, x_length)
            loss = criterion(output, label)
            loss += loss.item()
            acc += (output.argmax(1) == label).sum().item()

    return loss / len(data_), acc / len(data_)


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embed_dim = 100
num_class = 2
N_EPOCHS = 1
BATCH_SIZE = 5
vocab_size = 1000
n_data = 10000
max_length = 10
all_data = generate_fake_data(n_data, vocab_size, num_class, max_length)

random.seed(19)
random.shuffle(all_data)
split_ratio = 0.8
num_data = len(all_data)
train_data, test_data = all_data[:int(num_data * split_ratio)], all_data[int(num_data * split_ratio):]
num_train_data = len(train_data)
train_data, valid_data = train_data[:int(num_train_data * split_ratio)], train_data[int(num_train_data * split_ratio):]

model = Model(vocab_size, max_length, embed_dim, num_class).to(device)
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=4.0)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

for epoch in range(N_EPOCHS):
    start_time = time.time()
    train_loss, train_acc = train_func(train_data, model)
    valid_loss, valid_acc = test(valid_data, model)

    secs = int(time.time() - start_time)
    mins = secs / 60
    secs = secs % 60

    print('train', train_loss, train_acc)
    print('valid', valid_loss, valid_acc)

model.eval()
saved_model_path = '../test_model/'
device = torch.device("cpu")
model.to(device)
text1, label, x_length = generate_batch(valid_data[:1])
text1, x_length = text1.to(device), x_length.to(device)
traced_script_module = torch.jit.trace(model, (text1, x_length))
traced_script_module.save(saved_model_path + 'model.pt')

model = torch.jit.load(saved_model_path + 'model.pt')
loss, acc = test(test_data, model)

@BartlomiejSkwira
Copy link

@nlpconf have You resolved this issue in any way? I think I have something similar

@superlyc
Copy link

@nlpconf have You resolved this issue in any way? I think I have something similar

Nope. Maybe @wanchaol can help.

@wanchaol wanchaol removed their assignment Jul 15, 2021
@wanchaol wanchaol moved this from In discussion to Need triage in JIT Triage Jul 15, 2021
@nikithamalgifb nikithamalgifb moved this from Need triage to In discussion in JIT Triage Jul 20, 2021
@nikithamalgifb
Copy link
Contributor

@BartlomiejSkwira @nlpconf : Did you guys try out the above suggested workaround?

@BartlomiejSkwira
Copy link

@nikithamalgifb I don't see any workaround in this thread. What did You mean exactly?

In my case code 'X = X.cuda()' was causing an exception. My workaround was to comment out this part (and it had to be commented out, not just disabled with an if). In the end I trained and served my model on CPU.

@nikithamalgifb
Copy link
Contributor

I am referring to this: #41869 (comment) what @wanchaol had suggested earlier

@BartlomiejSkwira
Copy link

This comment is from @nlpconf, not @wanchaol and it's a reproduced the issue code?

image

@gmagogsfm
Copy link
Contributor

@nikithamalgifb I don't see any workaround in this thread. What did You mean exactly?

In my case code 'X = X.cuda()' was causing an exception. My workaround was to comment out this part (and it had to be commented out, not just disabled with an if). In the end I trained and served my model on CPU.

This seems like an interesting case especially since you mentioned guarding it with an if doesn't work. This makes me suspect a mis-compilation. Would you mind showing a bit more about your code and the changes you made?

@BartlomiejSkwira
Copy link

@gmagogsfm I have removed Cuda code totally and trained/served my model on CPU.

Btw: I think during compilation all code paths have to be visited, so guarding the statement with an if it's not enough

@antoniojkim
Copy link
Collaborator

I've just recently encountered this issue and have a minimal reproducer on the latest stable PyTorch (2.0.1)

My reproducer in question is in a custom implementation of topk mask construction:

def topk_mask(score, num_dense_elem):
    indices = torch.sort(score, dim=-1, descending=True).indices
    iota = torch.arange(
            indices.shape[-1],
            dtype=num_dense_elem.dtype,
            device=num_dense_elem.device,
    )
    in_topk = iota <= num_dense_elem
    indices = torch.where(in_topk, indices, indices[..., 0:1])

    mask = torch.zeros_like(score, dtype=torch.bool)
    mask = mask.scatter(-1, indices, torch.tensor(True))
    return mask

Calling this using eager PyTorch, I get a valid result

In : topk_mask(torch.arange(10, dtype=torch.float32), torch.tensor(4))
Out: tensor([False, False, False, False, False,  True,  True,  True,  True,  True])

But when I trace the function using TorchScript, I get the following error

In : torch.jit.trace(topk_mask, (torch.arange(10, dtype=torch.float32), torch.tensor(4)))
...
<ipython-input-31-8fd7c943c439> in topk_mask(score, num_dense_elem)
     10
     11     mask = torch.zeros_like(score, dtype=torch.bool)
---> 12     mask = mask.scatter(-1, indices, torch.tensor(True))
     13     return mask
     14

RuntimeError: Expected index [10] to be smaller than self [10] apart from dimension 0 and to be smaller size than src []

The error appears to be coming from this check
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/ScatterGatherChecks.h#L90-L100
Where apparently the src tensor's dim is required to be greater than or equal to the dim of the indices. Which obviously isn't true in eager execution. Perhaps in eager there is some broadcasting happening? Either way, this seems to be a hole in TorchScript's support of a fairly simple usage of scatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
JIT Triage
  
In discussion
Development

No branches or pull requests

9 participants