In [63]:
# Core imports
import os
import numpy as np
import h5py
import pickle
import random

# TensorFlow/Keras imports for model loading
import tensorflow as tf
from keras.models import model_from_json

# SQUID imports for mutagenesis
import squid

In [64]:
# load dev 20

dev_path = "/grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/library_creation/Dev_20_library/Dev_20"

import os
import pandas as pd
import pickle

#open pickle file
path = os.path.join(dev_path, "dev_20_library.pkl")

dev_pkl = pd.read_pickle(path)

dev_pkl = dev_pkl["dev"]

#dev_pd = pd.DataFrame(dev_pkl, index=["test_idx"])
print(len(dev_pkl))
#remove the removed seqs (21916, 1693, 8389)

dev_pkl = dev_pkl[~dev_pkl["test_idx"].isin([21916, 1693, 8389])]
dev_pkl = dev_pkl.reset_index(drop=True)
print(len(dev_pkl))
#save the new dev_pkl in data and model dir

save_path = "/grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/data_and_models/dev_20_library/"

os.makedirs(save_path, exist_ok=True)
with open(os.path.join(save_path, 'dev_20_library.pkl'), 'wb') as f:
    pickle.dump({'dev': dev_pkl}, f)

23
20


In [65]:
## load dev 20 pickle

dev_20_path = "/grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/data_and_models/dev_20_library/"
dev_pkl = pd.read_pickle(os.path.join(save_path, 'dev_20_library.pkl'))
print(len(dev_pkl["dev"]))
dev_pkl = dev_pkl["dev"]

dev_pkl.iloc[0]

20


test_idx                                                22612
sequence    TTTTAATGACTGAAATTAAAACATCATTAAGGCGAATTGGCCACCG...
activity                                             3.265582
ohe_seq     [[0.0, 0.0, 0.0, 1.0], [0.0, 0.0, 0.0, 1.0], [...
Name: 0, dtype: object

In [66]:
# Download and load the DeepSTARR model
model_dir = "/grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/data_and_models/models/"
MODEL_DIR = model_dir

# Download model files if not present
model_json_file = os.path.join(model_dir, 'deepstarr.model.json')
model_weights_file = os.path.join(model_dir, 'deepstarr.model.h5')

if not os.path.exists(model_json_file):
    print("Downloading deepstarr.model.json...")
    url = 'https://www.dropbox.com/scl/fi/y1mwsqpv2e514md9t68jz/deepstarr.model.json?rlkey=cdwhstqf96fibshes2aov6t1e&st=9a0c5skz&dl=1'
    urlretrieve(url, model_json_file)
else:
    print(f"Using existing {model_json_file}")

if not os.path.exists(model_weights_file):
    print("Downloading deepstarr.model.h5...")
    url = 'https://www.dropbox.com/scl/fi/6nl6e2hofyw70lh99h3uk/deepstarr.model.h5?rlkey=hqfnivn199xa54bjh8dn2jpaf&st=l4jig4ky&dl=1'
    urlretrieve(url, model_weights_file)
else:
    print(f"Using existing {model_weights_file}")



# Load the model architecture from JSON
with open(model_json_file, 'r') as f:
    model_json = f.read()

model = model_from_json(model_json, custom_objects={'Functional': tf.keras.Model})

# Set random seeds for reproducibility
np.random.seed(113)
random.seed(0)

# Load the model weights
model.load_weights(model_weights_file)
num_tasks = 2  # Dev [0] and Hk [1]

alphabet = ['A','C','G','T']

x_ref = dev_pkl.iloc[0]["ohe_seq"]
x_ref = np.expand_dims(x_ref,0)


# Define mutagenesis window for sequence
seq_length = x_ref.shape[1]
mut_window = [0, seq_length]  # [start_position, stop_position]
print("\nModel loaded successfully!")

# Forward pass to get output for the specific head
output = model(x_ref)
predd,predh = model.predict(x_ref)[0], model.predict(x_ref)[1]
print(f"\nWild-type predictions: {predd[0][0], predh[0][0]}")
print(f"Model input shape: {model.input_shape}")
print(f"Model output shape: {model.output_shape}")

Using existing /grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/data_and_models/models/deepstarr.model.json
Using existing /grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/data_and_models/models/deepstarr.model.h5

Model loaded successfully!

Wild-type predictions: (3.2655823, 0.6504629)
Model input shape: (None, 249, 4)
Model output shape: [(None, 1), (None, 1)]


In [67]:
# Helper function to save library to HDF5
def save_library(filepath, sequences, predictions, original_idx):
    """Save mutagenesis library to HDF5 file."""
    n_samples = len(sequences)
    with h5py.File(filepath, 'w') as f:
        f.create_dataset('sequences', data=sequences, compression='gzip', compression_opts=4)
        f.create_dataset('predictions', data=predictions, compression='gzip', compression_opts=4)
        # Add library_index for consistent subsetting (0 to n_samples-1)
        f.create_dataset('library_index', data=np.arange(n_samples), compression='gzip', compression_opts=4)
        f.attrs['original_idx'] = original_idx
        f.attrs['n_samples'] = n_samples

In [68]:



## DeepSHAP attribution function with checkpointing
def seam_deepshap(x_mut, task_index, checkpoint_path=None, checkpoint_every=5000):
    """Compute DeepSHAP attributions with optional checkpointing."""
    x_ref = x_mut
    print(f"Computing attributions for task_index: {task_index}")
    import time
    import tensorflow as tf
    from keras.models import model_from_json
    import numpy as np
    import random

    # Check for existing checkpoint
    if checkpoint_path and os.path.exists(checkpoint_path):
        with h5py.File(checkpoint_path, 'r') as f:
            start_idx = f.attrs['last_completed_idx'] + 1
            attributions_partial = f['attributions'][:start_idx]
        print(f"Resuming from checkpoint at index {start_idx}")
    else:
        start_idx = 0
        attributions_partial = None

    # If already complete, return
    if start_idx >= len(x_mut):
        print("Attributions already complete, loading from checkpoint")
        with h5py.File(checkpoint_path, 'r') as f:
            return f['attributions'][:]

    # Configuration
    attribution_method = 'deepshap'
    gpu = 0
    
    # Model paths
    keras_model_json = os.path.join(MODEL_DIR, 'deepstarr.model.json')
    keras_model_weights = os.path.join(MODEL_DIR, 'deepstarr.model.h5')

    if attribution_method == 'deepshap':
        try:
            tf.compat.v1.disable_eager_execution()
            tf.compat.v1.disable_v2_behavior()
            print("TensorFlow eager execution disabled for DeepSHAP compatibility")
            
            try:
                import shap
            except ImportError:
                raise ImportError("SHAP package required for DeepSHAP attribution")
            
            shap.explainers.deep.deep_tf.op_handlers["AddV2"] = shap.explainers.deep.deep_tf.passthrough

            keras_model = model_from_json(open(keras_model_json).read(), custom_objects={'Functional': tf.keras.Model})
            np.random.seed(113)
            random.seed(0)
            keras_model.load_weights(keras_model_weights)
            model_local = keras_model
            
            _ = model_local(tf.keras.Input(shape=model_local.input_shape[1:]))
            
        except ImportError:
            raise
        except Exception as e:
            print(f"Warning: Could not setup TensorFlow for DeepSHAP. Error: {e}")
            print("DeepSHAP may not work properly.")
        
        def deepstarr_compress(x):
            if hasattr(x, 'outputs'):
                return tf.reduce_sum(x.outputs[task_index], axis=-1)
            else:
                return x

        attributer = Attributer(
            model_local,
            method=attribution_method,
            task_index=task_index,
            compress_fun=deepstarr_compress
        )

        attributer.show_params(attribution_method)

        t1 = time.time()
        
        # Process in chunks with checkpointing
        n_samples = len(x_mut)
        all_attributions = []
        
        # Add previously computed attributions if resuming
        if attributions_partial is not None:
            all_attributions.append(attributions_partial)
        
        for chunk_start in range(start_idx, n_samples, checkpoint_every):
            chunk_end = min(chunk_start + checkpoint_every, n_samples)
            print(f"\nProcessing samples {chunk_start} to {chunk_end} of {n_samples}")
            
            x_chunk = x_mut[chunk_start:chunk_end]
            x_ref_chunk = x_chunk
            
            chunk_attributions = attributer.compute(
                x_ref=x_ref_chunk,
                x=x_chunk,
                save_window=None,
                batch_size=64,
                gpu=gpu,
            )
            
            all_attributions.append(chunk_attributions)
            
            # Save checkpoint
            if checkpoint_path:
                attributions_so_far = np.concatenate(all_attributions, axis=0)
                os.makedirs(os.path.dirname(checkpoint_path), exist_ok=True)
                with h5py.File(checkpoint_path, 'w') as f:
                    f.create_dataset('attributions', data=attributions_so_far, compression='gzip', compression_opts=4)
                    f.attrs['last_completed_idx'] = chunk_end - 1
                    f.attrs['n_samples'] = n_samples
                print(f"Checkpoint saved at index {chunk_end - 1}")
        
        attributions = np.concatenate(all_attributions, axis=0)
        
        t2 = time.time() - t1
        print(f'Attribution time: {t2/60:.2f} minutes')
        
        return attributions


### Helper functions

def load_library_25k(seq_idx):
    """Load the full 100K library for a Dev_20 sequence."""
    filepath = f'/grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/b_mutation_rate_sweep/seq_libraries/mut_sweep/Dev/seq_{seq_idx}/25K.h5'
    with h5py.File(filepath, 'r') as f:
        sequences = f['sequences'][:]
        predictions = f['predictions'][:]
        original_idx = f.attrs['original_idx']
        library_index = f['library_index'][:] if 'library_index' in f else np.arange(len(sequences))
    return sequences, predictions, original_idx, library_index


def create_subset_indices(library_index, subset_size, seed=42):
    """Create subset indices by shuffling library_index with a fixed seed."""
    indices = library_index.copy()
    np.random.seed(seed)
    np.random.shuffle(indices)
    return indices[:subset_size]


def save_attributions(filepath, attributions, original_idx, subset_idx=None):
    """Save attributions to HDF5 file with optional subset indices."""
    os.makedirs(os.path.dirname(filepath), exist_ok=True)
    with h5py.File(filepath, 'w') as f:
        f.create_dataset('attributions', data=attributions, compression='gzip', compression_opts=4)
        if subset_idx is not None:
            f.create_dataset('subset_idx', data=subset_idx, compression='gzip', compression_opts=4)
        f.attrs['original_idx'] = original_idx
        f.attrs['n_samples'] = len(attributions)


def load_attributions(filepath):
    """Load attributions from HDF5 file."""
    with h5py.File(filepath, 'r') as f:
        return f['attributions'][:]


def attributions_exist(seq_idx):
    """Check if 100K attributions already exist for a sequence."""
    filepath = f'{RESULTS_DIR}/attribution_maps/deepSHAP/Dev/seq_{seq_idx}/100K.h5'
    return os.path.exists(filepath)


def all_attributions_exist(seq_idx):
    """Check if ALL attribution files exist for a given sequence."""
    for size_label in subset_sizes.keys():
        attr_path = f'{RESULTS_DIR}/attribution_maps/deepSHAP/Dev/seq_{seq_idx}/{size_label}.h5'
        if not os.path.exists(attr_path):
            return False
    return True


In [None]:
# Generate 25K mutagenesis libraries for each Dev_20 sequence 
# and sweep through mutation rates

from typing import Any


mutation_rates = [.75, .50, .25, .10, .5, .1]  
lib_size = 25000


for mut_rate in mutation_rates:

    task_index = 0  # 0 for Dev
    x_seqs = dev_pkl["ohe_seq"]
    seq_indices = dev_pkl["test_idx"]

    for i, (x_seq, idx) in enumerate(zip(x_seqs, seq_indices)):
        output_dir = f'/grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/b_mutation_rate_sweep/seq_libraries/mut_sweep/Dev/seq_{idx}/{mut_rate*100}%/'
        output_file = f'{output_dir}/25K.h5'
        
        # Check if library already exists
        if os.path.exists(output_file):
            print(f"Skipping seq_{idx} - already exists")
            continue
        
        os.makedirs(output_dir, exist_ok=True)
        
        x_seq = np.array(x_seq)
        
        # Create predictor
        pred_generator = squid.predictor.ScalarPredictor(
            pred_fun=model.predict_on_batch,
            task_idx=task_index,
            batch_size=512
        )
        
        # Create mutagenizer
        mut_generator = squid.mutagenizer.RandomMutagenesis(
            mut_rate=mut_rate,
            seed=42
        )
        
        # Create MAVE
        mave = squid.mave.InSilicoMAVE(
            mut_generator,
            pred_generator,
            seq_length = 249,
            mut_window=[0, 249]
        )
        
        # Generate 25k mutant sequences
        x_mut, y_mut = mave.generate(x_seq, num_sim=lib_size)
        
        # save each in 

        # add subset_idx col to ../seq_libraries/Seq_X/10%/25K.h5
        subset_idx = np.arange(lib_size)
        save_library(output_file, x_mut, y_mut, idx)
        print(f"[{i+1}/{len(x_seqs)}] Created seq_{idx}/25K.h5 with Mutation Rate {mut_rate*100}%")

    # get deepshap attributions for each library
    for idx in seq_indices:
        output_dir = f'/grid/wsbs/home_norepl/pmantill/SEAM_revisions/SEAM_revisions/hyperparameter_selection/b_mutation_rate_sweep/seq_libraries/mut_sweep/deepshap/Dev/seq_{idx}/{mut_rate*100}%/'
        output_file = f'{output_dir}/25K.h5'

        if os.path.exists(output_file):
            print(f"Skipping seq_{idx} - already exists")
            continue
        
        # load library
        x_mut, y_mut, original_idx, library_index = load_library_25k(idx)
        
        # get deepshap attributions
        attributions = seam_deepshap(x_mut, task_index, checkpoint_path=output_file)

        # save attributions
        save_attributions(output_file, attributions, original_idx)
        
        print(f"Saved attributions for seq_{idx} with Mutation Rate {mut_rate*100}%")


            
            
            
            



    

print("\nDone!")


Building in silico MAVE...


Mutagenesis: 100%|██████████| 25000/25000 [00:02<00:00, 12433.97it/s]
Inference: 100%|██████████| 48/48 [00:01<00:00, 27.04it/s]


[1/20] Created seq_22612/25K.h5 with Mutation Rate 75.0%

Building in silico MAVE...


Mutagenesis: 100%|██████████| 25000/25000 [00:01<00:00, 12593.21it/s]
Inference: 100%|██████████| 48/48 [00:01<00:00, 27.39it/s]


[2/20] Created seq_21069/25K.h5 with Mutation Rate 75.0%

Building in silico MAVE...


Mutagenesis: 100%|██████████| 25000/25000 [00:02<00:00, 11455.86it/s]
Inference: 100%|██████████| 48/48 [00:01<00:00, 29.28it/s]


[3/20] Created seq_13748/25K.h5 with Mutation Rate 75.0%

Building in silico MAVE...


Mutagenesis: 100%|██████████| 25000/25000 [00:02<00:00, 12451.55it/s]
Inference: 100%|██████████| 48/48 [00:01<00:00, 29.59it/s]


[4/20] Created seq_3881/25K.h5 with Mutation Rate 75.0%

Building in silico MAVE...


Mutagenesis: 100%|██████████| 25000/25000 [00:02<00:00, 12305.06it/s]
Inference: 100%|██████████| 48/48 [00:01<00:00, 29.83it/s]


[5/20] Created seq_2974/25K.h5 with Mutation Rate 75.0%

Building in silico MAVE...


Mutagenesis:  87%|████████▋ | 21807/25000 [00:01<00:00, 11503.16it/s]


KeyboardInterrupt: 