# H200 Long-Form Residual Runs

Use this notebook to orchestrate long-sequence residual comparisons on the Vast H200 instance with the following guarantees:

- Pre-download every base/SFT pair before exercising the GPU.
- Force the four production prompts to remain long-form (≈200 tokens) with automatic 100-token and 50-token fallbacks if a model cannot handle the full text.
- Run only the default base vs SFT embedding configuration (no swaps, no fuzzing/perturbations).
- Stream progress with `tqdm` and immediately download each JSON/log back to this machine after every run.

> Fill in the connection settings in the next cell before running anything else.



In [8]:
from __future__ import annotations

import json
import os
import re
import shlex
import subprocess
import tarfile
import tempfile
import textwrap
from dataclasses import dataclass
from pathlib import Path, PurePosixPath
from typing import Dict, Iterable, List, Optional
from urllib.parse import urlparse

import pandas as pd
from tqdm.auto import tqdm
from transformers import AutoTokenizer

REPO_ROOT = Path("..").resolve()
if not (REPO_ROOT / "experiments").exists():
    REPO_ROOT = Path.cwd().resolve()


def detect_vast_instance(preferred_id: Optional[str]) -> Dict[str, Optional[str]]:
    try:
        result = subprocess.run(
            ["vastai", "show", "instances"],
            capture_output=True,
            text=True,
            check=True,
        )
    except (FileNotFoundError, subprocess.CalledProcessError):
        return {"instance_id": preferred_id, "ssh_host": None, "ssh_port": None}

    lines = [line.strip() for line in result.stdout.splitlines() if line.strip()]
    rows = [line for line in lines if line and line[0].isdigit()]
    target_row = None
    for row in rows:
        parts = row.split()
        if not parts:
            continue
        row_id = parts[0]
        if preferred_id and row_id == preferred_id:
            target_row = parts
            break
        if target_row is None:
            target_row = parts
    if not target_row or len(target_row) < 11:
        return {"instance_id": preferred_id or (target_row[0] if target_row else None), "ssh_host": None, "ssh_port": None}

    return {
        "instance_id": target_row[0],
        "ssh_host": target_row[9],
        "ssh_port": target_row[10],
    }


def fetch_ssh_settings(instance_id: Optional[str]) -> Dict[str, Optional[str]]:
    if not instance_id:
        return {}
    try:
        result = subprocess.run(
            ["vastai", "ssh-url", instance_id],
            capture_output=True,
            text=True,
            check=True,
        )
    except (FileNotFoundError, subprocess.CalledProcessError):
        return {}
    url = result.stdout.strip()
    if not url:
        return {}
    parsed = urlparse(url)
    if not parsed.hostname:
        return {}
    return {
        "ssh_user": parsed.username or "root",
        "ssh_host": parsed.hostname,
        "ssh_port": str(parsed.port) if parsed.port else None,
    }


detected = detect_vast_instance(os.environ.get("H200_INSTANCE_ID"))
INSTANCE_ID = os.environ.get("H200_INSTANCE_ID") or detected.get("instance_id") or "28315978"
ssh_url = fetch_ssh_settings(INSTANCE_ID)
SSH_USER = os.environ.get("H200_SSH_USER") or ssh_url.get("ssh_user") or "root"
SSH_HOST = os.environ.get("H200_SSH_HOST") or ssh_url.get("ssh_host") or detected.get("ssh_host") or "ssh7.vast.ai"
SSH_PORT = int(os.environ.get("H200_SSH_PORT") or ssh_url.get("ssh_port") or detected.get("ssh_port") or "35978")
_identity_env = os.environ.get("H200_SSH_IDENTITY")
default_identity = Path.home() / ".ssh" / "id_rsa"
if _identity_env:
    SSH_IDENTITY: Optional[Path] = Path(_identity_env).expanduser()
elif default_identity.exists():
    SSH_IDENTITY = default_identity
else:
    SSH_IDENTITY = None
REMOTE_REPO = os.environ.get("H200_REMOTE_REPO", "/workspace/vastai-ssh-jupyter-pytorch")
LOCAL_OUTPUT_ROOT = Path(os.environ.get("H200_LOCAL_OUTPUT", "h200_long_outputs")).resolve()
LOCAL_OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)

PROMPT_SOURCE = REPO_ROOT / "experiments" / "prompts" / "prod_prompts.txt"
PROMPT_VARIANTS_DIR = REPO_ROOT / "notebooks" / "generated_prompts"
PROMPT_VARIANTS_DIR.mkdir(parents=True, exist_ok=True)
REMOTE_PROMPT_DIR = f"{REMOTE_REPO}/experiments/prompts"
REMOTE_PYTHON = os.environ.get("H200_REMOTE_PYTHON", "python3")
PROD_CONFIG_REMOTE = f"{REMOTE_REPO}/configs/prod_config.yaml"
PROD_CONFIG_BACKUP = f"{REMOTE_REPO}/configs/prod_config.notebook.bak"
REMOTE_REPO_PATH = PurePosixPath(REMOTE_REPO)
REMOTE_REPO_PARENT = str(REMOTE_REPO_PATH.parent)
REMOTE_REPO_NAME = REMOTE_REPO_PATH.name
REMOTE_REPO_URL = os.environ.get(
    "H200_REMOTE_REPO_URL",
    "https://github.com/spencermcbridemoore/vastai-ssh-jupyter-pytorch.git",
)
LOCAL_CONFIG_PATH = REPO_ROOT / "configs" / "prod_config.yaml"

FALLBACK_TOKENIZER = os.environ.get("H200_FALLBACK_TOKENIZER", "Qwen/Qwen2.5-0.5B-Instruct")
PROMPT_VARIANT_ORDER = ("full", "tok100", "tok50")

print(f"Repo root: {REPO_ROOT}")
print(f"Local outputs: {LOCAL_OUTPUT_ROOT}")
print(f"Vast H200 instance {INSTANCE_ID}: ssh {SSH_USER}@{SSH_HOST} -p {SSH_PORT}")


Repo root: C:\Users\spenc\Cursor Repos\vastai-ssh-jupyter-pytorch
Local outputs: C:\Users\spenc\Cursor Repos\vastai-ssh-jupyter-pytorch\notebooks\h200_long_outputs
Vast H200 instance 28316581: ssh root@208.64.254.178 -p 30975


In [9]:
@dataclass(frozen=True)
class ModelSpec:
    name: str
    base: str
    sft: str
    tokenizer: str
    dtype: str = "bfloat16"
    device: str = "cuda:0"
    notes: str = ""


MODEL_SPECS: List[ModelSpec] = [
    ModelSpec(
        name="qwen-0_5b",
        base="Qwen/Qwen2.5-0.5B",
        sft="Qwen/Qwen2.5-0.5B-Instruct",
        tokenizer="Qwen/Qwen2.5-0.5B-Instruct",
        dtype="float16",
        notes="Tiny sanity check; fast to run.",
    ),
    ModelSpec(
        name="qwen-1_5b",
        base="Qwen/Qwen2.5-1.5B",
        sft="Qwen/Qwen2.5-1.5B-Instruct",
        tokenizer="Qwen/Qwen2.5-1.5B-Instruct",
        dtype="float16",
        notes="Fits comfortably on a single H200 GPU.",
    ),
    ModelSpec(
        name="qwen-3b",
        base="Qwen/Qwen2.5-3B",
        sft="Qwen/Qwen2.5-3B-Instruct",
        tokenizer="Qwen/Qwen2.5-3B-Instruct",
        notes="Baseline mid-size model.",
    ),
    ModelSpec(
        name="qwen-7b",
        base="Qwen/Qwen2.5-7B",
        sft="Qwen/Qwen2.5-7B-Instruct",
        tokenizer="Qwen/Qwen2.5-7B-Instruct",
        notes="Primary comparison target.",
    ),
    ModelSpec(
        name="qwen-14b",
        base="Qwen/Qwen2.5-14B",
        sft="Qwen/Qwen2.5-14B-Instruct",
        tokenizer="Qwen/Qwen2.5-14B-Instruct",
        notes="Largest Qwen2.5 variant that fits on H200.",
    ),
    ModelSpec(
        name="mistral-minitron-8b",
        base="nvidia/Mistral-NeMo-Minitron-8B-Base",
        sft="nvidia/Mistral-NeMo-Minitron-8B-Instruct",
        tokenizer="nvidia/Mistral-NeMo-Minitron-8B-Instruct",
        notes="NVIDIA NeMo Minitron pair.",
    ),
]

MODEL_SPECS


[ModelSpec(name='qwen-0_5b', base='Qwen/Qwen2.5-0.5B', sft='Qwen/Qwen2.5-0.5B-Instruct', tokenizer='Qwen/Qwen2.5-0.5B-Instruct', dtype='float16', device='cuda:0', notes='Tiny sanity check; fast to run.'),
 ModelSpec(name='qwen-1_5b', base='Qwen/Qwen2.5-1.5B', sft='Qwen/Qwen2.5-1.5B-Instruct', tokenizer='Qwen/Qwen2.5-1.5B-Instruct', dtype='float16', device='cuda:0', notes='Fits comfortably on a single H200 GPU.'),
 ModelSpec(name='qwen-3b', base='Qwen/Qwen2.5-3B', sft='Qwen/Qwen2.5-3B-Instruct', tokenizer='Qwen/Qwen2.5-3B-Instruct', dtype='bfloat16', device='cuda:0', notes='Baseline mid-size model.'),
 ModelSpec(name='qwen-7b', base='Qwen/Qwen2.5-7B', sft='Qwen/Qwen2.5-7B-Instruct', tokenizer='Qwen/Qwen2.5-7B-Instruct', dtype='bfloat16', device='cuda:0', notes='Primary comparison target.'),
 ModelSpec(name='qwen-14b', base='Qwen/Qwen2.5-14B', sft='Qwen/Qwen2.5-14B-Instruct', tokenizer='Qwen/Qwen2.5-14B-Instruct', dtype='bfloat16', device='cuda:0', notes='Largest Qwen2.5 variant that fit

In [10]:
JSON_PATH_RE = re.compile(r"Wrote comparison JSON to:\s*(.+)")


def _ssh_base() -> List[str]:
    cmd = ["ssh", "-p", str(SSH_PORT)]
    identity = None
    if SSH_IDENTITY:
        identity = Path(SSH_IDENTITY).expanduser()
    if identity and identity.exists():
        cmd += ["-i", str(identity)]
    cmd.append(f"{SSH_USER}@{SSH_HOST}")
    return cmd


def _scp_base() -> List[str]:
    cmd = ["scp", "-P", str(SSH_PORT)]
    identity = None
    if SSH_IDENTITY:
        identity = Path(SSH_IDENTITY).expanduser()
    if identity and identity.exists():
        cmd += ["-i", str(identity)]
    return cmd


def run_remote(command: str, *, desc: Optional[str] = None, check: bool = True) -> subprocess.CompletedProcess:
    if desc:
        tqdm.write(desc)
    full_cmd = _ssh_base() + [command]
    result = subprocess.run(full_cmd, capture_output=True, text=True)
    if check and result.returncode != 0:
        raise RuntimeError(
            f"Remote command failed ({result.returncode})\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}"
        )
    return result


def scp_to_remote(local_path: Path, remote_path: str, *, desc: Optional[str] = None) -> None:
    if desc:
        tqdm.write(f"Upload → {remote_path}: {local_path}")
    local_path = local_path.expanduser().resolve()
    cmd = _scp_base() + [str(local_path), f"{SSH_USER}@{SSH_HOST}:{remote_path}"]
    subprocess.run(cmd, check=True)


def scp_from_remote(remote_path: str, local_path: Path, *, desc: Optional[str] = None) -> None:
    if desc:
        tqdm.write(f"Download ← {remote_path} -> {local_path}")
    local_path.parent.mkdir(parents=True, exist_ok=True)
    cmd = _scp_base() + [f"{SSH_USER}@{SSH_HOST}:{remote_path}", str(local_path)]
    subprocess.run(cmd, check=True)


def write_log(log_path: Path, stdout: str, stderr: str) -> None:
    log_path.parent.mkdir(parents=True, exist_ok=True)
    log_path.write_text(
        stdout + ("\n--- stderr ---\n" + stderr if stderr else ""),
        encoding="utf-8",
    )


def format_env(env: Dict[str, str]) -> str:
    return " ".join(f"{key}={shlex.quote(str(value))}" for key, value in env.items())


In [None]:
def inspect_remote_repo(*, log: bool = True) -> Dict[str, object]:
    script = textwrap.dedent(
        f"""
        from pathlib import Path
        import json

        workspace = Path('/workspace')
        repo = Path(r"{REMOTE_REPO}")
        configs_dir = repo / 'configs'
        config_path = configs_dir / 'prod_config.yaml'

        def list_dir(path):
            if not path.exists():
                return None
            return sorted(f"{{child.name}}/" if child.is_dir() else child.name for child in path.iterdir())

        info = {{
            "workspace_entries": list_dir(workspace),
            "repo_exists": repo.exists(),
            "repo_entries": list_dir(repo),
            "configs_entries": list_dir(configs_dir),
            "config_exists": config_path.exists(),
            "repo_path": str(repo),
        }}
        print(json.dumps(info))
        """
    )
    desc = "Inspect remote filesystem" if log else None
    result = run_remote(
        f"{REMOTE_PYTHON} - <<'PY'\n{script}\nPY",
        desc=desc,
        check=False,
    )
    info: Dict[str, object] = {}
    stdout = (result.stdout or "").strip()
    if stdout:
        try:
            info = json.loads(stdout)
        except json.JSONDecodeError:
            info = {}
            if log:
                tqdm.write(stdout)
    if log and info:
        tqdm.write(json.dumps(info, indent=2))
    if result.stderr:
        tqdm.write("--- stderr ---\n" + result.stderr)
    return info


def clone_repo_via_git() -> bool:
    clone_script = textwrap.dedent(
        f"""
        set -euo pipefail
        mkdir -p {REMOTE_REPO_PARENT}
        cd {REMOTE_REPO_PARENT}
        rm -rf {REMOTE_REPO_NAME}
        git clone {REMOTE_REPO_URL} {REMOTE_REPO_NAME}
        """
    )
    result = run_remote(
        f"bash -lc \"{clone_script}\"",
        desc="Clone repo via git",
        check=False,
    )
    if result.returncode != 0:
        tqdm.write("Git clone failed; stderr →")
        tqdm.write(result.stderr or "(no stderr)")
    return result.returncode == 0


def upload_repo_snapshot() -> None:
    fd, tmp_path = tempfile.mkstemp(suffix=".tar.gz")
    os.close(fd)
    tmp_file = Path(tmp_path)
    try:
        with tarfile.open(tmp_file, "w:gz") as archive:
            archive.add(str(REPO_ROOT), arcname=REMOTE_REPO_NAME)
        remote_tar = f"/tmp/{REMOTE_REPO_NAME}.tar.gz"
        scp_to_remote(tmp_file, remote_tar, desc="Upload repo snapshot")
        extract_script = textwrap.dedent(
            f"""
            set -euo pipefail
            rm -rf {REMOTE_REPO}
            mkdir -p {REMOTE_REPO_PARENT}
            tar -xzf {remote_tar} -C {REMOTE_REPO_PARENT}
            rm -f {remote_tar}
            """
        )
        run_remote(
            f"bash -lc \"{extract_script}\"",
            desc="Extract repo snapshot",
        )
    finally:
        tmp_file.unlink(missing_ok=True)


def sync_prod_config_file() -> None:
    if not LOCAL_CONFIG_PATH.exists():
        raise FileNotFoundError(f"Missing local config: {LOCAL_CONFIG_PATH}")
    run_remote(
        f"bash -lc \"mkdir -p {REMOTE_REPO}/configs\"",
        desc="Ensure remote configs dir",
    )
    scp_to_remote(LOCAL_CONFIG_PATH, PROD_CONFIG_REMOTE, desc="Sync prod_config.yaml")


def validate_remote_config() -> bool:
    script = textwrap.dedent(
        f"""
        import sys
        import yaml
        from pathlib import Path

        config_path = Path(r"{PROD_CONFIG_REMOTE}")
        if not config_path.exists():
            print("prod_config.yaml missing")
            sys.exit(1)
        try:
            yaml.safe_load(config_path.read_text(encoding='utf-8'))
        except Exception as exc:  # pylint: disable=broad-except
            print("YAML error:", exc)
            sys.exit(1)
        """
    )
    result = run_remote(
        f"{REMOTE_PYTHON} - <<'PY'\n{script}\nPY",
        desc="Validate remote prod_config.yaml",
        check=False,
    )
    if result.returncode != 0 and result.stdout:
        tqdm.write(result.stdout)
    return result.returncode == 0


def ensure_remote_repo() -> Dict[str, object]:
    status = inspect_remote_repo(log=False)
    if not status.get("repo_exists"):
        if not clone_repo_via_git():
            tqdm.write("Falling back to local snapshot upload.")
            upload_repo_snapshot()
        status = inspect_remote_repo(log=False)
    if not status.get("config_exists"):
        tqdm.write("Remote prod_config.yaml missing; syncing local copy.")
        sync_prod_config_file()
        status = inspect_remote_repo(log=False)
    elif not validate_remote_config():
        tqdm.write("Remote prod_config.yaml invalid; syncing local copy.")
        sync_prod_config_file()
        validate_remote_config()
    run_remote(
        f"bash -lc \"mkdir -p {REMOTE_PROMPT_DIR}\"",
        desc="Ensure remote prompt dir",
    )
    return status


def ensure_remote_ready() -> Dict[str, object]:
    status = ensure_remote_repo()
    tqdm.write("Remote repo status:")
    tqdm.write(json.dumps(status, indent=2))
    return status



In [12]:
def ensure_remote_repo() -> None:
    script = textwrap.dedent(
        f"""
        import shutil
        import subprocess
        from pathlib import Path

        repo_path = Path(r"{REMOTE_REPO}")
        workspace = repo_path.parent
        config_path = Path(r"{PROD_CONFIG_REMOTE}")
        repo_url = "https://github.com/spencermcbridemoore/vastai-ssh-jupyter-pytorch.git"

        def clone_repo() -> None:
            workspace.mkdir(parents=True, exist_ok=True)
            subprocess.run(["git", "clone", repo_url, str(repo_path)], check=True)

        if not repo_path.exists():
            clone_repo()
        elif not config_path.exists():
            print("prod_config.yaml missing; recloning repo")
            if repo_path.exists():
                shutil.rmtree(repo_path)
            clone_repo()

        (repo_path / "experiments" / "prompts").mkdir(parents=True, exist_ok=True)
        """
    )
    run_remote(
        f"{REMOTE_PYTHON} - <<'PY'\n{script}\nPY",
        desc="Ensure remote repo + prompt directory exists",
    )


def ensure_remote_ready() -> None:
    ensure_remote_repo()
    inspect_remote_repo()


ensure_remote_ready()
tqdm.write("Verified remote repo checkout and prompt directory (with inspection).")


Ensure remote repo + prompt directory exists
Inspect remote filesystem
{
  "workspace_entries": [
    ".hf_home/",
    ".venv-backups/",
    "vastai-ssh-jupyter-pytorch/"
  ],
  "repo_exists": true,
  "repo_entries": [
    ".git/",
    ".gitignore",
    "NEXT_STEPS.md",
    "QUICKSTART.md",
    "README.md",
    "configs/",
    "docs/",
    "env.template",
    "experiments/",
    "h200_outputs_multi/",
    "notebooks/",
    "requirements.txt",
    "scripts/",
    "setup/",
    "src/",
    "tests/"
  ],
  "configs_entries": [
    "dev_config.yaml",
    "prod_config.notebook.bak",
    "prod_config.yaml"
  ],
  "config_exists": true,
  "repo_path": "/workspace/vastai-ssh-jupyter-pytorch"
}
--- stderr ---
Welcome to vast.ai. If authentication fails, try again after a few seconds, and double check your ssh key.
Have fun!

Verified remote repo checkout and prompt directory (with inspection).


In [13]:
with PROMPT_SOURCE.open("r", encoding="utf-8") as handle:
    RAW_PROMPTS = [line.strip() for line in handle if line.strip()]

if len(RAW_PROMPTS) != 4:
    raise ValueError(f"Expected 4 prompts, found {len(RAW_PROMPTS)} in {PROMPT_SOURCE}")

base_tokenizer = AutoTokenizer.from_pretrained(FALLBACK_TOKENIZER)


def truncate_prompt(text: str, max_tokens: Optional[int]) -> str:
    if max_tokens is None:
        return text
    encoded = base_tokenizer(text, add_special_tokens=False)
    ids = encoded["input_ids"]
    if len(ids) <= max_tokens:
        return text
    truncated = base_tokenizer.decode(
        ids[:max_tokens],
        skip_special_tokens=False,
        clean_up_tokenization_spaces=False,
    )
    return truncated


PROMPT_VARIANTS: Dict[str, List[str]] = {
    "full": RAW_PROMPTS,
    "tok100": [truncate_prompt(prompt, 100) for prompt in RAW_PROMPTS],
    "tok50": [truncate_prompt(prompt, 50) for prompt in RAW_PROMPTS],
}

LOCAL_PROMPT_FILES: Dict[str, Path] = {}
for label, prompts in PROMPT_VARIANTS.items():
    local_path = PROMPT_VARIANTS_DIR / f"prod_prompts_{label}.txt"
    with local_path.open("w", encoding="utf-8") as handle:
        for prompt in prompts:
            handle.write(prompt + "\n")
    LOCAL_PROMPT_FILES[label] = local_path
    tqdm.write(f"Wrote {label} prompt file → {local_path}")

REMOTE_PROMPT_FILES: Dict[str, str] = {}
for label in PROMPT_VARIANT_ORDER:
    local_path = LOCAL_PROMPT_FILES[label]
    remote_path = f"{REMOTE_PROMPT_DIR}/prod_prompts_{label}.txt"
    scp_to_remote(local_path, remote_path, desc=f"Sync {label} prompts")
    REMOTE_PROMPT_FILES[label] = remote_path

REMOTE_PROMPT_FILES


Wrote full prompt file → C:\Users\spenc\Cursor Repos\vastai-ssh-jupyter-pytorch\notebooks\generated_prompts\prod_prompts_full.txt
Wrote tok100 prompt file → C:\Users\spenc\Cursor Repos\vastai-ssh-jupyter-pytorch\notebooks\generated_prompts\prod_prompts_tok100.txt
Wrote tok50 prompt file → C:\Users\spenc\Cursor Repos\vastai-ssh-jupyter-pytorch\notebooks\generated_prompts\prod_prompts_tok50.txt
Upload → /workspace/vastai-ssh-jupyter-pytorch/experiments/prompts/prod_prompts_full.txt: C:\Users\spenc\Cursor Repos\vastai-ssh-jupyter-pytorch\notebooks\generated_prompts\prod_prompts_full.txt
Upload → /workspace/vastai-ssh-jupyter-pytorch/experiments/prompts/prod_prompts_tok100.txt: C:\Users\spenc\Cursor Repos\vastai-ssh-jupyter-pytorch\notebooks\generated_prompts\prod_prompts_tok100.txt
Upload → /workspace/vastai-ssh-jupyter-pytorch/experiments/prompts/prod_prompts_tok50.txt: C:\Users\spenc\Cursor Repos\vastai-ssh-jupyter-pytorch\notebooks\generated_prompts\prod_prompts_tok50.txt


{'full': '/workspace/vastai-ssh-jupyter-pytorch/experiments/prompts/prod_prompts_full.txt',
 'tok100': '/workspace/vastai-ssh-jupyter-pytorch/experiments/prompts/prod_prompts_tok100.txt',
 'tok50': '/workspace/vastai-ssh-jupyter-pytorch/experiments/prompts/prod_prompts_tok50.txt'}

In [14]:
CONFIG_PATCH_SCRIPT = textwrap.dedent(
    f"""
    import yaml
    from pathlib import Path

    config_path = Path(r"{PROD_CONFIG_REMOTE}")
    backup_path = Path(r"{PROD_CONFIG_BACKUP}")
    if not backup_path.exists():
        backup_path.write_text(config_path.read_text(encoding='utf-8'), encoding='utf-8')

    data = yaml.safe_load(config_path.read_text(encoding='utf-8'))
    rc = data.get('residual_compare', {{}})
    rc['prompt_variants'] = ['identity']
    rc['prompt_variant_options'] = {{}}
    rc['embedding_variants'] = [
        {{
            'name': 'default',
            'base': {{'embedding_source': 'base', 'unembedding_source': 'base'}},
            'sft': {{'embedding_source': 'sft', 'unembedding_source': 'sft'}},
        }}
    ]
    data['residual_compare'] = rc
    config_path.write_text(yaml.safe_dump(data, sort_keys=False), encoding='utf-8')
    """
)

CONFIG_RESTORE_SCRIPT = textwrap.dedent(
    f"""
    from pathlib import Path

    config_path = Path(r"{PROD_CONFIG_REMOTE}")
    backup_path = Path(r"{PROD_CONFIG_BACKUP}")
    if backup_path.exists():
        config_path.write_text(backup_path.read_text(encoding='utf-8'), encoding='utf-8')
    """
)


def apply_config_override() -> None:
    ensure_remote_repo()
    run_remote(
        f"cd {REMOTE_REPO} && {REMOTE_PYTHON} - <<'PY'\n{CONFIG_PATCH_SCRIPT}\nPY",
        desc="Apply no-fuzz config override",
    )


def restore_config_override() -> None:
    ensure_remote_repo()
    run_remote(
        f"cd {REMOTE_REPO} && {REMOTE_PYTHON} - <<'PY'\n{CONFIG_RESTORE_SCRIPT}\nPY",
        desc="Restore original prod_config.yaml",
        check=False,
    )


apply_config_override()
tqdm.write("prod_config.yaml patched (prompt_variants=identity, embedding_variants=default).")


Ensure remote repo + prompt directory exists
Apply no-fuzz config override


RuntimeError: Remote command failed (1)
STDOUT:

STDERR:
Welcome to vast.ai. If authentication fails, try again after a few seconds, and double check your ssh key.
Have fun!
Traceback (most recent call last):
  File "<stdin>", line 10, in <module>
  File "/usr/local/lib/python3.12/dist-packages/yaml/__init__.py", line 125, in safe_load
    return load(stream, SafeLoader)
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/__init__.py", line 81, in load
    return loader.get_single_data()
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/constructor.py", line 49, in get_single_data
    node = self.get_single_node()
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/composer.py", line 36, in get_single_node
    document = self.compose_document()
               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/composer.py", line 55, in compose_document
    node = self.compose_node(None, None)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/composer.py", line 84, in compose_node
    node = self.compose_mapping_node(anchor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/composer.py", line 133, in compose_mapping_node
    item_value = self.compose_node(node, item_key)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/composer.py", line 84, in compose_node
    node = self.compose_mapping_node(anchor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/composer.py", line 133, in compose_mapping_node
    item_value = self.compose_node(node, item_key)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/composer.py", line 84, in compose_node
    node = self.compose_mapping_node(anchor)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/composer.py", line 127, in compose_mapping_node
    while not self.check_event(MappingEndEvent):
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/parser.py", line 98, in check_event
    self.current_event = self.state()
                         ^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/yaml/parser.py", line 438, in parse_block_mapping_key
    raise ParserError("while parsing a block mapping", self.marks[-1],
yaml.parser.ParserError: while parsing a block mapping
  in "<unicode string>", line 111, column 5:
        enabled: false
        ^
expected <block end>, but found '-'
  in "<unicode string>", line 132, column 5:
        - name: "shared_unembedding"
        ^


In [None]:
def prefetch_models(specs: Iterable[ModelSpec]) -> None:
    repos = set()
    for spec in specs:
        repos.update([spec.base, spec.sft, spec.tokenizer])
    for repo in tqdm(sorted(repos), desc="Prefetching HF weights", unit="repo"):
        script = textwrap.dedent(
            f"""
            from huggingface_hub import snapshot_download
            snapshot_download('{repo}', repo_type='model', resume_download=True)
            """
        )
        run_remote(
            f"cd {REMOTE_REPO} && {REMOTE_PYTHON} - <<'PY'\n{script}\nPY",
            desc=f"Cache {repo}",
            check=False,
        )


prefetch_models(MODEL_SPECS)
tqdm.write("All requested models/tokenizers have been cached on the H200.")


In [None]:
BASE_ENV = {
    "DEV_MODE": "False",
    "PYTHONUNBUFFERED": "1",
}


def build_env(spec: ModelSpec, prompt_label: str) -> Dict[str, str]:
    env = dict(BASE_ENV)
    env.update(
        {
            "RESIDUAL_BASE_MODEL": spec.base,
            "RESIDUAL_SFT_MODEL": spec.sft,
            "RESIDUAL_TOKENIZER": spec.tokenizer,
            "RESIDUAL_DTYPE": spec.dtype,
            "RESIDUAL_DEVICE": spec.device,
            "RESIDUAL_PROMPT_FILE": REMOTE_PROMPT_FILES[prompt_label],
        }
    )
    return env


def execute_residual_attempt(spec: ModelSpec, prompt_label: str) -> Dict[str, object]:
    env_line = format_env(build_env(spec, prompt_label))
    command = f"cd {REMOTE_REPO} && {env_line} {REMOTE_PYTHON} experiments/base_vs_sft_residual.py"
    result = run_remote(command, desc=f"{spec.name} [{prompt_label}]", check=False)
    match = JSON_PATH_RE.search(result.stdout)
    json_path = match.group(1).strip() if match else None
    return {
        "completed_process": result,
        "json_path": json_path,
        "prompt_label": prompt_label,
    }


def download_artifacts(spec: ModelSpec, prompt_label: str, json_remote: str, result: subprocess.CompletedProcess) -> Dict[str, str]:
    model_dir = LOCAL_OUTPUT_ROOT / spec.name
    model_dir.mkdir(parents=True, exist_ok=True)
    local_json = model_dir / f"{Path(json_remote).name}" if json_remote else model_dir / f"{spec.name}_missing.json"
    scp_from_remote(json_remote, local_json, desc=f"JSON → {spec.name}")
    log_path = model_dir / f"{local_json.stem}_{prompt_label}.log"
    write_log(log_path, result.stdout, result.stderr)
    return {"json_local": str(local_json), "log_local": str(log_path)}


def run_with_fallbacks(spec: ModelSpec, prompts: Iterable[str] = PROMPT_VARIANT_ORDER) -> Dict[str, str]:
    attempts = list(prompts)
    for label in attempts:
        outcome = execute_residual_attempt(spec, label)
        result = outcome["completed_process"]
        if result.returncode == 0 and outcome["json_path"]:
            artifacts = download_artifacts(spec, label, outcome["json_path"], result)
            tqdm.write(f"✅ {spec.name} succeeded with {label} prompts")
            return {
                "model": spec.name,
                "prompt_variant": label,
                "remote_json": outcome["json_path"],
                **artifacts,
            }
        tqdm.write(
            f"⚠️ {spec.name} failed with {label} prompts (rc={result.returncode}). Trying next fallback..."
        )
    raise RuntimeError(f"All prompt variants failed for {spec.name}. Check remote logs for details.")


In [None]:
run_records: List[Dict[str, str]] = []

for spec in tqdm(MODEL_SPECS, desc="Residual sweep", unit="model"):
    try:
        outcome = run_with_fallbacks(spec)
        outcome["status"] = "ok"
    except Exception as exc:  # pylint: disable=broad-except
        tqdm.write(f"❌ {spec.name} failed: {exc}")
        outcome = {"model": spec.name, "status": "failed", "error": str(exc)}
    run_records.append(outcome)

summary_df = pd.DataFrame(run_records)
summary_df


In [None]:
# Run this cell after all experiments finish to restore prod_config.yaml
restore_config_override()
tqdm.write("prod_config.yaml restored from backup.")


In [None]:
import sys
sys.path.insert(0, str(REPO_ROOT / "src"))

from src.analysis.residual_results import iter_results, summarize_file  # type: ignore

json_paths = sorted(LOCAL_OUTPUT_ROOT.rglob("residual_compare_*.json"))
if not json_paths:
    raise RuntimeError(f"No residual_compare JSON files found in {LOCAL_OUTPUT_ROOT}")

summaries = []
for path in json_paths:
    summary = summarize_file(path)
    summaries.append(
        {
            "model": path.parent.name,
            "json_path": str(path),
            "num_results": summary.num_results,
            "avg_tokens": summary.avg_tokens,
        }
    )

summary_table = pd.DataFrame(summaries)
summary_table
