In [19]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import jax
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat
import pickle
import collections
from tqdm import tqdm 
import math

In [69]:
torch.manual_seed(1)

@dataclass
class MambaArgs:
    N: int 
    D: int
    n_layers: int
    vocab_size: int
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    expansion_factor: int = 2
    conv_1d_size: int = 4
    conv_bias: bool = True
    general_bias: bool = False # applies to the input and output projections

    def __post_init__(self):
        self.D_inner = int(self.expansion_factor * self.D)

class Mamba(nn.Module):
    ''' Full Mamba architecture '''
    def __init__(self, args: MambaArgs):
        super(Mamba, self).__init__()

        self.args = args

        self.embedding = nn.Embedding(args.vocab_size, args.D)

        self.layers = nn.ModuleList([ResidualMambaBlock(args) 
                                     for _ in range(args.n_layers)])
        self.norm_f = RMSNorm(args.D)

        self.logits = nn.Linear(args.D, args.vocab_size, bias=False)
        self.logits.weight = self.embedding.weight # weight tying! 

    def forward(self, x):

        x = self.embedding(x)
        
        for layer in self.layers:
            x = layer(x)
            
        x = self.norm_f(x)
        logits = self.logits(x)

        return logits

class ResidualMambaBlock(nn.Module):
    ''' Wraps the standard Mamba block with RMS normalization and residual
        connections (used everywhere)'''
    
    def __init__(self, args: MambaArgs):
        super(ResidualMambaBlock, self).__init__()

        self.args = args
        self.block = MambaBlock(args)
        self.rms = RMSNorm(args.D)
    
    def forward(self, x):
        return self.block(self.rms(x)) + x
    
class MambaBlock(nn.Module):
    ''' Standard Mamba block as illustrated in the paper '''
    def __init__(self, args: MambaArgs):
        super(MambaBlock, self).__init__()

        self.args = args

        # takes care of both of the upscale projections, factor of 2!
        self.in_proj = nn.Linear(args.D, 2*args.D_inner, bias=args.general_bias)
        self.conv1d = nn.Conv1d(
            in_channels=args.D_inner,
            out_channels=args.D_inner,
            bias=args.conv_bias,
            kernel_size=args.conv_1d_size,
            groups=args.D_inner,
            padding=args.conv_1d_size - 1,
        )    
        self.s6_block = S6Block(args)    

        self.out_proj = nn.Linear(args.D_inner, args.D, bias=args.general_bias)

    def forward(self, x):
        b, l, _ = x.shape # used to avoid specifying these in args
        x = self.in_proj(x)
        # split the input into the two paths
        (x, res) = x.split(
            split_size=[self.args.D_inner, self.args.D_inner], dim=-1)

        # input of shape (B,L,D), dimensions need switching for convolution
        x = torch.transpose(x, 1,2)
        x = self.conv1d(x)[:,:,:l] # the limit is needed because of the padding
        x = torch.transpose(x, 1,2)

        x = F.silu(x)
        x = self.s6_block(x)
        x = x * F.silu(res)

        y = self.out_proj(x)

        return y

class S6Block(nn.Module):
    ''' Inner SSM block '''
    def __init__(self, args: MambaArgs):

        super(S6Block, self).__init__()
        self.args = args 

        def s4d_real():
            # initialization for A used in the paper. Other complex-valued 
            # initializations also possible

            # compute one diagonal, then broadcast across D dimensions
            A = -(torch.arange(0,args.N)+1) 

            return A.unsqueeze(0).repeat(args.D_inner,1).float()

        def get_delta_bias():
            # sample from a uniform distribution bounded within these values,
            # then pass through an inverse softplus
            a = 0.001
            b = 0.1
            sample = (b-a)* torch.rand(1) + a
            # no built-in pytorch version of inverse softplus... numerical issues?
            return torch.log(torch.exp(sample-1))

        # the same A is broadcasted across all D token dimensions. A is a 
        # diagonal matrix, which is why we represent it only with its diagonal
        # elements. 
        self.A = nn.Parameter(s4d_real())
        
        # B, C, and delta are different for each token, but the linear projections
        # which generate them are shared across all tokens and batches. All
        # parameters are generated at once, and then split + broadcasted
        # as necessary. These are strictly linear projections, no biases used.
        # Delta uses one, which we manually add later. 
        self.to_BCdelta = nn.Linear(args.D_inner, 2*args.N+1, bias=False)

        # explicit delta bias term. Same term used for each projection across
        # batches and sequences. 
        self.delta_bias = nn.Parameter(get_delta_bias())
        
    def discretize(self, delta, B):

        delta_A = torch.einsum('bld,dn->bldn', delta, self.A)
        # effect: multiplying each vector N to be used at every l, b by the 
        # corresponding discretization constant at that position, then broadcasted 
        # over a new dimension D (we want to use the same A matrix for each input dimension)
        A_bar = torch.exp(delta_A)
        delta_B = torch.einsum('bld,bln->bldn', delta, B)
        # same effect as for delta_A, but shapes are different because B is 
        # directly defined for each l, b
        B_bar = 1/(delta_A) * (A_bar - 1) * delta_B
        # diagonal matrices, so 1/A is the inverse, subtracting 1 instead 
        # of the identity matrix, and directly multiplying elementwise for the 
        # first multiplication (second is defined elementwise anyway)

        return A_bar, B_bar

    def forward(self, x):
        b, l, _ = x.shape 
        # generate all projected parameters and split them up
        BCdelta = self.to_BCdelta(x)
        # delta: (B, L, 1). B, C: (B, L, N)
        (B, C, delta) = BCdelta.split(
            split_size=[self.args.N, self.args.N, 1], dim=-1)

        # broadcasting for delta and computing final parameters
        delta = delta.repeat(1,1,self.args.D_inner) # (B,L,D)
        delta += self.delta_bias
        delta = F.softplus(delta)

        # discretization
        A_bar, B_bar = self.discretize(delta, B) # (B, L, D, N)
        
        # input transformation is parallelizable
        input_transform = B_bar * x.unsqueeze(-1) # (B, L, D, N)
        
        # scan through each individual token to compute hidden states
        hidden_states = torch.zeros(
            b, l+1, self.args.D_inner, self.args.N).to(self.args.device)
        
        for i in range(0,l):
            # because A is represented only through diagonal, Ah_t-1 is 
            # equivalent to taking the elementwise product of the diagonal
            # and the hidden state
            hidden_states[:,i+1,:,:] = A_bar[:,i,:,:]*hidden_states[:,i,:,:].clone() + \
                input_transform[:,i,:,:] # (B,D,N)
        
        # compute outputs in parallel
        outputs = torch.einsum('bln,bldn->bld', C, hidden_states[:,1:,:,:])

        return outputs


class RMSNorm(nn.Module):
    ''' Simple implementation of RMSNorm. Default implementation is bugged
        in this version of PyTorch, don't want to mess with version updating '''
    def __init__(self,
                 D: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(D))

    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

In [77]:
# use a simple word-level tokenizer to create a tokenized version of the kaggle 
# song lyrics dataset. Simply training to predict the next word here. 

L = 32
B = 8
vocab_size = 10000

# device = 'mps'

!wget -nc https://umuguc.github.io/file-sharing/kaggle_song_lyrics_dataset.pkl.zip
!unzip -n kaggle_song_lyrics_dataset.pkl.zip -d kaggle_song_lyrics_dataset

with open("kaggle_song_lyrics_dataset/kaggle_song_lyrics_dataset.pkl", "rb") as f:
    seqs = pickle.load(f)

vocab       = ["<unk>"] + (lambda counter: sorted(counter, key=counter.get, reverse=True))(collections.Counter(seqs))[:vocab_size - 1]
idx_to_word = {idx: word for idx, word in enumerate(vocab)}
word_to_idx = {word: idx for idx, word in enumerate(vocab)}

class SeqDataset(Dataset):
    def __init__(self, device, seq_size, seqs):
        super(SeqDataset, self).__init__()
        self.device   = device
        self.seq_size = seq_size
        self.seqs     = seqs

    def __len__(self):
        return len(self.seqs) - self.seq_size - 1

    def __getitem__(self, idx):
        in_seq     = torch.tensor(self.seqs[idx    :idx + self.seq_size    ], dtype=torch.long, device=self.device)
        target_seq = torch.tensor(self.seqs[idx + 1:idx + self.seq_size + 1], dtype=torch.long, device=self.device)
        return in_seq, target_seq

seqs = seqs[:len(seqs)//100]

train_set = SeqDataset(device, L, [word_to_idx.get(word, 0) for word in seqs[                        :int(0.8 * len(seqs))]])
val_set   = SeqDataset(device, L, [word_to_idx.get(word, 0) for word in seqs[int(0.8 * len(seqs)) + 1:int(0.9 * len(seqs))]])
test_set  = SeqDataset(device, L, [word_to_idx.get(word, 0) for word in seqs[int(0.9 * len(seqs)) + 1:                    ]])

train_loader = DataLoader(train_set, B, True )
val_loader   = DataLoader(val_set  , B, False)
test_loader  = DataLoader(test_set , B, False)

In [78]:
# training! 
torch.autograd.set_detect_anomaly(True)
n_epochs = 3

D = 16
N = 8
n_layers = 5
args = MambaArgs(N, D, n_layers, vocab_size, device)
model = Mamba(args).to(args.device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

min_loss = float("inf")

for epoch in range(n_epochs):
    print(f"Epoch: {epoch + 1}/{n_epochs}")
    
    model.train()
    train_loss = 0.0
    
    for in_seq, target_seq in tqdm(train_loader):
        out_seq  = model(in_seq)
        loss        = criterion(out_seq.view(-1, vocab_size), target_seq.view(-1))
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss       /= len(train_loader)
    train_perplexity  = math.exp(train_loss)
    print(f"Train loss: {train_loss:.4f}, Train perplexity: {train_perplexity:.4f}")

    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for in_seq, target_seq in val_loader:
            out_seq, _  = model(in_seq)
            loss        = criterion(out_seq.view(-1, vocab_size), target_seq.view(-1))
            val_loss   += loss.item()

    val_loss       /= len(val_loader)
    val_perplexity  = math.exp(val_loss)
    print(f"Val loss: {val_loss:.4f}, Val perplexity: {val_perplexity:.4f}")
    
    # if val_loss < min_loss:
    #     min_loss = val_loss
    #     torch.save(model.state_dict(), f"kaggle_song_lyrics_dataset/{MODEL_NAME:s}.pt")

Epoch: 1/3


  0%|          | 6/1419 [00:12<48:33,  2.06s/it]


KeyboardInterrupt: 