# Downstream Analysis

This notebook should be used for downstream analysis of your OPS screen.
Cells marked with <font color='red'>SET PARAMETERS</font> contain crucial variables that need to be set according to your specific experimental setup and data organization.
Please review and modify these variables as needed before proceeding with the analysis.

## <font color='red'>SET PARAMETERS</font>

### Fixed parameters for cluster module

- `CONFIG_FILE_PATH`: Path to a Brieflow config file used during processing. Absolute or relative to where workflows are run from.

In [None]:
CONFIG_FILE_PATH = "config/config.yml"

In [None]:
from pathlib import Path
import warnings
warnings.filterwarnings('ignore', category=UserWarning)
import yaml
import pandas as pd
import matplotlib as plt

from lib.cluster.cluster_analysis import (
    differential_analysis, 
    waterfall_plot, 
    two_feature_plot, 
    cluster_heatmap
)
from lib.cluster.phate_leiden_clustering import plot_phate_leiden_clusters
from lib.shared.metrics import get_all_stats

In [None]:
# load config file and determine root path
with open(CONFIG_FILE_PATH, "r") as config_file:
    config = yaml.safe_load(config_file)
    ROOT_FP = Path(config["all"]["root_fp"])

## Pipeline Statistics Report

This function analyzes the entire data processing pipeline from raw microscopy images to clustered perturbation profiles. It will:

1. **Preprocessing**: Count input files and generated image tiles
2. **SBS**: Analyze cell segmentation and barcode mapping success rates
3. **Phenotype**: Count cells and morphological features extracted
4. **Merge**: Calculate how many cells were successfully matched between SBS and phenotype data
5. **Aggregation**: Measure perturbation coverage and batch effect correction (per cell class/channel)
6. **Clustering**: Evaluate pathway enrichment and optimal clustering parameters

**Expected runtime: 5-15 minutes** depending on dataset size, as this function loads and analyzes large parquet files to calculate batch statistics and enrichment metrics.

The report will display progress updates as each stage is analyzed.

In [None]:
statistics = get_all_stats(config)

In [None]:
# Load cell classes and channel combos
cluster_combo_fp = config["cluster"]["cluster_combo_fp"]
cluster_combos = pd.read_csv(cluster_combo_fp, sep="\t")

CHANNEL_COMBOS = list(cluster_combos["channel_combo"].unique())
print(f"Channel Combos: {CHANNEL_COMBOS}")

CELL_CLASSES = list(cluster_combos["cell_class"].unique())
print(f"Cell classes: {CELL_CLASSES}")

LEIDEN_RESOLUTION = list(cluster_combos["leiden_resolution"].unique())
print(f"Leiden resolution: {LEIDEN_RESOLUTION}")

## <font color='red'>SET PARAMETERS</font>

### Cluster Selection for Analysis

Set these parameters to select the specific cluster data to analyze:
- `CHANNEL_COMBO`: Select from available channel combinations,
- `CELL_CLASS`: Select from available cell classes,
- `LEIDEN_RESOLUTION`: Select from available Leiden resolutions,

These parameters determine which folder of cluster data will be analyzed.

In [None]:
CHANNEL_COMBO = None
CELL_CLASS = None
LEIDEN_RESOLUTION = None

In [None]:
aggregate_file = ROOT_FP / "aggregate" / "tsvs" / f"CeCl-{CELL_CLASS}_ChCo-{CHANNEL_COMBO}__feature_table.tsv"
print(f"Aggregate file: {aggregate_file}")

if not aggregate_file.exists():
    print(f"Aggregate file does not exist: {aggregate_file}")
else:
    print(f"Aggregate file found")

cluster_path = ROOT_FP / "cluster" / CHANNEL_COMBO / CELL_CLASS / str(LEIDEN_RESOLUTION)
print(f"Cluster path: {cluster_path}")

if not cluster_path.exists():
    print(f"Cluster directory does not exist: {cluster_path}")
else:
    print(f"Cluster directory found")

## Feature Plot Analysis

This section generates visualizations to explore the phenotypic effects of gene perturbations in your screen. The plots will help you:

1. **Differential Feature Analysis**: Identify genes with significant phenotypic changes vs. controls
2. **Waterfall Plots**: Rank genes by their effect on specific features of interest
3. **Two-Feature Plots**: Discover relationships between different phenotypic measurements
4. **Heatmaps**: Visualize patterns across multiple features and gene sets simultaneously

The interactive analysis allow you to customize each visualization for your specific biological questions.

In [None]:
cluster_file = cluster_path / "phate_leiden_clustering.tsv"
cluster_df = pd.read_csv(cluster_file, sep="\t")
display(cluster_df)

In [None]:
aggregate_df = pd.read_csv(aggregate_file, sep="\t")
display(aggregate_df)

## <font color='red'>SET PARAMETERS</font>

### Cluster Selection for Visualization

Set these parameters to select the specific cluster to analyze:
- `CLUSTER_ID`: The cluster of interest to generate plots.

In [None]:
CLUSTER_ID = None

In [None]:
print(f"\n{'='*50}")
print(f"Analyzing Cluster {CLUSTER_ID}")
cluster_genes = cluster_df[cluster_df['cluster'] == CLUSTER_ID][config["aggregate"]["perturbation_name_col"]]
print(f"Genes in Cluster {CLUSTER_ID}:" ,", ".join(cluster_genes.unique()))
print(f"{'='*50}")

# Run differential analysis using robust z-score
diff_results = differential_analysis(
    feature_df=aggregate_df,  
    cluster_df=cluster_df,    
    cluster_id=CLUSTER_ID,
    control_type="nontargeting",  
    control_label=config["aggregate"]["control_key"],
    use_nonparametric=True,  
    normalize_method="robust_zscore"
)

# Display results in a more user-friendly format
print("\nTop Upregulated Features:")
up_df = diff_results['top_up'][['feature', 'robust_zscore', 'p_value', 'median_test', 'median_control']]
up_df = up_df.rename(columns={
    'robust_zscore': 'Z-score',
    'p_value': 'p-value',
    'median_test': 'Median (test)',
    'median_control': 'Median (control)'
})
display(up_df.style.format({
    'Z-score': '{:.2f}',
    'p-value': '{:.2e}',
    'Median (test)': '{:.3f}',
    'Median (control)': '{:.3f}'
}))

print("\nTop Downregulated Features:")
down_df = diff_results['top_down'][['feature', 'robust_zscore', 'p_value', 'median_test', 'median_control']]
down_df = down_df.rename(columns={
    'robust_zscore': 'Z-score',
    'p_value': 'p-value',
    'median_test': 'Median (test)',
    'median_control': 'Median (control)'
})
display(down_df.style.format({
    'Z-score': '{:.2f}',
    'p-value': '{:.2e}',
    'Median (test)': '{:.3f}',
    'Median (control)': '{:.3f}'
}))

## <font color='red'>SET PARAMETERS</font>

### Feature Selection for Visualization

Set these parameters to select the specific cluster to analyze:
- `FEATURES_TO_ANALYZE`: The features found in the differential analysis to generate plots for.
- `GENES_TO_LABEL`: The genes within the cluster to label on the plots.

In [None]:
FEATURES_TO_ANALYZE = None
GENES_TO_LABEL = None

In [None]:
# Create waterfall plots for selected features
print("\n--- Waterfall Plots for Selected Features ---")
for feature in FEATURES_TO_ANALYZE:
    print(f"\nPlotting: {feature}")
    waterfall_plot(
        feature_df=aggregate_df, 
        feature=feature,
        cluster_df=cluster_df,
        cluster_id=CLUSTER_ID,
        nontargeting_pattern=config["aggregate"]["control_key"],
        title=f"Cluster {CLUSTER_ID}: {feature}",
        label_genes=GENES_TO_LABEL
    )

# Create two-feature plots for combinations
print("\n--- Two-Feature Plots ---")
# Generate all unique pairs of features
feature_pairs = [(FEATURES_TO_ANALYZE[i], FEATURES_TO_ANALYZE[j]) 
                 for i in range(len(FEATURES_TO_ANALYZE)) 
                 for j in range(i+1, len(FEATURES_TO_ANALYZE))]

for feature1, feature2 in feature_pairs:
    print(f"\nPlotting: {feature1} vs {feature2}")
    two_feature_plot(
        feature_df=aggregate_df,
        x=feature1,
        y=feature2,
        cluster_df=cluster_df,
        cluster_id=CLUSTER_ID,
        nontargeting_pattern=config["aggregate"]["control_key"],
        title=f"Cluster {CLUSTER_ID}: {feature1} vs {feature2}",
        label_genes=GENES_TO_LABEL
    )

# Create heatmap with all differential features
print("\n--- Heatmap of Differential Features ---")
# Get top features from differential analysis
top_up = diff_results['top_up']['feature'].tolist() if not diff_results['top_up'].empty else []
top_down = diff_results['top_down']['feature'].tolist() if not diff_results['top_down'].empty else []
diff_features = top_up + top_down

fig, ax, heatmap_data = cluster_heatmap(
    feature_df=aggregate_df,
    cluster_df=cluster_df,
    cluster_ids=[CLUSTER_ID],
    features=diff_features,
    perturbation_name_col=config["aggregate"]["perturbation_name_col"],
    z_score="global",
    title=f"Cluster {CLUSTER_ID}: Top Differential Features",
)

# PHATE plot for the cluster
print("\n--- PHATE Plot for Cluster ---")
phate_fig = plot_phate_leiden_clusters(
    phate_leiden_clustering=cluster_df,
    perturbation_name_col=config["aggregate"]["perturbation_name_col"],
    control_key=config["aggregate"]["control_key"],
    clusters_of_interest=[CLUSTER_ID],
)

## Mozzarellm: LLM-based Gene Cluster Analysis

### Overview
[Mozzarellm](https://github.com/cheeseman-lab/mozzarellm) is a Python package that leverages Large Language Models (LLMs) to analyze gene clusters for pathway identification and novel gene discovery. This notebook guides you through the process of:

1. **Loading and reshaping gene cluster data** from your OPS screen
2. **Analyzing gene clusters with LLMs** to identify biological pathways
3. **Categorizing genes** as established pathway members, uncharacterized, or having novel potential roles
4. **Prioritizing candidates** for experimental validation

### Prerequisites

You need to install the mozzarellm package in your Brieflow environment:

```bash
pip install git+https://github.com/cheeseman-lab/mozzarellm.git
```

Mozzarellm requires API keys to access LLM services. You need at least one of these keys:

- **OpenAI API Key**: Required for GPT models (gpt-4o, gpt-4.5, etc.)
- **Anthropic API Key**: Required for Claude models (claude-3-7-sonnet, etc.)
- **Google API Key**: Required for Gemini models (gemini-2.0-pro, etc.)

These keys provide access to paid API services, and usage will incur costs based on the number of tokens processed. The cost per analysis varies by model but typically ranges from $0.01-$0.10 per cluster, depending on cluster size and model choice. For this reason, we only run these analyses on a specific chosen Leiden resolution, rather than across all of the generated possible resolutions.

In [None]:
from mozzarellm import analyze_gene_clusters, reshape_to_clusters

In [None]:
# Set API keys
os.environ["OPENAI_API_KEY"] = "your_openai_key_here"
os.environ["ANTHROPIC_API_KEY"] = "your_anthropic_key_here"  
os.environ["GOOGLE_API_KEY"] = "your_google_key_here"

In [None]:
RESULTS_DIR = cluster_path / "mozzarellm_analysis"
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
print(f"Results will be saved to: {RESULTS_DIR}")

In [None]:
cluster_file = cluster_path / "phate_leiden_clustering.tsv"
cluster_df = pd.read_csv(cluster_file, sep="\t")
display(cluster_df)

## <font color='red'>SET PARAMETERS</font>: 

### Data Reshaping

Configure gene clustering parameters:
- `GENE_COL`: Column containing gene identifiers
- `CLUSTER_COL`: Column containing cluster assignments
- `UNIPROT_COL`: Column with UniProt annotations

These parameters control how gene-level data is converted to cluster-level data. These are set to default values that will be established in upstream notebooks/snakemakes, and usually don't need to be manipulated.

In [None]:
GENE_COL = config["aggregate"]["perturbation_name_col"]
CLUSTER_COL = "cluster"
UNIPROT_COL = "uniprot_function"

In [None]:
llm_cluster_df, llm_uniprot_df = reshape_to_clusters(
    input_df=cluster_df, 
    gene_col=GENE_COL,
    cluster_col=CLUSTER_COL,
    uniprot_col=UNIPROT_COL, 
    verbose=True
)
display(llm_cluster_df)
display(llm_uniprot_df)

## <font color='red'>SET PARAMETERS</font>: 
### LLM Analysis Configuration

- `MODEL_NAME`: LLM to use for analysis. Usable models include:
  - OpenAI: `o4-mini`, `o3-mini`, `gpt-4.1`, `gpt-4o`
  - Anthropic: `claude-3-7-sonnet-20250219`, `claude-3-5-haiku-20241022`
  - Google: `gemini-2.5-pro-preview-03-25`, `gemini-2.5-flash-preview-04-17`
- `CONFIG_DICT`: Configuration file for the LLM model
- `SCREEN_CONTEXT`: Context for the analysis and how to evaluate _clusters_
- `CLUSTER_ANALYSIS_PROMPT`: Context for the analysis and how to evaluate _genes within clusters_

Mozzarellm includes optimized [configurations](https://github.com/cheeseman-lab/mozzarellm/blob/main/mozzarellm/configs.py) and [prompts](https://github.com/cheeseman-lab/mozzarellm/blob/main/mozzarellm/prompts.py) you can import as shown below.

Custom text files can also be used by setting `screen_context_path` and `cluster_analysis_prompt_path` parameters.

In [None]:
from mozzarellm.prompts import ROBUST_SCREEN_CONTEXT, ROBUST_CLUSTER_PROMPT
from mozzarellm.configs import DEFAULT_ANTHROPIC_CONFIG

In [None]:
# Set up model configs
MODEL_NAME = "claude-3-7-sonnet-20250219"
CONFIG_DICT = DEFAULT_ANTHROPIC_CONFIG
SCREEN_CONTEXT = ROBUST_SCREEN_CONTEXT
CLUSTER_ANALYSIS_PROMPT = ROBUST_CLUSTER_PROMPT

In [None]:
# Run LLM analysis
anthropic_results = analyze_gene_clusters(
    # Input data options
    input_df=llm_cluster_df,
    # Model and configuration
    model_name=MODEL_NAME,
    config_dict=CONFIG_DICT,
    # Analysis context and prompts
    screen_context=SCREEN_CONTEXT,
    cluster_analysis_prompt=CLUSTER_ANALYSIS_PROMPT,
    # Gene annotations
    gene_annotations_df=llm_uniprot_df,
    # Processing options
    batch_size=1,
    # Output options
    output_file=f"{RESULTS_DIR}/{MODEL_NAME}",
    save_outputs=True,
    outputs_to_generate=["json", "clusters", "flagged_genes"],
)