In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
import pandas as pd

# Visualizations for the publication on CRC biomarker prediction with Transformer

1. Code to plot the AUROC heatmaps 
2. Data overview plots

## AUROC heatmaps

### Load data from result excel-file into dataframes

In [None]:
# path to save the figures
figure_path = Path('figures')

In [None]:
# identify runs
norm = 'histaugan'
target = 'isMSIH'
sheet_name = 'MSI high (single cohorts)'
# sheet_name = 'BRAF (single cohorts)'
# sheet_name = 'KRAS (single cohorts)'
# sheet_name = 'MSI high (AttentionMIL)'
# sheet_name = 'MSI high (Echle et al.)'

results = Path('experimental_results.xlsx')
results = pd.read_excel(results, sheet_name=sheet_name)

# choose training cohorts (automatically or manually)
# cohorts = results['Test'].unique().tolist()  # infer from dataframe
cohorts = ['CPTAC', 'DACHS', 'DUSSEL', 'Epi700', 'ERLANGEN', 'FOXTROT', 'MCO', 'MECC', 'MUNICH', 'QUASAR', 'RAINBOW', 'TCGA', 'TRANSCOT', 'YCR-BCIP-resections']  # MSI-H cohorts
# cohorts = ['DACHS', 'Epi700', 'MCO', 'QUASAR', 'RAINBOW', 'TCGA']  # BRAF/KRAS cohorts
print(cohorts)

In [None]:
for k in results.keys():
    if k not in ['Train', 'Test', 'auroc/test mean']:
        results = results.drop(columns=[k])

In [None]:
results

In [None]:
# data stats
cohort_size = {
    'CHINA': 35,
    'CPTAC': 105,
    'DACHS': 2039,
    'DUSSEL': 196,
    'Epi700': 603, 
    'ERLANGEN': 458,
    'FOXTROT': 702,
    'MAINZ': 86,
    'MCO': 1388, 
    'MECC': 683, 
    'MUNICH': 287, 
    'QUASAR': 1774,
    'RAINBOW': 2068, 
    'TCGA': 426, 
    'TRANSCOT': 1972, 
    'YCR-BCIP-resections': 867
}

cohort_label = {
    'CHINA': 'GUANGZHOU',
    'CPTAC': 'CPTAC',
    'DACHS': 'DACHS',
    'DUSSEL': 'DUSSEL',
    'Epi700': 'Epi700', 
    'ERLANGEN': 'ERLANGEN',
    'FOXTROT': 'FOXTROT',
    'MAINZ': 'MAINZ',
    'MCO': 'MCO', 
    'MECC': 'MECC', 
    'MUNICH': 'MUNICH', 
    'QUASAR': 'QUASAR',
    'RAINBOW': 'NLCS', 
    'TCGA': 'TCGA', 
    'TRANSCOT': 'TRANSCOT', 
    'YCR-BCIP-resections': 'YCR-BCIP',
}

In [None]:
# sort cohorts by size
cohorts_ordered = sorted(cohorts, key=lambda x: cohort_size[x], reverse=True)
size_ordered = [cohort_size[c] for c in cohorts_ordered]
label_ordered = [cohort_label[c] for c in cohorts_ordered]

### Plot heatmaps

Plot large heatmaps with number of samples per cohort

In [None]:
%%capture
# create heatmap with metric values
heatmap = np.zeros((len(cohorts_ordered), len(cohorts_ordered)))
for i, c_i in enumerate(cohorts_ordered):
    for j, c_j in enumerate(cohorts_ordered):
        try: 
            heatmap[j, i] = results[results['Train'] == c_i][results['Test'] == c_j]['auroc/test mean'].values[0]
        except IndexError:
            continue

In [None]:
name = sheet_name
labels = False  # show labels in heatmap

fig, ax = plt.subplots(nrows=1, ncols=2, figsize=((6/5)*5, 5), gridspec_kw={'width_ratios': [5, 1], 'wspace':0, 'hspace':0}) # figsize=(7, 5) 'width_ratios': [6, 1]
im = ax[0].imshow(heatmap.T, vmin=0.55, vmax=1, cmap='plasma')

if labels:
    # Loop over data dimensions and create text annotations.
    for i in range(len(cohorts_ordered)):
        for j in range(len(cohorts_ordered)):
            text = ax[0].text(j, i, f'{heatmap.T[i][j]:.2f}',
                        ha="center", va="center", color="w", fontsize=14)

ax[0].set_xticks(np.arange(len(label_ordered)), labels=label_ordered)
ax[0].set_yticks(np.arange(len(label_ordered)), labels=label_ordered)

plt.setp(ax[0].get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")

fig.tight_layout()

# ax[0].set(
#     title=f"AUROC scores (target: {target})",
# )
ax[0].spines['top'].set_visible(False)
ax[0].spines['right'].set_visible(False)
ax[0].spines['left'].set_visible(False)
ax[0].spines['bottom'].set_visible(False)

ax[0].set_ylabel('Train')
ax[0].set_xlabel('Test')

# plot bar plot
ax[1].barh(range(len(size_ordered), 0, -1), size_ordered, color='slategray')

ax[1].axes.get_yaxis().set_visible(False)
ax[1].set_xticks((0, 1000, 2000))
plt.setp(ax[1].get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")
ax[1].spines['top'].set_visible(False)
ax[1].spines['right'].set_visible(False)
ax[1].set(
    ylim=[0.5, len(size_ordered) + 0.5],
)
ax[1].set_xlabel('Counts')

# fig.savefig(figure_path / f'results_heatmap_{name}_new_cohorts.svg', format='svg', bbox_inches = 'tight', pad_inches = 0)
plt.show()

Plot smaller heatmaps for comparison with other methods

In [None]:
cohorts = ['RAINBOW', 'DACHS', 'QUASAR', 'TCGA']

# sort cohorts by size
cohorts_ordered = sorted(cohorts, key=lambda x: cohort_size[x], reverse=True)
size_ordered = [cohort_size[c] for c in cohorts_ordered]
label_ordered = [cohort_label[c] for c in cohorts_ordered]

In [None]:
%%capture
heatmap = np.zeros((len(cohorts_ordered), len(cohorts_ordered)))
for i, c_i in enumerate(cohorts_ordered):
    for j, c_j in enumerate(cohorts_ordered):
        try: 
            heatmap[j, i] = results[results['Train'] == c_i][results['Test'] == c_j]['auroc/test mean'].values[0]
        except IndexError:
            continue

In [None]:
model = 'echle'
name = 'histaugan'

fig, ax = plt.subplots(figsize=(1.875, 1.875))
im = ax.imshow(heatmap.T, vmin=0.55, vmax=1, cmap='plasma')

labels = heatmap.T
# Loop over data dimensions and create text annotations.
for i in range(len(cohorts)):
    for j in range(len(cohorts)):
        text = ax.text(j, i, f'{labels[i][j]:.2f}',
                       ha="center", va="center", color="w", fontsize=10)
        

ax.set_xticks(np.arange(len(cohorts)), labels=label_ordered)
ax.set_yticks(np.arange(len(cohorts)), labels=label_ordered)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.spines['bottom'].set_visible(False)

# ax.set_xlabel('Train')
# ax.set_ylabel('Transformer')

ax.axes.get_xaxis().set_visible(True)

        
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
         rotation_mode="anchor")
# fig.savefig(figure_path / f'results_{model}_{name}_zoom.svg', format='svg', bbox_inches = 'tight', pad_inches = 0)

plt.show()

## Donut plots for cohort overview

In [None]:
# all cohorts
cohorts_all = ['CPTAC', 'DACHS', 'DUSSEL', 'ERLANGEN', 'Epi700', 'FOXTROT', 'CHINA', 'MAINZ', 'MCO', 'MECC', 'MUNICH', 'RAINBOW', 'QUASAR', 'TCGA', 'TRANSCOT', 'YCR-BCIP-resections']

In [None]:
# --- MSI cohorts
# target = 'MSI'
# cohorts = ['CPTAC', 'DACHS', 'DUSSEL', 'ERLANGEN', 'Epi700', 'FOXTROT', 'CHINA', 'MAINZ', 'MCO', 'MECC', 'MUNICH', 'RAINBOW', 'QUASAR', 'TCGA', 'TRANSCOT', 'YCR-BCIP-resections']

# --- BRAF/KRAS cohorts
# target = 'BRAF'
target = 'KRAS'
cohorts = ['DACHS',  'Epi700', 'MCO',  'QUASAR', 'RAINBOW', 'TCGA', ]

In [None]:
# --- MSI cohorts
# cohort_size = {
#     'CHINA': 35,
#     'CPTAC': 105,
#     'DACHS': 2039,
#     'DUSSEL': 196,
#     'Epi700': 603, 
#     'ERLANGEN': 458,
#     'FOXTROT': 702,
#     'MAINZ': 86,
#     'MCO': 1388, 
#     'MECC': 683, 
#     'MUNICH': 287, 
#     'QUASAR': 1774,
#     'RAINBOW': 2068, 
#     'TCGA': 426, 
#     'TRANSCOT': 1972, 
#     'YCR-BCIP-resections': 867
# }

# --- BRAF cohorts
cohort_size = {
    'DACHS': 2075,
    'Epi700': 641, 
    'MCO': 1388, 
    'QUASAR': 1477,
    'RAINBOW': 2038, 
    'TCGA': 500, 
}

# --- KRAS cohorts
# cohort_size = {
#     'DACHS': 2068,
#     'Epi700': 645, 
#     'MCO': 1390, 
#     'QUASAR': 1436,
#     'RAINBOW': 2033, 
#     'TCGA': 500, 
# }

In [None]:
sum(cohort_size.values())

In [None]:
cohorts_ordered = sorted(cohorts, key=lambda x: cohort_label[x])
size_ordered = [cohort_size[c] for c in cohorts_ordered]
label_ordered = [cohort_label[c] for c in cohorts_ordered]
cohort_ids = [cohorts_all.index(c) for c in cohorts_ordered]
print(f'{len(cohorts)} cohorts with {sum(cohort_size.values())} patients')
print(label_ordered)
print(size_ordered)

In [None]:
# plot donut chart
fig, ax = plt.subplots(figsize=(6, 6), subplot_kw=dict(aspect="equal"))

cmap = plt.get_cmap("twilight")
r = np.array(cohort_ids) / len(cohorts_all)
colors = cmap(r)

wedges, texts = ax.pie(size_ordered, wedgeprops=dict(width=0.5, linewidth=2, edgecolor='w'), startangle=90, colors=colors)

plt.legend(label_ordered, bbox_to_anchor=(1.35, 0.5), loc='right', fontsize=16)

# plt.savefig(figure_path / f"num_patients_{target}.svg", format='svg', bbox_inches = 'tight', pad_inches = 0)
plt.show()