# Inference and Trajectory Training Lab

This notebook is a hands-on lab for Data Scientists and AI Researchers.

What you will do:
- Run end-to-end inference with the LangGraph workflow.
- Generate trajectory logs (`jsonl` + `summary`).
- Validate trajectory schema quality for training.
- Extract training-friendly samples from tool interactions.


## Workflow

1. Environment and project checks
2. Inference run (`main_langgraph.py`)
3. Trace inspection
4. Trajectory-to-training extraction


In [None]:
from __future__ import annotations

import json
import os
import subprocess
import sys
from pathlib import Path
from collections import Counter
from dotenv import load_dotenv


def find_project_root(start: Path) -> Path:
    for candidate in [start, *start.parents]:
        if (candidate / 'main_langgraph.py').exists() and (candidate / 'tools.py').exists():
            return candidate
    raise RuntimeError('Cannot locate project root containing main_langgraph.py and tools.py')


PROJECT_ROOT = find_project_root(Path.cwd().resolve())
os.chdir(PROJECT_ROOT)
load_dotenv(PROJECT_ROOT / '.env')

print('PROJECT_ROOT =', PROJECT_ROOT)
print('PYTHON =', sys.executable)


In [None]:
# Configuration
TOPIC = 'AI agent planning strategy market snapshot 2026'
MODEL = os.getenv('OPENAI_MODEL', 'gpt-4.1-mini')
PROVIDER = 'openai'
RUN_INFERENCE = True  # Set False to reuse the latest existing trace

# Output directories for this notebook run
NOTEBOOK_RUN_DIR = PROJECT_ROOT / 'test_outputs' / 'notebook_lab'
TRACE_DIR = NOTEBOOK_RUN_DIR / 'trajectories'
REPORT_PATH = NOTEBOOK_RUN_DIR / 'inference_report.md'
NOTEBOOK_RUN_DIR.mkdir(parents=True, exist_ok=True)
TRACE_DIR.mkdir(parents=True, exist_ok=True)

# API key diagnostics (masked)
api_key = os.getenv('OPENAI_API_KEY', '')
print('OPENAI_API_KEY set =', bool(api_key))
if api_key:
    print('OPENAI_API_KEY prefix =', api_key[:7] + '...')
else:
    raise RuntimeError('Missing OPENAI_API_KEY in environment.')


In [None]:
# Run end-to-end inference via CLI
if RUN_INFERENCE:
    cmd = [
        sys.executable,
        'main_langgraph.py',
        TOPIC,
        '--provider', PROVIDER,
        '--model', MODEL,
        '--output', str(REPORT_PATH),
        '--trace-dir', str(TRACE_DIR),
    ]

    env = os.environ.copy()
    # Force no-key search mode: DuckDuckGo + Wikipedia
    env['SERPER_API_KEY'] = ''

    print('Running command:')
    print(' '.join(cmd))
    completed = subprocess.run(
        cmd,
        cwd=str(PROJECT_ROOT),
        env=env,
        text=True,
        capture_output=True,
        check=True,
    )

    print('Return code:', completed.returncode)
    print('STDOUT preview:')
    print(completed.stdout[:1200])
    if completed.stderr:
        print('STDERR preview:')
        print(completed.stderr[:1200])
else:
    print('RUN_INFERENCE=False, skipping inference step.')


In [None]:
# Locate the latest trajectory files
summary_files = sorted(TRACE_DIR.glob('run_*.summary.json'), key=lambda p: p.stat().st_mtime)
jsonl_files = sorted(TRACE_DIR.glob('run_*.jsonl'), key=lambda p: p.stat().st_mtime)

if not summary_files or not jsonl_files:
    raise RuntimeError('No trajectory files found. Run inference first.')

SUMMARY_PATH = summary_files[-1]
JSONL_PATH = jsonl_files[-1]

summary = json.loads(SUMMARY_PATH.read_text())
print('SUMMARY_PATH =', SUMMARY_PATH)
print('JSONL_PATH =', JSONL_PATH)
print('status =', summary.get('status'))
print('event_count =', summary.get('event_count'))
print('tool_call_count =', summary.get('tool_call_count'))
print('tool_error_count =', summary.get('tool_error_count'))
print('report_len =', summary.get('report_len'))


In [None]:
# Validate trajectory schema and event coverage
rows = [json.loads(line) for line in JSONL_PATH.read_text().splitlines() if line.strip()]
required_top = {'ts_utc', 'run_id', 'idx', 'event_type', 'payload'}
missing_top_rows = [i for i, r in enumerate(rows, 1) if not required_top.issubset(r.keys())]

event_counts = Counter(r['event_type'] for r in rows)
print('rows =', len(rows))
print('missing_top_rows =', len(missing_top_rows))
print('event_counts =', dict(event_counts))

required_events = {
    'run_started',
    'phase',
    'tool_call',
    'tool_result',
    'message_snapshot',
    'run_completed',
    'final_report',
}
missing_events = sorted(required_events - set(event_counts.keys()))
print('missing_required_events =', missing_events)
if missing_top_rows or missing_events:
    raise AssertionError('Trajectory schema/event validation failed.')


In [None]:
# Build training-friendly samples from tool interactions
# Sample schema: input state (tool call) + outcome (tool result)

tool_calls = [r for r in rows if r['event_type'] == 'tool_call']
tool_results = [r for r in rows if r['event_type'] == 'tool_result']

paired_count = min(len(tool_calls), len(tool_results))
training_samples = []
for i in range(paired_count):
    call = tool_calls[i]
    result = tool_results[i]
    training_samples.append({
        'run_id': call['run_id'],
        'step_index': i + 1,
        'tool': call['payload'].get('tool'),
        'tool_kwargs': call['payload'].get('kwargs'),
        'ok': result['payload'].get('ok'),
        'latency_ms': result['payload'].get('latency_ms'),
        'result_preview': result['payload'].get('result_preview'),
    })

print('tool_calls =', len(tool_calls))
print('tool_results =', len(tool_results))
print('paired_training_samples =', len(training_samples))
print('first_sample =', training_samples[0] if training_samples else None)


In [None]:
# Persist extracted dataset for downstream training pipelines
DATASET_DIR = NOTEBOOK_RUN_DIR / 'datasets'
DATASET_DIR.mkdir(parents=True, exist_ok=True)

dataset_jsonl = DATASET_DIR / f"{SUMMARY_PATH.stem}.training_samples.jsonl"
with dataset_jsonl.open('w', encoding='utf-8') as f:
    for sample in training_samples:
        f.write(json.dumps(sample, ensure_ascii=True) + '\n')

print('Saved training samples:', dataset_jsonl)
print('Sample count:', len(training_samples))


In [None]:
# Optional: inspect the generated report
if REPORT_PATH.exists():
    report_text = REPORT_PATH.read_text(encoding='utf-8')
    print('Report size:', len(report_text))
    print('Report preview:')
    print(report_text[:1200])
else:
    print('Report not found at', REPORT_PATH)


## Next steps

- Add reward labels (quality, factuality, citation quality) to each tool step.
- Merge multiple runs into one curated training corpus.
- Build a train/validation split by topic and source diversity.
