# Model Playground

This notebook lets you:
- Download and load checkpoints from training runs
- Test the model with conversation history
- See full reasoning and tool calls
- Simulate text messages and actions


In [None]:
import os
import sys
import subprocess
import tempfile
from pathlib import Path
import json
import re
import random

from datetime import datetime
from typing import List, Dict, Optional

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import yaml

# Add training directory to path (notebook is in training/notebooks/)
notebook_dir = Path.cwd()
training_dir = notebook_dir.parent
sys.path.insert(0, str(training_dir))

from dataloader import format_message
from config import Config

# Workaround: some transformers versions fail when config contains torch.dtype (not JSON-serializable in logger)
import transformers.configuration_utils as _conf_utils
_orig_to_dict = _conf_utils.PretrainedConfig.to_dict
def _to_dict_dtype_safe(self):
    d = _orig_to_dict(self)
    for k, v in list(d.items()):
        if isinstance(v, torch.dtype):
            d[k] = str(v)
        elif hasattr(v, "__class__") and v.__class__.__name__ == "dtype":
            d[k] = str(v)
    return d
_conf_utils.PretrainedConfig.to_dict = _to_dict_dtype_safe

# Device detection for Mac (MPS), CUDA, or CPU
def get_device():
    """Get the best available device for inference."""
    if torch.cuda.is_available():
        return "cuda"
    elif torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

# Set default device
DEVICE = get_device()
print(f"Using device: {DEVICE}")
if DEVICE == "mps":
    print("Note: MPS (Metal) acceleration enabled for Mac GPU")


In [None]:
def find_latest_checkpoint_remote(job_id: str, remote_path: str = "/projects/llpr/amiri.ry/projects/yap-for-me/training") -> Optional[str]:
    """Find the latest checkpoint in a remote job directory.
    
    Args:
        job_id: SLURM job ID
        remote_path: Remote path to training directory
    
    Returns:
        Path to latest checkpoint directory, or None if not found
    """
    remote_job_dir = f"{remote_path}/out/{job_id}"
    
    print(f"Checking remote directory: {remote_job_dir}")
    
    # List all subdirectories in the job output directory
    cmd = f"ssh amiri.ry@login.explorer.northeastern.edu 'ls -d {remote_job_dir}/*/ 2>/dev/null | head -20'"
    result = subprocess.run(cmd, shell=True, capture_output=True, text=True)
    
    if result.returncode != 0:
        print(f"Error listing directories. Return code: {result.returncode}")
        print(f"stderr: {result.stderr}")
        print(f"stdout: {result.stdout}")
        return None
    
    checkpoint_dirs = [line.strip().rstrip('/') for line in result.stdout.strip().split('\n') if line.strip()]
    
    print(f"Found {len(checkpoint_dirs)} directories: {checkpoint_dirs}")
    
    if not checkpoint_dirs:
        print(f"No directories found in {remote_job_dir}")
        return None
    
    # Check which ones have training_state.pt and get their step numbers
    checkpoints_with_steps = []
    total = len(checkpoint_dirs)
    
    for idx, checkpoint_dir in enumerate(checkpoint_dirs, 1):
        state_file = f"{checkpoint_dir}/training_state.pt"
        checkpoint_name = Path(checkpoint_dir).name
        
        print(f"\n[{idx}/{total}] Checking {checkpoint_name}...")
        
        # First check if file exists
        print(f"  Checking if {state_file} exists...")
        check_cmd = f"ssh amiri.ry@login.explorer.northeastern.edu 'test -f {state_file} && echo exists || echo notfound'"
        check_result = subprocess.run(check_cmd, shell=True, capture_output=True, text=True)
        
        if check_result.stdout.strip() != "exists":
            print(f"  {checkpoint_name}: No training_state.pt (skipping)")
            continue
        
        # Try to parse step number from checkpoint name (e.g., "checkpoint-2000" -> 2000)
        step = None
        if checkpoint_name.startswith("checkpoint-"):
            try:
                step = int(checkpoint_name.split("-")[-1])
                print(f"  ✓ {checkpoint_name}: step {step} (from name)")
            except ValueError:
                pass
        
        # If we couldn't parse from name, read from training_state.pt
        if step is None:
            print(f"  {checkpoint_name}: Reading step number from training_state.pt...")
            # Read step number directly on remote server
            escaped_file = state_file.replace('"', '\\"')
            python_cmd = f"python3 -c 'import torch; state = torch.load(\"{escaped_file}\", map_location=\"cpu\", weights_only=False); print(state.get(\"step\", 0))'"
            escaped_python = python_cmd.replace('"', '\\"')
            remote_cmd = f'ssh amiri.ry@login.explorer.northeastern.edu "{escaped_python}"'
            
            try:
                result = subprocess.run(remote_cmd, shell=True, capture_output=True, text=True, timeout=30)
                
                if result.returncode == 0:
                    try:
                        step = int(result.stdout.strip())
                        print(f"  ✓ {checkpoint_name}: step {step} (from training_state.pt)")
                    except ValueError:
                        print(f"  ✗ {checkpoint_name}: Could not parse step number: {result.stdout[:100]}")
                        continue
                else:
                    print(f"  ✗ {checkpoint_name}: Error reading step remotely")
                    print(f"    stderr: {result.stderr[:200]}")
                    continue
            except subprocess.TimeoutExpired:
                print(f"  ✗ {checkpoint_name}: Timeout reading step (took >30s)")
                continue
            except Exception as e:
                print(f"  ✗ {checkpoint_name}: Error reading step: {e}")
                continue
        
        checkpoints_with_steps.append((checkpoint_dir, step))
    
    if not checkpoints_with_steps:
        print(f"\nNo valid checkpoints found with training_state.pt")
        return None
    
    # Sort by step and return the latest
    checkpoints_with_steps.sort(key=lambda x: x[1], reverse=True)
    
    print(f"\n" + "="*60)
    print(f"Summary: Found {len(checkpoints_with_steps)} checkpoint(s):")
    for checkpoint_dir, step in checkpoints_with_steps:
        checkpoint_name = Path(checkpoint_dir).name
        print(f"  - {checkpoint_name}: step {step}")
    print(f"="*60)
    
    latest_checkpoint = checkpoints_with_steps[0][0]
    latest_step = checkpoints_with_steps[0][1]
    latest_name = Path(latest_checkpoint).name
    print(f"\nUsing latest checkpoint: {latest_name} (step {latest_step})")
    return latest_checkpoint

def download_checkpoint(job_id: str, local_dir: Optional[str] = None, remote_path: str = "/projects/llpr/amiri.ry/projects/yap-for-me/training") -> Optional[str]:
    """Download the latest checkpoint from a training run using rsync.
    
    Args:
        job_id: SLURM job ID
        local_dir: Local directory to save checkpoint (defaults to training/checkpoints)
        remote_path: Remote path to training directory
    
    Returns:
        Path to local checkpoint directory, or None if failed
    """
    print(f"Finding latest checkpoint for job {job_id}...")
    remote_checkpoint = find_latest_checkpoint_remote(job_id, remote_path)
    
    if not remote_checkpoint:
        return None
    
    # Default to training/checkpoints (parent of notebooks directory)
    if local_dir is None:
        local_dir = str(training_dir / "checkpoints")
    
    # Create local directory path
    local_checkpoint_dir = Path(local_dir) / job_id / Path(remote_checkpoint).name
    local_checkpoint_dir.mkdir(parents=True, exist_ok=True)
    
    # IMPORTANT: Use the dedicated transfer node (xfer.discovery.neu.edu) NOT login node!
    # According to Northeastern Explorer HPC documentation:
    # - Transfer node: xfer.discovery.neu.edu (optimized for data transfers)
    # - Login node: login.explorer.northeastern.edu (NOT for transfers)
    transfer_node = "xfer.discovery.neu.edu"
    username = "amiri.ry"
    
    local_full_path = str(local_checkpoint_dir.absolute())
    
    # Rsync via transfer node (recommended by Northeastern)
    rsync_cmd = f"rsync -av --partial --progress {username}@{transfer_node}:{remote_checkpoint}/ {local_full_path}/"
    
    print(f"\nDownloading checkpoint from {remote_checkpoint}...")
    print(f"Using transfer node: {transfer_node}")
    print(f"Local destination: {local_full_path}")
    print(f"\nRunning: {rsync_cmd}")
    print(f"This may take several minutes for large checkpoints...\n")
    
    # Execute rsync command
    result = subprocess.run(rsync_cmd, shell=True)
    
    if result.returncode == 0:
        print(f"\n✓ Checkpoint downloaded successfully to {local_full_path}")
        return str(local_full_path)
    else:
        print(f"\n✗ Download failed (return code: {result.returncode})")
        print(f"\nAlternative methods:")
        print(f"1. Try running the command manually:")
        print(f"   {rsync_cmd}")
        print(f"2. Use Globus (recommended for large files):")
        print(f"   - Visit: https://www.globus.org/")
        print(f"   - Sign in with Northeastern credentials")
        print(f"   - Install Globus Connect Personal")
        print(f"   - Search for 'Discovery Cluster' endpoint")
        print(f"   - Transfer files via web interface")
        return None

def load_model_from_checkpoint(checkpoint_path: str, device: Optional[str] = None):
    """Load model and tokenizer from checkpoint.
    
    Args:
        checkpoint_path: Path to checkpoint directory
        device: Device to load model on (defaults to auto-detected best device)
    
    Returns:
        Tuple of (model, tokenizer)
    """
    if device is None:
        device = get_device()
    
    checkpoint_path = Path(checkpoint_path)
    
    print(f"Loading tokenizer from {checkpoint_path}...")
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, trust_remote_code=True)
    
    print(f"Loading model from {checkpoint_path}...")
    
    # For MPS, we might need to use float16 instead of bfloat16
    if device == "mps":
        # MPS supports float16 but bfloat16 support may vary
        try:
            model = AutoModelForCausalLM.from_pretrained(
                checkpoint_path,
                dtype=torch.float16,
                trust_remote_code=True
            )
        except Exception as e:
            print(f"Warning: float16 failed, trying bfloat16: {e}")
            model = AutoModelForCausalLM.from_pretrained(
                checkpoint_path,
                dtype=torch.bfloat16,
                trust_remote_code=True
            )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            checkpoint_path,
            dtype=torch.bfloat16,
            trust_remote_code=True
        )
    
    model.to(device)
    model.eval()
    
    print(f"Model loaded on {device}")
    return model, tokenizer

def format_conversation_history(messages: List[Dict], max_context_tokens: int = 1800) -> str:
    """Format conversation history for model input.
    
    Args:
        messages: List of message dicts with keys: timestamp, speaker, text, replying_to, guid
        max_context_tokens: Maximum tokens for context (leaves room for tool calls)
    
    Returns:
        Formatted conversation string
    """
    formatted_messages = []
    for msg in messages:
        formatted = format_message(
            msg.get('timestamp', ''),
            msg.get('speaker', 'Unknown'),
            msg.get('text', ''),
            msg.get('replying_to'),
            msg.get('guid')
        )
        formatted_messages.append(formatted + '\n')
    
    return ''.join(formatted_messages)

def generate_tool_calls(
    model,
    tokenizer,
    conversation_history: str,
    max_new_tokens: int = 512,
    temperature: float = 0.7,
    top_p: float = 0.9,
    device: Optional[str] = None
) -> str:
    """Generate tool calls given conversation history.
    
    Args:
        model: Loaded model
        tokenizer: Loaded tokenizer
        conversation_history: Formatted conversation history
        max_new_tokens: Maximum tokens to generate
        temperature: Sampling temperature
        top_p: Nucleus sampling parameter
        device: Device to run inference on (defaults to auto-detected best device)
    
    Returns:
        Generated tool calls text
    """
    if device is None:
        device = get_device()
    
    # Tokenize the conversation history
    inputs = tokenizer(
        conversation_history,
        return_tensors="pt",
        add_special_tokens=False
    ).to(device)
    
    # Generate tool calls
    with torch.no_grad():
        # MPS may have issues with some generation parameters, so we handle it carefully
        generation_kwargs = {
            "max_new_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
            "do_sample": True,
            "pad_token_id": tokenizer.eos_token_id
        }
        
        # MPS sometimes has issues with certain generation settings
        if device == "mps":
            # Ensure we're using compatible settings for MPS
            try:
                outputs = model.generate(**inputs, **generation_kwargs)
            except Exception as e:
                print(f"Warning: Generation with MPS failed, trying with adjusted settings: {e}")
                # Fallback: disable some optimizations that might not work on MPS
                generation_kwargs.pop("top_p", None)
                outputs = model.generate(**inputs, **generation_kwargs)
        else:
            outputs = model.generate(**inputs, **generation_kwargs)
    
    # Decode only the newly generated tokens
    input_length = inputs['input_ids'].shape[1]
    generated_tokens = outputs[0][input_length:]
    generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
    
    return generated_text


## Usage Example

### Step 1: Download and load a checkpoint


In [None]:
# Replace with your job ID
JOB_ID = "3105120"

# Download the latest checkpoint
# checkpoint_path = download_checkpoint(JOB_ID)
checkpoint_path = "/Users/ramiri/dev/projects/YapForMe/training/checkpoints/3105120/checkpoint-2000"

if checkpoint_path:
    # Load the model
    model, tokenizer = load_model_from_checkpoint(checkpoint_path)
    print("Model ready for inference!")
else:
    print("Failed to download checkpoint")


### Step 2: Create conversation history


In [None]:
# Example conversation history
conversation_messages = [
    {
        "timestamp": "2025-12-09 10:00:00",
        "speaker": "Alice",
        "text": "Hey, are you free for lunch today?",
        "replying_to": None,
        "guid": "1"
    },
    {
        "timestamp": "2025-12-09 10:05:00",
        "speaker": "Ryan",
        "text": "Let me check my calendar",
        "replying_to": None,
        "guid": "2"
    },
    {
        "timestamp": "2025-12-09 10:06:00",
        "speaker": "Alice",
        "text": "Sounds good! Let me know",
        "replying_to": None,
        "guid": "3"
    },
]

# Format the conversation
formatted_history = format_conversation_history(conversation_messages)
print("Conversation History:")
print("=" * 60)
print(formatted_history)
print("=" * 60)


### Step 3: Generate tool calls


In [None]:
# Generate tool calls based on conversation history
tool_calls = generate_tool_calls(
    model,
    tokenizer,
    formatted_history,
    max_new_tokens=512,
    temperature=0.7,
    top_p=0.9
)

print("Generated Tool Calls:")
print("=" * 60)
print(tool_calls)
print("=" * 60)

print("\nFull Input + Output Sequence:")
print("=" * 60)
full_sequence = formatted_history + tool_calls
print(full_sequence)
print("=" * 60)

# Try to extract and display individual tool calls if they're in a structured format
# Look for common tool call patterns (adjust based on your actual format)
tool_patterns = [
    r'<tool_call>(.*?)</tool_call>',
    r'```(?:json|python)?\s*(\{.*?\})\s*```',
    r'(\w+)\s*\([^)]*\)',
]

print("\nExtracted Tools:")
print("=" * 60)
tools_found = []
for pattern in tool_patterns:
    matches = re.findall(pattern, tool_calls, re.DOTALL | re.IGNORECASE)
    if matches:
        tools_found.extend(matches)
        break

if tools_found:
    for i, tool in enumerate(tools_found, 1):
        print(f"Tool {i}:")
        print(tool)
        print()
else:
    print("No structured tool calls detected. Raw output shown above.")
print("=" * 60)


### Interactive Testing

You can build conversations interactively and test the model:


In [None]:
def simulate_conversation(messages: List[Dict], model, tokenizer):
    """Simulate a conversation and see what the model would respond with."""
    formatted_history = format_conversation_history(messages)
    
    print("Current Conversation:")
    print("=" * 60)
    for msg in messages:
        print(f"[{msg.get('timestamp', '')}] {msg.get('speaker', '')}: {msg.get('text', '')}")
    print("=" * 60)
    print("\n")
    
    tool_calls = generate_tool_calls(model, tokenizer, formatted_history)
    
    print("Model's Response (Tool Calls):")
    print("=" * 60)
    print(tool_calls)
    print("=" * 60)
    
    return tool_calls

# Example usage
test_messages = [
    {"timestamp": "2025-12-09 10:00:00", "speaker": "Friend", "text": "What are you up to?", "replying_to": None, "guid": None},
]

# Uncomment to test:
# simulate_conversation(test_messages, model, tokenizer)


### Load from existing local checkpoint

If you already have a checkpoint downloaded, you can load it directly:


In [None]:
# Load from an existing local checkpoint
# checkpoint_path = str(training_dir / "checkpoints" / "3105120" / "checkpoint-2000")
# model, tokenizer = load_model_from_checkpoint(checkpoint_path)


### Experiment with generation parameters

Try different temperature and top_p values to see how the model responds:


In [None]:
# Try different temperatures
# Lower temperature = more focused/deterministic
# Higher temperature = more creative/random

test_temps = [0.3, 0.7, 1.0]
test_context = format_conversation_history([
    {"timestamp": "2025-12-09 14:00:00", "speaker": "Mom", "text": "Are you coming home for dinner?", "replying_to": None, "guid": None},
])

for temp in test_temps:
    print(f"\n{'='*60}")
    print(f"Temperature: {temp}")
    print(f"{'='*60}")
    response = generate_tool_calls(
        model, tokenizer, test_context,
        temperature=temp,
        max_new_tokens=256
    )
    print(response)


## Interactive Chat

Chat with the model in real-time. You type messages as "the other person" and the model responds as Ryan.

Run this cell and follow the prompts. Type "quit" or "exit" to stop.

In [None]:
def parse_tool_calls(model_output: str) -> list:
    """Parse code-style tool calls from model output.
    
    Extracts tool calls in format:
    - react(message_guid="...", reaction_type="...")
    - reply(message_guid="...", text="...")
    - send_message(text="...")
    
    Returns list of dicts with 'action_type' and 'params' keys.
    """
    tool_calls = []
    
    lines = model_output.strip().split('\n')
    for line in lines:
        line = line.strip()
        if not line:
            continue
        
        react_match = re.match(r'react\s*\(\s*message_guid\s*=\s*"([^"]+)"\s*,\s*reaction_type\s*=\s*"([^"]+)"\s*\)', line)
        if react_match:
            tool_calls.append({
                'action_type': 'react',
                'params': {
                    'message_guid': react_match.group(1),
                    'reaction_type': react_match.group(2)
                }
            })
            continue
        
        reply_match = re.match(r'reply\s*\(\s*message_guid\s*=\s*"([^"]+)"\s*,\s*text\s*=\s*(.+?)\s*\)', line, re.DOTALL)
        if reply_match:
            text_value = reply_match.group(2).strip()
            if text_value.startswith('"') and text_value.endswith('"'):
                text_value = text_value[1:-1]
            elif text_value.startswith("'") and text_value.endswith("'"):
                text_value = text_value[1:-1]
            tool_calls.append({
                'action_type': 'reply',
                'params': {
                    'message_guid': reply_match.group(1),
                    'text': text_value
                }
            })
            continue
        
        send_msg_match = re.match(r'send_message\s*\(\s*text\s*=\s*(.+?)\s*\)', line, re.DOTALL)
        if send_msg_match:
            text_value = send_msg_match.group(1).strip()
            if text_value.startswith('"') and text_value.endswith('"'):
                text_value = text_value[1:-1]
            elif text_value.startswith("'") and text_value.endswith("'"):
                text_value = text_value[1:-1]
            tool_calls.append({
                'action_type': 'send_message',
                'params': {
                    'text': text_value
                }
            })
            continue
    
    return tool_calls


def format_tool_calls_display(tool_calls: list, conversation: list) -> str:
    """Format parsed tool calls for display.
    
    Returns a human-readable string showing what Ryan would do.
    """
    if not tool_calls:
        return "(No valid tool calls parsed)"
    
    outputs = []
    for tc in tool_calls:
        action = tc['action_type']
        params = tc['params']
        
        if action == 'send_message':
            outputs.append(params.get('text', ''))
        elif action == 'reply':
            target_guid = params.get('message_guid', '')
            text = params.get('text', '')
            target_msg = None
            for msg in conversation:
                if msg.get('guid') == target_guid:
                    target_msg = msg
                    break
            if target_msg:
                outputs.append(f"[replying to \"{target_msg['text'][:30]}...\"] {text}")
            else:
                outputs.append(f"[reply] {text}")
        elif action == 'react':
            reaction = params.get('reaction_type', '')
            target_guid = params.get('message_guid', '')
            target_msg = None
            for msg in conversation:
                if msg.get('guid') == target_guid:
                    target_msg = msg
                    break
            if target_msg:
                outputs.append(f"[{reaction} reaction to \"{target_msg['text'][:30]}...\"]")
            else:
                outputs.append(f"[{reaction} reaction]")
    
    return '\n'.join(outputs) if outputs else "(empty response)"


def interactive_chat(
    model,
    tokenizer,
    temperature: float = 0.7,
    top_p: float = 0.9,
    max_new_tokens: int = 128,
    show_raw_output: bool = False
):
    """Interactive chat loop with the model.
    
    You play the role of someone texting Ryan, and the model responds as Ryan.
    Type 'quit' or 'exit' to stop.
    
    Args:
        model: Loaded model
        tokenizer: Loaded tokenizer
        temperature: Sampling temperature (0.3-0.5 focused, 0.8-1.0 creative)
        top_p: Nucleus sampling parameter
        max_new_tokens: Max tokens to generate
        show_raw_output: If True, also shows raw tool call output
    """
    from datetime import datetime, timedelta
    
    print("Interactive Chat with Ryan")
    print("Type your messages and see how Ryan would respond.")
    print("Type 'quit' or 'exit' to stop.\n")
    
    partner_name = input("Who are you chatting as? (e.g., Mom, Alice, Friend): ").strip()
    if not partner_name:
        partner_name = "Friend"
    
    print(f"\nYou are now chatting as '{partner_name}'. Start the conversation!\n")
    
    conversation = []
    current_time = datetime.now()
    msg_id = 1
    
    while True:
        try:
            user_input = input(f"{partner_name}: ").strip()
        except (EOFError, KeyboardInterrupt):
            print("\n\nChat ended.")
            break
        
        if user_input.lower() in ("quit", "exit", "q"):
            print("\nChat ended.")
            break
        
        if not user_input:
            continue
        
        timestamp_str = current_time.strftime("%Y-%m-%d %H:%M:%S")
        conversation.append({
            "timestamp": timestamp_str,
            "speaker": partner_name,
            "text": user_input,
            "replying_to": None,
            "guid": str(msg_id)
        })
        msg_id += 1
        current_time += timedelta(seconds=random.randint(5, 30))
        
        formatted_history = format_conversation_history(conversation)
        
        print("\nRyan is typing...")
        raw_response = generate_tool_calls(
            model,
            tokenizer,
            formatted_history,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p
        )
        
        response_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
        current_time += timedelta(seconds=random.randint(10, 60))
        
        tool_calls = parse_tool_calls(raw_response)
        display_text = format_tool_calls_display(tool_calls, conversation)
        
        if show_raw_output:
            print(f"\n[Raw output]: {raw_response}")
        print(f"\nRyan: {display_text}\n")
        
        for tc in tool_calls:
            action = tc['action_type']
            params = tc['params']
            
            if action in ('send_message', 'reply'):
                text = params.get('text', raw_response)
                conversation.append({
                    "timestamp": response_time,
                    "speaker": "Ryan",
                    "text": text,
                    "replying_to": None,
                    "guid": str(msg_id)
                })
                msg_id += 1
                current_time += timedelta(seconds=random.randint(2, 10))
                response_time = current_time.strftime("%Y-%m-%d %H:%M:%S")
        
        if not tool_calls:
            conversation.append({
                "timestamp": response_time,
                "speaker": "Ryan",
                "text": raw_response,
                "replying_to": None,
                "guid": str(msg_id)
            })
            msg_id += 1
    
    if conversation:
        print("\nFull conversation:")
        print("-" * 60)
        for msg in conversation:
            print(f"[{msg['timestamp']}] {msg['speaker']}: {msg['text']}")
        print("-" * 60)
    
    return conversation

# Start the interactive chat (make sure model and tokenizer are loaded first)
# Adjust temperature for different response styles:
#   - Lower (0.3-0.5): More focused, predictable
#   - Higher (0.8-1.0): More creative, varied
# Set show_raw_output=True to see the raw tool calls
chat_history = interactive_chat(model, tokenizer, temperature=0.7, show_raw_output=True)