In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

In [None]:
repo_dir = os.path.join(os.path.abspath('../../'))
repo_dir

In [None]:
import sys
sys.path.insert(0,repo_dir)
import pridict

In [None]:
data_pth = os.path.join(repo_dir, 'dataset')
data_pth

In [None]:
df = pd.read_csv(os.path.join(data_pth, 'proc_v2', f'data_23k_v1.csv'))

In [None]:
df.columns.tolist()

In [None]:
from pridict.pridictv2.dataset import MinMaxNormalizer
from pridict.pridictv2.data_preprocess import *
from pridict.pridictv2.utilities import *

In [None]:
def get_outcome_colnames(prefix, suffix=None):
    if suffix:
        lst = [f'{prefix}{colname}_{suffix}' for colname in ['averageedited', 'averageunedited', 'averageunintended']]
    else:
        lst = [f'{prefix}{colname}' for colname in ['averageedited', 'averageunedited', 'averageunintended']]
    return lst

### Process the dataset df

In [None]:
pe_seq_processor = PESeqProcessor()
tdf, proc_seq_init_df, num_init_cols,  proc_seq_mut_df, num_mut_cols = pe_seq_processor.process_init_mut_seqs(df, 
                                                                                                              'wide_initial_target', 
                                                                                                              'wide_mutated_target', 
                                                                                                              align_symbol=2)

In [None]:
check_editing_alignment_correctness(tdf, correction_len_colname='Correction_Length_effective')

### Merge the new data frame with original one

In [None]:
# add to df
df = pd.merge(left = df,
              right = tdf[['seq_id', 'wide_initial_target_align', 'wide_mutated_target_align', 
                           'Correction_Length_effective']],
             how='inner',
             left_on=['seq_id'],
             right_on=['seq_id'])

df['PBSinitlength'] = proc_seq_init_df['end_PBS'] - proc_seq_init_df['start_PBS']
df['PBSmutlength'] = proc_seq_mut_df['end_PBS'] - proc_seq_mut_df['start_PBS']
print((df['PBSinitlength'] == df['PBSmutlength']).all())
df['PBSlength'] = df['PBSinitlength']


### Visualize sequences

In [None]:
from IPython.core.display import HTML

#### using precomputed aligned dataframe

In [None]:
for correction_type in ['Replacement', 'Insertion', 'Deletion']:
    cond = (df['Correction_Type'] == correction_type) & (df['Correction_Length']>7)
    seq_id = np.random.choice(df.loc[cond, 'seq_id'])
    display(HTML(Viz_PESeqs().viz_align_initmut_seq_precomputed(tdf, seq_id, wsize=20, return_type='html')))

#### using original dataframe

In [None]:
display(HTML(Viz_PESeqs().viz_align_initmut_seq(df, seq_id, wsize=20, return_type='html')))

### Normalize the continuous features

In [None]:
include_MFE = False
include_addendumfeat = False
minmax_normalizer = MinMaxNormalizer(include_MFE=include_MFE,include_addendumfeat=include_addendumfeat)
norm_colnames = minmax_normalizer.normalize_cont_cols_max(df, suffix='_norm')

In [None]:
df[['PBSlength', 'PBSlength_norm']].hist()

In [None]:
df[['RToverhanglength', 'RToverhanglength_norm']].hist()

In [None]:
df[['Correction_Length', 'Correction_Length_norm']].hist()

In [None]:
df[['Correction_Length_effective', 'Correction_Length_effective_norm']].hist()

### Create datapartitions, datatensor and dump on disk

In [None]:
from pridict.pridictv2.dataset import *

In [None]:
def run_clean_check_tests(df, dpartitions, outcome_name, suffix=''):
    print('run test to check NaN rows are removed for outcome:', outcome_name)
    print('>> True would mean that rows are still there!! <<')
    for run in range(len(dpartitions)):
        print('run:', run)
        for dsettype in ['train', 'validation', 'test']:
            indices = dpartitions[run][dsettype]
            if suffix:
                ocols = get_outcome_colnames(outcome_name, suffix)
            else:
                ocols = get_outcome_colnames(outcome_name)
            print(df.loc[indices, ocols].isna().any())
        print()
    print('run test for confirming there is no overlap between train, validation and test sets')
    print('>> 0 means no overlap <<')
    for run in range(5):
        print('run:', run)
        print(df.loc[dpartitions[run]['test'], 'grp_id'].isin(df.loc[dpartitions[run]['train'], 'grp_id']).sum())
        print(df.loc[dpartitions[run]['test'], 'grp_id'].isin(df.loc[dpartitions[run]['validation'], 'grp_id']).sum())
        print(df.loc[dpartitions[run]['validation'], 'grp_id'].isin(df.loc[dpartitions[run]['train'], 'grp_id']).sum())
        
def clean_dpartitions(dpartitions, nan_indices):
    dpartitions_upd = {}
    for run in range(len(dpartitions)):
        print('run_id:', run)
        dpartitions_upd[run] = {}
        for dsettype in ['train', 'validation', 'test']:
            indices = dpartitions[run][dsettype]
            print(f'# of {dsettype} indices:', len(indices))
            clean_indices = set(indices) - set(nan_indices)
            print(f'# of {dsettype} indices after:', len(clean_indices))
            dpartitions_upd[run][dsettype] = np.array(list(clean_indices))
        print()
    return dpartitions_upd

def plot_y_distrib_acrossfolds(dpartitions, y, opt='separate_folds'):
    #  histtype in {'bar', 'step'}, fill=True, stacked=True
    if opt == 'separate_dsettypes':
        fig, axs = plt.subplots(figsize=(9,11), 
                                nrows=3, 
                                constrained_layout=True)
        axs = axs.ravel()
        for run_num in range(len(dpartitions)):
            counter = 0
            for dsettype in ['train', 'validation', 'test']:
                curr_ax = axs[counter]
                ids = dpartitions[run_num][dsettype]
                curr_ax.hist(y[ids], alpha=0.3, label=f"{dsettype}_run{run_num}")
                counter+=1
                curr_ax.legend()
    elif opt == 'separate_folds':
        fig, axs = plt.subplots(figsize=(9,11),
                                nrows=5,
                                constrained_layout=True)
        axs = axs.ravel()
        for run_num in range(len(dpartitions)):
            curr_ax = axs[run_num]
            for dsettype in ['train', 'validation', 'test']:
                ids = dpartitions[run_num][dsettype]
                curr_ax.hist(y[ids], alpha=0.4,label=f"{dsettype}_run{run_num}")
                curr_ax.legend()


### Run to create cleaned datapartitions and dtensor

In [None]:
tfolder = 'proc_v2'
tdir = create_directory(os.path.join(repo_dir, 'dataset', tfolder))
if include_MFE:
    fsuffix = 'withMFE'
else:
    fsuffix = 'withoutMFE'
dump_dir = create_directory(os.path.join(repo_dir, 'dataset', tfolder, f'align_{fsuffix}'))
hek_indices_nan = ReaderWriter.read_data(os.path.join(tdir, f'hek_indices_nan.pkl'))
k562_indices_nan = ReaderWriter.read_data(os.path.join(tdir, f'k562_indices_nan.pkl'))

wsize=20
outcome_suffix = 'clamped'

# get grouped 5-fold data partitions
dpartitions = get_stratified_partitions(df['grp_id'].values, num_folds=5, valid_set_portion=0.1, random_state=42)
validate_partitions(dpartitions, range(df['grp_id'].shape[0]), valid_set_portion=0.1, test_set_portion=0.2)
print()
for outcome_name in ['HEK', 'K562']:
    dtensor = create_datatensor(df, proc_seq_init_df, num_init_cols, 
                                proc_seq_mut_df, num_mut_cols, norm_colnames, 
                                window=wsize, y_ref=get_outcome_colnames(outcome_name, outcome_suffix))
    if outcome_name == 'HEK':
        nan_indices = hek_indices_nan
    elif outcome_name == 'K562':
        nan_indices = k562_indices_nan
    
#     run_clean_check_tests(df, dpartitions, outcome_name, suffix=outcome_suffix)
    print()
    dpartitions_upd = clean_dpartitions(dpartitions, nan_indices)
    run_clean_check_tests(df, dpartitions_upd, outcome_name, suffix=outcome_suffix)
    print()
    plot_y_distrib_acrossfolds(dpartitions_upd, dtensor.y_score.numpy(), opt='separate_folds')
    
    # dump on disk
    fname = f'dpartitions_{outcome_name}_{outcome_suffix}_wsize{wsize}.pkl'
    ReaderWriter.dump_data(dpartitions_upd, os.path.join(dump_dir, fname))
    fname = f'dtensor_{outcome_name}_{outcome_suffix}_wsize{wsize}.pkl'
    ReaderWriter.dump_data(dtensor, os.path.join(dump_dir, fname))
