## Step 1: check the CLIP model

In [1]:
import os, torch
from melp.models.uniclip_model import UniCLIPModel          


import os, torch.distributed as dist
if not dist.is_initialized():
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"   
    dist.init_process_group(backend="gloo", rank=0, world_size=1)

    


BATCH = 4                     
C, FREQ, WIN = 2, 128, 30     
T = FREQ * WIN                


dummy_ecg    = torch.zeros(BATCH, C, T)      # (B, C, T)
dummy_report = ["This is a demo report."] * BATCH

batch = {"psg": dummy_ecg}


In [2]:
model = UniCLIPModel(psg_encoder_name="resnet18",
                  text_encoder_name="google/flan-t5-base").eval()

with torch.no_grad():
    loss_dict, _ = model.shared_step(batch, batch_idx=0)
print("contrastive loss =", loss_dict["loss"].item())


contrastive loss = 1.3862943649291992


In [None]:
model = CLIPModel()
model.train()
opt = torch.optim.AdamW(model.parameters(), lr=2e-4, weight_decay=0.2)

opt.zero_grad(set_to_none=True)
loss_dict, _ = model.shared_step(batch, batch_idx=0)
loss_dict["loss"].backward()
opt.step()

print("after one step, logit_scale =", model.logit_scale.exp().item())


In [4]:
model = CLIPModel().eval()

with torch.no_grad():
    psg_emb = model.encoders["all"](batch["ecg"])              # (B, 256)
    tok     = model._tokenize(batch["report"])
    txt_emb = model.encode_text(tok["input_ids"], tok["attention_mask"])["proj_text_emb"]  # (B, 256)


    psg_emb = torch.nn.functional.normalize(psg_emb, dim=-1)
    txt_emb = torch.nn.functional.normalize(txt_emb, dim=-1)
    sim = (psg_emb @ txt_emb.t()) * model.logit_scale.exp()    # (B, B)

print("similarity matrix:\n", sim)


similarity matrix:
 tensor([[1.2841, 1.2841, 1.2841, 1.2841],
        [1.2841, 1.2841, 1.2841, 1.2841],
        [1.2841, 1.2841, 1.2841, 1.2841],
        [1.2841, 1.2841, 1.2841, 1.2841]])


## Step 2: Check MAE model

In [1]:
import os, torch
from melp.models.mae_model import MAEModel          


import os, torch.distributed as dist
if not dist.is_initialized():
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "29500"   
    dist.init_process_group(backend="gloo", rank=0, world_size=1)

    


BATCH = 4                     
C, FREQ, WIN = 21, 64, 30     
T = FREQ * WIN                


dummy_ecg    = torch.zeros(BATCH, C, T)      # (B, C, T)
dummy_report = ["This is a demo report."] * BATCH

batch = {"ecg": dummy_ecg, "report": dummy_report}


In [2]:
model = MAEModel(psg_encoder_name="vit_tiny",
                  text_encoder_name="google/flan-t5-base").eval()

with torch.no_grad():
    loss_dict, _ = model.shared_step(batch, batch_idx=0)
print("contrastive loss =", loss_dict["loss"].item())



Trainable parameters:
| Module      | # params    |
|-------------|-------------|
| dec_blocks  | 12,609,536  |
| dec_norm    | 1,024       |
| dec_proj    | 131,584     |
| encoders    | 5,625,472   |
| lm_model    | 109,628,544 |
| logit_scale | 1           |
| mask_token  | 512         |
| pos_dec     | 24,576      |
| pred_head   | 430,920     |
| proj_t      | 262,656     |
| TOTAL       | 128,714,825 |
torch.Size([4, 48, 192])
torch.Size([4, 12, 256])
torch.Size([4, 48, 840]) torch.Size([4, 48, 840])
contrastive loss = 0.334650456905365
