In [1]:
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix
from sklearn.utils import resample
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import os
import torch 
from argparse import Namespace
from tqdm import tqdm
import pickle 
import glob 
import ast

# Add the path to the directory containing the sybil module
sys.path.append('/workspace/home/tengyuezhang/sybil_cect/code/Sybil/')
from sybil.utils.metrics import concordance_index, get_survival_metrics
from sybil import Sybil, Serie
from sybil import visualize_attentions


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
num_threads = os.cpu_count() // 5

In [3]:
mln_with_labels_path = '/workspace/home/tengyuezhang/sybil_cect/data/MLN-SEG/MLN_SEG_Sybil_cases_w_outcome.csv'
data_root_dir = '/workspace/data/lung/MLN_SEG'
output_path = '/workspace/home/tengyuezhang/sybil_cect/results/MLN-SEG/MLN_SEG_risk_scores.csv'
vis_dir_path = "/workspace/home/tengyuezhang/sybil_cect/visualizations/MLN_SEG_attention_maps"
save_atten_maps = True
if not os.path.exists(vis_dir_path):
    os.makedirs(vis_dir_path)

In [4]:
# Initialize the Sybil model
model = Sybil("sybil_ensemble")
num_years = 6



In [5]:
# Load the CSV file
all_cases = pd.read_csv(mln_with_labels_path)
df = all_cases

In [6]:
for i in range(num_years):
    df[f'pred_risk_year_{i}'] = np.nan

for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing cases"):
    
    dicom_dir = os.path.abspath(os.path.join(data_root_dir, row['Directory']))
    event = row['LungCancer']
    years_to_event = 1
    pid = row['pid']
    dicom_list = glob.glob(dicom_dir + '/*')
    serie = Serie(dicom_list, label=event, censor_time=years_to_event)
    
    results = model.predict([serie], return_attentions=True, threads=num_threads)
        
    # Update the risk scores columns for the current row
    for i in range(num_years):
        df.at[index, f'pred_risk_year_{i}'] = results.scores[0][i]
        
    # Save the updated DataFrame to the output CSV file at each iteration
    df.to_csv(output_path, index=False)
    
    # Save attention maps 
    if save_atten_maps: 
        attentions = results.attentions

        series_with_attention = visualize_attentions(
            serie,
            attentions = attentions,
            pid = pid, 
            save_directory = vis_dir_path,
            gain = 1, 
            save_pngs = True, 
            save_rep_slice = True,
        )

Processing cases: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 97/97 [33:32<00:00, 20.75s/it]


# Attention maps

In [6]:
test_case = df.iloc[0]
test_case

pid                                                                        case_0018
Directory                          ./Mediastinal-Lymph-Node-SEG/case_0018/07-26-2...
Modality                                                                          CT
Study Date                                                                07-26-2007
Contrast/Bolus Agent                                      VOLUMEN & 100CC/ 2.5CC/SEC
Body Part Examined                                                             CHEST
Slice Thickness                                                                  2.5
Contrast/Bolus Route                                                       Oral & IV
Contrast/Bolus Ingredient                                                        NaN
Requested Procedure Description                                                  NaN
PrimaryCondition                                                  Hodgkin`s Lymphoma
LungCancer                                                       

In [7]:
test_vis_dir_path = "/workspace/home/tengyuezhang/sybil_cect/visualizations/test_MLN_SEG_attention_maps/"


In [8]:

test_data_dir = os.path.abspath(os.path.join(data_root_dir, test_case['Directory']))
event = test_case['LungCancer']
years_to_event = 1
pid = test_case['pid']
dicom_list = glob.glob(test_data_dir + '/*')
serie = Serie(dicom_list, label=event, censor_time=years_to_event)
results = model.predict([serie], return_attentions=True, threads=num_threads)


attentions = results.attentions



In [9]:
from sybil import visualize_attentions

series_with_attention = visualize_attentions(
    serie,
    attentions = attentions,
    pid = pid, 
    save_directory = test_vis_dir_path,
    gain = 3, 
    save_pngs = True, # defaults to True
    save_rep_slice = True, # defaults to True 
)