In [None]:
import os

os.environ['SF_BACKEND'] = 'torch'
os.environ['SF_SLIDE_BACKEND'] = 'libvips'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

import slideflow as sf
from slideflow.slide import qc
import pandas as pd
import slideflow.clam
import fastai
import numpy as np
from sklearn.metrics import precision_recall_curve, auc, f1_score, roc_auc_score
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
sf.about()

# Tile extraction

In [None]:
# P = sf.Project('/home/sramesh/PROJECTS/NB')
# P.annotations = "./annotations_updated_031824.csv"
# dataset = P.dataset(tile_px=299, tile_um=302, filters={'site': ['3'], 'QC':['Pass']}) 

# qc_list = [qc.GaussianV2(), qc.Otsu()]

# dataset.extract_tiles(qc=qc_list, grayspace_fraction=1, roi_method='inside', img_format='jpg',skip_extracted=False, dry_run=False)

# Feature Generation

In [None]:
# from slideflow.model import build_feature_extractor

# # Build extractor
# ctranspath = sf.model.build_feature_extractor('ctranspath', tile_px=299)
# features_pt=P.generate_features(ctranspath,dataset=dataset, normalizer = 'reinhard_mask')

# # Export feature bags.
# features_pt.to_torch('/home/sramesh/PROJECTS/NB/pt_files/NB_ctranspath')

# Dataset Handling

In [None]:
P = sf.Project('/home/sramesh/PROJECTS/NB')

P.annotations = '/home/sramesh/PROJECTS/NB/annotations_updated_031824.csv'
full_dataset = P.dataset(
    tile_px=299, 
    tile_um=302, 
    filters={'QC': ['Pass']},
    min_tiles=20,
    sources=["NB-1", "NB-2", "NB-3", "NB-4","NB-external"]
)
outcome_dataset = full_dataset.filter({'Diagnosis':['Neuroblastoma'],
#'Grade':['Poorly Differentiated', 'Differentiating']})
#'MKI':['Low/intermediate','High']})
'MYCN':['Non-Amplified','Amplified']})

print("All slides/tfrecords:", len(full_dataset.tfrecords()))
print("Outcome-specific slides/tfrecords:", len(outcome_dataset.tfrecords()))

In [None]:
label_counts = outcome_dataset.annotations['MYCN'].value_counts()
print(label_counts)

### Slide Manifest Reconciliation:
</br>188 training slides
</br>184 extracted (47, 97-3640, 4 - corrupted; 99-1955 - no tiles passed qc + path review said exclude)
</br>176 - 8 additional slides failed QC
</br>172 - 4 slides had less than 20 tiles extracted
</br>
</br>25 - external cohort
</br>
</br>Total - 197 (Note: This is only if all outcome labels are available, if subsets of patients have missing labels then total count is reduced)

### Splits for k-fold and internal/external validation

In [None]:
#Split the dataset 
# train,val = outcome_dataset.split(
#     val_fraction=0.15,
#     labels='Grade',
#     splits='/home/sramesh/PROJECTS/NB/splits/split_Grade_8515-2.json'
# )

In [None]:
# splits=outcome_dataset.kfold_split(
#     k=5,
#     labels='Grade',
#     splits='/home/sramesh/PROJECTS/NB/splits/split_Grade_5fold.json'
# )

In [None]:
#Split the dataset for train/test
train = outcome_dataset.filter({'site':['1','2']})

val = outcome_dataset.filter({'site':['3']})

In [None]:
print('Train size: '+str(len(train.tfrecords()))+', '+ 'Val size: '+ str(len(val.tfrecords())))

In [None]:
from slideflow.mil import mil_config

config = mil_config(
    'attention_mil', 
    wd= 1e-05,
    bag_size= 256,
    #fit_one_cycle= False,
    epochs= 10,
    batch_size= 32,
    #lr=7e-4
)

In [None]:
# Train on each cross-fold
# fold_counter=1
# for train, val in splits:
#     P.train_mil(
#         config=config,
#         outcomes='Grade',
#         train_dataset=train,
#         val_dataset=val,
#         bags='/home/sramesh/PROJECTS/NB/pt_files/NB_ctranspath',
#         exp_label=f'NB_ctranspath_RHMnorm_Grade-NB-kfold-{fold_counter}',
#         attention_heatmaps=True
#     )
#     fold_counter+=1

In [None]:
# Train MIL with fixed internal-external split
P.train_mil(
    config=config,
    train_dataset=train,
    val_dataset=None,
    outcomes='MYCN',
    bags='/home/sramesh/PROJECTS/NB/pt_files/NB_ctranspath',
    exp_label='NB_ctranspath_RHMnorm_MYCN-fulltrain',
    attention_heatmaps=False)

In [None]:
#Evaluate a saved MIL model

P.evaluate_mil(
    '/home/sramesh/PROJECTS/NB/mil/00099-NB_ctranspath_RHMnorm_MYCN-fulltrain-revised/',
    outcomes='MYCN',
    dataset=val,
    bags='/home/sramesh/PROJECTS/NB/pt_files/NB_ctranspath',
    attention_heatmaps=True
)

# K-Fold Results Aggregation

In [None]:
# For k-fold

# Initialize an empty DataFrame to store aggregated data
# aggregated_data = pd.DataFrame()
# k=1

# # Loop through each fold and aggregate data
# for i in range(89, 94):
#     file_name = f'/home/sramesh/PROJECTS/NB/mil/000{i}-NB_ctranspath_RHMnorm_Grade-NB-kfold-{k}/predictions.parquet'
#     data = pd.read_parquet(file_name)
#     aggregated_data = pd.concat([aggregated_data, data])
#     k+=1

# For external validation
file_name = f'/home/sramesh/PROJECTS/NB/mil_eval/00038-attention_mil-MKI_fulltrain-revised/predictions.parquet'
aggregated_data = pd.read_parquet(file_name)

# Define a threshold for classification
threshold = 0.5

# Classify as 1 if y_pred1 > threshold, else 0
aggregated_data['prediction'] = np.where(aggregated_data['y_pred1'] > threshold, 1, 0)

# Calculate TP, FN, TN, FP
TP = np.sum((aggregated_data['y_true'] == 1) & (aggregated_data['prediction'] == 1))
FN = np.sum((aggregated_data['y_true'] == 1) & (aggregated_data['prediction'] == 0))
TN = np.sum((aggregated_data['y_true'] == 0) & (aggregated_data['prediction'] == 0))
FP = np.sum((aggregated_data['y_true'] == 0) & (aggregated_data['prediction'] == 1))

# Calculate Sensitivity, Specificity, Precision, Recall, F1, AUPRC, AUROC
sensitivity = TP / (TP + FN) if (TP + FN) > 0 else 0
specificity = TN / (TN + FP) if (TN + FP) > 0 else 0
precision = TP / (TP + FP) if (TP + FP) > 0 else 0
recall = sensitivity  # Recall is equivalent to sensitivity
f1 = f1_score(aggregated_data['y_true'], aggregated_data['prediction'])
precision_curve, recall_curve, _ = precision_recall_curve(aggregated_data['y_true'], aggregated_data['y_pred1'])
auprc = auc(recall_curve, precision_curve)
auroc = roc_auc_score(aggregated_data['y_true'], aggregated_data['y_pred1'])

In [None]:
# Print the metrics
print("MKI Aggregated Metrics:")
print(f"Average Sensitivity: {sensitivity:.4f}")
print(f"Average Specificity: {specificity:.4f}")
print(f"Average Precision: {precision:.4f}")
print(f"Average Recall: {recall:.4f}")
print(f"Average F1-Score: {f1:.4f}")
print(f"Average AUPRC: {auprc:.4f}")
print(f"Average AUROC: {auroc:.4f}")

In [None]:
# Create a mask for correct predictions
is_correct_prediction = (aggregated_data['y_pred1'] > threshold) == aggregated_data['y_true']

# Filter the DataFrame for correct predictions
correct_predictions_df = aggregated_data[is_correct_prediction]

# Display the DataFrame with correct predictions
pd.set_option('display.max_rows', None)
print(correct_predictions_df)

correct_predictions_df.to_csv('MKI-correct-predictions-external.csv')

In [None]:
# Resetting the index of the DataFrame to ensure no duplicate indices
correct_predictions_df_reset = correct_predictions_df.reset_index(drop=True)

# Plotting the distribution again
plt.figure(figsize=(10, 6))
sns.histplot(data=correct_predictions_df_reset, x="y_pred1", hue="y_true", bins=20, kde=True)
plt.title("Distribution of Correct Predictions by Confidence")
plt.xlabel("Prediction Confidence")
plt.ylabel("Count")
plt.legend(title="True Label", labels=["1", "0"])
plt.show()


# Sensitivity vs. Specificity

In [None]:
data = pd.read_parquet('/home/sramesh/PROJECTS/NB/mil_eval/00037-attention_mil-MYCN_fulltrain/predictions.parquet')

# Define a threshold for classification
threshold = 0.5

# Classify as 1 if y_pred1 > threshold, else 0
data['prediction'] = np.where(data['y_pred1'] > threshold, 1, 0)

# Calculate True Positives, False Positives, True Negatives, and False Negatives
TP = np.sum((data['y_true'] == 1) & (data['prediction'] == 1))
FN = np.sum((data['y_true'] == 1) & (data['prediction'] == 0))
TN = np.sum((data['y_true'] == 0) & (data['prediction'] == 0))
FP = np.sum((data['y_true'] == 0) & (data['prediction'] == 1))

# Calculate Sensitivity and Specificity
sensitivity = TP / (TP + FN)
specificity = TN / (TN + FP)

print(f"Sensitivity: {sensitivity}")
print(f"Specificity: {specificity}")

In [None]:
# Calculate Precision and Recall
precision, recall, _ = precision_recall_curve(data['y_true'], data['y_pred1'])
auprc = auc(recall, precision)

# Calculate Precision (Positive Predictive Value)
precision_value = TP / (TP + FP) if (TP + FP) > 0 else 0

# Calculate F1 Score
f1 = f1_score(data['y_true'], data['prediction'])

print(f"Precision: {precision_value}")
print(f"AUPRC: {auprc}")
print(f"F1-score: {f1}")

# Incorrect Prediction Assessment

In [None]:
# Classify as 1 if y_pred1 > threshold, else 0
data['prediction'] = np.where(data['y_pred1'] > threshold, 1, 0)

# Filter out the rows where the prediction is incorrect
data_wrong = data[data['y_true'] != data['prediction']]

# Output the rows where the prediction is incorrect
print(data_wrong)

In [None]:
# Classify as 1 if y_pred1 > threshold, else 0
data['prediction'] = np.where(data['y_pred1'] > threshold, 1, 0)

# Filter out the rows where the prediction is incorrect
data_correct = data[data['y_true'] == data['prediction']]

# Output the rows where the prediction is incorrect
print(data_correct)

In [None]:
from slideflow.mil import predict_slide
from slideflow.slide import qc

# Load a slide and apply Otsu thresholding
slide = '/home/sramesh/labshare/SLIDES/UCH_APPLEBAUM/May7_Scans/Mark_Applebaum-042_05-07.svs'
wsi = sf.WSI(slide, tile_px=299, tile_um=302)
wsi.qc(qc.Otsu())

# Calculate predictions and attention heatmap
model = '/home/sramesh/PROJECTS/NB/mil/00013-NB_ctranspath_RHMnorm_dx'
y_pred, y_att = predict_slide(model, wsi)

# UMAP

In [None]:
# ctranspath = sf.model.build_feature_extractor('ctranspath', tile_px=299)

# # Generate DatasetFeatures from an extractor
# ftrs = sf.DatasetFeatures(ctranspath, full_dataset, normalizer= 'reinhard_mask',cache='/home/sramesh/PROJECTS/NB/UMAP/ctranspath.pkl')

# # Create the base UMAP
# slide_map = ftrs.map_activations()

# # Load annotations (site, outcomes)
# outcome_labels, _ = full_dataset.labels('Diagnosis_binary_old', format='name')

# # Label UMAP by outcome
# slide_map.label_by_slide(outcome_labels)
# #slide_map.save_plot('/home/sramesh/PROJECTS/NB/UMAP/ctranspath_Diagnosis_binary_old.png', s=5)
# ftrs