## Dockerfile

```docker
FROM pytorch/pytorch:2.1.0-cuda11.8-cudnn8-runtime

WORKDIR /app

# Copy all project files (including helper.py, inference.py, requirements.txt)
COPY . /app

# Install dependencies
RUN pip install --upgrade pip
RUN pip install -r requirements.txt

# Copy the serve script into a global path SageMaker can call
COPY serve /usr/bin/serve
RUN chmod +x /usr/bin/serve

# SageMaker expects inference port 8080
EXPOSE 8080

# Set the ENTRYPOINT for SageMaker (this enables `serve`)
ENTRYPOINT ["/usr/bin/serve"]
```

## Python Scripts

`helper.py`

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from typing import List, Dict

class GemmaChat:
    def __init__(self, model_path: str, token: str, system_prompt: str):
        """
        Initialize the Gemma model, tokenizer, and system prompt.

        Args:
            model_path (str): Hugging Face model repository path.
            token (str): Hugging Face access token.
            system_prompt (str): The system prompt used for instruction tuning.
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_path, use_auth_token=token)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_path,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
            use_auth_token=token
        )
        self.system_prompt = system_prompt
        self.messages: List[Dict[str, str]] = [
            {"role": "system", "content": self.system_prompt}
        ]

    def append_message(self, role: str, content: str) -> List[Dict[str, str]]:
        """
        Append a message to the ongoing conversation.

        Args:
            role (str): The role of the message sender, e.g., "user" or "assistant".
            content (str): The content of the message.

        Returns:
            List[Dict[str, str]]: The updated list of messages.
        """
        self.messages.append({"role": role, "content": content})
        return self.messages

    def generate_answer(self, prompt: str, max_tokens: int = 1024) -> str:
        """
        Generate an answer using chat-style prompting with Gemma-3 settings.

        Args:
            prompt (str): The user's input prompt.
            max_tokens (int, optional): Maximum number of tokens to generate. Defaults to 1024.

        Returns:
            str: The decoded model response.
        """
        self.append_message("user", prompt)
        chat_text = self.tokenizer.apply_chat_template(
            self.messages,
            add_generation_prompt=True,
            tokenize=False,
        )
        inputs = self.tokenizer(chat_text, return_tensors="pt").to(self.model.device)

        outputs = self.model.generate(
            **inputs,
            max_new_tokens=max_tokens,
            temperature=1.0,
            top_p=0.95,
            top_k=64,
            pad_token_id=self.tokenizer.eos_token_id,
        )

        decoded = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        self.append_message("assistant", decoded)
        return decoded
```

`inference.py`

```python
from fastapi import FastAPI, Request
from pydantic import BaseModel
from helper import GemmaChat
import uvicorn
import os

app = FastAPI()

# Load environment variables
model_path = os.getenv("MODEL_PATH", "google/gemma-3-4b-it")
hf_token = os.getenv("HF_TOKEN", "hf_xxx")
system_prompt = os.getenv("SYSTEM_PROMPT", "You are a helpful assistant that solves math problems directly by stating the final answer.")

chat = GemmaChat(model_path=model_path, token=hf_token, system_prompt=system_prompt)

class Prompt(BaseModel):
    prompt: str

@app.get("/ping")
def ping():
    return {"status": "ok"}

@app.post("/invocations")
async def invoke(prompt: Prompt):
    response = chat.generate_answer(prompt.prompt)
    return {"response": response}

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=8080)
```

## Serve

`serve` script

```bash
#!/bin/bash
echo "Starting inference server with uvicorn..."
uvicorn inference:app --host 0.0.0.0 --port 8080
```

## Requirements

`requirements.txt`

```text
transformers==4.50.2
datasets
tqdm
accelerate>=0.26.0
torch
fastapi
uvicorn
```

## Deploy Endpoint on Sagemaker

```python
import boto3
import sagemaker
from sagemaker.model import Model

role = "arn:aws:iam::<account-id>:role/<sagemaker-execution-role>"
region = "us-west-2"
image_uri = "<account-id>.dkr.ecr.us-west-2.amazonaws.com/sagemaker-gemma-inference:latest"

sagemaker_session = sagemaker.Session(boto3.Session(region_name=region))

model = Model(
    image_uri=image_uri,
    role=role,
    sagemaker_session=sagemaker_session,
)

predictor = model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.12xlarge",  # or g5.2xlarge for testing
)
```

Or you can use the following which means the Sagemaker notebook role is what you are using.

```python
import boto3
import sagemaker
from sagemaker import get_execution_role
from sagemaker.model import Model

# Automatically retrieve the execution role from the current environment
role = get_execution_role()

region = "us-west-2"
image_uri = "<account-id>.dkr.ecr.us-west-2.amazonaws.com/sagemaker-gemma-inference:latest"

# Create a SageMaker session
sagemaker_session = sagemaker.Session(boto3.Session(region_name=region))

# Define the model
model = Model(
    image_uri=image_uri,
    role=role,
    sagemaker_session=sagemaker_session,
)

# Deploy the model
predictor = model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.12xlarge",  # or use "ml.g5.2xlarge" for testing
)
```

## Lambda

```python
import json
import boto3
import traceback

def lambda_handler(event, context):
    """
    Lambda function to invoke a SageMaker endpoint with a given prompt.
    Expects 'prompt' and 'endpoint_name' as input keys in the event.

    Returns:
        JSON response containing model output or error.
    """
    # Default CORS headers
    headers = {
        "Access-Control-Allow-Origin": "*",
        "Access-Control-Allow-Methods": "POST, OPTIONS",
        "Access-Control-Allow-Headers": "Content-Type"
    }

    try:
        # Handle preflight CORS
        if event.get("httpMethod") == "OPTIONS":
            return {
                "statusCode": 200,
                "headers": headers,
                "body": json.dumps({"message": "CORS preflight check passed"})
            }

        # Parse incoming body (supports both direct and JSON-stringified bodies)
        body = event.get("body")
        if isinstance(body, str):
            body = json.loads(body)

        prompt = body.get("prompt")
        endpoint_name = body.get("endpoint_name")

        if not prompt or not endpoint_name:
            return {
                "statusCode": 400,
                "headers": headers,
                "body": json.dumps({"error": "Missing 'prompt' or 'endpoint_name'"})
            }

        # Invoke SageMaker
        runtime = boto3.client("sagemaker-runtime", region_name="us-west-2")
        response = runtime.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType="application/json",
            Body=json.dumps({"prompt": prompt})
        )

        result = json.loads(response["Body"].read().decode())

        return {
            "statusCode": 200,
            "headers": headers,
            "body": json.dumps({"response": result.get("response")})
        }

    except Exception as e:
        print("Error:", str(e))
        traceback.print_exc()

        return {
            "statusCode": 500,
            "headers": headers,
            "body": json.dumps({"error": "Internal server error", "details": str(e)})
        }
```

Payload in Lambda console

```json
{
    "body": "{\"prompt\": \"What is the derivative of x^2?\", \"endpoint_name\": \"sagemaker-888577059780-model-serving-ge-2025-04-02-22-07-37-220\"}",
    "httpMethod": "POST"
}
```

Payload in API test

```json
{
    "prompt": "What is the derivative of x^2?",
    "endpoint_name": "sagemaker-888577059780-model-serving-ge-2025-04-02-22-07-37-220"
}
```

In-line Policy

```json
{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Effect": "Allow",
      "Action": "sagemaker:InvokeEndpoint",
      "Resource": "arn:aws:sagemaker:us-west-2:888577059780:endpoint/sagemaker-888577059780-model-serving-ge-2025-04-02-22-07-37-220"
    }
  ]
}
```