Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LTC] Fail to run testcase of latest lazy_tensor_core branch #65465

Open
leslie-fang-intel opened this issue Sep 22, 2021 · 2 comments
Open

[LTC] Fail to run testcase of latest lazy_tensor_core branch #65465

leslie-fang-intel opened this issue Sep 22, 2021 · 2 comments
Assignees
Labels
lazy Lazy Tensor work items module: lazy triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@leslie-fang-intel
Copy link
Collaborator

leslie-fang-intel commented Sep 22, 2021

I have built latest lazy_tensor_core branch with commit: 7f3d592
After that, I find it fails my test case in ltm.mark_step():

import torch
import torch.nn as nn
import copy
import time
import lazy_tensor_core
import lazy_tensor_core.core.lazy_model as ltm

lazy_tensor_core._LAZYC._ltc_init_ts_backend()

class SimpleNet(torch.nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv = torch.nn.Conv2d(64, 128, (3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        self.conv2 = torch.nn.Conv2d(128, 128, (3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        self.conv3 = torch.nn.Conv2d(64, 128, (1, 1), stride=(2, 2), padding=(0, 0), bias=False)

    def forward(self, x):
        x1 = self.conv(x)
        y1 = self.conv2(x1)
        y2 = self.conv3(x)
        y = y1 + y2
        y = torch.flatten(y, start_dim=1)
        return y
lazy_device = ltm.lazy_device()

model = SimpleNet()
model.train().to(lazy_device)

x = torch.rand(64, 64, 3, 3, requires_grad=True).to(lazy_device)
y = model(x)
yg = torch.rand(64, 512).to(lazy_device)
loss = nn.CrossEntropyLoss()
output = loss(y, yg)
output.backward()

ltm.mark_step()

@wconstab @alanwaketan Could you help to take a look?
Here is the detail fail message:

Traceback (most recent call last):
  File "test_lazy.py", line 46, in <module>
    ltm.mark_step()
  File "/home/lesliefang/pytorch_1_7_1/lazy_tensor/pytorch/lazy_tensor_core/lazy_tensor_core/core/lazy_model.py", line 727, in mark_step
    wait=xu.getenv_as('LTC_SYNC_WAIT', bool, False))
ValueError: stoi

@ezyang ezyang added lazy Lazy Tensor work items module: lazy triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Sep 22, 2021
@alanwaketan alanwaketan self-assigned this Sep 22, 2021
@alanwaketan alanwaketan added this to To do in Lazy Tensor Core via automation Sep 22, 2021
@alanwaketan alanwaketan moved this from To do to In progress in Lazy Tensor Core Sep 22, 2021
@alanwaketan alanwaketan moved this from In progress to To do in Lazy Tensor Core Sep 22, 2021
@alanwaketan
Copy link
Collaborator

Yup, absolutely love to investigate what's going on.

@leslie-fang-intel
Copy link
Collaborator Author

@alanwaketan Thanks for taking a look of this issue. I suspect this issue may relate with GCC or LibC version. If I switch to another system(CentOS Linux release 8.4.2105) with GCC 8.4.1. I can't reproduce this issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
lazy Lazy Tensor work items module: lazy triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Development

No branches or pull requests

3 participants