# 3. Differential Expression Visualization and QC

This notebook performs visualization and quality control analysis of differential expression results.

#### Workflow Overview
1. Creation of standard DE visualization plots:
   - Volcano plots
   - MA plots
   - Gene vs transcript level comparisons

2. Quality control of pydeseq2 results:
   - Validation of log fold change calculations
   - Comparison of raw vs squeezed log fold changes
   - Cross-contrast comparisons

3. Data Processing:
   - TPM scaling with normalized counts
   - Calculating group averages 

#### Expected Input Files
-The notebook expects DESeq2 dataset objects (h5ad format) from the previous notebook containing:
   - Normalized counts
   - Size factors
   - Differential expression results
   - Sample metadata
   - Contrast definitions

#### Output Files
- CSV files containing DE results for each contrast
- CSV file that describes samples used in each contrast
- Updated AnnData objects

In [None]:
from pathlib import Path
import warnings
from itertools import combinations

import pandas as pd
import numpy as np
import anndata as ad
import scanpy as sc
from pybiomart import Dataset

import matplotlib.pyplot as plt
from matplotlib import lines

from pydeseq2.dds import DeseqDataSet
from pydeseq2.ds import DeseqStats

warnings.filterwarnings('ignore')

### 3.1 Configure Notebook

#### Required Variable Definitions:
Processing Parameters
- *NUM_CPUS*: Number of CPUs to use for parallel processing (default: 8)
- *LOG2_FC_THRESH*: Log2 fold change threshold for significance (default: log2(2.0))
- *NLOG10_PADJ_THRESH*: -log10 adjusted p-value threshold (default: -log10(0.05))

File Paths
- *DATA_PATH*: Root directory path for the project (default: current working directory)
- *RESULTS_PATH*: Output directory for differential expression results (default: DATA_PATH/de_results)
- *DDS_TRANSCRIPT_FH*: Path to transcript-level DESeq2 dataset object
- *DDS_GENE_FH*: Path to gene-level DESeq2 dataset object

In [None]:
NUM_CPUS = 8
LOG2_FC_THRESH = np.log2(2.0)
NLOG10_PADJ_THRESH = -1*np.log10(0.05)

DATA_PATH = Path.cwd().parent

EXPERIMENT_ID = DATA_PATH.parts[-1]

RESULTS_PATH = DATA_PATH / 'de_results'

DDS_TRANSCRIPT_FH = RESULTS_PATH / f'{EXPERIMENT_ID}_dds_transcript.h5_ad'
DDS_GENE_FH = RESULTS_PATH / f'{EXPERIMENT_ID}_dds_gene.h5_ad'

In [None]:
# Read in previously created dds objects, should be filtered and have correct
# metadata defined in obs.

dds = ad.read_h5ad(DDS_TRANSCRIPT_FH)
dds_gene = ad.read_h5ad(DDS_GENE_FH)

contrasts = dds.uns['contrasts']

### 3.2 Fetch ensembl gene id - external gene mappings from biomart.

In [None]:
# Fetch ensembl gene id - external gene name mappings from biomart.

species_mapping = 'hsapiens' if any(dds.var.index.str.startswith('ENSG')) else (
    'mmusculus' if any(dds.var.index.str.startswith('ENSMUSG')) else None
)
if not species_mapping:
    raise Exception('Must pick a valid species, adjust above if not human/mouse.')

dataset = Dataset(
    name=species_mapping,
    host='http://www.ensembl.org',
)

external_gene_mapping = dataset.query(
    attributes=['ensembl_gene_id', 'external_gene_name']
)

external_gene_mapping.rename({'Gene stable ID': 'gene_id', 'Gene name': 'gene_name'}, axis=1, inplace=True)

external_gene_mapping

### 3.3 Volcano plots for all contrasts.

In [None]:
# Create volcano plots of DE-transcripts and -genes.

fig, ax = plt.subplots(len(contrasts), 2, figsize=(10,5*len(contrasts)))

ax = ax.reshape((-1,2))

scale_marker=2

for i, k in enumerate(contrasts.keys()):
    
    ax[i,0].scatter(
        dds.uns['stat_results'][k]['log2FoldChange'], 
        dds.uns['stat_results'][k]['-log10_padj'], 
        alpha=0.1,
        s=scale_marker*np.log2(dds.uns['stat_results'][k]['baseMean']),
        c=dds_gene.uns['stat_results'][k].apply(
            lambda x: 
                '#1f77b4' if (abs(x['log2FoldChange']) > LOG2_FC_THRESH) and (x['-log10_padj'] > NLOG10_PADJ_THRESH) else '#ff7f0e'
            axis=1,
        ),
    )

    ax[i,1].scatter(
        dds_gene.uns['stat_results'][k]['log2FoldChange'], 
        dds_gene.uns['stat_results'][k]['-log10_padj'], 
        alpha=0.05,
        s=scale_marker*np.log2(dds_gene.uns['stat_results'][k]['baseMean']),
        c=dds_gene.uns['stat_results'][k].apply(
            lambda x: 
                '#1f77b4' if (abs(x['log2FoldChange']) > LOG2_FC_THRESH) and (x['-log10_padj'] > NLOG10_PADJ_THRESH) else '#ff7f0e'
            axis=1,
        ),
    )

    ax[i,0].set_xlabel('log2 FC')
    ax[i,1].set_xlabel('log2 FC')
    ax[i,0].set_ylabel('-log10 padj')

    ax[i,0].set_title('%s Transcript' % k)
    ax[i,1].set_title('%s Gene' % k)
    
    element_range = np.rint(
        np.linspace(1, 5*round(max(np.log2(dds_gene.uns['stat_results'][k]['baseMean']))/5), 4)
    )

    legend_elements = [
        lines.Line2D(
            [0], 
            [0], 
            lw=0, 
            marker="o", 
            linestyle=None, 
            markersize=(scale_marker*s)**0.5,
        ) for s in element_range
    ]

    legend = ax[i,1].legend(
        legend_elements,
        element_range,
        frameon=False, 
        loc='upper left', 
        bbox_to_anchor=(1.,1.),
        title='log2 mean expression'
    )
    ax[i,1].add_artist(legend)
    
    color_legend = ax[i,1].legend(
        [
            lines.Line2D([0], [0], lw=0, marker='o', linestyle=None, markerfacecolor='#1f77b4'),
            lines.Line2D([0], [0], lw=0, marker='o', linestyle=None, markerfacecolor='#ff7f0e'),
        ],
        [
            f'log2FC > {LOG2_FC_THRESH} and -log10_padj > {NLOG10_PADJ_THRESH:.2f}',
            f'log2FC < {LOG2_FC_THRESH} and -log10_padj < {NLOG10_PADJ_THRESH:.2f}',
        ],
        frameon=False,
        loc='upper left',
        bbox_to_anchor=(1.,0.5),
    )   


### 3.4 MA plots for all contrasts.

In [None]:
# Create MA plots of DE-transcripts and -genes.

fig, ax = plt.subplots(len(contrasts),2,figsize=(10,5*len(contrasts)))

ax = ax.reshape((-1,2,))

for i,k in enumerate(contrasts.keys()):

    ax[i,0].scatter(
        np.log2(dds.uns['stat_results'][k]['baseMean']), 
        dds.uns['stat_results'][k]['log2FoldChange'], 
        alpha=0.1,
        s=dds.uns['stat_results'][k]['-log10_padj'],
        c=dds_gene.uns['stat_results'][k].apply(
            lambda x: 
                '#1f77b4' if (abs(x['log2FoldChange']) > LOG2_FC_THRESH) and (x['-log10_padj'] > NLOG10_PADJ_THRESH) else '#ff7f0e'
            axis=1,
        ),
    )

    ax[i,1].scatter(
        np.log2(dds_gene.uns['stat_results'][k]['baseMean']), 
        dds_gene.uns['stat_results'][k]['log2FoldChange'], 
        alpha=0.1,
        s=dds_gene.uns['stat_results'][k]['-log10_padj'],
        c=dds_gene.uns['stat_results'][k].apply(
            lambda x: 
                '#1f77b4' if (abs(x['log2FoldChange']) > LOG2_FC_THRESH) and (x['-log10_padj'] > NLOG10_PADJ_THRESH) else '#ff7f0e'
            axis=1,
        ),
    )

    ax[i,0].set_xlabel('log2 mean expression')
    ax[i,1].set_xlabel('log2 mean expression')
    ax[i,0].set_ylabel('log2 FC expression')

    ax[i,0].set_title('%s Transcript' % k)
    ax[i,1].set_title('%s Gene' % k)


    element_range = np.rint(
        np.linspace(1, 5*round(max(dds_gene.uns['stat_results'][k]['-log10_padj'])/5), 4)
    )

    legend_elements = [
        lines.Line2D(
            [0], 
            [0], 
            lw=0, 
            marker="o", 
            linestyle=None, 
            markersize=s**0.5,
        ) for s in element_range
    ]

    legend = ax[i,1].legend(
        legend_elements,
        element_range,
        frameon=False, 
        loc='upper left', 
        bbox_to_anchor=(1.,1.),
        title='-log10_padj'
    )
    ax[i,1].add_artist(legend)
    
    color_legend = ax[i,1].legend(
        [
            lines.Line2D([0], [0], lw=0, marker='o', linestyle=None, markerfacecolor='#1f77b4'),
            lines.Line2D([0], [0], lw=0, marker='o', linestyle=None, markerfacecolor='#ff7f0e'),
        ],
        [
            f'log2FC > {LOG2_FC_THRESH} and -log10_padj > {NLOG10_PADJ_THRESH:.2f}',
            f'log2FC < {LOG2_FC_THRESH} and -log10_padj < {NLOG10_PADJ_THRESH:.2f}',
        ],
        frameon=False,
        loc='upper left',
        bbox_to_anchor=(1.,0.5),
    )

### 3.5 DE genes/transcripts for each contrasts filtered by log2fc and padj thresholds.

In [None]:
# Filter summary tables based upon thresholds.

for i,k in enumerate(contrasts.keys()):

    markers = dds.uns['stat_results'][k].loc[
        (abs(dds.uns['stat_results'][k]['log2FoldChange']) > LOG2_FC_THRESH) & 
        (dds.uns['stat_results'][k]['-log10_padj'] > NLOG10_PADJ_THRESH)
    ]
    print('%s Transcript: %s'  % (k, len(markers)))
    print(markers.sort_values('log2FoldChange', axis=0))

    markers = dds_gene.uns['stat_results'][k].loc[
        (abs(dds_gene.uns['stat_results'][k]['log2FoldChange']) > LOG2_FC_THRESH) & 
        (dds_gene.uns['stat_results'][k]['-log10_padj'] > NLOG10_PADJ_THRESH)
    ]
    print('%s Gene: %s' % (k, len(markers)))
    print(markers.sort_values('log2FoldChange', axis=0))


### 3.6 Plot raw log2fc vs. squeezed log2fc for each contrasts (qc issues with pydeseq2 log2fc calculations).

In [None]:
# Run LFC calculations on raw inputs as gut-check for issues with pydeseq2 LFC calculations.

for k, v in contrasts.items():

    dds.varm['LFC_reflevel_%s_raw' % v[2]] = dds.varm['LFC_reflevel_%s' % v[2]].copy()
    dds_gene.varm['LFC_reflevel_%s_raw' % v[2]] = dds_gene.varm['LFC_reflevel_%s' % v[2]].copy()

    dds.varm['LFC_reflevel_%s_raw' % v[2]]['%s_%s_vs_%s' % tuple(v)] = np.log2(
        np.mean(dds.X[dds.obs[v[0]] == v[1]],axis=0) / 
            np.mean(dds.X[dds.obs[v[0]] == v[2]], axis=0)
    )
    
    dds_gene.varm['LFC_reflevel_%s_raw' % v[2]]['%s_%s_vs_%s' % tuple(v)] = np.log2(
        np.mean(dds_gene.X[dds_gene.obs[v[0]] == v[1]],axis=0) / 
            np.mean(dds_gene.X[dds_gene.obs[v[0]] == v[2]], axis=0)
    )
                                                


In [None]:
# Plot raw LFC calculations vs. squeezed LFC calculations.

fig, ax = plt.subplots(len(contrasts),2,figsize=(10,5*len(contrasts)))

ax = ax.reshape((-1,2,))

for i,(k,v) in enumerate(contrasts.items()):
    
    ax[i,0].scatter(
        dds.varm['LFC_reflevel_%s' % v[2]]['%s_%s_vs_%s' % tuple(v)],
        dds.varm['LFC_reflevel_%s_raw' % v[2]]['%s_%s_vs_%s' % tuple(v)],
        alpha=0.1,
        s=0.1,
    )
    
    ax[i,1].scatter(
        dds_gene.varm['LFC_reflevel_%s' % v[2]]['%s_%s_vs_%s' % tuple(v)],
        dds_gene.varm['LFC_reflevel_%s_raw' % v[2]]['%s_%s_vs_%s' % tuple(v)],
        alpha=0.1,
        s=0.1,
    )

    ax[i,0].set_xlabel('log2 FC squeezed')
    ax[i,1].set_xlabel('log2 FC squeezed')
    ax[i,0].set_ylabel('log2 FC raw')

    ax[i,0].set_title('%s Transcript' % k)
    ax[i,1].set_title('%s Gene' % k)

### 3.7 Plot log2fc gene vs. log2fc transcript.

In [None]:
# Transfer gene to transcript mappings to dds.var dataframe. Plot LogFC between transcript- 
# and gene-level quantifications.

gene_transcript_mapping = dict(zip(dds.uns['gene_transcript_mapping']['tx'],dds.uns['gene_transcript_mapping']['gene_id']))

fig, ax = plt.subplots(len(contrasts),1,figsize=(5,5*len(contrasts)))

if type(ax) != np.ndarray:
    ax = np.array(ax)

ax = ax.reshape((-1,1,))

for i,k in enumerate(contrasts.keys()):
    
    dds.uns['stat_results'][k]['gene_id'] = dds.uns['stat_results'][k].index.map(lambda x: gene_transcript_mapping[x])

    df = dds.uns['stat_results'][k].merge(dds_gene.uns['stat_results'][k], left_on='gene_id', right_index=True)
    ax[i,0].scatter(
        df['log2FoldChange_x'], 
        df['log2FoldChange_y'],
        alpha=0.05,
        s=5*df['-log10_padj_x']
    )
    
    ax[i,0].set_xlabel('log2 FC transcript')
    ax[i,0].set_ylabel('log2 FC gene')

    ax[i,0].set_title('%s Transcript v Gene logFC' % k)

    element_range = np.rint(
        np.linspace(1, 5*round(max(dds_gene.uns['stat_results'][k]['-log10_padj'])/5), 4)
    )

    legend_elements = [
        lines.Line2D(
            [0], 
            [0], 
            lw=0, 
            marker="o", 
            linestyle=None, 
            markersize=s**0.5,
        ) for s in element_range
    ]

    legend = ax[i,0].legend(
        legend_elements,
        element_range,
        frameon=False, 
        loc='upper left', 
        bbox_to_anchor=(1.,1.),
        title='-log10_padj transcript'
    )


### 3.8 plot log2fc of all pairs of contrasts.

In [None]:
# Plot LogFC between all combinations of differnt contrasts. 

contrast_combinations = [(k_1, k_2) for k_1, k_2 in combinations(contrasts.keys(), 2)]

if len(contrast_combinations) > 0:

    fig, ax = plt.subplots(len(contrast_combinations),2,figsize=(10,5*len(contrast_combinations)))

    ax = ax.reshape((-1,2,))

    for i, (k_1, k_2) in enumerate(contrast_combinations):

        ax[i,0].scatter(
            dds.uns['stat_results'][k_1]['log2FoldChange'], 
            dds.uns['stat_results'][k_2]['log2FoldChange'], 
            s=1, 
            alpha=0.1,
        )
        
        ax[i,0].set_xlabel('log2 FC %s %s' % ('transcript',k_1))
        ax[i,0].set_ylabel('log2 FC %s %s' % ('transcript',k_2))

        
        ax[i,1].scatter(
            dds_gene.uns['stat_results'][k_1]['log2FoldChange'], 
            dds_gene.uns['stat_results'][k_2]['log2FoldChange'], 
            s=1, 
            alpha=0.1,
        )
        
        ax[i,1].set_xlabel('log2 FC %s %s' % ('gene',k_1))
        ax[i,1].set_ylabel('log2 FC %s %s' % ('gene',k_2))

### 3.9 Create scaled normed_counts to attempt to shift normed_counts distribution to be similar in magnitude to tpm 
* For various reasons, this doesn't work well in practice

In [None]:
# Fit scaling factors for normed counts to scale to the tpm distribution. This is to make visualizing the 
# deseq2 normed counts easier across experiments.

scale_factor_t = np.sum(
    dds.layers['normed_counts'] * dds.layers['raw_tpm']) 
    / np.sum(dds.layers['normed_counts']**2
)

scale_factor_g = np.sum(
    dds_gene.layers['normed_counts'] * dds_gene.layers['raw_tpm']) 
    / np.sum(dds_gene.layers['normed_counts']**2
)    
    
dds.layers['normed_counts_transform'] = dds.layers['normed_counts']*scale_factor_t
dds_gene.layers['normed_counts_transform'] = dds_gene.layers['normed_counts']*scale_factor_g

dds.uns['normed_counts_transform_scalefactor'] = scale_factor_t
dds_gene.uns['normed_counts_transform_scalefactor'] = scale_factor_g

In [None]:
# Wipe existing columns meannormedcounts in each output contrast if they exist.

column_prefixes = dds.obs.columns[dds.obs.colums.str.startswith(('condition-', 'group-'))].tolist() 
    + dds_gene.obs.columns[dds_gene.obs.colums.str.startswith(('condition-', 'group-'))].tolist()
                    [c for c in dds_gene.obs.columns if c.startswith('condition-') or c.startswith('group-')]

for k in contrasts.keys():
    dds.uns['stat_results'][k].drop(
        [
            c for c in dds.uns['stat_results'][k].columns if \
                any([True if c.startswith(cp) else False for cp in column_prefixes])
        ],
        axis=1,
        inplace=True,
    )
    
    dds_gene.uns['stat_results'][k].drop(
        [
            c for c in dds_gene.uns['stat_results'][k].columns if \
                any([True if c.startswith(cp) else False for cp in column_prefixes])
        ],
        axis=1,
        inplace=True,
    )

### 3.10 Merge normed_count group summaries into dataframe

In [None]:
# Create group mean counts and individual sample counts that will be merged into de_dataframes.

dds_gene.varm['normed_counts_group_mean'] = pd.DataFrame(index=dds_gene.var.index)
dds.varm['normed_counts_group_mean'] = pd.DataFrame(index=dds.var.index)
dds_gene.varm['normed_counts_group'] = pd.DataFrame(index=dds_gene.var.index)
dds.varm['normed_counts_group'] = pd.DataFrame(index=dds.var.index)

for c in dds_gene.obs.columns[dds_gene.obs.columns.str.startswith(('condition-', 'gene-'))]:
    levels = dds_gene.obs[c].unique()
    for l in levels:
        l_i = np.where(np.isin(dds_gene.obs.index.values, dds_gene.obs[dds_gene.obs[c] == l].index.values,))
        dds_gene.varm['normed_counts_group_mean']['%s_%s_meannormedcounts' % (c,l)] = \
            dds_gene.layers['normed_counts_transform'][l_i[0],:].mean(axis=0)
        
        dds_gene.varm['normed_counts_group'][['%s_%s_%s' % (c,l,s) for s in dds.obs.index[l_i]]] = \
            dds_gene.layers['normed_counts_transform'][l_i[0],:].T

        l_i = np.where(np.isin(dds.obs.index.values, dds.obs[dds.obs[c] == l].index.values,))
        dds.varm['normed_counts_group_mean']['%s_%s_meannormedcounts' % (c,l)] = \
            dds.layers['normed_counts_transform'][l_i[0],:].mean(axis=0)
        
        dds.varm['normed_counts_group'][['%s_%s_%s' % (c,l,s) for s in dds.obs.index[l_i]]] = \
            dds.layers['normed_counts_transform'][l_i[0],:].T


In [None]:
# Merge group-level means into stat_results dataframes.

for k,v in contrasts.items():
    
    keep_cols_mean = [
        c for c in dds.varm['normed_counts_group_mean'] 
        if c.startswith(v[0]) and c.split('_')[1] in (v[1],v[2],)
    ]
    
    keep_cols_ind = [
        c for c in dds.varm['normed_counts_group']
        if c.startswith(v[0]) and c.split('_')[1] in (v[1],v[2],)
    ]

    dds.uns['stat_results'][k] = dds.uns['stat_results'][k].merge(
        dds.varm['normed_counts_group_mean'][keep_cols_mean], 
        left_index=True, 
        right_index=True,
    )
    
    dds.uns['stat_results'][k] = dds.uns['stat_results'][k].merge(
        dds.varm['normed_counts_group'][keep_cols_ind],
        left_index=True,
        right_index=True,
    )
        
    dds_gene.uns['stat_results'][k] = dds_gene.uns['stat_results'][k].merge(
        dds_gene.varm['normed_counts_group_mean'][keep_cols_mean], 
        left_index=True, 
        right_index=True,
    )
    
    dds_gene.uns['stat_results'][k] = dds_gene.uns['stat_results'][k].merge(
        dds_gene.varm['normed_counts_group'][keep_cols_ind],
        left_index=True,
        right_index=True,
    )
    

In [None]:
# Define a contrast dataframe that outlines which condition/sample types were used in each contrast.

contrast_df = pd.DataFrame(
    columns=(
        ['contrast','contrast_condition'] 
        + list(set([c[0] for c in dds.uns['contrasts'].values()]))
    ),
)

for n, c in dds.uns['contrasts'].items():
    contrast_df.loc[len(contrast_df),['contrast','contrast_condition',c[0]]] = n,c[0],c[1]
    contrast_df.loc[len(contrast_df),['contrast','contrast_condition',c[0]]] = n,c[0],c[2]

for c in contrast_df.columns[contrast_df.columns.str.startswith('condition-')]:
    contrast_df = contrast_df.merge(dds.obs.reset_index(), on=c)

contrast_df.to_csv(RESULTS_PATH / f'{EXPERIMENT_ID}_metadata_contrast.csv', index=False)

In [None]:
# Dump results dataframes to results folder.

for k in contrasts.keys():

    dds.uns['stat_results'][k].to_csv(
        RESULTS_PATH / f'{EXPERIMENT_ID}_transcript_{k}.csv'
    )
    dds_gene.uns['stat_results'][k].to_csv(
        RESULTS_PATH / f'{EXPERIMENT_ID}_gene_{k}.csv'
    )

In [None]:
# Write dds objects to files for gsea analysis.

# Pydeseq2 supports trend_coeffs/replaced as either np.array or pd.series, np.array required for 
# saving h5-formatted AnnData objects.
dds.uns['trend_coeffs'] = np.array(dds.uns['trend_coeffs'])
dds_gene.uns['trend_coeffs'] = np.array(dds_gene.uns['trend_coeffs'])

dds.varm['replaced'] = np.array(dds.varm['replaced'])
dds_gene.varm['replaced'] = np.array(dds_gene.varm['replaced'])

# DeseqDataSet doesn't have native support for writing h5, save as AnnData objects and restore from
# AnnData objects.
dds.write(DDS_TRANSCRIPT_FH)
dds_gene.write(DDS_GENE_FH)