# GPT-OSS-20B on Google Colab for Chartelier

This notebook demonstrates running Chartelier with GPT-OSS-20B on Google Colab using an A100 GPU.

## Prerequisites
- Google Colab Pro+ account (for A100 access)
- GPU runtime enabled (Runtime -> Change runtime type -> A100 GPU)

## Step 1: Clone Repository and Setup Environment

In [None]:
# Clone the Chartelier repository
!git clone https://github.com/sog4be/chartelier.git
%cd chartelier

# Check current branch (should be feature/gpt-oss-20b-colab-support)
!git checkout feature/gpt-oss-20b-colab-support

In [None]:
# Run the setup script
!python colab/setup_gpt_oss.py

## Step 2: Start vLLM Server (Run in Background)

**Important**: This cell will keep running. Start it and then proceed to the next cells while it runs.

In [None]:
# Start vLLM server in the background
# This will download the model (~14GB) on first run
import subprocess
import time

# Start server in background
server_process = subprocess.Popen(
    ["python", "colab/start_vllm_server.py"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
)

print("üöÄ Starting vLLM server...")
print("‚è≥ This may take 2-5 minutes on first run while downloading the model")

# Wait a bit for server to start
time.sleep(10)

# Check if server is starting
import requests

max_attempts = 60  # 5 minutes max
attempt = 0

while attempt < max_attempts:
    try:
        response = requests.get("http://localhost:8000/health", timeout=5)
        if response.status_code == 200:
            print("\n‚úÖ vLLM server is ready!")
            break
    except:
        pass

    if attempt % 6 == 0:  # Print status every 30 seconds
        print(f"‚è≥ Waiting for server... ({attempt * 5}s elapsed)")

    time.sleep(5)
    attempt += 1

if attempt >= max_attempts:
    print("‚ùå Server failed to start. Check the logs above.")
else:
    # Check if model is loaded
    try:
        response = requests.get("http://localhost:8000/v1/models", timeout=5)
        if response.status_code == 200:
            models = response.json().get("data", [])
            if models:
                print(f"‚úÖ Model loaded: {models[0]['id']}")
                print("\nüéâ You can now run the test in the next cell!")
    except:
        print("‚ö†Ô∏è Could not verify model loading")

## Step 3: Test the Setup

First, let's verify the server is working with a simple test:

In [None]:
# Quick test of the vLLM server
import json

import requests

# Test the OpenAI-compatible endpoint
url = "http://localhost:8000/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
    "model": "openai/gpt-oss-20b",
    "messages": [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": "What is 2+2?"},
    ],
    "temperature": 0.0,
    "max_tokens": 50,
}

try:
    response = requests.post(url, headers=headers, json=data, timeout=30)
    if response.status_code == 200:
        result = response.json()
        print("‚úÖ Server test successful!")
        print(f"Response: {result['choices'][0]['message']['content']}")
    else:
        print(f"‚ùå Server returned error: {response.status_code}")
        print(response.text)
except Exception as e:
    print(f"‚ùå Failed to connect to server: {e}")

## Step 4: Run Chartelier E2E Test

In [None]:
# Update the test script to use the Colab environment
import os

# Set environment variables for Chartelier
os.environ["CHARTELIER_LLM_MODEL"] = "openai/gpt-oss-20b"
os.environ["CHARTELIER_LLM_API_BASE"] = "http://localhost:8000/v1"
os.environ["CHARTELIER_LLM_API_KEY"] = "dummy"  # vLLM doesn't need a real key for local
os.environ["CHARTELIER_LLM_TIMEOUT"] = "30"

print("Environment configured:")
print(f"  Model: {os.environ['CHARTELIER_LLM_MODEL']}")
print(f"  API Base: {os.environ['CHARTELIER_LLM_API_BASE']}")
print(f"  Timeout: {os.environ['CHARTELIER_LLM_TIMEOUT']}s")

In [None]:
# Chartelier„ÉÜ„Çπ„ÉàÂÆüË°åÔºàÁõ¥Êé•ÂÆüË°åÁâàÔºâ
import json
import os
import sys
from pathlib import Path

# Áí∞Â¢ÉÂ§âÊï∞„ÇíÁ¢∫Ë™çÔºàÊó¢„Å´Ë®≠ÂÆöÊ∏à„Åø„ÅÆ„ÅØ„ÅöÔºâ
print("Environment variables:")
print(f"  Model: {os.environ.get('CHARTELIER_LLM_MODEL', 'Not set')}")
print(f"  API Base: {os.environ.get('CHARTELIER_LLM_API_BASE', 'Not set')}")
print(f"  Timeout: {os.environ.get('CHARTELIER_LLM_TIMEOUT', 'Not set')}s")
print()

# src„Çí„Éë„Çπ„Å´ËøΩÂä†
sys.path.insert(0, "/content/chartelier/src")

from chartelier.interfaces.mcp.handler import MCPHandler
from chartelier.interfaces.mcp.protocol import JSONRPCRequest, MCPMethod
from chartelier.infra.llm_client import LLMSettings


def test_chartelier():
    """Chartelier„ÅÆ„Ç®„É≥„Éâ„ÉÑ„Éº„Ç®„É≥„Éâ„ÉÜ„Çπ„Éà"""
    print("=" * 60)
    print("üß™ Chartelier End-to-End Test")
    print("=" * 60)

    # Ë®≠ÂÆöÁ¢∫Ë™ç
    settings = LLMSettings()
    print(f"\n‚úÖ Configuration:")
    print(f"   Model: {settings.model}")
    print(f"   API Base: {settings.api_base}")
    print(f"   Timeout: {settings.timeout}s")

    # MCP„Éè„É≥„Éâ„É©„Éº‰ΩúÊàê
    handler = MCPHandler()
    print("‚úÖ MCP handler created")

    # „ÉÜ„Çπ„Éà„Éá„Éº„Çø
    csv_data = """month,sales,category
2024-01,1000,Product A
2024-02,1200,Product A
2024-03,1100,Product A
2024-04,1300,Product A
2024-01,800,Product B
2024-02,900,Product B
2024-03,950,Product B
2024-04,1050,Product B"""

    # ÂèØË¶ñÂåñ„É™„ÇØ„Ç®„Çπ„Éà‰ΩúÊàê
    request = JSONRPCRequest(
        id=1,
        method=MCPMethod.TOOLS_CALL,
        params={
            "name": "chartelier_visualize",
            "arguments": {
                "data": csv_data,
                "query": "Show monthly sales trends for Product A and Product B as a line chart",
                "options": {
                    "format": "svg",
                    "width": 800,
                    "height": 600,
                },
            },
        },
    )

    print("\n‚úÖ Request prepared")
    print(f"üìà Data: {len(csv_data.splitlines()) - 1} rows")
    print(f"üìù Query: 'Show monthly sales trends for Product A and Product B'")
    print(f"üé® Format: SVG (800x600)")

    print("\n" + "=" * 60)
    print("‚ö†Ô∏è  Ready to send request to LLM")
    print(f"üöÄ Using local vLLM server with {settings.model}")
    print("=" * 60)

    # „É™„ÇØ„Ç®„Çπ„ÉàÂá¶ÁêÜ
    print("\n‚è≥ Processing visualization request...")
    print("   Phase 1: Data validation")
    print("   Phase 2: Pattern selection (LLM)")
    print("   Phase 3: Chart selection (LLM)")
    print("   Phase 4: Data processing")
    print("   Phase 5: Data mapping (LLM)")
    print("   Phase 6: Chart building")

    try:
        response_str = handler.handle_message(json.dumps(request.model_dump()))
        response = json.loads(response_str)

        if response.get("result", {}).get("isError"):
            print("\n‚ùå Visualization failed")
            error_msg = response["result"]["content"][0]["text"]
            print(f"   Error: {error_msg}")

            if "structuredContent" in response["result"]:
                error = response["result"]["structuredContent"].get("error", {})
                print(f"   Code: {error.get('code')}")
                if error.get("hint"):
                    print(f"   Hint: {error.get('hint')}")
            return None
        else:
            print("\n‚úÖ Visualization successful!")

            result = response["result"]

            # ÁîªÂÉè„Ç≥„É≥„ÉÜ„É≥„ÉÑÁ¢∫Ë™ç
            if "content" in result and len(result["content"]) > 0:
                content = result["content"][0]
                if content["type"] == "image":
                    print(f"\nüìä Chart generated:")
                    print(f"   MIME type: {content.get('mimeType', 'unknown')}")
                    print(f"   Data size: {len(content.get('data', ''))} characters")

                    # „É°„Çø„Éá„Éº„ÇøË°®Á§∫
                    if "structuredContent" in result and "metadata" in result["structuredContent"]:
                        metadata = result["structuredContent"]["metadata"]
                        print(f"\nüìä Processing metadata:")
                        pattern_id = metadata.get("pattern_id")
                        print(f"   Pattern: {pattern_id} - {get_pattern_description(pattern_id)}")
                        print(f"   Template: {metadata.get('template_id')}")

                        if metadata.get("mapping"):
                            print(f"   Mapping:")
                            for key, value in metadata["mapping"].items():
                                print(f"      {key}: {value}")

                        # Âá¶ÁêÜÊôÇÈñìË°®Á§∫
                        if metadata.get("stats", {}).get("duration_ms"):
                            duration = metadata["stats"]["duration_ms"]
                            total = duration.get("total", 0)
                            print(f"   Processing time: {total:.0f}ms")

                    # SVG„ÇíËøî„Åô
                    if "svg" in content.get("mimeType", ""):
                        return content["data"]

    except Exception as e:
        print(f"\n‚ùå Unexpected error: {e}")
        import traceback

        traceback.print_exc()
        return None

    return None


def get_pattern_description(pattern_id):
    """„Éë„Çø„Éº„É≥ID„ÅÆË™¨Êòé„ÇíÂèñÂæó"""
    patterns = {
        "P01": "Single time series",
        "P02": "Category comparison",
        "P03": "Distribution overview",
        "P12": "Multiple time series comparison",
        "P13": "Distribution over time",
        "P21": "Category differences over time",
        "P23": "Distribution comparison by category",
        "P31": "Overall patterns over time",
        "P32": "Distribution comparison across categories",
    }
    return patterns.get(pattern_id, "Unknown pattern")


# „ÉÜ„Çπ„ÉàÂÆüË°å
svg_data = test_chartelier()

# ÁµêÊûúË°®Á§∫
if svg_data:
    from IPython.display import SVG, display

    print("\nüìä Displaying generated chart:")
    display(SVG(data=svg_data))

    # „Éï„Ç°„Ç§„É´„Å´‰øùÂ≠ò
    output_path = Path("/content/output.svg")
    with open(output_path, "w") as f:
        f.write(svg_data)
    print(f"\nüíæ Chart saved to: {output_path}")
    print("   You can download it from the Files panel on the left")
else:
    print("\n‚ùå Chart generation failed")
    print("\nPlease check:")
    print("1. vLLM server is running (Step 2)")
    print("2. Environment variables are set (Step 4)")
    print("3. No errors in server logs")

## Step 5: View Generated Chart

If the test was successful, display the generated chart:

In [None]:
# Display the generated SVG chart
import os

from IPython.display import SVG, display

output_path = "temp/output.svg"
if os.path.exists(output_path):
    print("üìä Generated Chart:")
    display(SVG(filename=output_path))
else:
    print("‚ùå No output file found. The test may have failed.")

## Optional: Custom Visualization Test

Try creating your own visualization:

In [None]:
# Custom visualization test
import sys

sys.path.insert(0, "src")

from chartelier.interfaces.mcp.handler import MCPHandler
from chartelier.interfaces.mcp.protocol import JSONRPCRequest, MCPMethod

# Sample data - different from the default test
custom_data = """date,temperature,city
2024-01-01,5,Tokyo
2024-01-02,7,Tokyo
2024-01-03,6,Tokyo
2024-01-04,8,Tokyo
2024-01-01,10,Osaka
2024-01-02,12,Osaka
2024-01-03,11,Osaka
2024-01-04,13,Osaka"""

# Create request
handler = MCPHandler()
request = JSONRPCRequest(
    id=2,
    method=MCPMethod.TOOLS_CALL,
    params={
        "name": "chartelier_visualize",
        "arguments": {
            "data": custom_data,
            "query": "Compare daily temperature trends between Tokyo and Osaka",
            "options": {
                "format": "svg",
                "width": 800,
                "height": 600,
            },
        },
    },
)

print("üé® Creating custom visualization...")
print("Query: 'Compare daily temperature trends between Tokyo and Osaka'")

try:
    response_str = handler.handle_message(json.dumps(request.model_dump()))
    response = json.loads(response_str)

    if response.get("result", {}).get("isError"):
        print("‚ùå Visualization failed")
        print(response["result"]["content"][0]["text"])
    else:
        print("‚úÖ Visualization successful!")

        # Save and display the result
        if "content" in response["result"] and len(response["result"]["content"]) > 0:
            content = response["result"]["content"][0]
            if content["type"] == "image" and "svg" in content.get("mimeType", ""):
                svg_data = content["data"]

                # Save to file
                with open("temp/custom_output.svg", "w") as f:
                    f.write(svg_data)

                # Display
                from IPython.display import SVG, display

                display(SVG(data=svg_data))

                # Show metadata
                if "structuredContent" in response["result"]:
                    metadata = response["result"]["structuredContent"].get("metadata", {})
                    print("\nüìä Metadata:")
                    print(f"   Pattern: {metadata.get('pattern_id')}")
                    print(f"   Template: {metadata.get('template_id')}")
                    if metadata.get("processing_time_ms"):
                        print(f"   Processing time: {metadata['processing_time_ms']}ms")

except Exception as e:
    print(f"‚ùå Error: {e}")
    import traceback

    traceback.print_exc()

## Cleanup

Stop the vLLM server when done:

In [None]:
# Stop the vLLM server
try:
    server_process.terminate()
    server_process.wait(timeout=5)
    print("‚úÖ vLLM server stopped")
except:
    print("Server process was not running or already stopped")

## Troubleshooting

### Common Issues:

1. **No GPU available**: Make sure you've selected GPU runtime (Runtime -> Change runtime type -> GPU)
2. **Out of memory**: The A100 40GB should be sufficient, but if you get OOM errors, try restarting the runtime
3. **Model download slow**: First run downloads ~14GB model. This is normal and will be cached for future runs
4. **Server not starting**: Check the server logs in the cell output for specific errors
5. **Connection refused**: Make sure the vLLM server cell is still running