# Mixtral in Colab

Welcome! In this notebook you can run [Mixtral8x7B-Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) with decent generation speed **right in Google Colab or on a consumer-grade GPU**. This was made possible by quantizing the original model in mixed precision and implementing a MoE-specific offloading strategy.

To learn more, read our [tech report](https://arxiv.org/abs/2312.17238) or check out the [repo](https://github.com/dvmazur/mixtral-offloading) on GitHub.

One will need approximately 16 GB of VRAM and 11 GB of RAM to run this notebook and generate somewhat long texts.


<details>

<summary>How to balance between RAM and GPU VRAM usage</summary>

You can balance between RAM and GPU VRAM usage by changing <code>offload_per_layer</code> variable in the <a href="#scrollTo=_mIpePTMFyRY&line=10&uniqifier=1">Initialize model</a> section. Increasing <code>offload_per_layer</code> will decrease GPU VRAM usage, increase RAM usage and decrease generation speed. Decreasing <code>offload_per_layer</code> will have the opposite effect.

Note that this notebook should run normally in Google Colab with <code>offload_per_layer = 4</code>, but may crush with other values. However, if you run this somewhere else, you're free to play with this variable.
</details>

## Install and import libraries

In [1]:
%load_ext autoreload
%autoreload 2
# fix numpy in colab
import numpy
from IPython.display import clear_output

# fix triton in colab
!export LC_ALL="en_US.UTF-8"
!export HF_HOME=/scratch/bcjw/yyuan6/.cache
# !export LD_LIBRARY_PATH="/usr/lib64-nvidia"
#!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
# !ldconfig /usr/lib64-nvidia

#!git clone https://github.com/dvmazur/mixtral-offloading.git --quiet
#!cd mixtral-offloading && pip install -q -r requirements.txt
#!huggingface-cli download lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo --quiet --local-dir Mixtral-8x7B-Instruct-v0.1-offloading-demo

# clear_output()


In [3]:
import sys

sys.path.append("/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading")
import torch
from torch.nn import functional as F
from hqq.core.quantize import BaseQuantizeConfig
# from huggingface_hub import snapshot_download
# from IPython.display import clear_output
# from tqdm.auto import trange
from transformers import AutoConfig, AutoTokenizer
# from transformers.utils import logging as hf_logging
from src.build_model import OffloadConfig, QuantConfig, build_model, build_model_without_quant



[36mhqq_aten package not installed. HQQBackend.ATEN backend will not work unless you install the hqq_aten lib in hqq/kernels.[0m


  from .autonotebook import tqdm as notebook_tqdm


## Initialize model

In [4]:

try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except NameError:
    pass
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
# state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/Mixtral-8x7B-Instruct-v0.1"
# state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading/test_dir"
config = AutoConfig.from_pretrained(model_name, torch_dtype=torch.float16,)
device = torch.device("cuda:0")

import gc
gc.collect()
torch.cuda.empty_cache()

peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
print(f"initial GPU Memory Usage: {peak_memory_usage} GB")  

torch.cuda.reset_peak_memory_stats()

offload_per_layer = 6

num_experts = config.num_local_experts
print("number experts:", num_experts)

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
)

print(offload_config)

model = build_model_without_quant(
    device=device,
    offload_config=offload_config,
    state_path=state_path,
)
peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
print(f"Peak GPU Memory Usage: {peak_memory_usage} GB")


initial GPU Memory Usage: 0.0 GB
number experts: 8
OffloadConfig(main_size=64, offload_size=192, buffer_size=4, offload_per_layer=6)




/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading/test_dir/model-00001-of-00257.safetensors
{'lm_head.weight': tensor([[-0.0006, -0.0018,  0.0006,  ..., -0.0029, -0.0026, -0.0023],
        [-0.0006, -0.0018,  0.0006,  ..., -0.0029, -0.0025, -0.0023],
        [ 0.0043,  0.0060,  0.0060,  ..., -0.0079, -0.0052,  0.0060],
        ...,
        [-0.0014,  0.0058, -0.0052,  ...,  0.0016, -0.0154, -0.0018],
        [ 0.0107, -0.0035, -0.0057,  ...,  0.0097, -0.0006, -0.0059],
        [ 0.0095,  0.0016,  0.0001,  ..., -0.0151, -0.0116, -0.0032]],
       device='cuda:0', dtype=torch.float16), 'model.embed_tokens.weight': tensor([[-0.0000e+00,  0.0000e+00, -0.0000e+00,  ..., -0.0000e+00,
         -0.0000e+00, -0.0000e+00],
        [-1.5259e-02,  9.2030e-05, -1.6113e-02,  ...,  4.2915e-05,
         -1.7212e-02,  2.2125e-03],
        [-7.5531e-04, -4.6349e-04, -3.0136e-04,  ..., -2.0218e-04,
         -3.6812e-04, -4.6730e-05],
        ...,
        [ 8.9111e-03, -1.2817e-02,  5.3101e-03,  ...,  

MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40.


Finish loading trunk states!
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.297885417938232 GB
Peak GPU Memory Usage: 7.428245544433594 GB
Peak GPU Memory Usage: 7.756370544433594 GB
Peak GPU Memory Usage: 8.084495544433594 GB
Peak GPU Memory Usage: 8.412620544433594 GB
Peak GPU Memory Usage: 8.740745544433594 GB
Peak GPU Memory Usage: 9.068870544433594 GB
Peak GPU Memory Usage: 9.396995544433594 GB
Peak GPU Memory Usage: 9.725120544433594 GB
Peak GPU Memory Usage: 10.053245544433594 GB
Peak GPU Memory Usage: 10.381370544433594 GB
Peak GPU Memory Usage: 10.709495544433594 GB
Peak GPU Memory Usage: 11.037620544433594 GB

Loading experts:   0%|          | 0/32 [00:00<?, ?it/s]

Peak GPU Memory Usage: 26.459495544433594 GB
Peak GPU Memory Usage: 26.568870544433594 GB
Peak GPU Memory Usage: 26.787620544433594 GB
Peak GPU Memory Usage: 27.115745544433594 GB
Peak GPU Memory Usage: 27.443870544433594 GB
Peak GPU Memory Usage: 27.553245544433594 GB
Peak GPU Memory Usage: 27.553245544433594 GB


Loading experts:   3%|▎         | 1/32 [00:52<27:16, 52.77s/it]

Peak GPU Memory Usage: 27.553245544433594 GB
Peak GPU Memory Usage: 27.553245544433594 GB
Peak GPU Memory Usage: 27.771995544433594 GB
Peak GPU Memory Usage: 27.881370544433594 GB
Peak GPU Memory Usage: 27.881370544433594 GB
Peak GPU Memory Usage: 27.881370544433594 GB
Peak GPU Memory Usage: 27.881370544433594 GB
Peak GPU Memory Usage: 27.881370544433594 GB


Loading experts:   6%|▋         | 2/32 [01:08<15:29, 30.98s/it]

Peak GPU Memory Usage: 28.100120544433594 GB
Peak GPU Memory Usage: 28.100120544433594 GB
Peak GPU Memory Usage: 28.100120544433594 GB
Peak GPU Memory Usage: 28.100120544433594 GB
Peak GPU Memory Usage: 28.100120544433594 GB
Peak GPU Memory Usage: 28.100120544433594 GB
Peak GPU Memory Usage: 28.209495544433594 GB
Peak GPU Memory Usage: 28.209495544433594 GB


Loading experts:   9%|▉         | 3/32 [01:24<11:44, 24.31s/it]

Peak GPU Memory Usage: 28.209495544433594 GB
Peak GPU Memory Usage: 28.209495544433594 GB
Peak GPU Memory Usage: 28.209495544433594 GB
Peak GPU Memory Usage: 28.428245544433594 GB
Peak GPU Memory Usage: 28.428245544433594 GB
Peak GPU Memory Usage: 28.428245544433594 GB
Peak GPU Memory Usage: 28.428245544433594 GB
Peak GPU Memory Usage: 28.428245544433594 GB


Loading experts:  12%|█▎        | 4/32 [02:06<14:29, 31.04s/it]

Peak GPU Memory Usage: 28.428245544433594 GB
Peak GPU Memory Usage: 28.428245544433594 GB
Peak GPU Memory Usage: 28.428245544433594 GB
Peak GPU Memory Usage: 28.428245544433594 GB
Peak GPU Memory Usage: 28.428245544433594 GB
Peak GPU Memory Usage: 28.756370544433594 GB
Peak GPU Memory Usage: 28.756370544433594 GB
Peak GPU Memory Usage: 28.756370544433594 GB


Loading experts:  16%|█▌        | 5/32 [02:21<11:24, 25.35s/it]

Peak GPU Memory Usage: 28.756370544433594 GB
Peak GPU Memory Usage: 28.756370544433594 GB
Peak GPU Memory Usage: 29.084495544433594 GB
Peak GPU Memory Usage: 29.084495544433594 GB
Peak GPU Memory Usage: 29.084495544433594 GB
Peak GPU Memory Usage: 29.084495544433594 GB
Peak GPU Memory Usage: 29.084495544433594 GB
Peak GPU Memory Usage: 29.412620544433594 GB


Loading experts:  19%|█▉        | 6/32 [02:34<09:08, 21.11s/it]

Peak GPU Memory Usage: 29.412620544433594 GB
Peak GPU Memory Usage: 29.412620544433594 GB
Peak GPU Memory Usage: 29.412620544433594 GB
Peak GPU Memory Usage: 29.412620544433594 GB
Peak GPU Memory Usage: 29.740745544433594 GB
Peak GPU Memory Usage: 29.740745544433594 GB
Peak GPU Memory Usage: 29.740745544433594 GB
Peak GPU Memory Usage: 29.740745544433594 GB
Peak GPU Memory Usage: 29.740745544433594 GB


Loading experts:  22%|██▏       | 7/32 [03:16<11:37, 27.91s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  25%|██▌       | 8/32 [03:32<09:41, 24.23s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  28%|██▊       | 9/32 [03:48<08:20, 21.76s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  31%|███▏      | 10/32 [04:18<08:51, 24.18s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  34%|███▍      | 11/32 [04:45<08:44, 24.96s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  38%|███▊      | 12/32 [05:24<09:45, 29.28s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  41%|████      | 13/32 [06:06<10:29, 33.15s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  44%|████▍     | 14/32 [06:57<11:35, 38.61s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  47%|████▋     | 15/32 [08:14<14:13, 50.18s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  50%|█████     | 16/32 [08:29<10:32, 39.55s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  53%|█████▎    | 17/32 [09:21<10:46, 43.13s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  56%|█████▋    | 18/32 [09:37<08:12, 35.16s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  59%|█████▉    | 19/32 [09:51<06:14, 28.83s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  62%|██████▎   | 20/32 [10:20<05:44, 28.69s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  66%|██████▌   | 21/32 [10:36<04:35, 25.08s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  69%|██████▉   | 22/32 [10:51<03:38, 21.88s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  72%|███████▏  | 23/32 [11:06<02:59, 19.93s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  75%|███████▌  | 24/32 [11:50<03:37, 27.14s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  78%|███████▊  | 25/32 [12:04<02:41, 23.10s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  81%|████████▏ | 26/32 [12:18<02:01, 20.33s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  84%|████████▍ | 27/32 [12:53<02:03, 24.74s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  88%|████████▊ | 28/32 [14:07<02:38, 39.66s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  91%|█████████ | 29/32 [14:23<01:37, 32.58s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  94%|█████████▍| 30/32 [14:39<00:55, 27.54s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts:  97%|█████████▋| 31/32 [15:32<00:35, 35.30s/it]

Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB
Peak GPU Memory Usage: 30.068870544433594 GB


Loading experts: 100%|██████████| 32/32 [15:53<00:00, 29.79s/it]

Peak GPU Memory Usage: 30.068870544433594 GB





Peak GPU Memory Usage: 30.068870544433594 GB


In [5]:

try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except NameError:
    pass
from src.build_model import OffloadConfig, QuantConfig, build_model
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantized_model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
# state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/Mixtral-8x7B-Instruct-v0.1"

config = AutoConfig.from_pretrained(quantized_model_name)
config.num_experts_per_tok = 1

device = torch.device("cuda:0")
peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
print(f"initial GPU Memory Usage: {peak_memory_usage} GB") 
torch.cuda.reset_peak_memory_stats()
##### Change this to 5 if you have only 12 GB of GPU VRAM #####
offload_per_layer = 7
# offload_per_layer = 5
###############################################################

num_experts = config.num_local_experts
print("number experts:", num_experts)

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
)


attn_config = BaseQuantizeConfig(
    nbits=4,
    group_size=64,
    quant_zero=True,
    quant_scale=True,
)
attn_config["scale_quant_params"]["group_size"] = 256


ffn_config = BaseQuantizeConfig(
    nbits=2,
    group_size=16,
    quant_zero=True,
    quant_scale=True,
)
quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config)


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)

number experts: 8
dict_keys(['lm_head.weight', 'model.embed_tokens.weight', 'model.layers.0.block_sparse_moe.gate.weight', 'model.layers.0.input_layernorm.weight', 'model.layers.0.post_attention_layernorm.weight', 'model.layers.0.self_attn.k_proj.W_q', 'model.layers.0.self_attn.k_proj.meta.meta_scale.scale', 'model.layers.0.self_attn.k_proj.meta.meta_scale.zero', 'model.layers.0.self_attn.k_proj.meta.meta_zero.scale', 'model.layers.0.self_attn.k_proj.meta.meta_zero.zero', 'model.layers.0.self_attn.k_proj.meta.scale_q', 'model.layers.0.self_attn.k_proj.meta.zero_q', 'model.layers.0.self_attn.o_proj.W_q', 'model.layers.0.self_attn.o_proj.meta.meta_scale.scale', 'model.layers.0.self_attn.o_proj.meta.meta_scale.zero', 'model.layers.0.self_attn.o_proj.meta.meta_zero.scale', 'model.layers.0.self_attn.o_proj.meta.meta_zero.zero', 'model.layers.0.self_attn.o_proj.meta.scale_q', 'model.layers.0.self_attn.o_proj.meta.zero_q', 'model.layers.0.self_attn.q_proj.W_q', 'model.layers.0.self_attn.q_pro

Loading experts:   0%|          | 0/32 [00:00<?, ?it/s]

cuda:0
Peak GPU Memory Usage: 4.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.386344909667969 GB


Loading experts:   3%|▎         | 1/32 [00:04<02:29,  4.84s/it]

cuda:0
Peak GPU Memory Usage: 4.448844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.448844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.448844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB


Loading experts:   6%|▋         | 2/32 [00:11<03:00,  6.03s/it]

cuda:0
Peak GPU Memory Usage: 4.511344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.573844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB


Loading experts:   9%|▉         | 3/32 [00:18<03:01,  6.27s/it]

cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB


Loading experts:  12%|█▎        | 4/32 [00:24<02:51,  6.14s/it]

cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.636344909667969 GB


Loading experts:  16%|█▌        | 5/32 [00:29<02:42,  6.00s/it]

cuda:0
Peak GPU Memory Usage: 4.698844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.761344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.761344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.761344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.823844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.823844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.823844909667969 GB
cuda:0
Peak GPU Memory Usage: 4.886344909667969 GB


Loading experts:  19%|█▉        | 6/32 [00:36<02:36,  6.03s/it]

cuda:0
Peak GPU Memory Usage: 4.886344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.886344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.886344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.886344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.886344909667969 GB
cuda:0
Peak GPU Memory Usage: 4.948844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.011344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.011344909667969 GB


Loading experts:  22%|██▏       | 7/32 [01:04<05:32, 13.32s/it]

cuda:0
Peak GPU Memory Usage: 5.011344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB


Loading experts:  25%|██▌       | 8/32 [01:17<05:14, 13.10s/it]

cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB


Loading experts:  28%|██▊       | 9/32 [01:22<04:04, 10.61s/it]

cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB


Loading experts:  31%|███▏      | 10/32 [01:26<03:07,  8.54s/it]

cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB


Loading experts:  34%|███▍      | 11/32 [01:31<02:39,  7.59s/it]

cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB


Loading experts:  38%|███▊      | 12/32 [01:38<02:27,  7.36s/it]

cuda:0
Peak GPU Memory Usage: 5.073844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB


Loading experts:  41%|████      | 13/32 [01:43<02:09,  6.82s/it]

cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB


Loading experts:  44%|████▍     | 14/32 [01:50<02:01,  6.77s/it]

cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB


Loading experts:  47%|████▋     | 15/32 [01:57<01:55,  6.77s/it]

cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB


Loading experts:  50%|█████     | 16/32 [02:03<01:45,  6.62s/it]

cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB


Loading experts:  53%|█████▎    | 17/32 [02:10<01:39,  6.66s/it]

cuda:0
Peak GPU Memory Usage: 5.136344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB


Loading experts:  56%|█████▋    | 18/32 [02:18<01:37,  6.97s/it]

cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB


Loading experts:  59%|█████▉    | 19/32 [02:24<01:30,  6.93s/it]

cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB


Loading experts:  62%|██████▎   | 20/32 [02:30<01:18,  6.58s/it]

cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB


Loading experts:  66%|██████▌   | 21/32 [02:58<02:22, 12.93s/it]

cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB


Loading experts:  69%|██████▉   | 22/32 [03:06<01:55, 11.52s/it]

cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.198844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB


Loading experts:  72%|███████▏  | 23/32 [03:10<01:23,  9.23s/it]

cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB


Loading experts:  75%|███████▌  | 24/32 [03:16<01:05,  8.15s/it]

cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB


Loading experts:  78%|███████▊  | 25/32 [03:21<00:51,  7.31s/it]

cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB


Loading experts:  81%|████████▏ | 26/32 [03:27<00:40,  6.83s/it]

cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB


Loading experts:  84%|████████▍ | 27/32 [03:33<00:32,  6.58s/it]

cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.261344909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB


Loading experts:  88%|████████▊ | 28/32 [03:39<00:25,  6.43s/it]

cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB


Loading experts:  91%|█████████ | 29/32 [03:46<00:20,  6.77s/it]

cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB


Loading experts:  94%|█████████▍| 30/32 [03:54<00:14,  7.10s/it]

cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB


Loading experts:  97%|█████████▋| 31/32 [04:01<00:06,  6.90s/it]

cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB
cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB


Loading experts: 100%|██████████| 32/32 [04:07<00:00,  7.73s/it]

cuda:0
Peak GPU Memory Usage: 5.323844909667969 GB





## Run the model

In [14]:
from transformers import TextStreamer
import time


tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
past_key_values = None
sequence = None

seq_len = 0
for output_length in [128, 512]:
    # text = """Write an email with the subject line "Resumes"."""
    for batch_size in [1, 4, 16]:
      text =  """In the twilight haze of a city that never truly slept, where the neon lights flickered like the last gasps of a dying star, there lived a girl named Lyra. She resided in the heart of the metropolis, in an apartment that was too small for dreams as big as hers. The world outside her window was a tapestry of shadows and light, a place where fortunes were made and lost with the flip of a coin, and where secrets whispered on the wind were more valuable than gold.

            Lyra had a peculiar talent, one that set her apart from the millions of souls that hustled through the city’s arteries each day. She could see the threads of fate that bound people together, glowing lines that stretched out into the distance, intertwining and parting in a dance as old as time. It was a gift she had hidden away, fearful of the consequences should the wrong eyes find her. But as the city around her grew darker, consumed by a greed that ate at its heart, Lyra knew she could no longer remain hidden in the shadows."""
      #user_entry = dict(role="user", content=user_input)

      # input_ids = tokenizer(text, return_tensors="pt").to(0).to(device)

# Tokenize all entries (batch operation)
      texts = [text] * batch_size
      input_ids = tokenizer(texts, padding=False, return_tensors="pt").input_ids.to(device)

      # input_ids = tokenizer.apply_chat_template([user_entry] * batch_size, return_tensors="pt").to(device)
      # input_ids = tokenizer.apply_chat_template([user_entry], return_tensors="pt").to(device)


      """
      if past_key_values is None:
        attention_mask = torch.ones_like(input_ids[0])
      else:
        seq_len = input_ids.size(1) + past_key_values[0][0][0].size(1)
        attention_mask = torch.ones([1, seq_len - 1], dtype=torch.int, device=device)
      """
      
      for i in range(2):
        _ = model.generate(
        input_ids=input_ids,
        past_key_values=past_key_values,
        do_sample=True,
        temperature=0.9,
        top_p=0.9,
        max_new_tokens=output_length,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        output_hidden_states=True,
      )

      total_time = 0
      for iter in range(4):
          torch.cuda.reset_peak_memory_stats()
          st = time.time()
          result = model.generate(
            input_ids=input_ids,
            past_key_values=past_key_values,
            do_sample=True,
            temperature=0.9,
            top_p=0.9,
            max_new_tokens=output_length,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            output_hidden_states=True,
          )
          et = time.time()
          torch.cuda.synchronize()
          # print(tokenizer.decode(outputs[0], skip_special_tokens=True))
          peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
          for i, response in enumerate(result):
              print(len(response), end="")
          print(" ")
          print(f"output length: {output_length}, batch_size: {batch_size}, iter {iter}")
          print(f"total inference time: {et - st} s; latency per token: {(et - st)/output_length} tokens/s")
          print(f"Peak GPU Memory Usage: {peak_memory_usage} GB")
          total_time += et - st

      print("###AVERAGE")
      print(f"output length: {output_length}, batch_size: {batch_size}")
      print(f"total inference time: {total_time / 5} s; latency per token: {total_time / 5 / output_length} tokens/s")

      """
      result = model.generate(
        input_ids=input_ids,
        attention_mask=attention_mask,
        past_key_values=past_key_values,
        streamer=streamer,
        do_sample=True,
        temperature=0.9,
        top_p=0.9,
        max_new_tokens=512,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        output_hidden_states=True,
      )
      print("\n")

      sequence = result["sequences"]
      past_key_values = result["past_key_values"]
      peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
      print(f"Peak GPU Memory Usage: {peak_memory_usage} GB")
      """

91315 
output length: 32, batch_size: 4, iter 0
total inference time: 106.84455347061157 s; latency per token: 3.3388922959566116 tokens/s
Peak GPU Memory Usage: 26.923448085784912 GB
91315 
output length: 32, batch_size: 4, iter 1
total inference time: 107.92586183547974 s; latency per token: 3.3726831823587418 tokens/s
Peak GPU Memory Usage: 27.224334239959717 GB
91315 
output length: 32, batch_size: 4, iter 2
total inference time: 109.82498121261597 s; latency per token: 3.432030662894249 tokens/s
Peak GPU Memory Usage: 27.221885204315186 GB
91315 
output length: 32, batch_size: 4, iter 3
total inference time: 105.26572871208191 s; latency per token: 3.2895540222525597 tokens/s
Peak GPU Memory Usage: 27.22917127609253 GB
91315 
output length: 32, batch_size: 4, iter 4
total inference time: 110.35679316520691 s; latency per token: 3.448649786412716 tokens/s
Peak GPU Memory Usage: 27.228034496307373 GB
###AVERAGE
output length: 32, batch_size: 4
total inference time: 108.0435836791992

In [12]:

import json
import gc
import re
import os
from safetensors.torch import load_file, save_file
try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except NameError:
    pass

save_path = "/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading/test_dir/"
state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/Mixtral-8x7B-Instruct-v0.1"
state_index_path = os.path.join(state_path, "model.safetensors.index.json")
save_index_path = os.path.join(save_path, "model.safetensors.index.json")
with open(save_index_path) as f:
    save_weight_map = json.load(f)["weight_map"]
with open(state_index_path) as f:
    state_weight_map = json.load(f)["weight_map"]

state_dict = {}
exclusion_pattern = re.compile(r"model\.layers\.\d+\.block_sparse_moe\.experts\.\d+\.(w1|w2|w3)\.weight")
unique_filenames = set(state_weight_map.values())
for filename in unique_filenames:
    file_path = os.path.join(state_path, filename)
    loaded_file = load_file(file_path, device=str(device))
    for key, file_in_map in state_weight_map.items():
        if filename == file_in_map and not exclusion_pattern.search(key):
            if key in loaded_file:
                # print(f"Adding tensor to state_dict: {key}")
                state_dict[key] = loaded_file[key]
            else:
                print(f"Expected tensor not found in safetensor file: {key}")
    peak_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 3) 
    print(f"Peak GPU Memory Usage: {peak_memory_usage} GB")  

print(state_dict.keys())
half_state_dict = {key: tensor.half() for key, tensor in state_dict.items()}
file_to_save = save_path + "model-00001-of-00257.safetensors" 
save_file(half_state_dict, file_to_save)

if not os.path.exists(save_index_path):
    state_index_data = {"weight_map": {}}
    print("???")
else:
    with open(save_index_path, 'r') as index_file:
        index_data = json.load(index_file)

for key in state_dict.keys():
    index_data['weight_map'][key] = "model-00001-of-00257.safetensors" 

print(index_data['weight_map'])
with open(save_index_path, 'w') as index_file:
    json.dump(index_data, index_file, indent=4)


Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB


Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
Peak GPU Memory Usage: 27.330001831054688 GB
dict_keys(['model.layers.1.input_layernorm.weight', 'model.layers.1.post_attention_layernorm.weight', 'model.layers.2.block_sparse_moe.gate.weight', 'model.layers.2.input_layernorm.weight', 'model.layers.2.post_attention_layernorm.weight', 'model.layers.2.self_attn.k_proj.weight'

In [8]:

import json
import gc
import re
import os
from safetensors.torch import load_file, save_file
try:
    del model
    gc.collect()
    torch.cuda.empty_cache()
except NameError:
    pass
device = torch.device("cuda:0")
save_path = "/scratch/bcjw/yyuan6/mistral-8x7b/mixtral-offloading/test_dir/"
state_path = "/scratch/bcjw/yyuan6/mistral-8x7b/Mixtral-8x7B-Instruct-v0.1"
index_path = os.path.join(save_path, "model.safetensors.index.json")
i = 1
if not os.path.exists(index_path):
    index_data = {"weight_map": {}}
else:
    with open(index_path, 'r') as index_file:
        index_data = json.load(index_file)

for layer_idx in range(32):  
    for expert_idx in range(8):
        index_path = os.path.join(state_path, "model.safetensors.index.json")
        with open(index_path) as f:
            module_idx = f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}"
            weight_map = json.load(f)["weight_map"]
            state_fpath = weight_map[f"{module_idx}.w1.weight"]
            state_fpath2 = weight_map[f"{module_idx}.w3.weight"]
        
        loaded_state_dict = load_file(os.path.join(state_path, state_fpath), device=str(device))
        state_dict = {}
        state_dict["w1.weight"] = loaded_state_dict[f'model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight'] 

        if f'model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight' not in loaded_state_dict:
            loaded_state_dict = load_file(os.path.join(state_path, state_fpath2), device=str(device))
        state_dict["w2.weight"] = loaded_state_dict[f'model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight'] 
        
        if f'model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight' not in loaded_state_dict:
            loaded_state_dict = load_file(os.path.join(state_path, state_fpath2), device=str(device))
        state_dict["w3.weight"] = loaded_state_dict[f'model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight'] 

        print(state_dict.keys())
        # half_state_dict = {key: tensor.half() for key, tensor in state_dict.items()}

        i = i + 1
        save_file_name = f"model-{i:05d}-of-00257.safetensors"
        # save_file(half_state_dict, save_path + save_file_name)

        for weight_key in ['w1.weight', 'w2.weight', 'w3.weight']:
            index_data['weight_map'][f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.{weight_key}"] = save_file_name

print(index_data['weight_map'])
with open(index_path, 'w') as index_file:
    json.dump(index_data, index_file, indent=4)

dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])


dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.weight', 'w3.weight'])
dict_keys(['w1.weight', 'w2.wei