# DiffuSETS metric scripts

## Prerequisites

In [1]:
import os
import re
import numpy as np
import pandas as pd
import torch
from scipy.linalg import sqrtm

from diffusers import DDPMScheduler
from torch.utils.data import DataLoader
from unet.conditional_unet_patient_3 import ECGconditional
from vae.vae_model import VAE_Decoder
from clip.clip_model import CLIP
from dataset.ptbxl_dataset import PtbxlDataset, PtbxlDataset_VAE
from dataset.mimic_iv_ecg_dataset import MIMIC_IV_ECG_Dataset, VAE_MIMIC_IV_ECG_Dataset

from tqdm import tqdm
import json

PTB_PATH = '/data/0shared/laiyongfan/data_text2ecg/ptb-xl/'
PTB_VAE_PATH = '/data/0shared/laiyongfan/data_text2ecg/ptb-xl_vae'
MIMIC_PATH = '/data1_science/1shared/physionet.org/files/mimic-iv-ecg/1.0/mimic-iv-ecg-diagnostic-electrocardiogram-matched-subset-1.0'
MIMIC_VAE_PATH = '/data/0shared/laiyongfan/data_text2ecg/mimic_vae'

os.environ['CUDA_VISIBLE_DEVICE'] = '7'
device = "cuda"

  from .autonotebook import tqdm as notebook_tqdm


Loading Models

In [2]:
# CLIP model
# model under clip_3 folder is trained on mimic with embed dim 64
clip_model_root = '/home/laiyongfan/text2ecg/checkpoints/clip_4/CLIP_model_ep9.pth'
clip_model = CLIP(embed_dim=64)
clip_model_weight = torch.load(clip_model_root, map_location=device)
clip_model.load_state_dict(clip_model_weight)
clip_model.eval()
clip_model = clip_model.to(device)

In [3]:
# UNET
n_channels = 4
num_train_steps = 1000
diffused_model = DDPMScheduler(num_train_timesteps=num_train_steps, beta_start=0.00085, beta_end=0.0120)
unet = ECGconditional(num_train_steps, resolution=128, kernel_size=7, num_levels=5, n_channels=n_channels)
unet_weights_dir = '/data/0shared/chenjiabo/DiffuSETS/weights/ablation_new/DDPM_1111/test/' 
files = os.listdir(unet_weights_dir)
max_id = -1
latest_weight_file = None
for file in files:
    match = re.search(r'ecg2chan_input20K_1000ts_(\d+)\.pth', file)
    if match:
        file_id = int(match.group(1))
        if file_id > max_id:
            max_id = file_id
            latest_weight_file = file
unet_weights_path = unet_weights_dir + f'ecg2chan_input20K_1000ts_{max_id}.pth'
print(unet_weights_path)

unet.load_state_dict(torch.load(unet_weights_path, map_location=device))
unet.eval()
unet = unet.to(device)
diffused_model.set_timesteps(1000)

decoder = None

/data/0shared/chenjiabo/DiffuSETS/weights/ablation_new/DDPM_1111/test/ecg2chan_input20K_1000ts_72.pth


In [4]:
# VAE
decoder = VAE_Decoder()
# VAE_path
vae_path = '/home/laiyongfan/text2ecg/checkpoints/vae_1/VAE_model_ep9.pth'
checkpoint = torch.load(vae_path, map_location=device)
decoder.load_state_dict(checkpoint['decoder'])
decoder.eval()
decoder = decoder.to(device)

Read features from json

In [5]:
def read_features(features_path):
    with open(features_path + '/features.json', 'r') as file:
        features_dict = json.load(file)
        tensor_features = ['text_embed', 'Ori Latent', 'Gen Latent']
        for key in features_dict:
            if key in tensor_features:
                features_dict[key] = eval(features_dict[key])
    return features_dict

DDPM inference function

In [6]:
@torch.no_grad()
def generation_from_net(diffused_model, net, text_embedding, condition, batch_size, device, dim=128, decoder=None):
    net.eval()
    xi = torch.randn(batch_size, 4, dim)
    xi = xi.to(device)
    for _, i in enumerate(diffused_model.timesteps):
        t = i*torch.ones(batch_size, dtype=torch.long)
        with torch.no_grad():

            # change this line to fit your unet 
            if condition:
                noise_predict = net(xi, t, text_embedding, condition)
            if not condition:
                noise_predict = net(xi, t, text_embedding)

            xi = diffused_model.step(model_output=noise_predict, 
                                     timestep=i, 
                                     sample=xi)['prev_sample']

    if decoder:
        xi = decoder(xi)

    # xi: (B, L, C) whether using VAE decoder or not
    return xi 

Fetch text embedding

Prototype:
```python
# Step 0: load the embedding dict
embedding_dict = pd.read_csv(path, low_memory=False)

def fetch_embedding_xxx(text: str):
    # Step 1: transform text to key-form that can be inquired from table
    text = get_key_text(text)

    # Step 2: Inquire from table
    try:
        text_embed = embedding_dict.loc[embedding_dict['xxx'] == text, 'embed'].values[0]
        text_embed = eval(text_embed)
    except IndexError:
        # Special case
        text_embed = special_case()
    
    # Step 3: Convert to torch.Tensor
    return torch.tensor(text_embed)

# Usage:
for idx, (X, y) in enumerate(dataloader):
    text = y['text']
    text = [fetch_text_embedding_mimic_report_0(x) for x in text]
    # text_embedding: (B, 1536)
    text_embedding = torch.stack(text).to(device)
```

In [7]:
ptbxl_embedding_dict = pd.read_csv('/data/0shared/chenjiabo/DiffuSETS/data/ptbxl_database_embed.csv', low_memory=False)[['ecg_id', 'text_embed']]
ptbxl_original_sheet = pd.read_csv('/data/0shared/laiyongfan/data_text2ecg/ptb-xl/ptbxl_database.csv', low_memory=False)
ptbxl_embedding_dict = pd.merge(ptbxl_embedding_dict, ptbxl_original_sheet)

# converting ptbxl text to ada embedding
def fetch_text_embedding_ptbxl(text:str):
    text = text.split('|')[0]
    text = text.replace('The report of the ECG is that ', '')
    try:
        text_embed = ptbxl_embedding_dict.loc[ptbxl_embedding_dict['report'] == text, 'text_embed'].values[0]
        text_embed = eval(text_embed)
    except IndexError:
        text_embed = [0] * 1536
        print(text)
    return torch.tensor(text_embed)

In [8]:
mimic_embedding_dict = pd.read_csv('/data/0shared/chenjiabo/DiffuSETS/data/mimic_iv_text_embed.csv')

# converting mimic label text into embedding
# NOTE: ONLY use report 0
def fetch_text_embedding_mimic_report_0(text: str):
    text = text.split('|')[0]
    if len(text) > 0 and text[-1] != '.':
        text += '.'
    try:
        text_embed = mimic_embedding_dict.loc[mimic_embedding_dict['text'] == text, 'embed'].values[0]
        text_embed = eval(text_embed)
    except IndexError:
        text_embed = mimic_embedding_dict.iloc[-1]['embed']
        text_embed = eval(text_embed)
        print(1, text)
    return torch.tensor(text_embed)

## CLIP Score

Compare the cosine similarity between generated ecg embedding and label text embedding


In [9]:
ptbxl_vae_dataset = PtbxlDataset_VAE(PTB_VAE_PATH)
ptbxl_vae_dataloader = DataLoader(ptbxl_vae_dataset, batch_size=1, shuffle=True)
mimic_vae_dataset = VAE_MIMIC_IV_ECG_Dataset(MIMIC_VAE_PATH)
mimic_vae_dataloader = DataLoader(mimic_vae_dataset, batch_size=1, shuffle=True)

In [10]:
# Generating ecg during test
@torch.no_grad()
def CLIP_Score(clip_model, diffused_model, unet, test_dataloader, fetch_func, device, gen_batch, decoder=None, num_test=None, use_condition=True):
    total_clip_score = 0
    if num_test is None:
        num_test = len(test_dataloader)

    with tqdm(total=num_test) as pbar:
        for idx, (X, y) in enumerate(test_dataloader):
            # y: label dict, one label during one generation
            text = y['text'][0]

            # text to embedding transform
            # unet text embedding: (gen_B, 1, 1536)
            text_embedding = fetch_func(text).to(device)
            unet_text_embedding = text_embedding.repeat((gen_batch, 1, 1))

            # condition part, value: (gen_B, 1, 1)
            gender = 1 if y['gender'] == 'M' else 0
            gender = torch.tensor([gender])
            condition = {'gender': gender, 
                         'age': y['age'], 
                         'heart rate': y['hr']}

            for key in condition:
                condition[key] = np.array([condition[key]])
                condition[key] = np.repeat(condition[key][np.newaxis, :], gen_batch, axis=0)
                condition[key] = torch.Tensor(condition[key])
                condition[key] = condition[key].to(device)
            
            if not use_condition:
                condition = False
            
            # Generating a bunch of ECGs according to the same text
            # ecgs: (gen_B, L, C)
            ecgs = generation_from_net(diffused_model=diffused_model, 
                                       net=unet, 
                                       text_embedding=unet_text_embedding, 
                                       condition=condition, 
                                       batch_size=gen_batch, 
                                       device=device, 
                                       decoder=decoder)

            signal_embedding = clip_model.encode_signal(ecgs)

            # signal features: (gen_B, embed_dim)
            signal_features = clip_model.ecg_projector(signal_embedding)
            # text features: (1, embed_dim) -> (gen_B, embed_dim)
            text_features = clip_model.text_projector(text_embedding)
            text_features = text_features.repeat((gen_batch, 1))

            # normalized features
            signal_features = signal_features / signal_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)

            # cosine similarity
            sample_clip_score = torch.trace(signal_features @ text_features.t()) / gen_batch

            total_clip_score += sample_clip_score

            pbar.update(1)
            if idx == num_test - 1:
                break

    mean_clip_score = total_clip_score / num_test

    return mean_clip_score
        

In [11]:
# CLIP Score on MIMIC VAE
CLIP_Score(clip_model=clip_model, 
           diffused_model=diffused_model, 
           unet=unet, 
           test_dataloader=mimic_vae_dataloader, 
           fetch_func=fetch_text_embedding_mimic_report_0, 
           device=device, 
           gen_batch=64, 
           decoder=decoder, 
           num_test=10)

100%|██████████| 10/10 [02:20<00:00, 14.08s/it]


tensor(0.1472, device='cuda:0')

In [12]:
# CLIP Score on Ptbxl VAE
CLIP_Score(clip_model=clip_model, 
           diffused_model=diffused_model, 
           unet=unet, 
           test_dataloader=ptbxl_vae_dataloader, 
           fetch_func=fetch_text_embedding_ptbxl, 
           device=device, 
           gen_batch=64, 
           decoder=decoder, 
           num_test=10)

100%|██████████| 10/10 [02:16<00:00, 13.68s/it]


tensor(0.1065, device='cuda:0')

In [15]:
@torch.no_grad()
def CLIP_Score_saved_samples(sample_dir:str, clip_model, decoder, device):
    """ 
    CLIP Score on saved samples

    sample_dir: path to the sample directory\n
    /path/to/sample_dir
       |-001\n
       |-002\n
       ...\n 
    """
    total_clip_score = 0
    for idx, root in enumerate(tqdm(os.listdir(sample_dir))):
        feature_dict = read_features(os.path.join(sample_dir, root))
        gen_batch = feature_dict['batch']

        # text_embedding: (gen_B, 1536)
        text_embedding = feature_dict['text_embed']
        text_embedding = torch.tensor(text_embedding).repeat((gen_batch, 1)).to(device)

        # ecg_latent: (gen_B, 4, 128)
        ecg_latent = feature_dict['Gen Latent']
        ecg_latent = torch.tensor(ecg_latent).to(device)

        # generated ECGs: (gen_B, L, C)
        ecgs = decoder(ecg_latent)

        signal_embedding = clip_model.encode_signal(ecgs)

        # signal features: (gen_B, embed_dim)
        signal_features = clip_model.ecg_projector(signal_embedding)
        # text features: (1, embed_dim) -> (gen_B, embed_dim)
        text_features = clip_model.text_projector(text_embedding)
        text_features = text_features.repeat((gen_batch, 1))

        # normalized features
        signal_features = signal_features / signal_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity
        sample_clip_score = torch.trace(signal_features @ text_features.t())

        total_clip_score += sample_clip_score

    total_num = gen_batch * (idx + 1)

    return {'CLIP': total_clip_score / total_num, 'num_samples': total_num}

In [16]:
CLIP_Score_saved_samples(sample_dir='./sample/', clip_model=clip_model, decoder=decoder, device=device)

100%|██████████| 100/100 [00:16<00:00,  5.90it/s]


{'CLIP': tensor(0.0989, device='cuda:0'), 'num_samples': 6400}

## CLIP Classifier (Zero shot)

Compare the cosine similarity between generated ecg embedding and given text label embeddings

Distribute ecg to the class whose embeddding is most similar 

Directly read generated ecg and the relating label

In [None]:
classifier_dataset = 
classifier_dataloader = 

In [None]:
# TODO
# Assign classes to ecg according to the text
def text_to_class(text:str):
    pass

# Get all the super-class features
# 1. get the ada embedding of classes
# 2. use clip model to get the feature
# A tensor with shape (num_class, 1536)
def fetch_class_features(clip_model):
    pass

class_features = fetch_class_features(clip_model=clip_model)

In [None]:
@torch.no_grad()
def CLIP_Classifier(clip_model, class_features, test_dataloader):
    correct = 0 
    label_all = []
    logits_all = []
    for idx, (X, y) in enumerate(test_dataloader):
        label = [text_to_class(x) for x in y['text']]
        
        signal_embedding = clip_model.encode_signal(X)
        signal_features = clip_model.ecg_projector(signal_embedding)

        # X: (B, 1536) y: (num_class, 1536) X @ y.t(): (B, num_class)
        logits = signal_features @ class_features.t()
        pred = torch.argmax(logits, dim=1)

        correct += torch.sum((label == pred))
        logits_all.append(logits.reshape(-1, ))
        label_all.append(label)
    
    acc = correct / len(test_dataloader.dataset)
    logits_all = torch.concat(logits_all)
    label_all = torch.concat(label_all)

    return acc, label_all, logits_all

## FID Score

Compute the distance between the generated ecg embedding and real ecg embedding

Have to compute between two datasets

**INPUT**: Stacked embeddings generated from CLIP model

**INPUT SHAPE**: (num_ecgs, num_features)

In [17]:
@torch.no_grad()
def generate_feature_matrix(sample_dir:str, clip_model, device, decoder, use_all_batch=True):
    """ 
    Generating feature matrix from experiment folder
    sample_dir: path to the sample directory\n
    /path/to/sample_dir
       |-001\n
       |-002\n
       ...\n 
    use_all_batch: whether to use whole batch, 
    if not, only sample one piece of ecg from each generation folder.\n
    return: dict of `gen` and `real`, which contains feature matrix of shape (num_samples, feature_dim)
    """
    M_gen = []
    M_real = []
    for idx, root in enumerate(tqdm(os.listdir(sample_dir))):
        feature_dict = read_features(os.path.join(sample_dir, root))

        # ecg_latent: (gen_B, 4, 128)
        gen_latent = feature_dict['Gen Latent']
        gen_latent = torch.tensor(gen_latent).to(device)

        # generated ECGs: (gen_B, L, C)
        gen_ecgs = decoder(gen_latent)
        
        # gen_ecg_features: (gen_B, feature_dim) or (1, feature_dim)
        gen_ecg_embedding = clip_model.encode_signal(gen_ecgs)
        gen_ecg_features = clip_model.ecg_projector(gen_ecg_embedding)

        if not use_all_batch:
            gen_ecg_features = gen_ecg_features[0].unsqueeze(0)
        gen_ecg_features = gen_ecg_features.cpu()

        M_gen.append(gen_ecg_features)

        # ori_latent: (1, 4, 128)
        ori_latent = feature_dict['Ori Latent']
        ori_latent = torch.tensor(ori_latent).to(device)
        ori_ecgs = decoder(ori_latent)
        ori_ecg_embedding = clip_model.encode_signal(ori_ecgs)
        # ori_ecg_feature: (1, feature_dim)
        ori_ecg_features = clip_model.ecg_projector(ori_ecg_embedding)
        ori_ecg_features = ori_ecg_features.cpu()
        M_real.append(ori_ecg_features)

    M_gen = torch.concat(M_gen)
    M_real = torch.concat(M_real)

    # M_gen: (num_samples, num_features)
    return {'gen': M_gen, 'real': M_real}


In [18]:
state = generate_feature_matrix(sample_dir='./sample/', clip_model=clip_model, device=device, decoder=decoder, use_all_batch=True)
M_gen, M_real = state['gen'], state['real']
print(M_gen.shape, M_real.shape)

100%|██████████| 100/100 [00:18<00:00,  5.28it/s]

torch.Size([6400, 64]) torch.Size([100, 64])





In [19]:
def FID_score(M1: torch.Tensor, M2: torch.Tensor):
    M1, M2 = M1.numpy(), M2.numpy()
    mu1, sigma1 = M1.mean(axis=0), np.cov(M1, rowvar=False)
    mu2, sigma2 = M2.mean(axis=0), np.cov(M2, rowvar=False)

    ssdiff = np.sum((mu1 - mu2)**2.0)
    # calculate sqrt of product between cov
    covmean = sqrtm(sigma1.dot(sigma2))
    # check and correct imaginary numbers from sqrt
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    # calculate score
    fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
    return fid

In [20]:
fid_score = FID_score(M_real, M_gen) 
num_samples = M_real.shape[0]
scaler = FID_score(M_real[:num_samples // 2], M_real[num_samples // 2:])

# From ME-GAN, can produce result more stable to num_samples (and smaller result)
r_FID = fid_score / scaler

print(fid_score, r_FID)

339490635.045777 6.848712837192828


### NOTE

More samples, more close to real performance  

**If M_real and M_gen are both sampled from N(0, 1)** 

when num_samples is 2048: FID = 257, r_FID = 0.50; 

when num_samples is 20480: FID = 25, r_FID = 0.50;

**If M_real ~ N(0, 1) and M_gen ~ N(1, 4)** 

when num_samples is 2048: FID = 2567, r_FID = 5; √

when num_samples is 20480: FID = 2099, r_FID = 40;

## Precision and Recall

Use k-NN to construct manifold

In [21]:
class ManifoldDetector():
    def __init__(self, data: torch.Tensor, k=3):
        self.k = k
        self.data = data

        # Compute pairwise distances
        distances = torch.sqrt(torch.sum((self.data.unsqueeze(1) - self.data.unsqueeze(0))**2, dim=2))

        # Get indices of k-nearest neighbors
        _, indices = torch.topk(distances, k=self.k + 1, dim=1, largest=False)
        indices = indices[:, 1:]  # Exclude the point itself

        # Compute radius as the distance to the k-th nearest neighbor
        self.radii = torch.gather(distances, 1, indices[:, -1].view(-1, 1))

def is_in_manifold(test_point: torch.Tensor, manifold_detector: ManifoldDetector):
    distances = torch.sqrt(torch.sum((manifold_detector.data - test_point)**2, dim=1))
    is_inside = distances <= manifold_detector.radii.squeeze()
    return is_inside.any()

def points_in_manifold(test_points: torch.Tensor, manifold_detector: ManifoldDetector):
    count = 0
    for point in test_points:
       count += is_in_manifold(point, manifold_detector) 

    return count

def precision_recall(M_g, M_r, k=3):
    """ 
    Compute the precision and recall value for generation result\n
    M_g: feature matrix of generated ECG\n
    M_r: feature matrix of real ECG\n
    k: using distance from k nearest neighborhood to constuct manifold 
    """
    print("Initializing...")
    manifold_detector_g = ManifoldDetector(M_g, k)
    manifold_detector_r = ManifoldDetector(M_r, k)

    state = {}
    print("computing precision...")
    num_precision = points_in_manifold(M_g, manifold_detector_r)
    state['precision'] = num_precision / M_g.shape[0]

    print("computing recall...")
    num_recall = points_in_manifold(M_r, manifold_detector_g)
    state['recall'] = num_recall / M_r.shape[0] 

    return state

In [22]:
precision_recall(M_g=M_gen, M_r=M_real)

Initializing...
computing precision...
computing recall...


{'precision': tensor(0.9906), 'recall': tensor(0.3000)}