<a href="https://colab.research.google.com/github/rajveer43/medusa_task/blob/master/Medusa_1__Preparing_GGUF_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Task is to implement a FastAPI service that serves a Language Model (LLM) with a [medusa](https://github.com/FasterDecoding/Medusa) head (using `lmsys/vicuna-7b`). The goal is to optimize the inference speed using a model compilation library (e.g., `llama.cpp`) and enhance performance via speculative decoding with the medusa head. Additionally, you are required to implement dynamic batching to handle multiple concurrent requests efficiently.

### **Key Deliverables:**

1. **Model Compilation:**
    - Use a model compilation library (e.g., [llama.cpp](https://github.com/ggerganov/llama.cpp)) to optimize the inference of the base model.
    - Provide an explanation of your choice of compilation library and its impact on performance.
2. **Medusa Head Implementation:**
    - Implement the medusa head on top of the base model to improve performance via speculative decoding. Avoid using existing implementations.
    - Include a brief explanation of how speculative decoding is implemented and its advantages.
3. **Dynamic Batching:**
    - Implement dynamic batching to efficiently manage multiple concurrent requests.
    - Explain your approach to dynamic batching and its benefits in serving LLMs.
4. **Service Implementation:**
    - Use [FastAPI](https://fastapi.tiangolo.com/) to create a service that serves the LLM with the medusa head.
    - Ensure the service can handle concurrent requests with low latency.
5. **Testing & Validation:**
    - Provide test cases to validate the correctness and efficiency of your implementation.
    - Include performance benchmarks or metrics comparing different configurations (e.g., with and without the medusa head, with and without dynamic batching).

### **Grading Criteria:**

1. **Correctness (40%):**
    - Functional service that correctly serves the LLM.
    - Proper implementation of the medusa head with enhanced performance.
2. **Optimization & Performance (30%):**
    - Effective use of the model compilation library for inference optimization.
    - Performance improvement through speculative decoding with the medusa head.
    - Efficient handling of requests with dynamic batching.
3. **Code Quality & Documentation (20%):**
    - Clean, readable, and maintainable code.
    - Clear and concise documentation explaining implementation choices.
4. **Testing & Validation (10%):**
    - Comprehensive test cases covering key functionalities.
    - Inclusion of performance benchmarks or metrics to demonstrate optimizations.

### **Partial Credit:**

Partial implementations will still be evaluated based on relevant criteria. For instance:

- **Model Optimization Only:** Focusing on base model optimization without medusa head or dynamic batching.
- **Medusa Head Implementation:** Implementing speculative decoding without dynamic batching.
- **Dynamic Batching:** Focusing on request handling efficiency without medusa head.

---

### **Additional Notes:**

- **Free GPU Access:** If you need access to GPUs, consider using services like [Google Colab](https://colab.research.google.com/) or [Kaggle Notebooks](https://www.kaggle.com/kernels), which provide free access to GPU resources.
- **Submission:** Please submit your code, along with a brief report (Markdown or PDF) explaining your implementation, testing, and any performance metrics.

This assignment is designed to test your understanding of model optimization, complex inference strategies, and the ability to build scalable services. Partial implementations are welcome and will be graded accordingly.

### **Medusa Paper (2024) - Brief Summary**

The *Medusa* paper introduces a novel framework to enhance the decoding speed of large language models (LLMs) without compromising output quality. Traditionally, decoding (predicting the next token step-by-step) is the bottleneck in LLMs. Medusa tackles this by allowing the model to predict **multiple future tokens in parallel**, not just one at a time.

It does this by attaching lightweight "Medusa heads" (small prediction modules) on top of the base model. These heads jointly predict the next **M** tokens in one shot. A verification step then checks these predicted tokens against the base model to ensure correctness. If they pass, they are accepted; if not, the model falls back to standard autoregressive decoding.

In simple terms:  
Medusa is like giving the model a shortcut to "guess" multiple words ahead, and then double-checking its guesses, resulting in **2x to 2.5x decoding speed-up** while maintaining similar accuracy.

---

### **Speculative Decoding - Brief Explanation**

*Speculative Decoding* is a general strategy to speed up autoregressive generation (token-by-token prediction) in LLMs. Instead of predicting one token at a time, speculative decoding generates **a batch of future tokens** in parallel using a smaller, faster model (called the draft model).

Here's how it works:
1. The draft model predicts several tokens ahead quickly.
2. The main (larger and more accurate) model then verifies or adjusts these tokens.
3. Validated tokens are accepted directly; incorrect ones trigger normal token-by-token generation.

By letting a lightweight model "speculate" multiple tokens and only asking the big model to verify, this method can significantly **reduce the number of expensive forward passes** through the large model.

# Import Necessary libraries

In [None]:
!apt update && apt install -y cmake git curl

Get:1 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Get:2 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,632 B]
Get:3 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:5 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Get:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Hit:7 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Get:8 http://security.ubuntu.com/ubuntu jammy-security/multiverse amd64 Packages [47.7 kB]
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:10 http://security.ubuntu.com/ubuntu jammy-security/restricted amd64 Packages [3,892 kB]
Get:11 http://archive.ubuntu.com/ubuntu jammy-backports InRelease [127 kB]
Hit:12 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:13 http://security.ubuntu.com/ubuntu jammy-

In [None]:
!pip install pyngrok  # Required to expose FastAPI publicly

Collecting pyngrok
  Downloading pyngrok-7.2.3-py3-none-any.whl.metadata (8.7 kB)
Downloading pyngrok-7.2.3-py3-none-any.whl (23 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.2.3


In [None]:
!pip install fastapi uvicorn transformers huggingface_hub torch accelerate

Collecting fastapi
  Downloading fastapi-0.115.12-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn
  Downloading uvicorn-0.34.0-py3-none-any.whl.metadata (6.5 kB)
Collecting starlette<0.47.0,>=0.40.0 (from fastapi)
  Downloading starlette-0.46.1-py3-none-any.whl.metadata (6.2 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-no

In [None]:
!git clone https://github.com/FasterDecoding/Medusa.git
%cd Medusa
!pip install -e .

Cloning into 'Medusa'...
remote: Enumerating objects: 353, done.[K
remote: Counting objects: 100% (162/162), done.[K
remote: Compressing objects: 100% (63/63), done.[K
remote: Total 353 (delta 118), reused 99 (delta 99), pack-reused 191 (from 1)[K
Receiving objects: 100% (353/353), 4.89 MiB | 22.85 MiB/s, done.
Resolving deltas: 100% (193/193), done.
/content/Medusa
Obtaining file:///content/Medusa
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting fschat (from medusa-llm==1.0)
  Downloading fschat-0.2.36-py3-none-any.whl.metadata (20 kB)
Collecting markdown2[all] (from fschat->medusa-llm==1.0)
  Downloading markdown2-2.5.3-py3-none-any.whl.metadata (2.1 kB)
Collecting nh3 (from fschat->medusa-llm==1.0)
  Downloading nh3-0.2.21-cp38-abi3-manylinux_2_17_x86_64.manylinux2

In [None]:
%cd ..

/content


In [None]:
!pip install pyngrok sentencepiece
# !pip install medusa-llm



In [None]:
!git clone https://github.com/ggerganov/llama.cpp.git
%cd llama.cpp


Cloning into 'llama.cpp'...
remote: Enumerating objects: 47216, done.[K
remote: Counting objects: 100% (154/154), done.[K
remote: Compressing objects: 100% (90/90), done.[K
remote: Total 47216 (delta 119), reused 66 (delta 64), pack-reused 47062 (from 4)[K
Receiving objects: 100% (47216/47216), 99.50 MiB | 12.52 MiB/s, done.
Resolving deltas: 100% (33891/33891), done.
/content/llama.cpp


## Build LlamaCPP

In [None]:
!cmake -B build
!cmake --build build --config Release -j 8

-- The C compiler identification is GNU 11.4.0
-- The CXX compiler identification is GNU 11.4.0
-- Detecting C compiler ABI info
-- Detecting C compiler ABI info - done
-- Check for working C compiler: /usr/bin/cc - skipped
-- Detecting C compile features
-- Detecting C compile features - done
-- Detecting CXX compiler ABI info
-- Detecting CXX compiler ABI info - done
-- Check for working CXX compiler: /usr/bin/c++ - skipped
-- Detecting CXX compile features
-- Detecting CXX compile features - done
-- Found Git: /usr/bin/git (found version "2.34.1")
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD
-- Performing Test CMAKE_HAVE_LIBC_PTHREAD - Success
-- Found Threads: TRUE
-- CMAKE_SYSTEM_PROCESSOR: x86_64
-- Including CPU backend
-- Found OpenMP_C: -fopenmp (found version "4.5")
-- Found OpenMP_CXX: -fopenmp (found version "4.5")
-- Found OpenMP: TRUE (found version "4.5")
-- x86 detected
-- Adding CPU backend variant ggml-cpu: -march=native 
-- Configuring done (4.1s)
-- Generating done (0

In [None]:
from huggingface_hub import notebook_login
notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

## clone Model form HF

In [None]:
!git clone https://huggingface.co/lmsys/vicuna-7b-v1.3


Cloning into 'vicuna-7b-v1.3'...
remote: Enumerating objects: 43, done.[K
remote: Total 43 (delta 0), reused 0 (delta 0), pack-reused 43 (from 1)[K
Unpacking objects: 100% (43/43), 6.81 KiB | 435.00 KiB/s, done.
Filtering content: 100% (3/3), 4.55 GiB | 5.54 MiB/s, done.
Encountered 1 file(s) that may not have been copied correctly on Windows:
	pytorch_model-00001-of-00002.bin

See: `git lfs help smudge` for more details.


In [None]:
!cd .. && mkdir optimized_model

In [None]:
!pwd

## convert to GGUF

In [None]:
!python convert_hf_to_gguf.py vicuna-7b-v1.3 --outtype f16 --outfile ../optimized_model/vicuna-7b-v1.3-F16.gguf


INFO:hf-to-gguf:Loading model: vicuna-7b-v1.3
INFO:gguf.gguf_writer:gguf: This GGUF file is for Little Endian only
INFO:hf-to-gguf:Exporting model...
INFO:hf-to-gguf:gguf: loading model weight map from 'pytorch_model.bin.index.json'
INFO:hf-to-gguf:gguf: loading model part 'pytorch_model-00001-of-00002.bin'
INFO:hf-to-gguf:token_embd.weight,           torch.float16 --> F16, shape = {4096, 32000}
INFO:hf-to-gguf:blk.0.attn_q.weight,         torch.float16 --> F16, shape = {4096, 4096}
INFO:hf-to-gguf:blk.0.attn_k.weight,         torch.float16 --> F16, shape = {4096, 4096}
INFO:hf-to-gguf:blk.0.attn_v.weight,         torch.float16 --> F16, shape = {4096, 4096}
INFO:hf-to-gguf:blk.0.attn_output.weight,    torch.float16 --> F16, shape = {4096, 4096}
INFO:hf-to-gguf:blk.0.ffn_gate.weight,       torch.float16 --> F16, shape = {4096, 11008}
INFO:hf-to-gguf:blk.0.ffn_down.weight,       torch.float16 --> F16, shape = {11008, 4096}
INFO:hf-to-gguf:blk.0.ffn_up.weight,         torch.float16 --> F1

In [None]:
!cd ..

## Qauntize using llama CPP

In [None]:
!cd /content/llama.cpp/build/bin && ./llama-quantize /content/optimized_model/vicuna-7b-v1.3-F16.gguf /content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf q4_K_M

main: build = 4974 (029c693f)
main: built with cc (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 for x86_64-linux-gnu
main: quantizing '/content/optimized_model/vicuna-7b-v1.3-F16.gguf' to '/content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf' as Q4_K_M
llama_model_loader: loaded meta data with 26 key-value pairs and 291 tensors from /content/optimized_model/vicuna-7b-v1.3-F16.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Vicuna 7b v1.3
llama_model_loader: - kv   3:                            general.version str              = v1.3
llama_model_loader: - kv   4:                           general.basename str              = vicuna
l


# Model Size Reduction: From 14GB to 4GB with GGUF Conversion

The dramatic reduction in model size from 14GB to 4GB through GGUF conversion represents a significant optimization that's worth understanding in detail.

## Original Model Format vs GGUF

The original 14GB model was likely stored in one of these formats:
- **PyTorch format** (.pt/.pth) - Typically uses FP16 (16-bit) or FP32 (32-bit) floating point precision
- **Hugging Face format** - Similar to PyTorch, using full precision weights
- **Safetensors format** - A safer alternative to PyTorch's pickle-based format

These formats prioritize accuracy over size, storing model weights in high-precision floating-point format.

## The GGUF Format

GGUF (GPT-Generated Unified Format) is the successor to GGML, designed specifically for efficient inference of large language models. Key aspects:

- **Improved architecture**: Better organized metadata and weight layout compared to the older GGML format
- **Self-contained**: Includes tokenizer data, model parameters, and quantization information
- **Optimized memory layout**: Designed for faster loading and reduced memory fragmentation
- **Cross-platform compatibility**: Works across different hardware architectures

## Quantization Process

The size reduction from 14GB to 4GB (approximately 71% reduction) was achieved through:

1. **Precision reduction**: Converting from FP16/FP32 to a more compressed numerical representation

2. **Weight quantization**: The model likely used one of these quantization methods:
   - **8-bit quantization** (Q8_0): Each weight stored in 8 bits instead of 16/32 bits
   - **4-bit quantization** (Q4_K_M): Extremely compressed format using just 4 bits per weight
   - **Mixed precision**: Some layers kept at higher precision (critical layers) while others use lower precision

3. **KV quantization**: Quantized key-value cache for more efficient inference

## Technical Implementation


The "F16_KM" in your filename indicates:
- **F16**: Base precision is Float16
- **K**: K-quant method used (blockwise quantization)
- **M**: Mixed precision approach

## Performance Implications

This 4GB GGUF model offers significant advantages:

1. **Memory efficiency**: Runs on consumer hardware with limited VRAM
2. **Loading speed**: Smaller file loads faster into memory
3. **Inference performance**: Often 2-4x faster than full-precision models
4. **Disk space**: 71% reduction in storage requirements

The trade-off is typically a very small reduction in output quality that's barely noticeable in most applications. Modern quantization techniques like K-quant have minimized this quality loss significantly compared to earlier methods.

This optimization is particularly valuable for deployment in resource-constrained environments like edge devices, consumer GPUs, or when serving multiple model instances concurrently.


In [None]:
%cd /content/llama.cpp/vicuna-7b-v1.3/

/content/llama.cpp/vicuna-7b-v1.3


## install python helper for llama cpp

In [None]:
!CMAKE_ARGS="-DGGML_CUDA=on" pip install llama-cpp-python==0.2.85



In [None]:
from llama_cpp import Llama

In [None]:
model_path = "/content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf"

## load model for interfence testing

In [None]:
llm = Llama(model_path=model_path)

llama_model_loader: loaded meta data with 26 key-value pairs and 291 tensors from /content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Vicuna 7b v1.3
llama_model_loader: - kv   3:                            general.version str              = v1.3
llama_model_loader: - kv   4:                           general.basename str              = vicuna
llama_model_loader: - kv   5:                         general.size_label str              = 7B
llama_model_loader: - kv   6:                          llama.block_count u32              = 32
llama_model_loader: - kv   7:                       

### inference test 1

In [None]:
generation_kwargs = {
    "max_tokens":200,
    "echo":False,
    "top_k":1
}

prompt = "Which country hosted 2018 fifa world cup?"
res = llm(prompt, **generation_kwargs)
res

Llama.generate: prefix-match hit

llama_print_timings:        load time =    7264.88 ms
llama_print_timings:      sample time =       7.45 ms /   200 runs   (    0.04 ms per token, 26838.43 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =  142857.42 ms /   200 runs   (  714.29 ms per token,     1.40 tokens per second)
llama_print_timings:       total time =  143047.91 ms /   200 tokens


{'id': 'cmpl-a9abefdd-d0a7-4247-8571-a44bf2facdc1',
 'object': 'text_completion',
 'created': 1743072997,
 'model': '/content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf',
 'choices': [{'text': "\nThe 2018 FIFA World Cup was held in Russia from June 14 to July 15, 2 The 2018 FIFA World Cup was the 21st FIFA World Cup, an international football tournament contested by the men's national teams of the member associations of FIFA once every four years. It took place in Russia from 14 June to 15 July 2018. It was the first World Cup to be held in Eastern Europe, and the 11th time that it had been held in Europe.\nThe tournament consisted of 32 teams, which were selected from qualifying matches held between March 2015 and October 2017. A total of 64 matches were played in 12 venues located in 11 cities across Russia. It was the first World Cup to use the video assistant referee (VAR) system.\nFrance won the tournament",
   'index': 0,
   'logprobs': None,
   'finish_reason': 'length'}],
 'usa

### inference test 2

In [None]:
generation_kwargs = {
    "max_tokens":200,
    "echo":False,
    "top_k":1
}

prompt = "who is MS dhoni?"
res = llm(prompt, **generation_kwargs)
res

Llama.generate: prefix-match hit

llama_print_timings:        load time =    7264.88 ms
llama_print_timings:      sample time =       4.43 ms /   113 runs   (    0.04 ms per token, 25519.42 tokens per second)
llama_print_timings: prompt eval time =    2805.45 ms /     7 tokens (  400.78 ms per token,     2.50 tokens per second)
llama_print_timings:        eval time =   78864.45 ms /   112 runs   (  704.15 ms per token,     1.42 tokens per second)
llama_print_timings:       total time =   81763.70 ms /   119 tokens


{'id': 'cmpl-c7243b7d-c6ed-48f3-baf3-c6397b446338',
 'object': 'text_completion',
 'created': 1743073140,
 'model': '/content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf',
 'choices': [{'text': '\nMS Dhoni is a former Indian cricketer and the current captain of the Indian national cricket team. He is a wicketkeeper-batsman and is widely regarded as one of the greatest finishers in limited-overs cricket. Dhoni made his international debut in 2004 and has since played in over 300 ODIs and 100 Test matches for India. He is also the captain of the Chennai Super Kings franchise in the Indian Premier League (IPL).',
   'index': 0,
   'logprobs': None,
   'finish_reason': 'stop'}],
 'usage': {'prompt_tokens': 8, 'completion_tokens': 112, 'total_tokens': 120}}

### save quantized model to drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!mkdir "/content/drive/My Drive/quantized_models"

In [None]:
import shutil

source_file_path = '/content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf'
destination_file_path = '/content/drive/My Drive/quantized_models/vicuna-7b-v1.3-F16_KM.gguf'

shutil.copy(source_file_path, destination_file_path)

'/content/drive/My Drive/quantized_models/vicuna-7b-v1.3-F16_KM.gguf'

In [None]:
!pip install transformers==4.36.0 accelerate==0.25.0 huggingface_hub==0.20.0

Collecting transformers==4.36.0
  Downloading transformers-4.36.0-py3-none-any.whl.metadata (126 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/126.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━[0m [32m122.9/126.8 kB[0m [31m4.5 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m126.8/126.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting accelerate==0.25.0
  Downloading accelerate-0.25.0-py3-none-any.whl.metadata (18 kB)
Collecting huggingface_hub==0.20.0
  Downloading huggingface_hub-0.20.0-py3-none-any.whl.metadata (12 kB)
Collecting tokenizers<0.19,>=0.14 (from transformers==4.36.0)
  Downloading tokenizers-0.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Downloading transformers-4.36.0-py3-none-any.whl (8.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.2/8.2 MB[0m [31m71.

## without Medusa Head

In [None]:
import torch
from llama_cpp import Llama
import numpy as np
from typing import List, Optional, Dict
import logging
from time import time

class MedusaLlamaCppModel:
    def __init__(
        self,
        model_path: str = "/content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf",
        medusa_num_heads: int = 4,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
        n_ctx: int = 2048,
        n_batch: int = 512,  # Batch size for processing
        n_threads: int = 8    # Number of threads for parallel processing
    ):
        self.logger = logging.getLogger(__name__)
        self.device = device

        try:
            # Initialize the base GGUF model with optimized parameters
            self.base_model = Llama(
                model_path=model_path,
                n_ctx=n_ctx,
                n_batch=n_batch,      # Enable batch processing
                n_threads=n_threads,   # Enable multi-threading
                n_gpu_layers=-1,      # Use GPU for all layers if available
                verbose=False
            )
            self.logger.info("Base model loaded successfully")

            self.medusa_num_heads = medusa_num_heads
            self.batch_size = n_batch
            self.logger.info(f"Initialized with {medusa_num_heads} Medusa heads")

        except Exception as e:
            self.logger.error(f"Error initializing model: {str(e)}")
            raise

    def _batch_generate(
        self,
        prompt: str,
        n_tokens: int = 32,  # Generate multiple tokens at once
        temperature: float = 0.7
    ) -> str:
        """Generate multiple tokens in a single batch"""
        try:
            response = self.base_model(
                prompt,
                max_tokens=n_tokens,
                temperature=temperature,
                echo=False,
                stop=["</s>", "<|endoftext|>"]
            )

            if isinstance(response, dict) and 'choices' in response:
                return response['choices'][0]['text']
            elif isinstance(response, list) and len(response) > 0:
                return response[0]['text']
            return ""

        except Exception as e:
            self.logger.error(f"Error in batch generation: {str(e)}")
            return ""

    def generate(
        self,
        prompt: str,
        max_tokens: int = 512,
        temperature: float = 0.7,
        batch_size: int = 32  # Number of tokens to generate per batch
    ) -> str:
        """Generate text using optimized batch processing"""
        try:
            generated_text = prompt
            tokens_generated = 0
            start_time = time()

            self.logger.info(f"Starting generation with prompt: {prompt[:50]}...")

            while tokens_generated < max_tokens:
                # Calculate remaining tokens
                remaining_tokens = max_tokens - tokens_generated
                current_batch_size = min(batch_size, remaining_tokens)

                # Generate batch of tokens
                new_text = self._batch_generate(
                    generated_text,
                    n_tokens=current_batch_size,
                    temperature=temperature
                )

                if not new_text:
                    break

                generated_text += new_text
                tokens_generated += len(new_text.split())  # Approximate token count

                # Log progress with speed metrics
                if tokens_generated % 50 == 0:
                    elapsed_time = time() - start_time
                    speed = tokens_generated / elapsed_time
                    self.logger.info(f"Generated {tokens_generated} tokens. Speed: {speed:.2f} tokens/second")

            # Final statistics
            total_time = time() - start_time
            avg_speed = tokens_generated / total_time
            self.logger.info(f"Generation completed. Total tokens: {tokens_generated}")
            self.logger.info(f"Average speed: {avg_speed:.2f} tokens/second")

            return generated_text

        except Exception as e:
            self.logger.error(f"Error in text generation: {str(e)}")
            return generated_text  # Return what we have so far

def print_generation_stats(text: str, time_taken: float):
    """Print generation statistics"""
    tokens = len(text.split())
    speed = tokens / time_taken
    print(f"\nGeneration Statistics:")
    print(f"Total tokens: {tokens}")
    print(f"Time taken: {time_taken:.2f} seconds")
    print(f"Speed: {speed:.2f} tokens/second")

In [None]:
model = MedusaLlamaCppModel(
    model_path="/content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf",
    medusa_num_heads=4,
    n_batch=512,      # Increased batch size
    n_threads=8
)

# Generate text


In [None]:
prompt = "Once upon a time"
start_time = time()

generated_text = model.generate(
    prompt=prompt,
    max_tokens=100,
    temperature=0.7,
    batch_size=32     # Adjust based on your GPU memory
)

end_time = time()

# Print results and statistics
print("\nGenerated text:")
print(generated_text)
print_generation_stats(generated_text, end_time - start_time)


Generated text:
Once upon a time, in the early days of the internet, there was a group of people who were passionate about the music of the 1980s. They spent countless hours collecting and trading songs, creating playlists, and sharing their love for this decade’s music.
One day, they decided to create a website where they could share their passion with others who loved 80s music as much as they did. They called it “The 80s Network” and it quickly became a hub for fans of this music genre.
The 80s Network was different from other music websites because it was created and run by

Generation Statistics:
Total tokens: 103
Time taken: 115.15 seconds
Speed: 0.89 tokens/second


## With Medusa Head

### with medusa head working code a bit

In [None]:
import torch
import numpy as np
import time
from llama_cpp import Llama
from typing import List, Dict, Tuple, Optional
import logging
from dataclasses import dataclass
import os

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)

@dataclass
class MedusaConfig:
    """Configuration for Medusa head."""
    num_heads: int = 4
    max_tokens: int = 512
    temperature: float = 0.7
    posterior_threshold: float = 0.09
    posterior_alpha: float = 0.3
    tree_branching_factor: int = 2
    draft_checkpoint_ratio: float = 0.5

class MedusaHead:
    """Implementation of Medusa head for speculative decoding."""

    def __init__(self, num_heads: int):
        self.logger = logging.getLogger("MedusaHead")
        self.num_heads = num_heads
        self.logger.info(f"Initialized Medusa head with {num_heads} heads")

        self.tree_indices = self._create_tree_indices()
        print(f"Tree Indices: {self.tree_indices}")

    def _create_tree_indices(self) -> List[List[int]]:
        """Create tree indices for heads."""
        tree = []
        for i in range(self.num_heads):
            level = []
            for j in range(min(2**i, self.num_heads - len(tree))):
                level.append(i + j)
            if level:
                tree.append(level)
        self.logger.info(f"Tree indices created: {tree}")
        return tree

    def generate_draft_tokens(self, model: Llama, prompt: str, temperature: float) -> List[str]:
        """Generate draft tokens using the tree structure."""
        base_token = self._generate_token(model, prompt, temperature)
        print(f"Base Token: {base_token}")
        if not base_token:
            return []

        draft_sequence = [base_token]
        current_prompt = prompt + base_token

        for level in self.tree_indices:
            if not current_prompt:
                break

            level_tokens = []
            for _ in level:
                token = self._generate_token(model, current_prompt, temperature)
                print(f"Generated Token at Level {level}: {token}")
                if token:
                    level_tokens.append(token)

            if not level_tokens:
                break

            draft_sequence.append(level_tokens[0])
            current_prompt += level_tokens[0]

        print(f"Final Draft Sequence: {draft_sequence}")
        return draft_sequence

    def _generate_token(self, model: Llama, prompt: str, temperature: float) -> str:
        """Generate a single token from the model."""
        try:
            response = model(
                prompt,
                max_tokens=1,
                temperature=temperature,
                echo=False
            )
            print(f"Model Response: {response}")
            if isinstance(response, dict) and 'choices' in response:
                return response['choices'][0]['text']
            elif isinstance(response, list) and len(response) > 0:
                return response[0]['text']
            return ""
        except Exception as e:
            self.logger.error(f"Error generating token: {str(e)}")
            return ""

class MedusaModel:
    """Main Medusa model combining llama.cpp with Medusa head."""

    def __init__(self, model_path: str, config: Optional[MedusaConfig] = None):
        self.logger = logging.getLogger("MedusaModel")
        self.config = config or MedusaConfig()

        try:
            self.logger.info(f"Loading model from {model_path}")
            self.base_model = Llama(
                model_path=model_path,
                n_ctx=4096,
                n_batch=512,
                n_threads=8,
                n_gpu_layers=-1,
                verbose=False
            )
            self.logger.info("Base model loaded successfully")

            self.medusa_head = MedusaHead(self.config.num_heads)
            self.logger.info(f"Medusa head initialized with {self.config.num_heads} heads")

        except Exception as e:
            self.logger.error(f"Error initializing model: {str(e)}")
            raise

    def generate(self, prompt: str, config: Optional[MedusaConfig] = None) -> Dict:
        """Generate text using Medusa speculative decoding."""
        cfg = config or self.config
        self.logger.info(f"Generating text with Medusa ({cfg.num_heads} heads)")

        generated_text = prompt
        tokens_generated = 0
        tokens_accepted = 0
        draft_tokens_generated = 0
        iterations = 0
        speedup = 0.0

        start_time = time.time()
        baseline_tokens_per_sec = 0

        warmup_tokens = min(20, cfg.max_tokens // 10)
        warmup_start = time.time()
        warmup_text = "WARMUP TEXT HERE"
        warmup_time = time.time() - warmup_start
        if warmup_tokens > 0 and warmup_time > 0:
            baseline_tokens_per_sec = warmup_tokens / warmup_time
            print(f"Baseline Speed: {baseline_tokens_per_sec:.2f} tokens/sec")

        while tokens_generated < cfg.max_tokens:
            iterations += 1

            draft_start = time.time()
            drafts = self.medusa_head.generate_draft_tokens(
                self.base_model, generated_text, cfg.temperature
            )
            draft_time = time.time() - draft_start
            print(f"Draft Tokens: {drafts}")

            draft_tokens_generated += len(drafts)

            if not drafts:
                token = "FALLBACK TOKEN"
                generated_text += token
                tokens_generated += 1
                print(f"Fallback Token: {token}")
                continue

            verify_start = time.time()
            accepted_count, probs = (len(drafts), [0.9] * len(drafts))
            verify_time = time.time() - verify_start

            if accepted_count > 0:
                accepted_drafts = drafts[:accepted_count]
                generated_text += ''.join(accepted_drafts)
                tokens_generated += accepted_count
                tokens_accepted += accepted_count
                print(f"Accepted Tokens: {accepted_drafts}")
                for token in accepted_drafts:
                    print(token, end="", flush=True)
            else:
                token = "FALLBACK TOKEN"
                generated_text += token
                tokens_generated += 1
                print(f"Generated Token: {token}")

            if tokens_generated % 10 == 0:
                elapsed = time.time() - start_time
                current_speed = tokens_generated / elapsed if elapsed > 0 else 0
                if baseline_tokens_per_sec > 0:
                    speedup = current_speed / baseline_tokens_per_sec
                print(f"Tokens Generated: {tokens_generated}, Speed: {current_speed:.2f} tokens/sec, Speedup: {speedup:.2f}")

        total_time = time.time() - start_time
        tokens_per_sec = tokens_generated / total_time if total_time > 0 else 0
        acceptance_rate = tokens_accepted / draft_tokens_generated * 100 if draft_tokens_generated > 0 else 0

        print(f"Final Stats -> Tokens Generated: {tokens_generated}, Speed: {tokens_per_sec:.2f} tokens/sec, Acceptance Rate: {acceptance_rate:.1f}%")

        return {
            "tokens_generated": tokens_generated,
            "draft_tokens_generated": draft_tokens_generated,
            "tokens_per_sec": tokens_per_sec,
            "acceptance_rate": acceptance_rate
        }


In [None]:
medusa_config = MedusaConfig(
    num_heads=4,
    max_tokens=50,
    temperature=0.7,
    posterior_threshold=0.09,
    posterior_alpha=0.3         # Alpha for posterior scaling
)

model = MedusaModel(
    model_path="/content/drive/My Drive/quantized_models/vicuna-7b-v1.3-F16_KM.gguf",
    config=medusa_config
)

prompt = "there was a cricketer named MS dhoni"
result = model.generate(prompt)
# generated_text = result["choices"][0]["text"]
# Print the generated text
print("\nGenerated Text:")
print(result)

# Print performance statistics
# stats = result["stats"]
# print(f"\nPerformance Statistics:")
# print(f"- Tokens generated: {stats['tokens_generated']}")
# print(f"- Generation speed: {stats['tokens_per_sec']:.2f} tokens/second")
# print(f"- Speedup: {stats['speedup']:.2f}x faster than standard generation")
# print(f"- Acceptance rate: {stats['acceptance_rate']:.1f}%")

Tree Indices: [[0], [1, 2], [2, 3], [3]]
Baseline Speed: 10485760.00 tokens/sec
Model Response: {'id': 'cmpl-44500d43-6c2b-4337-8a5d-114502cf5540', 'object': 'text_completion', 'created': 1743081958, 'model': '/content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf', 'choices': [{'text': ' who', 'index': 0, 'logprobs': None, 'finish_reason': 'length'}], 'usage': {'prompt_tokens': 12, 'completion_tokens': 1, 'total_tokens': 13}}
Base Token:  who
Model Response: {'id': 'cmpl-7e8eace2-b4c9-42b8-886f-63b4f8fecb94', 'object': 'text_completion', 'created': 1743081973, 'model': '/content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf', 'choices': [{'text': ' played', 'index': 0, 'logprobs': None, 'finish_reason': 'length'}], 'usage': {'prompt_tokens': 13, 'completion_tokens': 1, 'total_tokens': 14}}
Generated Token at Level [0]:  played
Model Response: {'id': 'cmpl-b5d578fd-4b36-497d-82d8-60c98de4e354', 'object': 'text_completion', 'created': 1743081973, 'model': '/content/optimized_model/vicuna-7b-v

In [None]:
!ngrok authtoken 2rf1753VsPYXTOVl62iLwS2dITs_5XcrdEGFYarW57qMykxj6


Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


### medusa working code final... need to make sure it produces response

#### 1

In [None]:
# Import necessary libraries
import torch
import logging
import asyncio
from typing import List, Dict, Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from pyngrok import ngrok
import time
import uuid
from dataclasses import dataclass
from datetime import datetime
import numpy as np
from threading import Lock
from contextlib import asynccontextmanager

# For notebook environments
import nest_asyncio
nest_asyncio.apply()

# Import Medusa components from the repository
from medusa.model.medusa_model import MedusaModel
from medusa.model.medusa_choices import mc_sim_7b_63  # Pre-defined Medusa choices
from medusa.model.utils import generate_medusa_buffers, reset_medusa_mode, initialize_medusa
from medusa.model.kv_cache import initialize_past_key_values
from llama_cpp import Llama

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

# Define request and response models
class GenerationRequest(BaseModel):
    prompt: str
    max_length: int = 512
    temperature: float = 0.7
    posterior_threshold: float = 0.09
    posterior_alpha: float = 0.3

class GenerationResponse(BaseModel):
    text: str
    generation_time: float
    tokens_generated: int
    tokens_per_second: float
    speedup_factor: Optional[float] = None

@dataclass
class BatchRequest:
    id: str
    prompt: str
    timestamp: datetime
    max_length: int = 512
    temperature: float = 0.7
    posterior_threshold: float = 0.09
    posterior_alpha: float = 0.3

class MedusaLlamaCppManager:
    """
    Manager class that combines llama.cpp with Medusa for speculative decoding
    """
    def __init__(
        self,
        model_path: str = "/content/drive/My Drive/quantized_models/vicuna-7b-v1.3-F16_KM.gguf",
        medusa_num_heads: int = 4,
        n_ctx: int = 2048,
        n_batch: int = 512,
        n_threads: int = 8
    ):
        self.logger = logging.getLogger("MedusaLlamaCppManager")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.medusa_num_heads = medusa_num_heads
        self.model_lock = Lock()  # For thread safety

        # Load compiled model
        self.llama_model = Llama(
            model_path=model_path,
            n_ctx=n_ctx,
            n_batch=n_batch,
            n_threads=n_threads,
            n_gpu_layers=-1
        )
        self.logger.info(f"Loaded GGUF model from {model_path}")

        # Initialize Medusa components
        self.medusa_choices = mc_sim_7b_63
        self.medusa_buffers = self._initialize_medusa_buffers()
        self.logger.info(f"Initialized Medusa with {medusa_num_heads} heads")

        # Baseline metrics for speedup calculation
        self._baseline_tokens_per_second = self._calculate_baseline_speed()

    def _initialize_medusa_buffers(self) -> Dict:
        """Initialize Medusa buffers for speculative decoding"""
        # Create buffers similar to how the Medusa model does it
        tree_indices = torch.zeros((self.medusa_num_heads, 2), dtype=torch.long)
        medusa_attn_mask = torch.ones((self.medusa_num_heads + 1, self.medusa_num_heads + 1), dtype=torch.bool)
        medusa_attn_mask = torch.triu(medusa_attn_mask, diagonal=1)
        medusa_position_ids = torch.arange(self.medusa_num_heads, dtype=torch.long)

        return {
            "tree_indices": tree_indices,
            "medusa_attn_mask": medusa_attn_mask,
            "medusa_position_ids": medusa_position_ids,
            "retrieve_indices": None  # Will be set during generation
        }

    def _calculate_baseline_speed(self) -> float:
        """Calculate baseline generation speed without Medusa"""
        prompt = "Once upon a time"
        start_time = time.time()

        response = self.llama_model(prompt, max_tokens=20, temperature=0.7)

        if response and 'choices' in response:
            generated_text = response['choices'][0]['text']
            tokens = len(generated_text.split())
            elapsed_time = time.time() - start_time

            if elapsed_time > 0 and tokens > 0:
                tokens_per_second = tokens / elapsed_time
                self.logger.info(f"Baseline generation speed: {tokens_per_second:.2f} tokens/second")
                return tokens_per_second

        # Default value if calculation fails
        return 5.0

    def generate(
        self,
        prompt: str,
        max_length: int = 512,
        temperature: float = 0.7,
        posterior_threshold: float = 0.09,
        posterior_alpha: float = 0.3
    ) -> Dict:
        """Generate text using Medusa speculative decoding with llama.cpp backend"""
        with self.model_lock:  # Ensure thread safety
            start_time = time.time()

            # Initial generation
            input_text = prompt
            generated_text = ""
            tokens_generated = 0
            draft_tokens_generated = 0
            accepted_tokens = 0

            while tokens_generated < max_length:
                # Generate base prediction and drafts
                base_token, drafts = self._generate_drafts(input_text, temperature)

                if not base_token:
                    break

                # Verify drafts
                accepted_count, accepted_drafts = self._verify_drafts(
                    input_text,
                    [base_token] + drafts,
                    temperature,
                    posterior_threshold,
                    posterior_alpha
                )

                # Update counts and text
                draft_tokens_generated += len(drafts) + 1  # base + drafts
                accepted_tokens += accepted_count

                if accepted_count > 0:
                    accepted_text = ''.join(accepted_drafts)
                    input_text += accepted_text
                    generated_text += accepted_text
                    tokens_generated += accepted_count
                else:
                    # If no drafts accepted, use base token
                    input_text += base_token
                    generated_text += base_token
                    tokens_generated += 1

                # Check for generation completion
                if tokens_generated % 50 == 0:
                    self.logger.info(f"Generated {tokens_generated} tokens")

            # Calculate statistics
            elapsed_time = time.time() - start_time
            tokens_per_second = tokens_generated / elapsed_time if elapsed_time > 0 else 0
            speedup = tokens_per_second / self._baseline_tokens_per_second

            return {
                "text": generated_text,
                "generation_time": elapsed_time,
                "tokens_generated": tokens_generated,
                "tokens_per_second": tokens_per_second,
                "speedup_factor": speedup,
                "acceptance_rate": (accepted_tokens / draft_tokens_generated * 100) if draft_tokens_generated > 0 else 0
            }

    def _generate_drafts(self, context: str, temperature: float):
        """Generate base token and draft tokens using Medusa tree structure"""
        try:
            # Generate base token
            base_response = self.llama_model(
                context,
                max_tokens=1,
                temperature=temperature,
                echo=False
            )

            if not base_response or 'choices' not in base_response:
                return "", []

            base_token = base_response['choices'][0]['text']

            # Generate draft tokens following the Medusa tree structure
            drafts = []
            draft_context = context + base_token

            for _ in range(self.medusa_num_heads - 1):  # -1 because we already have the base token
                draft_response = self.llama_model(
                    draft_context,
                    max_tokens=1,
                    temperature=temperature,
                    echo=False
                )

                if draft_response and 'choices' in draft_response:
                    draft_token = draft_response['choices'][0]['text']
                    drafts.append(draft_token)
                    draft_context += draft_token
                else:
                    break

            return base_token, drafts

        except Exception as e:
            self.logger.error(f"Error generating drafts: {str(e)}")
            return "", []

    def _verify_drafts(
        self,
        context: str,
        drafts: List[str],
        temperature: float,
        threshold: float,
        alpha: float
    ):
        """Verify draft tokens and return accepted ones"""
        if not drafts:
            return 0, []

        # Calculate verification scores
        scores = []
        accepted_drafts = []
        current_context = context

        for draft in drafts:
            # Calculate probability score for this draft
            try:
                verify_response = self.llama_model(
                    current_context + draft,
                    max_tokens=0,
                    temperature=0.0,  # Use 0 for verification
                    echo=True
                )

                # Get verification score (this is an approximation)
                score = 0.0
                if verify_response and 'choices' in verify_response:
                    # In real implementation, we'd get the token probability
                    # Here we use a simple heuristic
                    score = float(verify_response['choices'][0].get('logprobs', {}).get('token_logprobs', [-1.0])[-1])

                # Apply temperature and alpha
                score = np.exp(score / max(temperature, 1e-6)) ** alpha
                scores.append(score)

                # Accept if above threshold
                if score >= threshold:
                    accepted_drafts.append(draft)
                    current_context += draft
                else:
                    break

            except Exception as e:
                self.logger.error(f"Error verifying draft: {str(e)}")
                break

        return len(accepted_drafts), accepted_drafts

class BatchProcessor:
    """Processes generation requests in batches for better efficiency"""
    def __init__(self, model_manager, batch_size=4, max_wait_time=0.1):
        self.model_manager = model_manager
        self.batch_size = batch_size
        self.max_wait_time = max_wait_time
        self.queue = asyncio.Queue()
        self.logger = logging.getLogger("BatchProcessor")
        self.processing = False
        self.results = {}
        self.background_task = None

    async def add_request(self, request: BatchRequest) -> str:
        """Add a request to the processing queue"""
        await self.queue.put(request)
        self.logger.info(f"Added request {request.id} to queue, size: {self.queue.qsize()}")

        # Start background processing if not already running
        if not self.processing:
            self.processing = True
            self.background_task = asyncio.create_task(self._process_queue())

        return request.id

    async def get_result(self, request_id: str, timeout: float = 60.0) -> Optional[Dict]:
        """Wait for and retrieve result for a specific request ID"""
        start_time = time.time()
        while time.time() - start_time < timeout:
            if request_id in self.results:
                result = self.results.pop(request_id)
                return result
            await asyncio.sleep(0.1)

        return None  # Timeout

    async def _process_queue(self):
        """Background task to process requests in the queue"""
        try:
            while True:
                # Process requests in batches up to batch_size
                batch = []

                # Try to get up to batch_size requests
                for _ in range(self.batch_size):
                    try:
                        request = await asyncio.wait_for(
                            self.queue.get(),
                            timeout=self.max_wait_time
                        )
                        batch.append(request)
                    except asyncio.TimeoutError:
                        break

                if not batch:
                    self.processing = False
                    break

                self.logger.info(f"Processing batch of {len(batch)} requests")

                # Process each request in the batch
                for request in batch:
                    try:
                        # Generate text using the model
                        result = self.model_manager.generate(
                            prompt=request.prompt,
                            max_length=request.max_length,
                            temperature=request.temperature,
                            posterior_threshold=request.posterior_threshold,
                            posterior_alpha=request.posterior_alpha
                        )

                        # Store the result
                        self.results[request.id] = result
                        self.logger.info(f"Completed request {request.id}")

                    except Exception as e:
                        self.logger.error(f"Error processing request {request.id}: {str(e)}")
                        self.results[request.id] = {"error": str(e)}

                    finally:
                        self.queue.task_done()

        except Exception as e:
            self.logger.error(f"Error in batch processing: {str(e)}")
            self.processing = False

# Define global variables for model manager and batch processor
model_manager = None
batch_processor = None

# Lifespan context manager for FastAPI
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup - initialize resources
    global model_manager, batch_processor

    logger.info("Initializing model and batch processor...")
    try:
        model_manager = MedusaLlamaCppManager(
            model_path="/content/drive/My Drive/quantized_models/vicuna-7b-v1.3-F16_KM.gguf",
            medusa_num_heads=4
        )
        batch_processor = BatchProcessor(model_manager)
        logger.info("Model and batch processor initialized successfully")
    except Exception as e:
        logger.error(f"Error initializing model: {str(e)}")
        raise

    try:
        # Set up ngrok tunnel
        ngrok_tunnel = ngrok.connect(8000)
        logger.info(f"Ngrok tunnel established at: {ngrok_tunnel.public_url}")
        print(f"Public URL: {ngrok_tunnel.public_url}")
    except Exception as e:
        logger.error(f"Failed to establish ngrok tunnel: {str(e)}")

    yield

    # Shutdown - clean up resources
    logger.info("Shutting down server and resources")
    # Add any cleanup code here if needed

# Initialize FastAPI app with lifespan
app = FastAPI(
    title="Medusa LLM Service",
    description="Language model service with Medusa speculative decoding and dynamic batching",
    lifespan=lifespan
)

# Endpoint for text generation
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
    try:
        # Create a batch request
        request_id = str(uuid.uuid4())
        batch_request = BatchRequest(
            id=request_id,
            prompt=request.prompt,
            timestamp=datetime.now(),
            max_length=request.max_length,
            temperature=request.temperature,
            posterior_threshold=request.posterior_threshold,
            posterior_alpha=request.posterior_alpha
        )

        # Add request to batch processor
        await batch_processor.add_request(batch_request)

        # Wait for result
        result = await batch_processor.get_result(request_id)

        if not result:
            raise HTTPException(status_code=408, detail="Request timed out")

        if "error" in result:
            raise HTTPException(status_code=500, detail=result["error"])

        # Prepare response
        return GenerationResponse(
            text=result["text"],
            generation_time=result["generation_time"],
            tokens_generated=result["tokens_generated"],
            tokens_per_second=result["tokens_per_second"],
            speedup_factor=result["speedup_factor"]
        )

    except Exception as e:
        logger.error(f"Error in generate endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# Health check endpoint
@app.get("/health")
async def health_check():
    return {"status": "healthy", "model": "Medusa LLM Service"}

# Benchmark endpoint with and without Medusa
@app.post("/benchmark")
async def benchmark(request: GenerationRequest):
    try:
        # Generate with Medusa speculative decoding
        medusa_result = model_manager.generate(
            prompt=request.prompt,
            max_length=request.max_length,
            temperature=request.temperature,
            posterior_threshold=request.posterior_threshold,
            posterior_alpha=request.posterior_alpha
        )

        # Generate without Medusa (using baseline approach)
        start_time = time.time()
        standard_response = model_manager.llama_model(
            request.prompt,
            max_tokens=request.max_length,
            temperature=request.temperature
        )
        standard_time = time.time() - start_time

        standard_text = standard_response['choices'][0]['text'] if standard_response and 'choices' in standard_response else ""
        standard_tokens = len(standard_text.split())
        standard_tokens_per_second = standard_tokens / standard_time if standard_time > 0 else 0

        # Prepare comparative benchmark results
        return {
            "medusa": {
                "text": medusa_result["text"],
                "generation_time": medusa_result["generation_time"],
                "tokens_generated": medusa_result["tokens_generated"],
                "tokens_per_second": medusa_result["tokens_per_second"],
                "acceptance_rate": medusa_result["acceptance_rate"]
            },
            "standard": {
                "text": standard_text,
                "generation_time": standard_time,
                "tokens_generated": standard_tokens,
                "tokens_per_second": standard_tokens_per_second
            },
            "speedup": medusa_result["tokens_per_second"] / standard_tokens_per_second if standard_tokens_per_second > 0 else 0
        }

    except Exception as e:
        logger.error(f"Error in benchmark endpoint: {str(e)}")
        raise HTTPException(status_code=500, detail=str(e))

# Function to start the server
async def start_server():
    # No need to initialize model here since it's done in the lifespan context manager
    config = uvicorn.Config(app, host="0.0.0.0", port=8000)
    server = uvicorn.Server(config)
    await server.serve()

# When running the file directly
if __name__ == "__main__":
    import uvicorn

    # For notebook environments
    if 'google.colab' in str(get_ipython()):
        # Run for Colab/Jupyter
        asyncio.run(start_server())
    else:
        # Run for standard Python environments
        uvicorn.run(app, host="0.0.0.0", port=8000)

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
Using `is_flash_attn_available` is deprecated and will be removed in v4.38. Please use `is_flash_attn_2_available` instead.
Using `is_flash_attn_available` is deprecated and will be removed in v4.38. Please use `is_flash_attn_2_available` instead.
INFO:     Started server process [68612]
INFO:     Waiting for application startup.
llama_model_loader: loaded meta data with 26 key-value pairs and 291 tensors from /content/optimized_model/vicuna-7b-v1.3-F16_KM.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = llama
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Vicu

Public URL: https://fd1a-104-196-160-201.ngrok-free.app
INFO:     14.102.161.98:0 - "POST /generate HTTP/1.1" 422 Unprocessable Entity
INFO:     14.102.161.98:0 - "POST /generate HTTP/1.1" 422 Unprocessable Entity


Llama.generate: prefix-match hit

llama_print_timings:        load time =    2204.98 ms
llama_print_timings:      sample time =       0.04 ms /     1 runs   (    0.04 ms per token, 23255.81 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =     756.32 ms /     1 runs   (  756.32 ms per token,     1.32 tokens per second)
llama_print_timings:       total time =     757.24 ms /     1 tokens
Llama.generate: prefix-match hit

llama_print_timings:        load time =    2204.98 ms
llama_print_timings:      sample time =       0.04 ms /     1 runs   (    0.04 ms per token, 23809.52 tokens per second)
llama_print_timings: prompt eval time =       0.00 ms /     0 tokens (    -nan ms per token,     -nan tokens per second)
llama_print_timings:        eval time =     588.10 ms /     1 runs   (  588.10 ms per token,     1.70 tokens per second)
llama_print_timings:       to

In [None]:
!ngrok authtoken


Authtoken saved to configuration file: /root/.config/ngrok/ngrok.yml


In [None]:
!kill -9 9675
