In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="3"

from llama import Llama

generator = Llama.build(
        ckpt_dir="/raid/sourab/mixtral-8x7b-32kseqlen/",
        tokenizer_path="/raid/sourab/mixtral-8x7b-32kseqlen/tokenizer.model",
        max_seq_len=128,
        max_batch_size=4,
        num_gpus=2,
    )

=== created Mixtral 8x7B. Experts spread over 2 GPUs ===
Loaded in 488.12 seconds


In [3]:
params = 0
for p in generator.model.parameters():
    params += p.numel()

params

46702792704

In [4]:
model = generator.model

In [5]:
model.tok_embeddings.weight

Parameter containing:
tensor([[-7.4938e-38,  1.2214e-38, -5.7305e-37,  ..., -1.8220e-37,
         -2.9534e-37, -2.1894e-37],
        [-1.4954e-02,  2.5940e-04, -1.6113e-02,  ...,  2.2531e-05,
         -1.6846e-02,  2.1973e-03],
        [-5.3883e-05,  3.1948e-05,  6.3777e-06,  ...,  4.0293e-05,
          1.1623e-05,  1.1474e-06],
        ...,
        [ 8.7891e-03, -1.3000e-02,  5.2795e-03,  ...,  5.5237e-03,
          3.7079e-03,  3.7231e-03],
        [ 1.9073e-03,  1.8677e-02, -4.9133e-03,  ...,  1.3245e-02,
          4.6997e-03,  5.6152e-03],
        [ 3.6926e-03,  1.6235e-02,  1.8597e-04,  ..., -7.7820e-03,
         -9.5215e-03, -2.4292e-02]], requires_grad=True)

In [6]:
import bitsandbytes as bnb
from llama.model import MoE
import torch


def replace_remaining_with_4bit(model):
    for name, module in model.named_children():
        if isinstance(module, torch.nn.Linear) and not isinstance(module, bnb.nn.modules.LinearSparse) and "output" not in name and "gate" not in name:
            weight = module.weight.data
            in_features = module.in_features
            out_features = module.out_features
            if module.bias is not None:
                    bias = module.bias
            module = bnb.nn.Linear4bit(
                            in_features,
                            out_features,
                            module.bias is not None,
                        )
            module.weight = bnb.nn.Params4bit(weight, requires_grad=False)
            if module.bias is not None:
                    module.bias = bias
            model._modules[name] = module
            
        if len(list(module.children())) > 0:
            replace_remaining_with_4bit(module)

    
    

def replace_moe_with_sparse_linear_and_remaining_with_4bit(model):
    for name, module in model.named_modules():
        if isinstance(module, MoE):
            for expert in module.experts:
                for name, layer in expert.named_children():
                    weight = layer.weight.data
                    if layer.bias is not None:
                        bias = layer.bias
                    in_features = layer.in_features
                    out_features = layer.out_features
                    layer = bnb.nn.modules.LinearSparse(in_features, out_features, layer.bias is not None)
                    layer.weight = bnb.nn.modules.ParamsSparse(weight)
                    if layer.bias is not None:
                        layer.bias = bias
                    setattr(expert, name, layer)
    replace_remaining_with_4bit(model)
                

In [7]:
%%time
replace_moe_with_sparse_linear_and_remaining_with_4bit(model)

CPU times: user 5min 22s, sys: 1min 9s, total: 6min 31s
Wall time: 6min 32s


In [8]:
model

Transformer(
  (tok_embeddings): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x TransformerBlock(
      (attention): Attention(
        (wq): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (wk): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (wv): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (wo): Linear4bit(in_features=4096, out_features=4096, bias=False)
      )
      (feed_forward): MoE(
        (experts): ModuleList(
          (0-7): 8 x FeedForward(
            (w1): LinearSparse(in_features=4096, out_features=14336, bias=False)
            (w2): LinearSparse(in_features=14336, out_features=4096, bias=False)
            (w3): LinearSparse(in_features=4096, out_features=14336, bias=False)
          )
        )
        (gate): Linear(in_features=4096, out_features=8, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_fea

In [9]:
model.to("cuda")
model.eval()

Transformer(
  (tok_embeddings): Embedding(32000, 4096)
  (layers): ModuleList(
    (0-31): 32 x TransformerBlock(
      (attention): Attention(
        (wq): Linear4bit(in_features=4096, out_features=4096, bias=False)
        (wk): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (wv): Linear4bit(in_features=4096, out_features=1024, bias=False)
        (wo): Linear4bit(in_features=4096, out_features=4096, bias=False)
      )
      (feed_forward): MoE(
        (experts): ModuleList(
          (0-7): 8 x FeedForward(
            (w1): LinearSparse(in_features=4096, out_features=14336, bias=False)
            (w2): LinearSparse(in_features=14336, out_features=4096, bias=False)
            (w3): LinearSparse(in_features=4096, out_features=14336, bias=False)
          )
        )
        (gate): Linear(in_features=4096, out_features=8, bias=False)
      )
      (attention_norm): RMSNorm()
      (ffn_norm): RMSNorm()
    )
  )
  (norm): RMSNorm()
  (output): Linear(in_fea

In [11]:
prompts = [
# For these prompts, the expected answer is the natural continuation of the prompt
"Mistral.ai is a company that",
"Simply put, the theory of relativity states that ",
"""A brief message congratulating the team on the launch:

Hi everyone,

I just """,
# Few shot prompt (providing a few examples before asking model to complete more);
"""Translate English to French:

sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>""",
]
with torch.autocast(dtype=torch.float16, device_type="cuda"):
    results = generator.text_completion(
    prompts,
    max_gen_len=64,
    temperature=0.2,
    top_p=0.95,
    )
    for prompt, result in zip(prompts, results):
        print(prompt)
        print(f"> {result['generation']}")
        print("\n==================================\n")

Mistral.ai is a company that
>                  ,       ,       , ,      ,   ,     a a,     ,s ,


Simply put, the theory of relativity states that 
>                          ,   ,   , s,  ,  ,s,  ,   ,s,s,s ,s,s,


A brief message congratulating the team on the launch:

Hi everyone,

I just 
>   ,         ,  ,   ,   s , -,s,s s,s,s,s,s.s,s,s.s,s,s,s.s,s,s,


Translate English to French:

sea otter => loutre de mer
peppermint => menthe poivrée
plush girafe => girafe peluche
cheese =>
> s,s,s,s.s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,s,


