In [None]:
from model_pytorch import Mamba, ModelArgs
from transformers import AutoTokenizer

import torch
import torch.nn.functional as F

import jax
from jax import numpy as jnp
from jax import random

from tqdm import tqdm

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

In [None]:
from model import Mamba as MambaJax
from model import ModelArgs as ModelArgsJax

modelArgsJax = ModelArgsJax(
    d_model=model.args.d_model,
    n_layer=model.args.n_layer,
    vocab_size=model.args.vocab_size,
)

modelJax = MambaJax(modelArgsJax)

from params import pytorch_to_jax_weights

weights = pytorch_to_jax_weights(model.state_dict())

In [None]:
def generate_jax(model,
             params,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40,
             temperature: float = 1.0,
             seed: int = 0):
    
    # @jax.jit
    def model_inference(params, input_ids):
        return model.apply({"params": params}, input_ids)

    initial_input_ids = jnp.array(tokenizer(prompt, return_tensors='jax').input_ids)
    max_len = initial_input_ids.shape[1] + n_tokens_to_gen
    input_ids = jnp.pad(initial_input_ids, ((0, 0), (0, n_tokens_to_gen)))
    layer_outputs = None
    
    # @jax.jit
    def generate_step(state):
        input_ids, current_pos, rng = state
        next_token_logits = model_inference(params, input_ids[:, :current_pos])[:, -1, :]
        
        next_token_logits = next_token_logits / temperature
        probs = jax.nn.softmax(next_token_logits, axis=-1)
        
        if top_k is not None:
            top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k)
            probs = jnp.zeros_like(probs).at[jnp.arange(probs.shape[0])[:, None], top_k_indices].set(top_k_probs)
            probs = probs / probs.sum(axis=-1, keepdims=True)
        
        if sample:
            rng, next_rng = random.split(rng)
            next_index = random.categorical(next_rng, probs, axis=-1)
        else:
            next_index = jnp.argmax(probs, axis=-1)
        
        input_ids = input_ids.at[:, current_pos].set(next_index)
        current_pos += 1
        
        return input_ids, current_pos, rng

    rng = random.PRNGKey(seed)
    initial_pos = initial_input_ids.shape[1]
    # final_state = jax.lax.while_loop(cond_fn, generate_step, (input_ids, initial_pos, rng))
    for _ in tqdm(range(n_tokens_to_gen)):
        input_ids, initial_pos, rng = generate_step((input_ids, initial_pos, rng))


    # final_input_ids, _, _ = final_state
    final_input_ids = input_ids
    output_completions = tokenizer.decode(final_input_ids[0, :max_len].tolist(), skip_special_tokens=True)
    
    return output_completions

In [None]:
def generate_pytorch(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40,):
    model.eval()
    
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    layer_outputs = None
    
    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input)[:, -1]
        
        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape
        
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)
        
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        
        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    
    return output_completions

In [None]:
# prompt = "The quick brown fox jumps"
prompt = "Mamba is the"

output= generate_jax(modelJax, weights, tokenizer, prompt, seed=0, n_tokens_to_gen=10, sample=False)
output

In [None]:
prompt = "Mamba is the"

output_pytorch = generate_pytorch(model, tokenizer, prompt, sample=False, n_tokens_to_gen=10)
output_pytorch