Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Ban passing torch view tensors into taichi kernel #4225

Merged
merged 1 commit into from
Feb 8, 2022

Conversation

ailzhang
Copy link
Contributor

@ailzhang ailzhang commented Feb 7, 2022

fixes #4208

Related issue = #

@netlify
Copy link

netlify bot commented Feb 7, 2022

✔️ Deploy Preview for docsite-preview ready!

🔨 Explore the source changes: 4e07bd5

🔍 Inspect the deploy log: https://app.netlify.com/sites/docsite-preview/deploys/62012d0c3f51f600073e13f7

😎 Browse the preview: https://deploy-preview-4225--docsite-preview.netlify.app

Copy link
Member

@victoriacity victoriacity left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM Thanks!

@strongoier strongoier merged commit 35a5cc1 into taichi-dev:master Feb 8, 2022
@ifsheldon
Copy link

This is a breaking change, so I suggest at least bump the version by +0.1 according to semantic versioning

@ifsheldon
Copy link

ifsheldon commented Feb 9, 2022

Besides, this commit affects field.from_torch(), which impacts beyond the scope of ti.ndarray mentioned in #4225.

Moreover, the impact is weird. Can you test this case?

@ti.kernel
def ti_print(a: ti.template()):
    for i in ti.grouped(a):
        print(f"{i[0]} {i[1]}, {a[i]}")


def test_view():
    ti.init(ti.cpu)
    a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
    a_t = a.T
    assert a_t._is_view()
    a_field = ti.field(ti.f32, shape=a.shape)
    a_field.from_torch(a_t)  # ok, but why?
    ti_print(a_field)
    a_complex = torch.ones(2, 2, dtype=torch.cfloat)
    a_real_view = torch.view_as_real(a_complex)
    assert a_real_view._is_view()
    complex_field = ti.field(ti.f32, shape=a_real_view.shape)
    complex_field.from_torch(a_real_view)  # not ok
    ti_print(complex_field)

This impact on field.from_torch() breaks my library that depends on Taichi, but I don't know if it's expected.

Can you please check the commit? @ailzhang

My test log is

tests/test_ti.py:36 (test_view)
def test_view():
        ti.init(ti.cpu)
        a = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
        a_t = a.T
        assert a_t._is_view()
        a_field = ti.field(ti.f32, shape=a.shape)
        a_field.from_torch(a_t)  # ok, but why?
        ti_print(a_field)
        a_complex = torch.ones(2, 2, dtype=torch.cfloat)
        a_real_view = torch.view_as_real(a_complex)
        assert a_real_view._is_view()
        complex_field = ti.field(ti.f32, shape=a_real_view.shape)
>       complex_field.from_torch(a_real_view)  # not ok

test_ti.py:49: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../../../miniconda3/envs/build_ti/lib/python3.9/site-packages/taichi-0.8.12-py3.9-linux-x86_64.egg/taichi/lang/util.py:226: in wrapped
    return func(*args, **kwargs)
../../../miniconda3/envs/build_ti/lib/python3.9/site-packages/taichi-0.8.12-py3.9-linux-x86_64.egg/taichi/lang/field.py:146: in from_torch
    self.from_numpy(arr.contiguous())
../../../miniconda3/envs/build_ti/lib/python3.9/site-packages/taichi-0.8.12-py3.9-linux-x86_64.egg/taichi/lang/util.py:226: in wrapped
    return func(*args, **kwargs)
../../../miniconda3/envs/build_ti/lib/python3.9/site-packages/taichi-0.8.12-py3.9-linux-x86_64.egg/taichi/lang/field.py:264: in from_numpy
    ext_arr_to_tensor(arr, self)
../../../miniconda3/envs/build_ti/lib/python3.9/site-packages/taichi-0.8.12-py3.9-linux-x86_64.egg/taichi/lang/kernel_impl.py:740: in wrapped
    return primal(*args, **kwargs)
../../../miniconda3/envs/build_ti/lib/python3.9/site-packages/taichi-0.8.12-py3.9-linux-x86_64.egg/taichi/lang/shell.py:37: in new_call
    ret = old_call(*args, **kwargs)
../../../miniconda3/envs/build_ti/lib/python3.9/site-packages/taichi-0.8.12-py3.9-linux-x86_64.egg/taichi/lang/kernel_impl.py:671: in __call__
    return self.compiled_functions[key](*args)
../../../miniconda3/envs/build_ti/lib/python3.9/site-packages/taichi-0.8.12-py3.9-linux-x86_64.egg/taichi/lang/kernel_impl.py:579: in func__
    tmp, torch_callbacks = self.get_torch_callbacks(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

self = <taichi.lang.kernel_impl.Kernel object at 0x7fbfcc516610>
v = tensor([[[1., 0.],
         [1., 0.]],

        [[1., 0.],
         [1., 0.]]])
has_torch = True, is_ndarray = False

    def get_torch_callbacks(self, v, has_torch, is_ndarray=True):
        callbacks = []
    
        def get_call_back(u, v):
            def call_back():
                u.copy_(v)
    
            return call_back
    
        assert has_torch
        assert isinstance(v, torch.Tensor)
        if v._is_view():
>           raise ValueError(
                "Torch view tensors are not supported, please call tensor.clone() before passing it into taichi kernel."
            )
E           ValueError: Torch view tensors are not supported, please call tensor.clone() before passing it into taichi kernel.

../../../miniconda3/envs/build_ti/lib/python3.9/site-packages/taichi-0.8.12-py3.9-linux-x86_64.egg/taichi/lang/kernel_impl.py:491: ValueError

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Views of PyTorch tensors are not observed in ti.any_arr() access
4 participants