# QC Data, filter samples, and run differential expression testing
* User input needed:
    * Define specific samples or groups to drop.
    * Defining contrasts -> notebook will not complete if contrasts are not defined or if contrasts don't exist in the metadata csv.

In [None]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns
from IPython.display import HTML
import html
from tqdm import tqdm
import copy
from typing import List
from statsmodels.stats.multitest import multipletests
from scipy.stats import chi2
from joblib import Parallel, delayed, parallel_backend
import warnings

from pydeseq2.ds import DeseqStats
from pydeseq2.dds import DeseqDataSet
from pydeseq2.utils import build_design_matrix, nb_nll

import plotly.graph_objects as go
from plotly.offline import init_notebook_mode, iplot

init_notebook_mode(connected=True)
warnings.filterwarnings('ignore')

In [None]:
NUM_CPUS = 8

DATA_PATH = '/data/expression_atlas/runs/%s/' % os.getcwd().split('/')[-1]

MULTIQC_PATH = os.path.join(DATA_PATH, 'rnaseq_output/multiqc/star_salmon/multiqc_report.html')

COUNT_PATH = os.path.join(DATA_PATH, 'rnaseq_output/star_salmon')

RESULTS_PATH = DATA_PATH + 'de_results/%s' % DATA_PATH.rstrip('/').split('/')[-1]

METADATA_FH = DATA_PATH + DATA_PATH.rstrip('/').split('/')[-1] + '_metadata.csv'

PCA_VARIABLES = ['lib_sizes', 'library_layout', 'instrument_model']

DDS_TRANSCRIPT_FH = RESULTS_PATH + '_dds_transcript.h5_ad'
DDS_GENE_FH = RESULTS_PATH + '_dds_gene.h5_ad'

MAX_NLOG10_PADJ = 400.
TRANSCRIPT_SUM_FILTER = 1
GENE_SUM_FILTER = 1

In [None]:
# Utilitiy functions for working with pydeseq2 glms.

def relevel_design(dds: DeseqDataSet, ref_level: List[str]) -> None:
    """Relevels pydeseq2 DeseqDataSet to level in ref_level. Rearranges coefficients to accomodate
    new reference level and rebuilds design matrix.

    Args:
        dds (DeseqDataset) pydeseq2 object to modify
        ref_level (List[str]) list of two elements ['condition','new_reference_level']
    """
    if ref_level[0] not in dds.obs.columns:
        raise ValueError('%s condition not in design.' % ref_level[0])
    if ref_level[1] not in dds.obs[ref_level[0]].values:
        raise ValueError('%s condition level not in %s.' % (ref_level[1], ref_level[0]))
    if not dds.ref_level:
        raise AttributeError('%s define reference level "ref_level" for original design.')
    
    if any(
        True if c.startswith(ref_level[0]) and c.endswith(ref_level[1]) else False
            for c in dds.obsm['design_matrix'].columns):
        print('%s already reference level for %s' % (ref_level[0], ref_level[1]))
        return

    design_matrix = build_design_matrix(
                                    metadata=dds.obs.copy(),
                                    design_factors=dds.design_factors,
                                    ref_level=ref_level,
                                )
    
    if 'LFC' not in dds.varm.keys():
        dds.deseq2()

    coef_df = dds.varm['LFC'].copy()

    refo = [c.split('_')[-1] for c in dds.varm['LFC'] if c.startswith(ref_level[0])][0]
    
    columns_to_relevel = [c for c in design_matrix.columns if c.startswith(ref_level[0])]

    coef_df['intercept'] = dds.varm['LFC']['intercept'] + \
                                dds.varm['LFC']['%s_%s_vs_%s' % (ref_level[0],ref_level[1],refo)]

    for c in columns_to_relevel:
        ref, con = c.split(ref_level[0]+'_')[-1].split('_vs_')
        
        if '%s_%s_vs_%s' % (ref_level[0],con,ref) in dds.varm['LFC'].columns:
            coef_df[c] = -1. * dds.varm['LFC']['%s_%s_vs_%s' % (ref_level[0],con,ref)]
        else:
            coef_df[c] = dds.varm['LFC']['%s_%s_vs_%s' % (ref_level[0],ref,refo)] - \
                                        dds.varm['LFC']['%s_%s_vs_%s' % (ref_level[0],con,refo)] 
    
    columns_drop = [c for c in dds.varm['LFC'] if c.startswith(ref_level[0])]

    coef_df.drop(columns_drop, axis=1, inplace=True)
    coef_df = coef_df[design_matrix.columns]

    dds.varm['LFC'] = coef_df.copy()
    dds.obsm['design_matrix'] = design_matrix.copy()
    dds.ref_level = ref_level

    print('dds releveled to %s-%s' % (ref_level[0], ref_level[1]))



def likelihood_ratio_test(dds: DeseqDataSet, factors: List[str], alpha: float=0.05 ) -> pd.DataFrame:
    """Perform likelihood ratio test of full model against null model lacking factors described in arguments.

    Args:
        dds (DeseqDataSet) pydeseq2 object with full design
        factors (List[str]) factors to drop from full design
        alpha (float) alpha value for pvalue adjustment via bh-correction

    Returns:
        (pd.DataFrame) dataframe containg lr-statistic, pvalue, and padj of full model vs. null model
    """

    if any(True if c not in dds.design_factors else False for c in factors):
        raise ValueError('check to make sure all factors %s in design factors.' % (','.join(factors)))

    # Calculate likelihood of fit model.    
    mu_fit = np.stack(
                dds.varm['LFC'].apply(
                       lambda x: dds.obsm['size_factors'] * np.exp(dds.obsm['design_matrix'].values @ x), 
                       axis=1,
                    )
                )

    with parallel_backend("loky", inner_max_num_threads=1):
                res = Parallel(
                    n_jobs=dds.n_processes,
                    verbose=dds.joblib_verbosity,
                    batch_size=dds.batch_size,
                )(
                    delayed(nb_nll)(
                        dds.X[:, i],
                        mu_fit[i,:],
                        dds.varm['dispersions'][i],
                    )
                    for i in range(dds.X.shape[1])
                )

    l_fit = 2. * np.array(res)

    # Calculate likelihood of null model.
    dds_null = copy.deepcopy(dds)
    
    dds_null.obsm['design_matrix'].drop(
                        [c for c in dds_null.obsm['design_matrix'] if \
                                    any(c.startswith(f) for f in factors)],
                        axis=1,
                        inplace=True,
                    )
    
    # Fit null model given estimated dispersion parameters from full model.
    dds_null.fit_LFC()

    mu_null = np.stack(
                dds_null.varm['LFC'].apply(
                   lambda x: dds_null.obsm['size_factors'] * \
                            np.exp(dds_null.obsm['design_matrix'].values @ x), 
                   axis=1,
                )
            )

    with parallel_backend("loky", inner_max_num_threads=1):
                res = Parallel(
                    n_jobs=dds.n_processes,
                    verbose=dds.joblib_verbosity,
                    batch_size=dds.batch_size,
                )(
                    delayed(nb_nll)(
                        dds_null.X[:, i],
                        mu_null[i,:],
                        dds_null.varm['dispersions'][i],
                    )
                    for i in range(dds_null.X.shape[1])
                )

    l_null = 2. * np.array(res)

    # Perform LRT.
    lr_statistic = l_null - l_fit

    df = dds.obsm['design_matrix'].shape[1] - dds_null.obsm['design_matrix'].shape[1]

    p = chi2.sf(lr_statistic, df)

    padj = multipletests(p, alpha=alpha, method="fdr_bh")[1]

    return pd.DataFrame(
                    np.array([lr_statistic, p, padj]).T, 
                    index=dds.varm['LFC'].index, 
                    columns=['lrstat', 'pvalue', 'padj'],
                )

### QC1.1 MultiQC output

In [None]:
# Embed the MulitQC report into notebook. Note srcdoc was reqiured to get the html rendered without screwing up the 
# styling of the notebook while using an iframe. Buttons on the side don't work, but everything else seems to work fine.

with open(MULTIQC_PATH,'r') as f_in:
    html_raw = html.escape(f_in.read())

HTML('<iframe srcdoc="%s" width="1200px" height="1000px"></iframe>' % html_raw)

### QC1.2 Metadata

In [None]:
# Read sample metadata into dataframe

metadata = pd.read_csv(METADATA_FH, index_col=0)
smallest_condition_size = metadata[[c for c in metadata.columns if c.startswith('condition')]].value_counts()[-1]

metadata, smallest_condition_size

### QC1.3 Define contrasts
* Contrasts need to be defined in a dict as contrast_name: [column_name, treatment_level, reference_level].
* Each contrast in contrasts will get a DE dataframe defined under anndata.uns['stat_results][<contrast_name>].

In [None]:
# Manually define contrasts given conditions in metadata dataframe. Define reference level for comparisons.


# Pydeseq2 contrasts require condition-name, treatment level, reference level format.

# Example contrast:
# contrasts = {
#     'TYPE_2_DIABETES_vs_CONTROL': ['condition-1','DISEASE-1', 'CONTROL'],
#     }

reference_level = ['condition-1', 'CONTROL']

# All underscores are replaced by hyphens.
assert (
        len(contrasts) > 0 and 
        all(c[0] in metadata.columns for c in contrasts.values()) and 
        all(
            (   l_1.replace('-','_') in metadata[c].values and 
                l_2.replace('-','_') in metadata[c].values) 
                    for c, l_1, l_2 in contrasts.values()
        )
    )
reference_level = [r.replace('_','-') for r in reference_level]
contrasts = {
    n: [r.replace('_','-') for r in l] for n, l in contrasts.items()
}

In [None]:
# Merge effective lengths and average across all runs for transcript and gene dataframes.

# Build the transcript length dataframe off of the first sample in the metadata dataframe.
transcript_length = pd.read_csv(os.path.join(COUNT_PATH, metadata.index[0], 'quant.sf'), delimiter= '\t', index_col=0)
transcript_length.rename({'EffectiveLength':'EffectiveLength_%s' % metadata.index[0]}, axis=1, inplace=True)
transcript_length.drop(['Length', 'TPM', 'NumReads'], inplace=True, axis=1)

gene_length = pd.read_csv(os.path.join(COUNT_PATH, metadata.index[0], 'quant.genes.sf'), delimiter= '\t', index_col=0)
gene_length.rename({'EffectiveLength':'EffectiveLength_%s' % metadata.index[0]}, axis=1, inplace=True)
gene_length.drop(['Length', 'TPM', 'NumReads'], inplace=True, axis=1)

# Populate samples into effective length dataframe with remaining samples.
for srx in tqdm(metadata.index[1:]):
    df = pd.read_csv(os.path.join(COUNT_PATH, srx, 'quant.sf'), delimiter='\t', index_col=0)
    df.drop(['Length', 'TPM', 'NumReads'], inplace=True, axis=1)
    df.rename({'EffectiveLength':'EffectiveLength_%s' % srx}, axis=1, inplace=True)
    transcript_length = transcript_length.merge(df, on='Name')

    df_gene = pd.read_csv(os.path.join(COUNT_PATH, srx, 'quant.genes.sf'), delimiter='\t', index_col=0)
    df_gene.drop(['Length', 'TPM', 'NumReads'], inplace=True, axis=1)
    df_gene.rename({'EffectiveLength':'EffectiveLength_%s' % srx}, axis=1, inplace=True)
    gene_length = gene_length.merge(df_gene, on='Name')


# Average effective lengths. It would probably be better to use a weighted averaging similar to tximport.
transcript_length['length'] = transcript_length.mean(axis=1)
gene_length['length'] = gene_length.mean(axis=1)
transcript_length.drop([c for c in transcript_length.columns if c.startswith('EffectiveLength_')], inplace=True, axis=1)
gene_length.drop([c for c in gene_length.columns if c.startswith('EffectiveLength_')], inplace=True, axis=1)

transcript_length.shape, gene_length.shape

In [None]:
# Read in transcript and gene count/TPM dataframes.

expression = pd.read_csv(os.path.join(COUNT_PATH, 'salmon.merged.transcript_counts.tsv'), delimiter = '\t', index_col=0)
gene_transcript_mapping = expression[['gene_id']].copy().reset_index()
expression.drop('gene_id', inplace=True, axis=1)

tpm = pd.read_csv(os.path.join(COUNT_PATH, 'salmon.merged.transcript_tpm.tsv'), delimiter = '\t', index_col=0)
tpm.drop('gene_id', inplace=True, axis=1)

# See: https://nf-co.re/rnaseq/3.12.0/docs/output#salmon on output choice below 
# salmon.merged.gene_counts_length_scaled.tsv is the gene-level output of nf-core rnaseq that is bias-corrected
# and is already scaled by potential transcript length
expression_gene = pd.read_csv(os.path.join(COUNT_PATH, 'salmon.merged.gene_counts_length_scaled.tsv'), delimiter='\t', index_col=0)
expression_gene.drop('gene_name', inplace=True, axis=1)

tpm_gene = pd.read_csv(os.path.join(COUNT_PATH, 'salmon.merged.gene_tpm.tsv'), delimiter='\t', index_col=0)
tpm_gene.drop('gene_name', inplace=True, axis=1)

expression.shape, expression_gene.shape


In [None]:
# Filter expression dataframes on samples not in metadata.csv

samples_to_drop = [c for c in expression.columns if c not in metadata.index]

expression.drop(samples_to_drop, axis=1, inplace=True)
expression_gene.drop(samples_to_drop, axis=1, inplace=True)
tpm.drop(samples_to_drop, axis=1, inplace=True)
tpm_gene.drop(samples_to_drop, axis=1, inplace=True)
expression

### QC1.4 Drop samples or groups from count dataframe or design matrix, respectively.
* Samples are dropped based on "SRX" id, or cell in 'accession' column of the metadata dataframe.
* Groups are dropped based on group id ex. 'group-1', a grouping present in the metadata dataframe.

In [None]:
# Drop specific samples from dataframe. Provide accession name of samples to remove from analysis.

# Clear outlier.
samples_to_drop = []

expression.drop(samples_to_drop, axis=1, inplace=True)
expression_gene.drop(samples_to_drop, axis=1, inplace=True)
tpm.drop(samples_to_drop, axis=1, inplace=True)
tpm_gene.drop(samples_to_drop, axis=1, inplace=True)
metadata.drop(samples_to_drop, axis=0, inplace=True)
expression

In [None]:
# Drop specific conditions/groups from metadata dataframe. 

groups_to_drop = []
metadata.drop(groups_to_drop, axis=1, inplace=True)
metadata

In [None]:
# Filter expression dataframe and prepare for QC/EDA.

filtered_expression_transcript = expression.T.copy()
filtered_expression_transcript = filtered_expression_transcript.loc[:,
                                        filtered_expression_transcript.sum(axis=0) >= TRANSCRIPT_SUM_FILTER
                                    ]

# Tag to take another look, filter later 
filtered_expression_transcript = filtered_expression_transcript.loc[:,
                                        (filtered_expression_transcript >= TRANSCRIPT_SUM_FILTER).sum(axis=0) >
                                            smallest_condition_size
                                    ]

filtered_tpm_transcript = tpm.T.copy()
filtered_tpm_transcript = filtered_tpm_transcript[filtered_expression_transcript.columns]

filtered_expression_gene = expression_gene.T.copy()
filtered_expression_gene = filtered_expression_gene.loc[:,
                                        filtered_expression_gene.sum(axis=0) >= GENE_SUM_FILTER
                                    ]

# Tag to take another look, filter later 
filtered_expression_gene = filtered_expression_gene.loc[:,
                                        (filtered_expression_gene >= GENE_SUM_FILTER).sum(axis=0) > 
                                            smallest_condition_size
                                    ]

filtered_tpm_gene = tpm_gene.T.copy()
filtered_tpm_gene = filtered_tpm_gene[filtered_expression_gene.columns]

assert (
        all([i == j for i,j in zip(filtered_expression_gene.columns, filtered_tpm_gene.columns)]) and 
        all([i == j for i,j in zip(filtered_expression_transcript.columns, filtered_tpm_transcript.columns)]) and 
        all([i == j for i,j in zip(filtered_expression_gene.index, filtered_tpm_gene.index)]) and
        all([i == j for i,j in zip(filtered_expression_transcript.index, filtered_tpm_transcript.index)])
    )

(   expression.shape, 
    filtered_expression_transcript.shape, 
    expression_gene.shape, 
    filtered_expression_gene.shape, 
    filtered_tpm_transcript.shape, 
    filtered_tpm_gene.shape, 
    )

In [None]:
# Create a Deseq dataframe (AnnData object).

# DeseqDataSet expects integers in counts matrix, need to check in on the default method for 
# rounding fractional counts to integers in tximport.
 
dds = DeseqDataSet(
    counts = filtered_expression_transcript.astype(int), 
    metadata = metadata, 
    design_factors = 
        [c for c in metadata.columns if c.startswith('group')]+
        [c for c in metadata.columns if c.startswith('condition')],
    ref_level=reference_level,
    refit_cooks = True, 
    n_cpus = NUM_CPUS, 
    )

dds_gene = DeseqDataSet(
    counts = filtered_expression_gene.astype(int), 
    metadata = metadata, 
    design_factors = 
        [c for c in metadata.columns if c.startswith('group')]+
        [c for c in metadata.columns if c.startswith('condition')],
    ref_level=reference_level,
    refit_cooks = True, 
    n_cpus = NUM_CPUS, 
    )

In [None]:
# Set gene-transcript mapping attribute in uns for comparisons between 
# gene- and transcript-level quantifications.

dds.uns['gene_transcript_mapping'] = gene_transcript_mapping
dds_gene.uns['gene_transcript_mapping'] = gene_transcript_mapping

In [None]:
# Set raw TPMs as layer in dds objects.

dds.layers['raw_tpm'] = np.array(filtered_tpm_transcript)
dds_gene.layers['raw_tpm'] = np.array(filtered_tpm_gene)

In [None]:
# Merge the average effective lengths into the var dataframe.

dds.var = dds.var.merge(transcript_length, left_index=True, right_index=True)
dds_gene.var = dds_gene.var.merge(gene_length, left_index=True, right_index=True)

### QC1.5 computed size-factors and library sizes.

In [None]:
# Compute size-factors and library sizes.

dds.fit_size_factors()
dds.obs['size_factors'] = dds.obsm['size_factors']
dds.obs['lib_sizes'] = dds.X.sum(axis=1)

dds_gene.fit_size_factors()
dds_gene.obs['size_factors'] = dds_gene.obsm['size_factors']
dds_gene.obs['lib_sizes'] = dds_gene.X.sum(axis=1)

dds.obs

In [None]:
# Variance-stabilizing transformation.

dds.vst()
dds_gene.vst()

dds.layers['vst_counts'], dds_gene.layers['vst_counts']

In [None]:
# Set recoverable count data.

dds.layers['counts'] = dds.X.copy()
dds_gene.layers['counts'] = dds_gene.X.copy()

In [None]:
# Compute fractional counts to get a quick idea for any weird skews in library composition.

dds.layers['fraction_counts'] = dds.layers['counts'] / np.reshape(dds.layers['counts'].sum(axis=1), (-1,1))
dds_gene.layers['fraction_counts'] = dds_gene.layers['counts'] / np.reshape(dds_gene.layers['counts'].sum(axis=1), (-1,1))

dds.layers['fraction_counts'], dds_gene.layers['fraction_counts']


### QC1.6 CDF curves library composition by fraction counts. 

In [None]:
# Plot CDF of fractional composition of libraries. 

ax_transcript = sns.ecdfplot(np.log2(dds.layers['fraction_counts'].T))
ax_transcript.set_xlabel('log2 fraction counts')
ax_transcript.legend(
        labels=dds.obs.index, 
        loc='upper left', 
        bbox_to_anchor=(1.,1.), 
        ncols=1 if len(dds.obs.index) < 10 else int(len(dds.obs.index)/10),
        frameon=False,
    )
ax_transcript.set_title('transcript')
plt.show()

ax_gene = sns.ecdfplot(np.log2(dds_gene.layers['fraction_counts'].T))
ax_gene.set_xlabel('log2 fraction counts')
ax_gene.legend(
        labels=dds_gene.obs.index, 
        loc='upper left', 
        bbox_to_anchor=(1.,1.), 
        ncols=1 if len(dds_gene.obs.index) < 10 else int(len(dds_gene.obs.index)/10),
        frameon=False,
    )
ax_gene.set_title('gene')
plt.show()


In [None]:
# Replace count matrix with variance-transformed counts, following DESeq2 recommendation
# for preprocessing count data before QC visualization.

dds.X = dds.layers['vst_counts'].copy()
dds_gene.X = dds_gene.layers['vst_counts'].copy()

np.nan_to_num(dds.X, copy=False)
np.nan_to_num(dds_gene.X, copy=False)

dds.layers['counts'], dds.X, dds.layers['vst_counts']

In [None]:
# Scale transformed variables.

sc.pp.scale(dds)
sc.pp.scale(dds_gene)

np.nan_to_num(dds.X, copy=False)
np.nan_to_num(dds_gene.X, copy=False)

dds.X.mean(axis=0), dds.X.std(axis=0), dds_gene.X.mean(axis=0), dds_gene.X.std(axis=0)

### QC1.7 PCA on vst-counts colored on conditions, groups, and defined PCA variables. 

In [None]:
# Preliminary PCA on transcript- and gene-level data.

suffix_size = 4

sc.pp.pca(dds)
ax_transcript_pca = sc.pl.pca( 
    dds, 
    color=
        [c for c in dds.obs.columns if c.startswith('group')]+
        [c for c in dds.obs.columns if c.startswith('condition')]+
        ['lib_sizes', 'library_layout', 'instrument_model'], 
    size = 128,
    show=False,
    )

for i, s in enumerate(dds.obsm['X_pca']):
    if type(ax_transcript_pca) == list:
        for ax in ax_transcript_pca:
            ax.text(s[0], s[1], dds.obs.index[i][-suffix_size:])
    else:
        ax_transcript_pca.text(s[0], s[1], dds.obs.index[i][-suffix_size:])

sc.pp.pca(dds_gene)
ax_gene_pca = sc.pl.pca(
    dds_gene, 
    color=
        [c for c in dds_gene.obs.columns if c.startswith('group')]+
        [c for c in dds_gene.obs.columns if c.startswith('condition')]+
        ['lib_sizes', 'library_layout', 'instrument_model'], 
    size = 128,
    show=False, 
    )

for i, s in enumerate(dds_gene.obsm['X_pca']):
    if type(ax_gene_pca) == list:
        for ax in ax_gene_pca:
            ax.text(s[0], s[1], dds_gene.obs.index[i][-suffix_size:])
    else:
        ax_gene_pca.text(s[0], s[1], dds_gene.obs.index[i][-suffix_size:])


In [None]:
# Plot PCA out to 4 PCs.

transcript_color_columns = [c for c in dds.obs.columns if c.startswith(('group', 'condition',))]+ \
    ['lib_sizes', 'library_layout', 'instrument_model']
ax_transcript_pca = sc.pl.pca( 
    dds, 
    color=transcript_color_columns, 
    size = 128,
    show=False,
    ncols=len(transcript_color_columns),
    components=['1,2', '2,3', '3,4', '4,5'],
)

gene_color_columns = [c for c in dds_gene.obs.columns if c.startswith(('group', 'condition',))]+ \
    ['lib_sizes', 'library_layout', 'instrument_model']
ax_transcript_pca = sc.pl.pca( 
    dds_gene, 
    color=gene_color_columns, 
    size = 128,
    show=False,
    ncols=len(gene_color_columns),
    components=['1,2', '2,3', '3,4', '4,5'],
)


### QC1.8 PC - explained variance ratios. 

In [None]:
# Plot explained variance ratios.

fig, ax = plt.subplots(1,2,figsize=(10,5))

ax[0].plot(dds.uns['pca']['variance_ratio'])
ax[1].plot(dds_gene.uns['pca']['variance_ratio'])

ax[0].set_ylabel('fraction explained variance')
ax[0].set_xlabel('PC')
ax[1].set_xlabel('PC')
ax[0].set_title('Transcript')
ax[1].set_title('Gene')


### QC1.9 PCA loadings for transcript(top) and gene(bottom).

In [None]:
# Plot loadings for first 3 PCs.

sc.pl.pca_loadings(dds, components = '1,2,3')
sc.pl.pca_loadings(dds_gene, components = '1,2,3')

### QC1.10 sample-sample pearson correlation.

In [None]:
# Sample-sample pearson correlation.

dds.layers['vst_counts'].shape
dds_gene.layers['vst_counts'].shape

dist = np.corrcoef(dds.layers['vst_counts'])

ax_transcript = sns.heatmap(
                dist, 
                xticklabels=dds.obs.index, 
                yticklabels=dds.obs['condition-1'], 
                cbar_kws={'label': 'pearson r'}, 
            )
ax_transcript.set_title('transcript')
plt.show()

dist = np.corrcoef(dds_gene.layers['vst_counts'])

ax_gene = sns.heatmap(
                dist, 
                xticklabels=dds_gene.obs.index, 
                yticklabels=dds_gene.obs['condition-1'], 
                cbar_kws={'label': 'pearson r'},
            )
ax_gene.set_title('gene')
plt.show()



In [None]:
# # Sample-sample pearson correlation with plotly.

# dds.layers['vst_counts'].shape
# dds_gene.layers['vst_counts'].shape

# dist = np.corrcoef(dds.layers['vst_counts'])

# fig = go.Figure(
#         data=go.Heatmap(
#                     z=dist,
#                     x=dds.obs.index,
#                     y=dds.obs.reset_index()[
#                         [c for c in dds.obs.columns if c.startswith('condition') or c.startswith('group')]+['accession']
#                                             ].agg('_'.join,axis=1),
#                     hoverongaps = False, 
#                     colorbar={'title': 'pearson r'},
#                 )
#             )

# fig.update_layout(
#                 height=500,
#                 width=700,
#                 title_text='transcript'
#             )

# iplot(fig)

# dist = np.corrcoef(dds_gene.layers['vst_counts'])

# fig = go.Figure(
#         data=go.Heatmap(
#                     z=dist,
#                     x=dds_gene.obs.index,
#                     y=dds_gene.obs.reset_index()[
#                         [c for c in dds_gene.obs.columns if c.startswith('condition') or c.startswith('group')]+['accession']
#                                             ].agg('_'.join, axis=1),
#                     hoverongaps = False, 
#                     colorbar={'title': 'pearson r'},
#                 )
#             )

# fig.update_layout(
#                 height=500,
#                 width=700,
#                 title_text='gene'
#             )

# iplot(fig)


In [None]:
# Restore the original counts data.

dds.X = dds.layers['counts'].copy()
dds_gene.X = dds_gene.layers['counts'].copy()


In [None]:
# Fit dispersions, logFCs, and calculate cooks.

dds.deseq2()
dds_gene.deseq2()

### DE1.1 fitted and squeezed dispersions. 

In [None]:
# Plot fitted dispersions.

fig, ax = plt.subplots(1,2,figsize=(10,5))

ax[0].scatter(
        np.log(dds.varm['_normed_means']), 
        np.log(dds.varm['genewise_dispersions']), 
        s=1, 
        alpha=0.01, 
        label='raw',
    )
ax[0].scatter(
        np.log(dds.varm['_normed_means']), 
        np.log(dds.varm['dispersions']), 
        s=1, 
        alpha=0.01, 
        label='squeezed',
    )
ax[0].scatter(
        np.log(dds.varm['_normed_means']), 
        np.log(dds.varm['fitted_dispersions']), 
        s=1, 
        alpha=0.01, 
        label='trended', 
        c='r', 
    )
ax[0].set_ylabel('log dispersions')
ax[0].set_xlabel('log normalized mean')
ax[0].set_title('transcript-level')
ax[0].legend(frameon=False)
legend = ax[0].legend(frameon=False)
for lh in legend.legend_handles:
    lh.set_alpha(1)

ax[1].scatter(
        np.log(dds_gene.varm['_normed_means']), 
        np.log(dds_gene.varm['genewise_dispersions']), 
        s=1, 
        alpha=0.01, 
        label='raw',
    )
ax[1].scatter(
        np.log(dds_gene.varm['_normed_means']), 
        np.log(dds_gene.varm['dispersions']), 
        s=1, 
        alpha=0.01, 
        label='squeezed',
    )
ax[1].scatter(
        np.log(dds_gene.varm['_normed_means']), 
        np.log(dds_gene.varm['fitted_dispersions']), 
        s=1, 
        alpha=0.01, 
        label='trended', 
        c='r', 
    )
ax[1].set_xlabel('log normalized mean')
ax[1].set_title('gene-level')
legend = ax[1].legend(frameon=False)
for lh in legend.legend_handles:
    lh.set_alpha(1)

### DE1.2 metadata and design matrix.

In [None]:
metadata

In [None]:
dds.obsm['design_matrix'], dds_gene.obsm['design_matrix']

### DE1.3 Fit model and run tests.

In [None]:
# Create Stats object. Define relevant contrasts for DE and LogFC computations and run tests. 

# Holds all DeseqStats objects as defined in contrasts.
dds.uns['stat_results'] = {}
dds_gene.uns['stat_results'] = {}

for k, v in contrasts.items():

    # Relevel design matrix and recalculate logFCs.

    relevel_design(dds, [v[0], v[2]])
    relevel_design(dds_gene, [v[0], v[2]])

    stat_res = DeseqStats(dds, contrast=v, n_cpus=NUM_CPUS)
    stat_res_gene = DeseqStats(dds_gene, contrast=v, n_cpus=NUM_CPUS)

    stat_res.summary()
    stat_res_gene.summary()

    stat_res.lfc_shrink()
    stat_res_gene.lfc_shrink()

    dds.varm['LFC_reflevel_%s' % v[2]] = dds.varm['LFC'].copy()
    dds_gene.varm['LFC_reflevel_%s' % v[2]] = dds_gene.varm['LFC'].copy()
    
    dds.uns['stat_results'][k] = stat_res.results_df.copy()
    dds_gene.uns['stat_results'][k] = stat_res_gene.results_df.copy()

### DE1.4 Run LRT on full design against null design.

In [None]:
# # Run an anova-like test by manually specifying design factors to drop from the full model.


# lrt_df = likelihood_ratio_test(dds, dds.design_factors)
# lrt_df_gene = likelihood_ratio_test(dds, dds.design_factors)

# display(lrt_df)
# display(lrt_df_gene)

# dds.uns['stat_results_lrtnull'] = lrt_df.copy()
# dds_gene.uns['stat_results_lrtnull'] = lrt_df_gene.copy()

In [None]:
# Calculate -log10_padj and fill NaN. NaN caused by independent filtering of pvalues before bh correction.

for i, k in enumerate(contrasts.keys()):
    
    dds.uns['stat_results'][k]['-log10_padj'] = -1. * np.log10(dds.uns['stat_results'][k]['padj'])
    dds_gene.uns['stat_results'][k]['-log10_padj'] = -1. * np.log10(dds_gene.uns['stat_results'][k]['padj'])

    dds.uns['stat_results'][k]['-log10_padj'].fillna(0.0, inplace=True)
    dds_gene.uns['stat_results'][k]['-log10_padj'].fillna(0.0, inplace=True)

    dds.uns['stat_results'][k]['-log10_padj'].replace(np.inf, MAX_NLOG10_PADJ, inplace=True)
    dds_gene.uns['stat_results'][k]['-log10_padj'].replace(np.inf, MAX_NLOG10_PADJ, inplace=True)


In [None]:
# Dump results dataframes to results folder.

for k in contrasts.keys():

    dds.uns['stat_results'][k].to_csv('%s_%s_%s.csv' % (RESULTS_PATH, 'transcript', k))
    dds_gene.uns['stat_results'][k].to_csv('%s_%s_%s.csv' % (RESULTS_PATH, 'gene', k))


In [None]:
# Write dds objects to files for DE and LogFC calculations.

# Pydeseq2 supports trend_coeffs/replaced as either np.array or pd.series, np.array required for 
# saving h5-formatted AnnData objects.
dds.uns['trend_coeffs'] = np.array(dds.uns['trend_coeffs'])
dds_gene.uns['trend_coeffs'] = np.array(dds_gene.uns['trend_coeffs'])

dds.varm['replaced'] = np.array(dds.varm['replaced'])
dds_gene.varm['replaced'] = np.array(dds_gene.varm['replaced'])

# Add contrasts to uns.
dds.uns['contrasts'] = contrasts
dds_gene.uns['contrasts'] = contrasts

# DeseqDataSet doesn't have native support for writing h5, save as AnnData objects and restore from
# AnnData objects.

dds.write(DDS_TRANSCRIPT_FH)
dds_gene.write(DDS_GENE_FH)