# Mixtral in Colab

Welcome! In this notebook you can run [Mixtral8x7B-Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) with decent generation speed **right in Google Colab or on a consumer-grade GPU**. This was made possible by quantizing the original model in mixed precision and implementing a MoE-specific offloading strategy.

To learn more, read our [tech report](https://arxiv.org/abs/2312.17238) or check out the [repo](https://github.com/dvmazur/mixtral-offloading) on GitHub.

One will need approximately 16 GB of VRAM and 11 GB of RAM to run this notebook and generate somewhat long texts.


<details>

<summary>How to balance between RAM and GPU VRAM usage</summary>

You can balance between RAM and GPU VRAM usage by changing <code>offload_per_layer</code> variable in the <a href="#scrollTo=_mIpePTMFyRY&line=10&uniqifier=1">Initialize model</a> section. Increasing <code>offload_per_layer</code> will decrease GPU VRAM usage, increase RAM usage and decrease generation speed. Decreasing <code>offload_per_layer</code> will have the opposite effect.

Note that this notebook should run normally in Google Colab with <code>offload_per_layer = 4</code>, but may crush with other values. However, if you run this somewhere else, you're free to play with this variable.
</details>

## Install and import libraries

In [1]:
%load_ext autoreload
%autoreload 2
# fix numpy in colab
import numpy
from IPython.display import clear_output

# fix triton in colab
!export LC_ALL="en_US.UTF-8"
!export HF_HOME=/scratch/bcjw/yyuan6/.cache
# !export LD_LIBRARY_PATH="/usr/lib64-nvidia"
#!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
# !ldconfig /usr/lib64-nvidia

#!git clone https://github.com/dvmazur/mixtral-offloading.git --quiet
#!cd mixtral-offloading && pip install -q -r requirements.txt
#!huggingface-cli download lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo --quiet --local-dir Mixtral-8x7B-Instruct-v0.1-offloading-demo

# clear_output()


In [2]:
import sys

sys.path.append("/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading")
import torch
from torch.nn import functional as F
from hqq.core.quantize import BaseQuantizeConfig
# from huggingface_hub import snapshot_download
# from IPython.display import clear_output
# from tqdm.auto import trange
from transformers import AutoConfig, AutoTokenizer
# from transformers.utils import logging as hf_logging
from src.build_model import OffloadConfig, QuantConfig, build_model, build_model_without_quant



[36mhqq_aten package not installed. HQQBackend.ATEN backend will not work unless you install the hqq_aten lib in hqq/kernels.[0m


  from .autonotebook import tqdm as notebook_tqdm


## Initialize model

In [3]:

try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except NameError:
    pass
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/Mixtral-8x7B-Instruct-v0.1"
# state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
config = AutoConfig.from_pretrained(model_name, torch_dtype=torch.float16,)
device = torch.device("cuda:0")
config.num_experts_per_tok = 1

import gc
gc.collect()
torch.cuda.empty_cache()

peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
print(f"initial GPU Memory Usage: {peak_memory_usage} GB")  

torch.cuda.reset_peak_memory_stats()

offload_per_layer = 7

num_experts = config.num_local_experts
print("number experts:", num_experts)

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
)

print(offload_config)

model = build_model_without_quant(
    device=device,
    offload_config=offload_config,
    state_path=state_path,
)
peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
print(f"Peak GPU Memory Usage: {peak_memory_usage} GB")


initial GPU Memory Usage: 0.0 GB
number experts: 8
OffloadConfig(main_size=32, offload_size=224, buffer_size=4, offload_per_layer=7)




Peak GPU Memory Usage: 8.459648132324219 GB
Peak GPU Memory Usage: 13.100425720214844 GB
Peak GPU Memory Usage: 13.256828308105469 GB
Peak GPU Memory Usage: 13.413230895996094 GB
Peak GPU Memory Usage: 13.413230895996094 GB
Peak GPU Memory Usage: 13.413230895996094 GB
Peak GPU Memory Usage: 14.04840087890625 GB
Peak GPU Memory Usage: 14.120803833007812 GB
Peak GPU Memory Usage: 14.199005126953125 GB
Peak GPU Memory Usage: 14.605331420898438 GB
Peak GPU Memory Usage: 14.761734008789062 GB
Peak GPU Memory Usage: 14.918136596679688 GB
Peak GPU Memory Usage: 15.074539184570312 GB
Peak GPU Memory Usage: 15.152740478515625 GB
Peak GPU Memory Usage: 15.230941772460938 GB
Peak GPU Memory Usage: 15.387344360351562 GB
Peak GPU Memory Usage: 15.543746948242188 GB
Peak GPU Memory Usage: 15.700149536132812 GB


MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40.


Peak GPU Memory Usage: 15.778350830078125 GB
dict_keys(['model.layers.10.block_sparse_moe.gate.weight', 'model.layers.10.self_attn.k_proj.weight', 'model.layers.10.self_attn.o_proj.weight', 'model.layers.10.self_attn.q_proj.weight', 'model.layers.10.self_attn.v_proj.weight', 'model.layers.8.input_layernorm.weight', 'model.layers.8.post_attention_layernorm.weight', 'model.layers.9.block_sparse_moe.gate.weight', 'model.layers.9.input_layernorm.weight', 'model.layers.9.post_attention_layernorm.weight', 'model.layers.9.self_attn.k_proj.weight', 'model.layers.9.self_attn.o_proj.weight', 'model.layers.9.self_attn.q_proj.weight', 'model.layers.9.self_attn.v_proj.weight', 'model.layers.18.input_layernorm.weight', 'model.layers.18.post_attention_layernorm.weight', 'model.layers.19.block_sparse_moe.gate.weight', 'model.layers.19.input_layernorm.weight', 'model.layers.19.post_attention_layernorm.weight', 'model.layers.19.self_attn.k_proj.weight', 'model.layers.19.self_attn.o_proj.weight', 'model.

Loading experts:   0%|          | 0/32 [00:00<?, ?it/s]

Peak GPU Memory Usage: 27.335128784179688 GB
Peak GPU Memory Usage: 27.663253784179688 GB
Peak GPU Memory Usage: 27.663253784179688 GB
Peak GPU Memory Usage: 27.991378784179688 GB
Peak GPU Memory Usage: 28.319503784179688 GB
Peak GPU Memory Usage: 28.647628784179688 GB
Peak GPU Memory Usage: 28.647628784179688 GB


Loading experts:   3%|▎         | 1/32 [00:21<11:02, 21.37s/it]

Peak GPU Memory Usage: 28.647628784179688 GB
Peak GPU Memory Usage: 28.647628784179688 GB
Peak GPU Memory Usage: 28.647628784179688 GB
Peak GPU Memory Usage: 28.897628784179688 GB
Peak GPU Memory Usage: 28.897628784179688 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB


Loading experts:   6%|▋         | 2/32 [00:44<11:10, 22.36s/it]

Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB


Loading experts:   9%|▉         | 3/32 [01:06<10:40, 22.10s/it]

Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 32.63215637207031 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB


Loading experts:  12%|█▎        | 4/32 [01:31<10:51, 23.28s/it]

Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB


Loading experts:  16%|█▌        | 5/32 [01:53<10:13, 22.72s/it]

Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB


Loading experts:  19%|█▉        | 6/32 [02:14<09:40, 22.34s/it]

Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.21601867675781 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB


Loading experts:  22%|██▏       | 7/32 [02:38<09:27, 22.69s/it]

Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB


Loading experts:  25%|██▌       | 8/32 [02:59<08:57, 22.41s/it]

Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB


Loading experts:  28%|██▊       | 9/32 [03:26<09:04, 23.69s/it]

Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB


Loading experts:  31%|███▏      | 10/32 [03:48<08:33, 23.34s/it]

Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB


Loading experts:  34%|███▍      | 11/32 [04:10<08:00, 22.89s/it]

Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.62232971191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB


Loading experts:  38%|███▊      | 12/32 [04:34<07:42, 23.11s/it]

Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB


Loading experts:  41%|████      | 13/32 [04:56<07:13, 22.79s/it]

Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 33.95045471191406 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB


Loading experts:  44%|████▍     | 14/32 [05:23<07:13, 24.10s/it]

Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB


Loading experts:  47%|████▋     | 15/32 [05:45<06:39, 23.50s/it]

Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB


Loading experts:  50%|█████     | 16/32 [06:10<06:21, 23.84s/it]

Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB


Loading experts:  53%|█████▎    | 17/32 [06:36<06:10, 24.68s/it]

Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB


Loading experts:  56%|█████▋    | 18/32 [06:59<05:36, 24.00s/it]

Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB


Loading experts:  59%|█████▉    | 19/32 [07:26<05:23, 24.91s/it]

Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB


Loading experts:  62%|██████▎   | 20/32 [07:48<04:48, 24.07s/it]

Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB


Loading experts:  66%|██████▌   | 21/32 [08:12<04:25, 24.17s/it]

Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.341156005859375 GB


Loading experts:  69%|██████▉   | 22/32 [08:35<03:55, 23.60s/it]

Peak GPU Memory Usage: 35.341156005859375 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB


Loading experts:  72%|███████▏  | 23/32 [09:00<03:37, 24.19s/it]

Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 35.59107971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB


Loading experts:  75%|███████▌  | 24/32 [09:24<03:11, 23.95s/it]

Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB


Loading experts:  78%|███████▊  | 25/32 [09:46<02:43, 23.43s/it]

Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB


Loading experts:  81%|████████▏ | 26/32 [10:10<02:21, 23.56s/it]

Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB


Loading experts:  84%|████████▍ | 27/32 [10:32<01:55, 23.12s/it]

Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB


Loading experts:  88%|████████▊ | 28/32 [10:58<01:35, 23.99s/it]

Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB


Loading experts:  91%|█████████ | 29/32 [11:24<01:14, 24.74s/it]

Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB


Loading experts:  94%|█████████▍| 30/32 [11:47<00:48, 24.05s/it]

Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB


Loading experts:  97%|█████████▋| 31/32 [12:10<00:23, 23.72s/it]

Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB
Peak GPU Memory Usage: 36.90357971191406 GB


Loading experts: 100%|██████████| 32/32 [12:30<00:00, 23.45s/it]

Peak GPU Memory Usage: 36.90357971191406 GB





Peak GPU Memory Usage: 36.90357971191406 GB


In [5]:

try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except NameError:
    pass
from src.build_model import OffloadConfig, QuantConfig, build_model
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantized_model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
# state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/Mixtral-8x7B-Instruct-v0.1"

config = AutoConfig.from_pretrained(quantized_model_name)
config.num_experts_per_tok = 1

device = torch.device("cuda:0")
peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
print(f"initial GPU Memory Usage: {peak_memory_usage} GB") 
torch.cuda.reset_peak_memory_stats()
##### Change this to 5 if you have only 12 GB of GPU VRAM #####
offload_per_layer = 7
# offload_per_layer = 5
###############################################################

num_experts = config.num_local_experts
print("number experts:", num_experts)

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
)


attn_config = BaseQuantizeConfig(
    nbits=4,
    group_size=64,
    quant_zero=True,
    quant_scale=True,
)
attn_config["scale_quant_params"]["group_size"] = 256


ffn_config = BaseQuantizeConfig(
    nbits=2,
    group_size=16,
    quant_zero=True,
    quant_scale=True,
)
quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config)


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)

number experts: 8
dict_keys(['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.block_sparse_moe.gate.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.W_q', 'model.layers.0.self_attn.k_proj.meta.meta_scale.scale', 'model.layers.0.self_attn.k_proj.meta.meta_scale.zero', 'model.layers.0.self_attn.k_proj.meta.meta_zero.scale', 'model.layers.0.self_attn.k_proj.meta.meta_zero.zero', 'model.layers.0.self_attn.k_proj.meta.scale_q', 'model.layers.0.self_attn.k_proj.meta.zero_q', 'model.layers.0.self_attn.o_proj.W_q', 'model.layers.0.self_attn.o_proj.meta.meta_scale.scale', 'model.layers.0.self_attn.o_proj.meta.meta_scale.zero', 'model.layers.0.self_attn.o_proj.meta.meta_zero.scale', 'model.layers.0.self_attn.o_proj.meta.meta_zero.zero', 'model.layers.0.self_attn.o_proj.meta.scale_q', 'model.layers.0.self_attn.o_proj.meta.zero_q', 'model.layers.0.self_attn.q_proj.W_q', 'model.layers.0.self_attn.q_pro

Loading experts:   0%|          | 0/32 [00:00<?, ?it/s]

cuda:0
Peak GPU Memory Usage: 4.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.386344909667969 GB


Loading experts:   3%|▎         | 1/32 [00:04<02:29,  4.84s/it]

cuda:0
Peak GPU Memory Usage: 4.448844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.448844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.448844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB


Loading experts:   6%|▋         | 2/32 [00:11<03:00,  6.03s/it]

cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.573844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB


Loading experts:   9%|▉         | 3/32 [00:18<03:01,  6.27s/it]

cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB


## Run the model

In [4]:
from transformers import TextStreamer


tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
past_key_values = None
sequence = None

seq_len = 0
user_input = "tell me a joke"
print("\n")

user_entry = dict(role="user", content=user_input)
input_ids = tokenizer.apply_chat_template([user_entry], return_tensors="pt").to(device)

if past_key_values is None:
  attention_mask = torch.ones_like(input_ids)
else:
  seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
  attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)

print("Mixtral: ", end="")
result = model.generate(
  input_ids=input_ids,
  attention_mask=attention_mask,
  past_key_values=past_key_values,
  streamer=streamer,
  do_sample=True,
  temperature=0.9,
  top_p=0.9,
  max_new_tokens=512,
  pad_token_id=tokenizer.eos_token_id,
  return_dict_in_generate=True,
  output_hidden_states=True,
)
print("\n")

sequence = result["sequences"]
past_key_values = result["past_key_values"]
peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
print(f"Peak GPU Memory Usage: {peak_memory_usage} GB")



Mixtral: Sure! Here's a light-hearted joke for you:

Why don't scientists trust atoms?

Because they make up everything!

I hope that gave you a chuckle. Do you have any specific topic you'd like to hear a joke about, or just a general joke? I'm here to help with any question you have to the best of my ability.


Peak GPU Memory Usage: 36.90357971191406 GB
