## (1) Load model

In [1]:
from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
# pretrained_model_name = 'state-spaces/mamba-370m'
pretrained_model_name = '/share/public_models/mamba-2.8b'

model = Mamba.from_pretrained(pretrained_model_name).cuda()
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

  from .autonotebook import tqdm as notebook_tqdm


In [8]:
print(model)

Mamba(
  (embedding): Embedding(50280, 2560)
  (layers): ModuleList(
    (0-63): 64 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=2560, out_features=10240, bias=False)
        (conv1d): Conv1d(5120, 5120, kernel_size=(4,), stride=(1,), padding=(3,), groups=5120)
        (x_proj): Linear(in_features=5120, out_features=192, bias=False)
        (dt_proj): Linear(in_features=160, out_features=5120, bias=True)
        (out_proj): Linear(in_features=5120, out_features=2560, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=2560, out_features=50280, bias=False)
)


## (2) Generate Text

In [10]:
import torch
import torch.nn.functional as F


def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40):
    model.eval()
    
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
    
    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input)[:, -1]
        
        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape
        
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)
        
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        
        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    
    return output_completions

In [11]:
print(generate(model, tokenizer, 'Mamba is the'))

Mamba is the Mamba"

"How about your brother" was the question from the reporter.

"Oh he's gone to Paris, my brother has been to Paris many times, my brother goes
to Paris on business and Paris is the only


In [12]:
print(generate(model, tokenizer, 'John: Hi!\nSally:'))

John: Hi!
Sally: Hi.
John: How's your momm--
John: How's your--
(John and Sally scream together)
John: It's okay, I know, it's--
I--
Sally: Stop it, you


In [13]:
print(generate(model, tokenizer, 'The meaning of life is '))

The meaning of life is ~~~~~to stop worrying and then enjoy life.

The three most important things are ~~~~

You cannot use your intellect

You cannot judge the time.

You cannot judge the place.

You are always one step


In [14]:
print(generate(model, tokenizer, 'def reverse_string('))

def reverse_string(string):
    # Write code here
    return string.reverse()

# Write your code here
def reverse_string2(string):
    for i in range(len(string)):
        string[i] = reverse


In [15]:
print(generate(model, tokenizer, 'My cat wrote all this CUDA code for a new language model and'))


OpenCL: failed to create a device object (error: 3, e.g. No such device
