In [5]:
import torch

fast_weight = []
grad_fast_weight = []

class AutogradLinearAttn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, k, v, q, layer_id):
        # key, query: (B, dim)
        # value: (B, v_dim)
        # layer id: layer_id
        global fast_weight
        # fast_weight is a global variable. This is needed as we can not have
        # extra, non-gradient parameter in the backward pass.
        ctx.save_for_backward(k, v, q)
        ctx.layer_id = layer_id

        weight_update = torch.bmm(v.unsqueeze(2), k.unsqueeze(1))
        fast_weight[layer_id] += weight_update
        output = torch.bmm(fast_weight[layer_id], q.unsqueeze(2)).squeeze(2)

        return output

    @staticmethod
    def backward(ctx, grad_out):
        global fast_weight, grad_fast_weight
        k, v, q = ctx.saved_tensors
        layer_id = ctx.layer_id

        # compute grad_q
        grad_q = torch.bmm(grad_out.unsqueeze(1), fast_weight[layer_id]).squeeze()

        # revert update fast weight
        fast_weight[layer_id] -= torch.bmm(v.unsqueeze(2), k.unsqueeze(1))

        # update grad_W
        grad_fast_weight[layer_id] += torch.bmm(grad_out.unsqueeze(2), q.unsqueeze(1))

        # compute grad_k and grad_v
        #print(v.dtype)
        #print(grad_fast_weight[layer_id].dtype)
        grad_k = torch.bmm(v.unsqueeze(1), grad_fast_weight[layer_id]).squeeze(1)
        grad_v = torch.bmm(grad_fast_weight[layer_id], k.unsqueeze(2)).squeeze(2)

        return grad_k, grad_v, grad_q, None

In [6]:
bsz = 3
dim = 5
v_dim = 2
steps = 10

dtype = torch.double

fast_weight = [0.0]
grad_fast_weight = [0.0]

k = torch.randn(steps, bsz, dim, requires_grad=True, dtype=dtype)
v = torch.randn(steps, bsz, v_dim, requires_grad=True, dtype=dtype)
q = torch.rand([steps, bsz, dim], requires_grad=True, dtype=dtype)

func = AutogradLinearAttn.apply

output = torch.zeros(bsz, v_dim, dtype=dtype)

for i in range(steps):
    output += func(k[i], v[i], q[i], 0)

loss = output.sum()
loss.backward()
    


torch.Size([3, 5]) torch.Size([3, 2, 5])
torch.Size([3, 5]) torch.Size([3, 2, 5])
torch.Size([3, 5]) torch.Size([3, 2, 5])
torch.Size([3, 5]) torch.Size([3, 2, 5])
torch.Size([3, 5]) torch.Size([3, 2, 5])
torch.Size([3, 5]) torch.Size([3, 2, 5])
torch.Size([3, 5]) torch.Size([3, 2, 5])
torch.Size([3, 5]) torch.Size([3, 2, 5])
torch.Size([3, 5]) torch.Size([3, 2, 5])
torch.Size([3, 5]) torch.Size([3, 2, 5])
