# 2. Quality Control, Filtering and Differential Expression Analysis

This notebook performs quality control checks, filters samples, and runs differential expression testing on RNA-seq data.

###  Required User Input

1. Define Contrasts for Differential Expression Analysis
    - You must specify the contrasts you want to analyze in the format: 
        - *contrasts = {
            'CONTRAST_NAME': ['column_name', 'treatment_level', 'reference_level']
        }*
    - **Important**: The notebook will not complete if:
        - Contrasts are not defined
        - Contrast levels don't exist in the metadata CSV

2. Define Samples/Groups to Drop
    - Optionally specify any samples or groups that should be excluded from the analysis:
        - Individual samples can be dropped by their accession ID
        - Entire groups can be dropped based on their group designation

#### Workflow Overview
1. Load and prepare data
2. Perform quality control checks
3. Filter samples based on user criteria
4. Run differential expression analysis
5. Generate visualizations and statistics

#### Expected Input Files
1. MultiQC Report
   - Location: `rnaseq_output/multiqc/star_salmon/multiqc_report.html`
   - Generated by nf-core/rnaseq pipeline

2. Salmon Quantification Files
   - Location: `rnaseq_output/star_salmon/`
   - Files per sample:
     - `quant.sf` - Transcript-level quantification
     - `quant.genes.sf` - Gene-level quantification

3. Merged Expression Files
   - Location: `rnaseq_output/star_salmon/`
   - Files:
     - `salmon.merged.transcript_counts.tsv` - Merged transcript counts
     - `salmon.merged.transcript_tpm.tsv` - Merged transcript TPMs
     - `salmon.merged.gene_counts_length_scaled.tsv` - Merged gene counts (length-scaled)
     - `salmon.merged.gene_tpm.tsv` - Merged gene TPMs

4. Metadata File
   - Location: `de_results/{EXPERIMENT_ID}_metadata.csv`
   - Required columns:
     - Sample accession IDs as index
     - At least one condition column (e.g., 'condition-1')
     - Optional grouping columns (e.g., 'group-1')



#### Output Files 
- AnnData objects with DE data

In [None]:
from pathlib import Path 
import os 
import warnings 
import html

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scanpy as sc
import seaborn as sns
from IPython.display import HTML
from tqdm import tqdm

from pydeseq2.ds import DeseqStats
from pydeseq2.dds import DeseqDataSet

from src.utils import (
    relevel_design, 
    likelihood_ratio_test,
    pca_anova_model, 
    pca_variance_components_model,
)

warnings.filterwarnings('ignore')

### 2.1 Configure Notebook 

#### Required Variable Definitions:
Notebook Paths/Parameters:
- *NUM_CPUS*: Number of CPUs to use for parallel processing
- *DATA_PATH*: Root directory path for the project
- *MULTIQC_PATH*: Path to MultiQC report from nf-core/rnaseq pipeline
- *COUNT_PATH*: Directory containing Salmon quantification files
- *RESULTS_PATH*: Output directory for differential expression results
- *METADATA_FH*: Path to metadata CSV file containing sample information
- *PCA_VARIABLES*: Variables to color PCA plots by
- *DDS_TRANSCRIPT_FH*: Path to save transcript-level DESeq2 dataset object
- *DDS_GENE_FH*: Path to save gene-level DESeq2 dataset object

Analysis Parameters:
- *MAX_NLOG10_PADJ*: Maximum -log10 adjusted p-value to prevent infinity (default: 400)
- *TRANSCRIPT_SUM_FILTER*: Minimum sum of counts to keep a transcript (default: 1)
    - Should be set low for experiments where lowly-expressed receptors are of interest
- *GENE_SUM_FILTER*: Minimum sum of counts to keep a gene (default: 1)
    - Should be set low for experiments where lowly-expressed receptors are of interest

In [None]:
NUM_CPUS = 8

DATA_PATH = Path.cwd().parent

EXPERIMENT_ID = DATA_PATH.parts[-1]

MULTIQC_PATH = DATA_PATH / 'rnaseq_output/multiqc/star_salmon/multiqc_report.html'

COUNT_PATH = DATA_PATH / 'rnaseq_output/star_salmon'

RESULTS_PATH = DATA_PATH / 'de_results'

METADATA_FH = RESULTS_PATH / f'{EXPERIMENT_ID}_metadata.csv'

PCA_VARIABLES = ['lib_sizes', 'library_layout', 'instrument_model']

DDS_TRANSCRIPT_FH = RESULTS_PATH / f'{EXPERIMENT_ID}_dds_transcript.h5_ad'
DDS_GENE_FH = RESULTS_PATH / f'{EXPERIMENT_ID}_dds_gene.h5_ad'

MAX_NLOG10_PADJ = 400.
TRANSCRIPT_SUM_FILTER = 1
GENE_SUM_FILTER = 1

### 2.2 MultiQC output

In [None]:
# Embed the MulitQC report into notebook. Note srcdoc was reqiured to get the html rendered without screwing up the 
# styling of the notebook while using an iframe. Buttons on the side don't work, but everything else seems to work fine.

with open(MULTIQC_PATH,'r') as f_in:
    html_raw = html.escape(f_in.read())

HTML(f'<iframe srcdoc="{html_raw}" width="1200px" height="1000px"></iframe>')

### 2.3 Metadata

In [None]:
# Read sample metadata into dataframe

metadata = pd.read_csv(METADATA_FH, index_col=0)
smallest_condition_size = metadata.loc[
    :,[metadata.columns.str.startswith('condition')]
].value_counts()[-1]

metadata, smallest_condition_size

### 2.4 Define Contrasts

Contrasts must be defined in a dictionary with the following format:

*contrasts = {
    'CONTRAST_NAME': ['column_name', 'treatment_level', 'reference_level']
}*

**Requirements:**
- Column name must exist in metadata DataFrame
- Treatment and reference levels must exist in the specified column
- Underscores in levels will be automatically replaced with hyphens

**Example:**
*contrasts = {
    'TYPE_2_DIABETES_vs_CONTROL': ['condition-1', 'DISEASE-1', 'CONTROL']
}*

Each contrast will generate a DE results DataFrame stored in:
- *anndata.uns['stat_results'][<contrast_name>]*

In [None]:
# Manually define contrasts given conditions in metadata dataframe. Define reference level for comparisons.


# Pydeseq2 contrasts require condition-name, treatment level, reference level format.

# Example contrast:
# contrasts = {
#     'TYPE_2_DIABETES_vs_CONTROL': ['condition-1','DISEASE-1', 'CONTROL'],
#     }

reference_level = ['condition-1', 'CONTROL']

# All underscores are replaced by hyphens.
assert (len(contrasts) > 0), "No contrasts defined"
assert all(c[0] in metadata.columns for c in contrasts.values()), "Column not found in metadata" 
assert all(
    (   
        l_1.replace('-','_') in metadata[c].values and 
        l_2.replace('-','_') in metadata[c].values
    ) for c, l_1, l_2 in contrasts.values()
), "Treatment or reference level not found in metadata"

# Convert underscores to hyphens to adhere to pydeseq2 requirements
reference_level = [r.replace('_','-') for r in reference_level]
contrasts = {
    n: [r.replace('_','-') for r in l] for n, l in contrasts.items()
}

In [None]:
# Merge effective lengths and average across all runs for transcript and gene dataframes.

# Build the transcript length dataframe off of the first sample in the metadata dataframe.
transcript_length = pd.read_csv(
    COUNT_PATH / metadata.index[0] / 'quant.sf',
    delimiter='\t',
    index_col=0,
)
transcript_length.rename({'EffectiveLength':f'EffectiveLength_{metadata.index[0]}'}, axis=1, inplace=True)
transcript_length.drop(['Length', 'TPM', 'NumReads'], inplace=True, axis=1)

gene_length = pd.read_csv(
    COUNT_PATH / metadata.index[0] / 'quant.genes.sf',
    delimiter='t',
    index_col=0,
)
gene_length.rename({'EffectiveLength':f'EffectiveLength_{metadata.index[0]}'}, axis=1, inplace=True)
gene_length.drop(['Length', 'TPM', 'NumReads'], inplace=True, axis=1)

# Populate samples into effective length dataframe with remaining samples.
for srx in tqdm(metadata.index[1:]):
    df = pd.read_csv(
        COUNT_PATH / metadata.index[0] / 'quant.sf',
        delimiter='\t',
        index_col=0,
    )
    df.drop(['Length', 'TPM', 'NumReads'], inplace=True, axis=1)
    df.rename({'EffectiveLength':f'EffectiveLength_{srx}'}, axis=1, inplace=True)
    transcript_length = transcript_length.merge(df, on='Name')

    df_gene = pd.read_csv(os.path.join(COUNT_PATH, srx, 'quant.genes.sf'), delimiter='\t', index_col=0)
    df_gene.drop(['Length', 'TPM', 'NumReads'], inplace=True, axis=1)
    df_gene.rename({'EffectiveLength':f'EffectiveLength_{srx}'}, axis=1, inplace=True)
    gene_length = gene_length.merge(df_gene, on='Name')


# Average effective lengths.
transcript_length['length'] = transcript_length.mean(axis=1)
gene_length['length'] = gene_length.mean(axis=1)
transcript_length.drop([c for c in transcript_length.columns if c.startswith('EffectiveLength_')], inplace=True, axis=1)
gene_length.drop([c for c in gene_length.columns if c.startswith('EffectiveLength_')], inplace=True, axis=1)

transcript_length.shape, gene_length.shape

In [None]:
# Read in transcript and gene count/TPM dataframes.

expression = pd.read_csv(
    COUNT_PATH / 'salmon.merged.transcript_counts.tsv', 
    delimiter = '\t', 
    index_col=0,
)
gene_transcript_mapping = expression.loc[:,['gene_id']].reset_index()
expression.drop('gene_id', inplace=True, axis=1)

tpm = pd.read_csv(
    COUNT_PATH / 'salmon.merged.transcript_tpm.tsv', 
    delimiter = '\t', 
    index_col=0,
)
tpm.drop('gene_id', inplace=True, axis=1)

# See: https://nf-co.re/rnaseq/3.12.0/docs/output#salmon on output choice below 
# salmon.merged.gene_counts_length_scaled.tsv is the gene-level output of nf-core rnaseq that is bias-corrected
# and is already scaled by potential transcript length
expression_gene = pd.read_csv(
    COUNT_PATH / 'salmon.merged.gene_counts_length_scaled.tsv', 
    delimiter='\t', 
    index_col=0,
)
expression_gene.drop('gene_name', inplace=True, axis=1)

tpm_gene = pd.read_csv(
    COUNT_PATH / 'salmon.merged.gene_tpm.tsv', 
    delimiter='\t', 
    index_col=0,
)
tpm_gene.drop('gene_name', inplace=True, axis=1)

expression.shape, expression_gene.shape

In [None]:
# Filter expression dataframes on samples not in metadata.csv

samples_to_drop = [c for c in expression.columns if c not in metadata.index]

expression.drop(samples_to_drop, axis=1, inplace=True)
expression_gene.drop(samples_to_drop, axis=1, inplace=True)
tpm.drop(samples_to_drop, axis=1, inplace=True)
tpm_gene.drop(samples_to_drop, axis=1, inplace=True)
expression

### 2.5 Drop samples or groups from count dataframe or design matrix, respectively.
* Samples are dropped based on "SRX" id, or cell in 'accession' column of the metadata dataframe.
* Groups are dropped based on group id ex. 'group-1', a grouping present in the metadata dataframe.

In [None]:
# Drop specific samples from dataframe. Provide accession name of samples to remove from analysis.

# Clear outlier.
samples_to_drop = []

expression.drop(samples_to_drop, axis=1, inplace=True)
expression_gene.drop(samples_to_drop, axis=1, inplace=True)
tpm.drop(samples_to_drop, axis=1, inplace=True)
tpm_gene.drop(samples_to_drop, axis=1, inplace=True)
metadata.drop(samples_to_drop, axis=0, inplace=True)

# Drop specific conditions/groups from metadata dataframe. 

groups_to_drop = []
metadata.drop(groups_to_drop, axis=1, inplace=True)

In [None]:
# Filter expression dataframe and prepare for QC/EDA.

# Filter and prepare transcript-level expression data
filtered_expression_transcript = expression.T.copy()
filtered_expression_transcript = filtered_expression_transcript.loc[
    :,filtered_expression_transcript.sum(axis=0) >= TRANSCRIPT_SUM_FILTER
]
filtered_expression_transcript = filtered_expression_transcript.loc[
    :,(filtered_expression_transcript >= TRANSCRIPT_SUM_FILTER).sum(axis=0) > smallest_condition_size
]

# Filter corresponding transcript TPMs
filtered_tpm_transcript = tpm.T.copy()
filtered_tpm_transcript = filtered_tpm_transcript.loc[:, filtered_expression_transcript.columns]

# Filter and prepare gene-level expression data
filtered_expression_gene = expression_gene.T.copy()
filtered_expression_gene = filtered_expression_gene.loc[
    :, filtered_expression_gene.sum(axis=0) >= GENE_SUM_FILTER
]
filtered_expression_gene = filtered_expression_gene.loc[
    :, (filtered_expression_gene >= GENE_SUM_FILTER).sum(axis=0) > smallest_condition_size
]

# Filter corresponding gene TPMs
filtered_tpm_gene = tpm_gene.T.copy()
filtered_tpm_gene = filtered_tpm_gene[filtered_expression_gene.columns]

# Verify alignment of all filtered datasets
assert all([
    filtered_expression_gene.columns.equals(filtered_tpm_gene.columns),
    filtered_expression_transcript.columns.equals(filtered_tpm_transcript.columns),
    filtered_expression_gene.index.equals(filtered_tpm_gene.index),
    filtered_expression_transcript.index.equals(filtered_tpm_transcript.index)
]), "Misaligned indices or columns in filtered datasets"

# Shapes of original and filtered datasets

(   expression.shape, 
    filtered_expression_transcript.shape, 
    expression_gene.shape, 
    filtered_expression_gene.shape, 
    filtered_tpm_transcript.shape, 
    filtered_tpm_gene.shape, 
)

### 2.6 Build DeseqDataSet objects from gene and transcript data 

In [None]:
# Create a Deseq dataframe (AnnData object).

# DeseqDataSet expects integers in counts matrix, need to check in on the default method for 
# rounding fractional counts to integers in tximport.
 
dds = DeseqDataSet(
    counts = filtered_expression_transcript.astype(int), 
    metadata = metadata, 
    design_factors = (
        [c for c in metadata.columns if c.startswith('group')]+
        [c for c in metadata.columns if c.startswith('condition')]
    ),
    ref_level=reference_level,
    refit_cooks = True, 
    n_cpus = NUM_CPUS, 
)

dds_gene = DeseqDataSet(
    counts = filtered_expression_gene.astype(int), 
    metadata = metadata, 
    design_factors = (
        [c for c in metadata.columns if c.startswith('group')]+
        [c for c in metadata.columns if c.startswith('condition')]
    ),
    ref_level=reference_level,
    refit_cooks = True, 
    n_cpus = NUM_CPUS, 
)

In [None]:
# Set gene-transcript mapping attribute in uns for comparisons between 
# gene- and transcript-level quantifications.

dds.uns['gene_transcript_mapping'] = gene_transcript_mapping
dds_gene.uns['gene_transcript_mapping'] = gene_transcript_mapping

In [None]:
# Set raw TPMs as layer in dds objects.

dds.layers['raw_tpm'] = np.array(filtered_tpm_transcript)
dds_gene.layers['raw_tpm'] = np.array(filtered_tpm_gene)

In [None]:
# Merge the average effective lengths into the var dataframe.

dds.var = dds.var.merge(transcript_length, left_index=True, right_index=True)
dds_gene.var = dds_gene.var.merge(gene_length, left_index=True, right_index=True)

### 2.7 Compute size-factors, library sizes, and perform vst

In [None]:
# Compute size-factors and library sizes.

dds.fit_size_factors()
dds.obs['size_factors'] = dds.obsm['size_factors']
dds.obs['lib_sizes'] = dds.X.sum(axis=1)

dds_gene.fit_size_factors()
dds_gene.obs['size_factors'] = dds_gene.obsm['size_factors']
dds_gene.obs['lib_sizes'] = dds_gene.X.sum(axis=1)

dds.obs

In [None]:
# Variance-stabilizing transformation.

dds.vst()
dds_gene.vst()

# Set recoverable count data.

dds.layers['counts'] = dds.X.copy()
dds_gene.layers['counts'] = dds_gene.X.copy()

# Compute fractional counts to get a quick idea for any weird skews in library composition.

dds.layers['fraction_counts'] = dds.layers['counts'] / np.reshape(dds.layers['counts'].sum(axis=1), (-1,1))
dds_gene.layers['fraction_counts'] = dds_gene.layers['counts'] / np.reshape(dds_gene.layers['counts'].sum(axis=1), (-1,1))

dds.layers['fraction_counts'], dds_gene.layers['fraction_counts']

### 2.8 CDF curves library composition by fraction counts. 

In [None]:
# Plot CDF of fractional composition of libraries. 

ax_transcript = sns.ecdfplot(np.log2(dds.layers['fraction_counts'].T))
ax_transcript.set_xlabel('log2 fraction counts')
ax_transcript.legend(
    labels=dds.obs.index, 
    loc='upper left', 
    bbox_to_anchor=(1.,1.), 
    ncols=1 if len(dds.obs.index) < 10 else int(len(dds.obs.index)/10),
    frameon=False,
)
ax_transcript.set_title('transcript')
plt.show()

ax_gene = sns.ecdfplot(np.log2(dds_gene.layers['fraction_counts'].T))
ax_gene.set_xlabel('log2 fraction counts')
ax_gene.legend(
    labels=dds_gene.obs.index, 
    loc='upper left', 
    bbox_to_anchor=(1.,1.), 
    ncols=1 if len(dds_gene.obs.index) < 10 else int(len(dds_gene.obs.index)/10),
    frameon=False,
)
ax_gene.set_title('gene')
plt.show()

In [None]:
# Replace count matrix with variance-transformed counts before QC and visualization. 

dds.X = dds.layers['vst_counts'].copy()
dds_gene.X = dds_gene.layers['vst_counts'].copy()

np.nan_to_num(dds.X, copy=False)
np.nan_to_num(dds_gene.X, copy=False)

# Scale transformed variables.

sc.pp.scale(dds)
sc.pp.scale(dds_gene)

np.nan_to_num(dds.X, copy=False)
np.nan_to_num(dds_gene.X, copy=False)

dds.X.mean(axis=0), dds.X.std(axis=0), dds_gene.X.mean(axis=0), dds_gene.X.std(axis=0)

### 2.9 PCA on vst-counts colored on conditions, groups, and defined PCA variables. 

In [None]:
# Preliminary PCA on transcript- and gene-level data.

suffix_size = 4

sc.pp.pca(dds)
ax_transcript_pca = sc.pl.pca( 
    dds, 
    color=(
        [c for c in dds.obs.columns if c.startswith('group')]+
        [c for c in dds.obs.columns if c.startswith('condition')]+
        ['lib_sizes', 'library_layout', 'instrument_model']
    ), 
    size = 128,
    show=False,
)

for i, s in enumerate(dds.obsm['X_pca']):
    if isinstance(ax_transcript_pca, list):
        for ax in ax_transcript_pca:
            ax.text(s[0], s[1], dds.obs.index[i][-suffix_size:])
    else:
        ax_transcript_pca.text(s[0], s[1], dds.obs.index[i][-suffix_size:])

sc.pp.pca(dds_gene)
ax_gene_pca = sc.pl.pca(
    dds_gene, 
    color=(
        [c for c in dds_gene.obs.columns if c.startswith('group')]+
        [c for c in dds_gene.obs.columns if c.startswith('condition')]+
        ['lib_sizes', 'library_layout', 'instrument_model']
    ), 
    size = 128,
    show=False, 
    )

for i, s in enumerate(dds_gene.obsm['X_pca']):
    if isinstance(ax_gene_pca, list):
        for ax in ax_gene_pca:
            ax.text(s[0], s[1], dds_gene.obs.index[i][-suffix_size:])
    else:
        ax_gene_pca.text(s[0], s[1], dds_gene.obs.index[i][-suffix_size:])


### 2.10 Plot additional PCs 

In [None]:
# Plot PCA out to 4 PCs for transcript-level quantifcations. 

transcript_color_columns = [c for c in dds.obs.columns if c.startswith(('group', 'condition',))]+ \
    ['lib_sizes', 'library_layout', 'instrument_model']
ax_transcript_pca = sc.pl.pca( 
    dds, 
    color=transcript_color_columns, 
    size = 128,
    show=False,
    ncols=len(transcript_color_columns),
    components=['1,2', '2,3', '3,4', '4,5'],
)

gene_color_columns = [c for c in dds_gene.obs.columns if c.startswith(('group', 'condition',))]+ \
    ['lib_sizes', 'library_layout', 'instrument_model']
ax_transcript_pca = sc.pl.pca( 
    dds_gene, 
    color=gene_color_columns, 
    size = 128,
    show=False,
    ncols=len(gene_color_columns),
    components=['1,2', '2,3', '3,4', '4,5'],
)


### 2.11 PC - explained variance ratios. 

In [None]:
# Plot explained variance ratios.

fig, ax = plt.subplots(1,2,figsize=(10,5))

ax[0].plot(dds.uns['pca']['variance_ratio'])
ax[1].plot(dds_gene.uns['pca']['variance_ratio'])

ax[0].set_ylabel('fraction explained variance')
ax[0].set_xlabel('PC')
ax[1].set_xlabel('PC')
ax[0].set_title('Transcript')
ax[1].set_title('Gene')

### 2.12 PCA loadings for transcript(top) and gene(bottom).

In [None]:
# Plot loadings for first 3 PCs.

sc.pl.pca_loadings(dds, components = '1,2,3')
sc.pl.pca_loadings(dds_gene, components = '1,2,3')

### 2.13 Sample-Sample pearson correlation.

In [None]:
# Sample-sample pearson correlation.

dds.layers['vst_counts'].shape
dds_gene.layers['vst_counts'].shape

dist = np.corrcoef(dds.layers['vst_counts'])

ax_transcript = sns.heatmap(
    dist, 
    xticklabels=dds.obs.index, 
    yticklabels=dds.obs['condition-1'], 
    cbar_kws={'label': 'pearson r'}, 
)
ax_transcript.set_title('transcript')
plt.show()

dist = np.corrcoef(dds_gene.layers['vst_counts'])

ax_gene = sns.heatmap(
    dist, 
    xticklabels=dds_gene.obs.index, 
    yticklabels=dds_gene.obs['condition-1'], 
    cbar_kws={'label': 'pearson r'},
)
ax_gene.set_title('gene')
plt.show()

### 2.14 Reset counts, fit dispersion estimates.

In [None]:
# Restore the original counts data.

dds.X = dds.layers['counts'].copy()
dds_gene.X = dds_gene.layers['counts'].copy()

# Fit dispersions, logFCs, and calculate cooks.

dds.deseq2()
dds_gene.deseq2()

### 2.15 Plot fitted and squeezed dispersions. 

In [None]:
# Plot fitted dispersions.

fig, ax = plt.subplots(1,2,figsize=(10,5))

ax[0].scatter(
    np.log(dds.varm['_normed_means']), 
    np.log(dds.varm['genewise_dispersions']), 
    s=1, 
    alpha=0.01, 
    label='raw',
)
ax[0].scatter(
    np.log(dds.varm['_normed_means']), 
    np.log(dds.varm['dispersions']), 
    s=1, 
    alpha=0.01, 
    label='squeezed',
)
ax[0].scatter(
    np.log(dds.varm['_normed_means']), 
    np.log(dds.varm['fitted_dispersions']), 
    s=1, 
    alpha=0.01, 
    label='trended', 
    c='r', 
)
ax[0].set_ylabel('log dispersions')
ax[0].set_xlabel('log normalized mean')
ax[0].set_title('transcript-level')
ax[0].legend(frameon=False)
legend = ax[0].legend(frameon=False)
for lh in legend.legend_handles:
    lh.set_alpha(1)

ax[1].scatter(
    np.log(dds_gene.varm['_normed_means']), 
    np.log(dds_gene.varm['genewise_dispersions']), 
    s=1, 
    alpha=0.01, 
    label='raw',
)
ax[1].scatter(
    np.log(dds_gene.varm['_normed_means']), 
    np.log(dds_gene.varm['dispersions']), 
    s=1, 
    alpha=0.01, 
    label='squeezed',
)
ax[1].scatter(
    np.log(dds_gene.varm['_normed_means']), 
    np.log(dds_gene.varm['fitted_dispersions']), 
    s=1, 
    alpha=0.01, 
    label='trended', 
    c='r', 
)
ax[1].set_xlabel('log normalized mean')
ax[1].set_title('gene-level')
legend = ax[1].legend(frameon=False)
for lh in legend.legend_handles:
    lh.set_alpha(1)

### 2.16 Examine metadata and design matrix.

In [None]:
metadata

In [None]:
dds.obsm['design_matrix']

### 2.17 Fit model and run DE tests.

In [None]:
# Create Stats object. Define relevant contrasts for DE and LogFC computations and run tests. 

# Holds all DeseqStats objects as defined in contrasts.
dds.uns['stat_results'] = {}
dds_gene.uns['stat_results'] = {}

for k, v in contrasts.items():

    # Relevel design matrix and recalculate logFCs.

    relevel_design(dds, [v[0], v[2]])
    relevel_design(dds_gene, [v[0], v[2]])

    stat_res = DeseqStats(dds, contrast=v, n_cpus=NUM_CPUS)
    stat_res_gene = DeseqStats(dds_gene, contrast=v, n_cpus=NUM_CPUS)

    stat_res.summary()
    stat_res_gene.summary()

    stat_res.lfc_shrink()
    stat_res_gene.lfc_shrink()

    dds.varm[f'LFC_reflevel_{v[2]}'] = dds.varm['LFC'].copy()
    dds_gene.varm[f'LFC_reflevel_{v[2]}'] = dds_gene.varm['LFC'].copy()
    
    dds.uns['stat_results'][k] = stat_res.results_df.copy()
    dds_gene.uns['stat_results'][k] = stat_res_gene.results_df.copy()

### 2.18 Run LRT on full design against null design.

In [None]:
# Run an anova-like test by manually specifying design factors to drop from the full model.

lrt_df = likelihood_ratio_test(dds, dds.design_factors)
lrt_df_gene = likelihood_ratio_test(dds, dds.design_factors)

dds.uns['stat_results_lrtnull'] = lrt_df.copy()
dds_gene.uns['stat_results_lrtnull'] = lrt_df_gene.copy()

### 2.19 Clean up DE results and dump.

In [None]:
# Calculate -log10_padj and fill NaN. NaN caused by independent filtering of pvalues before bh correction.

for i, k in enumerate(contrasts.keys()):
    
    dds.uns['stat_results'][k]['-log10_padj'] = -1. * np.log10(dds.uns['stat_results'][k]['padj'])
    dds_gene.uns['stat_results'][k]['-log10_padj'] = -1. * np.log10(dds_gene.uns['stat_results'][k]['padj'])

    dds.uns['stat_results'][k]['-log10_padj'].fillna(0.0, inplace=True)
    dds_gene.uns['stat_results'][k]['-log10_padj'].fillna(0.0, inplace=True)

    dds.uns['stat_results'][k]['-log10_padj'].replace(np.inf, MAX_NLOG10_PADJ, inplace=True)
    dds_gene.uns['stat_results'][k]['-log10_padj'].replace(np.inf, MAX_NLOG10_PADJ, inplace=True)

In [None]:
# Write dds objects to files for DE and LogFC calculations.

# Pydeseq2 supports trend_coeffs/replaced as either np.array or pd.series, np.array required for 
# saving h5-formatted AnnData objects.
dds.uns['trend_coeffs'] = np.array(dds.uns['trend_coeffs'])
dds_gene.uns['trend_coeffs'] = np.array(dds_gene.uns['trend_coeffs'])

dds.varm['replaced'] = np.array(dds.varm['replaced'])
dds_gene.varm['replaced'] = np.array(dds_gene.varm['replaced'])

# Add contrasts to uns.
dds.uns['contrasts'] = contrasts
dds_gene.uns['contrasts'] = contrasts

# DeseqDataSet doesn't have native support for writing h5, save as AnnData objects and restore from AnnData objects.

dds.write(DDS_TRANSCRIPT_FH)
dds_gene.write(DDS_GENE_FH)