Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update neural_tangent_kernels.ipynb #788

Merged
merged 1 commit into from
May 10, 2022
Merged

Conversation

ain-soph
Copy link
Contributor

@ain-soph ain-soph commented May 9, 2022

Fix a small bug

    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')        # should be torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')        # should be torch.einsum('NMKK->NMK', result)

Fix a small bug
```python3
    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')        # should be torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')        # should be torch.einsum('NMKK->NMK', result)
```
@ain-soph
Copy link
Contributor Author

ain-soph commented May 9, 2022

I follow the tutorial and implement the version without using functorch. I wonder what's the advantage of using functorch?

import torch
import torch.nn as nn
from torch.nn.utils import _stateless

import functools

def ntk(module: nn.Module, input1: torch.Tensor, input2: torch.Tensor,
        parameters: dict[str, nn.Parameter] = None,
        compute='full') -> torch.Tensor:
    einsum_expr: str = ''
    match compute:
        case 'full':
            einsum_expr = 'Naf,Mbf->NMab'
        case 'trace':
            einsum_expr = 'Naf,Maf->NM'
        case 'diagonal':
            einsum_expr = 'Naf,Maf->NMa'
        case _:
            raise ValueError(compute)

    if parameters is None:
        parameters = dict(module.named_parameters())
    keys, values = zip(*parameters.items())

    def func(*params: torch.Tensor, _input: torch.Tensor = None):
        _output: torch.Tensor = _stateless.functional_call(
            module, {n: p for n, p in zip(keys, params)}, _input)
        return _output  # (N, C)

    jac1: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input1), values, vectorize=True)
    jac2: tuple[torch.Tensor] = torch.autograd.functional.jacobian(
        functools.partial(func, _input=input2), values, vectorize=True)
    jac1 = [j.flatten(2) for j in jac1]
    jac2 = [j.flatten(2) for j in jac2]
    result = torch.stack([torch.einsum(einsum_expr, j1, j2) for j1, j2 in zip(jac1, jac2)]).sum(0)
    return result

@Chillee
Copy link
Contributor

Chillee commented May 10, 2022

Thanks!

@Chillee Chillee merged commit a7a8e66 into pytorch:main May 10, 2022
zou3519 pushed a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
Fix a small bug
```python3
    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')        # should be torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')        # should be torch.einsum('NMKK->NMK', result)
```
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
Fix a small bug
```python3
    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')        # should be torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')        # should be torch.einsum('NMKK->NMK', result)
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants