Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JIT] nn.ModuleList loses None objects inside it after scripting #39309

Open
WaterKnight1998 opened this issue May 31, 2020 · 1 comment
Open
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triage review

Comments

@WaterKnight1998
Copy link

WaterKnight1998 commented May 31, 2020

馃悰 Bug

To Reproduce

Executing torch.jit.script over my model is working however it returns a model that fails at runtime.

Looking deeply the nn.ModuleList is loosing None elements from the Modulelist.

Here, above I attach a code for reproducing the error:

import os
import sys
import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision import transforms
from PIL import Image


class TestBlock(nn.Module):
    def __init__(self):
        super(TestBlock, self).__init__()
        
        layers = []
        layers.append(None)
        layers.append(None)
        layers.append(nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
                               bias=False))
        self.layer = nn.ModuleList(layers)
        
    def forward(self,x):
        for aux in self.layer:
            print("ENTER")
            if aux is not None:
                x = aux(x)
                print("Not None")
        return x

Creating model and tracing it:

model=TestBlock()
traced_cell=torch.jit.script(model)

Testing model with an image:

img = Image.open("test.png")

my_transforms = transforms.Compose([transforms.Resize((1002,1002)),
                                    transforms.ToTensor(),
                                    transforms.Normalize(
                                                        [0.485, 0.456, 0.406],
                                                        [0.229, 0.224, 0.225])])
img_input= my_transforms(img).unsqueeze(0).cpu()

res=model(img_input)

This outputs the next:

ENTER
ENTER
ENTER
Not None

Traced version output:

res=traced_cell(img_input)
ENTER
Not None

Expected behavior

Get same output as original model

cc @suo

@facebook-github-bot facebook-github-bot added the oncall: jit Add this issue/PR to JIT oncall triage queue label May 31, 2020
@WaterKnight1998 WaterKnight1998 changed the title [JIT] TorchScript Scripted model got fewer elements at nn.ModuleList [JIT] nn.ModuleList with some Nones inside it losses them after Scripting Jun 1, 2020
@WaterKnight1998 WaterKnight1998 changed the title [JIT] nn.ModuleList with some Nones inside it losses them after Scripting [JIT] nn.ModuleList loses None objects inside it after scripting Jun 2, 2020
@WaterKnight1998
Copy link
Author

Tracing doesn't work either it throws an assertion error assert(isinstance(orig, torch.nn.Module))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
oncall: jit Add this issue/PR to JIT oncall triage queue triage review
Projects
None yet
Development

No branches or pull requests

3 participants