In [4]:
# !pip install accelerate scikit-learn #matplotlib seaborn datasets scikit-learn modal

In [6]:
import modal

stub = modal.Stub("rast-modal-runner")

image = modal.Image.debian_slim().pip_install(
    "torch",
    "transformers",
    "datasets",
    "tqdm",
    "numpy",
)

# Mount your steerit package (assuming it's local)
stub.mount("./steerit", remote_path="/root/steerit")

@stub.function(
    image=image,
    gpu="A10G",
    timeout=1800,
    secrets=[modal.Secret.from_name("huggingface")],
)
def run_rast():
    import math, random, warnings, re
    from collections import deque

    import numpy as np
    import torch, torch.nn.functional as F
    from tqdm import tqdm
    from datasets import load_dataset
    from transformers import AutoTokenizer, AutoModelForCausalLM

    from steerit.steerit import SteeringModel, SteeringVector

    MODEL_NAME = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B"
    HF_TOKEN   = os.getenv("HF_TOKEN", None)
    DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
    DTYPE      = torch.float16 if DEVICE == "cuda" else torch.float32

    STEER_LAY  = 20
    K_WIN      = 16
    DKL_THR    = 0.15
    ALPHA_HI   = 7.0
    GEN_LIMIT  = 512
    MAX_TOKENS = 4096
    MAX_POOL   = 4000
    SUFFIX     = " Answer step by step and end with: Final answer:"
    SEED       = 42

    random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
    warnings.filterwarnings("ignore")

    tok = AutoTokenizer.from_pretrained(MODEL_NAME, use_auth_token=HF_TOKEN)
    base_model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME,
        torch_dtype=DTYPE,
        device_map="auto" if DEVICE == "cuda" else None,
        use_auth_token=HF_TOKEN,
    )
    model = SteeringModel(base_model, [STEER_LAY], DEVICE)
    STOP_IDS = tok("Final answer:", add_special_tokens=False).input_ids

    def kl_div(a, b):
        p, q = torch.softmax(a, -1), torch.softmax(b, -1)
        return (p * (p.log() - q.log())).sum(-1).item()

    def numeric_match(text, gold):
        m = re.search(r'([-+]?\d+(?:\.\d+)?)\s*$', text)
        if not m: return False
        try: return math.isclose(float(m.group(1)), float(eval(gold)), 1e-3)
        except: return False

    def sample_next_cpu(logits_gpu, prev_ids, temperature=0.9, top_p=0.9, eps=1e-12):
        dev = logits_gpu.device
        logits = (logits_gpu.float() / temperature).cpu().numpy()[0].astype(np.float64)
        logits = np.clip(logits, -80.0, 80.0)
        probs = np.exp(logits - logits.max()); probs /= probs.sum()

        idx = np.argsort(probs)[::-1]; cum = np.cumsum(probs[idx])
        keep = cum <= top_p; keep[0] = True
        mask = np.zeros_like(probs, dtype=bool); mask[idx[keep]] = True
        probs = probs * mask + eps; probs /= probs.sum()

        if not np.isfinite(probs).all(): probs[:] = 1.0 / len(probs)
        for _ in range(10):
            tok = np.random.choice(len(probs), p=probs)
            if prev_ids.size(1) == 0 or tok != prev_ids[0, -1].item(): break
        return torch.tensor([[tok]], device=dev, dtype=torch.long)

    @torch.no_grad()
    def stream(prompt, max_len, model, tok, k_win=None):
        ids = tok(prompt, return_tensors="pt").to(DEVICE)["input_ids"]
        past, logs = None, []
        for _ in range(max_len):
            inp = ids if past is None else ids[:, -1:]
            out = model(input_ids=inp, past_key_values=past, use_cache=True)
            past = out.past_key_values
            if k_win:
                past = tuple((k[..., -k_win:, :], v[..., -k_win:, :]) for k, v in past)
            logits = out.logits[:, -1, :]
            logs.append(logits.detach())
            nxt = sample_next_cpu(logits, ids)
            ids = torch.cat([ids, nxt], -1)
            if nxt.item() in STOP_IDS or nxt.item() == tok.eos_token_id:
                break
        return ids.squeeze(0), logs

    gsm = load_dataset("gsm8k", "main")['test'].shuffle(SEED)
    train_rows = gsm.select(range(50))
    eval_rows = gsm.select(range(50, 60))

    def reservoir(pool, vec):
        if len(pool) < MAX_POOL: pool.append(vec)
        else:
            j = random.randrange(len(pool)+1)
            if j < MAX_POOL: pool[j] = vec

    hi, lo = [], []
    for row in tqdm(train_rows):
        prompt = f"Question: {row['question']}.{SUFFIX}"
        ids, logs = stream(prompt, GEN_LIMIT, base_model, tok, k_win=K_WIN)
        hs = model(input_ids=ids.unsqueeze(0).to(DEVICE), output_hidden_states=True).hidden_states[STEER_LAY][0]
        off = len(ids) - len(logs)
        for j in range(K_WIN, len(logs)):
            vec = hs[off+j].detach().cpu().numpy()
            reservoir(hi if kl_div(logs[j], logs[j-K_WIN]) >= DKL_THR else lo, vec)

    vec_dir = np.mean(hi, 0) - np.mean(lo, 0)
    vec_dir /= np.linalg.norm(vec_dir) + 1e-9
    rast_vec = SteeringVector({STEER_LAY: vec_dir.astype(np.float32)})

    @torch.no_grad()
    def rast_generate(prompt, steer_vec, model, tok, max_tokens=256):
        ids = tok(prompt, return_tensors="pt").to(DEVICE)["input_ids"]
        past, buf = None, deque(maxlen=K_WIN+1)
        model.set_steering(steer_vec, coeff=0.0)

        for _ in range(max_tokens):
            inp = ids if past is None else ids[:, -1:]
            out = model(input_ids=inp, past_key_values=past, use_cache=True)
            past = out.past_key_values
            past = tuple((k[..., -K_WIN:, :], v[..., -K_WIN:, :]) for k, v in past)
            logits = out.logits[:, -1, :]
            buf.append(logits.detach())
            model.coeff = ALPHA_HI if (len(buf) > K_WIN and kl_div(logits, buf[0]) < DKL_THR) else 0.0
            nxt = sample_next_cpu(logits, ids)
            ids = torch.cat([ids, nxt], -1)
            if nxt.item() in STOP_IDS or nxt.item() == tok.eos_token_id:
                break

        model.reset_steering()
        return ids.squeeze(0)

    tok_b, tok_r, acc_b, acc_r = [], [], [], []
    for row in tqdm(eval_rows):
        prompt = f"Question: {row['question']}.{SUFFIX}"
        base_ids, _ = stream(prompt, MAX_TOKENS, base_model, tok, k_win=K_WIN)
        rast_ids = rast_generate(prompt, rast_vec, model, tok, max_tokens=MAX_TOKENS)

        base_txt = tok.decode(base_ids, skip_special_tokens=True)
        rast_txt = tok.decode(rast_ids, skip_special_tokens=True)

        tok_b.append(base_ids.numel()); tok_r.append(rast_ids.numel())
        acc_b.append(numeric_match(base_txt, row["answer"]))
        acc_r.append(numeric_match(rast_txt, row["answer"]))

    print("\n──────── RESULTS ────────")
    print(f"Mean tokens baseline : {np.mean(tok_b):.1f}")
    print(f"Mean tokens RAST     : {np.mean(tok_r):.1f}")
    print(f"Token saving         : {100*(np.mean(tok_b)-np.mean(tok_r))/np.mean(tok_b):.1f}%")
    print(f"Accuracy baseline    : {np.mean(acc_b):.3f}")
    print(f"Accuracy RAST        : {np.mean(acc_r):.3f}")
    print("────────────────────────")


AttributeError: 'Stub' object has no attribute 'mount'