In [4]:
from transformers import (
    AutoTokenizer,
    TrainingArguments,
    PreTrainedTokenizer,
    HfArgumentParser,
)
from mamba.configuration_mamba import MambaConfig
from mamba.modeling_mamba import MambaForCausalLM

In [5]:
model_name_or_path = "modeloutput/checkpoint-4000"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)

model = MambaForCausalLM.from_pretrained(model_name_or_path, device_map="cuda:0")

n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
n_params

371303424

In [11]:
inputs = tokenizer.encode("中国是", return_tensors="pt")
inputs = inputs.to(model.device)
outputs = model.generate(
    inputs,
    num_beams=4,
    max_new_tokens=128,
    do_sample=True,
    top_k=10,
    penalty_alpha=0.6,
    temperature=0.9,
    repetition_penalty=1.2,
)
print(tokenizer.decode(outputs[0]))

 <s>中国是全球最大的企业之一，是全球最大的企业之一，是全球最大的跨境电商企业，是全球最大的跨境电商企业，也是全球最大的跨境电商企业之一，是全球最大的跨境电商企业，是全球最大的跨境电商企业，是全球最大的跨境电商企业之一，是全球最大的跨境电商和跨境电商企业，是全球最大的跨境电商企业。
B2B是跨境电商企业，在跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境电商、跨境


In [None]:
import torch
import torch.nn.functional as F
from tqdm import tqdm


def generate(
    model,
    tokenizer,
    prompt: str,
    n_tokens_to_gen: int = 50,
    sample: bool = True,
    top_k: int = 10,
):
    model.eval()

    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(model.device)

    for token_n in tqdm(range(n_tokens_to_gen)):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input).logits[:, -1]

        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape

        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)

        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]

        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]

    return output_completions


print(generate(model, tokenizer, "中国是"))

In [None]:
print(generate(model, tokenizer, "如何看待"))