In [1]:
from sympy import *

In [2]:
var('A1 A2 A3 A4')
A_cumsum = Matrix([[0, 0, 0, 0,0], 
                   [A1, 0, 0,0,0],
                   [A2, A2-A1,0,0,0],
                   [A3,A3-A1,A3-A2,0,0],
                   [A4, A4-A1, A4-A2, A4-A3,0],
                  ]).applyfunc(lambda x: exp(x) if x else 0)
A_cumsum1 = Matrix([[0,0,0],
                   [A1,0,0],
                   [A2, A2-A1, 0]]).applyfunc(lambda x: exp(x) if x else 0)
A_cumsum2 = Matrix([[0,0,0],
                   [A3,0,0],
                   [A4,A4-A3,0]]).applyfunc(lambda x: exp(x) if x else 0)

In [3]:
pprint(A_cumsum)
pprint(A_cumsum1)
pprint(A_cumsum2)

⎡ 0       0          0          0      0⎤
⎢                                       ⎥
⎢ A₁                                    ⎥
⎢ℯ        0          0          0      0⎥
⎢                                       ⎥
⎢ A₂   -A₁ + A₂                         ⎥
⎢ℯ    ℯ              0          0      0⎥
⎢                                       ⎥
⎢ A₃   -A₁ + A₃   -A₂ + A₃              ⎥
⎢ℯ    ℯ          ℯ              0      0⎥
⎢                                       ⎥
⎢ A₄   -A₁ + A₄   -A₂ + A₄   -A₃ + A₄   ⎥
⎣ℯ    ℯ          ℯ          ℯ          0⎦
⎡ 0       0      0⎤
⎢                 ⎥
⎢ A₁              ⎥
⎢ℯ        0      0⎥
⎢                 ⎥
⎢ A₂   -A₁ + A₂   ⎥
⎣ℯ    ℯ          0⎦
⎡ 0       0      0⎤
⎢                 ⎥
⎢ A₃              ⎥
⎢ℯ        0      0⎥
⎢                 ⎥
⎢ A₄   -A₃ + A₄   ⎥
⎣ℯ    ℯ          0⎦


In [6]:
var('hi, hw, hx, hy, hz')
states = Matrix([0, hw, hx, hy, hz])
states1 = Matrix([0,hw,hx])
states2 = Matrix([0,hy,hz])
hi = 0

In [7]:
pprint(A_cumsum.multiply(states))
pprint(A_cumsum1.multiply(states1))
pprint(A_cumsum2.multiply(states2))

⎡                    0                     ⎤
⎢                                          ⎥
⎢                    0                     ⎥
⎢                                          ⎥
⎢                   -A₁ + A₂               ⎥
⎢               hw⋅ℯ                       ⎥
⎢                                          ⎥
⎢           -A₁ + A₃       -A₂ + A₃        ⎥
⎢       hw⋅ℯ         + hx⋅ℯ                ⎥
⎢                                          ⎥
⎢    -A₁ + A₄       -A₂ + A₄       -A₃ + A₄⎥
⎣hw⋅ℯ         + hx⋅ℯ         + hy⋅ℯ        ⎦
⎡     0      ⎤
⎢            ⎥
⎢     0      ⎥
⎢            ⎥
⎢    -A₁ + A₂⎥
⎣hw⋅ℯ        ⎦
⎡     0      ⎤
⎢            ⎥
⎢     0      ⎥
⎢            ⎥
⎢    -A₃ + A₄⎥
⎣hy⋅ℯ        ⎦


# State passing test code

In [13]:
import torch.nn.functional as F
def state_passing_ref(states, dA_chunk_cumsum, initial_states=None):
    """
    Argument:
        states: (batch, nchunks, nheads, dim)
        dA_chunk_cumsum: (batch, nheads, nchunks)
        initial_states: (batch, nheads, dim)
    Return:
        out: (batch, nchunks, nheads, dim)
        final_states: (batch, nheads, dim)
    """
    if initial_states is None:
        initial_states = torch.zeros_like(states[:, 0])
    states = torch.cat([rearrange(initial_states, "b h d -> b 1 h d"), states], dim=1)
    dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
    dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
    nchunks = dA_chunk_cumsum.shape[-1]
    # (batch, nheads, nchunks, nchunks)
    dt_chunk_segment_sum = dA_chunk_cumsum[:, :, :, None] - dA_chunk_cumsum[:, :, None, :]
    # (batch, nheads, nchunks, nchunks)
    decay_chunk = torch.exp(dt_chunk_segment_sum)
    causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
    decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
    out = torch.einsum("bhzc,bchd->bzhd", decay_chunk.to(dtype=states.dtype), states)
    return out[:, :-1], out[:, -1]

In [8]:
def state_passing_test(states, dA_chunk_cumsum, initial_states=None):
    """
    Argument:
        states: (nchunks, dim)
        dA_chunk_cumsum: (batch, nheads, nchunks)
        initial_states: (batch, nheads, dim)
    Return:
        out: (batch, nchunks, nheads, dim)
        final_states: (batch, nheads, dim)
    """
    if initial_states is None:
        initial_states = torch.zeros_like(states[0])
        print(rearrange(initial_states, "d -> 1 d").shape)
    print(f"{states.shape = }")
    states = torch.cat([rearrange(initial_states, "d -> 1 d"), states], dim=0)
    dA_chunk_cumsum = F.pad(dA_chunk_cumsum, (1, 0))
    dA_chunk_cumsum = torch.cumsum(dA_chunk_cumsum, dim=-1)
    nchunks = dA_chunk_cumsum.shape[-1]
    # (batch, nheads, nchunks, nchunks)
    dt_chunk_segment_sum = dA_chunk_cumsum[:, None] - dA_chunk_cumsum[None, :]
    print(f"{dt_chunk_segment_sum = }")
    # (batch, nheads, nchunks, nchunks)
    #decay_chunk = torch.exp(dt_chunk_segment_sum)
    decay_chunk=dt_chunk_segment_sum
    print(f"{decay_chunk.shape = }")
    causal_mask = torch.tril(torch.ones(nchunks, nchunks, device=states.device, dtype=bool), diagonal=0)
    print(f"{causal_mask.shape = }")
    decay_chunk = decay_chunk.masked_fill(~causal_mask, 0)
    print(f"{decay_chunk = }")
    out = torch.einsum("zc,cd->zd", decay_chunk.to(dtype=states.dtype), states)
    return out[:-1], out[-1]

In [14]:
import torch
from einops import rearrange
import torch.nn.functional as F

chunks=6
gchunk=2
states = torch.ones([chunks,3]).cumsum(1)
#states[-1] = 0
print(states)
dA_chunk_cumsum = torch.ones([chunks])#.cumsum(0)
#dA_chunk_cumsum[3] += 2
print(f"{dA_chunk_cumsum = }")
fout = torch.zeros(3)
sout_0, fout_0 = state_passing_test(states,dA_chunk_cumsum)
all_sout=[]
all_fout=[]
for i in range(0,chunks//gchunk):
    j = i+1
    print(f"{i*gchunk}:{j*gchunk}")
    sout,fout = state_passing_test(states[i*gchunk:j*gchunk],
                                   dA_chunk_cumsum[i*gchunk:j*gchunk],
                                  initial_states=fout)
    all_sout.append(sout)
    all_fout.append(fout)
print(fout)
print(fout_0)
print(torch.cat(all_sout))
print(sout_0)

tensor([[1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.],
        [1., 2., 3.]])
dA_chunk_cumsum = tensor([1., 1., 1., 1., 1., 1.])
torch.Size([1, 3])
states.shape = torch.Size([6, 3])
dt_chunk_segment_sum = tensor([[ 0., -1., -2., -3., -4., -5., -6.],
        [ 1.,  0., -1., -2., -3., -4., -5.],
        [ 2.,  1.,  0., -1., -2., -3., -4.],
        [ 3.,  2.,  1.,  0., -1., -2., -3.],
        [ 4.,  3.,  2.,  1.,  0., -1., -2.],
        [ 5.,  4.,  3.,  2.,  1.,  0., -1.],
        [ 6.,  5.,  4.,  3.,  2.,  1.,  0.]])
decay_chunk.shape = torch.Size([7, 7])
causal_mask.shape = torch.Size([7, 7])
decay_chunk = tensor([[0., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0.],
        [2., 1., 0., 0., 0., 0., 0.],
        [3., 2., 1., 0., 0., 0., 0.],
        [4., 3., 2., 1., 0., 0., 0.],
        [5., 4., 3., 2., 1., 0., 0.],
        [6., 5., 4., 3., 2., 1., 0.]])
0:2
states.shape = torch.Size([2, 3])
dt_chunk_segment_sum = tens

Given all_fout, all_sout, and dA_chunk_cumsum, update all_sout to be sout_0 without using any sharing of information between all_sout elements

In [None]:
for x in range(3):
    sout,fout = state_passing_test(states,
                                       dA_chunk_cumsum,
                                      initial_states=torch.ones(2)*x)
    print(sout)
    print(fout)