In [None]:
import torch
import mlflow
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, random_split

from tqdm import tqdm
from mlflow.types import Schema, TensorSpec
from mlflow.models import ModelSignature

from src.sd_vae.ae import VAE
from src.trainers import EarlyStopping
from src.trainers.first_stage_trainer import CLEAR_VAEFirstStageTrainer
from sklearn.manifold import TSNE


from src.utils.exp_utils.train_utils import (
    load_cfg,
    build_first_stage_trainer,
    xavier_init,
)
from src.utils.exp_utils.visual import feature_swapping_plot
from src.utils.data_utils.camelyon import build_dataloader

# in distribution swapping & ood x in distribution swapping 


# experiment protocal in-the-middle 要好好写
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
dataloaders = build_dataloader(
    data_root="/hpc/group/engelhardlab/ms1008/image_data",
    batch_size=64,
    download=False,
    num_workers=32,
)

In [None]:
train_loader = dataloaders["train"]
base_train_loader = dataloaders["train"]
collate_fn = getattr(train_loader, "collate_fn", None)
valid_loader = dataloaders['valid']
test_loader = dataloaders['test']
train_dataset = dataloaders["train"].dataset
n_total = len(train_dataset)
n_train = int(0.6 * n_total)
n_cv = n_total - n_train

train_subset, cv_subset = random_split(
    train_dataset,
    [n_train, n_cv],
    generator=torch.Generator().manual_seed(42)  
)

train_loader = DataLoader(
    train_subset,
    batch_size=base_train_loader.batch_size,
    shuffle=True,
    num_workers=base_train_loader.num_workers,
    pin_memory=True,
    collate_fn=collate_fn,  
)

cv_loader = DataLoader(
    cv_subset,
    batch_size=base_train_loader.batch_size,
    shuffle=False,
    num_workers=base_train_loader.num_workers,
    pin_memory=True,
    collate_fn=collate_fn,  
)

In [None]:
cfg = load_cfg('./config/camelyon.yaml')
cfg

In [None]:
input_schema = Schema([TensorSpec(np.dtype(np.float32), [-1, 1, 32, 32])])
output_schema = Schema([TensorSpec(np.dtype(np.float32), [-1, 1, 32, 32])])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)

vae = VAE(**cfg['vae']).to(device)

vae.apply(xavier_init)

trainer = CLEAR_VAEFirstStageTrainer(
    model=vae,
    early_stopping=EarlyStopping(patience=8),
    verbose_period=2,
    device=device,
    model_signature=signature,
    args=cfg["trainer_param"],
)


In [None]:
mlflow.set_tracking_uri("./mlruns")
mlflow.set_experiment("test-camelyon")
with mlflow.start_run() as run:
    mlflow.log_params(cfg['vae'] | cfg['trainer_param'])
    trainer.fit(epochs=cfg['train']['epochs'], train_loader=train_loader, valid_loader=valid_loader)

In [None]:
run_id = run.info.run_id
print(run_id)
# run_id = '8a30d6b78488426bab8b7b09b014b80c'

In [None]:
x = next(iter(train_loader))["image"].to(device)
best_model = mlflow.pytorch.load_model(f"runs:/{run_id}/best_model")
with torch.no_grad():
    best_model.eval()
    _, posterior = best_model(x)
z_c, z_s = posterior.mu.split_with_sizes(
    cfg["trainer_param"]["channel_split"], dim=1
)
select = torch.randint(0, 32, (10,)).tolist()

In [None]:
z_c.shape

In [None]:
# 6000 1500 0.05 4 8 8a30d6b78488426bab8b7b09b014b80c

feature_swapping_plot(
    z_c[select],
    z_s[select],
    x[select],
    best_model,
    img_size=96,
)

In [None]:
from torchvision.utils import make_grid
from src.utils.exp_utils.visual import make_colored_grid

def feature_swapping_plot_rows_cols(
    z_c_rows,          
    z_s_cols,         
    x_rows,            
    x_cols,            
    vae: torch.nn.Module,
    img_size=32,
    out_dir=None,
    run_id=None,
):
    with torch.no_grad():
        n_row = z_c_rows.size(0)
        n_col = z_s_cols.size(0)
        device = z_c_rows.device

        # Combine latent vectors
        paired_z = torch.cat(
            (
                z_c_rows[:, None, :, :, :].repeat(1, n_col, 1, 1, 1),
                z_s_cols[None, :, :, :, :].repeat(n_row, 1, 1, 1, 1),
            ),
            dim=2,
        ).flatten(start_dim=0, end_dim=1)

        # Decode
        paired_z = paired_z * 0.18215
        x_inter = vae.decoder(paired_z)
        
        maingrid = make_grid(x_inter, nrow=n_col, padding=2)

  
        top_bar = make_colored_grid(x_cols, nrow=n_col, color="red")
        left_bar = make_colored_grid(x_rows, nrow=1, color="blue")
 
        corner_h = top_bar.size(1)
        corner_w = left_bar.size(2)
        empty_corner = torch.ones(3, corner_h, corner_w).to(device)

        top_row = torch.cat([empty_corner, top_bar], dim=2)
        bottom_row = torch.cat([left_bar, maingrid], dim=2)
        final_grid = torch.cat([top_row, bottom_row], dim=1)

        plt.figure(figsize=(12, 12))  
        plt.imshow(final_grid.permute(1, 2, 0).cpu().numpy())
        plt.axis("off")
        
        if out_dir is not None:
            out_dir = os.path.join(out_dir, run_id)
            os.makedirs(out_dir, exist_ok=True)
            filepath = os.path.join(out_dir, "swap_rows_cv_cols_train.png")
            plt.savefig(filepath, bbox_inches="tight", pad_inches=0)
            plt.close()
        else:
            plt.show()


In [None]:
batch_test = next(iter(test_loader))
x_test = batch_test["image"].to(device)

batch_cv = next(iter(cv_loader))
x_cv = batch_cv["image"].to(device)

best_model = mlflow.pytorch.load_model(f"runs:/{run_id}/best_model")

with torch.no_grad():
    best_model.eval()

    _, posterior_test = best_model(x_test)
    zc_test, zs_test = posterior_test.mu.split_with_sizes(
        cfg["trainer_param"]["channel_split"], dim=1
    )

    _, posterior_cv = best_model(x_cv)
    zc_cv, zs_cv = posterior_cv.mu.split_with_sizes(
        cfg["trainer_param"]["channel_split"], dim=1
    )

n_row = 20   #
n_col = 20  

row_idx = torch.randint(0, x_test.size(0), (n_row,), device=device)
col_idx = torch.randint(0, x_cv.size(0), (n_col,), device=device)

feature_swapping_plot_rows_cols(
    z_c_rows=zc_test[row_idx],   
    z_s_cols=zs_cv[col_idx],   
    x_rows=x_test[row_idx],     
    x_cols=x_cv[col_idx],        
    vae=best_model,
    img_size=96,
    out_dir=None,
    run_id=run_id,
)

In [None]:
import torch
import torch.nn as nn
import numpy as np
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from src.modules.distribution import IsotropicNormalDistribution 

class DownstreamMLPTrainer:
    def __init__(
        self,
        vae: nn.Module,
        model: nn.Module,
        optimizer: Optimizer,
        criterion: nn.Module,
        verbose_period: int,
        device: torch.device,
        transform=None,
    ) -> None:
        self.vae = vae
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.verbose_period = verbose_period
        self.device = device
        self.transform = transform

    def _unpack_batch(self, batch):
        if isinstance(batch, dict):
            if "x" in batch: x = batch["x"]
            elif "image" in batch: x = batch["image"]
            elif "data" in batch: x = batch["data"]
            else: x = list(batch.values())[0]

            if "y" in batch: y = batch["y"]
            elif "label" in batch: y = batch["label"]
            elif "target" in batch: y = batch["target"]
            else: y = list(batch.values())[1]
            
        elif isinstance(batch, (list, tuple)):
            x, y = batch[0], batch[1]
        else:
            raise TypeError(f"Unsupported batch type: {type(batch)}")
            
        return x, y

    def _get_vae_feature(self, x):
        moments = self.vae.encoder(x)
        
        posterior = IsotropicNormalDistribution(moments)

        if hasattr(posterior, 'mean'):
            z = posterior.mean
        elif hasattr(posterior, 'mode'):
            z = posterior.mode()
        else:
            c = moments.shape[1] // 2
            z = moments[:, :c, :, :] 

        z = z.reshape(z.shape[0], -1) 
        
        return z

    def fit(self, epochs: int, train_loader: DataLoader, valid_loader: DataLoader = None):
        for epoch in range(epochs):
            self._train(train_loader, verbose=True, epoch_id=epoch)
            if valid_loader:
                self._valid(valid_loader, verbose=True, epoch_id=epoch)

    def _train(self, dataloader: DataLoader, verbose: bool, epoch_id: int):
        self.model.train()
        
        with tqdm(dataloader, unit="batch", disable=not verbose) as bar:
            bar.set_description(f"epoch {epoch_id}")
            for batch in bar:
                X_batch, y_batch = self._unpack_batch(batch)
                
                y_batch = y_batch.reshape(-1).long()
                X_batch, y_batch = X_batch.to(self.device), y_batch.to(self.device)
                
                if self.transform:
                    X_batch = self.transform(X_batch)
                
                self.optimizer.zero_grad()
                
                with torch.no_grad():
                    z_feature = self._get_vae_feature(X_batch)
                
                logits = self.model(z_feature)

                loss = self.criterion(logits, y_batch)
                loss.backward()
                self.optimizer.step()
                bar.set_postfix(loss=float(loss))

    def _valid(self, dataloader: DataLoader, verbose: bool, epoch_id: int):
        if verbose:
            (aupr_scores, auroc_scores), acc = self.evaluate(
                dataloader, verbose, epoch_id
            )
            print(f"val_acc: {acc:.3f}")

    def evaluate(self, dataloader: DataLoader, verbose: bool, epoch_id: int):
        self.model.eval()
        all_y = []
        all_probs = []
        groups = [] 
        
        with torch.no_grad():
            iterator = tqdm(dataloader, disable=not verbose, desc=f"val-epoch {epoch_id}")
            for batch in iterator:
                X_batch, y_batch = self._unpack_batch(batch)
                
                g_batch = None
                if isinstance(batch, dict):
                    if "c" in batch: g_batch = batch["c"]
                    elif "group" in batch: g_batch = batch["group"]
                elif isinstance(batch, (list, tuple)) and len(batch) > 2:
                    g_batch = batch[2]
                if g_batch is not None:
                    groups.append(g_batch.cpu().numpy())

                y_batch = y_batch.reshape(-1)
                X_batch = X_batch.to(self.device)
                
                z_feature = self._get_vae_feature(X_batch)
                
                logits = self.model(z_feature)
                probs = torch.softmax(logits, dim=1)

                all_y.append(y_batch.cpu())
                all_probs.append(probs.cpu())

        all_y = torch.cat(all_y).numpy()
        all_probs = torch.cat(all_probs).numpy()
        
        if len(groups) > 0:
            groups = np.concatenate(groups)
        else:
            groups = np.zeros_like(all_y)

        acc = accuracy_score(all_y, np.argmax(all_probs, axis=1))
        aupr_scores = {}
        auroc_scores = {}
        
        unique_groups = np.unique(groups)
        for g in unique_groups:
            mask = (groups == g)
            y_sub = all_y[mask]
            prob_sub = all_probs[mask]
            
            if len(np.unique(y_sub)) > 1 and prob_sub.shape[1] >= 2:
                try:
                    auroc = roc_auc_score(y_sub, prob_sub[:, 1])
                    aupr = average_precision_score(y_sub, prob_sub[:, 1])
                except:
                    auroc, aupr = 0.5, 0.0
            else:
                auroc, aupr = 0.5, 0.0
            
            k = f"group_{g}" if isinstance(g, (int, np.integer)) else str(g)
            auroc_scores[k] = float(auroc)
            aupr_scores[k] = float(aupr)
            
        return (aupr_scores, auroc_scores), acc


In [None]:
import torch
import torch.nn as nn
import numpy as np
import mlflow.pytorch
from src.modules.distribution import IsotropicNormalDistribution

def evaluate_loaded_vae(best_model, train_loader, valid_loader, test_loader, device, n_class=2):

    best_model.to(device)
    best_model.eval()

    for p in best_model.parameters():
        p.requires_grad = False
    print("Calculating MLP input dimension...")
    
    try:
        batch = next(iter(train_loader))
        
        if isinstance(batch, dict):
            if "x" in batch: x_dummy = batch["x"]
            elif "image" in batch: x_dummy = batch["image"]
            elif "data" in batch: x_dummy = batch["data"]
            else: x_dummy = list(batch.values())[0]
        else:
            x_dummy = batch[0]
            
        x_dummy = x_dummy[0:1].to(device)
        
    except Exception as e:
        print(f"Error getting dummy input: {e}. Fallback to random input (3x96x96).")
        x_dummy = torch.randn(1, 3, 96, 96).to(device)

    with torch.no_grad():
        moments = best_model.encoder(x_dummy)

        c = moments.shape[1] // 2
        h, w = moments.shape[2], moments.shape[3]
        
        flatten_dim = c * h * w
        
    print(f"Detected encoder output shape: {moments.shape}")
    print(f"MLP input dimension (flattened): {flatten_dim}")

    mlp = nn.Sequential(
        nn.Linear(flatten_dim, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(),
        nn.Linear(256, n_class), 
    ).to(device)

    optimizer = torch.optim.Adam(mlp.parameters(), lr=3e-4)
    criterion = nn.CrossEntropyLoss()


    trainer = DownstreamMLPTrainer(
        best_model, mlp, optimizer, criterion, 1, device
    )

    print("Training downstream MLP classifier...")
    trainer.fit(1, train_loader, valid_loader)

    print("Evaluating on test set...")
    (aupr_scores, auroc_scores), acc = trainer.evaluate(test_loader, False, 0)

    results = {
        "acc": round(float(acc), 3),
        "pr": {
            "overall": round(np.mean(list(aupr_scores.values())), 3),
            "stratified": aupr_scores,
        },
        "roc": {
            "overall": round(np.mean(list(auroc_scores.values())), 3),
            "stratified": auroc_scores,
        },
    }
    
    return results

metrics = evaluate_loaded_vae(best_model, train_loader, valid_loader, test_loader, device, n_class=2) 
print(metrics)


In [None]:
from torchvision.utils import make_grid
mu = posterior.mu
print(mu.shape)
for i in range(mu.shape[1]):
    plt.imshow(make_grid(mu[:,i][:,None,:,:], nrow=16).cpu().permute(1,2,0))
    plt.show()

In [None]:
z_cs = []
z_ss = []
labels = []
styles = []
channel_split = [4,8]
with torch.no_grad():
    best_model.eval()
    for batch in tqdm(test_loader):
        x = batch['image'].to(device)
        _, posterior = best_model(x)
        z_c, z_s = posterior.sample().split_with_sizes(channel_split, dim=1)
        z_cs.append(z_c.cpu())
        z_ss.append(z_s.cpu())
        labels.append(batch['label'])
        styles.append(batch['style'])

z_cs = torch.cat(z_cs, dim=0)
z_ss = torch.cat(z_ss, dim=0)
labels = torch.cat(labels, dim=0)
styles = torch.cat(styles, dim=0)

In [None]:
X = z_cs.view(z_cs.shape[0], -1).cpu().numpy()
y_content = labels.cpu().numpy()
y_style = styles.cpu().numpy()


N = X.shape[0]
N_sub = 10000   
idx = np.random.choice(N, size=min(N_sub, N), replace=False)

X_sub = X[idx]
content_sub = y_content[idx]
style_sub = y_style[idx]

tsne = TSNE(
    n_components=2,
    init='pca',
    perplexity=100,
    learning_rate=800,
    max_iter=8000,
    early_exaggeration=40,
    random_state=0,
)
z_2d = tsne.fit_transform(X_sub)  

fig, axs = plt.subplots(1, 2, figsize=(10, 4))

axs0 = axs[0].scatter(
    z_2d[:, 0],
    z_2d[:, 1],
    c=content_sub,        
    cmap='tab10',
    alpha=0.4,
    s=5,
)
cbar = fig.colorbar(axs0, ax=axs[0])
axs[0].set_title('color by content (5 types)')
plt.tight_layout()
plt.show()

In [None]:
X = z_ss.view(z_ss.shape[0], -1).cpu().numpy()
y_content = labels.cpu().numpy()
y_style = styles.cpu().numpy()

N = X.shape[0]
N_sub = 10000   
idx = np.random.choice(N, size=min(N_sub, N), replace=False)

X_sub = X[idx]
content_sub = y_content[idx]
style_sub = y_style[idx]

tsne = TSNE(
    n_components=2,
    init='pca',
    perplexity=100,
    learning_rate=800,
    max_iter=8000,
    early_exaggeration=40,
    random_state=0,
)
z_2d = tsne.fit_transform(X_sub)   

fig, axs = plt.subplots(1, 2, figsize=(10, 4))

axs0 = axs[0].scatter(
    z_2d[:, 0],
    z_2d[:, 1],
    c=content_sub,       
    cmap='tab10',
    alpha=0.4,
    s=5,
)
cbar = fig.colorbar(axs0, ax=axs[0])
axs[0].set_title('color by content (5 types)')

plt.tight_layout()
plt.show()