# Deploying Gemma2 9B Model with FastAPI

This tutorial will guide you through deploying the Gemma2 9B model (from Hugging Face) using FastAPI. We will build a FastAPI application to serve the model, handle requests, and return predictions.

---

## 1. Install Required Libraries
We need to install the necessary libraries for our application, including `fastapi`, `uvicorn`, and `transformers`.

**Explanation**
- `fastapi`: Framework for building APIs.
- `uvicorn`: ASGI server to run the FastAPI app.
- `transformers`: For working with the Gemma2 9B model from Hugging Face.
- `torch`: For PyTorch-based model inference.
- `python-dotenv`: For loading environment variables from a `.env` file.

In [None]:
!source ./fastapi-env/bin/activate && pip install -q transformers torch python-dotenv 'accelerate>=0.26.0' websockets

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
import os
import torch
from helper import FastAPIServer
import asyncio
import nest_asyncio
import websockets

## 2. Load the Gemma2 9B Model
We'll use the transformers library to load the Gemma2 9B model and its tokenizer.

**Explanation**
* The `AutoModelForCausalLM` class loads the pre-trained Gemma2 model.
* The `AutoTokenizer` handles input text tokenization.
* Use `from_pretrained()` to load the model and tokenizer from Hugging Face.

For secure access to private or restricted Hugging Face models, we will store the authentication token in a .env file. Python's os and dotenv libraries will retrieve the token at runtime.

**Steps to Securely Store the Token**

> You must have a Hugging Face account and a token to access private models. Since user interfaces changes frequently, please refer to the Hugging Face documentation for the most up-to-date instructions on how to obtain a token. You wil also need to aggregate the model's name and the organization name to access the model.

1. Create a .env file in your project directory.
2. Add the Hugging Face token to the .env file.
3. Use the dotenv library to load the token into your application.

**Explanation**
- The `.env` file keeps sensitive information like tokens separate from your codebase.
- `os.getenv()` retrieves the token without hardcoding it in your script.
- `use_auth_token` ensures the Hugging Face API uses the provided token for authentication.

In [None]:
# Check for CUDA, MPS, and fallback to CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# Hugging Face token and save directory
GEMA_TOKEN = os.getenv("GEMMA_TOKEN")
save_directory = "saved_model"

# Temporary code to prevent issues with forked processes
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Try loading the model and tokenizer
try:
    print("Attempting to load the model from the saved directory...")
    loaded_model = AutoModelForCausalLM.from_pretrained(save_directory).to(device)
    loaded_tokenizer = AutoTokenizer.from_pretrained(save_directory)
    print("Model and tokenizer loaded successfully from the saved directory.")
except Exception as e:
    print(f"Failed to load model from saved directory: {e}")
    print("Downloading model and tokenizer from Hugging Face...")
    
    # Download and save the model and tokenizer
    loaded_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it", use_auth_token=GEMA_TOKEN)
    loaded_model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-9b-it",
        torch_dtype=torch.float32,
        use_auth_token=GEMA_TOKEN
    ).to(device)

    # Save to the directory
    loaded_model.save_pretrained(save_directory)
    loaded_tokenizer.save_pretrained(save_directory)
    print("Model and tokenizer downloaded and saved successfully.")

# Test functionality
sample_text = "Once upon a time"
inputs = loaded_tokenizer(sample_text, return_tensors="pt").to(device)
outputs = loaded_model.generate(inputs["input_ids"])
print(loaded_tokenizer.decode(outputs[0], skip_special_tokens=True))

## 3. **Gemma2 API: Text Generation with FastAPI & Transformers**  

This script sets up an AI-powered text generation API using FastAPI and a pre-trained transformer model. It provides a simple interface for generating text based on user input. This script will show you how to make a simple API that can be used with curl, Postman, or any other HTTP client. Later, we will focus on deploying a more intuitive and user-friendly interface within the jupyter notebook.

#### **1️⃣ Model Initialization**  
- Loads a pre-trained language model (`AutoModelForCausalLM`) and its tokenizer from a saved directory (`saved_model`).  
- Ensures the model is ready for inference as soon as the API starts.  

#### **2️⃣ FastAPI Setup**  
- Initializes FastAPI (`app = FastAPI()`) to provide an interface for text generation.  
- Defines structured request and response models using Pydantic as an example of its usage (another tutorial will cover Pydantic in more detail):  
  - `RequestModel`: Accepts a prompt and an optional `max_length` for response generation.  
  - `ResponseModel`: Returns generated text in a structured format.  

#### **3️⃣ API Endpoints**  
- **`/` (GET)** - Returns a simple status message confirming the API is running.  
- **`/predict` (POST)** - Generates text based on the user's prompt using the AI model.  

#### **4️⃣ Text Generation Process**  
1. Receives a JSON request containing a user prompt and optional `max_length`.  
2. Encodes the input using the tokenizer.  
3. Passes it to the AI model, which generates a continuation.  
4. Decodes and returns the output in JSON format.  

#### **Key Benefits**  
- Fast and lightweight API using FastAPI for quick response times.  
- Easy integration with any frontend or automation system.  
- Customizable response length for better control over AI output.  
- Efficient model inference running asynchronously.  

#### **Next Steps**  
1. Run the API server:  
   ```sh
   uvicorn main:app --host 127.0.0.1 --port 8080
   ```  
   In this case this is wrapped inside the following line of code:  
   ```python
    port = 8080
    fastapi = FastAPIServer(port)
    fastapi.run()
   ```
2. Test in a browser by visiting `http://127.0.0.1:8080/` to check if the API is running.  
3. Send a prediction request using `cURL` or Postman:  
   ```sh
   curl -X 'POST' 'http://127.0.0.1:8080/predict' \
        -H 'Content-Type: application/json' \
        -d '{"prompt": "Hello, AI!", "max_length": 50}'
   ```  
4. Integrate the API into chatbots, content generation tools, or automation workflows.  

This API provides a simple yet powerful way to interact with AI-generated text, making it ideal for NLP applications.

Tip `!kill -9 $(lsof -t -i :8080)` to kill the server if it remains open and you are having conflicts with it.

In [None]:
%%writefile main.py
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# Load the saved model and tokenizer
save_directory = "saved_model"
loaded_model = AutoModelForCausalLM.from_pretrained(save_directory)
loaded_tokenizer = AutoTokenizer.from_pretrained(save_directory)

# Initialize FastAPI app
app = FastAPI()

# Define request and response models
class RequestModel(BaseModel):
    prompt: str
    max_length: int = 50  # Optional: max tokens in the response

class ResponseModel(BaseModel):
    generated_text: str

# Root endpoint
@app.get("/")
async def root():
    return {"message": "Gemma2 API is running"}

# Prediction endpoint
@app.post("/predict", response_model=ResponseModel)
async def predict(request: RequestModel):
    inputs = loaded_tokenizer(request.prompt, return_tensors="pt")
    outputs = loaded_model.generate(inputs["input_ids"], max_length=request.max_length)
    generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return {"generated_text": generated_text}

In [None]:
port = 8080
fastapi = FastAPIServer(port)
fastapi.run()

In [None]:
!curl -X POST "http://127.0.0.1:8080/predict" \
-H "Content-Type: application/json" \
-d '{"prompt": "Hello, AI", "max_length": 50}'

## 4. WebSocket-Based AI Chat Server with FastAPI

In this section we will create a wbsocket/fastapi server to serve the model in real-time in the local machine. You should have a basic understanding of FastAPI and Websockets to follow along and a machone with a GPU to run the model.

### Server

This script sets up a **WebSocket server using FastAPI** to handle real-time communication with an AI-powered assistant. Below is a structured breakdown of its functionality:

#### **1️⃣ Logger Configuration**
- Configures **critical-level logging** to reduce unnecessary log output.

#### **2️⃣ AI Model Loading**
- Loads a pre-trained **transformer model** (`AutoModelForCausalLM`) and tokenizer from a saved directory.
- Defines a **system prompt** to guide AI responses.

#### **3️⃣ WebSocket Handling (`/chat` endpoint)**
- Accepts **WebSocket connections** and maintains active connections in a list.
- Listens for incoming messages:
  - **Ignores "ping" messages** (heartbeat signals from clients). These pings help keep the connection alive.
  - **Processes user input** using the AI model and system prompt.
  - **Runs model inference asynchronously** with `asyncio.to_thread()` to prevent blocking.
  - **Sends AI responses** back to all connected clients.

#### **4️⃣ Heartbeat Mechanism**
- A separate `model_heartbeats()` task **sends periodic "processing..." messages** every 5 seconds.
- Ensures WebSocket remains open while the AI model generates a response.
- The heartbeat task is **canceled once the response is ready** to avoid overprocessing.

#### **5️⃣ Graceful Disconnection Handling**
- **Removes clients from the active connections list** when they disconnect.
- Catches **WebSocket errors** to prevent crashes.

This example demonstrates how to **integrate AI models with WebSocket servers** for interactive applications, chatbots, and more. Feel free to customize the AI model, system prompts, and WebSocket handling to suit your use case! 🤖🌐

In [None]:
%%writefile main.py
import asyncio
import logging
from fastapi import FastAPI, WebSocket
from transformers import AutoModelForCausalLM, AutoTokenizer
from typing import List

# LOGGER CONFIGURATION
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    level=logging.CRITICAL,
)
logger = logging.getLogger(__name__)

# MODEL CONFIGURATION
# Load the saved model and tokenizer
save_directory = "saved_model"
loaded_model = AutoModelForCausalLM.from_pretrained(save_directory)
loaded_tokenizer = AutoTokenizer.from_pretrained(save_directory)

# Define system prompt
SYSTEM_PROMPT = """
You are a helpful AI assistant. Answer concisely and clearly. Your name is not Gemma. It is SilverAIWolf.
"""

app = FastAPI()
connections: List[WebSocket] = []

@app.websocket("/chat")
async def websocket_endpoint(websocket: WebSocket):
    """Handles WebSocket connections."""
    await websocket.accept()
    connections.append(websocket)
    client_ip = websocket.client.host
    logger.info(f"Client {client_ip} connected. Total connections: {len(connections)}")

    try:
        while True:
            data = await websocket.receive_text()
            logger.debug(data)
            if data == "ping":  # Ignore heartbeat messages
                logger.debug(f"Received heartbeat from {client_ip}")
                continue

            # Start heartbeat task
            model_heartbeat_task = asyncio.create_task(model_heartbeats(websocket))

            # Run the model asynchronously using `asyncio.to_thread`
            full_prompt = f"{SYSTEM_PROMPT}\nUser: {data}\nAI:"
            inputs = loaded_tokenizer(full_prompt, return_tensors="pt")

            # Process the model response in a separate thread to avoid blocking the model heartbeats
            response = await asyncio.to_thread(
                lambda: generate_response(loaded_model, loaded_tokenizer, inputs)
            )

            # Stop heartbeats and send the final response
            model_heartbeat_task.cancel()

            # Send the response to all connected clients
            for conn in connections:
                await conn.send_text(response)

    except Exception as e:
        logger.error(f"WebSocket error with {client_ip}: {e}")

    finally:
        connections.remove(websocket)
        logger.info(f"Client {client_ip} disconnected. Total connections: {len(connections)}")


async def model_heartbeats(websocket):
    logger.debug("Model heartbeat started")
    counter = 0
    while True:
        await asyncio.sleep(5)
        try:
            await websocket.send_text("processing" + "." * ((counter % 30) + 1))
            counter += 1
        except Exception as e:
            logger.error(f"Heartbeat error: {e}")
            break


def generate_response(model, tokenizer, inputs):
    """Runs the model processing in a separate thread to avoid blocking."""
    outputs = model.generate(inputs["input_ids"], max_length=500)
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text.split("AI:")[-1].strip()

### Client

### **Interactive WebSocket Chat Client for SilverAIWolf**

This script sets up an **interactive chat client** that connects to the **SilverAIWolf WebSocket server**. It enables real-time communication with the AI assistant while ensuring smooth user experience.


### **1️⃣ WebSocket Connection**
- Establishes a **WebSocket connection** with the server at `ws://127.0.0.1:8080/chat` using the `websockets` library.
- The connection remains **open for continuous interaction** with the AI assistant.


### **2️⃣ Unique Welcome Screen**
- Displays an **ASCII logo** and an **engaging introduction**.
- Highlights key features:
  - **Instant AI-powered responses**.
  - **Simple commands** (`exit` to close the chat).
  - **Immersive chat experience**.


### **3️⃣ User Input Handling**
- Uses `asyncio.to_thread(input, "You: ")` to **avoid blocking the event loop**, allowing WebSocket communication to run efficiently.
- Ensures **graceful exit** when the user types `"exit"` or `"bye"`.


### **4️⃣ Real-Time Message Processing**
- Sends the **user's message** to the WebSocket server.
- **Waits for responses** while handling:
  - **"Processing..." messages**: Keeps the chat interactive during AI processing.
  - **Final AI responses**: Displays neatly formatted output.


### **5️⃣ User-Friendly Output**
- **"Processing..." updates** overwrite the same console line instead of spamming multiple lines.
- **AI responses** are displayed in a structured format:
  ```
  SilverAIWolf: <AI Response>
  ------------------------------------------------------------
  ```
- **Farewell Message**: Displays `"Goodbye, wanderer! Until next time. 🐺✨"` when exiting.


### **✅ Key Benefits**
| **Feature** | **Description** |
|------------|----------------|
| ✅ **Seamless WebSocket Communication** | Ensures smooth interaction with the AI assistant. |
| ✅ **Interactive UI** | Custom ASCII branding and clear instructions. |
| ✅ **Efficient Message Handling** | Prevents UI lag using `asyncio.to_thread()`. |
| ✅ **Real-Time Processing Updates** | Shows `"processing..."` instead of making the user wait silently. |
| ✅ **Graceful Disconnection** | Properly exits and displays a friendly goodbye message. |


### **🚀 Next Steps**
1. **Run the script** and test chat interactions.
2. **Ensure the WebSocket server is running (`main.py`)**.
3. **Try different inputs** to see AI responses in action.
4. **Enjoy chatting with SilverAIWolf!** 🐺✨

---

This tutorial ensures **a smooth, engaging, and efficient** WebSocket chat experience with **SilverAIWolf**! 🚀🔥

In [None]:
import asyncio
import websockets

async def chat():
    uri = "ws://127.0.0.1:8080/chat"
    
    async with websockets.connect(uri) as websocket:
        print("""
░██████╗██╗██╗░░░░░██╗░░░██╗███████╗██████╗░░█████╗░██╗░██╗░░░░░░░██╗░█████╗░██╗░░░░░███████╗
██╔════╝██║██║░░░░░██║░░░██║██╔════╝██╔══██╗██╔══██╗██║░██║░░██╗░░██║██╔══██╗██║░░░░░██╔════╝
╚█████╗░██║██║░░░░░╚██╗░██╔╝█████╗░░██████╔╝███████║██║░╚██╗████╗██╔╝██║░░██║██║░░░░░█████╗░░
░╚═══██╗██║██║░░░░░░╚████╔╝░██╔══╝░░██╔══██╗██╔══██║██║░░████╔═████║░██║░░██║██║░░░░░██╔══╝░░
██████╔╝██║███████╗░░╚██╔╝░░███████╗██║░░██║██║░░██║██║░░╚██╔╝░╚██╔╝░╚█████╔╝███████╗██║░░░░░
╚═════╝░╚═╝╚══════╝░░░╚═╝░░░╚══════╝╚═╝░░╚═╝╚═╝░░╚═╝╚═╝░░░╚═╝░░░╚═╝░░░╚════╝░╚══════╝╚═╝░░░░░

Welcome to SilverAIWolf Chat! 🐺⚡
Your intelligent AI assistant, always ready to help!

🔹 Type your messages below and get instant AI-powered responses.
🔹 Type `bye` or `exit` to close the chat.
🔹 AI is listening... Let's chat! 🚀
        """)

        while True:
            msg = await asyncio.to_thread(input, "\nYou: ")  # Blocks input until response is received
            if msg.lower() in ["bye", "exit"]:
                print("\nSilverAIWolf: Goodbye, wanderer! Until next time. 🐺✨\n")
                break
            
            await websocket.send(msg)  # Send user input
            
            while True:
                response = await websocket.recv()
                if "processing" in response:
                    print(response, end="\r", flush=True)  # Overwrites console with "processing..."
                else:
                    print(f"\nSilverAIWolf: {response}\n", end="-"*60)  # Prints final response
                    break  # Exit processing loop before taking new input

# Run the chat function asynchronously
await chat()


# END OF NOTEBOOK