In [None]:
import requests
import json
import warnings
from urllib3.exceptions import InsecureRequestWarning

# Hide warnings about insecure TLS connections
warnings.filterwarnings("ignore", category=InsecureRequestWarning)

In [None]:
API_ENDPOINT = "<inference_endpoint>/v1/chat/completions"

In [None]:
def chat_completion(messages, 
                    model="google/gemma-2-2b-instruct",
                    top_p=1,
                    n=1,
                    max_tokens=50,
                    stream=False,
                    frequency_penalty=0.0,
                    stop=None):
    """
    Calls the /v1/chat/completions endpoint on the NIM container.
    Expects a list of messages for a 'chat' style model.
    """
    # 2. Build the payload according to the logs
    payload = {
        "model": model,
        "messages": messages,
        "top_p": top_p,
        "n": n,
        "max_tokens": max_tokens,
        "stream": stream,
        "frequency_penalty": frequency_penalty
    }
    
    # Optional: if you want to specify stop tokens
    if stop is not None:
        payload["stop"] = stop

    try:
        # 3. Make the POST request
        response = requests.post(
            API_ENDPOINT,
            headers={"Content-Type": "application/json"},
            data=json.dumps(payload),
            timeout=30,
            verify=False
        )
        # 4. Raise exception on non-2xx response
        response.raise_for_status()
        return response.json()
    except requests.exceptions.RequestException as err:
        print(f"Request error: {err}")
        return None

In [None]:
if __name__ == "__main__":
    # 5. Example usage: a small chat
    messages = [
        {"role": "user", "content": "Hello! How are you?"},
        {"role": "assistant", "content": "Hi! I'm well, how can I help you today?"},
        {"role": "user", "content": "Can you write me a short poem?"}
    ]
    
    result = chat_completion(
        messages,
        max_tokens=30,        # number of tokens in the response
        stream=False,         # if True, you'd have to handle chunks
        frequency_penalty=1.0 # as in your example
    )
    
    print("Inference result:", result)