In [None]:
import torch
from torch import nn
from mamba_ssm import Mamba
import numpy as np
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.decomposition import PCA
from matplotlib import pyplot as plt

In [None]:
device = 'cuda:0'
ckpt = ''

In [None]:

class CT25D(nn.Module):
    def __init__(self):
        super(CT25D, self).__init__()
                
        self.cnn = torch.hub.load('facebookresearch/dinov2', f'dinov2_vitb14')
        self.hidden_dim = 768
        self.lstm = nn.LSTM(768, 768, 2, batch_first=True, bidirectional=True)

        self.proj = nn.Sequential(                        
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.ReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim//2),
            nn.ReLU(),
            nn.Linear(self.hidden_dim//2, 128),  
            nn.ReLU(),
            nn.Linear(128, 64),  
            nn.ReLU(),
            nn.Linear(64, 32),  
            )


    def forward(self, imgs):                
        bs, n_img, channel, w, h = imgs.shape
        imgs = imgs.reshape(bs * n_img, channel, w, h)
        img_feats = self.cnn(imgs)        
        img_feats = img_feats.reshape(bs, n_img, -1)        

        img_feats = self.lstm(img_feats)
        img_feats = img_feats.mean(1)        
        return img_feats
    

In [None]:
model = CT25D().to(device)
model.load_state_dict(torch.load(ckpt, map_location='cpu'))
model = model.eval()

In [None]:
data_loader = ... # Use your dataset
normal_loader = ... # Use your dataset

In [None]:
normal_features = []
for images in normal_loader:
    with torch.no_grad():
        features = model(images.to(device)).cpu()
    
    normal_features.append(features)

tumor_vessel_features = []
density_list = []

for bs_img, bs_v_msks, pid in data_loader:    
    with torch.no_grad():
        feature = model(bs_img.to(device)).cpu()    
        tumor_vessel_features.append(feature)    

In [None]:
normal_vessel_features = torch.cat(normal_features,0).numpy()
tumor_vessel_features = torch.cat(tumor_vessel_features, 0).numpy()

In [None]:
def calculate_vessel_risk_score(normal_features, tumor_features, n_components=5, random_state=42, do_pca=False):
        
    normal_reduced = normal_features
    tumor_reduced = tumor_features
    
    if do_pca:
        pca_components = 0.95
        pca = PCA(n_components=pca_components, random_state=random_state)
        normal_reduced = pca.fit_transform(normal_features)
        tumor_reduced = pca.transform(tumor_features)    
    

    normal_scaled = normal_reduced
    tumor_scaled = tumor_reduced
    
    gmm = GaussianMixture(
        n_components=n_components,
        covariance_type='full',
        random_state=random_state,
        reg_covar=4e-3  # 수치 안정성을 위한 정규화
    )
    gmm.fit(normal_scaled)
    
    normal_log_probs = gmm.score_samples(normal_scaled)
    tumor_log_probs = gmm.score_samples(tumor_scaled)
    
    normal_scores = -normal_log_probs
    tumor_scores = -tumor_log_probs    
    
    p05 = np.percentile(tumor_scores, 5)
    p95 = np.percentile(tumor_scores, 95)
    print(p05, p95)
        
    tumor_scores = (tumor_scores - p05) / (p95 - p05)    
    # tumor_scores = np.clip(tumor_scores, 0, 1)
    plt.hist(tumor_scores)
    
    return {
        'vessel_risk_scores': tumor_scores,        
        'gmm': gmm,
        # 'scaler': scaler
        }

In [None]:
VRS_score = calculate_vessel_risk_score(normal_vessel_features, tumor_vessel_features, n_components=2, random_state=42, do_pca=False)