In [12]:
import pandas as pd
import numpy as np
import re, random, pandas as pd
random.seed(7)


In [13]:
%reload_ext autoreload
%autoreload 2

In [14]:
%pwd

'/Users/yvesgreatti/github/drug_discovery/notebook'

In [15]:

# Load the dataset
usecols = ["entry","uniprot_description","GPT_description","all_description"]
df = pd.read_csv("../mutadescribe_data/structural_split/train.csv", usecols=usecols, low_memory=False)

# Prefer enriched text, fallback to raw
df["effect_text"] = df["all_description"].fillna(df["uniprot_description"]).fillna(df["GPT_description"])
df = df.dropna(subset=["effect_text"])
df.head()

Unnamed: 0,entry,uniprot_description,GPT_description,all_description,effect_text
0,Q8N884-D95A,No effect on type I IFN and RSAD2 induction. N...,,No effect on type I IFN and RSAD2 induction. N...,No effect on type I IFN and RSAD2 induction. N...
1,P06276-L153F,In BCHED; seems to cause reduced expression of...,The mutation in the BCHE gene leads to reduced...,In BCHED; seems to cause reduced expression of...,In BCHED; seems to cause reduced expression of...
2,P63096-E245L,Enhances interaction (inactive GDP-bound) with...,,Enhances interaction (inactive GDP-bound) with...,Enhances interaction (inactive GDP-bound) with...
3,O35244-S32A,Abolishes lipid binding.,"Increased Prdx6 alpha-helical content, key rol...",Abolishes lipid binding. Increased Prdx6 alpha...,Abolishes lipid binding. Increased Prdx6 alpha...
4,P80365-R337C,In AME; decreased half-life from 21 to 4 hours...,This mutation has been discovered in a consang...,In AME; decreased half-life from 21 to 4 hours...,In AME; decreased half-life from 21 to 4 hours...


In [16]:
# 2) Lightweight polarity rules (negation-aware)
NEG = [r"loss[- ]of[- ]function", r"\blof\b", r"decreas(?:e|es|ed|ing)", r"reduc(?:e|es|ed|ing)",
       r"impair(?:s|ed|ing|ment)", r"inactivat(?:e|es|ed|ing)", r"disrupt(?:s|ed|ion)",
       r"abolish(?:es|ed|ing)?", r"destabiliz(?:e|es|ed|ing)", r"misfold(?:s|ed|ing)?",
       r"defect(?:ive)?", r"deleterious", r"inhibit(?:s|ed|ing|ion)"]
POS = [r"gain[- ]of[- ]function", r"\bgof\b", r"increas(?:e|es|ed|ing)", r"enhanc(?:e|es|ed|ing)",
       r"activat(?:e|es|ed|ing)", r"stabiliz(?:e|es|ed|ing)", r"improv(?:e|es|ed|ing)",
       r"up[- ]?regulat(?:e|es|ed|ing)", r"beneficial", r"protect(?:ive|ion|s|ed)?"]
NEU = [r"no (?:significant|measurable|observable) (?:change|effect|difference)",
       r"does not (?:affect|alter|impact)", r"not (?:affect|alter|impact)(?:ed|ing)?",
       r"unchang(?:ed|ing)", r"wild[- ]type(?:[- ]like)?", r"\bWT[- ]?like\b",
       r"comparable to (?:wt|wild[- ]type)", r"\bneutral\b", r"\btolerated\b"]

NEGATORS = {"no","not","without","lack","lacks","lacking","fails","failed","absence"}

In [17]:
def _has_negator(left, n=4):
    toks = re.findall(r"[a-zA-Z']+", left.lower())
    return any(t in NEGATORS for t in toks[-n:])

def _count(text, pats):
    c=0
    for p in pats:
        for m in re.finditer(p, text, flags=re.I):
            if not _has_negator(text[:m.start()]):
                c+=1
    return c

def label(text):
    t = " ".join(str(text).split())
    if _count(t, NEU)>0: return "Not significant"
    neg, pos = _count(t, NEG), _count(t, POS)
    if neg==0 and pos==0: return "Unknown"
    return "Malignant" if neg>pos else "Benign"

df["label"] = df["effect_text"].apply(label)

# 3) Build balanced, short few-shots
def pick_shots(df, per_class=4, max_words=35):
    shots = []
    for cls in ["Malignant","Benign","Not significant","Unknown"]:
        cand = df[(df["label"]==cls) & (df["effect_text"].str.split().str.len()<=max_words)]
        # prefer sentences with a clear single clause
        cand = cand[cand["effect_text"].str.count(r"[.;]")<=2]
        take = cand.sample(min(per_class, len(cand)), random_state=7)["effect_text"].tolist()
        shots += [(t, cls) for t in take]
    random.shuffle(shots)
    return shots

shots = pick_shots(df, per_class=3)

In [18]:
SYSTEM = """You label the functional polarity of a protein mutation from text.
- Malignant = detrimental/negative effect (LoF, decreased activity, destabilization, inhibition, misfolding…)
- Benign = beneficial/positive effect (GoF, increased activity, stabilization, activation, rescue…)
- Not significant = explicitly no/negligible effect (WT-like, unchanged, no significant change).
- Unknown = insufficient/contradictory.
Respect negation (“no decrease” is NOT negative). Output exactly one: Malignant | Benign | Not significant | Unknown."""
print(SYSTEM)
print("\nFEW-SHOTS:")
for t, y in shots:
    print(f"TEXT: {t}\nLABEL: {y}\n")

You label the functional polarity of a protein mutation from text.
- Malignant = detrimental/negative effect (LoF, decreased activity, destabilization, inhibition, misfolding…)
- Benign = beneficial/positive effect (GoF, increased activity, stabilization, activation, rescue…)
- Not significant = explicitly no/negligible effect (WT-like, unchanged, no significant change).
- Unknown = insufficient/contradictory.
Respect negation (“no decrease” is NOT negative). Output exactly one: Malignant | Benign | Not significant | Unknown.

FEW-SHOTS:
TEXT: Increased binding affinity to curdlan compared to the wild-type.
LABEL: Not significant

TEXT: Gain of activity.
LABEL: Unknown

TEXT: Enables self-association and NF-kappa-B inhibition by B14. The mutation in the B14 protein promotes its binding to IKKβ and activating NF-κB-dependent gene expression.
LABEL: Benign

TEXT: No effect on BMAL1 binding.
LABEL: Unknown

TEXT: In RP62; results in a complete loss of kinase activity compared to wild-type

In [19]:
%pwd

'/Users/yvesgreatti/github/drug_discovery/notebook'

In [None]:
import os, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
HF_TOKEN = "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"
#HF_TOKEN = os.environ.get("HF_TOKEN")  # set this first if the repo is gated
CACHE_DIR = "../hf_models"  # choose where files go on disk


tok = AutoTokenizer.from_pretrained(
    MODEL_ID, use_fast=True, token=HF_TOKEN, cache_dir=CACHE_DIR
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    token=HF_TOKEN,
    device_map="cpu",                # <— CPU only
    torch_dtype=torch.float32,       # <— CPU needs fp32
    low_cpu_mem_usage=True,          # <— reduces peak load
    cache_dir=CACHE_DIR,
).eval()


In [2]:
from huggingface_hub import hf_hub_download
import os
HF_TOKEN = "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"
CACHE_DIR = "../hf_models"  # choose where files go on disk

local_path = hf_hub_download(
    repo_id="bartowski/Meta-Llama-3-8B-Instruct-GGUF",  # or QuantFactory/*
    filename="Meta-Llama-3-8B-Instruct-Q4_K_M.gguf",
    token=HF_TOKEN,
    local_dir=CACHE_DIR
)
print(local_path)


Meta-Llama-3-8B-Instruct-Q4_K_M.gguf:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

../hf_models/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf


In [4]:
# from llama_cpp import Llama

# llm = Llama(
#     model_path="../hf_models/Meta-Llama-3-8B-Instruct-Q4_K_M.gguf",
#     n_ctx=4096,
#     n_threads=8,
# )


In [20]:
# pip install -U transformers accelerate bitsandbytes torch

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from huggingface_hub import login
token = "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"
login(token=token)

# 0) MODEL
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"   # or a biomed LLM
labels   = ["Malignant", "Benign", "Not significant", "Unknown"]

bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

CACHE_DIR = "../hf_models"  # choose where files go on disk

tok = AutoTokenizer.from_pretrained(
    MODEL_ID, use_fast=True, token=token, cache_dir=CACHE_DIR
)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    token=token,
    device_map="cpu",                # <— CPU only
    torch_dtype=torch.float32,       # <— CPU needs fp32
    low_cpu_mem_usage=True,          # <— reduces peak load
    cache_dir=CACHE_DIR,
).eval()

# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_ID, device_map="auto", quantization_config=bnb
# ).eval()

# You already have:
#   SYSTEM  -> string
#   shots   -> list[tuple[text,label]] like: [("text...", "Malignant"), ...]



Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

: 

In [None]:

def _chat_messages(effect_text: str):
    """
    Build a proper chat log:
      [system] rules
      [user]/[assistant] few-shot pairs
      [user]  query ending with 'LABEL:'
    """
    msgs = [{"role": "system", "content": SYSTEM}]
    for t, y in shots:
        msgs.append({"role": "user", "content": f"TEXT: {t}\nLABEL:"})
        msgs.append({"role": "assistant", "content": y})
    msgs.append({"role": "user", "content": f"TEXT: {effect_text}\nLABEL:"})
    return msgs

def build_prompt(effect_text: str) -> torch.Tensor:
    """
    Turn the chat messages into model-ready token ids using the tokenizer's
    chat template (best for Llama-style instruct models).
    """
    try:
        # returns token ids when tokenize=True
        ids = tok.apply_chat_template(
            _chat_messages(effect_text),
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        )
    except Exception:
        # Fallback to a plain string if the model has no chat template
        few = "".join([f"\nTEXT: {x}\nLABEL: {y}" for x, y in shots])
        plain = (
            f"<s>[SYSTEM]\n{SYSTEM}\n[/SYSTEM]\n"
            f"[USER]\nClassify the following text.{few}\n\nTEXT: {effect_text}\nLABEL: [/USER]\n[ASSISTANT]\n"
        )
        ids = tok(plain, return_tensors="pt", add_special_tokens=False).input_ids
    return ids.to(model.device)

@torch.inference_mode()
def label_logprob(prompt_ids: torch.Tensor, label: str) -> float:
    """
    log P(label | prompt) as a sum of next-token logprobs over the label tokens.
    """
    lab_ids = tok(" " + label, return_tensors="pt", add_special_tokens=False).input_ids.to(model.device)
    input_ids = torch.cat([prompt_ids, lab_ids], dim=1)
    attn_mask = torch.ones_like(input_ids)

    out = model(input_ids=input_ids, attention_mask=attn_mask).logits[:, :-1, :]
    targets = input_ids[:, 1:]

    Lp = prompt_ids.size(1)
    Ll = lab_ids.size(1)
    # positions where the label tokens are predicted
    logprobs = out[:, Lp-1:Lp+Ll-1, :].log_softmax(dim=-1)
    tgt = targets[:, Lp-1:Lp+Ll-1]
    return float(logprobs.gather(-1, tgt.unsqueeze(-1)).squeeze(-1).sum().item())

# def label_logprob(llm: Llama, prompt_text: str, label: str, *, add_bos: bool = True) -> float:
#     """
#     Return log P(label | prompt) by summing next-token logprobs over label tokens.
#     Works by feeding the concatenated prompt+label with echo=True and reading
#     the per-token logprobs for the label slice.
#     """
#     # Ensure label tokenization is aligned with Llama tokenizer conventions
#     label_text = " " + label

#     # Tokenize to find the slice that corresponds to the label
#     prompt_toks = llm.tokenize(prompt_text.encode("utf-8"), add_bos=add_bos)
#     full_toks   = llm.tokenize((prompt_text + label_text).encode("utf-8"), add_bos=add_bos)
#     Lp = len(prompt_toks)
#     Ll = len(full_toks) - Lp

#     # Ask llama.cpp for per-token logprobs for *prompt+label*
#     # max_tokens=0 + echo=True means: don't generate, just score what we sent.
#     out = llm.create_completion(
#         prompt=prompt_text + label_text,
#         max_tokens=0,
#         echo=True,
#         logprobs=1,   # return per-token logprobs
#     )

#     token_logprobs = out["choices"][0]["logprobs"]["token_logprobs"]
#     # token_logprobs aligns with the tokenized full sequence (may have a leading None for BOS)
#     start = len(token_logprobs) - Ll
#     label_slice = token_logprobs[start:]

#     # Sum over the label tokens; guard against any None values (e.g., BOS)
#     return float(sum(lp for lp in label_slice if lp is not None))


@torch.inference_mode()
def classify(text: str, margin_threshold: float = 1.0) -> dict:
    """
    Deterministic zero/few-shot classifier.
    margin_threshold controls 'Unknown' gating when top-2 scores are close.
    """
    if not isinstance(text, str) or not text.strip():
        return {"label": "Unknown", "margin": 0.0, "scores": {}}

    prompt_ids = build_prompt(text)
    scores = {lab: label_logprob(prompt_ids, lab) for lab in labels}
    best = max(scores, key=scores.get)
    sorted_vals = sorted(scores.values(), reverse=True)
    margin = sorted_vals[0] - sorted_vals[1] if len(sorted_vals) > 1 else float("inf")
    label = best if margin >= margin_threshold else "Unknown"
    return {"label": label, "margin": margin, "scores": scores}

In [None]:
res = classify("Variant reduces catalytic activity and destabilizes the protein.")
    print(res)  # {'label': 'Malignant', 'margin': ..., 'scores': {...}}