In [1]:
from experimental_utils.load_data import generate_yearly_data
from experimental_utils.load_vocab import load_joint_vocab
from train_clf_asd import get_modality_config
from core.data_utils import MODALITY_DATA_SELECT, SUPPRESS_MODALITY
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import torch
from tqdm import tqdm
from core.data_utils import MultiModalCollate
from torch.utils.data import DataLoader, ConcatDataset

from experimental_utils.track_model import select_model, parse_tags_from_filename 
%load_ext autoreload
%autoreload 2

In [2]:
import os

In [3]:
MODALITY_FLAGS = ["all"]
TASK = "binary"
MODALITY_CHECKPOINT = "first_check"

MODALITY_CONFIG = get_modality_config(MODALITY_FLAGS)
MODALITIES = [k for k, v in MODALITY_CONFIG.items() if v]
N_MODALITY = len(MODALITIES)

MIXING_MODULE = "softmax-gating"
ZP_ONLY = False
ADD_CONTRASTIVE = True

In [4]:
FILTERING_TAGS = {
    "modality_checkpoint": MODALITY_CHECKPOINT,
    "mixing_approach": MIXING_MODULE,
    "contrast": ADD_CONTRASTIVE,
    "zp_only": ZP_ONLY
}

In [5]:
DEVICE = "cuda:0"
EMBEDDING_DIM = 256

# train_cohort = pd.read_csv("../../data/outcome/train_asd.csv", dtype={"PATID": str})
reference_cohort = pd.read_csv("../../data/outcome/test_asd.csv", dtype={"PATID": str})
test_cohort = pd.read_csv("../../data/rAOM/test_raom.csv", dtype={"PATID": str})


_, transform = load_joint_vocab(reference_cohort)
# embeddings = load_joint_embeddings(vocab, EMBEDDING_DIM, DEVICE)

test_dataset_list = [
    generate_yearly_data(
        'raom', i, test_cohort, transform, MODALITY_CONFIG, complete_case=False
    )
    for i in tqdm(range(2015, 2023), desc="preparing test")
]
test_dataset = ConcatDataset(test_dataset_list)
collate_fn = MultiModalCollate(n_modality=N_MODALITY, survival=False)
    

100%|██████████| 32/32 [00:47<00:00,  1.47s/file]


vocab size: 19145


100%|██████████| 349/349 [00:00<00:00, 1671.60it/s]s]
100%|██████████| 379/379 [00:00<00:00, 1692.99it/s] 2.64s/it]
100%|██████████| 367/367 [00:00<00:00, 1631.60it/s] 2.90s/it]
100%|██████████| 116/116 [00:00<00:00, 1659.09it/s] 3.06s/it]
0it [00:00, ?it/s]0%|█████     | 4/8 [00:12<00:12,  3.07s/it]
0it [00:00, ?it/s]2%|██████▎   | 5/8 [00:15<00:09,  3.09s/it]
0it [00:00, ?it/s]5%|███████▌  | 6/8 [00:18<00:06,  3.11s/it]
0it [00:00, ?it/s]8%|████████▊ | 7/8 [00:22<00:03,  3.40s/it]
preparing test: 100%|██████████| 8/8 [00:26<00:00,  3.36s/it]


In [None]:
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn
)

In [None]:
batch = next(iter(test_loader))
event = batch["event"].float().to(DEVICE)
time = batch["time"].float().to(DEVICE)
masks = batch["mask"].to(DEVICE)
inputs = batch["inputs"].to(DEVICE)

In [None]:
baselines = [f for f in os.listdir("../../model_checkpoint/asd-archive") if f"all_None_{MIXING_MODULE}" in f and f.endswith("pth")]
baselines

In [None]:
name = "joint_survival_all_first_check_softmax-gating_contrast_zpOnly_869.pth"

In [None]:
model = torch.load(f"../../model_checkpoint/asd/{name}", map_location=DEVICE)
# model = torch.load(f"../../model_checkpoint/asd/joint_survival_all_None_self-attention_570.pth", map_location=DEVICE)

In [None]:
w = model(
    inputs=inputs,
    masks=masks,
)[-1]
w = w.detach().cpu().numpy()

[baby_birth, baby_dev, mom_birth, mom_prenatal] -> [mom_prenatal, mom_birth, baby_birth, baby_dev]

### self-gating

In [None]:
# i = int(torch.randint(size=(1,), low=0, high=w.shape[0]))
# i = 89
i = 152
print(i, event[i].item(), time[i])
print(masks[i][[3,2,0,1]])
plt.figure(figsize=(5, 4))
sns.heatmap(w[i][[3,2,0,1],:], annot=False, cmap="viridis", cbar=True, fmt='.2f')
plt.xticks([])
plt.yticks(
    ticks=[0.5, 1.5, 2.5, 3.5], 
    labels=[
        "$m_{\\text{prenatal}}$", "$m_{\\text{birth}}^{\\text{mom}}$",
        "$m_{\\text{birth}}^{\\text{baby}}$", "$m_{\\text{developmental}}$"
    ], 
    rotation=0
)
plt.show()

### self-attn

In [None]:
# i = int(torch.randint(size=(1,), low=0, high=w.shape[0]))
# i = 89
i = 152
print(i, event[i].item(), time[i])
print(masks[i][[3,2,0,1]])
plt.figure(figsize=(5, 4))
sns.heatmap(w[i][[3,2,0,1]][:,[3,2,0,1]], annot=False, cmap="viridis", cbar=True, fmt='.2f')
plt.xticks(
    ticks=[0.5, 1.5, 2.5, 3.5], 
    labels=[
        "$m_{\\text{prenatal}}$", "$m_{\\text{birth}}^{\\text{mom}}$",
        "$m_{\\text{birth}}^{\\text{baby}}$", "$m_{\\text{developmental}}$"
    ], 
    rotation=0
)
plt.yticks(
    ticks=[0.5, 1.5, 2.5, 3.5], 
    labels=[
        "$m_{\\text{prenatal}}$", "$m_{\\text{birth}}^{\\text{mom}}$",
        "$m_{\\text{birth}}^{\\text{baby}}$", "$m_{\\text{developmental}}$"
    ], 
    rotation=0
)
plt.show()

### Shap

In [9]:
from captum.attr import LayerIntegratedGradients

#### scratch

In [None]:
name = "joint_survival_all_first_check_softmax-gating_contrast_zpOnly_869.pth"
# name = "joint_survival_all_first_check_softmax-gating_157.pth"
model = torch.load(f"../../model_checkpoint/asd-archive/{name}", map_location=DEVICE)

def forward_func(inputs, masks):
    return model(inputs, masks)[0]
lig = LayerIntegratedGradients(forward_func, model.embedding_module)

test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn
)

batch = next(iter(test_loader))
event = batch["event"].float().to(DEVICE)
time = batch["time"].float().to(DEVICE)
masks = batch["mask"].to(DEVICE)
inputs = batch["inputs"].to(DEVICE)

attr, delta = lig.attribute(
    inputs=inputs,
    target=8,
    additional_forward_args=(masks,),
    n_steps=50,
    internal_batch_size=128,
    return_convergence_delta=True
)
attr = -attr[:, [3,2,0,1], :, :]
masks = masks[:, [3,2,0,1]]

### modality level

In [7]:
import numpy as np

In [10]:
# model_names = [
#     "joint_survival_all_first_check_softmax-gating_contrast_zpOnly_869.pth",
#     "joint_survival_all_first_check_softmax-gating_contrast_924.pth",
#     "joint_survival_all_first_check_softmax-gating_157.pth",
#     "joint_survival_all_mid_check_softmax-gating_contrast_zpOnly_241.pth",
#     "joint_survival_all_mid_check_softmax-gating_contrast_859.pth",
#     "joint_survival_all_mid_check_softmax-gating_418.pth",
#     "joint_survival_all_final_check_softmax-gating_contrast_zpOnly_810.pth",
#     "joint_survival_all_final_check_softmax-gating_contrast_134.pth",
#     "joint_survival_all_final_check_softmax-gating.pth"
# ]

filtering_tags = {
    "modality_checkpoint": "first_check",
    "mixing_approach": "softmax-gating",
    "contrast": True,
    "zp_only": True
}

model_names = select_model(
    root="../../model_checkpoint/raom-archive",
    filtering_tags=filtering_tags
)
model_names

['joint_binary_all_first_check_softmax-gating_contrast_zpOnly_232.pth',
 'joint_binary_all_first_check_softmax-gating_contrast_zpOnly_260.pth',
 'joint_binary_all_first_check_softmax-gating_contrast_zpOnly_143.pth',
 'joint_binary_all_first_check_softmax-gating_contrast_zpOnly_971.pth',
 'joint_binary_all_first_check_softmax-gating_contrast_zpOnly_262.pth']

In [12]:
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=False, collate_fn=collate_fn
)

importance_list = []
for name in model_names:
    print(name)
    model = torch.load(f"../../model_checkpoint/raom-archive/{name}", map_location=DEVICE)
    def forward_func(inputs, masks):
        return model(inputs, masks)[0]
    lig = LayerIntegratedGradients(forward_func, model.embedding_module)

    importance = torch.zeros((4,)).to(DEVICE)
    for batch in tqdm(test_loader):
        masks = batch["mask"].to(DEVICE)

        if not masks.all(dim=1).any().item(): # if all samples in the batch do not have complete obs
            continue

        inputs = batch["inputs"].to(DEVICE)
        attr, delta = lig.attribute(
            inputs=inputs,
            target=0,
            additional_forward_args=(masks,),
            n_steps=50,
            internal_batch_size=128,
            return_convergence_delta=True
        )
        attr = attr[:, [3,2,0,1], :, :]
        masks = masks[:, [3,2,0,1]]
        attr = attr[masks.all(dim=1)].abs().sum(dim=-1) #(B, 4, L)

        # min-max norm
        attr = (attr - attr.min()) / (attr.max() - attr.min())

        # clip outliers at 99th percentile
        # attr = torch.clamp(attr, max=torch.quantile(attr, 0.99))

        a = attr.sum(dim=-1)
        importance += (a / a.sum(dim=-1, keepdim=True)).mean(dim=0) / len(test_loader)
    importance_list.append(importance)
        
    print(importance.round(decimals=3))
    print("")

joint_binary_all_first_check_softmax-gating_contrast_zpOnly_232.pth


100%|██████████| 10/10 [00:12<00:00,  1.21s/it]


tensor([0.2280, 0.3040, 0.2890, 0.1780], device='cuda:0')

joint_binary_all_first_check_softmax-gating_contrast_zpOnly_260.pth


100%|██████████| 10/10 [00:12<00:00,  1.24s/it]


tensor([0.2410, 0.3160, 0.2470, 0.1960], device='cuda:0')

joint_binary_all_first_check_softmax-gating_contrast_zpOnly_143.pth


100%|██████████| 10/10 [00:12<00:00,  1.29s/it]


tensor([0.2240, 0.3060, 0.2550, 0.2150], device='cuda:0')

joint_binary_all_first_check_softmax-gating_contrast_zpOnly_971.pth


100%|██████████| 10/10 [00:12<00:00,  1.22s/it]


tensor([0.2590, 0.2960, 0.2440, 0.2010], device='cuda:0')

joint_binary_all_first_check_softmax-gating_contrast_zpOnly_262.pth


100%|██████████| 10/10 [00:12<00:00,  1.22s/it]

tensor([0.2310, 0.2990, 0.2590, 0.2120], device='cuda:0')






In [13]:
print(torch.stack(importance_list).mean(dim=0).cpu().tolist())
print(torch.stack(importance_list).std(dim=0).cpu().tolist())

[0.23660431802272797, 0.3042547106742859, 0.2585369944572449, 0.20060400664806366]
[0.01402726024389267, 0.007799405604600906, 0.01794830895960331, 0.014584501273930073]


In [None]:
m = "|".join([filtering_tags["modality_checkpoint"], filtering_tags["mixing_approach"]])
if filtering_tags["contrast"]:
    m += "|contrast"
if filtering_tags["zp_only"]:
    m += "|zpOnly"
m

In [None]:
import json

In [None]:
result = {}
result[m] = {
    'mean': torch.stack(importance_list).mean(dim=0).cpu().tolist(),
    'std': torch.stack(importance_list).std(dim=0).cpu().tolist()
}
with open("../../modality_importance/test.json", "w") as f:
    json.dump(result, f, indent=4)

### Sctrach

In [None]:
delta.max()

In [None]:
mod_norm = torch.norm(attr.sum(dim=-1), dim=-1)
mod_norm = mod_norm * (mod_norm != 0) + torch.ones_like(mod_norm) * (mod_norm == 0)
attr.sum(dim=-1) / mod_norm[:,:,None]
# shape (batch_size, 4, L)

In [None]:
a = attr.sum(dim=-1) #/ mod_norm[:,:,None]
a

In [None]:
a.sum(dim=-1)[event == 1]

In [None]:
a.sum(dim=-1)[event == 0]

### token level

In [None]:
from joblib import Parallel, delayed
from collections import defaultdict

def parallel_aggregate_joblib(inputs, attrs, n_jobs=-1):
    def process_batch(batch):
        local_dict = defaultdict(list)
        for token, attr in batch:
            local_dict[int(token)].append(float(attr))
        return dict(local_dict)
    
    # Create batches
    batch_size = len(inputs) // (n_jobs if n_jobs > 0 else 4)
    batches = [list(zip(inputs[i:i+batch_size], attrs[i:i+batch_size])) 
               for i in range(0, len(inputs), batch_size)]
    
    # Process in parallel
    results = Parallel(n_jobs=n_jobs)(
        delayed(process_batch)(batch) for batch in batches
    )
    
    # Merge results
    final_dict = defaultdict(list)
    for result in results:
        for key, values in result.items():
            final_dict[key].extend(values)
    
    return dict(final_dict)

In [None]:
def summarize_attributions_across_samples(explainer, dataloader, device):
    """
    Summarize attribution scores across samples for each token.

    Returns:
        dict: A dictionary mapping each unique token to the mean attribution score across samples.
    """
    token_to_attributions = dict()
    for batch in tqdm(dataloader):
        masks = batch["mask"].to(device)
        inputs = batch["inputs"].to(device)
        attr, delta = explainer.attribute(
            inputs=inputs,
            target=8,
            additional_forward_args=(masks,),
            n_steps=50,
            internal_batch_size=128,
            return_convergence_delta=True
        )
        attr = -attr.sum(dim=-1) # negate as the anchored target is "censoring"
        # mod_norm = torch.norm(attr.sum(dim=-1), dim=-1)
        # mod_norm = mod_norm * (mod_norm != 0) + torch.ones_like(mod_norm) * (mod_norm == 0)
        # attr = attr / mod_norm[:,:,None]
        if float(delta.max()) > 0.05:
            raise RuntimeWarning("convergence warning: large delta")

        inputs = inputs.flatten()
        attr = attr.flatten()
        

        # Aggregate attributions by token
        unique_tokens, inv_idx = torch.unique(inputs, return_inverse=True)
        for i, token in enumerate(unique_tokens):
            if int(token) not in token_to_attributions:
                token_to_attributions[int(token)] = []
            token_to_attributions[int(token)].extend(attr[inv_idx == i].tolist())

        # for token, attr in zip(inputs, attr):
        #     if int(token) not in token_to_attributions:
        #         token_to_attributions[int(token)] = []
        #     token_to_attributions[int(token)].append(float(attr))
        

    return token_to_attributions

In [None]:
test_loader = DataLoader(
    test_dataset, batch_size=128, shuffle=True, collate_fn=collate_fn
)

model_names = [
    "joint_survival_all_first_check_softmax-gating_contrast_zpOnly_772",
    "joint_survival_all_first_check_softmax-gating_contrast_zpOnly_555",
    "joint_survival_all_first_check_softmax-gating_contrast_zpOnly_899",
    "joint_survival_all_first_check_softmax-gating_contrast_zpOnly_823",
    "joint_survival_all_first_check_softmax-gating_contrast_zpOnly_274"
]

attr_dict_list = []
for name in model_names:
    model = torch.load(f"../../model_checkpoint/asd-archive/{name}.pth", map_location=DEVICE)
    def forward_func(inputs, masks):
        return model(inputs, masks)[0]
    lig = LayerIntegratedGradients(forward_func, model.embedding_module)
    attr_dict = summarize_attributions_across_samples(lig, test_loader, DEVICE)
    attr_dict_list.append(attr_dict)


In [None]:
h = {token: 0 for token, _ in attr_dict.items()}
for elem in attr_dict_list:
    for token in elem.keys():
        h[token] += torch.tensor(elem[token]).mean() / 5

In [None]:
top10_pos_tokens = dict(sorted(h.items(), key=lambda x: x[1], reverse=True)[:10])
top10_neg_tokens = dict(sorted(h.items(), key=lambda x: x[1], reverse=False)[:10])

In [None]:
top10_pos_tokens

In [None]:
top10_neg_tokens

In [None]:
print(top10_pos_tokens.keys())
vocab.lookup_tokens(list(top10_pos_tokens.keys()))

In [None]:
print(top10_neg_tokens.keys())
vocab.lookup_tokens(list(top10_neg_tokens.keys()))