# Build Reference

## Fetch files

In [None]:
wget "https://ftp.ensembl.org/pub/release-113/fasta/mus_musculus/dna/Mus_musculus.GRCm39.dna.primary_assembly.fa.gz"
wget "https://ftp.ensembl.org/pub/release-113/gff3/mus_musculus/Mus_musculus.GRCm39.113.gff3.gz"
gzip -d *.gz

## Construct reference transcriptome

In [None]:
gffread -w GRCm39_transcripts.fa -g Mus_musculus.GRCm39.dna.primary_assembly.fa Mus_musculus.GRCm39.113.gff3

## Fetch additional annotations

In [None]:
# annotations.py

import re
import pandas as pd
import requests
import sys
import multiprocessing as mp
from tqdm import tqdm

input_fasta = sys.argv[1]
annotations_csv = sys.argv[2]
num_cores = int(sys.argv[3])
batch_size = int(sys.argv[4])

def fetch_chunk(chunk):
    response = requests.post(
        "https://rest.ensembl.org/lookup/id",
        json={"ids": chunk},
    )
    if response.status_code == 200:
        return list(response.json().values())
    else:
        raise requests.HTTPError(f'HTTP request failed with status code {response.status_code}: {response.text}')

def get_transcript_info(ids):
    args = [ids[i:(i + batch_size)] for i in range(0, len(ids), batch_size)]
    with mp.Pool(processes=num_cores) as pool:
        all_data = list(tqdm(pool.imap_unordered(fetch_chunk, args), total=len(args), desc="Fetch Data"))
    return [item for sublist in all_data for item in sublist if item is not None]

with open(input_fasta, "r") as f:
    ids = re.findall(r"(?<=transcript:)(\w+)(?=\s|$)", f.read())

annotations_data = get_transcript_info(ids)
annotations = pd.DataFrame(annotations_data)
annotations.to_csv(annotations_csv, index=False)


In [None]:
python annotations.py GRCm39_transcripts.fa annotations.csv 16 100

## Extract the longest transcript
- The in-house reference transcriptome contains the longest transcript of each gene

In [None]:
# transcriptome.py

import re
import pandas as pd
from Bio import SeqIO
import sys
import subprocess as sp
from tqdm import tqdm

input_fasta = sys.argv[1]
output_fasta = sys.argv[2]
annotations_csv = sys.argv[3]

def count_transcripts(input_fasta):
    return int(sp.check_output(f"grep -c '^>' {input_fasta}", shell=True).decode("utf-8"))

annotations = pd.read_csv(annotations_csv)
longest_transcripts = annotations.loc[annotations.groupby("Parent")["length"].idxmax()]
total_transcripts = count_transcripts(input_fasta)

with open(output_fasta, "w") as output:
    for record in tqdm(SeqIO.parse(input_fasta, "fasta"), total=total_transcripts, desc="Filter Transcripts"):
        transcript_id = re.findall(r"(?<=transcript:)(\w+)(?=\s|$)", record.id)[0]
        if transcript_id in longest_transcripts["id"].values:
            description = longest_transcripts.loc[longest_transcripts['id'] == transcript_id, 'display_name'].fillna('').iloc[0]
            record.id = transcript_id
            record.description = description
            SeqIO.write(record, output, "fasta")


In [None]:
python transcriptome.py GRCm39_transcripts.fa GRCm39_lt.fa annotations.csv

## Build index

In [None]:
mkdir index
bowtie2-build --threads 16 GRCm39_lt.fa index/GRCm39_lt

# Build Label File

In [None]:
python project.py sra_result.csv sra_result_projected.csv '["Experiment Accession", "Experiment Title", "Instrument"]'
python project.py SraRunInfo.csv SraRunInfo_projected.csv '["Run", "Experiment", "LibraryLayout", "SampleName"]'
python left-join.py sra_result_projected.csv SraRunInfo_projected.csv "Experiment Accession" "Experiment" "Run" SRRList_joined.csv
python label.py SRRList_joined.csv SRRList.csv

# Preprocess Data

## Fetch data

In [None]:
# fetch.py

import multiprocessing as mp
from tqdm import tqdm
import sys
import utils

label_csv = sys.argv[1]
max_size = sys.argv[2]
num_cores = int(sys.argv[3])
batch_size = int(sys.argv[4])

def download_srr(srr):
    executer.run(["prefetch", "--max-size", max_size, srr, "--output-directory", "raw"], f"{srr} downloaded")

def dump_fastq(srr):
    executer.run(["fasterq-dump", "--force", "--verbose", "--split-files", "--outdir", "fastq/", "--temp", "cache/", "--threads", str(num_cores), f"raw/{srr}/{srr}.sra"], f"{srr} dumped")

def compress_fastq(srr):
    if labels.are_equal(srr, {"Layout": ["PAIRED"]}):
        executer.run(["pigz", "--processes", str(num_cores), f"fastq/{srr}_1.fastq", f"fastq/{srr}_2.fastq"], f"{srr} compressed")
    else:
        executer.run(["pigz", "--processes", str(num_cores), f"fastq/{srr}.fastq"], f"{srr} compressed")

def fetch_srr(srr):
    download_srr(srr)
    dump_fastq(srr)
    compress_fastq(srr)

labels = utils.Label(label_csv)
srr_list = labels.get_srr_list()

executer = utils.Executer()
executer.log(f"SRRs to download: {srr_list}")
executer.run(["mkdir", "-p", "raw", "fastq"], "directories created")

with mp.Pool(processes=batch_size) as pool, tqdm(total=len(srr_list), desc="Fetch Data") as pbar:
    for _ in pool.imap_unordered(fetch_srr, srr_list):
        pbar.update(1)


In [None]:
python fetch.py SRRList.csv 100G 4 4

## Process RNA-Seq data

In [None]:
# rnaseq.py

import multiprocessing as mp
from tqdm import tqdm
import sys
import sam
import utils

input_dir = sys.argv[1]
label_csv = sys.argv[2]
bowtie2_index = sys.argv[3]
reference_fasta = sys.argv[4]
num_cores = int(sys.argv[5])
batch_size = int(sys.argv[6])

def decompress_fastq(srr):
    executer.run(["cp", f"{input_dir}/{srr}_1.fastq.gz", f"{input_dir}/{srr}_2.fastq.gz", f"cache/{srr}/"], f"{srr} copied")
    executer.run(["pigz", "--decompress", "--processes", str(num_cores), f"cache/{srr}/{srr}_1.fastq.gz", f"cache/{srr}/{srr}_2.fastq.gz"], f"{srr} decompressed")

def trim_adaptor(srr):
    executer.run(["trim_galore", "--trim-n", "--output_dir", f"cache/{srr}/", "--basename", srr, "--cores", str(num_cores), "--paired", f"cache/{srr}/{srr}_1.fastq", f"cache/{srr}/{srr}_2.fastq"], f"{srr} trimmed")

def align_reads(srr):
    executer.run(["bowtie2", "--xeq", "--non-deterministic", "--end-to-end", "--very-sensitive", "--threads", str(num_cores), "-x", bowtie2_index, "-1", f"cache/{srr}/{srr}_val_1.fq", "-2", f"cache/{srr}/{srr}_val_2.fq", "-S", f"cache/{srr}/{srr}.sam"], f"{srr} aligned")

def process_sam(srr):
    metrics_data = sam.count_metrics(f"cache/{srr}/{srr}.sam")
    return metrics_data

def process_srr(srr):
    executer.run(["mkdir", "-p", f"cache/{srr}/"], f"cache/{srr}/ created")

    decompress_fastq(srr)
    trim_adaptor(srr)
    align_reads(srr)

    metrics_data = process_sam(srr)
    return srr, metrics_data

def write_database(srr, output_data):
    base_name = labels.get_base_name(srr)
    metrics.connect()
    for ref_name, data in output_data.items():
        key = f"{base_name}|{ref_name}"
        metrics.write(key, data)
    metrics.close()

    executer.log(f"{srr} processed")
    executer.run(["rm", "-r", f"cache/{srr}"], f"cache/{srr}/ removed")

labels = utils.Label(label_csv)
srr_list = labels.get_srr_list({"Experiment": ["RNA-Seq"]})

metrics = utils.Database("metrics.db", "metrics")
executer = utils.Executer()
executer.log(f"SRRs to process: {srr_list}")

with mp.Pool(processes=batch_size) as pool:
    for srr, output_data in tqdm(pool.imap_unordered(process_srr, srr_list), total=len(srr_list), desc="Process SRRs"):
        write_database(srr, output_data)


In [None]:
python rnaseq.py fastq SRRList.csv genome/index/GRCm39_lt genome/GRCm39_lt.fa 4 4

## Process icSHAPE data

In [None]:
# icshape.py

import sys
from tqdm import tqdm
import multiprocessing as mp
import sam
import utils

input_dir = sys.argv[1]
label_csv = sys.argv[2]
trim_head = sys.argv[3] # 15
bowtie2_index = sys.argv[4]
reference_fasta = sys.argv[5]
num_cores = int(sys.argv[6])
batch_size = int(sys.argv[7])

def decompress_fastq(srr):
    executer.run(["cp", f"{input_dir}/{srr}.fastq.gz", f"cache/{srr}/"], f"{srr} copied")
    executer.run(["pigz", "--decompress", "--processes", str(num_cores), f"cache/{srr}/{srr}.fastq.gz"], f"{srr} decompressed")

def trim_adaptor(srr):
    executer.run(["trim_galore", "--trim-n", "--output_dir", f"cache/{srr}/", "--basename", srr, "--cores", str(num_cores), f"cache/{srr}/{srr}.fastq"], f"{srr} trimmed")

def collapse_reads(srr):
    executer.run(["clumpify.sh", f"in=cache/{srr}/{srr}_trimmed.fq", f"out=cache/{srr}/{srr}_deduped.fq", f"threads={str(num_cores)}", "dedupe=t", "subs=0", "usetmpdir=t", "tmpdir=cache/"], f"{srr} collapsed")

def remove_index(srr):
    executer.run(["cutadapt", "--cut", trim_head, "--cores", str(num_cores), "--output", f"cache/{srr}/{srr}_unindex.fq", f"cache/{srr}/{srr}_deduped.fq"], f"{srr} index removed")

def align_reads(srr):
    executer.run(["bowtie2", "--xeq", "--non-deterministic", "--end-to-end", "--very-sensitive", "--threads", str(num_cores), "-x", bowtie2_index, "-U", f"cache/{srr}/{srr}_unindex.fq", "-S", f"cache/{srr}/{srr}.sam"], f"{srr} aligned")

def process_sam(srr):
    rtstops_data = sam.count_rtstops(f"cache/{srr}/{srr}.sam")
    return rtstops_data

def process_srr(srr):
    executer.run(["mkdir", "-p", f"cache/{srr}/"], f"cache/{srr}/ created")

    decompress_fastq(srr)
    trim_adaptor(srr)
    collapse_reads(srr)
    remove_index(srr)
    align_reads(srr)

    rtstops_data = process_sam(srr)
    return srr, rtstops_data

def write_database(srr, output_data):
    base_name = labels.get_base_name(srr)
    rtstops.connect()
    for ref_name, data in output_data.items():
        key = f"{base_name}|{ref_name}"
        rtstops.write(key, data)
    rtstops.close()

    executer.log(f"{srr} processed")
    executer.run(["rm", "-r", f"cache/{srr}"], f"cache/{srr}/ removed")

labels = utils.Label(label_csv)
srr_list = labels.get_srr_list({"Experiment": ["icSHAPE"]})

rtstops = utils.Database("rtstops.db", "rtstops")
executer = utils.Executer()
executer.log(f"SRRs to process: {srr_list}")

with mp.Pool(processes=batch_size) as pool:
    for srr, output_data in tqdm(pool.imap_unordered(process_srr, srr_list), total=len(srr_list), desc="Process SRRs"):
        write_database(srr, output_data)


In [None]:
python icshape.py fastq SRRList.csv 13 genome/index/GRCm39_lt genome/GRCm39_lt.fa 4 4

## Check Consistency

In [None]:
# replicates.py

import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import multiprocessing as mp
import utils
from tqdm import tqdm
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

database_db = sys.argv[1]
table_name = sys.argv[2]
annotations_csv = sys.argv[3]
output_dir = sys.argv[4]
batch_size = int(sys.argv[5])

def arrange_keys(keys):
    key_maps = dict()
    srr_maps = dict()
    for key in keys:
        sample, experiment, group, srr, ref_name = key.split("|")
        legend = f"{sample}|{experiment}|{group}"
        if srr not in key_maps:
            key_maps[srr] = dict()
        if legend not in srr_maps:
            srr_maps[legend] = list()
        if srr not in srr_maps[legend]:
            srr_maps[legend].append(srr)
        key_maps[srr][ref_name] = key
    
    return key_maps, srr_maps

def load_data(args):
    srr, ref_names = args
    cache = dict()
    data = {ref_name: 0 for ref_name in annotations.index}
    total_reads = 0
    for ref_name, key in ref_names.items():
        entry_data = database.read(key)
        entry_reads = np.sum(entry_data["ED"])
        entry_length = annotations.loc[ref_name, "length"]
        total_reads += entry_reads
        cache[ref_name] = entry_reads / entry_length
    for ref_name, entry_data in cache.items():
        data[ref_name] = entry_data / total_reads
    return srr, data

def group_points(data, srr_maps):
    data_grouped = dict()
    for legend, srr_list in srr_maps.items():
        data_grouped[legend] = data.loc[srr_list]
    return data_grouped

# Load data
annotations = pd.read_csv(annotations_csv)
annotations = annotations.loc[annotations.groupby("Parent")["length"].idxmax()].set_index("id").sort_index()
database = utils.Database(database_db, table_name)
database.connect()
keys = database.list()
key_maps, srr_maps = arrange_keys(keys)
data = dict()

with mp.Pool(processes=batch_size) as pool:
    for srr, srr_data in tqdm(pool.imap_unordered(load_data, key_maps.items()), total=len(key_maps), desc="Load Data"):
        data[srr] = srr_data

database.close()
data = pd.DataFrame(data).transpose()

# PCA decomposition
scaler = StandardScaler()
data_scaled = scaler.fit_transform(data)
pca = PCA(n_components=2)
data_pca = pca.fit_transform(data_scaled)
data_pca = pd.DataFrame(data_pca, index=data.index, columns=['Component 1', 'Component 2'])

# Create scatter plot
data_grouped = group_points(data_pca, srr_maps)
plt.figure(figsize=(10, 10), dpi=300)
for legend, data in data_grouped.items():
    plt.scatter(data.iloc[:,0], data.iloc[:,1], label=legend)
for index in data_pca.index:
    plt.annotate(index, (data_pca.loc[index, 'Component 1'], data_pca.loc[index, 'Component 2']))
plt.xlabel("Component 1")
plt.ylabel("Component 2")
plt.legend()
plt.title(table_name)
plt.savefig(f"{output_dir}/{table_name}_replicates.png")


In [None]:
python replicates.py metrics.db metrics genome/annotations.csv replicates 16
python replicates.py rtstops.db rtstops genome/annotations.csv replicates 16

# Build Dataset

## Format data

In [None]:
# format.py

import numpy as np
import sys
import sqlite3
import utils
import multiprocessing as mp
from tqdm import tqdm
import orjson
import ast

label_csv = sys.argv[1]
metrics_db = sys.argv[2]
rtstops_db = sys.argv[3]
reference_fasta = sys.argv[4]
alpha = float(sys.argv[5]) # 0.25
strip = ast.literal_eval(sys.argv[6])
dataset_db = sys.argv[7]
table_name = sys.argv[8]
batch_size = int(sys.argv[9])

def onehot_encode(sequence):
    tokens = {'A': [1, 0, 0, 0], 'C': [0, 1, 0, 0], 'G': [0, 0, 1, 0], 'U': [0, 0, 0, 1], 'N': [0.25, 0.25, 0.25, 0.25]}
    encoded = [tokens[nt] for nt in sequence]
    encoded = np.transpose(encoded).tolist()
    return encoded

def winsorize_scale(scores):
    keys = list(scores.keys())
    values = list(scores.values())

    lower_threshold = np.percentile(values, 5)
    upper_threshold = np.percentile(values, 95)

    if lower_threshold == upper_threshold:
        lower_threshold = np.min(values)
        upper_threshold = np.max(values)
        winsorized_values = values
    else:
        winsorized_values = np.clip(values, lower_threshold, upper_threshold)

    if lower_threshold == upper_threshold != 0:
        scaled_values = winsorized_values / upper_threshold
    elif lower_threshold == upper_threshold == 0:
        scaled_values = winsorized_values
    else:
        scaled_values = (winsorized_values - lower_threshold) / (upper_threshold - lower_threshold)

    results = {key: float(value) for key, value in zip(keys, scaled_values)}
    return results

def filter_entries(keys):
    ref_names = dict()
    sample_list = labels.unique_values("Sample")
    for sample in sample_list:
        ref_names[sample] = dict()
        srr_list = labels.get_srr_list({"Sample": [sample]})
        for srr in srr_list:
            ref_names[sample][srr] = list()

    for key in keys:
        sample = key.split("|")[0]
        srr = key.split("|")[3]
        ref_name = key.split("|")[4]
        ref_names[sample][srr].append(ref_name)

    results = dict()
    for sample in ref_names:
        results[sample] = list(set.intersection(*[set(ref_names[sample][srr]) for srr in ref_names[sample]]))

    return results

def read_fasta(filename):
    sequences = {}
    with open(filename, 'r') as file:
        current_seq_id = None
        sequence_data = ""
        for line in file:
            line = line.strip()  # Remove leading/trailing whitespace
            if line.startswith('>'):
                if current_seq_id:  # Store previous sequence
                    sequences[current_seq_id] = sequence_data
                current_seq_id = line[1:].split(' ')[0]  # Extract ID from header
                sequence_data = ""
            else:
                sequence_data += line
        if current_seq_id:  # Store last sequence
            sequences[current_seq_id] = sequence_data
    return sequences

def load_reactivity(sample, ref_name):
    NAIN3_list = labels.get_srr_list({"Sample": [sample], "Experiment": ["icSHAPE"], "Group": ["NAIN3"]})
    DMSO_list = labels.get_srr_list({"Sample": [sample], "Experiment": ["icSHAPE"], "Group": ["DMSO"]})

    DMSO_depth = dict()
    DMSO_stop = dict()
    NAIN3_stop = dict()

    num_rep = len(DMSO_list)
    for srr in DMSO_list:
        key = f"{sample}|icSHAPE|DMSO|{srr}|{ref_name}"
        rtstops_entry = rtstops.read(key)
        for index, pos in enumerate(rtstops_entry["PS"]):
            if pos not in DMSO_depth:
                DMSO_depth[pos] = 0
            if pos not in DMSO_stop:
                DMSO_stop[pos] = 0
            DMSO_depth[pos] += rtstops_entry["RD"][index] / num_rep
            DMSO_stop[pos] += rtstops_entry["ED"][index] / num_rep

    num_rep = len(NAIN3_list)
    for srr in NAIN3_list:
        key = f"{sample}|icSHAPE|NAIN3|{srr}|{ref_name}"
        rtstops_entry = rtstops.read(key)
        for index, pos in enumerate(rtstops_entry["PS"]):
            if pos not in NAIN3_stop:
                NAIN3_stop[pos] = 0
            NAIN3_stop[pos] += rtstops_entry["ED"][index] / num_rep

    start = min(min(DMSO_stop.keys()), min(NAIN3_stop.keys())) # 1-based position
    end = max(max(DMSO_stop.keys()), max(NAIN3_stop.keys())) # 1-based position
        
    md_indicators = dict() # missing data indicators, 0b001: DMSO missing, 0b010: NAIN3 missing, 0b100: RNA-Seq missing
    reactivity_scores = dict()
    for pos in range(start, end + 1):
        md_indicators[pos] = 0
        reactivity_scores[pos] = 0
        if pos not in DMSO_stop:
            md_indicators[pos] |= 0b001
        if pos not in NAIN3_stop:
            md_indicators[pos] |= 0b010
        if md_indicators[pos] == 0:
            reactivity_scores[pos] = (NAIN3_stop[pos] - alpha * DMSO_stop[pos]) / DMSO_depth[pos]

    total_density = np.sum(list(DMSO_depth.values()))
    reactivity_scores = winsorize_scale(reactivity_scores)

    return reactivity_scores, md_indicators, total_density

def load_metrics(sample, ref_name):
    srr_list = labels.get_srr_list({"Sample": [sample], "Experiment": ["RNA-Seq"]})
    
    keys = list()
    for srr in srr_list:
        key = f"{sample}|RNA-Seq|NA|{srr}|{ref_name}"
        if key in metrics_keys:
            keys.append(key)
    
    num_rep = len(keys)
    if num_rep == 0:
        raise ValueError("no RNA-Seq data")
    
    read_depth = dict()
    end_depth = dict()
    end_rate = dict()
    mismatch_count = dict()
    mismatch_rate = dict()

    for key in keys:
        metrics_entry = metrics.read(key)
        for index, pos in enumerate(metrics_entry["PS"]):
            if pos not in read_depth:
                read_depth[pos] = 0
            if pos not in end_depth:
                end_depth[pos] = 0
            if pos not in mismatch_count:
                mismatch_count[pos] = 0
            read_depth[pos] += metrics_entry["RD"][index] / num_rep
            end_depth[pos] += metrics_entry["ED"][index] / num_rep
            mismatch_count[pos] += metrics_entry["MC"][index] / num_rep

    for pos in read_depth:
        end_rate[pos] = end_depth[pos] / read_depth[pos]
        mismatch_rate[pos] = mismatch_count[pos] / read_depth[pos]

    total_depth = np.sum(list(read_depth.values()))
    total_end_rate = np.sum(list(end_rate.values()))
    total_mismatch_rate = np.sum(list(mismatch_rate.values()))
    read_depth = winsorize_scale(read_depth)

    return read_depth, end_rate, mismatch_rate, total_depth, total_end_rate, total_mismatch_rate

def format_entry(args):
    sample, ref_name = args
    entry_name = f"{sample}|{ref_name}"

    try:
        reactivity_scores_dict, reactivity_indicators_dict, total_density = load_reactivity(sample, ref_name)
        read_depth_dict, end_rate_dict, mismatch_rate_dict, total_depth, total_end_rate, total_mismatch_rate = load_metrics(sample, ref_name)
    except ValueError as e:
        executer.log(f"{e} for {entry_name}")
        return None
    
    sequence = reference_transcriptome[ref_name].replace("T", "U")
    full_length = len(sequence)
    channel_A, channel_C, channel_G, channel_U = onehot_encode(sequence)
    
    indicators = [0] * full_length # 0b001: DMSO missing, 0b010: NAIN3 missing, 0b100: RNA-Seq missing
    reactivity = [0] * full_length
    read_depth = [0] * full_length
    end_rate = [0] * full_length
    mismatch_rate = [0] * full_length
    for pos in range(1, full_length + 1):
        if pos in reactivity_scores_dict:
            reactivity[pos - 1] = reactivity_scores_dict[pos]
            indicators[pos - 1] |= reactivity_indicators_dict[pos]
        else:
            indicators[pos - 1] |= 0b011
        if pos in read_depth_dict:
            read_depth[pos - 1] = read_depth_dict[pos]
            end_rate[pos - 1] = end_rate_dict[pos]
            mismatch_rate[pos - 1] = mismatch_rate_dict[pos]
        else:
            indicators[pos - 1] |= 0b100

    # find the first and last pos where indicator == 0
    start = None # 1-based position
    end = None # (-1)-based position in reverse
    for index in range(full_length):
        if indicators[index] == 0 and start is None:
            start = index + 1
        if indicators[index] == 0:
            end = index - full_length
            
    if (start is None) or (end is None) or (start - 1 >= end + full_length):
        executer.log(f"no valid data for {entry_name}")
        return None
    
    valid_length = end + full_length - start + 2
    strip_length = full_length - valid_length
    gap = 0
    for index in range(start - 1, end + full_length + 1):
        if indicators[index] != 0:
            gap += 1
    mean_depth = total_depth / valid_length
    mean_end = total_end_rate / valid_length
    mean_density = total_density / valid_length
    mean_mismatch = total_mismatch_rate / valid_length

    if strip:
        sequence = sequence[(start - 1):(end + full_length + 1)]
        channel_A = channel_A[(start - 1):(end + full_length + 1)]
        channel_C = channel_C[(start - 1):(end + full_length + 1)]
        channel_G = channel_G[(start - 1):(end + full_length + 1)]
        channel_U = channel_U[(start - 1):(end + full_length + 1)]
        read_depth = read_depth[(start - 1):(end + full_length + 1)]
        end_rate = end_rate[(start - 1):(end + full_length + 1)]
        mismatch_rate = mismatch_rate[(start - 1):(end + full_length + 1)]
        reactivity = reactivity[(start - 1):(end + full_length + 1)]
        indicators = indicators[(start - 1):(end + full_length + 1)]

    entry = (entry_name, 
             orjson.dumps(channel_A), orjson.dumps(channel_C), orjson.dumps(channel_G), orjson.dumps(channel_U), orjson.dumps(read_depth), orjson.dumps(end_rate), orjson.dumps(mismatch_rate), orjson.dumps(reactivity), orjson.dumps(indicators), 
             sample, ref_name, sequence, start, end, full_length, valid_length, strip_length,
             mean_depth, mean_end, mean_density, mean_mismatch, gap
            )
    return entry

labels = utils.Label(label_csv)
reference_transcriptome = read_fasta(reference_fasta)

metrics = utils.Database(metrics_db, "metrics")
rtstops = utils.Database(rtstops_db, "rtstops")
executer = utils.Executer()

metrics.connect()
rtstops.connect()
metrics_keys = metrics.list()
rtstops_keys = rtstops.list()
ref_names = filter_entries(rtstops_keys + metrics_keys)
args = [(sample, ref_name) for sample in ref_names for ref_name in ref_names[sample]]

dataset = sqlite3.connect(dataset_db)
dataset_cursor = dataset.cursor()
dataset_cursor.execute(f"CREATE TABLE IF NOT EXISTS {table_name} (SeqID TEXT PRIMARY KEY, A TEXT, C TEXT, G TEXT, U TEXT, RD TEXT, ER TEXT, MR TEXT, RT TEXT, IC TEXT, Sample TEXT, RefName TEXT, Sequence TEXT, Start INT, End INT, FullLength INT, ValidLength INT, StripLength INT, MeanDepth REAL, MeanEnd REAL, MeanDensity REAL, MeanMismatch REAL, Gap INT)")

with mp.Pool(processes=batch_size) as pool:
    for entry in tqdm(pool.imap_unordered(format_entry, args), total=len(args), desc="Format Dataset"):
        if entry is not None:
            dataset_cursor.execute(f"INSERT INTO {table_name} VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", entry)

dataset.commit()
dataset.close()
metrics.close()
rtstops.close()


In [None]:
python format.py SRRList.csv metrics.db rtstops.db genome/GRCm39_lt.fa 0.25 False mouse.db mouse 16

## Add references

In [None]:
# references.py

import sys
import pandas as pd
import numpy as np
import sqlite3

annotations_csv = sys.argv[1]
dataset_db = sys.argv[2]

annotations = pd.read_csv(annotations_csv)
annotations = annotations.loc[annotations.groupby("Parent")["length"].idxmax()].sort_values(by="id")
dataset = sqlite3.connect(dataset_db)
cursor = dataset.cursor()
cursor.execute("CREATE TABLE IF NOT EXISTS ref (RefName TEXT PRIMARY KEY, DisplayName TEXT, Biotype TEXT, Start INT, End INT, Length INT, Strand INT, Parent TEXT, SeqRegionName INT, Canonical INT, GENCODEPrimary INT, AssemblyName TEXT, Version INT)")

for _, row in annotations.iterrows():
    display_name = row["display_name"]
    if display_name is np.nan:
        display_name = row["id"]
    entry = (row["id"], display_name, row["biotype"], row["start"], row["end"], row["length"], row["strand"], row["Parent"], row["seq_region_name"], row["is_canonical"], row["gencode_primary"], row["assembly_name"], row["version"])
    cursor.execute(f"INSERT INTO ref VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", entry)

dataset.commit()
dataset.close()


In [None]:
python references.py genome/annotations.csv mouse.db

# Assemble Datasets

In [None]:
# assembly.py

import sqlite3
import re
import sys

source_db = sys.argv[1]
source_table = sys.argv[2]
target_db = sys.argv[3]
target_table = sys.argv[4]

source_conn = sqlite3.connect(source_db)
source_cursor = source_conn.cursor()
target_conn = sqlite3.connect(target_db)
target_cursor = target_conn.cursor()

source_cursor.execute(f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{source_table}'")
source_schema = source_cursor.fetchone()
source_schema = re.search(r'\((.*)\)', source_schema, re.DOTALL).group(1).strip()

target_cursor.execute(f"CREATE TABLE IF NOT EXISTS {target_table} ({source_schema})")
target_cursor.execute(f"ATTACH DATABASE '{source_db}' AS source")
target_cursor.execute(f"PRAGMA table_info('{target_table}');")
target_pragma = target_cursor.fetchall()
target_columns = [col[1] for col in target_pragma]

insert_columns = ", ".join(target_columns)
select_columns = ", ".join([f"source.{source_table}.{col}" for col in target_columns])

query = f"INSERT INTO {target_table} ({insert_columns}) SELECT {select_columns} FROM source.{source_table}"
print(query)
target_cursor.execute(query)
target_conn.commit()


In [None]:
python assembly.py /home/test/xwt/pred/data/neural/process/neural.db neural assembly.db assembly
python assembly.py /home/test/xwt/pred/data/zebrafish/process/zebrafish.db zebrafish assembly.db assembly
python assembly.py /home/test/xwt/pred/data/neural/process/neural.db ref assembly.db ref
python assembly.py /home/test/xwt/pred/data/zebrafish/process/zebrafish.db ref assembly.db ref

# Quality Control

## Data distribution

In [None]:
# histogram.py

import sys
import matplotlib.pyplot as plt
from tqdm import tqdm
import sqlite3
import ast
import numpy as np

database_db = sys.argv[1]
table_name = sys.argv[2]
features = ast.literal_eval(sys.argv[3])
clause = sys.argv[4]
num_bins = int(sys.argv[5])
logarithmic = ast.literal_eval(sys.argv[6])
eps = float(sys.argv[7])
output_dir = sys.argv[8]

conn = sqlite3.connect(database_db)
cursor = conn.cursor()

for feature in tqdm(features, desc="Calculate Features"):
    cursor.execute(f"SELECT {feature} FROM {table_name} {clause}")
    if logarithmic:
        data = [np.log10(abs(float(row[0])) + eps) for row in cursor.fetchall()]
    else:
        data = [float(row[0]) for row in cursor.fetchall()]
    plt.figure(figsize=(10, 5), dpi=300)
    plt.hist(data, bins=num_bins, edgecolor="white")
    if logarithmic:
        plt.xlabel(f"$\\log_{{10}}$(|{feature}| + {eps})")
    else:
        plt.xlabel(feature)
    plt.ylabel("Frequency")
    plt.title(f"{table_name}: {feature}")
    plt.savefig(f"{output_dir}/{table_name}_{feature}.png")

conn.close()


In [None]:
python histogram.py ../assembly.db assembly "['MeanEnd', 'MeanMismatch']" "" 100 True 1e-5 original
python histogram.py ../assembly.db assembly "['Start', 'End', 'FullLength', 'ValidLength', 'StripLength', 'MeanDepth', 'MeanDensity', 'Gap']" "" 100 True  1 original
python histogram.py ../assembly.db assembly "['MeanEnd', 'MeanMismatch']" "WHERE FullLength BETWEEN 64 AND 4096 AND StripLength <= 16 AND MeanDepth >= 16 AND MeanDensity >= 64 AND Gap = 0" 100 True 1e-5 filtered
python histogram.py ../assembly.db assembly "['Start', 'End', 'FullLength', 'ValidLength', 'StripLength', 'MeanDepth', 'MeanDensity', 'Gap']" "WHERE FullLength BETWEEN 64 AND 4096 AND StripLength <= 16 AND MeanDepth >= 16 AND MeanDensity >= 64 AND Gap = 0" 100 True 1 filtered

## Mutual information

In [None]:
# informtiation.py

import pandas as pd
from sklearn.feature_selection import mutual_info_regression
import sqlite3
import sys
import orjson
import ast
from tqdm import tqdm

database_db = sys.argv[1]
table_name = sys.argv[2]
key_col = sys.argv[3]
features = ast.literal_eval(sys.argv[4])
clause = sys.argv[5]
output_dir = sys.argv[6]

conn = sqlite3.connect(database_db)
cursor = conn.cursor()
cursor.execute(f"SELECT {', '.join(features)} FROM {table_name} {clause}")
rows = cursor.fetchall()

data = {feature: list() for feature in features}

for row in tqdm(rows, desc="Load Data"):
    for feature, value in zip(features, row):
        data[feature].extend(orjson.loads(value))

data = pd.DataFrame(data)
column_names = list(data.columns)
data = data.to_numpy()
n_features = data.shape[1]
mi_matrix = list()

for i in tqdm(range(n_features), desc="Compute MI"):
    mi_matrix.append(mutual_info_regression(data, data[:, i]))

mi_matrix = pd.DataFrame(mi_matrix, columns=column_names, index=column_names)
mi_matrix.to_csv(f'{output_dir}/information.csv', index=True)


In [None]:
python information.py ../assembly.db assembly SeqID "['A', 'C', 'G', 'U', 'RD', 'ER', 'MR', 'IC', 'RT']" "WHERE FullLength BETWEEN 64 AND 4096 AND StripLength <= 16 AND MeanDepth >= 16 AND MeanDensity >= 64 AND Gap = 0" filtered

## Motif completeness

In [None]:
# completeness.py

import sys
import sqlite3
import multiprocessing as mp
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
from tqdm import tqdm
from functools import partial

database_db = sys.argv[1]
table_name = sys.argv[2]
clause = sys.argv[3]
max_len = int(sys.argv[4])
output_dir = sys.argv[5]
batch_size = int(sys.argv[6])

def add_dict(this, other):
    for key, val in other.items():
        if key in this:
            this[key] += val
        else:
            this[key] = val
    return this

def count_motif(sequence, motif_length):
    end = len(sequence) - motif_length + 1
    motifs = dict()
    for start in range(motif_length):
        for i in range(start, end, motif_length):
            motif = sequence[i:i+motif_length]
            if 'N' in motif:
                continue
            if motif in motifs:
                motifs[motif] += 1
            else:
                motifs[motif] = 1
    return motifs

def summarize_counts(motif_counts):
    completeness = dict()
    p_values = dict()
    for motif_length in range(1, max_len + 1):
        observed = len(list(motif_counts[motif_length].values()))
        expected = 4**motif_length
        observed_counts = np.array(list(motif_counts[motif_length].values()))
        expected_counts = observed_counts.sum() / expected
        completeness[motif_length] = observed / expected
        chi_squared = np.sum((observed_counts - expected_counts)**2 / expected_counts) + (expected - observed) * expected_counts
        p_values[motif_length] = stats.chi2.sf(chi_squared, expected - 1)
    return completeness, p_values

conn = sqlite3.connect(database_db)
cursor = conn.cursor()
cursor.execute(f"SELECT Sequence FROM {table_name} {clause}")
rows = cursor.fetchall()
sequences = [row[0] for row in rows]
motif_counts = dict()
conn.close()

with mp.Pool(processes=batch_size) as pool:
    for motif_length in tqdm(range(1, max_len + 1), desc=f"Count Motif"):
        partial_count = partial(count_motif, motif_length=motif_length)
        motif_counts[motif_length] = dict()
        for motifs in pool.imap_unordered(partial_count, sequences):
            motif_counts[motif_length] = add_dict(motif_counts[motif_length], motifs)

completeness, p_values = summarize_counts(motif_counts)

plt.figure(figsize=(10, 5), dpi=300)
plt.plot(list(completeness.keys()), list(completeness.values()))
plt.xlabel("Motif Length")
plt.ylabel("Proportion Complete")
plt.title(f"{table_name}: Completeness")
plt.savefig(f"{output_dir}/{table_name}_completeness.png")

plt.figure(figsize=(10, 5), dpi=300)
plt.plot(list(p_values.keys()), list(p_values.values()))
plt.xlabel("Motif Length")
plt.ylabel("P-Value Balanced")
plt.title(f"{table_name}: Balance")
plt.savefig(f"{output_dir}/{table_name}_balance.png")


In [None]:
python completeness.py ../assembly.db assembly "WHERE FullLength BETWEEN 64 AND 4096 AND StripLength <= 16 AND MeanDepth >= 16 AND MeanDensity >= 64 AND Gap = 0" 16 filtered 16

# Export Data

## Export datasets

In [None]:
# query.py

import sqlite3
import pandas as pd
import sys

database_db = sys.argv[1]
output_csv = sys.argv[2]
query = sys.argv[3]

conn = sqlite3.connect(database_db)
cursor = conn.cursor()
cursor.execute(query)
rows = cursor.fetchall()
col_names = [description[0] for description in cursor.description]
conn.close()

df = dict()
for col_name in col_names:
    df[col_name] = list()
for row in rows:
    for col_name, value in zip(col_names, row):
        if type(value) == bytes:
            value = value.decode()
        df[col_name].append(value)

df = pd.DataFrame(df)
df.to_csv(output_csv, index=False)

In [None]:
python query.py assembly.db sample.csv "SELECT * FROM assembly WHERE FullLength BETWEEN 64 AND 4096 AND StripLength <= 16 AND MeanDepth >= 16 AND MeanDensity >= 64 AND Gap = 0 ORDER BY RANDOM() LIMIT 10"
python query.py assembly.db assembly.csv "SELECT * FROM assembly WHERE FullLength BETWEEN 64 AND 4096 AND StripLength <= 16 AND MeanDepth >= 16 AND MeanDensity >= 64 AND Gap = 0 ORDER BY RANDOM()"

## Update datasets

In [None]:
# update.py

import sqlite3
import pandas as pd
import sys

database_db = sys.argv[1]
dataset_csv = sys.argv[2]
key_column = sys.argv[3]
query = sys.argv[4]

conn = sqlite3.connect(database_db)
cursor = conn.cursor()
cursor.execute(query)
rows = cursor.fetchall()
col_names = [description[0] for description in cursor.description]
conn.close()

df = dict()
for col_name in col_names:
    df[col_name] = list()
for row in rows:
    for col_name, value in zip(col_names, row):
        if type(value) == bytes:
            value = value.decode()
        df[col_name].append(value)

df = pd.DataFrame(df)
data = pd.read_csv(dataset_csv)
data = pd.merge(data, df, on=key_column, how="outer")
data.to_csv(dataset_csv, index=False)

In [None]:
python update.py assembly.db assembly.csv SeqID "SELECT SeqID, RibonanzaNetPredictions, RibonanzaNetMAE FROM assembly WHERE FullLength BETWEEN 64 AND 4096 AND StripLength <= 16 AND MeanDepth >= 16 AND MeanDensity >= 64 AND Gap = 0"

# Modules

## Utilities

In [None]:
# utils.py

import pandas as pd
import sqlite3
import orjson
import ast
import subprocess
import datetime

class Label:
    def __init__(self, label_csv):
        self.data = pd.read_csv(label_csv, keep_default_na=False, na_values=[])
        self.sanity_check()
        self.data["SRR"] = self.data["SRR"].apply(ast.literal_eval)

    def sanity_check(self):
        for _, row in self.data.iterrows():
            if row["Layout"] != "SINGLE" and row["Layout"] != "PAIRED":
                raise ValueError(f"invalid library layout type: {row["Layout"]}")
            if (row["Experiment"] == "RNA-Seq" and row["Layout"] != "PAIRED") or (row["Experiment"] == "icSHAPE" and row["Layout"] != "SINGLE"):
                raise ValueError(f"incorrect library layout for {row["GSM"]}")
            if (row["Experiment"] == "RNA-Seq" and row["Group"] != "NA") or (row["Experiment"] == "icSHAPE" and (row["Group"] != "DMSO" and row["Group"] != "NAIN3")):
                raise ValueError(f"incorrect experiment group for {row["GSM"]}")
            
    def get_row(self, srr):
        rows = self.data[self.data["SRR"].apply(lambda x: srr in x)]
        if rows.shape[0] > 1:
            raise ValueError(f"{srr} found in multiple rows")
        elif rows.shape[0] == 0:
            raise ValueError(f"{srr} not found")
        row = rows.iloc[0]
        return row
    
    def unique_values(self, property):
        values = self.data[property].unique().tolist()
        return values

    def get_base_name(self, srr):
        row = self.get_row(srr)
        base_name = f"{row["Sample"]}|{row["Experiment"]}|{row["Group"]}|{srr}"
        return base_name
    
    def are_equal(self, srr, map):
        row = self.get_row(srr)
        indicators = list()
        for property, value in map.items():
            indicators.append(row[property] in value)
        return all(indicators)
    
    def get_srr_list(self, map=dict()):
        if len(map) == 0:
            srr_lists = self.data["SRR"].tolist()
            srr_list = [srr for item in srr_lists for srr in item]
        else:
            filtered_data = self.data
            for property, value in map.items():
                filtered_data = filtered_data[filtered_data[property].isin(value)]
            srr_lists = filtered_data["SRR"].tolist()
            srr_list = [srr for item in srr_lists for srr in item]
        return srr_list

class Database:
    def __init__(self, database_db, table_name):
        self.database_db = database_db
        self.table_name = table_name

    def connect(self):
        self.conn = sqlite3.connect(self.database_db)
        self.cursor = self.conn.cursor()
        self.cursor.execute(f"CREATE TABLE IF NOT EXISTS {self.table_name} (key TEXT PRIMARY KEY, value TEXT)")

    def write(self, key, value):
        serialized_value = orjson.dumps(value)
        self.cursor.execute(f"INSERT INTO {self.table_name} VALUES (?, ?)", (key, serialized_value))

    def read(self, key):
        self.cursor.execute(f"SELECT value FROM {self.table_name} WHERE key = ?", (key,))
        value = self.cursor.fetchone()[0]
        retrieved = orjson.loads(value)
        return retrieved
    
    def list(self):
        self.cursor.execute(f"SELECT key FROM {self.table_name}")
        keys = [row[0] for row in self.cursor.fetchall()]
        return keys
    
    def close(self):
        self.conn.commit()
        self.conn.close()

class Executer:
    def __init__(self, log_file="logs.txt", executable_path="/bin/bash"):
        self.logs = open(log_file, "a")
        self.executable = executable_path

    def run(self, command, message):
        self.log(str(command))
        subprocess.run(command, stdout=self.logs, stderr=self.logs)
        self.log(message)

    def shell(self, script, message):
        self.log(script)
        subprocess.run(script, shell=True, executable=self.executable, stdout=self.logs, stderr=self.logs)
        self.log(message)

    def log(self, message):
        self.logs.write(f"{datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}: {message}\n")
        self.logs.flush()


## SAM parser

In [None]:
# sam.py

import pysam
from collections import defaultdict

def count_metrics(sam_file):
    sam_data = pysam.AlignmentFile(sam_file, "rb")
    metrics_count = defaultdict(lambda: defaultdict(lambda: {"RD": 0, "ED": 0, "MC": 0}))

    for read in sam_data.fetch():
        flag = read.flag

        if (not (flag & 3 == 3)) or (flag & (4 | 8 | 256 | 512 | 1024 | 2048)):
            continue
            
        rname = sam_data.get_reference_name(read.reference_id)
        pos = read.reference_start + 1
        
        # Process CIGAR string
        indicator = []
        for op, count in read.cigartuples:
            if op == 2:  # D
                indicator.extend(['D'] * count)
            elif op == 7:  # =
                indicator.extend(['='] * count)
            elif op == 8:  # X
                indicator.extend(['X'] * count)

        # Count end depth
        end = pos + len(indicator) - 1 if flag & 16 else pos
        metrics_count[rname][end]["ED"] += 1
        
        # Batch process read depth and mismatches
        for index, char in enumerate(indicator):
            bp = pos + index
            metrics_count[rname][bp]["RD"] += 1
            if char != '=':
                metrics_count[rname][bp]["MC"] += 1

    sam_data.close()

    # Convert to final format
    metrics_data = {}
    for rname, positions in metrics_count.items():
        metrics_entry = {"PS": [], "RD": [], "ED": [], "MC": []}
        for bp, counts in sorted(positions.items()):
            metrics_entry["PS"].append(bp)
            metrics_entry["RD"].append(counts["RD"])
            metrics_entry["ED"].append(counts["ED"])
            metrics_entry["MC"].append(counts["MC"])
        metrics_data[rname] = metrics_entry

    return metrics_data

def count_rtstops(sam_file):
    sam_data = pysam.AlignmentFile(sam_file, "rb")
    rtstops_count = defaultdict(lambda: defaultdict(lambda: {"RD": 0, "ED": 0, "MC": 0}))

    for read in sam_data.fetch():
        flag = read.flag

        if flag:
            continue
            
        rname = sam_data.get_reference_name(read.reference_id)
        pos = read.reference_start + 1
        
        # Process CIGAR string
        indicator = []
        for op, count in read.cigartuples:
            if op == 2:  # D
                indicator.extend(['D'] * count)
            elif op == 7:  # =
                indicator.extend(['='] * count)
            elif op == 8:  # X
                indicator.extend(['X'] * count)

        # Count end depth
        rtstops_count[rname][pos]["ED"] += 1
        
        # Batch process read depth and mismatches
        for index, char in enumerate(indicator):
            bp = pos + index
            rtstops_count[rname][bp]["RD"] += 1
            if char != '=':
                rtstops_count[rname][bp]["MC"] += 1

    sam_data.close()

    # Convert to final format
    rtstops_data = {}
    for rname, positions in rtstops_count.items():
        rtstops_entry = {"PS": [], "RD": [], "ED": [], "MC": []}
        for bp, counts in sorted(positions.items()):
            rtstops_entry["PS"].append(bp)
            rtstops_entry["RD"].append(counts["RD"])
            rtstops_entry["ED"].append(counts["ED"])
            rtstops_entry["MC"].append(counts["MC"])
        rtstops_data[rname] = rtstops_entry

    return rtstops_data


# Extra Tools

## Left join CSV

In [None]:
# left-join.py

import pandas as pd
import sys

def left_join_csv(file1, file2, left_col, right_col, multiple_map_col):
    """Performs a left join on two CSV files with specified column mappings.

    Args:
        file1 (str): Path to the first CSV file.
        file2 (str): Path to the second CSV file.
        left_col (str): Column name from file1 to use in the join.
        right_col (str): Corresponding column name from file2 to use in the join.

    Returns:
        pandas.DataFrame: The resulting DataFrame after the join.
    """
    
    # Load data from CSV files into Pandas DataFrames
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)

    def aggregate_columns(group):
        """Aggregates columns from the right table, using a list or single value."""
        result = {}
        for col in df2.columns:
            if col != right_col and col != multiple_map_col:  # Exclude the 'SRX' column used for joining
                values = group[col].unique()
                assert len(values) == 1
                result[col] = values[0]
            elif col == multiple_map_col:
                values = group[col].unique()
                result[col] = list(values)
        return pd.Series(result)

    # Perform the inner join 
    joined_df = pd.merge(df1, df2, left_on=left_col, right_on=right_col, how='left')

    grouped_df = joined_df.groupby(left_col).apply(aggregate_columns, include_groups=False).reset_index()

    final_df = pd.merge(df1, grouped_df, on=left_col, how="left")

    return final_df

if __name__ == "__main__":
    file1 = sys.argv[1]
    file2 = sys.argv[2]

    # Define corresponding column names for the join
    left_col = sys.argv[3]
    right_col = sys.argv[4]
    multiple_map_col = sys.argv[5]

    result_df = left_join_csv(file1, file2, left_col, right_col, multiple_map_col)

    result_df.to_csv(sys.argv[6], index=False)


## Project CSV

In [None]:
# project.py

import pandas as pd
import sys
import ast

def project_dataframe(input_file, columns_to_keep):
    """Projects a CSV table, returning a pandas DataFrame with specified columns.

    Args:
        input_file: The path to the input CSV file.
        columns_to_keep: A list of column titles to keep.
    
    Returns:
        pandas.DataFrame: A DataFrame containing only the specified columns.
    """
    
    # Read the CSV directly into a DataFrame
    df = pd.read_csv(input_file)
    
    # Check if all columns_to_keep exist in the DataFrame
    valid_columns = [col for col in columns_to_keep if col in df.columns]
    
    # Return the projected DataFrame
    return df[valid_columns]

if __name__ == "__main__":
    # Get user input (or modify this section for direct file paths)
    input_file = sys.argv[1]
    output_file = sys.argv[2]
    columns_to_keep = ast.literal_eval(sys.argv[3])

    projected_df = project_dataframe(input_file, columns_to_keep)
    projected_df.to_csv(output_file, index=False)


## Label CSV

In [None]:
# label.py

import pandas as pd
import sys

def label_experiment(source):
    data = {"GSM": [], "SRR": [], "Sample": [], "Experiment": [], "Group": [], "Repetition": [], "Layout": [], "Instrument": []}
    for index, row in source.iterrows():
        data["GSM"].append(row["SampleName"])
        data["SRR"].append(row["Run"])
        data["Layout"].append(row["LibraryLayout"])
        data["Instrument"].append(row["Instrument"])
        label = row["Experiment Title"]
        
        # Sample
        if "0h" in label:
            data["Sample"].append("0h")
        elif "4h" in label:
            data["Sample"].append("4h")
        else:
            data["Sample"].append("OTHER")

        # Experiment & Group
        if ("-N" in label) and ("icSHAPE" in label) :
            data["Experiment"].append("icSHAPE")
            data["Group"].append("NAIN3")
        elif ("-D" in label) and ("icSHAPE" in label):
            data["Experiment"].append("icSHAPE")
            data["Group"].append("DMSO")
        elif "WT" in label:
            data["Experiment"].append("RNA-Seq")
            data["Group"].append("NA")
        else:
            data["Experiment"].append("OTHER")
            data["Group"].append("OTHER")

        # Repetition
        if ("-1" in label) or ("-N1" in label) or ("-D1" in label):
            data["Repetition"].append("1")
        elif ("-2" in label) or ("-N2" in label) or ("-D2" in label):
            data["Repetition"].append("2")
        elif ("-3" in label) or ("-N3" in label) or ("-D3" in label):
            data["Repetition"].append("3")
        else:
            data["Repetition"].append("NA")
    
    # Remove OTHER rows
    df = pd.DataFrame(data)
    filtered = df[~df.isin(['OTHER']).any(axis=1)]
    return filtered

if __name__ == "__main__":
    input_file = sys.argv[1]
    output_file = sys.argv[2]
    source = pd.read_csv(input_file)
    
    df = label_experiment(source)
    df.to_csv(output_file, index=False)


## Merge CSV

In [None]:
# merge.py

import pandas as pd
import sys

def merge_csv(file1, file2):
    """Merges two CSV files with the same header.

    Args:
        file1 (str): Path to the first CSV file.
        file2 (str): Path to the second CSV file.

    Returns:
        pandas.DataFrame: The resulting DataFrame after merging.
    """
    
    # Load data from CSV files into Pandas DataFrames
    df1 = pd.read_csv(file1)
    df2 = pd.read_csv(file2)

    # Concatenate the two DataFrames
    merged_df = pd.concat([df1, df2], ignore_index=True)

    return merged_df

if __name__ == "__main__":
    file1 = sys.argv[1]
    file2 = sys.argv[2]

    result_df = merge_csv(file1, file2)

    result_df.to_csv(sys.argv[3], index=False)