In [None]:
!pip install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

In [1]:
%pip install -q -U pip

# Core Hugging Face + bitsandbytes
%pip install -q \
  transformers \
  accelerate \
  datasets \
  peft \
  trl \
  bitsandbytes \
  huggingface_hub \
  gradio \
  sentencepiece

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.8 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.6/1.8 MB[0m [31m16.3 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
# pip installs

!pip install -q --upgrade requests==2.32.3 bitsandbytes>=0.43.1 transformers==4.48.3 accelerate==1.3.0 datasets==3.2.0 peft==0.14.0 trl==0.14.0 matplotlib wandb

In [2]:


# Sanity check (these should all import cleanly)
import datasets, transformers, accelerate, peft, trl, huggingface_hub, numpy
print({
    "datasets": datasets.__version__,
    "transformers": transformers.__version__,
    "accelerate": accelerate.__version__,
    "peft": peft.__version__,
    "trl": trl.__version__,
    "numpy": numpy.__version__,
})

{'datasets': '4.0.0', 'transformers': '4.55.2', 'accelerate': '1.10.0', 'peft': '0.17.0', 'trl': '0.21.0', 'numpy': '2.0.2'}


In [3]:
! pip install evaluate

Collecting evaluate
  Downloading evaluate-0.4.5-py3-none-any.whl.metadata (9.5 kB)
Downloading evaluate-0.4.5-py3-none-any.whl (84 kB)
Installing collected packages: evaluate
Successfully installed evaluate-0.4.5



# 1) Imports & device


In [4]:

import os, json, math, random, torch
from dataclasses import dataclass
from typing import Dict, List

import datasets
from datasets import load_dataset
from huggingface_hub import hf_hub_download

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline
)

from peft import LoraConfig
from trl import SFTTrainer
import evaluate
import gradio as gr

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cuda


In [5]:
import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

In [6]:
# Safety for T4: use fp16
DTYPE = torch.float32

# 2) Choose model (BioMistral‑7B — biomedical-tuned Mistral)
#### Source: https://huggingface.co/BioMistral/BioMistral-7B

In [7]:
BASE_MODEL = "BioMistral/BioMistral-7B"

### 4‑bit quantization config for QLoRA

In [8]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=DTYPE
)

In [9]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
# Fallback pad token if missing
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/72.0 [00:00<?, ?B/s]

In [10]:
import torch, sys
print("Torch:", torch.__version__, "| Python:", sys.version.split()[0])

Torch: 2.8.0+cu126 | Python: 3.12.11


In [11]:
#!pip -q install triton

In [12]:
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    device_map=device,
    quantization_config=bnb_config,
    torch_dtype=DTYPE
)
model.config.use_cache = False  # better for training with gradient checkpointing

config.json:   0%|          | 0.00/567 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/14.5G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/111 [00:00<?, ?B/s]

In [13]:
# ✅ Attach LoRA adapters to a 4‑bit model so it’s trainable
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

# (Optional) for large models, helps memory
# model.gradient_checkpointing_enable()

# Make the quantized layers train-ready
model = prepare_model_for_kbit_training(model)

# LoRA config — common target modules for Llama/Mistral-style blocks
peft_config = LoraConfig(
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules= ["q_proj","v_proj","up_proj","down_proj"],
)

# Wrap the base model with LoRA adapters
model = get_peft_model(model, peft_config)

# (Nice to have) show how many params will actually train
def print_trainable_params(m):
    trainable, total = 0, 0
    for _, p in m.named_parameters():
        num = p.numel()
        total += num
        if p.requires_grad:
            trainable += num
    print(f"Trainable params: {trainable:,} / {total:,} "
          f"({100*trainable/total:.2f}%)")

print_trainable_params(model)

Trainable params: 12,845,056 / 3,764,916,224 (0.34%)


# 3) Load BioProBench — Protocol Generation split (GEN)
#### Dataset hub: https://huggingface.co/datasets/BioProBench/BioProBench
### We'll directly load the JSON from the repo (GEN for train, GEN_test for eval).

In [14]:
BIOPROBENCH_REPO = "BioProBench/BioProBench"
gen_train_url = "https://huggingface.co/datasets/BioProBench/BioProBench/resolve/main/GEN.json"
gen_test_url  = "https://huggingface.co/datasets/BioProBench/BioProBench/resolve/main/GEN_test.json"


In [28]:
import requests, json
from datasets import Dataset, Features, Value, Sequence

GEN_TRAIN_URL = "https://huggingface.co/datasets/BioProBench/BioProBench/resolve/main/GEN.json"
GEN_TEST_URL  = "https://huggingface.co/datasets/BioProBench/BioProBench/resolve/main/GEN_test.json"

def fetch_json(url: str):
    r = requests.get(url, timeout=60)
    r.raise_for_status()
    return r.json()

def normalize_records(raw):
    cleaned = []
    for ex in raw:
        out = ex.get("output")
        # normalize output → list[str]
        if out is None:
            out_list = []
        elif isinstance(out, list):
            out_list = [str(s).strip() for s in out if str(s).strip()]
        else:
            out_list = [str(out).strip()] if str(out).strip() else []

        cleaned.append({
            "system_prompt": str(ex.get("system_prompt") or ""),
            "instruction":   str(ex.get("instruction")   or ""),
            "input":         str(ex.get("input")         or ""),
            "output":        out_list,
            "id":            str(ex.get("id")            or ""),
            "type":          str(ex.get("type")          or ""),
        })
    return cleaned

train_raw = fetch_json(GEN_TRAIN_URL)
test_raw  = fetch_json(GEN_TEST_URL)

train_clean = normalize_records(train_raw)
test_clean  = normalize_records(test_raw)

features = Features({
    "system_prompt": Value("string"),
    "instruction":   Value("string"),
    "input":         Value("string"),
    "output":        Sequence(Value("string")),
    "id":            Value("string"),
    "type":          Value("string"),
})

train_ds = Dataset.from_list(train_clean, features=features)
eval_ds  = Dataset.from_list(test_clean,  features=features)
print(train_ds[0].keys())
print("Train size:", len(train_ds), "Eval size:", len(eval_ds))


dict_keys(['system_prompt', 'instruction', 'input', 'output', 'id', 'type'])
Train size: 118955 Eval size: 772


In [29]:
KEYWORDS = ("PCR", "qPCR", "RT‑PCR", "reverse transcription", "gel", "electrophoresis",
            "cell culture", "seeding", "transfection", "western blot", "ELISA")

def domain_filter(ex):
    blob = f"{ex['system_prompt']} {ex['instruction']} {ex['input']}".lower()
    return any(k.lower() in blob for k in KEYWORDS)

train_ds_domain = train_ds.filter(domain_filter)
eval_ds_domain  = eval_ds.filter(domain_filter)

# then optionally cap
N = 30000
train_ds = train_ds_domain.select(range(min(N, len(train_ds_domain))))
eval_ds  = eval_ds_domain.select(range(min(1000, len(eval_ds_domain))))

print("Train size:", len(train_ds), "Eval size:", len(eval_ds))

Filter:   0%|          | 0/118955 [00:00<?, ? examples/s]

Filter:   0%|          | 0/772 [00:00<?, ? examples/s]

Train size: 21570 Eval size: 130


# The GEN items look like:
# {
#   "system_prompt": "...",
#   "instruction": "... formatting/constraints ...",
#   "input": "Goal / scenario (what to do)",
#   "output": ["1. ...", "2. ...", ...],   # list of step strings
#   "id": "...",
#   "type": "easy|medium|hard"
# }
#
# We'll turn each item into a single **instruction-tuning** string with this format:
#
#### System:
# <system_prompt>
#### Instruction:
# <instruction>
#### Input:
# <input>
#### Response:
# <joined numbered steps>

In [30]:
RESPONSE_HEADER = "### Response:\n"

def format_row(ex):
    out_text = "\n".join(ex["output"])
    return {
        "text":
        f"### System:\n{ex['system_prompt']}\n\n"
        f"### Instruction:\n{ex['instruction']}\n\n"
        f"### Input:\n{ex['input']}\n\n"
        f"{RESPONSE_HEADER}{out_text}"
    }

train_ds_fmt = train_ds.map(format_row, remove_columns=train_ds.column_names)
eval_ds_fmt  = eval_ds.select(range(min(400, len(eval_ds)))).map(format_row, remove_columns=eval_ds.column_names)

Map:   0%|          | 0/21570 [00:00<?, ? examples/s]

Map:   0%|          | 0/130 [00:00<?, ? examples/s]

In [31]:
def format_example(ex: Dict) -> str:
    sys_ = ex.get("system_prompt", "").strip()
    inst = ex.get("instruction", "").strip()
    inp  = ex.get("input", "").strip()
    out  = ex.get("output", [])
    if isinstance(out, list):
        out_text = "\n".join([s.strip() for s in out if isinstance(s, str) and s.strip()])
    else:
        out_text = str(out).strip()

    text = (
        f"### System:\n{sys_}\n\n"
        f"### Instruction:\n{inst}\n\n"
        f"### Input:\n{inp}\n\n"
        f"{RESPONSE_HEADER}{out_text}"
    )
    return text

In [32]:
def map_fn(example):
    return {"text": format_example(example)}

In [33]:
print(train_ds_fmt[0]["text"][:1000])

### System:
As a specialist in Biochemical & Molecular Functional Analysis, provide clear step-by-step instructions for experimental procedures.

### Instruction:
Please describe the protocol in a flat list format (using only 1., 2., 3. numbers). Include only the steps, not a rationale or materials list. Use concise language and maintain a chronological order.

### Input:
To prepare cell culture extracts for analyzing mitochondrial and cytosolic aconitase activities, cell pellet harvesting, washing, lysis, and extraction are essential. How to prepare cell culture extracts?

### Response:
1. Harvest and wash cell pellets with PBS at 4°C.
2. Lyse cell pellets using extraction buffer at 4°C.
3. Incubate lysed cells on ice with intermittent agitation.
4. Centrifuge to collect the supernatant (protein extract).


# 4) Data collator that **only computes loss on the Response**
### We don't want the model to learn to emit the prompt tokens.
### Using TRL's DataCollatorForCompletionOnlyLM with the "
### Response:\n" anchor.

In [34]:
from dataclasses import dataclass
import torch

RESPONSE_HEADER = "### Response:\n"

@dataclass
class ResponseOnlyDataCollator:
    tokenizer: any
    response_template: str = RESPONSE_HEADER
    max_length: int = 1024

    def __call__(self, features):
        texts = [f["text"] for f in features]
        batch = self.tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length,
        )
        labels = batch["input_ids"].clone()

        # mask everything before the response template
        for i, text in enumerate(texts):
            idx = text.find(self.response_template)
            if idx == -1:
                labels[i, :] = -100  # no response found → ignore
                continue

            prefix = text[: idx + len(self.response_template)]
            prefix_ids = self.tokenizer(prefix, add_special_tokens=False).input_ids
            cut = min(len(prefix_ids), labels.size(1))
            labels[i, :cut] = -100

        batch["labels"] = labels
        return batch

# ✅ use this collator instead of DataCollatorForCompletionOnlyLM
collator = ResponseOnlyDataCollator(tokenizer, RESPONSE_HEADER, max_length=1024)

# 6) Training args (keep Colab/T4 friendly)

In [35]:

OUT_DIR = "/content/bio-protocol-planner-lora"


In [36]:
training_args = TrainingArguments(
    output_dir=OUT_DIR,
    num_train_epochs=1,                  # one solid pass; you can bump to 2 if you have time
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,      # effective batch size = 128
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    logging_steps=20,
    save_steps=200,
    save_total_limit=2,
    eval_steps=200,
    fp16=(device=="cuda"),
    bf16=False,
    gradient_checkpointing=True,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    optim="paged_adamw_8bit"             # bitsandbytes optimizer to save VRAM
)

# 7) Trainer (TRL SFTTrainer)

In [37]:
from transformers import Trainer

trainer = Trainer(
    model=model,                    # now a PEFT-wrapped, k-bit-ready model
    args=training_args,
    train_dataset=train_ds_fmt,     # your dataset with a "text" field
    eval_dataset=eval_ds_fmt,
    data_collator=collator,         # the ResponseOnlyDataCollator you made
    processing_class=tokenizer,     # replaces deprecated "tokenizer=" kwarg
)


In [38]:
# Kick off training (≈ 20–40 min on T4; depends on load).
trainer.train()
# Save adapters
trainer.model.save_pretrained(OUT_DIR)
tokenizer.save_pretrained(OUT_DIR)

Step,Training Loss
20,2.0385
40,0.6632
60,0.631
80,0.5916
100,0.5746
120,0.5694
140,0.5749
160,0.5547
180,0.5791
200,0.5957


('/content/bio-protocol-planner-lora/tokenizer_config.json',
 '/content/bio-protocol-planner-lora/special_tokens_map.json',
 '/content/bio-protocol-planner-lora/chat_template.jinja',
 '/content/bio-protocol-planner-lora/tokenizer.model',
 '/content/bio-protocol-planner-lora/added_tokens.json',
 '/content/bio-protocol-planner-lora/tokenizer.json')

In [46]:
from huggingface_hub import notebook_login



PROJECT_RUN_NAME = "ruchirnamjoshi/BioMistralFinetuned"

trainer.model.push_to_hub(PROJECT_RUN_NAME, private=True)
print(f"Saved to the hub: {PROJECT_RUN_NAME}")

Processing Files (0 / 0)                : |          |  0.00B /  0.00B            

New Data Upload                         : |          |  0.00B /  0.00B            

  ...prlmgii8r/adapter_model.safetensors:   1%|1         |  554kB / 51.4MB            

Saved to the hub: ruchirnamjoshi/BioMistralFinetuned


In [51]:
from peft import PeftModel

In [53]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
)
base_model.generation_config.pad_token_id = tokenizer.pad_token_id


In [54]:
fine_tuned_model = PeftModel.from_pretrained(base_model, PROJECT_RUN_NAME)

adapter_model.safetensors:   0%|          | 0.00/51.4M [00:00<?, ?B/s]

In [57]:
# ---- Define a test prompt ----
prompt = "Design a step-by-step wet-lab protocol to test the effect of a new compound on E. coli growth. Keep it short and simple."

# ---- Tokenize input ----
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# ---- Generate output ----
outputs = fine_tuned_model.generate(
    **inputs,
    max_new_tokens=1024,   # how long the response can be
    temperature=0.1,      # lower = more deterministic
    top_p=0.9,            # nucleus sampling
    do_sample=True        # enable randomness
)

# ---- Decode prediction ----
result = tokenizer.decode(outputs[0], skip_special_tokens=True)
print("Generated Protocol:\n", result)

Generated Protocol:
 Design a step-by-step wet-lab protocol to test the effect of a new compound on E. coli growth. Keep it short and simple.
1. Prepare the bacterial culture.
2. Add the compound to the culture.
3. Incubate the culture.
4. Measure the optical density.
5. Analyze the results.

### Materials

- E. coli culture
- LB medium
- Compound stock solution
- 96-well plate
- Spectrophotometer

### Steps

1. Prepare the bacterial culture.
1.1. Inoculate 50 mL of LB medium with a single colony of E. coli.
1.2. Incubate the culture overnight at 37°C with shaking.
2. Add the compound to the culture.
2.1. Dilute the compound stock solution to the desired concentration.
2.2. Add 200 µL of the diluted compound to each well of a 96-well plate.
2.3. Add 200 µL of the bacterial culture to each well.
3. Incubate the culture.
3.1. Incubate the plate at 37°C with shaking for 24 hours.
4. Measure the optical density.
4.1. Measure the optical density of each well using a spectrophotometer.
5. An


# 8) Quick evaluation (ROUGE‑L) on a small eval slice

In [47]:

rouge = evaluate.load("rouge")

Downloading builder script: 0.00B [00:00, ?B/s]

In [48]:
def generate_protocol(prompt_text, max_new_tokens=512, temperature=0.3, top_p=0.9):
    # Build the same prompt skeleton the model saw during SFT
    sys_ = "You are an expert wet‑lab protocol planner. Output a single flat list of numbered steps (1., 2., 3., ...). No materials list or explanations."
    inst = "Given the goal and constraints, produce a chronological protocol in a single-level numbered list."
    text = f"### System:\n{sys_}\n\n### Instruction:\n{inst}\n\n### Input:\n{prompt_text}\n\n{RESPONSE_HEADER}"
    inputs = tokenizer([text], return_tensors="pt").to(device)
    with torch.no_grad():
        gen_ids = trainer.model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            eos_token_id=tokenizer.eos_token_id
        )
    out = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
    # Grab only the part after "### Response:"
    if RESPONSE_HEADER in out:
        out = out.split(RESPONSE_HEADER, 1)[1].strip()
    return out

In [49]:
# Evaluate on ~64 samples for speed
sample_eval = eval_ds.select(range(min(64, len(eval_ds))))
preds, refs = [], []
for ex in sample_eval:
    prompt_text = f"{ex.get('input','').strip()}"
    pred = generate_protocol(prompt_text, max_new_tokens=384)
    preds.append(pred)
    # reference is a list of steps; join
    ref = "\n".join([s.strip() for s in ex.get("output", []) if isinstance(s, str)])
    refs.append(ref)

rouge_scores = rouge.compute(predictions=preds, references=refs, use_stemmer=True)
print("ROUGE‑L:", round(rouge_scores.get("rougeL", 0.0), 4))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Setting `past_key_value=None`.
Caching is incompatible with gradient checkpointing in MistralDecoderLayer. Set

ROUGE‑L: 0.0074


In [42]:
!pip install rouge_score

Collecting rouge_score
  Downloading rouge_score-0.1.2.tar.gz (17 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: rouge_score
[33m  DEPRECATION: Building 'rouge_score' using the legacy setup.py bdist_wheel mechanism, which will be removed in a future version. pip 25.3 will enforce this behaviour change. A possible replacement is to use the standardized build interface by setting the `--use-pep517` option, (possibly combined with `--no-build-isolation`), or adding a `pyproject.toml` file to the source tree of 'rouge_score'. Discussion can be found at https://github.com/pypa/pip/issues/6334[0m[33m
[0m  Building wheel for rouge_score (setup.py) ... [?25l[?25hdone
  Created wheel for rouge_score: filename=rouge_score-0.1.2-py3-none-any.whl size=24934 sha256=9b497a154f7df1962ef223d92981a0599467a243b7f82885fe7b66172139f71b
  Stored in directory: /root/.cache/pip/wheels/85/9d/af/01feefbe7d55ef5468796f0c68225b6788e85d9d0a281e7a70
Successfu