# Mixtral in Colab

Welcome! In this notebook you can run [Mixtral8x7B-Instruct](https://huggingface.co/mistralai/Mixtral-8x7B-v0.1) with decent generation speed **right in Google Colab or on a consumer-grade GPU**. This was made possible by quantizing the original model in mixed precision and implementing a MoE-specific offloading strategy.

To learn more, read our [tech report](https://arxiv.org/abs/2312.17238) or check out the [repo](https://github.com/dvmazur/mixtral-offloading) on GitHub.

One will need approximately 16 GB of VRAM and 11 GB of RAM to run this notebook and generate somewhat long texts.


<details>

<summary>How to balance between RAM and GPU VRAM usage</summary>

You can balance between RAM and GPU VRAM usage by changing <code>offload_per_layer</code> variable in the <a href="#scrollTo=_mIpePTMFyRY&line=10&uniqifier=1">Initialize model</a> section. Increasing <code>offload_per_layer</code> will decrease GPU VRAM usage, increase RAM usage and decrease generation speed. Decreasing <code>offload_per_layer</code> will have the opposite effect.

Note that this notebook should run normally in Google Colab with <code>offload_per_layer = 4</code>, but may crush with other values. However, if you run this somewhere else, you're free to play with this variable.
</details>

## Install and import libraries

In [1]:
from IPython.display import clear_output

# fix triton in colab
!export LC_ALL="en_US.UTF-8"
!export LD_LIBRARY_PATH="/usr/lib64-nvidia"
!export LIBRARY_PATH="/usr/local/cuda/lib64/stubs"
!ldconfig /usr/lib64-nvidia

!git clone https://github.com/type-shangshu/mixtral-offloading.git --quiet
!cd mixtral-offloading && pip install -q -r requirements.txt
!huggingface-cli download lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo --quiet --local-dir Mixtral-8x7B-Instruct-v0.1-offloading-demo

clear_output()

In [2]:
import sys
import os

# Get the absolute path to the project root directory
project_root = os.path.abspath(os.path.join(os.getcwd(), 'mixtral-offloading')) # use '..' instead of 'mixtral-offloading' if you pull the repo and run locally
print(f"Project root: {project_root}")

# Add the project root to Python path
sys.path.insert(0, project_root)

# Now import other dependencies
import torch
from torch.nn import functional as F
from hqq.core.quantize import BaseQuantizeConfig
from huggingface_hub import snapshot_download
from IPython.display import clear_output
from tqdm.auto import trange
from transformers import AutoConfig, AutoTokenizer
from transformers.utils import logging as hf_logging

# Import project specific modules
from src.build_model import OffloadConfig, QuantConfig, build_model

Project root: /home/ljw/mixtral-offloading/notebooks/mixtral-offloading
[36mhqq_aten package not installed. HQQBackend.ATEN backend will not work unless you install the hqq_aten lib in hqq/kernels.[0m


  from .autonotebook import tqdm as notebook_tqdm


## Login your huggingface

In [None]:
from huggingface_hub import notebook_login

notebook_login()

## Initialize model

In [None]:
# Initialize model parameters
model_name = "mistralai/Mixtral-8x7B-Instruct-v0.1"
quantized_model_name = "lavawolfiee/Mixtral-8x7B-Instruct-v0.1-offloading-demo"
state_path = "Mixtral-8x7B-Instruct-v0.1-offloading-demo"

config = AutoConfig.from_pretrained(quantized_model_name)

device = torch.device("cuda:0")

##### Change this to 5 if you have only 12 GB of GPU VRAM #####
# offload_per_layer = 4
offload_per_layer = 5
###############################################################

num_experts = config.num_local_experts

##### Change Cache Strategy as you want #####
cache_strategy = "random" #Options: "lru", "lfu", "random"

offload_config = OffloadConfig(
    main_size=config.num_hidden_layers * (num_experts - offload_per_layer),
    offload_size=config.num_hidden_layers * offload_per_layer,
    buffer_size=4,
    offload_per_layer=offload_per_layer,
    cache_strategy=cache_strategy,
)


attn_config = BaseQuantizeConfig(
    nbits=4,
    group_size=64,
    quant_zero=True,
    quant_scale=True,
)
attn_config["scale_quant_params"]["group_size"] = 256


ffn_config = BaseQuantizeConfig(
    nbits=2,
    group_size=16,
    quant_zero=True,
    quant_scale=True,
)
quant_config = QuantConfig(ffn_config=ffn_config, attn_config=attn_config)


model = build_model(
    device=device,
    quant_config=quant_config,
    offload_config=offload_config,
    state_path=state_path,
)

Loading experts:   0%|          | 0/32 [00:01<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 14.00 MiB. GPU 0 has a total capacity of 5.61 GiB of which 22.62 MiB is free. Including non-PyTorch memory, this process has 5.08 GiB memory in use. Of the allocated memory 4.38 GiB is allocated by PyTorch, and 647.38 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Specify cache strategy

In [3]:
new_strategy = "lru" #options: "lru", "lfu", "random"

## Run the model

In [None]:
from transformers import AutoTokenizer, TextStreamer
import time
from collections import Counter

tokenizer = AutoTokenizer.from_pretrained(model_name)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

chat_history = []

def switch_cache_strategy(model, new_strategy: str):
    """Switch cache strategy during inference"""
    for layer in model.model.layers:
        if hasattr(layer, 'block_sparse_moe'):
            if hasattr(layer.block_sparse_moe, 'expert_cache'):
                layer.block_sparse_moe.expert_cache.switch_cache_strategy(new_strategy)
    print(f"Cache strategy switched to: {new_strategy}")

switch_cache_strategy(model, new_strategy)

def get_cache_stats(model):
    """获取所有层的缓存统计信息"""
    total_hits = 0
    total_misses = 0
    total_swaps = 0
    active_experts = Counter()
    
    for i, layer in enumerate(model.model.layers):
        if hasattr(layer, 'block_sparse_moe'):
            cache = layer.block_sparse_moe.expert_cache
            for group in cache.group_infos.values():
                total_hits += group.hits
                total_misses += group.misses
            active_experts.update(cache.active_experts)
            total_swaps += cache.swap_count if hasattr(cache, 'swap_count') else 0
    
    return {
        'hits': total_hits,
        'misses': total_misses,
        'hit_rate': total_hits / (total_hits + total_misses) if (total_hits + total_misses) > 0 else 0,
        'swaps': total_swaps,
        'active_experts': dict(active_experts.most_common())
    }

def print_performance_stats(start_time, end_time, total_tokens, cache_stats):
    duration = end_time - start_time
    throughput = total_tokens / duration
    
    print("\n" + "="*50)
    print("Performance Statistics:")
    print(f"Time elapsed: {duration:.2f}s")
    print(f"Total tokens: {total_tokens}")
    print(f"Throughput: {throughput:.2f} tokens/s")
    
    print("\nCache Statistics:")
    print(f"Cache hits: {cache_stats['hits']}")
    print(f"Cache misses: {cache_stats['misses']}")
    print(f"Hit rate: {cache_stats['hit_rate']*100:.2f}%")
    print(f"Total swaps: {cache_stats['swaps']}")
    
    print("\nMost Active Experts:")
    for expert_id, count in list(cache_stats['active_experts'].items())[:5]:
        print(f"Expert {expert_id}: {count} times")
    
    # 性能建议
    print("\nPerformance Suggestions:")
    if cache_stats['hit_rate'] < 0.8:
        print("- Consider increasing cache size - low hit rate detected")
    if throughput < 10:
        print("- Consider reducing offload_per_layer to improve throughput")
    if cache_stats['swaps'] > cache_stats['hits'] * 0.5:
        print("- High swap rate detected - consider adjusting cache strategy")
    print("="*50 + "\n")

while True:
    try:
        print("User: ", end="")
        user_input = input()
        if not user_input:
            continue
            
        chat_history.append(dict(role="user", content=user_input))
        input_ids = tokenizer.apply_chat_template(chat_history, return_tensors="pt").to(device)
        attention_mask = torch.ones_like(input_ids)
        
        start_time = time.time()
        print("Mixtral: ", end="")
        result = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            streamer=streamer,
            do_sample=True,
            temperature=0.9,
            top_p=0.9,
            max_new_tokens=512,
            pad_token_id=tokenizer.eos_token_id,
            return_dict_in_generate=True,
            output_hidden_states=False,
        )
        end_time = time.time()
        print("\n")
        
        output_ids = result["sequences"][0]
        reply = tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True)
        chat_history.append(dict(role="assistant", content=reply))
        
        # 计算生成的token数量
        generated_tokens = len(output_ids) - len(input_ids[0])
        # 获取并打印统计信息
        cache_stats = get_cache_stats(model)
        print_performance_stats(start_time, end_time, generated_tokens, cache_stats)
        
    except KeyboardInterrupt:
        print("\nInterrupted by user")
        cache_stats = get_cache_stats(model)
        print_performance_stats(start_time, time.time(), generated_tokens, cache_stats)
        break
    except Exception as e:
        print(f"\nError occurred: {str(e)}")
        break