In [None]:
import torch
import mlflow
import numpy as np

from mlflow.types import Schema, TensorSpec
from mlflow.models import ModelSignature

from src.sd_vae.ae import VAE
from src.trainers import EarlyStopping
from src.trainers.first_stage_trainer import CLEAR_VAEFirstStageTrainer


from src.utils.exp_utils.train_utils import (
    load_cfg,
    xavier_init,
)
from src.utils.exp_utils.visual import feature_swapping_plot
from src.utils.data_utils.camelyon import build_dataloader

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
dataloaders = build_dataloader(
    data_root="/hpc/group/engelhardlab/ms1008/image_data",
    batch_size=32,
    download=False,
    num_workers=10,
)

In [None]:
train_loader = dataloaders["train"]
valid_loader = dataloaders['valid']

In [None]:
cfg = load_cfg('./config/camelyon.yaml')

In [None]:
cfg

In [None]:
input_schema = Schema([TensorSpec(np.dtype(np.float32), [-1, 1, 32, 32])])
output_schema = Schema([TensorSpec(np.dtype(np.float32), [-1, 1, 32, 32])])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)

vae = VAE(**cfg['vae']).to(device)

vae.apply(xavier_init)

trainer = CLEAR_VAEFirstStageTrainer(
    model=vae,
    early_stopping=EarlyStopping(patience=8),
    verbose_period=2,
    device=device,
    model_signature=signature,
    args=cfg["trainer_param"],
)

In [None]:
mlflow.set_tracking_uri("./mlruns")
mlflow.set_experiment("test-camelyon")
with mlflow.start_run() as run:
    mlflow.log_params(cfg['vae'] | cfg['trainer_param'])
    trainer.fit(epochs=1, train_loader=train_loader, valid_loader=valid_loader)

In [None]:
# run_id = run.info.run_id
# print(run_id)
run_id = 'dd7f3cd2c6e54d9d9485b7eae144ac8a'

In [None]:
x = next(iter(dataloaders["train"]))["image"].to(device)
best_model = mlflow.pytorch.load_model(f"runs:/{run_id}/best_model")
with torch.no_grad():
    best_model.eval()
    _, posterior = best_model(x)
z_c, z_s = posterior.mu.split_with_sizes(
    cfg["trainer_param"]["channel_split"], dim=1
)
select = torch.randint(0, 32, (5,)).tolist()

In [None]:
z_c.shape

In [None]:
feature_swapping_plot(
    z_c[select],
    z_s[select],
    x[select],
    best_model,
    img_size=96,
)