In [1]:
import numpy as np
import pandas as pd
from sklearn.metrics import roc_auc_score, roc_curve, confusion_matrix
from sklearn.utils import resample
import matplotlib.pyplot as plt
import seaborn as sns
import sys
import os
import torch 
from argparse import Namespace
from tqdm import tqdm
import pickle 
import glob 
import ast

# Add the path to the directory containing the sybil module
sys.path.append('/workspace/home/tengyuezhang/sybil_cect/code/Sybil/')
from sybil.utils.metrics import concordance_index, get_survival_metrics
from sybil import Sybil, Serie
from sybil import visualize_attentions_v2

In [2]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
num_threads = os.cpu_count() // 5

In [3]:
IN_CASES_PATH = '/workspace/home/kkulkarni/Sybil/Results/lungx_Diagnosis_corrected.csv' 
OUT_RISK_PATH = '/workspace/home/tengyuezhang/sybil_cect/results/lungx/lungx_risk_scores.csv'
# attention maps 
# OUT_VIS_DIR_PATH = '/workspace/home/tengyuezhang/sybil_cect/visualizations/lungx_attention_maps'
OUT_VIS_DIR_PATH = '/workspace/home/tengyuezhang/sybil_cect/visualizations/lungx_attention_maps_w_nodule'
SAVE_ATTN_MAPS = True 
if SAVE_ATTN_MAPS and not os.path.exists(OUT_VIS_DIR_PATH):
    os.makedirs(OUT_VIS_DIR_PATH)
# nodule 
IN_NODULE_PATH = '/workspace/home/tengyuezhang/sybil_cect/data/lungx/lungx_nodule_location.csv'

In [4]:
# Initialize the Sybil model
model = Sybil("sybil_ensemble")
num_years = 6



In [5]:
# Load the CSV file
all_cases = pd.read_csv(IN_CASES_PATH)
df = all_cases 
df['pid'] = df['Scan Number']

In [6]:
# load nodule location for visualization 
nodule_df = pd.read_csv(IN_NODULE_PATH)
nodule_df

Unnamed: 0,pid,Nodule Number,"Nodule Center x,y Position*",Nodule Center Image,Final Diagnosis,diagnosis
0,LUNGx-CT001,1,"135, 303",142,Benign nodule,0
1,LUNGx-CT002,1,"330, 348",205,Benign nodule,0
2,LUNGx-CT002,2,"364, 212",150,Benign nodule,0
3,LUNGx-CT004,1,"197, 290",41,Primary lung cancer,1
4,LUNGx-CT004,2,"328, 242",80,Primary lung cancer,1
...,...,...,...,...,...,...
58,CT-Training-be001,1,"405, 296",169,benign,0
59,CT-Training-be002,1,"184, 268",117,benign,0
60,CT-Training-be006,1,"449, 266",241,benign,0
61,CT-Training-be007,1,"385, 206",194,benign,0


In [7]:
for i in range(num_years):
    df[f'pred_risk_year_{i}'] = np.nan 



for index, row in tqdm(df.iterrows(), total=len(df), desc="Processing cases"):
    dicom_dir = os.path.join('/workspace', row['Directory'][1:])
    event = 0
    years_to_event = 1
    pid = row['pid']
    dicom_list = glob.glob(dicom_dir + '/*')
    serie = Serie(dicom_list, label=event, censor_time=years_to_event)
    
    # get predicted risk scores and features from the last hidden layer (returned along with the attentions)
    results = model.predict([serie], return_attentions=True, threads=num_threads)
            
    # update the risk scores columns for the current row
    for i in range(num_years):
        df.at[index, f'pred_risk_year_{i}'] = results.scores[0][i]
        
    # update risk score csv file 
    df.to_csv(OUT_RISK_PATH, index=False)
    
    # Save attention maps 
    
    # ----- for lungx only -----
    
    save_bbox = True 
    filtered_nodule_df = nodule_df[nodule_df['pid'].str.lower() == pid.lower()]
    centers = [] 
    nodule_slices = [] 
    nodule_labels = []
    for i, r in filtered_nodule_df.iterrows(): 
        centers.append(r["Nodule Center x,y Position*"])
        nodule_slices.append(int(len(dicom_list) - r['Nodule Center Image']))
        nodule_labels.append(r['diagnosis'])
    
     # --------------------------
        
    if SAVE_ATTN_MAPS: 
        attentions = results.attentions

        series_with_attention = visualize_attentions_v2(
            serie,
            attentions = attentions,
            pid = pid, 
            save_directory = os.path.join(OUT_VIS_DIR_PATH, str(pid)),
            gain = 1, 
            save_pngs = True, 
            save_rep_slice = True,
            save_bbox = True, 
            centers = centers, 
            nodule_slices = nodule_slices, 
            nodule_labels = nodule_labels,
        )

Processing cases: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 54/54 [2:53:02<00:00, 192.26s/it]
