# Analyze and visualize hits

Here we compute the number of hits per ClinVar category, and also plot selected sites from allele pairs of interest.

In [1]:
# imports
import collections
from importlib import reload

import displayImages as di
import polars as pl
from tqdm import tqdm

reload(di)

<module 'displayImages' from '/dgx1nas1/storage/data/jess/repos/2021_09_01_VarChAMP/7.downstream_analysis_jess/final_notebooks/displayImages.py'>

In [10]:
# Read in data and set parameters
metrics_dir = "/dgx1nas1/storage/data/jess/varchamp/sc_data/classification_results/B7B8_1percent_updatedmeta"
prof_dir = "/dgx1nas1/storage/data/jess/repos/2021_09_01_VarChAMP/6.downstream_analysis_snakemake/outputs/batch_profiles"
img_dir = "/dgx1nas1/storage/data/jess/varchamp/images"
metrics_df = pl.read_csv(f"{metrics_dir}/metrics.csv")
metrics_wtvar = pl.read_csv(f"{metrics_dir}/metrics_summary.csv")
thresh = 3  # previously 10
min_class_num = 2

## Annotate with Clinvar labels

In [11]:
# split into morphology and localization, and count the proportion of classifiers that surpass the 0.99 NULL F1 score
morph_wtvar = metrics_wtvar.filter(pl.col("Classifier_type") == "morphology")
local_wtvar = metrics_wtvar.filter(pl.col("Classifier_type") == "localization")
print(local_wtvar.shape)

# Analyze wrt clinvar annotations
clinvar = pl.read_csv("../data/allele_collection_clinical_significance.csv")
clinvar = clinvar.with_columns(
    pl.concat_str(["symbol", "aa_change"], separator="_").alias("allele_0")
)

local_wtvar = local_wtvar.join(
    clinvar.select(["allele_0", "clinvar_cs"]), on="allele_0"
)
print(local_wtvar.shape)

(942, 11)
(962, 12)


In [12]:
# Alleles must be mislocalized in both batches
local_wtvar = local_wtvar.with_columns(
    pl.when(pl.col("AUROC_mean") > pl.col("AUROC_thresh"))
    .then(1)
    .otherwise(0)
    .alias("Mislocalized")
).unique()

misloc_binary = (
    local_wtvar.pivot(
        index=["allele_0", "Allele_set", "clinvar_cs"],
        columns="Batch",
        values="Mislocalized",
    )
    .with_columns(
        ((pl.col("batch7") == 1) & (pl.col("batch8") == 1)).alias(
            "Mislocalized_both_batches"
        )
    )
    .rename({"batch7": "mislocalized_batch7", "batch8": "mislocalized_batch8"})
)

misloc_auroc = local_wtvar.pivot(
    index=["allele_0"],
    columns="Batch",
    values="AUROC_mean",
).rename({"batch7": "auroc_batch7", "batch8": "auroc_batch8"})

misloc_summary = misloc_binary.join(misloc_auroc, on="allele_0")

In [5]:
misloc_summary.write_csv("../results/summary_auroc.csv")

In [16]:
# count %  mislocalized by label type
benign_local = misloc_summary.filter(pl.col("clinvar_cs") == "Benign")
vus_local = misloc_summary.filter(pl.col("clinvar_cs") == "VUS")
path_local = misloc_summary.filter(pl.col("clinvar_cs") == "Pathogenic")

print("Benign")
print(
    benign_local.filter(pl.col("Mislocalized_both_batches")).shape[0]
    / benign_local.shape[0]
)
print(
    "# misloc: "
    + str(benign_local.filter(pl.col("Mislocalized_both_batches")).shape[0])
)
print("total #: " + str(benign_local.shape[0]))
print("\n")

print("VUS")
print(
    vus_local.filter(pl.col("Mislocalized_both_batches")).shape[0] / vus_local.shape[0]
)
print(
    "# misloc: " + str(vus_local.filter(pl.col("Mislocalized_both_batches")).shape[0])
)
print("total #: " + str(vus_local.shape[0]))
print("\n")

print("Pathogenic")
print(
    path_local.filter(pl.col("Mislocalized_both_batches")).shape[0]
    / path_local.shape[0]
)
print(
    "# misloc: " + str(path_local.filter(pl.col("Mislocalized_both_batches")).shape[0])
)
print("total #: " + str(path_local.shape[0]))

Benign
0.16666666666666666
# misloc: 9
total #: 54


VUS
0.2911392405063291
# misloc: 23
total #: 79


Pathogenic
0.3465909090909091
# misloc: 61
total #: 176


In [18]:
# Define different lists of alleles of interest

misloc_benign = (
    benign_local.filter(pl.col("Mislocalized_both_batches"))
    .select("allele_0")
    .to_series()
    .to_list()
)

misloc_path = (
    path_local.filter(pl.col("Mislocalized_both_batches"))
    .select("allele_0")
    .to_series()
    .to_list()
)[0:5]

misloc_all = (
    misloc_summary.filter(pl.col("Mislocalized_both_batches"))
    .select("allele_0")
    .to_series()
    .to_list()
)

In [19]:
# Make barplot
import matplotlib.pyplot as plt

values = {
    "Pathogenic": 34,
    "VUS": 28,
    "Conflicting": 23,
    "Benign": 16,
    "No annotation": 24,
}

# Set the font to Arial
plt.rcParams["font.family"] = "Arial"

# Create the barplot
plt.figure(figsize=(8, 6))
bars = plt.bar(values.keys(), values.values(), color="skyblue")

# Add labels and title
plt.xlabel("Categories", fontsize=12)
plt.ylabel("Hit rate (%)", fontsize=12)

# Calculate the total for percentage
total = sum(values.values())

# Add the percentage labels on top of each bar
for bar in bars:
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2,
        height,
        f"{int(height)}%",
        ha="center",
        va="bottom",
        fontsize=12,
    )

# Show the plot
plt.tight_layout()
plt.savefig(
    "/dgx1nas1/storage/data/jess/repos/2021_09_01_VarChAMP/7.downstream_analysis_jess/results/variant_hit_rate.pdf",
    format="pdf",
)
plt.close()

## Plot images

In [20]:
# Get metadata required for plotting (for batch 7 only here)
pm_df = pl.scan_parquet(
    f"{prof_dir}/2024_01_23_Batch_7/profiles_tcdropped_filtered_var_mad_outlier_featselect_filtcells_metacorr.parquet"
)
meta_cols = [
    "Metadata_Well",
    "Metadata_Plate",
    "Metadata_gene_allele",
    "Metadata_node_type",
]
pm_df = pm_df.select(meta_cols).unique().collect()

pm_df = pm_df.rename({
    "Metadata_Well": "Well",
    "Metadata_Plate": "Plate",
    "Metadata_gene_allele": "Allele",
    "Metadata_node_type": "control_type",
}).with_columns(
    pl.lit("05").alias("Site"),
    pl.col("Plate").str.slice(11, 6).alias("Batch"),
    pl.col("Plate").str.slice(11, 9).alias("Plate"),
)

rep_df = pl.DataFrame({
    "Batch": [
        "B7A1R1",
        "B7A1R1",
        "B7A1R1",
        "B7A1R1",
        "B7A2R1",
        "B7A2R1",
        "B7A2R1",
        "B7A2R1",
        "B8A1R2",
        "B8A1R2",
        "B8A1R2",
        "B8A1R2",
        "B8A2R2",
        "B8A2R2",
        "B8A2R2",
        "B8A2R2",
    ],
    "Replicate": [
        "T1",
        "T2",
        "T3",
        "T4",
        "T1",
        "T2",
        "T3",
        "T4",
        "T1",
        "T2",
        "T3",
        "T4",
        "T1",
        "T2",
        "T3",
        "T4",
    ],
})

pm_df = pm_df.join(rep_df, on="Batch")

In [37]:
# Define allele list and plot dir
alleles = misloc_summary.filter(pl.col("allele_0").str.contains("MVK")).to_series().unique().to_list()
#alleles = ["MVK_Leu41Pro", "MVK_Leu255Pro"]
plot_dir = f"{img_dir}/Images_Maxime"

In [38]:
# plot the 5th site from all images, organized by classifier
#counter = 0
for var_allele in tqdm(alleles):
    #padded_counter = str(counter).zfill(3)

    #gfp_nm = f"{str(padded_counter)}_{var_allele}_GFP.png"
    gfp_nm = f"{var_allele}_GFP.png"

    wt_allele = var_allele.split("_")[0]
    plot_img = (
        pm_df.filter(
            ((pl.col("Allele") == var_allele) | (pl.col("Allele") == wt_allele))
        )
        .unique()
        .sort(["Allele", "Batch", "Replicate"])
    )

    # filter to keep only images where there is a single WT-VAR pair in each plate
    plates_img = (
        plot_img.select(["Plate", "control_type"])
        .unique()
        .select("Plate")
        .to_series()
        .to_list()
    )
    plates_img = collections.Counter(plates_img)
    plates_img = (
        pl.DataFrame(plates_img)
        .melt()
        .filter(pl.col("value") == 2)
        .select("variable")
        .to_series()
        .to_list()
    )
    plot_img = plot_img.filter(pl.col("Plate").is_in(plates_img))

    di.plotMultiImages(
        plot_img, "GFP", 0.99, 4, display=False, plotpath=f"{plot_dir}/{gfp_nm}"
    )
    #counter = counter + 1

100%|██████████| 25/25 [03:32<00:00,  8.48s/it]
