Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal assert in NTK computation #417

Closed
zou3519 opened this issue Jan 26, 2022 · 3 comments
Closed

Internal assert in NTK computation #417

zou3519 opened this issue Jan 26, 2022 · 3 comments
Assignees
Milestone

Comments

@zou3519
Copy link
Contributor

zou3519 commented Jan 26, 2022

import torch
from functorch import jvp, vjp, vmap

torch.manual_seed(0)
x_train = torch.randn(1, 2)
x_test = torch.randn(1, 2)
params = (torch.tensor([[1., 0.5], [1., 1.]]),)

def fnet_single(params, x):
    return (x @ params[0]).sigmoid()
    
def empirical_ntk_implicit(fnet_single, params, x1, x2):
    def get_ntk(x1, x2):
        def push_fnet_x1(params):
            return fnet_single(params, x1)
    
        def push_fnet_x2(params):
            return fnet_single(params, x2)
    
        output, vjp_fn = vjp(push_fnet_x1, params)
    
        def get_ntk_slice(vec):
            vjps = vjp_fn(vec)
            _, jvps = jvp(push_fnet_x2, (params,), vjps)
            return jvps
    
        basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
        result = vmap(get_ntk_slice)(basis)
        return result
    
    return vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)
    
res2 = empirical_ntk_implicit(fnet_single, params, x_train, x_test)

gives

RuntimeError: self_bdim.has_value()INTERNAL ASSERT FAILED at "/private/home/rzou/functorch2/functorch/csrc/BatchRulesViews.cpp":389, please report a bug to PyTorch.
@zou3519
Copy link
Contributor Author

zou3519 commented Jan 28, 2022

Minimal repro:

import torch
from functorch import vmap

torch.manual_seed(0)
x = torch.randn(5, 2, 2, 2)
y = torch.randn(5, 7, 2)

vmap(vmap(torch.matmul, in_dims=(None, 0)))(x, y)

@zou3519
Copy link
Contributor Author

zou3519 commented Feb 2, 2022

Ugh I need to rewrite our codegen to fix this

@zou3519
Copy link
Contributor Author

zou3519 commented Feb 25, 2022

fixed in #530

@zou3519 zou3519 closed this as completed Feb 25, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant