In [None]:
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import v2
import albumentations as A
from datetime import datetime
from albumentations.pytorch import ToTensorV2
from sklearn.preprocessing import MinMaxScaler, Binarizer
import numpy as np
from sklearn.preprocessing import minmax_scale
from scipy.stats import pearsonr
from mamba_ssm import Mamba
import seaborn as sns
import pickle, argparse
from matplotlib import pyplot as plt
import os
from tqdm import tqdm
from models.feature_extractor import CT25D
import cv2
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from TumorVesselDataset import CustomDataset, NormalVesselDataset
from sklearn.metrics import mean_squared_error, accuracy_score, precision_score, recall_score,confusion_matrix, f1_score, roc_auc_score
from sklearn.model_selection import train_test_split, StratifiedKFold
from pydicom import dcmread
import shutil
import transformers
from torch_ema import ExponentialMovingAverage
from glob import glob
from pycox.models.loss import cox_ph_loss, cox_cc_loss, nll_logistic_hazard, cox_ph_loss_sorted
import random, timm
import cv2
import os
import pandas as pd
from tqdm import tqdm
from glob import glob
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import logrank_test
import matplotlib.pyplot as plt
from scipy import stats

In [None]:


def calculate_confidence_interval(data, confidence=0.95):
    n = len(data)
    mean = np.mean(data)
    se = stats.sem(data)
    h = se * stats.t.ppf((1 + confidence) / 2., n-1)
    return f'{mean:.3f}({mean - h:.2f}-{mean + h:.2f})'

In [None]:
dataset = pd.read_csv('SBRT_cohort.csv', index_col=None)

In [None]:
valid_transform = A.Compose([
    A.Resize(224, 224, interpolation=cv2.INTER_CUBIC),
    ToTensorV2()
        ], additional_targets={
            'mask1': 'mask',
            'mask2': 'mask'
        })

In [None]:
class CONFIG:
    seed = 2
    root = 'data'    
    batch_size = 12
    fold = 1
    IMG_SIZE=224
    num_workers = 4    
    warmup_epoch = 10
    device = 'cuda:0'
    loss_type = 'BCE'    
    cnn_lr = 5e-6
    seq_lr = 5e-5
    lr = 5e-5
    weight_decay = 1e-3
    ema_decay = 0.995
    step=9
    interval= 1    
    ckpt = 'ckpt/best_RT.pt'    
    
args = CONFIG()

In [None]:
sbrt_dataset = CustomDataset(dataset, transform=valid_transform, mode='test')
sbrt_loader = DataLoader(sbrt_dataset, shuffle = False, batch_size = 1, num_workers = args.num_workers)

In [None]:
model = CT25D().to(args.device)
model.load_state_dict(torch.load(args.ckpt, 'cpu'), 'cpu')
model = model.eval()

In [None]:
pids = os.listdir(f'data/normal_vessels')

In [None]:
normal_dataset = NormalVesselDataset(pids, valid_transform, 'test', args)
normal_loader = DataLoader(normal_dataset, batch_size=32, shuffle=False, num_workers=args.num_workers)

In [None]:
normal_features = []
for images in tqdm(normal_loader):
    with torch.no_grad():
        features = model.get_features(images.to(args.device)).cpu()
    
    normal_features.append(features)

In [None]:
class MahalanobisDetector:
    def __init__(self, feature_dim):
        self.feature_dim = feature_dim
        self.mean = None
        self.inv_covariance = None
        
    def fit(self, normal_features):
        self.mean = torch.mean(normal_features, dim=0)
        
        centered_features = normal_features - self.mean
        covariance = torch.mm(centered_features.t(), centered_features) / (normal_features.size(0) - 1)
        
        covariance_np = covariance.cpu().numpy()
        epsilon = 1e-6
        covariance_np += epsilon * np.eye(covariance_np.shape[0])
        self.inv_covariance = torch.tensor(np.linalg.inv(covariance_np)).float()
                
        self.mean = self.mean
        self.inv_covariance = self.inv_covariance

    def calculate_distance(self, features):
        centered_features = features - self.mean
        
        distances = torch.sqrt(
            torch.sum(
                torch.mm(centered_features, self.inv_covariance) * centered_features,
                dim=1
            )
        )
        
        return distances
    
detector = MahalanobisDetector(768)
detector.fit(torch.cat(normal_features, 0))

In [None]:
from sklearn.preprocessing import minmax_scale

mahalanobis_distance = []
vessel_density = []

for bs_img, bs_v_msks in tqdm(sbrt_loader):    
    with torch.no_grad():
        feature = model.get_features(bs_img.to(args.device)).cpu()
    
    mahalanobis_distance.append(detector.calculate_distance(feature).item())     
    vessel_density.append(bs_v_msks.sum() / (224 * 224 * 3 * 3))

In [None]:
vessel_risk_score = mahalanobis_distance