In [None]:
import argparse
import itertools
import json
import pathlib
from types import SimpleNamespace

import numpy as np
import pandas as pd
import torch
from pathlib import Path

import sys
sys.path.append("../")

from lib.read_data import read_data
from lib.get_models import get_diffusion_model
from lib.diffusion_models import GaussianDiffusion
from lib.sampling import coherent_sample, sample
from lib.config import modalities_list

In [None]:
# --- Configuration ---
mode = 'multi'             # 'multi' or 'coherent'
test_iterations = 1        # number of generation repeats
dim = '32'                 # your chosen dimension


results_path = '../../results'    
data_dir = '../../datasets_TCGA/07_normalized/'


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# parameters

print("Loading data...")
# Load modality data
modalities_map = read_data(
    modalities=modalities_list,
    splits=['train'],
    data_dir=data_dir,
    dim=dim,
    mask_train_path=f'../../datasets_TCGA/06_masked/{dim}/masks_train.csv')

In [None]:
models_dict = {}
diffusion = GaussianDiffusion(num_timesteps=1000).to(device)

In [None]:
for target in modalities_map.keys():
    conds = [m for m in modalities_map.keys() if m != target]
    if mode == 'multi':
        # load the multi-conditioning model (trained on all other modalities jointly)
        ckpt_path = f"../../results/{dim}/{target}_from_multi/train/best_by_mse.pth"
        ckpt = torch.load(ckpt_path, map_location='cpu')
        cfg = ckpt['config']
        state = ckpt['best_model_mse']
        x_dim = modalities_map[target]['train'].shape[1]
        cond_dims = [modalities_map[c]['train'].shape[1] for c in conds]
        model = get_diffusion_model(
            cfg['architecture'], diffusion, SimpleNamespace(**cfg),
            x_dim=x_dim, cond_dims=cond_dims
        ).to(device)
        model.load_state_dict(state)
        model.eval()
        models_dict[target] = {'models': [model], 'diffusion': diffusion, 'conds': conds}
    else:
        # coherent: load one single-conditioning model per conditioning modality
        models = []
        for c in conds:
            ckpt_path = f"../../results/{dim}/{target}_from_{c}/train/best_by_mse.pth"
            ckpt = torch.load(ckpt_path, map_location='cpu')
            cfg = ckpt['config']
            state = ckpt['best_model_mse']
            x_dim = modalities_map[target]['train'].shape[1]
            cond_dim = modalities_map[c]['train'].shape[1]
            mdl = get_diffusion_model(
                cfg['architecture'], diffusion, SimpleNamespace(**cfg),
                x_dim=x_dim, cond_dims=cond_dim
            ).to(device)
            mdl.load_state_dict(state)
            mdl.eval()
            models.append(mdl)
        models_dict[target] = {'models': models, 'diffusion': diffusion, 'conds': conds}

In [None]:
models_dict['cna'].keys()

In [None]:
def generate_batch(target, cond_batches, present):
    """
    target: modality name
    cond_batches: list of np.ndarray for present modalities, in the order of `present`
    present: tuple of present modality names
    """
    entry = models_dict[target]
    diff = diffusion
    full_conds = entry['conds']  # list of all cond modality names

    if mode == 'multi':
        # build full cond list in order entry['conds'], zero for missing
        cond_ts = []
        mask_arr = []
        for cm in full_conds:
            if cm in present:
                # find index in present to grab batch
                idx = present.index(cm)
                arr = cond_batches[idx]
                mask_arr.append(np.ones(arr.shape[0]))
            else:
                # zero batch
                arr = np.zeros((len(cond_batches[0]), modalities_map[cm]['train'].shape[1]), dtype=np.float32)
                mask_arr.append(np.zeros(arr.shape[0]))
            cond_ts.append(torch.tensor(arr, dtype=torch.float32, device=device))
        # mask_tensor shape: (num_conditions, batch_size)
        mask_tensor = torch.tensor(np.stack(mask_arr, axis=0), dtype=torch.float32, device=device)
        gen = sample(
            model=entry['models'][0],
            diffusion=diff,
            cond=cond_ts,
            num_features=modalities_map[target]['train'].shape[1],
            mask=mask_tensor,
            device=device
        )
    else:
        # coherent: only use available modalities and corresponding models
        cond_ts = []
        models = []
        for mdl, cm in zip(entry['models'], full_conds):
            if cm in present:
                cond_ts.append(torch.tensor(cond_batches[present.index(cm)], dtype=torch.float32, device=device))
                models.append(mdl)
        gen = coherent_sample(
            models=models,
            diffusion=diff,
            num_samples=cond_ts[0].shape[0],
            num_features=modalities_map[target]['train'].shape[1],
            conds=cond_ts,
            device=device
        )
    return gen.cpu().numpy()

In [None]:
# Identify groups: key = (tuple(present), tuple(missing))
groups = {}
for idx in modalities_map[next(iter(modalities_map))]['train'].index:
    mask = {m: modalities_map[m]['mask_train'].loc[idx] for m in modalities_map}
    missing = tuple(sorted([m for m, v in mask.items() if v == 0]))
    present = tuple(sorted([m for m in modalities_map if m not in missing]))
    if not missing or len(present) < 2:
        continue
    groups.setdefault((present, missing), []).append(idx)

In [None]:
len(groups.keys())

In [None]:
# Process one-missing groups in batch
imputed = {}
for (present, missing), idxs in groups.items():
    if len(missing) == 1:
        tgt = missing[0]
        # collect cond arrays per present modality
        cond_batches = [np.vstack([modalities_map[c]['train'].loc[idx].values for idx in idxs]) for c in present]
        # generate batch
        gen_batch = generate_batch(tgt, cond_batches, present)
        # store results
        for i, idx in enumerate(idxs):
            rec = {c: modalities_map[c]['train'].loc[idx].values for c in present}
            rec[tgt] = gen_batch[i]
            imputed[idx] = rec

In [None]:
# Process two-missing groups in batch, per order
for (present, missing), idxs in groups.items():
    if len(missing) == 2:
        a, b = missing
        # collect cond arrays once for batch
        cond_batches = {c: np.vstack([modalities_map[c]['train'].loc[idx].values for idx in idxs]) for c in present}
        # generate each missing modality independently
        gen_a = generate_batch(a, [cond_batches[c] for c in present], present)
        gen_b = generate_batch(b, [cond_batches[c] for c in present], present)
        # combine into single imputed record per sample
        for i, idx in enumerate(idxs):
            rec = {c: modalities_map[c]['train'].loc[idx].values for c in present}
            rec[a] = gen_a[i]
            rec[b] = gen_b[i]
            imputed[idx] = rec

In [None]:
# Append imputed to modalities_map

for mod in modalities_map:
    orig = modalities_map[mod]['train']
    new_data = {idx: rec[mod] for idx, rec in imputed.items()}
    if not new_data:
        continue
    df_new = pd.DataFrame.from_dict(new_data, orient='index', columns=orig.columns)
    modalities_map[mod]['train'] = pd.concat([orig, df_new])
    # update mask
    mask_orig = modalities_map[mod]['mask_train']
    mask_new = pd.Series(1, index=df_new.index)
    modalities_map[mod]['mask_train'] = pd.concat([mask_orig, mask_new])

print(f"Batched imputation complete: {len(imputed)} records generated.")