In [1]:
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import dictys
import joblib

In [2]:
from utils_custom import *
from pseudotime_curves import *
from episode_plots import *
from episodic_dynamics import *
from config import *

In [3]:
config = Config()

In [4]:
# Define file paths 
output_folder = os.path.join(config._BCELL_BASE, 'output')

In [5]:
# Load data
dictys_dynamic_object = dictys.net.dynamic_network.from_file('/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/output/dynamic.h5')

#### Getting attributes of the dynamic oject

In [None]:
psuedotime_values_of_windows = dictys_dynamic_object.point['s'].dist
#save the psuedotime values of the windows as csv file
psuedotime_values_of_windows_df = pd.DataFrame(psuedotime_values_of_windows[:,0], columns=['psuedotime'])
psuedotime_values_of_windows_df.to_csv(os.path.join(output_folder, 'psuedotime_values_of_windows.csv'), index=False)

In [None]:
# load cell labels
PATH_TO_CELL_LABELS = "/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/data/day_labels.csv"
day_labels_df = pd.read_csv(PATH_TO_CELL_LABELS)
day_labels = day_labels_df["cell_type_major"].tolist()
# get cell labels per window
day_labels_per_window = get_state_labels_in_window(dictys_dynamic_object, day_labels)

In [None]:
cell_count_per_window_df = window_labels_to_count_df(day_labels_per_window)
display(cell_count_per_window_df.head())
#save the cell_count_per_window_df as a csv
cell_count_per_window_df.to_csv(os.path.join(output_folder, 'day_count_per_window_df.csv'), index=True)

# Visualize custom TFs

In [None]:
#Trajectory branch defined as (starting node, ending node) from trajectory inference
#See main1.ipynb
branches={
	'PlasmaBlast':(0,2),
	'GerminalCenter':(0,3)
}

for branchname in branches:
	print(branchname)
	figs=dictys_dynamic_object.draw_discover(*branches[branchname],ntops=(12,12,12,12),num=20,dist=0.001,mode='TF_expression')
	plt.show()

In [None]:
custom_tfs = ['BACH2','PAX5', 'BATF','NFATC2','IKZF2','RBPJ']
tf_indices_custom, tf_gene_indices_custom, missing_tfs = get_tf_indices(dictys_dynamic_object, custom_tfs)
display(tf_indices_custom)
display(missing_tfs) #not present in the motif databases, hence not in the final GRN. QC filtering has been masked for important genes to not dropout.

In [None]:
custom_lf_pairs = [('PBX3','PAX5'),('RFX3','CEP128'),('CREB3L2','FNDC3A'),('CREB3L2','TXNDC5'),('CREB3L2','TRAM1'),('PAX5','RUNX2')]

In [None]:
custom_tf_links = [('PRDM1', 'CYB561A3'),('PRDM1', 'SPIB'),('PRDM1','CNPY3'),('PRDM1','TGFB1'),('PRDM1','FUT8'),('PRDM1','BABAM1'),('PRDM1','PNISR'),('PRDM1','BCAP31'),('PRDM1','FOXJ3'),('PRDM1','RBM38'),('PRDM1','NT5C'),('PRDM1','OCIAD1'),('PRDM1','NDUFB5'),('PRDM1','AIP'),('PRDM1','VAPA'),('PRDM1','LZIC'),('PRDM1','TMEM248'),('PRDM1','NAA38'),('PRDM1','EMC3'),('PRDM1','NSUN5'),('PRDM1','NIBAN3'),('PRDM1','STX7'),('PRDM1','MYBL2'),('PRDM1','PVT1'),('PRDM1','ARHGAP42'),('PRDM1','PACC1'),('PRDM1','DPYSL2'),('PRDM1','KCTD13'),('PRDM1','MTG2'),('PRDM1','RRP9'),('PRDM1','FOXK1'),('PRDM1','RCL1'),('PRDM1','PINX1'),('PRDM1','ADARB1'),('PRDM1','PDLIM5'),('PRDM1','NEK6'),('PRDM1','PPM1D'),('PRDM1','PRKAR1B'),('PRDM1','STK33'),('PRDM1','ANAPC4')]

In [None]:
# plot the heatmap of the cell-cycle pairs
fig, ax, dnet = fig_regulation_heatmap(
    network=dictys_dynamic_object,
    start=1,
    stop=2,
    regulations=custom_tf_links,
    num=50,
    dist=0.0005,
    cmap='RdBu_r' 
)

In [None]:
print("Original dnet values range:", dnet.min(), "to", dnet.max())
# Convert dnet to DataFrame with row labels
df = pd.DataFrame(dnet, 
                 index=["-".join(x) for x in custom_tf_links])
print("DataFrame values range:", df.values.min(), "to", df.values.max())
# Calculate max absolute value for symmetric color scaling
vmax_val = float(df.abs().values.max())  # Get max from numpy array values

In [None]:
# Use cluster_heatmap with absolute values for clustering
fig, x, y = cluster_heatmap(df.abs(),  # Use absolute values for clustering
                           dshow=df,    # Show original values in heatmap
                           dtop=0,      # No clustering on time points
                           dright=0.3,  # Cluster the links
                           method='ward',
                           metric='euclidean',
                           cmap='RdYlGn',   # Use the colormap name directly
                           aspect=0.1,
                           xtick=False,
                           vmin=-vmax_val,  # Set minimum value
                           vmax=vmax_val)   # Set maximum value

In [None]:
# plot expression gradient of RUNX2
fig, ax, cmap = fig_expression_gradient_heatmap(
    network=dictys_dynamic_object,
    start=0,
    stop=2,
    genes_or_regulations=['PAX5', 'RUNX2'],
    num=100,
    dist=0.0005,
    cmap='YlGn' # green is positive yellow is negative
)


#### Get curve characteristics like regulation, expression, etc.

In [None]:
from pseudotime_curves import *

In [None]:
gc_exp_curves_dy, gc_exp_curves_dx = SmoothedCurves(dictys_dynamic_object,
    trajectory_range=(0,3),
    num_points=100,
    dist=0.0005,
    sparsity=0.01,
    mode="expression").get_smoothed_curves()

In [None]:
gc_reg_curves_dy, gc_reg_curves_dx = SmoothedCurves(dictys_dynamic_object,
    trajectory_range=(0,3),
    num_points=100,
    dist=0.0005,
    sparsity=0.01,
    mode="regulation").get_smoothed_curves()

In [None]:
# Plot expression trajectories
fig = plt.figure(figsize=(10, 6))
ax = plt.gca()  # Get current axes

# Remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Highlight specific genes
genes_of_interest = ['BACH2','PAX5', 'BATF','NFATC2','IKZF2','RBPJ']
colors = ['green','blue','red','orange','purple','pink']   # Define colors for each gene

for gene, color in zip(genes_of_interest, colors):
    if gene in gc_exp_curves_dy.index:
        line = plt.plot(gc_exp_curves_dx, gc_exp_curves_dy.loc[gene], linewidth=2, color=color)
        # Add label at the end of the line
        plt.text(gc_exp_curves_dx.iloc[-1], gc_exp_curves_dy.loc[gene].iloc[-1], f' {gene}', 
                color=color, 
                verticalalignment='center')

plt.xlabel('Pseudotime')
plt.ylabel('Log (CPM)')
plt.title('GC branch')
#save the figure
plt.savefig(os.path.join(output_folder, 'gc_branch_episodic_tfs_expression.pdf'), dpi=300)
plt.show()

# Animations

In [None]:
pairs = [('PBX3','PAX5'), ('NRF1','PAX5'),('NRF1','RUNX2'),('PAX5','RUNX2')]

In [None]:
# plot the heatmap of the lineage pairs
fig, ax_gc, cmap = fig_regulation_heatmap(
    network=dictys_dynamic_object,
    start=0,
    stop=3,
    regulations=pairs,
    num=100,
    dist=0.0005,
    cmap='RdBu'
)

In [None]:
# Get the raw data from heatmap
heatmap_data_gc = ax_gc.images[0].get_array()

# Get unique TFs and targets from cell_cycle_pairs
tfs_gc = list(set(pair[0] for pair in pairs))
targets_gc = list(set(pair[1] for pair in pairs))

# Function to convert RdBu colors to regulation strength
def rdbu_to_regulation(rgb_values):
    # In RdBu, red indicates positive regulation (R > B)
    # and blue indicates negative regulation (B > R)
    return rgb_values[:, 0] - rgb_values[:, 2]  # R - B gives regulation strength

# Create a dictionary to store actual regulation values
regulation_dict_gc = {}
for row_idx, (tf, target) in enumerate(pairs):
    # Convert RGB values to regulation strength
    regulation_strength = rdbu_to_regulation(heatmap_data_gc[row_idx])
    regulation_dict_gc[(tf, target)] = regulation_strength

# Create 3D array (n_tf x n_target x n_timepoints)
regulation_array_gc = np.zeros((len(tfs_gc), len(targets_gc), heatmap_data_gc.shape[1]))
for tf_idx, tf in enumerate(tfs_gc):
    for target_idx, target in enumerate(targets_gc):
        if (tf, target) in regulation_dict_gc:
            regulation_array_gc[tf_idx, target_idx, :] = regulation_dict_gc[(tf, target)]

print("\nFinal array shape:", regulation_array_gc.shape)
print("TFs:", tfs_gc)
print("Targets:", targets_gc)

#### Animate the subgraph of the mentioned regulations

In [None]:
# get the weights of the lineage pairs across windows
tf_indices_lineage, tf_gene_indices_lineage, _ = get_tf_indices(dictys_dynamic_object, [pair[0] for pair in pairs])
gene_indices_lineage = get_gene_indices(dictys_dynamic_object, [pair[1] for pair in pairs])


In [None]:
import networkx as nx
import matplotlib.animation as animation
from matplotlib.patches import ArrowStyle, ConnectionStyle
import matplotlib.pyplot as plt

def create_network_animation(weights, tf_names, target_names, output_path, branch_name):
    """
    Create and save network animation showing weight changes across windows.
    Node positions remain fixed, only edge weights change.
    """
    n_tfs, n_targets, n_windows = weights.shape
    
    # Create initial graph with all possible nodes and edges
    G_init = nx.DiGraph()
    
    # Add TF nodes
    for tf in tf_names:
        G_init.add_node(tf, node_type='TF')
        
    # Add target nodes
    for target in target_names:
        G_init.add_node(target, node_type='target')
    
    # Calculate fixed layout once
    pos = nx.spring_layout(G_init, k=1, iterations=50)
    
    def create_graph(window_idx):
        G = G_init.copy()
        
        # Add edges with weights for this window
        for i, tf in enumerate(tf_names):
            for j, target in enumerate(target_names):
                weight = weights[i, j, window_idx]
                if abs(weight) > 0.1:  # Only show stronger connections
                    G.add_edge(tf, target, weight=weight)
        
        return G
    
    # Create figure
    fig, ax = plt.subplots(figsize=(10, 8))
    
    def update(frame):
        ax.clear()
        G = create_graph(frame)
        
        # Draw nodes using fixed positions
        tf_nodes = [n for n, attr in G.nodes(data=True) if attr['node_type'] == 'TF']
        target_nodes = [n for n, attr in G.nodes(data=True) if attr['node_type'] == 'target']
        
        # Draw nodes
        nx.draw_networkx_nodes(G, pos, nodelist=tf_nodes, node_color='lightblue', 
                             node_size=1000, label='TFs')
        nx.draw_networkx_nodes(G, pos, nodelist=target_nodes, node_color='lightgreen',
                             node_size=1000, label='Targets')
        
        # Draw edges with different styles based on weight
        edges = G.edges(data=True)
        if edges:
            # Separate positive and negative edges with their weights
            pos_edges = [(u, v, abs(d['weight'])) for (u, v, d) in edges if d['weight'] > 0]
            neg_edges = [(u, v, abs(d['weight'])) for (u, v, d) in edges if d['weight'] < 0]
            
            # Function to scale edge widths (min width 1, max width 5)
            def scale_width(weight):
                return 1 + 4 * (weight - 0.1) / (2.0 - 0.1)
            
            # Draw positive edges with arrow
            if pos_edges:
                nx.draw_networkx_edges(G, pos, 
                                     edgelist=[(u,v) for u,v,w in pos_edges], 
                                     edge_color='red',
                                     arrows=True,
                                     arrowsize=20,
                                     width=[scale_width(w) for _,_,w in pos_edges],
                                     arrowstyle='->')
            
            # Draw negative edges with custom repression style
            if neg_edges:
                nx.draw_networkx_edges(G, pos, 
                                     edgelist=[(u,v) for u,v,w in neg_edges],
                                     edge_color='blue',
                                     arrows=True,
                                     arrowsize=20,
                                     width=[scale_width(w) for _,_,w in neg_edges],
                                     arrowstyle='->')
        
        # Add labels with fixed positions
        nx.draw_networkx_labels(G, pos)
        
        # Add title and window info
        plt.title(f'Window {frame + 1}/{n_windows}')
        
        # Add legend with smaller markers
        plt.legend(markerscale=0.5, prop={'size': 8})
        
        return ax

    # Create animation
    anim = animation.FuncAnimation(fig, update, frames=n_windows, 
                                 interval=200, blit=False)
    
    # Save animation
    output_file = os.path.join(output_path, f'network_animation_{branch_name}.mp4')
    anim.save(output_file, writer='ffmpeg', fps=15)
    plt.close()
    
    return output_file

In [None]:
# First, let's verify our data
print("Regulation array shape:", regulation_array_gc.shape)
print("TFs:", tfs_gc)
print("Targets:", targets_gc)
print("Lineage pairs:", pairs)

# Create animation with the correct TF and target lists
output_file = create_network_animation(
    weights=regulation_array_gc,
    tf_names=tfs_gc,  # Use tfs_pb instead of extracting from lineage_pairs
    target_names=targets_gc,  # Use targets_pb instead of extracting from lineage_pairs
    output_path='/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/output',
    branch_name='GC'
)

print(f"Animation saved to: {output_file}")

### Branch specific animations

In [None]:
branches={
	'Plasmablast':(1,2),
	'Germinal-center':(1,3)
}
#Value range for coloring
vrange={
	'Terminal logFC':[-4,4],
	'Transient logFC':[-4,4],
	'Switching time':[0.0015,0.0045],
}

#### Update dictys object with cell type lables

In [None]:
cell_labels_file = '/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/data/clusters.csv'
#read the columns barcodes and cluster
cell_labels_df = pd.read_csv(cell_labels_file)
cell_labels = cell_labels_df['Cluster']
# Convert pandas Series to plain list, removing the index
cell_type_list = cell_labels.values.tolist() if isinstance(cell_labels, pd.Series) else list(cell_labels)
dictys_dynamic_object.prop['c']['color'] = cell_type_list
print(f"Number of cells with color labels: {len(dictys_dynamic_object.prop['c']['color'])}")
print(f"First few labels: {dictys_dynamic_object.prop['c']['color'][:5]}")

In [None]:
from IPython.display import FileLink
from dictys.plot import layout,panel
# Number of frames (interpolated time points/equispaced time points), use 100 or higher for finer resolution
nframe=20
# Animation FPS for saving. Determines speed of play
fps=0.10*nframe
# DPI for animation
dpi=100


In [None]:
branchname='Plasmablast'
# Select TFs for each row's dynamic subnetwork graph
tfs_subnet=[
	['PRDM1']
]
# Select TFs for each row's other plots
tfs_ann=[
	['IRF4','PRDM1','BATF', 'SPIB', 'BACH2']
]
# Select genes to annotate as targets in all rows
target_ann=['RUNX2','MZB1','PRDM1','AFF3', 'IRF4']


In [None]:
branch=branches[branchname]
# initialize layout with dist, n_points, dpi
layout1=layout.notch(dist=0.0005,nframe=nframe,dpi=dpi)
pts,fig,panels,animate_ka=layout1.draw(
	dictys_dynamic_object,branch,
	# Set genes to annotate
	bcde_tfs=tfs_ann,e_targets=target_ann,f_tfs=tfs_subnet,
	# Custom legend location for long cell type name
	a_ka={'scatterka':{'legend_loc':(0.6,1)}},
	# Custom configurations for color range
	e_ka={'lim':[-0.02,0.02]},
)
ca=panel.animate_generic(pts,fig,panels)
anim=ca.animate(**animate_ka)


In [None]:
w=matplotlib.animation.writers['ffmpeg_file'](fps=fps,codec='h264')
w.frame_format='jpeg'
fo=f'/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/output/20frames-{branchname}.mp4'
anim.save(fo,writer=w,dpi='figure')
display(FileLink(fo))


### Print window labels 

In [None]:
# get the proportion of cell types in each window
cell_labels_file = '/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/data/clusters.csv'
#read the columns barcodes and cluster
cell_labels_df = pd.read_csv(cell_labels_file, index_col=0)
cell_labels = cell_labels_df['Cluster']

In [None]:
# get the top 3 states from window 1,97,96 : 0,98,: 146,2: 193, 3
top_3_states = get_top_k_fraction_labels(dictys_dynamic_object, 135, cell_labels, k=3)
sorted_states = sorted(top_3_states, key=lambda x: x[1][0], reverse=True)
display(sorted_states)

In [None]:
# Create custom legend handles and labels
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D

leiden_color_dict = {
    'ActB-1': 'lightskyblue',     # light blue cluster at bottom
    'ActB-2': 'dodgerblue',       # darker blue cluster
    'ActB-4': 'mediumorchid',     # purple cluster
    'GC-1': 'limegreen',          # bright green cluster
    'ActB-3': 'darkblue',         # dark blue cluster
    'Naive': 'darkgray',          # gray cluster on left
    'GC-2': 'green',              # darker green cluster
    'PB-2': 'firebrick',          # red cluster
    'earlyPB': 'lightcoral',      # pink/coral cluster
    'earlyActB': 'teal'           # teal cluster
}
# Create custom legend handles and labels
legend_elements = [
    Line2D([0], [0], marker='o', color='w', markerfacecolor=leiden_color_dict[state[0]], 
           markersize=8, label=f'$\\mathbf{{{state[1][0]*100:.1f}}}$% {state[0]} state')
    for state in sorted_states
]

# Create a figure and axis
fig, ax = plt.subplots(figsize=(3, 2))  # Adjusted figure size for vertical layout
ax.set_visible(False)  # Hide the axis

# Add the legend to the figure vertically
fig.legend(handles=legend_elements, 
          loc='center', 
          ncol=1,  # Changed to 1 column for vertical layout
          frameon=False)
plt.savefig('/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/output/figures/legend_w135.pdf', 
            bbox_inches='tight',  # Ensures the legend isn't cut off
            dpi=300,             # High resolution
            format='pdf')


# Chromatin dynamics

In [6]:
from tqdm import tqdm
from multiprocessing import Pool, cpu_count
from functools import partial

def process_single_window(i, tfs, base_path):
    """
    Process a single window file and extract TF scores and counts.
    
    Parameters:
    -----------
    i : int
        Window index (1-194)
    tfs : list
        List of TFs to query
    base_path : str
        Base path to the binding.tsv.gz files
    
    Returns:
    --------
    tuple : (window_index, score_dict, count_dict)
    """
    try:
        # Read the binding file
        df = pd.read_csv(
            f'{base_path}/Subset{i}/binding.tsv.gz', 
            sep='\t', 
            compression='gzip'
        )
        df[['chr', 'start', 'end']] = df['loc'].str.split(':', expand=True)
        
        # Initialize results for this window
        window_scores = {}
        window_counts = {}
        
        # Process each TF
        for tf in tfs:
            tf_df = df[df['TF'] == tf]
            
            if not tf_df.empty:
                # Calculate mean score across chromosomes
                mean_score = tf_df.groupby('chr').agg({'score': 'mean'}).mean().values[0]
                window_scores[tf] = mean_score
                
                # Calculate mean count across chromosomes
                count_score = tf_df.groupby('chr').agg({'score': 'count'}).mean().values[0]
                window_counts[tf] = count_score
            else:
                # TF not found in this window
                window_scores[tf] = float('nan')
                window_counts[tf] = 0
        
        return (i, window_scores, window_counts)
    
    except Exception as e:
        print(f"Error processing window {i}: {e}")
        # Return NaN/0 for all TFs if file fails
        return (i, {tf: float('nan') for tf in tfs}, {tf: 0 for tf in tfs})

def multiprocess_tf_binding_data(tfs, base_path, n_windows=194, n_processes=None):
    """
    Multiprocess the extraction of TF binding data across all windows.
    
    Parameters:
    -----------
    tfs : list
        List of TFs to query
    base_path : str
        Base path to the tmp_dynamic folder
    n_windows : int
        Number of windows to process (default: 194)
    n_processes : int or None
        Number of processes to use (default: None uses all CPUs)
    
    Returns:
    --------
    tuple : (score_dict, count_dict)
        Dictionaries with TF -> list of values across windows
    """
    # Initialize result dictionaries
    score = {tf: [None] * n_windows for tf in tfs}
    count = {tf: [None] * n_windows for tf in tfs}
    
    # Determine number of processes
    if n_processes is None:
        n_processes = max(1, cpu_count() - 1)  # Leave one CPU free
    
    print(f"Processing {n_windows} windows using {n_processes} processes...")
    
    # Create partial function with fixed arguments
    process_func = partial(process_single_window, tfs=tfs, base_path=base_path)
    
    # Process windows in parallel
    with Pool(processes=n_processes) as pool:
        results = list(tqdm(
            pool.imap(process_func, range(1, n_windows + 1)),
            total=n_windows,
            desc="Processing windows"
        ))
    
    # Collect results
    for window_idx, window_scores, window_counts in results:
        for tf in tfs:
            score[tf][window_idx - 1] = window_scores[tf]
            count[tf][window_idx - 1] = window_counts[tf]
    
    return score, count

In [27]:
# list static and episodic TFs
static_tfs = ['PRDM1',
    'PAX5',
    'BATF',
    'BACH2',
    'ARID5B',
    'IRF4',
    'XBP1',
    'CREB3L2',
    'RUNX2',
    'TCF12']

episodic_tfs = ['MEF2C',
    'MAX',
    'USF2',
    'MEF2A',
    'IRF1',
    'IKZF3',
    'POU2F1',
    'TEAD2',
    'IRF7',
    'TFEC']

all_tfs = list(set(static_tfs + episodic_tfs))

base_path = '/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/tmp_dynamic'

# Process all TFs
score, count = multiprocess_tf_binding_data(
    tfs=all_tfs,
    base_path=base_path,
    n_windows=194,
    n_processes=None  # Use all available CPUs
)

Processing 194 windows using 127 processes...


Processing windows: 100%|██████████| 194/194 [00:05<00:00, 37.43it/s]


In [97]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

# Helper functions
def _to_float(v):
    if hasattr(v, "values"):
        arr = v.values
        if len(arr) == 0:
            return np.nan
        return arr[0]
    return float(v)

def order_series(series, idx_order):
    return np.array([series[i] if i < len(series) else np.nan for i in idx_order], dtype=float)

def smooth(series, sigma=2):
    return gaussian_filter1d(series, sigma=sigma)

# General processing function—count is a dict: count[tf]=[series_or_scalar, ...] for each TF
def process_tf_data(data_df, tfs, idx_pb, idx_gc, order=True, smooth_sigma=1):
    series_pb, series_gc = {}, {}
    for tf in tfs:
        vals = [(_to_float(v)) for v in data_df.get(tf, [])]
        series_pb[tf] = smooth(order_series(vals, idx_pb), sigma=smooth_sigma) if order else np.array(vals)
        series_gc[tf] = smooth(order_series(vals, idx_gc), sigma=smooth_sigma) if order else np.array(vals)
    return series_pb, series_gc

# Map window indices to pseudotime values
def get_pseudotimes_for_indices(window_indices, window_pseudotimes):
    """Map window indices to their corresponding pseudotime values"""
    return np.array([window_pseudotimes[idx] for idx in window_indices])

# Get pseudotime values for all windows
aligner = AlignTimeScales(
    dictys_dynamic_object=dictys_dynamic_object,
    trajectory_range=(1, 2),  # any trajectory starting from 1
    num_points=100,
    dist=0.0005,
    sparsity=0.01
)
window_pseudotimes = aligner.pseudotime_of_windows()

In [None]:
import numpy as np
import plotly.graph_objects as go

PB_fate_window_indices = [1] + list(range(97, 3, -1)) + [0] + list(range(98, 147, 1)) + [2]
GC_fate_window_indices = [1] + list(range(97, 3, -1)) + [0] + list(range(147, 193, 1)) + [3]
PB_post_bifurcation_window_indices = [0] + list(range(98, 147, 1)) + [2]
GC_post_bifurcation_window_indices = [0] + list(range(147, 193, 1)) + [3]

# Get scores y-values
series_score_pb, series_score_gc = process_tf_data(
    score, all_tfs, PB_fate_window_indices, GC_fate_window_indices, order=True, smooth_sigma=2
)

series_count_pb, series_count_gc = process_tf_data(
    count, all_tfs, PB_fate_window_indices, GC_fate_window_indices, order=True, smooth_sigma=2
)

# Get pseudotime x-axes
x_pb_pseudotime = get_pseudotimes_for_indices(PB_fate_window_indices, window_pseudotimes)
x_gc_pseudotime = get_pseudotimes_for_indices(GC_fate_window_indices, window_pseudotimes)


In [98]:
# static_colors = {
#     'PRDM1': '#064E3B',    # Dark forest green
#     'PAX5': '#D1FAE5',     
#     'BATF': '#047857',     # Emerald green
#     'BACH2': '#059669',    # Medium emerald
#     'ARID5B': '#10B981',   # Bright emerald
#     'IRF4': '#34D399',     # Light emerald
#     'CREB3L2': '#6EE7B7',     # Mint green
#     'XBP1': '#10B981',     # Bright emerald
#     'RUNX2': '#A7F3D0',    # Pale mint
#     'TCF12': '#065F46'
# }

static_colors = {
    'CREB3L2': '#B8860B',  # Dark goldenrod
    'PRDM1': '#D4A017',    # Muted gold
    'IRF4': '#E6B800',     # Muted bright yellow
    'BATF': '#F0CF85',     # Light muted yellow
    'PAX5': '#F5E5B8',     # Pale cream-yellow
}

# episodic_colors = {
#     'MEF2C': '#7F1D1D',    # Dark crimson
#     'MAX': '#991B1B',      # Deep red
#     'USF2': '#B91C1C',     # Strong red
#     'MEF2A': '#FEE2E2',    # Bright red
#     'IRF1': '#EF4444',     # Light red
#     'IKZF3': '#F87171',    # Coral red
#     'POU2F1': '#FCA5A5',   # Light coral
#     'TEAD2': '#FECACA',    # Pale coral
#     'IRF7': '#DC2626',     # Very light pink
#     'TFEC': '#E11D48',     # Rose red (alternate)
# }

episodic_colors = {
    'USF2': '#6B21A8',     # Deep purple
    'IRF1': '#7E22CE',     # Vivid purple
    'TFEC': '#9333EA',     # Bright violet
    'MEF2A': '#A855F7',    # Light purple
    'MEF2C': '#C084FC',    # Pale lavender
}

# Combine into master dictionary
tf_colors = {**static_colors, **episodic_colors}
max_pseudotime = max(np.nanmax(x_pb_pseudotime), np.nanmax(x_gc_pseudotime))

def get_tf_category(tf):
    """Determine if TF is Static or Episodic"""
    if tf in static_colors:
        return "Static"
    elif tf in episodic_colors:
        return "Episodic"
    else:
        return "Unknown"

In [104]:
import numpy as np
import plotly.graph_objects as go

def plot_tf_chromatin_dynamics(
    data_df: dict, 
    y_axis_label: str,
    title: str = None
    ) -> go.Figure:
    """
    Generates a plot of TF dynamics over pseudotime for PB and GC trajectories.

    This function assumes several variables are defined in the global scope:
    - process_tf_data: A function to process the raw data.
    - get_pseudotimes_for_indices: A function to get pseudotime values.
    - all_tfs: A list of all transcription factors to plot.
    - PB_fate_window_indices, GC_fate_window_indices: Lists of window orders.
    - window_pseudotimes_pb, window_pseudotimes_gc: Mappings of windows to pseudotime.
    - static_colors, episodic_colors: Dictionaries mapping TFs to colors.

    Parameters:
    -----------
    data_df : dict
        The data to plot (e.g., the 'score' or 'count' dictionary).
        It should map TF names to lists of numerical values.
    y_axis_label : str
        The label for the y-axis (e.g., "Binding Score" or "Binding Site Count").
    title : str, optional
        A custom title for the plot. If None, a default is generated.

    Returns:
    --------
    go.Figure
        The configured Plotly figure object, ready to be shown or saved.
    """
    
    # 1. Process raw data to get smoothed Y-values for each trajectory
    series_pb, series_gc = process_tf_data(
        data_df, all_tfs, PB_fate_window_indices, GC_fate_window_indices, 
        order=True, smooth_sigma=2
    )

    # 2. Get the corresponding X-values (pseudotime)
    x_pb_pseudotime = get_pseudotimes_for_indices(PB_fate_window_indices, window_pseudotimes_pb)
    x_gc_pseudotime = get_pseudotimes_for_indices(GC_fate_window_indices, window_pseudotimes_gc)

    # 3. Truncate the longer (PB) data to match the end of the shorter (GC) data
    min_max_pseudotime = np.nanmax(x_gc_pseudotime)
    pb_mask = x_pb_pseudotime <= min_max_pseudotime

    # 4. Create the plot figure
    fig = go.Figure()

    # Loop through TF categories to plot and group them in the legend
    for category, tf_dict in [("Static", static_colors), ("Episodic", episodic_colors)]:
        for i, tf in enumerate(tf_dict.keys()):
            tf_color = tf_dict[tf]
            
            # Add GC trace (dashed)
            fig.add_trace(go.Scatter(
                x=x_gc_pseudotime,
                y=series_gc[tf],
                mode='lines',
                name=tf,
                line=dict(dash='dash', color=tf_color, width=2.5),
                legendgroup=tf,
                legendgrouptitle_text=category if i == 0 else None,
                showlegend=True
            ))
            
            # Add truncated PB trace (solid)
            fig.add_trace(go.Scatter(
                x=x_pb_pseudotime[pb_mask],
                y=series_pb[tf][pb_mask],
                mode='lines',
                name=tf,
                line=dict(dash='solid', color=tf_color, width=2.5),
                legendgroup=tf,
                showlegend=False 
            ))
            
    # 5. Configure the layout
    if title is None:
        title = f"<b>TF {y_axis_label} Dynamics</b>"
        
    fig.update_layout(
        title=dict(text=title, x=0.5),
        xaxis=dict(
            title='Pseudotime',
            showgrid=True,
            range=[0, min_max_pseudotime] 
        ),
        yaxis=dict(title=y_axis_label),
        legend=dict(
            orientation='v', 
            x=1.02, 
            y=0.5,
            tracegroupgap=25,
            title_text="<b>TF Categories</b><br>(Solid=PB, Dashed=GC)"
        ),
        margin=dict(t=100, r=250),
        plot_bgcolor="white",
        paper_bgcolor="white",
        font=dict(family="Arial, sans-serif")
    )

    return fig

In [106]:
fig = plot_tf_chromatin_dynamics(
    data_df=score, 
    y_axis_label="Binding Score"
)

# Display the figure
fig.show()

fig.write_image("/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/output/figures/tf_binding_score_dynamics_by_category.svg", format='svg', width=2000, height=600)

In [107]:
fig = plot_tf_chromatin_dynamics(
    data_df=count, 
    y_axis_label="OCR Counts per TF"
)

# Display the figure
fig.show()
fig.write_image("/ocean/projects/cis240075p/asachan/datasets/B_Cell/multiome_1st_donor_UPMC_aggr/dictys_outs/actb1_added_v2/output/figures/tf_count_dynamics_by_category.svg", format='svg', width=2000, height=600)