## Visualize and Interpret WOT Results

In [20]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import anndata
import scipy.stats 
import wot
import math
from matplotlib.lines import Line2D
import ipywidgets as widgets

In [21]:
ADATA_PATH = 'data/wot_input.h5ad'
OUTPUT_PATH = 'output/wot/'
COMMON_NAMES = 'data/common-names.csv'
TF_PATH = 'data/tfs.txt'

### Setup

In [22]:
# Load the transcription factors and common gene names
tfs = pd.read_csv(TF_PATH, header=None, index_col=0)
tfs = list(tfs.index)
common_names = pd.read_csv(COMMON_NAMES, index_col=0)

In [23]:
# Create a color dictionary for visualization
lineages = ["Quiescent Center", "Stem Cell Niche", "Columella", "Lateral Root Cap", "Atrichoblast", "Trichoblast",
            "Cortex", "Cortex -", "Cortex +", "Endodermis", "Pericycle",  "Phloem", "Xylem", "Procambium"]
lineage_colors = ["#9400d3", 'tab:pink', "#5ab953", "#bfef45", "#008080", "#21B6A8", 
                  "#82b6ff", "#82b6ff", "cyan","#0000FF","#ff9900", "#e6194b", "#9a6324", "#ffe119"]
lineage_color_dict = {}

for i in range(len(lineages)):
    lineage_color_dict[lineages[i]] = lineage_colors[i]
    
# Create the legend for the color dictionary
def lineage_leg_elm(n):
    return Line2D([0], [0], marker='o', color='w', label=lineages[n], markerfacecolor=lineage_colors[n], markersize=10)

lineage_legend_elems = [lineage_leg_elm(n) for n in list(range(len(lineages)))]

In [24]:
# Create a tissue color dictionary
tissue_color_dict = {'endodermis': '#0000FF', 'cortex': '#82b6ff', 'stele': 'tab:orange', 'columella': '#5ab953',
                     'epidermis': '#21B6A8', 'QC': '#9400d3'}

In [25]:
# Load the anndata and set the reference time
adata = anndata.read_h5ad(ADATA_PATH)
T = 2

In [26]:
# Load trajectories and fates
type_trajectory_ds = anndata.read_h5ad(OUTPUT_PATH + 'type_trajectories_T={}.h5ad'.format(T))
typezone_trajectory_ds = anndata.read_h5ad(OUTPUT_PATH + 'typezone_trajectories_T={}.h5ad'.format(T))
tissue_trajectory_ds = anndata.read_h5ad(OUTPUT_PATH + 'tissue_trajectories_T={}.h5ad'.format(T))
fate_ds = anndata.read_h5ad(OUTPUT_PATH + "fate-adata-T={}.h5ad".format(T))

In [27]:
# Load celltype trends
trajectory_trend_datasets = []
trajectory_names = []

for i in range(type_trajectory_ds.shape[1]):
    trajectory_names.append(type_trajectory_ds.var.index[i]) 
    data = anndata.read_h5ad(OUTPUT_PATH + "trends/T={}-{}_trends.h5ad".format(T, type_trajectory_ds.var.index[i]))
    
    trajectory_trend_datasets.append(data)
    
AT1G_indexed_names = trajectory_trend_datasets[1].var.copy()
common_indexed_names = trajectory_trend_datasets[1].var.copy().set_index('Name')

In [28]:
# Load tissue trends
tissue_trajectory_trend_datasets = []
tissue_trajectory_names = []

for i in range(tissue_trajectory_ds.shape[1]):
    tissue_trajectory_names.append(tissue_trajectory_ds.var.index[i]) 
    data = anndata.read_h5ad(OUTPUT_PATH + "trends/T={}-{}_trends.h5ad".format(T, tissue_trajectory_ds.var.index[i]))
    
    tissue_trajectory_trend_datasets.append(data)
    
AT1G_indexed_names = tissue_trajectory_trend_datasets[1].var.copy()
common_indexed_names = tissue_trajectory_trend_datasets[1].var.copy().set_index('Name')

In [29]:
# Load differential expression results
diff_exp_results = {}
diff_exp_types = list(type_trajectory_ds.var.index)
diff_exp_types.sort()

for t in diff_exp_types:
    diff_exp_results[t] = pd.read_csv(OUTPUT_PATH + "diff-exp/diff-exp-T={}-{}.csv".format(T,t), index_col=0)

In [30]:
# Load the list of transcription factors
tfs = pd.read_csv(TF_PATH, header=None, index_col=0)
tfs = list(tfs.index)

### Visualize Celltype Trajectories

In [31]:
# Visualize trajectories by cell type
times = pd.unique(adata.obs['exp_time'])
times.sort()

trajectory_dropdown = widgets.Dropdown(
    options=type_trajectory_ds.var.index,
    description='Trajectory:'
)

def update_trajectory_vis(name):
    ncols = 3
    nrows = math.ceil(len(times)/ncols)
    size = 3
    fig = plt.figure(figsize=(ncols*size, nrows*size))
    plt.suptitle(name, fontsize=14, fontweight='bold')
    subplot = 1
    for time in times:
        data = adata.obs[adata.obs['exp_time'] == time]
        plt.subplot(nrows, ncols, subplot)
        subplot = subplot + 1
        plt.axis('off')
        plt.tight_layout()
        plt.title("{} hr".format(time))
        plt.scatter(data['umap_x'], data['umap_y'], c='#f0f0f0',
                       s=4, marker=',', edgecolors='none', alpha=0.8)
        trajectory_ds_filtered = type_trajectory_ds[type_trajectory_ds.obs['exp_time'] == time]
        binned_df = trajectory_ds_filtered.obs.copy()
        binned_df['values'] = trajectory_ds_filtered[:, name].X
        binned_df = binned_df.groupby(['umap_x', 'umap_y'], as_index=False).sum()
        plt.scatter(binned_df['umap_x'], binned_df['umap_y'], c=binned_df['values'],
                       s=6, edgecolors='none')

widgets.interact(update_trajectory_vis, name=trajectory_dropdown)

interactive(children=(Dropdown(description='Trajectory:', options=('Atrichoblast', 'Columella', 'Cortex', 'Cor…

<function __main__.update_trajectory_vis(name)>

In [32]:
# Visualize trajectories by cell type and developmental zone
zonetypes = list(typezone_trajectory_ds.var.index)
zonetypes.sort()

trajectory_dropdown = widgets.Dropdown(
    options=zonetypes,
    description='Trajectory:'
)

def update_trajectory_vis(name):
    ncols = 3
    nrows = math.ceil(len(times)/ncols)
    size = 3
    fig = plt.figure(figsize=(ncols*size, nrows*size))
    plt.suptitle(name, fontsize=14, fontweight='bold')
    subplot = 1
    for time in times:
        data = adata.obs[adata.obs['exp_time'] == time]
        plt.subplot(nrows, ncols, subplot)
        subplot = subplot + 1
        plt.axis('off')
        plt.tight_layout()
        plt.title("{} hr".format(time))
        plt.scatter(data['umap_x'], data['umap_y'], c='#f0f0f0',
                       s=4, marker=',', edgecolors='none', alpha=0.8)
        trajectory_ds_filtered = typezone_trajectory_ds[typezone_trajectory_ds.obs['exp_time'] == time]
        binned_df = trajectory_ds_filtered.obs.copy()
        binned_df['values'] = trajectory_ds_filtered[:, name].X
        binned_df = binned_df.groupby(['umap_x', 'umap_y'], as_index=False).sum()
        plt.scatter(binned_df['umap_x'], binned_df['umap_y'], c=binned_df['values'],
                       s=6, marker=',', edgecolors='none')

widgets.interact(update_trajectory_vis, name=trajectory_dropdown)

interactive(children=(Dropdown(description='Trajectory:', options=('Distal Columella-Columella', 'Distal Later…

<function __main__.update_trajectory_vis(name)>

### Visualize Gene Expression Trajectories

In [33]:
# Widget to search for genes by common name
search_text = widgets.Text(
    value='SCAMP',
    placeholder='Text to search',
    description='Search:',
    disabled=False
)

def table(search):
    genes = pd.DataFrame({'Genes' : tissue_trajectory_trend_datasets[1].var.Name})
    return [g for g in genes.Genes if search.lower() in g.lower()]

widgets.interact(table, search=search_text)

interactive(children=(Text(value='SCAMP', description='Search:', placeholder='Text to search'), Output()), _do…

<function __main__.table(search)>

In [34]:
# Create a widget to interactively visualize expression trends by tissue
times = pd.unique(adata.obs['exp_time'])
times.sort()

tissue_name_type = widgets.RadioButtons(
    options=['Common Names', 'AT1G Names'],
    value='Common Names',
    description='Gene Names:',
    disabled=False
)

tissue_AT1G_gene_input = widgets.Text(
    placeholder='',
    description='Genes:',
    value='AT1G01020',
    continuous_update=False
)

tissue_common_gene_input = widgets.Text(
    placeholder='',
    description='Genes:',
    value='ARV1',
    continuous_update=False
)

# Helper to toggle cell selection controls on and off
def view_controls(mode, controls):
    for control in controls:
        control.layout.display = 'flex' if mode else 'none'

def update_gene_vis(name_type, AT1G_gene_names, common_gene_names):
    # Based on the gene name toggle get common and AT1G named gene lists
    if name_type == 'Common Names':
        common_names = common_gene_names.replace(' ', '').split(',')
        common_names = [gene for gene in common_names if gene in common_indexed_names.index]
        AT1G_names = [common_indexed_names.loc[gene, 'gene'] for gene in common_names]
        view_controls(True, [tissue_common_gene_input])
        view_controls(False, [tissue_AT1G_gene_input])
    else:
        AT1G_names = AT1G_gene_names.replace(' ', '').split(',')
        AT1G_names = [gene for gene in AT1G_names if gene in AT1G_indexed_names.index]
        common_names = [AT1G_indexed_names.loc[gene, 'Name'] for gene in AT1G_names]
        view_controls(True, [tissue_AT1G_gene_input])
        view_controls(False, [tissue_common_gene_input])
    
    n_cols = 3
    n_rows = math.ceil(len(AT1G_gene_names)/n_cols)
        
    figure = plt.figure(figsize=(5*n_cols, 5*n_rows))
    for j in range(len(AT1G_names)):
        plt.subplot(n_rows,n_cols,j+1)
        plt.title("{} ({})".format(common_names[j], AT1G_names[j]))
        for i in range(len(tissue_trajectory_ds.var.index)):
            selected_trajectory = tissue_trajectory_ds.var.index[i]
            if selected_trajectory != 'QC':
                trajectory_index = tissue_trajectory_names.index(selected_trajectory)
                mean = tissue_trajectory_trend_datasets[trajectory_index][:, AT1G_names[j]]
                timepoints = mean.obs.index.values.astype(float)
                mean.obs.index = mean.obs.index.astype('category')

                plt.plot(timepoints, mean.X, c=tissue_color_dict[selected_trajectory], label=selected_trajectory,)
                if j == 0:
                    plt.legend()
                plt.xlabel("Time Point")
                plt.ylabel("Expression")

widgets.interact(update_gene_vis, name_type = tissue_name_type, common_gene_names=tissue_common_gene_input, AT1G_gene_names=tissue_AT1G_gene_input)

interactive(children=(RadioButtons(description='Gene Names:', options=('Common Names', 'AT1G Names'), value='C…

<function __main__.update_gene_vis(name_type, AT1G_gene_names, common_gene_names)>

In [35]:
# Compare gene expression trajectories for cortex + and - cells
name_type = widgets.RadioButtons(
    options=['Common Names', 'AT1G Names'],
    value='Common Names',
    description='Gene Names:',
    disabled=False
)

AT1G_gene_input = widgets.Text(
    placeholder='',
    description='Genes:',
    value='AT1G01020',
    continuous_update=False
)

common_gene_input = widgets.Text(
    placeholder='',
    description='Genes:',
    value='ARV1',
    continuous_update=False
)

# Helper to toggle cell selection controls on and off
def view_controls(mode, controls):
    for control in controls:
        control.layout.display = 'flex' if mode else 'none'

def update_gene_vis(name_type, AT1G_gene_names, common_gene_names):
    # Based on the gene name toggle get common and AT1G named gene lists
    if name_type == 'Common Names':
        common_names = common_gene_names.replace(' ', '').split(',')
        common_names = [gene for gene in common_names if gene in common_indexed_names.index]
        AT1G_names = [common_indexed_names.loc[gene, 'gene'] for gene in common_names]
        view_controls(True, [common_gene_input])
        view_controls(False, [AT1G_gene_input])
    else:
        AT1G_names = AT1G_gene_names.replace(' ', '').split(',')
        AT1G_names = [gene for gene in AT1G_names if gene in AT1G_indexed_names.index]
        common_names = [AT1G_indexed_names.loc[gene, 'Name'] for gene in AT1G_names]
        view_controls(True, [AT1G_gene_input])
        view_controls(False, [common_gene_input])
        
    n_cols = 3
    n_rows = math.ceil(len(AT1G_names)/n_cols)
    
    figure = plt.figure(figsize=(5*n_cols, 5*n_rows))
    for j in range(len(AT1G_names)):
        plt.subplot(n_rows,n_cols,j+1)
        plt.title("{} ({})".format(common_names[j], AT1G_names[j]))
        for i in ['Cortex +', 'Cortex -']:
            selected_trajectory = i 
            trajectory_index = trajectory_names.index(selected_trajectory)
            mean = trajectory_trend_datasets[trajectory_index][:, AT1G_names[j]]
            timepoints = mean.obs.index.values.astype(float)
            mean.obs.index = mean.obs.index.astype('category')
            
            plt.plot(timepoints, mean.X, c=lineage_color_dict[selected_trajectory], label=selected_trajectory)
            if j == 0:
                plt.legend()
            plt.xlabel("Time Point")
            plt.ylabel("Expression")

widgets.interact(update_gene_vis, name_type=name_type, AT1G_gene_names=AT1G_gene_input, common_gene_names=common_gene_input)

interactive(children=(RadioButtons(description='Gene Names:', options=('Common Names', 'AT1G Names'), value='C…

<function __main__.update_gene_vis(name_type, AT1G_gene_names, common_gene_names)>

### Visualize Pairs of Fates on Barycentric Coordinates

In [36]:
# Visualize pairs of fates on barycentric coordinates, coloring cells by annotation
fate_dropdown1 = widgets.Dropdown(
    options=type_trajectory_ds.var.index,
    description='Endodermis'
)
fate_dropdown2 = widgets.Dropdown(
    options=type_trajectory_ds.var.index,
    description='Fate 2:',
    value='Cortex'
)
time_dropdown = widgets.Dropdown(
    options=fate_ds.obs['exp_time'].unique(),
    description='Time',
    value=T
)


def update_fate_vis(name1, name2, hour):
    figure = plt.figure(figsize=(6, 6))   
    
    if name1 == 'Cortex':
        fate1 = fate_ds[:,'Cortex -'][fate_ds.obs['exp_time'] == hour].X.flatten() + fate_ds[:,'Cortex +'][fate_ds.obs['exp_time'] == hour].X.flatten()
    else:
        fate1 = fate_ds[:,name1][fate_ds.obs['exp_time'] == hour].X.flatten()

    if name2 == 'Cortex':
        fate2 = fate_ds[:,'Cortex -'][fate_ds.obs['exp_time'] == hour].X.flatten() + fate_ds[:,'Cortex +'][fate_ds.obs['exp_time'] == hour].X.flatten()
    else:
        fate2 = fate_ds[:,name2][fate_ds.obs['exp_time'] == hour].X.flatten()

    Nrows = len(fate1)
    x = np.zeros(Nrows)
    y = np.zeros(Nrows)
    P = np.array([[1,0],[np.cos(2*math.pi/3),math.sin(2*math.pi/3)],[math.cos(4*math.pi/3),math.sin(4*math.pi/3)]])

    for i in range(0,Nrows):
        ff = np.array([fate1[i],fate2[i],1-(fate1[i]+fate2[i])])
        x[i] = (ff @ P)[0]
        y[i] = (ff @ P)[1]

    t1 = plt.Polygon(P, color=(0,0,0,0.1))
    plt.gca().add_patch(t1)
    
    plt.scatter(x,y,c=fate_ds.obs['celltype'][fate_ds.obs['exp_time'] == hour].apply(lambda x: lineage_color_dict[x]))
    plt.text(P[0,0]+.1, P[0,1], name1)
    plt.text(P[1,0]-.1, P[1,1]+.1, name2)
    plt.text(P[2,0]-.1, P[2,1]-.2, 'Other')
    plt.axis('equal')
    plt.axis('off')
    
    plt.title('{} vs. {}'.format(name1, name2))

widgets.interact(update_fate_vis, name1=fate_dropdown1, name2=fate_dropdown2, hour=time_dropdown)

interactive(children=(Dropdown(description='Endodermis', options=('Atrichoblast', 'Columella', 'Cortex', 'Cort…

<function __main__.update_fate_vis(name1, name2, hour)>

### Differential Expression

In [42]:
# Function to filter diff exp results based on widget input
def diff_exp_table(adata, t, day, t2, timepoints='day only', method='pairwise', genes='all', n=20):
    # Set timepoints
    if timepoints == 'all':
         days = pd.unique(adata.obs['exp_time'])
    elif timepoints == 'ancestors only':
         days = [d for d in pd.unique(adata.obs['exp_time']) if d < day]
    elif timepoints == 'descendants only':
         days = [d for d in pd.unique(adata.obs['exp_time']) if d > day]
    else:
         days = [day]

    # Get results
    results = diff_exp_results[t]
    
    # Filter results
    results = results[results['Comparison Type'] == t2]
        
    if (genes == 'TFs'):
        results = results[results['TF'] == True]
        
    results = results[results['Time'].apply(lambda day: day in days)]
    
    # Truncate
    results = results[(results['t_fdr']<0.01)].sort_values('Expression Ratio', ascending=False).head(n)
    
    diff_exp_cols = ['Common Name', 'Time', 'Comparison Type', '{} Exp. Mean'.format(t), 
                     'Comparison Exp. Mean', '{} Prop. Exp.'.format(t), 'Comparison Prop. Exp.', 
                     'Expression Ratio']
    
    return results[diff_exp_cols]

In [43]:
# Visualize pairs of fates for a single time
T_prime = int(T) if T != 0.5 else T

diff_fate_dropdown1 = widgets.Dropdown(
    options=diff_exp_types,
    description='Fate 1:',
    value='Cortex +'
)

diff_fate_dropdown2 = widgets.Dropdown(
    options=diff_exp_types,
    description='Fate 2:',
    value='Cortex -'
)

diff_t_dropdown = widgets.Dropdown(
    options=[0.,0.5,1,2,4,8],
    description='Time:',
    value=T
)

n_slider = widgets.IntSlider(
    min = 0,
    max = 100,
    step=5,
    value=15,
    description='Results:'
)

gene_type = widgets.RadioButtons(
    options=['all', 'TFs'],
    value='all',
    description='Genes:',
    disabled=False
)

days_type = widgets.RadioButtons(
    options=['all', 'ancestors only', 'descendants only'],
    value='all',
    description='Timepoints:',
    disabled=False
)

output = widgets.Output()

# Helper to toggle cell selection controls on and off
def view_controls(mode, controls):
    for control in controls:
        control.layout.display = 'flex' if mode else 'none'

def update_diff_exp_vis(name1, name2, t, days_type, genes, n):
    data = adata
    with output:
        output.clear_output()
        display(diff_exp_table(adata=data, t=name1, t2=name2, day=t, timepoints=days_type, genes=genes, n=n))

out = widgets.interactive(update_diff_exp_vis, name1=diff_fate_dropdown1, name2=diff_fate_dropdown2, t=diff_t_dropdown, 
                          days_type=days_type, genes=gene_type, n=n_slider)

display(widgets.VBox([ widgets.HBox([widgets.VBox([diff_fate_dropdown1, diff_fate_dropdown2, diff_t_dropdown, n_slider]), 
                                     gene_type, days_type]),
                       output]))

out.update()

VBox(children=(HBox(children=(VBox(children=(Dropdown(description='Fate 1:', index=3, options=('Atrichoblast',…