In [1]:
import torch

In [None]:
def unpack(hidden_states, cu_seqlens, max_seqlen):
    """Upack the packed hidden_states and pad to max_seqlen.

    Args:
        hidden_states (torch.tensor): size of hidden_states should be (pack_seqlen, hidden_size)
        cu_seqlens (torch.tensor): cu_seqlen
        max_seqlen (int): max length of the samples.

    Returns:
        torch.tensor: dtype same as hidden_states, size should be (len(cu_seqlens)-1, max_seqlen, hidden_size)
    """
    batch_size = cu_seqlens.shape[0] - 1
    hidden_size = hidden_states.shape[1]
    output = torch.zeros(batch_size, max_seqlen, hidden_size, dtype=hidden_states.dtype, device=hidden_states.device)
    for i in range(batch_size):
        output[i, : cu_seqlens[i + 1] - cu_seqlens[i], :] = hidden_states[cu_seqlens[i] : cu_seqlens[i + 1], :]
    return output

def pack(unpacked_tensors, cu_seqlens):
    """Reverse function of unpack.

    Args:
        unpacked_tensors (torch.tensor): (batch_size, max_seqlen, hidden_size)
        cu_seqlens (torch.tensor)

    Returns:
        torch.tensor: dtype should be (pack_seqlen, hidden_size)
    """
    n_values = cu_seqlens[1:] - cu_seqlens[:-1]
    batch_size, seq_len, hidden_size = unpacked_tensors.size()
    seq_indices = (
        torch.arange(seq_len, device=unpacked_tensors.device)
        .unsqueeze(0)
        .unsqueeze(2)
        .repeat(batch_size, 1, hidden_size)
    )
    n_values_tensor_3d = n_values.unsqueeze(1).unsqueeze(2)
    mask_3d = seq_indices < n_values_tensor_3d
    selected_elements_3d = unpacked_tensors[mask_3d]
    reshaped_tensor = selected_elements_3d.view(-1, hidden_size)
    return reshaped_tensor


## Sample

In [3]:
cu_seqlens = torch.load('cu_seqlens.pt', map_location="cuda:0")
indexes = torch.load('indexes.pt', map_location="cuda:0")
hidden_states = torch.load('packed_hidden_states.pt', map_location="cuda:0")

In [4]:
indexes

tensor([ 0,  1,  2,  ..., 76, 77, 78], device='cuda:0')

In [5]:
cu_seqlens

tensor([   0,   86,  241, 1103, 1574, 1969, 2048], device='cuda:0',
       dtype=torch.int32)

In [6]:
indexes[86:241]

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154], device='cuda:0')

In [7]:
indexes[1103:1574]

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 1

In [8]:
indexes[1574:1969]

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 1

In [9]:
indexes[1969:2048]

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
        36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53,
        54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71,
        72, 73, 74, 75, 76, 77, 78], device='cuda:0')

In [10]:
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
unpacked_hidden_states = unpack(hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
unpacked_hidden_states

tensor([[[ 0.0601,  0.0728,  0.0055,  ..., -0.0417,  0.0322,  0.0645],
         [ 0.0261, -0.0476, -0.0334,  ...,  0.0391, -0.0068,  0.0132],
         [ 0.0311,  0.0179, -0.0073,  ..., -0.0713, -0.0430,  0.0192],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0601,  0.0728,  0.0055,  ..., -0.0417,  0.0322,  0.0645],
         [ 0.0466, -0.0903,  0.0013,  ..., -0.0153, -0.0181, -0.0342],
         [-0.0303, -0.0113,  0.0020,  ..., -0.0410, -0.0630,  0.0111],
         ...,
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],

        [[ 0.0601,  0.0728,  0.0055,  ..., -0.0417,  0.0322,  0.0645],
         [-0.0422,  0.0425, -0.0693,  ...,  0

In [11]:
pack_back_hidden_states = pack(unpacked_tensors=unpacked_hidden_states, cu_seqlens=cu_seqlens)
pack_back_hidden_states

tensor([[ 0.0601,  0.0728,  0.0055,  ..., -0.0417,  0.0322,  0.0645],
        [ 0.0261, -0.0476, -0.0334,  ...,  0.0391, -0.0068,  0.0132],
        [ 0.0311,  0.0179, -0.0073,  ..., -0.0713, -0.0430,  0.0192],
        ...,
        [ 0.0007, -0.0742, -0.0103,  ...,  0.0140, -0.0189, -0.0120],
        [ 0.0464, -0.0408,  0.0574,  ..., -0.0067,  0.0874, -0.0503],
        [ 0.0023,  0.0684,  0.0225,  ...,  0.0209, -0.0610,  0.0474]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<ViewBackward0>)

In [12]:
# Check if the pack and unpack are correct
torch.equal(hidden_states, pack_back_hidden_states)

True

## Mamba example

In [None]:
from mamba_ssm.modules.mamba_simple import Mamba
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() if cu_seqlens is not None else None
x = unpack(hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen).float()
dim = x.shape[-1]

model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to('cuda:0')
model(x)

# ModifiedMamba

In [None]:
%load_ext autoreload
%autoreload 2

In [1]:
from mamba_ssm import Mamba
import torch

In [2]:
cu_seqlens = torch.load('cu_seqlens.pt', map_location="cuda:0")
indexes = torch.load('indexes.pt', map_location="cuda:0")
hidden_states = torch.load('packed_hidden_states.pt', map_location="cuda:0")

In [3]:
cu_seqlens

tensor([   0,   86,  241, 1103, 1574, 1969, 2048], device='cuda:0',
       dtype=torch.int32)

In [4]:
sequence_list = torch.split(hidden_states, (cu_seqlens[1:] - cu_seqlens[:-1]).tolist())
zeros_inter = torch.zeros(3, hidden_states.size(1)).cuda()

tensor_list = []
for i, sequence in enumerate(sequence_list):
    tensor_list.append(sequence)
    if i < len(sequence_list) - 1:
        tensor_list.append(zeros_inter.clone())

result_list = torch.concat(tensor_list)
result_list.shape

torch.Size([2063, 2048])

In [5]:
start_indecies = torch.arange(len(sequence_list)+1).cuda() * 3
start_indecies = (cu_seqlens + start_indecies)[1:-1]
start_indecies = start_indecies
start_indecies

tensor([  89,  247, 1112, 1586, 1984], device='cuda:0')

In [6]:
ckpt = torch.load('/mnt/petrelfs/wangzerui/code/train_ssm/mamba-ckpt-1.4B/pytorch_model.bin')

In [7]:
from torch import nn
dim = hidden_states.shape[-1]
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
    use_fast_path=False,
)
model.in_proj.weight = nn.Parameter(ckpt['backbone.layers.0.mixer.in_proj.weight'])
model.conv1d.weight = nn.Parameter(ckpt['backbone.layers.0.mixer.conv1d.weight'])
model.D = nn.Parameter(ckpt['backbone.layers.0.mixer.D'])
model.conv1d.bias = nn.Parameter(ckpt['backbone.layers.0.mixer.conv1d.bias'])
model.x_proj.weight = nn.Parameter(ckpt['backbone.layers.0.mixer.x_proj.weight'])
model.dt_proj.weight = nn.Parameter(ckpt['backbone.layers.0.mixer.dt_proj.weight'])
model.dt_proj.bias = nn.Parameter(ckpt['backbone.layers.0.mixer.dt_proj.bias'])
model.A_log = nn.Parameter(ckpt['backbone.layers.0.mixer.A_log'])
model.out_proj.weight = nn.Parameter(ckpt['backbone.layers.0.mixer.out_proj.weight'])
model = model.to('cuda:0')

In [8]:
result_list.size()

torch.Size([2063, 2048])

In [9]:
tensor_list[2] - result_list[89:244]

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',
       grad_fn=<SubBackward0>)

In [10]:
x = result_list.unsqueeze(0)
y_pad = model(x, start_indecies=start_indecies)
y_pad[:, 89: 244, :]

tensor([[[ 0.0045, -0.0064, -0.0049,  ...,  0.0072, -0.0067, -0.0015],
         [ 0.0017,  0.0082,  0.0023,  ..., -0.0035,  0.0002, -0.0076],
         [-0.0049, -0.0077, -0.0040,  ...,  0.0025, -0.0043, -0.0007],
         ...,
         [-0.0154,  0.0111, -0.0051,  ..., -0.0027, -0.0177,  0.0217],
         [ 0.0064, -0.0137,  0.0193,  ...,  0.0096, -0.0096,  0.0073],
         [ 0.0160,  0.0276, -0.0057,  ..., -0.0105, -0.0158, -0.0119]]],
       device='cuda:0', grad_fn=<SliceBackward0>)

In [11]:
x = tensor_list[2].float().unsqueeze(0)
dim = x.shape[-1]
y = model(x, start_indecies=[])

In [12]:
y[:, :, :]

tensor([[[ 0.0045, -0.0064, -0.0049,  ...,  0.0072, -0.0067, -0.0015],
         [ 0.0017,  0.0082,  0.0023,  ..., -0.0035,  0.0002, -0.0076],
         [-0.0049, -0.0077, -0.0040,  ...,  0.0025, -0.0043, -0.0007],
         ...,
         [-0.0154,  0.0111, -0.0051,  ..., -0.0027, -0.0177,  0.0217],
         [ 0.0064, -0.0137,  0.0193,  ...,  0.0096, -0.0096,  0.0073],
         [ 0.0160,  0.0276, -0.0057,  ..., -0.0105, -0.0158, -0.0119]]],
       device='cuda:0', grad_fn=<SliceBackward0>)