In [128]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import jax
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat
import pickle
import collections
from tqdm import tqdm 
import math
from transformers import GPT2Tokenizer
from typing import Callable

In [197]:
torch.manual_seed(1)

@dataclass
class LMTrainingArgs:
    # NB this is not the actual learning rate, see below! 
    gpt_3_peak_lr: float # GPT3 spec, copy from table depending on size
    warmup_epochs: int 
    n_epochs: int

    # default arguments, specified according to Mamba paper. This is the 
    # default for language modeling, for artificial tasks use the other
    # args class!

    min_lr: float = 1e-5
    weight_decay: float = 0.1
    gradient_clip: float = 1.0
    adam_beta: tuple = (0.9, 0.95)
    adam_epsilon: float = 1e-8
    optimizer: str = "AdamW" # TODO: implement ability to use Adam, just because
    
    def __post_init__(self):
        self.peak_lr: float = 5*self.gpt_3_peak_lr # the actual lr in the training recipe
        self.schedule_fn: Callable[[int], float] = self.lm_learning_schedule

        assert self.warmup_epochs < self.n_epochs, "Warmup epochs > total epochs"
        assert self.optimizer == "AdamW" or self.optimizer == "Adam", 'Invalid optimizer'
        
    def lm_learning_schedule(self, epoch):
        # a cosine decay with a minimum value, with a linear warm-up
        if epoch < self.warmup_epochs:
            return float(epoch+1) / float(max(1, self.warmup_epochs))
        else:
            # calculate amount of decay progress
            progress = float(epoch - self.warmup_epochs + 1) / \
                            float(max(1, self.n_epochs - self.warmup_epochs))
            # shift cosine function up, rescale, and compute the appropriate amount
            # of decay
            cosine_decay = 0.5 * (1+ math.cos(math.pi * progress))
            # rescale the function again so that it doesn't go below the minimum
            # value
            return cosine_decay * (1 - self.min_lr / self.peak_lr) + self.min_lr / self.peak_lr
    
    def show_lr_schedule(self):
        # visualization of the learning rate schedule given the specified
        # training protocol
        epochs = np.arange(0, self.n_epochs)
        lr = np.zeros(len(epochs))
        for e in epochs:
            lr[e] = self.peak_lr*self.lm_learning_schedule(e)

        min_lr = np.min(lr[self.warmup_epochs+1:])
        max_lr = np.max(lr)

        plt.figure()
        plt.plot(epochs, lr)
        plt.xlim([0,self.n_epochs-1])
        plt.xlabel("Epoch")
        plt.ylabel("Learning rate")
        plt.title(f"Max = {max_lr}, \nMin = {min_lr:.2e}")
        plt.show()

@dataclass
class MambaArgs:
    N: int 
    D: int
    n_layers: int
    vocab_size: int
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    expansion_factor: int = 2
    conv_1d_size: int = 4
    conv_bias: bool = True
    general_bias: bool = False # applies to the input and output projections

    def __post_init__(self):
        self.D_inner = int(self.expansion_factor * self.D)

class Mamba(nn.Module):
    ''' Full Mamba architecture '''
    def __init__(self, args: MambaArgs):
        super(Mamba, self).__init__()

        self.args = args

        self.embedding = nn.Embedding(args.vocab_size, args.D)

        self.layers = nn.ModuleList([ResidualMambaBlock(args) 
                                     for _ in range(args.n_layers)])
        self.norm_f = RMSNorm(args.D)

        self.logits = nn.Linear(args.D, args.vocab_size, bias=False)
        self.logits.weight = self.embedding.weight # weight tying! 

    def forward(self, x):

        x = self.embedding(x)
        
        for layer in self.layers:
            x = layer(x)
            
        x = self.norm_f(x)
        logits = self.logits(x)

        return logits

class ResidualMambaBlock(nn.Module):
    ''' Wraps the standard Mamba block with RMS normalization and residual
        connections (used everywhere)'''
    
    def __init__(self, args: MambaArgs):
        super(ResidualMambaBlock, self).__init__()

        self.args = args
        self.block = MambaBlock(args)
        self.rms = RMSNorm(args.D)
    
    def forward(self, x):
        return self.block(self.rms(x)) + x
    
class MambaBlock(nn.Module):
    ''' Standard Mamba block as illustrated in the paper '''
    def __init__(self, args: MambaArgs):
        super(MambaBlock, self).__init__()

        self.args = args

        # takes care of both of the upscale projections, factor of 2!
        self.in_proj = nn.Linear(args.D, 2*args.D_inner, bias=args.general_bias)
        self.conv1d = nn.Conv1d(
            in_channels=args.D_inner,
            out_channels=args.D_inner,
            bias=args.conv_bias,
            kernel_size=args.conv_1d_size,
            groups=args.D_inner,
            padding=args.conv_1d_size - 1,
        )    
        self.s6_block = S6Block(args)    

        self.out_proj = nn.Linear(args.D_inner, args.D, bias=args.general_bias)

    def forward(self, x):
        b, l, _ = x.shape # used to avoid specifying these in args
        x = self.in_proj(x)
        # split the input into the two paths
        (x, res) = x.split(
            split_size=[self.args.D_inner, self.args.D_inner], dim=-1)

        # input of shape (B,L,D), dimensions need switching for convolution
        x = torch.transpose(x, 1,2)
        x = self.conv1d(x)[:,:,:l] # the limit is needed because of the padding
        x = torch.transpose(x, 1,2)

        x = F.silu(x)
        x = self.s6_block(x)
        x = x * F.silu(res)

        y = self.out_proj(x)

        return y

class S6Block(nn.Module):
    ''' Inner SSM block '''
    def __init__(self, args: MambaArgs):

        super(S6Block, self).__init__()
        self.args = args 

        def s4d_real():
            # initialization for A used in the paper. Other complex-valued 
            # initializations also possible

            # compute one diagonal, then broadcast across D dimensions
            A = -(torch.arange(0,args.N)+1) 

            return A.unsqueeze(0).repeat(args.D_inner,1).float()

        def get_delta_bias():
            # sample from a uniform distribution bounded within these values,
            # then pass through an inverse softplus
            a = 0.001
            b = 0.1
            sample = (b-a)* torch.rand(1) + a
            # no built-in pytorch version of inverse softplus... numerical issues?
            return torch.log(torch.exp(sample-1))

        self.A = nn.Parameter(s4d_real())
        
        # these are strictly linear projections, no biases used ever.
        # delta uses one, which we manually add later. 
        self.to_BCdelta = nn.Linear(args.D_inner, 2*args.N+1, bias=False)

        # explicit delta bias term
        self.delta_bias = nn.Parameter(get_delta_bias()) #TODO:  Check for training
        
    def discretize(self, delta, B):

        # ZOH discretization. Official implementation approximates B_bar with
        # Euler step instead
        delta_A = torch.einsum('bld,dn->bldn', delta, self.A)
        A_bar = torch.exp(delta_A)
        delta_B = torch.einsum('bld,bln->bldn', delta, B)
        # diagonal matrices, so 1/A is the inverse, subtracting 1 instead 
        # of the identity matrix, and directly multiplying elementwise for the 
        # first multiplication (second is defined elementwise anyway)
        B_bar = 1/(delta_A) * (A_bar - 1) * delta_B

        return A_bar, B_bar

    def forward(self, x):
        b, l, _ = x.shape 
        # generate all projected parameters and split them up
        BCdelta = self.to_BCdelta(x)
        # delta: (B, L, 1). B, C: (B, L, N)
        (B, C, delta) = BCdelta.split(
            split_size=[self.args.N, self.args.N, 1], dim=-1)

        # broadcasting for delta and computing final parameters
        delta = delta.repeat(1,1,self.args.D_inner) # (B,L,D)
        delta += self.delta_bias
        delta = F.softplus(delta)

        # discretization
        A_bar, B_bar = self.discretize(delta, B) # (B, L, D, N)
        
        # input transformation is parallelizable
        input_transform = B_bar * x.unsqueeze(-1) # (B, L, D, N)
        
        # scan through each individual token to compute hidden states
        hidden_states = torch.zeros(
            b, l+1, self.args.D_inner, self.args.N).to(self.args.device)
        
        for i in range(0,l):
            # because A is represented only through diagonal, Ah_t-1 is 
            # equivalent to taking the elementwise product of the diagonal
            # and the hidden state
            hidden_states[:,i+1,:,:] = A_bar[:,i,:,:]*hidden_states[:,i,:,:].clone() + \
                input_transform[:,i,:,:] # (B,D,N)
        
        # compute outputs in parallel
        outputs = torch.einsum('bln,bldn->bld', C, hidden_states[:,1:,:,:])

        return outputs


class RMSNorm(nn.Module):
    ''' Simple implementation of RMSNorm. Default implementation is bugged
        in this version of PyTorch, don't want to mess with version updating '''
    def __init__(self,
                 D: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(D))

    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

In [141]:
# use GPT3 tokenizer tokenizer to create a tokenized version of the kaggle 
# song lyrics dataset. Simply training to predict the next word here. 

L = 32
B = 8
D = 16
N = 8

vocab_size = 10000
device = 'mps' # hehe 

mamba_args = MambaArgs(N, D, n_layers=5, vocab_size=vocab_size, device=device)
model = Mamba(mamba_args).to(mamba_args.device)

# tokenizer = GPT2Tokenizer.from_pretrained("openai-community/gpt2", force_download=True)
tokenizer = GPT2Tokenizer.from_pretrained("model")

with open("kaggle_song_lyrics_dataset/kaggle_song_lyrics_dataset.pkl", "rb") as f:
    seqs = pickle.load(f)

# get tokens and corresponding IDs
tokens = [tokenizer.tokenize(word) for word in seqs]
token_ids = [tokenizer.convert_tokens_to_ids(token) for token in tokens]

# flattening them all in one list
token_ids = [item for sublist in token_ids for item in sublist] 

# limit all tokens to the top n most common. Replace the less common occurrences
# with unk token 
filtered_ids = []
unk_id = tokenizer.get_vocab()['unk']
for token_id in token_ids:

    # if given vocab size is not large enough to include the unk token itself, 
    # the vocab size must be reduced by 1 to fit this in
    if unk_id > mamba_args.vocab_size-1:
        vocab_limit = mamba_args.vocab_size-2
    else:
        vocab_limit = mamba_args.vocab_size-1

    # filter the text to only include top n most common words
    if token_id > vocab_limit and token_id != unk_id:
        filtered_ids.append(unk_id)
    else:
        filtered_ids.append(token_id)

# put the tokenized sequences in datasets + dataloaders

class SeqDataset(Dataset):
    def __init__(self, device, seq_size, seqs):
        super(SeqDataset, self).__init__()
        self.device   = device
        self.seq_size = seq_size
        self.seqs     = seqs

    def __len__(self):
        return len(self.seqs) - self.seq_size - 1

    def __getitem__(self, idx):
        in_seq     = torch.tensor(self.seqs[idx    :idx + self.seq_size    ], dtype=torch.long, device=self.device)
        target_seq = torch.tensor(self.seqs[idx + 1:idx + self.seq_size + 1], dtype=torch.long, device=self.device)
        return in_seq, target_seq

# 90 10 10 split between the 3 datasets
train_set = SeqDataset(mamba_args.device, L, filtered_ids[:int(0.8 * len(filtered_ids))])
val_set   = SeqDataset(mamba_args.device, L, filtered_ids[ int(0.8 * len(filtered_ids)) + 1:int(0.9 * len(filtered_ids))])
test_set  = SeqDataset(mamba_args.device, L, filtered_ids[ int(0.9 * len(filtered_ids)) + 1:])

train_loader = DataLoader(train_set, B, shuffle=True)
val_loader   = DataLoader(val_set  , B, shuffle=False)
test_loader  = DataLoader(test_set , B, shuffle=False)

In [199]:
# training! 
torch.autograd.set_detect_anomaly(True)

train_args = LMTrainingArgs(gpt_3_peak_lr=1.5e-3, warmup_epochs=10, n_epochs=100)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), 
                              lr=train_args.peak_lr, 
                              betas=train_args.adam_beta,
                              eps=train_args.adam_epsilon,
                              weight_decay=train_args.weight_decay)

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=train_args.schedule_fn)

# train_args.show_lr_schedule()

for epoch in range(train_args.n_epochs):
    print(f"Epoch: {epoch + 1}/{train_args.n_epochs}")

    model.train()
    train_loss = 0.0
    
    for in_seq, target_seq in tqdm(train_loader):
        out_seq = model(in_seq)
        loss = criterion(out_seq.view(-1, mamba_args.vocab_size), target_seq.view(-1))
        train_loss += loss.item()
        optimizer.zero_grad()
        loss.backward()

        # gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), train_args.gradient_clip)

        optimizer.step()
        scheduler.step()

    train_loss /= len(train_loader)
    train_perplexity = math.exp(train_loss)
    print(f"Train loss: {train_loss:.4f}, Train perplexity: {train_perplexity:.4f}")

    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for in_seq, target_seq in val_loader:
            out_seq, _ = model(in_seq)
            loss = criterion(out_seq.view(-1, mamba_args.vocab_size), target_seq.view(-1))
            val_loss += loss.item()

    val_loss /= len(val_loader)
    val_perplexity = math.exp(val_loss)
    print(f"Val loss: {val_loss:.4f}, Val perplexity: {val_perplexity:.4f}")



Epoch: 1/100


  0%|          | 2/167610 [00:05<132:21:14,  2.84s/it]


KeyboardInterrupt: 