# Mycobacterium tuberculosis validation 

## Step 1: Import Libraries

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import json
import os
import matplotlib.pyplot as plt

## Set input & output filepaths

In [None]:
km_tb_tsv_filepath = "/data/bnf/dev/ryan/validation/tb/tb_clinical_validation.tsv"
tbprofiler_result_dir = "/fs1/results_dev/jasen/mtuberculosis/tbprofiler_mergedb/"
quast_result_dir = "/fs1/results_dev/jasen/mtuberculosis/quast/"
paqc_result_dir = "/fs1/results_dev/jasen/mtuberculosis/postalignqc/"
tb_validation_csv_fpath = "/data/bnf/dev/ryan/validation/tb/tb_validation.csv"
tb_validation_accuracy_output = "/data/bnf/dev/ryan/validation/tb/tb_validation_accuracy.csv"
tb_validation_amr_fig = "/data/bnf/dev/ryan/validation/tb/tb_validation.png"
tb_validation_ass_output = "/data/bnf/dev/ryan/validation/tb/tb_validation_ass.csv"
tb_validation_qc_output = "/data/bnf/dev/ryan/validation/tb/tb_validation_qc.csv"
tb_validation_qc_fig = "/data/bnf/dev/ryan/validation/tb/tb_validation_qc.png"
tb_validation_paqc_output = "/data/bnf/dev/ryan/validation/tb/tb_validation_paqc.csv"
tb_validation_paqc_fig = "/data/bnf/dev/ryan/validation/tb/tb_validation_paqc.png"

## Preprocess data

In [None]:
km_tb_df = pd.read_csv(km_tb_tsv_filepath, delimiter='\t')

# Column patterns to be removed
patterns = ["tb-profiler", "Mykrobe", "NGS"]

# Filter out columns containing any of the patterns
km_tb_df = km_tb_df.drop(columns=[col for col in km_tb_df.columns if any(pattern in col for pattern in patterns)])

# Replace "ej testad" with NaN
km_tb_df.replace(["eh"], np.nan, inplace=True)

# print(km_tb_df.head())
# print(km_tb_df.columns.tolist())

## Define function

### Function to get AMR information

In [None]:
def get_sample_pred_array(analysis_result_dir, sample_id, suffix):
    pred_filepath = os.path.join(analysis_result_dir, sample_id + suffix)
    if os.path.exists(pred_filepath):
        with open(pred_filepath, "r") as fin:
            pred_dict = json.load(fin)
            res_drugs_dict = {}
            for var in pred_dict["dr_variants"]:
                for annot in var["annotation"]:
                    if annot["confidence"] == "Assoc w R":
                        res_drugs_dict[annot["drug"]] = "R"
                    elif annot["confidence"] == "Assoc w R - Interim":
                        if annot["drug"] not in res_drugs_dict:
                            res_drugs_dict[annot["drug"]] = "I"

            sample_info = {
                "Labnummer": sample_id,
                "main_lineage": pred_dict["main_lineage"],
            }

            return {**sample_info, **res_drugs_dict}
    else:
        # print(f"{sample_id} does not exist in {analysis_result_dir}")
        return None


### Function to get QC information 

In [None]:
def get_sample_qc_array(analysis_result_dir, sample_id, suffix, qc_params):
    qc_filepath = os.path.join(analysis_result_dir, sample_id + suffix)
    if os.path.exists(qc_filepath):
        tb_qc_df = pd.read_csv(qc_filepath, delimiter='\t')
        tb_qc_df = tb_qc_df[qc_params]
        qc_array = json.loads(tb_qc_df.to_json(orient="records"))[0]
        sample_info = {
            "Labnummer": sample_id,
        }
        return {**sample_info, **qc_array}
    else:
        # print(f"{sample_id} does not exist in {analysis_result_dir}")
        return None

### Function to get postalignqc information

In [None]:
def get_sample_paqc_array(analysis_result_dir, sample_id, suffix, qc_params):
    paqc_filepath = os.path.join(analysis_result_dir, sample_id + suffix)
    if os.path.exists(paqc_filepath):
        with open(paqc_filepath, "r") as fin:
            qc_array = json.load(fin)["pct_above_x"]
            sample_info = {
                "Labnummer": sample_id,
            }
            return {**sample_info, **qc_array}
    else:
        # print(f"{sample_id} does not exist in {analysis_result_dir}")
        return None

## Add amr columns to dataframe

In [None]:
amr_drugs = ['isoniazid', 'rifampicin', 'ethambutol', 'amikacin', 'ofloxacin',
             'pyrazinamide', 'linezolid', 'streptomycin', 'kanamycin', 'moxifloxacin',
             'levofloxacin', 'rifabutin', 'ethionamide', 'capreomycin', 'cycloserine',
             'PAS', 'clofazimine', 'bedaquiline', 'delamanid', 'main_lineage']
# Add each drug as an empty column
for drug in amr_drugs:
    km_tb_df[drug] = np.nan

## Add QC columns to dataframe

In [None]:
qc_params = ['# contigs', 'Largest contig', 'Total length', 'N50']
# Add each qc param as an empty column
for qc_param in qc_params:
    km_tb_df[qc_param] = np.nan

## Add postalignqc to dataframe

In [None]:
paqc_params = ["1", "10", "30", "100", "250", "500"]
# Add each qc param as an empty column
for paqc_param in paqc_params:
    km_tb_df[paqc_param] = np.nan

## Loop through sample ids' JASEN output and add columns

In [None]:
for sample_id in km_tb_df['Labnummer']:
    sample_pred_array = get_sample_pred_array(tbprofiler_result_dir, sample_id, "_tbprofiler.json")
    sample_qc_array = get_sample_qc_array(quast_result_dir, sample_id, "_quast.tsv", qc_params)
    sample_paqc_array = get_sample_paqc_array(paqc_result_dir, sample_id, "_bwa.qc", paqc_params)
    if sample_pred_array:
        # Identify the row index that matches the Labnummer
        row_index = km_tb_df[km_tb_df['Labnummer'] == sample_pred_array['Labnummer']].index
        for key, value in sample_pred_array.items():
            if key != 'Labnummer':
                km_tb_df[key] = km_tb_df[key].astype(object)
                km_tb_df.loc[row_index, key] = value
    if sample_qc_array:
        row_index = km_tb_df[km_tb_df['Labnummer'] == sample_qc_array['Labnummer']].index
        for key, value in sample_qc_array.items():
            if key != 'Labnummer':
                # km_tb_df[key] = km_tb_df[key].astype(object)
                km_tb_df.loc[row_index, key] = value
    if sample_paqc_array:
        row_index = km_tb_df[km_tb_df['Labnummer'] == sample_paqc_array['Labnummer']].index
        for key, value in sample_paqc_array.items():
            if key != 'Labnummer':
                # km_tb_df[key] = km_tb_df[key].astype(object)
                km_tb_df.loc[row_index, key] = value
km_tb_df.to_csv(tb_validation_csv_fpath, index=False)
# print(km_tb_df.head())

## Set column pairs

In [None]:
column_pairs = [
    ('Isoniazid 0.1', 'isoniazid'),
    ('Isoniazid 0.4', 'isoniazid'),
    ('Genotypisk Rifampicin', 'rifampicin'),
    ('Rifampicin 1.0', 'rifampicin'),
    ('Etambutol 2.5/5.0', 'ethambutol'),
    ('Amikacin 1.0', 'amikacin'),
    ('Ofloxacin 2.0', 'ofloxacin'),
    ('Pyrazinamid 100', 'pyrazinamide'),
    ('Linezolid', 'linezolid'),
    ('Streptomycin', 'streptomycin'),
    ('Kanamycin', 'kanamycin'),
    ('Moxifloxacin', 'moxifloxacin'),
    ('Levofloxacin', 'levofloxacin'),
    ('Rifabutin', 'rifabutin'),
    ('Ethionamid', 'ethionamide'),
    ('Capreomycin', 'capreomycin'),
    ('Cykloserin', 'cycloserine'),
    ('PAS', 'PAS'),
    ('Clofazimine', 'clofazimine'),
    ('Bedakilin', 'bedaquiline'),
    ('Delamanid', 'delamanid')
]


## Plot AMR calling accuracy

In [None]:
# Initialize a DataFrame to hold the results
results = pd.DataFrame(columns=['True positives', 'True negatives', 'False positives', 'False negatives', 'Not tested', 'Unknown'])

ass_loa = []
n_counts = []
for col1, col2 in column_pairs:
    true_positives = ((km_tb_df[col1] == km_tb_df[col2]) | (km_tb_df[col1] == "R") & (km_tb_df[col2] == "I")).sum()
    true_negatives = ((km_tb_df[col1] == "S") & (km_tb_df[col2].isna())).sum()
    false_positives = ((km_tb_df[col1] == "S") & (~km_tb_df[col2].isna())).sum()
    false_negatives = ((km_tb_df[col1] == "R") & (km_tb_df[col2].isna())).sum()
    not_tested = (km_tb_df[col1].isin(["ej testad", "Ej testad", np.nan])).sum()
    unknown = (~km_tb_df[col1].isin(["ej testad", "Ej testad", "R", "S", np.nan])).sum()
    total_count = len(km_tb_df)
    n_counts.append(total_count)
    if true_positives + true_negatives + false_positives + false_negatives == 0:
        accuracy = 0.0
    else:
        accuracy = (true_positives+true_negatives)/(true_positives+true_negatives+false_positives+false_negatives)
    if true_positives + false_negatives == 0:
        sensitivity = 0.0
    else:
        sensitivity = true_positives / (true_positives + false_negatives)
    if true_negatives + false_positives == 0:
        specificity = 0.0
    else:
        specificity = true_negatives/(true_negatives+false_positives)
    ass_loa.append({
        "comparison": f"{col1} vs {col2}",
        "accuracy": accuracy,
        "sensitivity": sensitivity,
        "specificity": specificity
    })
    results.loc[f'{col1} vs {col2}'] = [
        true_positives / total_count * 100,
        true_negatives / total_count * 100,
        false_positives / total_count * 100,
        false_negatives / total_count * 100,
        not_tested / total_count * 100,
        unknown / total_count * 100
    ]

# Write out csv
results.to_csv(tb_validation_accuracy_output, index=True)

# Plotting
fig, ax = plt.subplots(figsize=(6, 8))
results.plot(kind='barh', stacked=True, ax=ax, color=['green', 'gold', 'orange', 'red', 'darkblue', 'lightblue'])

# Customizing plot
ax.set_xlabel('Percentage (%)')
ax.set_title('AMR calling accuracy')
ax.legend(['True positives', 'True negatives', 'False positives', 'False negatives', 'Not tested', 'Unknown'], bbox_to_anchor=(1.05, 1), loc='upper left')
ax.set_xlim(0, 110)
for i in range(len(results)):
    ax.text(100, i, f'n={n_counts[i]}', ha='left', va='center', fontsize=10, color='black')
plt.savefig(tb_validation_amr_fig, dpi=600, bbox_inches='tight')
plt.show()

ass_df = pd.DataFrame(ass_loa)
ass_df.to_csv(tb_validation_ass_output, index=False, header=False)

## Plot QC scores

In [None]:
# Initialize a DataFrame to hold the results
qc_results = km_tb_df[['Labnummer'] + qc_params]

# Plot box plots for each of the qc_params
# Create a 2x2 grid for the subplots
fig, axs = plt.subplots(2, 2, figsize=(6, 8))

# Flatten the axes array for easy iteration
axs = axs.flatten()

# Plot each boxplot in a separate subplot
for i, qc_param in enumerate(qc_params):
    sns.boxplot(y=qc_results[qc_param], ax=axs[i])
    axs[i].set_title(qc_param)
    axs[i].set_xlabel('Value')


# Label subplots A, B, C, D
labels = ['A', 'B', 'C', 'D']
for ax, label in zip(axs, labels):
    ax.text(-0.1, 1.05, label, transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='right')

plt.tight_layout()
plt.savefig(tb_validation_qc_fig, dpi=600, bbox_inches='tight')
plt.show()

qc_results.to_csv(tb_validation_qc_output, index=False)

## Plot read depth

In [None]:
# Initialize a DataFrame to hold the results
paqc_results = km_tb_df[['Labnummer'] + paqc_params]
# Plot box plots for each of the paqc_params
# Create a 2x2 grid for the subplots
fig, axs = plt.subplots(3, 2, figsize=(10, 8))

# Flatten the axes array for easy iteration
axs = axs.flatten()

# Plot each boxplot in a separate subplot
for i, paqc_param in enumerate(paqc_params):
    sns.boxplot(y=paqc_results[paqc_param], ax=axs[i])
    axs[i].set_title(f"Coverage depth above {paqc_param}X")
    axs[i].set_xlabel(paqc_param)
    axs[i].set_ylim(0, 105)


# Label subplots A, B, C, D
labels = ['A', 'B', 'C', 'D', 'E', 'F']
for ax, label in zip(axs, labels):
    ax.text(-0.1, 1.05, label, transform=ax.transAxes,
            fontsize=16, fontweight='bold', va='top', ha='right')

plt.tight_layout()
plt.savefig(tb_validation_paqc_fig, dpi=600, bbox_inches='tight')
plt.show()

paqc_results.to_csv(tb_validation_paqc_output, index=False)

## Inspect two columns

In [None]:
# columns_to_print = ['Linezolid', 'linezolid', 'Labnummer']
# print(km_tb_df[columns_to_print])
# km_tb_df[columns_to_print].to_csv(f"./tb_validation_{columns_to_print[1]}.csv", index=False)
