
# Motif Shuffling Experiment - Dataset Preparation

This notebook documents the steps taken to generate sequence datasets for evaluating how deep learning models (e.g., SpliceAI) rely on known RNA-binding protein (RBP) motifs, such as those bound by SR proteins.

This process is part of a broader investigation into how well neural networks generalize splicing patterns when key regulatory signals are disrupted.

## Purpose

- Detect and score RBP motifs using position weight matrices (PWMs)
- Identify motif locations in exonic sequences
- Select a subset of motifs for targeted shuffling
- Generate new sequences with motif-disruptive shuffles while preserving all other genomic structure

## Workflow

1. **Load exon sequences and PWM matrix**
2. **Score motifs using a sliding window**
3. **Filter high-confidence motif matches**
4. **Randomly shuffle selected motif regions**
5. **Reconstruct the modified sequences**
6. **Output shuffled FASTA and BED files**

This notebook provides the core reproducible framework for generating input sequences used in the motif shuffle benchmarking experiments.


In [None]:
import math
from collections import defaultdict

# PWM for SRSF1 (Matrix ID 126)
pwm = {
    'A': [0.00031, 0.00031, 0.49969, 0.49969, 0.00031, 0.00031, 0.99906],
    'C': [0.99906, 0.49969, 0.49969, 0.49969, 0.49969, 0.49969, 0.00031],
    'G': [0.00031, 0.49969, 0.00031, 0.00031, 0.49969, 0.49969, 0.00031],
    'T': [0.00031, 0.00031, 0.00031, 0.00031, 0.00031, 0.00031, 0.00031],
}

motif_length = 7

# Convert PWM to log-odds assuming uniform background (0.25)
log_odds_pwm = {}
for base in 'ACGT':
    log_odds_pwm[base] = [math.log2(pwm[base][i] / 0.25) for i in range(motif_length)]


def score_sequence(seq):
    """Calculate PWM score for a given sequence window."""
    score = 0
    for i, base in enumerate(seq):
        if base in log_odds_pwm:
            score += log_odds_pwm[base][i]
        else:
            return None  # Skip sequences with ambiguous bases
    return score


def scan_sequence(seq, threshold):
    """Scan a sequence and return all motif hits with score >= threshold."""
    hits = []
    for i in range(len(seq) - motif_length + 1):
        window = seq[i:i + motif_length]
        score = score_sequence(window)
        if score is not None and score >= threshold:
            hits.append((i, i + motif_length, score, window))
    return hits


def scan_fasta(input_file, output_file, threshold):
    with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
        for line in infile:
            line = line.strip()
            if not line:
                continue
            try:
                transcript_id, seq = line.split('\t')
            except ValueError:
                print(f"Skipping malformed line: {line}")
                continue
            hits = scan_sequence(seq.upper(), threshold)
            for start, end, score, window in hits:
                outfile.write(f"{transcript_id}\t{start}\t{end}\t{score:.3f}\t{window}\n")


# Example usage:
scan_fasta(
    input_file='/mnt/lareaulab/sdahiyat/illumina/canonical_sequence_unflanked_unshuffled.txt',
    output_file='/mnt/lareaulab/sdahiyat/datasets/srsf1_motif_hits.tsv',
    threshold=5.0  # Adjust this value depending on the score distribution
)


In [None]:
def parse_junctions(file_path):
    junctions = {}
    with open(file_path, "r") as f:
        for line in f:
            fields = line.strip().split("\t")
            if len(fields) < 8:
                continue
            chrom = fields[2]
            start = int(fields[4])
            end = int(fields[5])
            exon_starts = list(map(int, fields[6].split(",")))
            exon_ends = list(map(int, fields[7].split(",")))
            transcript_key = f"{chrom}:{start}-{end}"
            junctions[transcript_key] = list(zip(exon_starts, exon_ends))
    return junctions

def extract_sequence(filepath, transcript_id):
    with open(filepath) as f:
        for line in f:
            tid, seq = line.strip().split("\t")
            if tid == transcript_id:
                return seq
    return None

# === Inputs ===
transcript_id = "chr1:65419-71585"
junction_file = "/mnt/lareaulab/sdahiyat/illumina/canonical_dataset_created.txt"
unshuffled_file = "/mnt/lareaulab/sdahiyat/illumina/motif_threshold_shuffled.txt"

# === Load data ===
junctions = parse_junctions(junction_file)
sequence = extract_sequence(unshuffled_file, transcript_id)

# === Get introns ===
seq_start = int(transcript_id.split(":")[1].split("-")[0])
exons = [(start - seq_start, end - seq_start) for start, end in junctions[transcript_id]]

introns = []
last = 0
for start, end in exons:
    if last < start:
        introns.append(sequence[last:start])
    last = end
if last < len(sequence):
    introns.append(sequence[last:])

print(f"🧬 Transcript: {transcript_id} | Total introns: {len(introns)}\n")
for i, intron in enumerate(introns, 1):
    print(f"Intron {i} (length {len(intron)}): {intron[:100]}{'...' if len(intron) > 100 else ''}\n")


In [None]:
input_file = "/mnt/lareaulab/sdahiyat/illumina/motif_threshold_shuffled.txt"
output_file = "/mnt/lareaulab/sdahiyat/illumina/motif_threshold_shuffled_flanked.txt"

# Define the flanking sequence (5000 Ns)
flank_length = 5000
flanking_seq = "N" * flank_length

with open(input_file, "r") as infile, open(output_file, "w") as outfile:
    for line in infile:
        line = line.strip()
        if not line:
            continue  # Skip empty lines
        
        try:
            header, sequence = line.split("\t")
        except ValueError:
            print(f"Skipping malformed line: {line}")
            continue
        
        # Extract chromosome, start, and end positions
        if ":" in header and "-" in header:
            chrom, positions = header.split(":")
            start, end = map(int, positions.split("-"))
            
            # Adjust start and end positions
            new_start = start - 5001 
            new_end = end + 5000
            
            # Construct new header
            new_header = f"{chrom}:{new_start}-{new_end}"
            
            # Construct new sequence with flanking Ns
            flanked_sequence = flanking_seq + sequence + flanking_seq
            # Write to output file
            outfile.write(f"{new_header}\t{flanked_sequence}\n")
        else:
            print(f"Skipping improperly formatted header: {header}")

print(f"Flanked sequences written to {output_file}.")


In [None]:
input_file = "/mnt/lareaulab/sdahiyat/datasets/srsf1_matches_shuffled.txt"
output_file = "/mnt/lareaulab/sdahiyat/datasets/srsf1_matches_shuffled_flanked.txt"

# Define the flanking sequence (5000 Ns)
flank_length = 5000
flanking_seq = "N" * flank_length

with open(input_file, "r") as infile, open(output_file, "w") as outfile:
    for line in infile:
        line = line.strip()
        if not line:
            continue  # Skip empty lines
        
        try:
            header, sequence = line.split("\t")
        except ValueError:
            print(f"Skipping malformed line: {line}")
            continue
        
        # Extract chromosome, start, and end positions
        if ":" in header and "-" in header:
            chrom, positions = header.split(":")
            start, end = map(int, positions.split("-"))
            
            # Adjust start and end positions
            new_start = start - 5001 
            new_end = end + 5000
            
            # Construct new header
            new_header = f"{chrom}:{new_start}-{new_end}"
            
            # Construct new sequence with flanking Ns
            flanked_sequence = flanking_seq + sequence + flanking_seq
            # Write to output file
            outfile.write(f"{new_header}\t{flanked_sequence}\n")
        else:
            print(f"Skipping improperly formatted header: {header}")

print(f"Flanked sequences written to {output_file}.")


In [None]:
# Step 1: Load keys from motif_matches_shuffled.txt
motif_file = "/mnt/lareaulab/sdahiyat/datasets/srsf1_matches_shuffled.txt"
motif_keys = set()

with open(motif_file, "r") as f:
    for line in f:
        line = line.strip()
        if not line:
            continue
        key = line.split("\t")[0]  # e.g., "chr1:65419-71585"
        motif_keys.add(key)

# Step 2: Filter canonical_dataset_created.txt
input_file = "/mnt/lareaulab/sdahiyat/illumina/canonical_dataset_created.txt"
output_file = "/mnt/lareaulab/sdahiyat/illumina/canonical_dataset_filtered_motifs.txt"

num_written = 0
with open(input_file, "r") as infile, open(output_file, "w") as outfile:
    for line in infile:
        parts = line.strip().split("\t")
        if len(parts) < 6:
            continue
        chrom = parts[2]
        try:
            start = int(parts[4])
            end = int(parts[5])
        except ValueError:
            continue
        coord_key = f"{chrom}:{start}-{end}"
        if coord_key in motif_keys:
            outfile.write(line)
            num_written += 1

print(f" Done! Wrote {num_written} matching entries to: {output_file}")
