# Attribution Analysis — AlphaGenome FT LentiMPRA

Compute and visualize attributions for our fine-tuned AlphaGenome model on LentiMPRA K562.

**Methods available:**
- **Gradient SHAP**: Expected gradients with Altschul-Erickson dinucleotide shuffle (`deepshap.py`)
- **Gradient × Input**: Standard gradient-based
- **ISM**: In silico mutagenesis (wildtype-base importance)

In [1]:
import os, sys, json
import numpy as np
import jax
import jax.numpy as jnp
from pathlib import Path

# Project paths
PROJ_DIR = Path("/grid/wsbs/home_norepl/pmantill/LentiMPRA_mcs/LentiMoCon")
RESULTS_DIR = PROJ_DIR / "lenti_AGFT" / "training" / "results"
DATA_DIR = str(PROJ_DIR / "test_run_lenti_data")

# Add custom head/data modules to path (EncoderMPRAHead, LentiMPRADataset)
sys.path.insert(0, str(PROJ_DIR / "alphagenome_FT_MPRA"))
# Add interpreting scripts to path (deepshap module)
sys.path.insert(0, str(PROJ_DIR / "lenti_AGFT" / "interpreting" / "scripts"))

# HuggingFace cache
os.environ["HF_HOME"] = os.path.expanduser("~/Liver_AGFT/.weights/huggingface")

# Patch: use HuggingFace instead of Kaggle for base model downloads
from alphagenome_research.model import dna_model
dna_model.create_from_kaggle = dna_model.create_from_huggingface

from deepshap import deep_lift_shap

print(f"JAX devices: {jax.devices()}")

JAX devices: [CudaDevice(id=0)]


## 1. Configuration

Set the model name and attribution parameters here.

In [2]:
# ---- Configuration ----
MODEL_NAME = "second_2stageig"          # Model to analyze
CHECKPOINT_STAGE = "best_stage2"         # "best" (stage 1) or "best_stage2"
CELL_TYPE = "K562"
ATTRIBUTION_METHOD = "deepshap"          # "deepshap", "gradient_x_input", or "ism"
N_SEQUENCES = 5                          # Number of top sequences to analyze
N_SHAP_REFERENCES = 20                    # Number of shuffle references for DeepSHAP
N_IG_STEPS = 50                           # Integration steps per reference (higher = more accurate)
HEAD_NAME = "mpra_head"
PROMOTER_CONSTRUCT_LENGTH = 281

# Paths
model_dir = RESULTS_DIR / MODEL_NAME
checkpoint_dir = model_dir / "checkpoints" / CHECKPOINT_STAGE
output_dir = PROJ_DIR / "lenti_AGFT" / "interpreting" / "results" / MODEL_NAME
output_dir.mkdir(parents=True, exist_ok=True)

# Load model hyperparameters
with open(model_dir / "args.json") as f:
    args_json = json.load(f)
hp = args_json["hp"]
print(f"Model: {MODEL_NAME}")
print(f"Checkpoint: {CHECKPOINT_STAGE}")
print(f"DeepSHAP: {N_SHAP_REFERENCES} refs x {N_IG_STEPS} IG steps = {N_SHAP_REFERENCES * N_IG_STEPS} grad evals")
print(f"Head config: pooling={hp['pooling_type']}, nl={hp['nl_size']}, do={hp['dropout']}, act={hp['activation']}")

Model: second_2stageig
Checkpoint: best_stage2
DeepSHAP: 20 refs x 50 IG steps = 1000 grad evals
Head config: pooling=flatten, nl=1024, do=0.1, act=relu


## 2. Load Model

In [3]:
from alphagenome.models import dna_output
from alphagenome_ft import (
    register_custom_head,
    load_checkpoint,
    CustomHeadConfig,
    CustomHeadType,
)
from src import EncoderMPRAHead, LentiMPRADataset

# Register head with same config as training
nl_size = hp["nl_size"] if isinstance(hp["nl_size"], list) else [hp["nl_size"]]
head_metadata = {
    "center_bp": hp["center_bp"],
    "pooling_type": hp["pooling_type"],
    "nl_size": nl_size,
    "do": hp["dropout"],
    "activation": hp["activation"],
}
register_custom_head(
    HEAD_NAME, EncoderMPRAHead,
    CustomHeadConfig(
        type=CustomHeadType.GENOME_TRACKS,
        output_type=dna_output.OutputType.RNA_SEQ,
        num_tracks=1, metadata=head_metadata,
    ),
)

# Load fine-tuned checkpoint (Kaggle->HuggingFace patch applied in setup cell)
model = load_checkpoint(
    str(checkpoint_dir),
    base_model_version="all_folds",
    init_seq_len=PROMOTER_CONSTRUCT_LENGTH,
)
print(f"Model loaded from {checkpoint_dir}")
print(f"Custom heads: {model._custom_heads}")
print(f"Total parameters: {model.count_parameters():,}")

Loading checkpoint from /grid/wsbs/home_norepl/pmantill/LentiMPRA_mcs/LentiMoCon/lenti_AGFT/training/results/second_2stageig/checkpoints/best_stage2
  Custom heads: ['mpra_head']
  Model type: Full model
Loading full model from checkpoint...


Downloading (incomplete total...): 0.00B [00:00, ?B/s]

Fetching 12 files:   0%|          | 0/12 [00:00<?, ?it/s]



✓ Checkpoint loaded successfully
  Total parameters: 455,176,326
Model loaded from /grid/wsbs/home_norepl/pmantill/LentiMPRA_mcs/LentiMoCon/lenti_AGFT/training/results/second_2stageig/checkpoints/best_stage2
Custom heads: ['mpra_head']
Total parameters: 455,176,326


## 3. Load Test Data

In [4]:
test_ds = LentiMPRADataset(
    model=model,
    path_to_data=DATA_DIR,
    cell_type=CELL_TYPE,
    split="test",
    random_shift=False,
    reverse_complement=False,
)
print(f"Test dataset: {len(test_ds)} sequences")

# Find top sequences by activity
activities = [(i, float(test_ds[i]["y"])) for i in range(len(test_ds))]
activities.sort(key=lambda x: x[1], reverse=True)
top_seqs = activities[:N_SEQUENCES]

print(f"\nTop {N_SEQUENCES} sequences by activity:")
for rank, (idx, act) in enumerate(top_seqs, 1):
    print(f"  {rank}. idx={idx}, activity={act:.4f}")

Loaded 19670 samples for test split
Test dataset: 19670 sequences

Top 5 sequences by activity:
  1. idx=6277, activity=2.8380
  2. idx=3320, activity=2.5780
  3. idx=3637, activity=2.2870
  4. idx=9113, activity=2.2810
  5. idx=5605, activity=2.2680


## 4. Helper Functions

In [5]:
def decode_one_hot(one_hot_seq):
    """Convert one-hot encoded sequence back to DNA string (A=0, C=1, G=2, T=3)."""
    if one_hot_seq.ndim == 3:
        one_hot_seq = one_hot_seq[0]
    base_map = {0: 'A', 1: 'C', 2: 'G', 3: 'T'}
    return ''.join(base_map[int(jnp.argmax(one_hot_seq[i]))] for i in range(one_hot_seq.shape[0]))


def prepare_sample(dataset, idx):
    """Get a sample ready for attribution (batch dim, JAX arrays)."""
    sample = dataset[idx]
    seq = jnp.array(sample["seq"])
    if seq.ndim == 2:
        seq = seq[None, :, :]  # add batch dim
    org = jnp.array(sample["organism_index"])
    if org.ndim == 0:
        org = org[None]
    seq_str = decode_one_hot(seq)
    return seq, org, seq_str, float(sample["y"])

## 4b. DeepLIFT Convergence Check

Verify the sum rule: `sum(attr * input) ≈ f(input) - mean(f(references))` for the first sequence.

In [7]:
import importlib, deepshap
importlib.reload(deepshap)
from deepshap import deep_lift_shap, dinucleotide_shuffle, _build_compute_output

# Use first top sequence — fewer refs for quick check (each does n_steps grad evals)
seq, org, seq_str, act = prepare_sample(test_ds, top_seqs[0][0])
n_check = 50
n_steps = 50

# Compute attributions (integrated gradients per reference)
attr = deep_lift_shap(model, seq, org, HEAD_NAME,
                      n_shuffles=n_check, n_steps=n_steps, random_state=42)

# Sum rule: sum over ALL positions and ALL 4 channels = f(x) - mean(f(refs))
attr_sum = float(np.sum(np.array(attr)))

# f(input)
compute_output = _build_compute_output(model, org, HEAD_NAME)
pred_input = float(compute_output(seq))

# mean(f(references)) — same seed -> same references as deep_lift_shap
refs = dinucleotide_shuffle(np.array(seq[0]), n=n_check, rng=np.random.default_rng(42))
ref_preds = []
for r in range(n_check):
    ref_seq = jnp.array(refs[r:r+1])
    ref_preds.append(float(compute_output(ref_seq)))
pred_ref_mean = np.mean(ref_preds)

expected = pred_input - pred_ref_mean
print(f"IG-SHAP sum rule check (n_refs={n_check}, n_steps={n_steps}):")
print(f"  sum(attr) [all channels]   = {attr_sum:.6f}")
print(f"  f(input) - mean(f(refs))   = {expected:.6f}")
print(f"  Difference                 = {abs(attr_sum - expected):.6f}")
print(f"  Relative error             = {abs(attr_sum - expected) / (abs(expected) + 1e-10):.4%}")
print(f"  Total grad evals           = {n_check * n_steps}")

IG-SHAP sum rule check (n_refs=50, n_steps=50):
  sum(attr) [all channels]   = 2.249261
  f(input) - mean(f(refs))   = 2.214248
  Difference                 = 0.035013
  Relative error             = 1.5813%
  Total grad evals           = 2500


## 5. Compute Attributions & Visualize

For each top sequence, compute all three attribution methods and generate a 2-column comparison figure (full attribution logos + WT-only logos).

In [57]:
import matplotlib
matplotlib.use("Agg")
from PIL import Image
import tempfile

for rank, (seq_idx, activity) in enumerate(top_seqs, 1):
    print(f"\n{'='*60}")
    print(f"Sequence {rank}/{N_SEQUENCES}: idx={seq_idx}, activity={activity:.4f}")
    print(f"{'='*60}")

    seq, org, seq_str, act = prepare_sample(test_ds, seq_idx)
    print(f"  Length: {len(seq_str)}bp")

    methods = {
        "DeepSHAP": lambda: jnp.array(deep_lift_shap(
            model, seq, org, HEAD_NAME,
            n_shuffles=N_SHAP_REFERENCES, n_steps=N_IG_STEPS,
            random_state=42, hypothetical=False,
        )),
        "Grad x Input": lambda: model.compute_input_gradients(
            sequence=seq, organism_index=org, head_name=HEAD_NAME, gradients_x_input=True,
        ),
        "ISM": lambda: model.compute_ism_attributions(
            sequence=seq, organism_index=org, head_name=HEAD_NAME,
        ),
    }

    logo_col1 = []  # full attribution logos
    logo_col2 = []  # WT-only logos

    for name, compute_fn in methods.items():
        attr = jnp.float32(compute_fn())
        if name != "ISM":
            attr = attr - jnp.mean(attr, axis=-1, keepdims=True)

        # Column 1: full attribution logo
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
            tmp1 = tmp.name
        model.plot_sequence_logo(
            sequence=seq, gradients=attr, save_path=tmp1,
            logo_type="weight", mask_to_sequence=False, use_absolute=False,
            title=f"{name} — seq {seq_idx} (act={act:.3f})",
        )
        logo_col1.append(Image.open(tmp1))

        # Column 2: one-hot x attribution (WT base only)
        attr_wt = attr * seq
        with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
            tmp2 = tmp.name
        model.plot_sequence_logo(
            sequence=seq, gradients=attr_wt, save_path=tmp2,
            logo_type="weight", mask_to_sequence=True, use_absolute=False,
            title=f"{name} (WT only) — seq {seq_idx}",
        )
        logo_col2.append(Image.open(tmp2))
        print(f"  Computed {name}")

    # Stitch into 2-column layout
    n = len(methods)
    w1 = max(img.width for img in logo_col1)
    w2 = max(img.width for img in logo_col2)
    row_h = max(max(img.height for img in logo_col1), max(img.height for img in logo_col2))
    combined = Image.new("RGB", (w1 + w2, row_h * n), "white")
    for i in range(n):
        combined.paste(logo_col1[i], (0, i * row_h))
        combined.paste(logo_col2[i], (w1, i * row_h))

    seq_out = output_dir / f"seq{seq_idx}"
    seq_out.mkdir(parents=True, exist_ok=True)
    save_path = seq_out / "method_comparison.png"
    combined.save(save_path, dpi=(150, 150))
    print(f"  Saved: {save_path}")

    for img in logo_col1 + logo_col2:
        img.close()

print(f"\nAll outputs saved to: {output_dir}")


Sequence 1/5: idx=6277, activity=2.8380
  Length: 281bp
Sequence logo saved to: /tmp/slurm_tmp/703917/tmp980k_n25.png
Sequence logo saved to: /tmp/slurm_tmp/703917/tmp7jfgwrq5.png
  Computed DeepSHAP
Sequence logo saved to: /tmp/slurm_tmp/703917/tmpvk60pte8.png
Sequence logo saved to: /tmp/slurm_tmp/703917/tmphtk4a30t.png
  Computed Grad x Input
Sequence logo saved to: /tmp/slurm_tmp/703917/tmpmuk16psl.png
Sequence logo saved to: /tmp/slurm_tmp/703917/tmp17esddq9.png
  Computed ISM
  Saved: /grid/wsbs/home_norepl/pmantill/LentiMPRA_mcs/LentiMoCon/lenti_AGFT/interpreting/results/fastdrop/seq6277/method_comparison.png

Sequence 2/5: idx=3320, activity=2.5780
  Length: 281bp
Sequence logo saved to: /tmp/slurm_tmp/703917/tmp4zyttmbh.png
Sequence logo saved to: /tmp/slurm_tmp/703917/tmpvavdaa7h.png
  Computed DeepSHAP
Sequence logo saved to: /tmp/slurm_tmp/703917/tmpks_9kk8g.png
Sequence logo saved to: /tmp/slurm_tmp/703917/tmpy553sevx.png
  Computed Grad x Input
Sequence logo saved to: /t