In [2]:
import scanpy as sc
import re
import pandas as pd
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from matplotlib_venn import venn2

In [3]:
adata = sc.read_h5ad("sim_CNV_final_calls_2.h5ad")
h5ad_filename = 'sim_CNV_final_calls_2.h5ad'

In [4]:
adata

AnnData object with n_obs × n_vars = 10309 × 19186
    obs: 'n_genes_by_counts', 'total_counts', 'total_counts_ribo', 'pct_counts_ribo', 'total_counts_mt', 'pct_counts_mt', 'n_genes', 'n_counts', 'cell_type', 'simulated_cnvs', 'leiden', 'cnv_calls'
    var: 'gene_ids', 'feature_types', 'genome', 'mt', 'ribo', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts', 'n_cells', 'query_x', 'chromosome_x', 'start_x', 'end_x', 'strand_x', 'query_y', 'chromosome_y', 'start_y', 'end_y', 'strand_y', 'highly_variable', 'means', 'dispersions', 'dispersions_norm', 'query', 'chromosome', 'start', 'end', 'strand', 'chrom'
    uns: 'cell_type_colors', 'hvg', 'leiden', 'log1p', 'neighbors', 'pca', 'seed_value', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    layers: 'counts', 'lognorm'
    obsp: 'connectivities', 'distances'

In [5]:
adata.obs

Unnamed: 0,n_genes_by_counts,total_counts,total_counts_ribo,pct_counts_ribo,total_counts_mt,pct_counts_mt,n_genes,n_counts,cell_type,simulated_cnvs,leiden,cnv_calls
AAACCCAAGCGCCCAT-1,1005,1760.0,0.0,0.0,0.0,0.0,1005,1760.0,CD4 T cell,,0,
AAACCCAAGGTTCCGC-1,4101,14240.0,0.0,0.0,0.0,0.0,4101,14240.0,Dendritic,,11,10:101845599-102065349 (CN 3); 10:103453240-10...
AAACCCACAGAGTTGG-1,1742,4208.0,0.0,0.0,0.0,0.0,1742,4208.0,CD14 monocyte,,1,12:7089587-7115736 (CN 3); 19:57435325-5746666...
AAACCCACAGGTATGG-1,2122,4354.0,0.0,0.0,0.0,0.0,2122,4354.0,NK cell,,6,10:103453240-104309698 (CN 3); 10:110868890-11...
AAACCCACATAGTCAC-1,1521,2819.0,0.0,0.0,0.0,0.0,1521,2819.0,B cell,7:28921393-29108349 (gain),8,
...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGGTGCGTCGT-1,1245,2279.0,0.0,0.0,0.0,0.0,1245,2279.0,CD4 T cell,,2,
TTTGTTGGTGTCATGT-1,1245,2548.0,0.0,0.0,0.0,0.0,1245,2548.0,CD14 monocyte,,1,
TTTGTTGGTTTGAACC-1,1308,2468.0,0.0,0.0,0.0,0.0,1308,2468.0,CD8 T cell,,4,
TTTGTTGTCCAAGCCG-1,1576,3092.0,0.0,0.0,0.0,0.0,1576,3092.0,CD4 T cell,,0,


In [6]:
import re

def reformat_cnvs(adata):
    """
    Reformat CNVs in both 'cnv_calls' and 'simulated_cnvs' columns to match the format
    'chrom:start-end (gain/loss)', and save them to new columns:
    - 'formatted_cnv_calls'
    - 'formatted_simulated_cnvs'
    The originals remain unchanged.
    If CNVs are already formatted correctly, they will be copied directly to the new columns.
    """
    def is_correct_format(cnv_str):
        """
        Check if the CNV string is already in the correct format.
        Correct format: 'chrom:start-end (gain/loss)'
        """
        return bool(re.match(r"([XY0-9]+):(\d+)-(\d+)\s\((loss|gain)\)", cnv_str))

    def parse_cnv(cnv_str):
        """
        Reformat CNV string if it's not in the correct format.
        CNV string must be in format 'chrom:start-end (CN 0/1/3/4)'.
        """
        if not cnv_str or cnv_str.strip() == '':
            return None

        # Check if the CNV string is already in the correct format
        if is_correct_format(cnv_str):
            return cnv_str  # No need to reformat if already correct

        # Match the format like '22:19807132-29743868 (CN 0)'
        match = re.match(r"([XY0-9]+):(\d+)-(\d+)\s\(CN (\d+)\)", cnv_str)
        if match:
            chrom, start, end, cn_type = match.groups()
            if cn_type in {'0', '1'}:
                return f"{chrom}:{start}-{end} (loss)"
            elif cn_type in {'3', '4'}:
                return f"{chrom}:{start}-{end} (gain)"
        return None  # Return None if it's neither in the correct format nor reformatable

    # Apply to both columns
    adata.obs['formatted_cnv_calls'] = adata.obs['cnv_calls'].apply(parse_cnv)
    adata.obs['formatted_simulated_cnvs'] = adata.obs['simulated_cnvs'].apply(parse_cnv)

    # Check for any invalid formats and report
    for col in ['cnv_calls', 'simulated_cnvs']:
        formatted_col = f"formatted_{col}"
        invalid_rows = adata.obs[adata.obs[formatted_col].isnull() & adata.obs[col].notnull()]
        if not invalid_rows.empty:
            print(f"Warning: Invalid CNV formats found in '{col}' column:")
            print(invalid_rows[[col]])

    print("First few rows of the newly formatted CNVs:")
    print(adata.obs[['cnv_calls', 'formatted_cnv_calls', 'simulated_cnvs', 'formatted_simulated_cnvs']].head())

    return adata


adata = reformat_cnvs(adata)


                   cnv_calls
AAACCCAAGCGCCCAT-1          
AAACCCACATAGTCAC-1          
AAACGAACACAAGCTT-1          
AAACGAAGTTATGTCG-1          
AAACGAATCTACTTCA-1          
...                      ...
TTTGTTGGTGCGTCGT-1          
TTTGTTGGTGTCATGT-1          
TTTGTTGGTTTGAACC-1          
TTTGTTGTCCAAGCCG-1          
TTTGTTGTCTTACTGT-1          

[2630 rows x 1 columns]
                   simulated_cnvs
AAACCCAAGCGCCCAT-1               
AAACCCAAGGTTCCGC-1               
AAACCCACAGAGTTGG-1               
AAACCCACAGGTATGG-1               
AAACCCACATCCAATG-1               
...                           ...
TTTGTTGGTGCGTCGT-1               
TTTGTTGGTGTCATGT-1               
TTTGTTGGTTTGAACC-1               
TTTGTTGTCCAAGCCG-1               
TTTGTTGTCTTACTGT-1               

[9278 rows x 1 columns]
First few rows of the newly formatted CNVs:
                                                            cnv_calls  \
AAACCCAAGCGCCCAT-1                                                      
AAAC

In [7]:
adata.obs

Unnamed: 0,n_genes_by_counts,total_counts,total_counts_ribo,pct_counts_ribo,total_counts_mt,pct_counts_mt,n_genes,n_counts,cell_type,simulated_cnvs,leiden,cnv_calls,formatted_cnv_calls,formatted_simulated_cnvs
AAACCCAAGCGCCCAT-1,1005,1760.0,0.0,0.0,0.0,0.0,1005,1760.0,CD4 T cell,,0,,,
AAACCCAAGGTTCCGC-1,4101,14240.0,0.0,0.0,0.0,0.0,4101,14240.0,Dendritic,,11,10:101845599-102065349 (CN 3); 10:103453240-10...,10:101845599-102065349 (gain),
AAACCCACAGAGTTGG-1,1742,4208.0,0.0,0.0,0.0,0.0,1742,4208.0,CD14 monocyte,,1,12:7089587-7115736 (CN 3); 19:57435325-5746666...,12:7089587-7115736 (gain),
AAACCCACAGGTATGG-1,2122,4354.0,0.0,0.0,0.0,0.0,2122,4354.0,NK cell,,6,10:103453240-104309698 (CN 3); 10:110868890-11...,10:103453240-104309698 (gain),
AAACCCACATAGTCAC-1,1521,2819.0,0.0,0.0,0.0,0.0,1521,2819.0,B cell,7:28921393-29108349 (gain),8,,,7:28921393-29108349 (gain)
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGTTGGTGCGTCGT-1,1245,2279.0,0.0,0.0,0.0,0.0,1245,2279.0,CD4 T cell,,2,,,
TTTGTTGGTGTCATGT-1,1245,2548.0,0.0,0.0,0.0,0.0,1245,2548.0,CD14 monocyte,,1,,,
TTTGTTGGTTTGAACC-1,1308,2468.0,0.0,0.0,0.0,0.0,1308,2468.0,CD8 T cell,,4,,,
TTTGTTGTCCAAGCCG-1,1576,3092.0,0.0,0.0,0.0,0.0,1576,3092.0,CD4 T cell,,0,,,


In [8]:
def evaluate_cnv_predictions(adata, h5ad_filename):
    """
    Evaluate CNV predictions at the cell level by comparing predicted CNVs ('formatted_cnv_calls')
    to true CNVs ('formatted_simulated_cnvs') in the AnnData object.

    Prints confusion matrix, standard metrics, and saves a confusion matrix heatmap.

    Parameters:
    - adata: AnnData object containing 'formatted_simulated_cnvs' and 'formatted_cnv_calls' in .obs
    - h5ad_filename: string, the filename of the h5ad file (used to create the PNG name)
    """
    # Extract true and predicted CNV columns
    true_cnv = adata.obs['formatted_simulated_cnvs'].fillna('').astype(str)
    pred_cnv = adata.obs['formatted_cnv_calls'].fillna('').astype(str)

    # Define binary labels: 1 if CNV present, 0 if absent
    y_true = (true_cnv != '').astype(int)
    y_pred = (pred_cnv != '').astype(int)

    # Confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    tn, fp, fn, tp = cm.ravel()

    print("Confusion Matrix:")
    print(f"True Positives: {tp}")
    print(f"False Positives: {fp}")
    print(f"False Negatives: {fn}")
    print(f"True Negatives: {tn}")
    print()

    # Classification report
    report = classification_report(y_true, y_pred, target_names=['No CNV', 'CNV'], zero_division=0)
    print("Classification Report:")
    print(report)

    # Plot confusion matrix heatmap
    labels = ['No CNV', 'CNV']
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')

    # Create dynamic filename
    base_name = os.path.splitext(os.path.basename(h5ad_filename))[0]
    png_filename = f"{base_name}_confusion_matrix.png"

    # Save plot
    plt.tight_layout()
    plt.savefig(png_filename, dpi=300)
    plt.close()

    print(f"Confusion matrix saved as {png_filename}")

    # Return results as a dictionary
    return {
        'true_positives': tp, # Cells that have a CNV and were predicted to have a CNV
        'false_positives': fp, # Cells that don't have a CNV but were predicted to have a CNV
        'false_negatives': fn, # Cells that have a CNV but were predicted to not have a CNV
        'true_negatives': tn, # Cells that don't have a CNV and were predicted to not have a CNV
        'classification_report': report,
        'confusion_matrix_file': png_filename
    }


In [10]:
results_cell = evaluate_cnv_predictions(adata, h5ad_filename)

Confusion Matrix:
True Positives: 753
False Positives: 6926
False Negatives: 278
True Negatives: 2352

Classification Report:
              precision    recall  f1-score   support

      No CNV       0.89      0.25      0.40      9278
         CNV       0.10      0.73      0.17      1031

    accuracy                           0.30     10309
   macro avg       0.50      0.49      0.28     10309
weighted avg       0.81      0.30      0.37     10309

Confusion matrix saved as sim_CNV_final_calls_2_confusion_matrix.png


In [11]:


def evaluate_true_positives(adata):
    """
    Evaluate the true positives (cells with both true and predicted CNVs) and compare
    the predicted CNVs against the true CNVs based on multiple criteria.

    Parameters:
        adata (AnnData): The AnnData object containing true CNVs in 'formatted_simulated_cnvs' and predicted CNVs in 'formatted_cnv_calls'.

    Returns:
        dict: A dictionary containing the evaluation results.
    """
    # Initialize counters
    correct_chromosome = 0
    correct_region = 0
    correct_direction = 0
    exact_match = 0
    total_true_positives = 0

    # Iterate over each cell and check if it has a true CNV and predicted CNVs
    for cell in adata.obs_names:
        true_cnv = adata.obs.loc[cell, 'formatted_simulated_cnvs']
        predicted_cnvs = adata.obs.loc[cell, 'formatted_cnv_calls']

        if true_cnv and predicted_cnvs:  # If both true and predicted CNVs exist
            total_true_positives += 1
            true_chrom, true_start_end, true_direction = parse_cnv(true_cnv)

            # Split the predicted CNVs (could have multiple CNVs per cell)
            predicted_cnv_list = predicted_cnvs.split(';')

            for pred_cnv in predicted_cnv_list:
                pred_chrom, pred_start_end, pred_direction = parse_cnv(pred_cnv)

                # Evaluate chromosome match
                if true_chrom == pred_chrom:
                    correct_chromosome += 1

                # Evaluate region match (overlapping base pairs)
                true_start, true_end = map(int, true_start_end.split('-'))
                pred_start, pred_end = map(int, pred_start_end.split('-'))
                if (true_chrom == pred_chrom) and (max(true_start, pred_start) <= min(true_end, pred_end)):
                    correct_region += 1

                # Evaluate direction match (same gain/loss)
                if true_direction == pred_direction:
                    correct_direction += 1

                # Evaluate exact match (chromosome, region, and direction)
                if true_chrom == pred_chrom and true_start == pred_start and true_end == pred_end and true_direction == pred_direction:
                    exact_match += 1

    # Compile results
    results = {
        'total_true_positives': total_true_positives,
        'correct_chromosome': correct_chromosome,
        'correct_region': correct_region,
        'correct_direction': correct_direction,
        'exact_match': exact_match
    }

    print("Results:")
    print(f"True Positives: {total_true_positives}")
    print(f"Chromosome Matched: {correct_chromosome}")
    print(f"Region Matched: {correct_region}")
    print(f"Direction Matched: {correct_direction}")
    print(f"Exact Matched: {exact_match}")
    print()

    return results

def parse_cnv(cnv_str):
    """
    Parse a CNV string of the form 'chrom:start-end (direction)' into its components.

    Parameters:
        cnv_str (str): CNV string in the format 'chrom:start-end (direction)'.

    Returns:
        tuple: (chromosome, start-end range, direction)
    """
    chrom, rest = cnv_str.split(':')
    start_end, direction = rest.split(' (')
    direction = direction.strip(')')
    return chrom, start_end, direction

results_cell_cnv = evaluate_true_positives(adata)

Results:
True Positives: 753
Chromosome Matched: 18
Region Matched: 0
Direction Matched: 491
Exact Matched: 0



In [13]:


def compare_unique_cnvs(adata, h5ad_filename, predicted_cnv_col='formatted_cnv_calls', true_cnv_col='formatted_simulated_cnvs'):
    # Extract true CNVs
    true_cnvs = set()
    for cnv_str in adata.obs[true_cnv_col]:
        if cnv_str and cnv_str != 'None':
            match = re.match(r"([XY0-9]+):(\d+)-(\d+)\s\((gain|loss)\)", cnv_str.strip())
            if match:
                chrom, start, end, direction = match.groups()
                true_cnvs.add((chrom, int(start), int(end), direction))

    # Extract predicted CNVs
    predicted_cnvs = set()
    for cnv_str in adata.obs[predicted_cnv_col]:
        if cnv_str and cnv_str != 'None':
            for cnv_part in cnv_str.split(';'):
                cnv_part = cnv_part.strip()
                match = re.match(r"([XY0-9]+):(\d+)-(\d+)\s\((gain|loss)\)", cnv_part)
                if match:
                    chrom, start, end, direction = match.groups()
                    predicted_cnvs.add((chrom, int(start), int(end), direction))

    # Identify matches
    matching_cnvs = true_cnvs & predicted_cnvs

    # Venn diagram
    plt.figure(figsize=(6, 6))
    venn = venn2([true_cnvs, predicted_cnvs], set_labels=('True CNVs', 'Predicted CNVs'))
    plt.title("Unique CNV Comparison")

    # Create dynamic filename
    base_name = os.path.splitext(os.path.basename(h5ad_filename))[0]
    venn_filename = f"{base_name}_venn_diagram.png"
    plt.savefig(venn_filename)
    plt.close()
    print(f"Venn diagram saved as: {venn_filename}")

    return {
        'total_unique_true_cnvs': len(true_cnvs),
        'total_unique_predicted_cnvs': len(predicted_cnvs),
        'matching_cnvs_count': len(matching_cnvs),
        'matching_cnv_list': list(matching_cnvs),
        'true_cnv_list': list(true_cnvs),
        'predicted_cnv_list': list(predicted_cnvs),
    }



results_cnv = compare_unique_cnvs(adata, h5ad_filename)

print(f"Total unique true CNVs: {results_cnv['total_unique_true_cnvs']}")
print(f"Total unique predicted CNVs: {results_cnv['total_unique_predicted_cnvs']}")
print(f"Matching CNVs: {results_cnv['matching_cnvs_count']}")
print(f"List of matching CNVs: {results_cnv['matching_cnv_list']}")
print(f"List of true CNVs: {results_cnv['true_cnv_list']}")
print(f"List of predicted CNVs: {results_cnv['predicted_cnv_list']}")



Venn diagram saved as: sim_CNV_final_calls_2_venn_diagram.png
Total unique true CNVs: 8
Total unique predicted CNVs: 185
Matching CNVs: 0
List of matching CNVs: []
List of true CNVs: [('X', 99307506, 99749358, 'gain'), ('14', 37511489, 37706949, 'loss'), ('5', 17749068, 17849998, 'gain'), ('6', 9654515, 9827269, 'loss'), ('19', 33464603, 33571478, 'gain'), ('8', 51104463, 51489550, 'gain'), ('20', 34285250, 34774160, 'loss'), ('7', 28921393, 29108349, 'gain')]
List of predicted CNVs: [('19', 248551, 589881, 'gain'), ('11', 1888078, 1919703, 'gain'), ('13', 19674624, 20728731, 'gain'), ('10', 51062577, 73253762, 'gain'), ('2', 70093885, 70808643, 'gain'), ('3', 16586792, 16586792, 'gain'), ('16', 22365121, 22479121, 'gain'), ('6', 3118374, 4018358, 'gain'), ('9', 137618992, 138199777, 'gain'), ('10', 17137336, 17315421, 'gain'), ('12', 49002274, 49188736, 'gain'), ('16', 3432414, 3443649, 'gain'), ('12', 7346685, 7470811, 'gain'), ('12', 47075707, 47352376, 'gain'), ('12', 66767, 307735

In [14]:
# Evaluating the true and predicted CNVs on the same chromosomes
# Measuring their distance in base pairs from each other

def compare_cnv_distances_with_plot(
    adata,
    predicted_cnv_col='formatted_cnv_calls',
    true_cnv_col='formatted_simulated_cnvs',
    h5ad_filename='adata.h5ad'
):
    comparisons = []

    # Get base filename without extension
    base_name = os.path.splitext(os.path.basename(h5ad_filename))[0]

    # Extract true CNVs
    true_cnvs = []
    for cnv_str in adata.obs[true_cnv_col]:
        if cnv_str and cnv_str != 'None':
            match = re.match(r"([XY0-9]+):(\d+)-(\d+)\s\((gain|loss)\)", cnv_str.strip())
            if match:
                chrom, start, end, direction = match.groups()
                true_cnvs.append((chrom, int(start), int(end), direction))

    # Extract predicted CNVs
    predicted_cnvs = []
    for cnv_str in adata.obs[predicted_cnv_col]:
        if cnv_str and cnv_str != 'None':
            for cnv_part in cnv_str.split(';'):
                cnv_part = cnv_part.strip()
                match = re.match(r"([XY0-9]+):(\d+)-(\d+)\s\((gain|loss)\)", cnv_part)
                if match:
                    chrom, start, end, direction = match.groups()
                    predicted_cnvs.append((chrom, int(start), int(end), direction))

    # Compare all true and predicted CNVs on the same chromosome
    for true_cnv in true_cnvs:
        t_chrom, t_start, t_end, t_dir = true_cnv
        for pred_cnv in predicted_cnvs:
            p_chrom, p_start, p_end, p_dir = pred_cnv
            if t_chrom == p_chrom:
                # Calculate distance between the intervals
                if p_end < t_start:
                    distance = t_start - p_end
                elif t_end < p_start:
                    distance = p_start - t_end
                else:
                    distance = 0  # They overlap
                comparisons.append({
                    'chromosome': t_chrom,
                    'true_cnv': true_cnv,
                    'predicted_cnv': pred_cnv,
                    'distance': distance
                })

    # Plot histogram of distances
    distances = [comp['distance'] for comp in comparisons]
    if distances:
        plt.figure(figsize=(8, 5))
        plt.hist(distances, bins=30, color='skyblue', edgecolor='black')
        plt.xlabel('Distance between CNVs (bp)')
        plt.ylabel('Count')
        plt.title('Histogram of CNV Distances (Same Chromosome)')
        plot_filename = f"{base_name}_cnv_distance_histogram.png"
        plt.savefig(plot_filename)
        plt.close()
        print(f"Histogram saved as: {plot_filename}")
    else:
        print("No distances to plot.")

    # Sort comparisons by distance (optional)
    comparisons.sort(key=lambda x: x['distance'])

    return comparisons


In [15]:
distances = compare_cnv_distances_with_plot(
    adata,
    predicted_cnv_col='formatted_cnv_calls',
    h5ad_filename=h5ad_filename
)

print(f"Found {len(distances)} comparisons on matching chromosomes.")
for comp in distances:
    print(f"Chromosome: {comp['chromosome']}, "
          f"True CNV: {comp['true_cnv']}, "
          f"Predicted CNV: {comp['predicted_cnv']}, "
          f"Distance: {comp['distance']} bp")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Chromosome: 5, True CNV: ('5', 17749068, 17849998, 'gain'), Predicted CNV: ('5', 154682986, 154818492, 'gain'), Distance: 136832988 bp
Chromosome: 5, True CNV: ('5', 17749068, 17849998, 'gain'), Predicted CNV: ('5', 154682986, 154818492, 'gain'), Distance: 136832988 bp
Chromosome: 5, True CNV: ('5', 17749068, 17849998, 'gain'), Predicted CNV: ('5', 154682986, 154818492, 'gain'), Distance: 136832988 bp
Chromosome: 5, True CNV: ('5', 17749068, 17849998, 'gain'), Predicted CNV: ('5', 154682986, 154818492, 'gain'), Distance: 136832988 bp
Chromosome: 5, True CNV: ('5', 17749068, 17849998, 'gain'), Predicted CNV: ('5', 154682986, 154818492, 'gain'), Distance: 136832988 bp
Chromosome: 5, True CNV: ('5', 17749068, 17849998, 'gain'), Predicted CNV: ('5', 154682986, 154818492, 'gain'), Distance: 136832988 bp
Chromosome: 5, True CNV: ('5', 17749068, 17849998, 'gain'), Predicted CNV: ('5', 154682986, 154818492, 'gain'), Distance: 136