## (1) Load model

In [3]:
from mamba.model import Mamba, ModelArgs
from transformers import AutoTokenizer
from tqdm import tqdm
# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'weights/mamba-370m'

# model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained("weights/gpt-neox-20b")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
pretrained_model_name = 'weights/mamba-370m'

model = Mamba.from_pretrained(pretrained_model_name)

  return self.fget.__get__(instance, owner)()


In [5]:
model.args

ModelArgs(d_model=1024, n_layer=48, vocab_size=50280, d_state=16, expand=2, dt_rank=64, d_conv=4, pad_vocab_size_multiple=8, conv_bias=True, bias=False)

## (2) Generate Text

In [6]:
import torch
import torch.nn.functional as F


def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40):
    model.eval()

    input_ids = tokenizer(prompt, return_tensors='pt').input_ids

    for token_n in tqdm(range(n_tokens_to_gen)):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input)[:, -1]

        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape

        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)

        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]

        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(
        output.tolist()) for output in input_ids][0]

    return output_completions

In [7]:
print(generate(model, tokenizer, 'Mamba is the'))

  0%|          | 0/50 [00:00<?, ?it/s]huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
100%|██████████| 50/50 [01:09<00:00,  1.40s/it]

Mamba is the first all-electric motorcycle on the market. It was launched in Japan last summer after the successful release of the first generation Electric Bike. Mamba was created with the goal to combine the qualities of e-bikes with the efficiency of conventional motor





In [8]:
print(generate(model, tokenizer, 'John: Hi!\nSally:'))

100%|██████████| 50/50 [01:07<00:00,  1.34s/it]

John: Hi!
Sally: Hi!
John: We were out and had a nice conversation
about your book and how it was done.
Sally: It definitely was.
John: What was that like when you wrote it and the
audiences were interested in





In [None]:
print(generate(model, tokenizer, 'The meaning of life is '))

In [None]:
print(generate(model, tokenizer, 'def reverse_string('))

In [11]:
model

Mamba(
  (embedding): Embedding(50280, 1024)
  (layers): ModuleList(
    (0-47): 48 x ResidualBlock(
      (mixer): MambaBlock(
        (in_proj): Linear(in_features=1024, out_features=4096, bias=False)
        (conv1d): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
        (x_proj): Linear(in_features=2048, out_features=96, bias=False)
        (dt_proj): Linear(in_features=64, out_features=2048, bias=True)
        (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (norm_f): RMSNorm()
  (lm_head): Linear(in_features=1024, out_features=50280, bias=False)
)

ResidualBlock(
  (mixer): MambaBlock(
    (in_proj): Linear(in_features=1024, out_features=4096, bias=False)
    (conv1d): Conv1d(2048, 2048, kernel_size=(4,), stride=(1,), padding=(3,), groups=2048)
    (x_proj): Linear(in_features=2048, out_features=96, bias=False)
    (dt_proj): Linear(in_features=64, out_features=2048, bias=True)
    (out_proj): Linear(in_features=2048, out_features=1024, bias=False)
  )
  (norm): RMSNorm()
)