# Run ABE8e DESeq2 Model

### Prepare Data Input

You will only need the encoding output from the previous step. You can also download the pre-computed encoding from [Zenodo](https://doi.org/10.5281/zenodo.13737880) at path: 
- CRISPR-CLEAR-data/data/encoding_output/20240807_v0_1_84_encoding_dataframes_denoised_removed_ABE8e_encodings_rep0.tsv
- CRISPR-CLEAR-data/data/encoding_output/20240807_v0_1_84_encoding_dataframes_denoised_removed_ABE8e_encodings_rep1.tsv
- CRISPR-CLEAR-data/data/encoding_output/20240807_v0_1_84_encoding_dataframes_denoised_removed_ABE8e_encodings_rep2.tsv

### Import Packages

In [None]:
from crispr_millipede import encoding as cme
from crispr_millipede import modelling as cmm

import plotly
import plotly.graph_objects as go
import plotly.express as px

import pandas as pd
import numpy as np

from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats
from pydeseq2.utils import load_example_data

### Prepare Millipede specification parameters

In [None]:
design_matrix_spec = cmm.MillipedeDesignMatrixProcessingSpecification(
    wt_normalization=False,
    total_normalization=True,
    sigma_scale_normalized=True,
    decay_sigma_scale=True,
    K_enriched=5,
    K_baseline=5,
    a_parameter=0.0005,
    set_offset_as_presort = True,
    offset_normalized = False,
    offset_psuedocount = 1 # OTHERWISE will get infinity error when taking log(count +1) for offset, since 0 counts could be acceptable depending on thresholding
)

millipede_model_specification_set = {
    "joint_replicate_per_experiment_models" : cmm.MillipedeModelSpecification(
        model_types=[cmm.MillipedeModelType.NORMAL_SIGMA_SCALED, cmm.MillipedeModelType.NORMAL],
        replicate_merge_strategy=cmm.MillipedeReplicateMergeStrategy.COVARIATE,
        experiment_merge_strategy=cmm.MillipedeExperimentMergeStrategy.SEPARATE,
        S = 5,
        tau = 0.01,
        tau_intercept = 0.0001,
        cutoff_specification=cmm.MillipedeCutoffSpecification(
            per_replicate_each_condition_num_cutoff = 0, 
            per_replicate_all_condition_num_cutoff = 1, 
            all_replicate_num_cutoff = 0, 
            all_experiment_num_cutoff = 0,
            baseline_pop_all_condition_each_replicate_num_cutoff = 3,
            baseline_pop_all_condition_acceptable_rep_count = 2,
            enriched_pop_all_condition_each_replicate_num_cutoff = 3,
            enriched_pop_all_condition_acceptable_rep_count = 2,
            presort_pop_all_condition_each_replicate_num_cutoff = 3,
            presort_pop_all_condition_acceptable_rep_count = 2,
            
        ),
        design_matrix_processing_specification=design_matrix_spec
    )
}

In [None]:
paired_end_experiments_inputdata_denoised = cmm.MillipedeInputDataExperimentalGroup(
    data_directory="./", 
    enriched_pop_fn_experiment_list = ["20240807_v0_1_84_encoding_dataframes_denoised_removed_ABE8e_encodings_rep{}.tsv"],
    enriched_pop_df_reads_colname = "#Reads_CD19minus",
    baseline_pop_fn_experiment_list = ["20240807_v0_1_84_encoding_dataframes_denoised_removed_ABE8e_encodings_rep{}.tsv"],
    baseline_pop_df_reads_colname = "#Reads_CD19plus", 
    presort_pop_fn_experiment_list = ["20240807_v0_1_84_encoding_dataframes_denoised_removed_ABE8e_encodings_rep{}.tsv"],
    presort_pop_df_reads_colname = '#Reads_presort',
    experiment_labels = ["ABE8e"],
    reps = [0,1,2],
    millipede_model_specification_set = millipede_model_specification_set
   )

### Prepare DESeq2 count matrix

In [None]:
data = paired_end_experiments_inputdata_denoised.millipede_model_specification_set_with_data['joint_replicate_per_experiment_models'][1].data[0]
alleles = [data[data['intercept_exp0_rep0'] == 1], data[data['intercept_exp0_rep1'] == 1], data[data['intercept_exp0_rep2'] == 1]]

In [None]:
allele_tables = []

for rep in range(3):
    cols = list(alleles[rep].columns.values)
    restricted_cols = [col for col in cols if ">" in col]
    
    label = []
    edits = []
    
    minus = []
    plus = []
    presort = []    

    for index, row in alleles[rep].iterrows():
        row_restricted = row[restricted_cols]
        cols_selected = row_restricted[row_restricted == 1].index.tolist()
        
        allele = amplicon
        for col in cols_selected:
            demarker = col.find(">")
            pos = int(col[:demarker - 1])
            alt = col[demarker+1:]
            allele = allele[:pos] + alt + allele[pos + 1:]
                 
        if len(cols_selected) == 0:
            edits.append("wt")
        else:
            edits.append(",".join(cols_selected))
        
        minus.append(row["#Reads_CD19minus_raw"])
        plus.append(row["#Reads_CD19plus_raw"])
        
    dicto = {"edits": edits,
             "CD19minus_rep" + str(rep + 1): minus,
             "CD19plus_rep" + str(rep + 1): plus}
    
    allele_tables.append(pd.DataFrame(dicto).set_index("edits"))

In [None]:
merged_df = pd.merge(allele_tables[0], allele_tables[1], on='edits', how='outer').fillna(0)
merged_df = pd.merge(merged_df, allele_tables[2], on='edits', how='outer').fillna(0)

In [None]:
count_df = merged_df.T
count_df.index.names = ['sample']

In [None]:
count_df

### Prepare DESeq2 design matrix

In [None]:
samples = list(count_df.index)
condition = ["CD19-", "CD19+"] * 3
replicate = [1, 1, 2, 2, 3, 3]
metadata_dict = {"sample": samples,
                 "condition": condition,
                 "replicate": replicate}

metadata_df = pd.DataFrame(metadata_dict).set_index("sample")

In [None]:
metadata_df

### Run DESeq2

In [None]:
inference = DefaultInference(n_cpus=8)
dds = DeseqDataSet(
    counts=count_df,
    metadata=metadata_df,
    design_factors="condition",
    refit_cooks=True,
    inference=inference,
    # n_cpus=8, # n_cpus can be specified here or in the inference object
)

In [None]:
dds.deseq2()

In [None]:
stat_res = DeseqStats(dds, inference=inference)

### Visualize 

In [None]:
results_df = stat_res.results_df
results_df = results_df[['log2FoldChange','pvalue']].dropna()
negative_log = np.log10(list(results_df["pvalue"])) * -10
results_df["-10 * log(pvalue)"] = negative_log

In [None]:
results_df

In [None]:
lfc = list(results_df['log2FoldChange'])
negative_log_pvalue = list(results_df['-10 * log(pvalue)'])
colors = []

for i in range(len(lfc)):
    if lfc[i] > 5 and negative_log_pvalue[i] > 10:
        colors.append("red")
    elif lfc[i] < -5 and negative_log_pvalue[i] > 10:
        colors.append("red")
    else:
        colors.append("blue")

In [None]:
fig = go.Figure()
scatter = go.Scatter(
    x=results_df['log2FoldChange'] * -1,
    y=results_df['-10 * log(pvalue)'],
    mode='markers',
    marker=dict(color=colors, size=6, opacity=0.5),
    hovertext=list(results_df.index)
)

fig.update_layout(
    showlegend=False,
    xaxis=dict(title="LFC [CD19+ vs CD19-]", title_font=dict(size=18), range=[-10, 10], showgrid = False, zeroline = False, ticks='outside', ticklen = 10, tickfont=dict(size=15)),
    yaxis=dict(title="-10 * log10(pvalue)", title_font=dict(size=18), range=[-0, 40], showgrid = False, zeroline = False, ticks='outside', ticklen = 10, tickfont=dict(size=15)),
    plot_bgcolor='white'
)

fig.add_trace(scatter)

ythresh = go.Scatter(
    x=[fig.layout.xaxis.range[0], fig.layout.xaxis.range[1]],
    y=[10, 10],
    mode='lines',
    line=dict(color='black', dash='dash', width = 1)
)

xthresh1 = go.Scatter(
    x=[-5, -5],
    y=[fig.layout.yaxis.range[0], fig.layout.yaxis.range[1]],
    mode='lines',
    line=dict(color='black', dash='dash', width = 1)
)
xthresh2 = go.Scatter(
    x=[5, 5],
    y=[fig.layout.yaxis.range[0], fig.layout.yaxis.range[1]],
    mode='lines',
    line=dict(color='black', dash='dash', width = 1)
)

fig.update_xaxes(showline=True, linewidth=1, linecolor='black')
fig.update_yaxes(showline=True, linewidth=1, linecolor='black')

fig.add_trace(ythresh)
fig.add_trace(xthresh1)
fig.add_trace(xthresh2)

fig.show()

fig.write_image("ABE8e_allelic_analysis.svg")