In [1]:
import sys
sys.path.append('..')

import torch
from clip_model import CLIP
from vae.vae_model import VAE_Decoder
from tqdm import tqdm

device = torch.device('cuda:7')
# CLIP model
model_root = '../checkpoints/clip_7/clip_best.pth'
clip_model = CLIP(embed_dim=64)
clip_model_weight = torch.load(model_root, map_location=device)
clip_model.load_state_dict(clip_model_weight)
clip_model = clip_model.to(device)

# VAE
decoder = VAE_Decoder()
# VAE_path
vae_path = '../checkpoints/vae_1/VAE_model_ep9.pth'
checkpoint = torch.load(vae_path, map_location=device)
decoder.load_state_dict(checkpoint['decoder'])
decoder = decoder.to(device)


In [2]:
from dataset.ptbxl_dataset import PtbxlDataset_VAE
from torch.utils.data import DataLoader
import pandas as pd

PTB_VAE_PATH = '/data/0shared/laiyongfan/data_text2ecg/ptb-xl_vae'
dataset = PtbxlDataset_VAE(PTB_VAE_PATH)
dataloader = DataLoader(dataset, 64)

In [3]:
embedding_dict = pd.read_csv('/data/0shared/chenjiabo/DiffuSETS/data/ptbxl_database_embed.csv', low_memory=False)[['ecg_id', 'text_embed']]
original_sheet = pd.read_csv('/data/0shared/laiyongfan/data_text2ecg/ptb-xl/ptbxl_database.csv', low_memory=False)
embedding_dict = pd.merge(embedding_dict, original_sheet)

In [4]:
def fetch_text_embedding_ptbxl(text:str):
    text = text.split('|')[0]
    text = text.replace('The report of the ECG is that ', '')
    try:
        text_embed = embedding_dict.loc[embedding_dict['report'] == text, 'text_embed'].values[0]
        text_embed = eval(text_embed)
    except IndexError:
        text_embed = [0] * 1536
        # print(text)
    return torch.tensor(text_embed)

@torch.no_grad()
def CLIP_Score(clip_model, test_dataloader, fetch_text_embedding, device, decoder=None, num_test=None):
    clip_model.eval()
    decoder.eval()

    total_clip_score = 0

    for idx, (X, y) in enumerate(tqdm(test_dataloader)):
        # y: label dict
        text = y['text']
        text = [fetch_text_embedding(x) for x in text]
        text_embedding = torch.stack(text).to(device)

        # ecgs: (gen_B, L, C)
        X = X.to(device)
        if decoder:
            X = decoder(X)

        signal_embedding = clip_model.encode_signal(X)

        # signal features: (B, 1536)
        signal_features = clip_model.ecg_projector(signal_embedding)
        # text features:  (B, 1536)
        text_features = clip_model.text_projector(text_embedding)

        # normalized features
        signal_features = signal_features / signal_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # cosine similarity
        batch_clip_score = torch.trace(signal_features @ text_features.t()) 

        total_clip_score += batch_clip_score


    mean_clip_score = total_clip_score / len(test_dataloader.dataset)

    return mean_clip_score

Compute CLIP Score on PTBXL (training set)

In [11]:
CLIP_Score(clip_model=clip_model, test_dataloader=dataloader, fetch_text_embedding=fetch_text_embedding_ptbxl, device=device, decoder=decoder)

100%|██████████| 341/341 [02:39<00:00,  2.13it/s]


tensor(0.5043, device='cuda:7')

In [6]:
signal_features = torch.randn((256, 64))
text_features = torch.randn((256, 64))
# text_features = signal_features

signal_features = signal_features / signal_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

# cosine similarity
batch_clip_score = torch.trace(signal_features @ text_features.t()) 
print(batch_clip_score / 256)

tensor(-0.0058)


In [7]:
from dataset.mimic_iv_ecg_dataset import VAE_MIMIC_IV_ECG_Dataset

MIMIC_VAE_PATH = '/data/0shared/laiyongfan/data_text2ecg/mimic_vae'
mimic_dataset = VAE_MIMIC_IV_ECG_Dataset(MIMIC_VAE_PATH)
mimic_dataloader = DataLoader(mimic_dataset, 256)

In [8]:
embedding_dict_mimic = pd.read_csv('/data/0shared/chenjiabo/DiffuSETS/data/mimic_iv_text_embed.csv')

In [9]:
def fetch_text_embedding_mimic_report_0(text: str):
    text = text.split('|')[0]
    if len(text) > 0 and text[-1] != '.':
        text += '.'
    try:
        text_embed = embedding_dict_mimic.loc[embedding_dict_mimic['text'] == text, 'embed'].values[0]
        text_embed = eval(text_embed)
    except IndexError:
        text_embed = embedding_dict_mimic.iloc[-1]['embed']
        text_embed = eval(text_embed)
        # print(1, text)
    return torch.tensor(text_embed)
    

In [10]:
CLIP_Score(clip_model=clip_model, test_dataloader=mimic_dataloader, fetch_text_embedding=fetch_text_embedding_mimic_report_0, device=device, decoder=decoder)

100%|██████████| 3104/3104 [1:06:28<00:00,  1.28s/it]


tensor(0.3333, device='cuda:7')

In [35]:
for idx, (X, y) in enumerate(mimic_dataloader):
    text = y['text']
    text = [fetch_text_embedding_mimic_report_0(x) for x in text]
    text_embedding = torch.stack(text)


1 


In [5]:
for idx, (X, y) in enumerate(dataloader):
    text = y['text']
    text = [fetch_text_embedding_ptbxl(x) for x in text]
    text_embedding = torch.stack(text)

 
 
nr: 030315 }lder:  87 }r k
