In [2]:
import numpy as np
import pandas as pd
import seaborn as sns

import matplotlib.pyplot as plt
import pickle
import warnings

from sklearn.base import clone
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, HistGradientBoostingClassifier
from sklearn.inspection import permutation_importance
from sklearn.metrics import RocCurveDisplay, roc_curve, auc, precision_recall_curve, average_precision_score, classification_report
from sklearn.model_selection import train_test_split, RandomizedSearchCV, StratifiedKFold

from metaorf.modeling.etl import generate_orf_id
from metaorf.modeling.ensemble import plot_roc_pr, Dataset

from Bio.Seq import Seq
from pathlib import Path

pd.set_option('display.max_columns', 100)
pd.set_option('display.max_rows', 500)

warnings.filterwarnings('ignore')

In [3]:
data_dir = Path('../data').absolute()

In [4]:
qc_samples = []
with open('../data/StringentQC_samples.txt', 'r') as infile:
    for line in infile.readlines():
        qc_samples.append(line.rstrip('\n').replace('_', '-'))

In [7]:
dataset_names = []
for feature_file in data_dir.glob('*orf_features.csv'):
    name = '-'.join(feature_file.name.split('_')[3:-2])
    if name in qc_samples: 
        dataset_names.append(feature_file.name)

In [8]:
len(dataset_names)

175

In [6]:
call_names = []

for call_file in data_dir.glob('*any_caller.csv'):
    name = '-'.join(feature_file.name.split('_')[3:-2])
    if name in qc_samples: 
        call_names.append(call_file.name)

In [7]:
call_names

['merged_orfs_found_by_any_caller.csv']

In [8]:
def return_passing_orf_ids(data_dir, call_files):
    """
    """

    filter_categories = ["3'UTR:CDSFrameOverlap", "5'UTR:CDSFrameOverlap", 'Annotated', 'annotated', 
                     'Internal:CDSFrameOverlap', 'Trunc', "5'UTR:Known", "3'UTR:Known", 'Internal:Known',
                     'Truncated', 'Truncated:Known', 'Extended', 'Extended:CDSFrameOverlap', 'Extended:Known']
    orf_ids = set()
    
    for call_file in call_files:
        df = pd.read_csv(data_dir.joinpath(call_file), sep='\t')
        tmp_df = df[(~df['ORF_type_price'].isin(filter_categories)) & \
                    (~df['ORF_type_ribotish'].isin(filter_categories)) & \
                    (~df['ORF_type_ribocode'].isin(filter_categories))]

        orf_ids = orf_ids.union(set(tmp_df.apply(lambda x: f'{x.chrom_id}_{x.orf_start}_{x.orf_end}_{x.strand}_{x.exon_blocks}', axis=1).values))

    return set(orf_ids)

In [9]:
overwrite=True

if overwrite:
    passing_orf_ids = return_passing_orf_ids(data_dir, call_names)
    with open('../data/novel_orf_ids.txt', 'w') as outfile:
        for orf_id in passing_orf_ids:
            outfile.write(f'{orf_id}\n')
else:
    passing_orf_ids = []
    with open('../data/novel_orf_ids.txt', 'r') as infile:
        for line in infile.readlines():
            passing_orf_ids.append(line.rstrip('\n'))

In [10]:
len(passing_orf_ids)

489720

In [11]:
with open('../data/top_model_all_gb.pkl', 'rb') as file:
    ds = pickle.load(file)
    ds.model = ds.model.fit(ds.X.drop(columns=['chrom_id']), ds.y)

In [12]:
def load_features(data_dir, datasets):
    """
    """

    feature_dfs = []
    for dataset in datasets:
        tmp_df = pd.read_csv(data_dir.joinpath(f'{dataset}'), sep='\t')
        tmp_df['dataset'] = '-'.join(dataset.split('_')[:-2])
        feature_dfs.append(tmp_df)
    
    feature_df = pd.concat(feature_dfs)
    
    feature_df['orf_id'] = feature_df.apply(lambda x: f'{x.chrom_id}_{x.orf_start}_{x.orf_end}_{x.strand}_{x.exon_blocks}_{x.dataset}', axis=1)
    feature_df['orf_idx_str'] =  feature_df.apply(lambda x: f'{x.chrom_id}_{x.orf_start}_{x.orf_end}_{x.strand}_{x.exon_blocks}', axis=1)

    feature_df.set_index('orf_id', inplace=True)

    return feature_df

In [1]:
dataset_names

NameError: name 'dataset_names' is not defined

In [13]:
pred_df_list = []
for dataset in dataset_names:
    feature_all_df = load_features(data_dir, [dataset])
    passing_orf_idx_strs = feature_all_df[feature_all_df['orf_idx_str'].isin(passing_orf_ids)].index
    drop_cols=['orf_start', 'orf_end']
    feature_df = feature_all_df.drop(columns=drop_cols)
    feature_df = feature_df.select_dtypes(include='number')
    feature_df = feature_df.loc[passing_orf_idx_strs]
    prediction_proba = ds.model.predict_proba(feature_df)
    feature_df['prediction_proba'] = prediction_proba[:,1]
    
    
    pred_df = feature_df[feature_df['prediction_proba'] > .95].copy()
    pred_df['chrom'] = pred_df.apply(lambda x: x.name.split('_')[0], axis=1)
    pred_df['dataset'] = pred_df.apply(lambda x: x.name.split('_')[-1], axis=1)
    pred_df['orf_idx'] = pred_df.apply(lambda x: '_'.join(x.name.split('_')[:-1]), axis=1)
    str_cols = ['orf_idx_str', 'chrom_id', 'orf_start', 'orf_end', 'strand', 'exon_blocks', 'orf_sequence']
    pred_all_df = pred_df.merge(feature_all_df[str_cols], left_index=True, right_on='orf_id', how='left')
    pred_all_df['aa'] = pred_all_df.apply(lambda x: str(Seq(x.orf_sequence).translate())[:-1], axis=1)
    pred_all_df['length'] = pred_all_df.apply(lambda x: len(x.aa), axis=1)
    pred_all_df = pred_all_df[pred_all_df['length'] > 14]
    pred_all_df = pred_all_df[(pred_all_df['orf_sequence'].str.startswith('ATG')) | \
                 (pred_all_df['orf_sequence'].str.startswith('CTG')) | \
                 (pred_all_df['orf_sequence'].str.startswith('GTG'))]
    
    pred_df_list.append(pred_all_df)

In [14]:
pred_df = pd.concat(pred_df_list)

In [15]:
print(len(set(pred_df['orf_sequence'])))
print(len(pred_df))

3531
46427


In [16]:
pred_df.to_csv('../data/top_orfs_gb-95_240414.csv')