In [1]:
import glob
import os
from concurrent.futures import ThreadPoolExecutor

import pysam
import pandas as pd
from sklearn.model_selection import train_test_split

from siren.utils.squiggletools import SquiggleFile, SquiggleFileLegacy

In [2]:
def classify_transcripts(bam_file_path):
    # Read BAM file
    bam_file = pysam.AlignmentFile(bam_file_path, "rb")
    
    # Extract necessary information from BAM file
    bam_arr = []
    for aln in bam_file.fetch():
        bam_arr.append({
            'query_name': aln.query_name,
            'reference_name': aln.reference_name,
            'sequence': len(aln.seq) 
        })
    df_bam = pd.DataFrame(bam_arr)

    # Extract max length and sequence length
    df_bam['max_len'] = df_bam['reference_name'].str.split('|').str[6].astype(int)
    df_bam['seq_len'] = pd.to_numeric(df_bam['sequence'], errors='coerce')

    # Retain only sequences longer than 100 bp
    df_bam = df_bam[((df_bam['sequence'] / df_bam['seq_len']) > 0.2)]

    # Extract transcript name
    df_bam['transcript_name'] = df_bam['reference_name'].str.split('|').str[4]

    # Classify transcripts as "mt" or "non-mt"
    df_bam['class'] = df_bam['transcript_name'].apply(lambda x: 'mt' if x.split('-')[0] == 'MT' else 'non-mt')

    # Count occurrences of each transcript
    transcript_counts = df_bam['reference_name'].value_counts()

    # Retain only transcripts with at least 20 copies
    valid_transcripts = transcript_counts[transcript_counts > 20].index
    df_bam = df_bam[df_bam['reference_name'].isin(valid_transcripts)]

    return df_bam

def extract_signal_dir(directory, query_names):
    signals = []
    found_query_names = []
    query_names_set = set(query_names)  # For faster lookup

    try:
        # Fetch both fast5 and pod5 files in a single search
        all_files = glob.glob(os.path.join(directory, "*.fast5")) + glob.glob(os.path.join(directory, "*.pod5"))

        def process_file(file):
            sq = None
            if file.endswith('.fast5'):
                sq = SquiggleFileLegacy(file)
            elif file.endswith('.pod5'):
                sq = SquiggleFile(file)

            signals_from_file = []
            found_queries = []
            
            if sq:
                reads = set(sq.list_reads())  # Convert reads to set for faster query
                relevant_queries = query_names_set & reads  # Get the intersection of reads and query_names
                for query_name in relevant_queries:
                    if file.endswith('.fast5'):
                        signal = sq.fetch_squiggle(f'read_{query_name}')['Raw']['Signal']
                    else:
                        signal = sq.fetch_squiggle(f'{query_name}')
                    if signal is not None:
                        signals_from_file.append(signal)
                        found_queries.append(query_name)

            return signals_from_file, found_queries

        # Parallelize file processing if the files are large or numerous
        with ThreadPoolExecutor() as executor:
            results = list(executor.map(process_file, all_files))
        
        # Collect the signals and query names from the parallel results
        for signal_list, query_list in results:
            signals.extend(signal_list)
            found_query_names.extend(query_list)

    except Exception as er:
        print(f"Error occurred: {er}")  # Improved error logging

    return signals, found_query_names

In [3]:
base = ""
bam_file = f"{base}data/rna004/gm12878.gc45.bam"
signal_dir = f"{base}data/rna004/pod5/"

df = classify_transcripts(bam_file)
query_names = df['query_name'].to_list()
signals, query_names = extract_signal_dir(signal_dir, query_names)
labels = [df[df['query_name'] == query_name]['class'].iloc[0] if not df[df['query_name'] == query_name].empty else None for query_name in query_names]
labels = [0 if x == 'mt' else 1 for x in labels]


In [4]:
X_train_val, X_test, y_train_val, y_test = train_test_split(
    signals, labels, test_size=0.15, random_state=42, stratify=labels
)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=0.1765, random_state=42, stratify=y_train_val
)
print(len(X_train))
print(len(X_val))
print(len(X_test))

5173
1109
1109


In [5]:
from teri.main import Teri

tr = Teri()
tr.load_model(num_classes=2)

train_loader = tr.prepare_dataset(X_train, y_train, batch_size=32, shuffle=True)
val_loader   = tr.prepare_dataset(X_val, y_val, batch_size=32, shuffle=False)
test_loader  = tr.prepare_dataset(X_test, y_test, batch_size=32, shuffle=False)

tr.train_model(train_loader, val_loader, 400)

Epoch 1/400, Train Loss: 0.6880, Train Acc: 54.82, Val Loss: 0.6888, Val Acc: 55.73
Epoch 2/400, Train Loss: 0.6873, Train Acc: 55.64, Val Loss: 0.6880, Val Acc: 55.46
Epoch 3/400, Train Loss: 0.6863, Train Acc: 55.65, Val Loss: 0.6883, Val Acc: 55.46
Epoch 4/400, Train Loss: 0.6858, Train Acc: 56.31, Val Loss: 0.6874, Val Acc: 55.46
Epoch 5/400, Train Loss: 0.6858, Train Acc: 55.48, Val Loss: 0.6883, Val Acc: 55.46
Epoch 6/400, Train Loss: 0.6860, Train Acc: 55.71, Val Loss: 0.6879, Val Acc: 55.91
Epoch 7/400, Train Loss: 0.6862, Train Acc: 56.39, Val Loss: 0.6872, Val Acc: 55.46
Epoch 8/400, Train Loss: 0.6859, Train Acc: 55.62, Val Loss: 0.6870, Val Acc: 55.91
Epoch 9/400, Train Loss: 0.6854, Train Acc: 56.18, Val Loss: 0.6883, Val Acc: 55.46
Epoch 10/400, Train Loss: 0.6847, Train Acc: 55.92, Val Loss: 0.6865, Val Acc: 55.46
Epoch 11/400, Train Loss: 0.6843, Train Acc: 55.81, Val Loss: 0.6860, Val Acc: 55.46
Epoch 12/400, Train Loss: 0.6834, Train Acc: 55.62, Val Loss: 0.6849, Val 

In [6]:
from teri.dataset import RNA004_Dataset

train_dataset = RNA004_Dataset(X_train, y_train)
val_dataset = RNA004_Dataset(X_val, y_val)

tr.hypyerparameter_tuning(train_dataset, val_dataset)

Epoch 1/10, Train Loss: 0.2969, Train Acc: 91.24, Val Loss: 0.4131, Val Acc: 87.74
Epoch 2/10, Train Loss: 0.1271, Train Acc: 95.46, Val Loss: 0.4026, Val Acc: 87.02
Epoch 3/10, Train Loss: 0.1957, Train Acc: 93.04, Val Loss: 0.4107, Val Acc: 87.74
Epoch 4/10, Train Loss: 0.1235, Train Acc: 95.21, Val Loss: 0.4199, Val Acc: 86.29
Epoch 5/10, Train Loss: 0.1128, Train Acc: 95.98, Val Loss: 0.4233, Val Acc: 86.65
Epoch 6/10, Train Loss: 0.1800, Train Acc: 93.62, Val Loss: 0.4226, Val Acc: 87.20
Epoch 7/10, Train Loss: 0.1387, Train Acc: 94.49, Val Loss: 0.4571, Val Acc: 86.11
Epoch 8/10, Train Loss: 0.1046, Train Acc: 96.66, Val Loss: 0.4581, Val Acc: 85.30
Epoch 9/10, Train Loss: 0.1184, Train Acc: 95.32, Val Loss: 0.4466, Val Acc: 85.66
Epoch 10/10, Train Loss: 0.1445, Train Acc: 94.63, Val Loss: 0.3917, Val Acc: 88.01
Epoch 1/20, Train Loss: 0.2170, Train Acc: 93.20, Val Loss: 0.4030, Val Acc: 87.74
Epoch 2/20, Train Loss: 0.1182, Train Acc: 95.40, Val Loss: 0.3741, Val Acc: 87.29
Epo

In [8]:
x_loader  = tr.prepare_dataset(X_test, y_test, batch_size=512, shuffle=False)

val_running_loss, val_total, val_correct = tr.evaluate_model(x_loader)
print(val_correct/val_total)

0.8548241659152389
