Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Get wrong jacobian from copyslice operation #1007

@KagamineLenOffical

Description

@KagamineLenOffical

Hi,
Thanks for your great work, I'm working with some minimize algorithm accleration using this code and find a problem.
Here is a sample example to reproduce the conditon:

import torch
from functorch import jacfwd
class FunctionWrapper(object):
    def __init__(self, fun):  # note that function can be a lambda expression
        self._fun = fun
        self.fevals = 0
        self.cur_x = None
        self.cur_y = None
        self.input_segment = None
        self.constant_dict = None

    def __call__(self, v, **kwargs):
        self.fevals += 1
        # self.cur_x = v.view(-1).requires_grad_()
        self.cur_x = v.view(-1)
        if self.input_segment is not None:
            x = []
            for i, _ in enumerate(self.input_segment):
                if i > 0:
                    x.append(v[self.input_segment[i - 1]:self.input_segment[i]])
            self.cur_y = self._fun(*x, **kwargs)
        else:
            self.cur_y = self._fun(self.cur_x, **kwargs)
        return self.cur_y

    def input_constructor(self, *args):  # if has time, convert to kwargs input
        l = []
        self.input_segment = [0]
        cur = 0
        for v in args:
            nv = v.view(-1)
            l.append(nv)
            cur += nv.size()[0]
            self.input_segment.append(cur)
        x = torch.concat(l).detach().requires_grad_()
        return x

if __name__ == '__main__':
    xx_ = torch.tensor([4.,5.,6.])
    yy_ = torch.tensor([7.,8.])
    def func(*args):
        #x_ = torch.tensor([1.,2.,3.,4.])
        xx_[:2] = args[0]
        y = args[1]
        return torch.vstack([(xx_**2).sum(),(y**3).sum()])
    funcc = dogleg.FunctionWrapper(func)
    xx = funcc.input_constructor(xx_[:2],yy_)
    print(torch.autograd.functional.jacobian(funcc,xx))
    print(jacfwd(funcc)(xx))

The functionWrapper is for counting function calls and spliting/merging the input value.
The output will be:

tensor([[[  8.,  10.,   0.,   0.]],

        [[  0.,   0., 147., 192.]]])
tensor([[[  0.,   0.,   0.,   0.]],

        [[  0.,   0., 147., 192.]]], grad_fn=<ViewBackward0>)

You can see that the result is different from functional.jacobian and get zero for all the w.r.t.s that slicecopy into input xx_.
But if we move xx_ into the function, we can get a right result.

if __name__ == '__main__':
    xx_ = torch.tensor([4.,5.,6.],requires_grad=True)
    yy_ = torch.tensor([7.,8.],requires_grad=True)
    def func(*args):
        x_ = torch.tensor([1.,2.,3.,4.])
        x_[:2] = args[0]
        y = args[1]
        return torch.vstack([(x_**2).sum(),(y**3).sum()])
    funcc = FunctionWrapper(func)
    xx = funcc.input_constructor(xx_[:2],yy_)
    print(torch.autograd.functional.jacobian(funcc,xx))
    print(jacfwd(funcc)(xx))

With the output:

tensor([[[  8.,  10.,   0.,   0.]],

        [[  0.,   0., 147., 192.]]])
tensor([[[  8.,  10.,   0.,   0.]],

        [[  0.,   0., 147., 192.]]], grad_fn=<ViewBackward0>)

Is there any misusage of this function? I can fix this problem by firstly split the input into two inputs, but i'm looking for a more general solution.

Metadata

Metadata

Assignees

Labels

high priorityThese issues are at the top of mind for us.

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions