# BitNet Model Implementation with Performance Optimizations

This notebook demonstrates how to load and run the BitNet model with various optimizations for improved performance. We'll walk through each step of the process, from setting up the environment to generating and evaluating responses.

In [28]:
# Install required libraries
!pip install torch git+https://github.com/huggingface/transformers.git bitsandbytes

Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /private/var/folders/pt/hmd3x2w503g0phnxg6dz9s1h0000gn/T/pip-req-build-8rugimf1
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /private/var/folders/pt/hmd3x2w503g0phnxg6dz9s1h0000gn/T/pip-req-build-8rugimf1
  Resolved https://github.com/huggingface/transformers.git to commit 31f8a0fe8a7e2db1ee30bf32ed5976cd11f3283c
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone


## 1. Set Environment Variables

First, we'll set up the necessary environment variables for tokenizer parallelism and PyTorch CPU memory alignment.

In [29]:
import os
# Set tokenizer parallelism before importing tokenizer-related modules
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Enable memory alignment
os.environ['PYTORCH_CPU_ALLOC_CONF'] = 'max_split_size_mb:64'

## 2. Import Required Libraries

Import the necessary Python libraries and enable PyTorch optimizations.

In [30]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitNetConfig
import torch.backends.cpu
import time
import tqdm as notebook_tqdm

# Enable optimizations
torch.backends.cpu.optimize = True
torch.set_num_threads(os.cpu_count())

## 3. Load Tokenizer and Model

Load the BitNet tokenizer and set up padding tokens.

In [31]:
model_id = "microsoft/bitnet-b1.58-2B-4T"

# Load tokenizer and set up padding token
tokenizer = AutoTokenizer.from_pretrained(model_id)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

## 4. Configure BitNet Model

Set up the model configuration with various optimizations for improved inference performance.

In [32]:
# Initialize BitNet configuration with optimizations
config = BitNetConfig.from_pretrained(model_id)
config.low_cpu_mem_usage = True
config.pad_token_id = tokenizer.pad_token_id

# Performance optimizations
config.use_cache = True  # Enable KV-cache for faster inference
config.pretraining_tp = 1  # Tensor parallelism degree
config.max_position_embeddings = 2048  # Match with your max_length
config.hidden_dropout_prob = 0  # Disable dropout for inference
config.attention_dropout_prob = 0  # Disable attention dropout for inference
config.use_memory_efficient_attention = True  # Use memory efficient attention
config.scale_attention_softmax_in_fp32 = False  # Keep in lower precision
config.use_flash_attention = True  # Enable flash attention if supported

# Load model with optimized configuration
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    config=config,
    torch_dtype=torch.bfloat16,  # Explicitly set dtype to avoid ValueError
    device_map='auto'  # Automatically handle device placement
)

# Move model to CPU and optimize
model = model.cpu()
model = torch.compile(model)  # Enable torch.compile for faster execution
model.eval()  # Set to inference mode

OptimizedModule(
  (_orig_mod): BitNetForCausalLM(
    (model): BitNetModel(
      (embed_tokens): Embedding(128256, 2560, padding_idx=128009)
      (layers): ModuleList(
        (0-29): 30 x BitNetDecoderLayer(
          (self_attn): BitNetAttention(
            (q_proj): AutoBitLinear(in_features=2560, out_features=2560, bias=False)
            (k_proj): AutoBitLinear(in_features=2560, out_features=640, bias=False)
            (v_proj): AutoBitLinear(in_features=2560, out_features=640, bias=False)
            (o_proj): AutoBitLinear(in_features=2560, out_features=2560, bias=False)
            (attn_sub_norm): BitNetRMSNorm((2560,), eps=1e-05)
          )
          (mlp): BitNetMLP(
            (gate_proj): AutoBitLinear(in_features=2560, out_features=6912, bias=False)
            (up_proj): AutoBitLinear(in_features=2560, out_features=6912, bias=False)
            (down_proj): AutoBitLinear(in_features=6912, out_features=2560, bias=False)
            (act_fn): ReLUSquaredActivation()

## 5. Prepare Input Messages

Define the chat messages and apply the chat template.

In [33]:
# Define batch size for processing
BATCH_SIZE = 1  # Adjust based on your CPU memory

# Prepare chat messages
messages = [
    {"role": "system", "content": "You are a helpful AI assistant."},
    {"role": "user", "content": "Tell me about the latest advancements in AI."},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

## 6. Tokenize Input

Tokenize the input messages and measure the tokenization time.

In [34]:
with torch.inference_mode():  # More efficient than no_grad for inference
    # Measure tokenization time
    tokenization_start = time.time()
    chat_input = tokenizer(
        prompt, 
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=2048  # Adjust based on your needs
    )
    tokenization_time = time.time() - tokenization_start

## 7. Generate Response

Generate the model response with optimized generation parameters.

In [35]:
with torch.inference_mode():
    # Measure generation time
    generation_start = time.time()
    chat_outputs = model.generate(
        **chat_input,
        max_new_tokens=500,
        num_return_sequences=1,
        do_sample=True,
        temperature=0.7,  # Slightly higher temperature for faster sampling
        top_p=0.95,      # Slightly higher top_p for faster sampling
        top_k=40,        # Add top_k sampling for better speed/quality balance
        use_cache=True,  # Enable KV-cache
        num_beams=1,     # Disable beam search for faster generation
        early_stopping=True,  # Enable early stopping
        pad_token_id=tokenizer.pad_token_id,
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        repetition_penalty=1.1,  # Light penalty to avoid repetitions efficiently
        length_penalty=1.0,  # Neutral length penalty for faster completion
        no_repeat_ngram_size=3  # Prevent repetition of 3-grams
    )
    generation_time = time.time() - generation_start

  )


## 8. Decode and Display Output

Decode the generated tokens and display the model's response.

In [36]:
# Measure decoding time
decoding_start = time.time()
response = tokenizer.decode(
    chat_outputs[0][chat_input['input_ids'].shape[-1]:],
    skip_special_tokens=True
)
decoding_time = time.time() - decoding_start

print("\nAssistant Response:", response)


Assistant Response: As of my last update, here are some key areas where significant advancements have been made in AI:

1. Natural Language Processing (NLP): Recent advancements in NLP include improved sentiment analysis, better handling of context, and more accurate translation between languages.

2. Computer Vision: There has been progress in object recognition, facial recognition, and action recognition. Autonomous vehicles also rely heavily on computer vision for navigation.

3. Machine Learning Algorithms: Techniques like reinforcement learning and unsupervised learning have gained popularity. These allow AI systems to learn from their environment without explicit instructions.

4. Robotics: AI is being increasingly used in robotics, enabling machines to perform tasks independently or even autonomously.

5. Healthcare: AI applications range from drug discovery to personalized medicine. AI algorithms can analyze medical images more accurately than human doctors, making early detec

## 9. Performance Metrics

Calculate and display various performance metrics including processing times and generation speed.

In [37]:
# Calculate performance metrics
total_time = tokenization_time + generation_time + decoding_time
num_input_tokens = chat_input['input_ids'].shape[1]
num_output_tokens = chat_outputs[0].shape[0] - num_input_tokens
tokens_per_second = num_output_tokens / generation_time

print("\nPerformance Metrics:")
print(f"Tokenization time: {tokenization_time:.2f}s")
print(f"Generation time: {generation_time:.2f}s")
print(f"Decoding time: {decoding_time:.2f}s")
print(f"Total time: {total_time:.2f}s")
print(f"Input tokens: {num_input_tokens}")
print(f"Output tokens: {num_output_tokens}")
print(f"Generation speed: {tokens_per_second:.2f} tokens/second")


Performance Metrics:
Tokenization time: 0.01s
Generation time: 29.20s
Decoding time: 0.01s
Total time: 29.22s
Input tokens: 26
Output tokens: 354
Generation speed: 12.12 tokens/second
