# Deploying Gemma2 9B Model with FastAPI

This tutorial will guide you through deploying the Gemma2 9B model (from Hugging Face) using FastAPI. We will build a FastAPI application to serve the model, handle requests, and return predictions.

---

## 1. Install Required Libraries
We need to install the necessary libraries for our application, including `fastapi`, `uvicorn`, and `transformers`.

**Explanation**
- `fastapi`: Framework for building APIs.
- `uvicorn`: ASGI server to run the FastAPI app.
- `transformers`: For working with the Gemma2 9B model from Hugging Face.
- `torch`: For PyTorch-based model inference.
- `python-dotenv`: For loading environment variables from a `.env` file.

In [None]:
!source ./fastapi-env/bin/activate && pip install -q transformers torch python-dotenv 'accelerate>=0.26.0'

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from dotenv import load_dotenv
import os
import torch
from helper import FastAPIServer

## 2. Load the Gemma2 9B Model
We'll use the transformers library to load the Gemma2 9B model and its tokenizer.

**Explanation**
* The `AutoModelForCausalLM` class loads the pre-trained Gemma2 model.
* The `AutoTokenizer` handles input text tokenization.
* Use `from_pretrained()` to load the model and tokenizer from Hugging Face.

For secure access to private or restricted Hugging Face models, we will store the authentication token in a .env file. Python's os and dotenv libraries will retrieve the token at runtime.

**Steps to Securely Store the Token**

> You must have a Hugging Face account and a token to access private models. Since user interfaces changes frequently, please refer to the Hugging Face documentation for the most up-to-date instructions on how to obtain a token. You wil also need to aggregate the model's name and the organization name to access the model.

1. Create a .env file in your project directory.
2. Add the Hugging Face token to the .env file.
3. Use the dotenv library to load the token into your application.

**Explanation**
- The `.env` file keeps sensitive information like tokens separate from your codebase.
- `os.getenv()` retrieves the token without hardcoding it in your script.
- `use_auth_token` ensures the Hugging Face API uses the provided token for authentication.

In [None]:
# Check for CUDA, MPS, and fallback to CPU
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# Hugging Face token and save directory
GEMA_TOKEN = os.getenv("GEMMA_TOKEN")
save_directory = "saved_model"

# Temporary code to prevent issues with forked processes
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Try loading the model and tokenizer
try:
    print("Attempting to load the model from the saved directory...")
    loaded_model = AutoModelForCausalLM.from_pretrained(save_directory).to(device)
    loaded_tokenizer = AutoTokenizer.from_pretrained(save_directory)
    print("Model and tokenizer loaded successfully from the saved directory.")
except Exception as e:
    print(f"Failed to load model from saved directory: {e}")
    print("Downloading model and tokenizer from Hugging Face...")
    
    # Download and save the model and tokenizer
    loaded_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it", use_auth_token=GEMA_TOKEN)
    loaded_model = AutoModelForCausalLM.from_pretrained(
        "google/gemma-2-9b-it",
        torch_dtype=torch.float32,
        use_auth_token=GEMA_TOKEN
    ).to(device)

    # Save to the directory
    loaded_model.save_pretrained(save_directory)
    loaded_tokenizer.save_pretrained(save_directory)
    print("Model and tokenizer downloaded and saved successfully.")

# Test functionality
sample_text = "Once upon a time"
inputs = loaded_tokenizer(sample_text, return_tensors="pt").to(device)
outputs = loaded_model.generate(inputs["input_ids"])
print(loaded_tokenizer.decode(outputs[0], skip_special_tokens=True))

## 3. Create a FastAPI Application
Set up a basic FastAPI application to define endpoints for model inference.

**Explanation**
- Define a root endpoint to verify the API is running.
- Create a `POST` endpoint /predict for model inference.
- Use `pydantic` for request/response models.

In [None]:
%%writefile main.py
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

# Load the saved model and tokenizer
save_directory = "saved_model"
loaded_model = AutoModelForCausalLM.from_pretrained(save_directory)
loaded_tokenizer = AutoTokenizer.from_pretrained(save_directory)

# Initialize FastAPI app
app = FastAPI()

# Define request and response models
class RequestModel(BaseModel):
    prompt: str
    max_length: int = 50  # Optional: max tokens in the response

class ResponseModel(BaseModel):
    generated_text: str

# Root endpoint
@app.get("/")
async def root():
    return {"message": "Gemma2 API is running"}

# Prediction endpoint
@app.post("/predict", response_model=ResponseModel)
async def predict(request: RequestModel):
    inputs = loaded_tokenizer(request.prompt, return_tensors="pt")
    outputs = loaded_model.generate(inputs["input_ids"], max_length=request.max_length)
    generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return {"generated_text": generated_text}

In [None]:
!kill -9 $(lsof -t -i :8081)

In [None]:
port = 8081
fastapi = FastAPIServer(port)
fastapi.run()

In [None]:
!curl -X POST "http://127.0.0.1:8081/predict" \
-H "Content-Type: application/json" \
-d '{"prompt": "Hello", "max_length": 100}'

## 4. Real-time Model Serving

### Server

In [None]:
%%writefile main.py
import asyncio
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Depends
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("websocket")

# Load the saved model and tokenizer
save_directory = "saved_model"
loaded_model = AutoModelForCausalLM.from_pretrained(save_directory)
loaded_tokenizer = AutoTokenizer.from_pretrained(save_directory)

# Define system prompt
SYSTEM_PROMPT = """
You are a helpful AI assistant. Answer concisely and clearly. Your name is not Gemma. It is SilverAIWolf.
"""

# Create an instance of FastAPI
app = FastAPI()

# Add CORS middleware for frontend-backend communication
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Change to specific domains in production
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class ConnectionManager:
    def __init__(self):
        self.active_connections: list[WebSocket] = []

    async def connect(self, websocket: WebSocket):
        await websocket.accept()
        self.active_connections.append(websocket)
        logger.info(f"New connection established. Active connections: {len(self.active_connections)}")

    def disconnect(self, websocket: WebSocket):
        if websocket in self.active_connections:
            self.active_connections.remove(websocket)
            logger.info(f"Connection removed. Active connections: {len(self.active_connections)}")

    async def broadcast(self, message: str):
        # logger.info(f"Broadcasting message: {message}")
        for connection in self.active_connections:
            try:
                await connection.send_text(message)
            except Exception as e:
                logger.error(f"Error broadcasting to connection: {e}")
                self.disconnect(connection)

# Create an instance of the ConnectionManager
manager = ConnectionManager()

@app.websocket("/ws/chat")
async def websocket_chat_endpoint(websocket: WebSocket):
    await manager.connect(websocket)
    try:
        while True:  
            try:
                data = await asyncio.wait_for(websocket.receive_text(), timeout=60)  
                if data == 'ping':
                    logger.info("Server received the ping")
                    
                    # Respond to ping requests to keep the connection alive
                    await websocket.send_text('pong')
                    continue
                    
            except asyncio.TimeoutError:
                logger.info("Connection timeout.")
                break

            full_prompt = f"{SYSTEM_PROMPT}\nUser: {data}\nAI:"
            inputs = loaded_tokenizer(full_prompt, return_tensors="pt")

            # Respond with the entire reply
            outputs = loaded_model.generate(inputs["input_ids"], max_length=500)
            generated_text = loaded_tokenizer.decode(outputs[0], skip_special_tokens=True)
            response = response = generated_text.split("AI:")[-1].strip()
            await manager.broadcast(response)
                
    except WebSocketDisconnect:
        manager.disconnect(websocket)
        await manager.broadcast("A user has left the chat.")

    except Exception as e:
        logger.error(f"Error in WebSocket handler: {e}")

### Client

In [None]:
import asyncio
import nest_asyncio
import websockets

nest_asyncio.apply()

PING_INTERVAL = 10  # Send a ping every 30 seconds

async def send_ping(websocket):
    """Periodically send a ping message to keep the connection alive."""
    while True:
        try:
            await asyncio.sleep(PING_INTERVAL)
            print("[Client] Sending ping...")
            await websocket.send("ping")
            print("ping sent!")
        except (websockets.exceptions.ConnectionClosed, asyncio.CancelledError):
            print("[Client] Stopping ping task: Connection closed.")
            break

async def chat_client(port):
    """ WebSocket chat client with ping handling and reconnection """
    uri = f"ws://localhost:{port}/ws/chat"

    while True:
        try:
            async with websockets.connect(uri) as websocket:
                print("Connected to the chat server. Type 'exit' to quit.")
                
                # Start sending pings in the background
                asyncio.create_task(send_ping(websocket))

                while True:
                    message = input("You: ")
                    
                    if message.lower() == "exit":
                        print("Exiting chat...")
                        return
                    
                    await websocket.send(message)
                    response = await websocket.recv()

                    # Check if the response is a pong
                    if response == "pong":
                        print("[Client] Received pong from server (connection is alive)")
                        continue

                    print(f"SilverAIWolf: {response}")

        except (websockets.exceptions.ConnectionClosed, ConnectionRefusedError) as e:
            print("[Client] Connection lost. Reconnecting in 5 seconds...")
            await asyncio.sleep(5)  # Wait before retrying

# Run the chat client
asyncio.get_event_loop().run_until_complete(chat_client(port))