In [1]:
import numpy as np
import torch
from einops import rearrange
from sklearn.decomposition import PCA
from tqdm.auto import tqdm

from diffusion_3d.chestct.autoencoder.vae.config import get_config
from diffusion_3d.chestct.autoencoder.vae.model import VAELightning
from diffusion_3d.datasets.ct_rate import CTRATEDataModule
from diffusion_3d.utils.visualize import plot_scans

In [2]:
config = get_config()
config

[1;35mMunch[0m[1m([0m[1m{[0m[32m'data'[0m: [1;35mMunch[0m[1m([0m[1m{[0m[32m'csvpath'[0m: [32m'/raid3/arjun/ct_pretraining/csvs/sources.csv'[0m, [32m'datapath'[0m: [32m'/raid3/arjun/ct_pretraining/scans/'[0m, [32m'checkpointspath'[0m: [32m'/raid3/arjun/checkpoints/adaptive_autoencoder/'[0m, [32m'limited_dataset_size'[0m: [3;35mNone[0m, [32m'allowed_spacings'[0m: [1m([0m[1m([0m[1;36m0.4[0m, [1;36m7[0m[1m)[0m, [1m([0m[1;36m-1[0m, [1;36m-1[0m[1m)[0m, [1m([0m[1;36m-1[0m, [1;36m-1[0m[1m)[0m[1m)[0m, [32m'allowed_shapes'[0m: [1m([0m[1m([0m[1;36m64[0m, [1;36m-1[0m[1m)[0m, [1m([0m[1;36m256[0m, [1;36m-1[0m[1m)[0m, [1m([0m[1;36m256[0m, [1;36m-1[0m[1m)[0m[1m)[0m, [32m'train_augmentations'[0m: [1;35mMunch[0m[1m([0m[1m{[0m[32m'_target_'[0m: [32m'monai.transforms.Compose'[0m, [32m'transforms'[0m: [1m[[0m[1;35mMunch[0m[1m([0m[1m{[0m[32m'_target_'[0m: [32m'vision_architectures.transforms.cr

In [4]:
checkpoint_path = r"/raid3/arjun/checkpoints/adaptive_autoencoder/v42__2025_03_24/version_0/checkpoints/last.ckpt"

model = VAELightning.load_from_checkpoint(
    checkpoint_path,
    map_location="cpu",
    model_config=config.model,
).autoencoder
model.eval()


[1;35mVAE[0m[1m([0m
  [1m([0mencoder[1m)[0m: [1;35mSwinV23DModel[0m[1m([0m
    [1m([0mpatchify[1m)[0m: [1;35mPatchEmbeddings3D[0m[1m([0m
      [1m([0mpatch_embeddings[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1[0m, [1;36m24[0m, [33mkernel_size[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m[1m)[0m
      [1m([0mnormalization[1m)[0m: [1;35mLayerNorm[0m[1m([0m[1m([0m[1;36m24[0m,[1m)[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33melementwise_affine[0m=[3;92mTrue[0m[1m)[0m
    [1m)[0m
    [1m([0mabsolute_position_embeddings[1m)[0m: [1;35mAbsolutePositionEmbeddings3D[0m[1m([0m[1m)[0m
    [1m([0mencoder[1m)[0m: [1;35mSwinV23DEncoder[0m[1m([0m
      [1m([0mstages[1m)[0m: [1;35mModuleList[0m[1m([0m
        [1m([0m[1;36m0[0m[1m)[0m: [1;35mSwinV23DStage[0m[1m([0m
          [1m([0mblocks[1m)[0m: [1;35mModuleList[

In [5]:
datamodule = CTRATEDataModule(config.data)
dataloader = datamodule.val_dataloader()
len(dataloader)

valid:   0%|          | 0/500 [00:00<?, ?it/s]

No. of valid datapoints: 500


[1;36m500[0m

In [6]:
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        if i >= 1:
            break

        x = batch["image"][:4]
        crop_offsets = batch['crop_offset'][:4]

        reconstructed = model(x, crop_offsets, "valid")["reconstructed"].cpu()
        print(f"Batch idx: {i}")
        print(reconstructed.min(), reconstructed.max())
        print(x.shape, reconstructed.shape)
        print((x.min(), x.max()), (reconstructed.min(), reconstructed.max()))

        x[:, :, :, 0, 0] = -1.0
        x[:, :, :, -1, -1] = 1.0
        reconstructed[:, :, :, 0, 0] = -1.0
        reconstructed[:, :, :, -1, -1] = 1.0

        for i in range(x.shape[0]):
            plot_scans([x[i][0], reconstructed[i][0]], ["Original", "Reconstructed"])

torch.cuda.empty_cache()

Batch idx: 0
metatensor(-0.7735) metatensor(0.3016)
torch.Size([4, 1, 64, 64, 64]) torch.Size([4, 1, 64, 64, 64])
(metatensor(-1.), metatensor(0.7240)) (metatensor(-0.7735), metatensor(0.3016))


interactive(children=(IntSlider(value=0, description='z', max=63), Output()), _dom_classes=('widget-interact',…

interactive(children=(IntSlider(value=0, description='z', max=63), Output()), _dom_classes=('widget-interact',…

interactive(children=(IntSlider(value=0, description='z', max=63), Output()), _dom_classes=('widget-interact',…

interactive(children=(IntSlider(value=0, description='z', max=63), Output()), _dom_classes=('widget-interact',…

# Whole CT

In [None]:
raise Exception("Don't want to automatically run beyond this point")

In [None]:
dataloader = datamodule.test_dataloader()
len(dataloader)

test:   0%|          | 0/500 [00:00<?, ?it/s]

No. of test datapoints: 500


[1;36m500[0m

In [None]:
with torch.no_grad():
    for i, batch in enumerate(dataloader):
        if i < 1:
            continue

        if i >= 2:
            break

        x = batch["image"]
        print(x.shape)
        x = x[:, :, :, 128:384, 128:384]
        print(x.shape)

        model.cuda()
        x = x.float()

        reconstructed = model(x.cuda(), "valid")["reconstructed"].cpu()
        print(f"Batch idx: {i}")
        print(reconstructed.min(), reconstructed.max())
        print(x.shape, reconstructed.shape)
        print((x.min(), x.max()), (reconstructed.min(), reconstructed.max()))

        reconstructed[:, :, :, 0, 0] = -1.0
        reconstructed[:, :, :, -1, -1] = 1.0

        for i in range(x.shape[0]):
            plot_scans([x[i][0], reconstructed[i][0]], ["Original", "Reconstructed"])

torch.cuda.empty_cache()

torch.Size([1, 1, 512, 512, 512])
torch.Size([1, 1, 512, 256, 256])
Batch idx: 1
metatensor(-0.9945) metatensor(0.5118)
torch.Size([1, 1, 512, 256, 256]) torch.Size([1, 1, 512, 256, 256])
(metatensor(-1.), metatensor(1.)) (metatensor(-0.9945), metatensor(0.5118))


interactive(children=(IntSlider(value=0, description='z', max=511), Output()), _dom_classes=('widget-interact'…

# PCA

In [None]:
raise Exception("Don't want to automatically run beyond this point")

In [None]:
import gc
from collections import defaultdict

device = torch.device("cuda:0")

latent_vectors = []
stage_output_vectors = defaultdict(list)

model.to(device)

with torch.no_grad():
    for batch in tqdm(dataloader):
        x = batch["image"].to(device)
        mu, _, stage_outputs = model.encode(x, return_stage_outputs=True)
        x.cpu()
        del x
        latent_vectors.append(mu.cpu().numpy())
        for i in range(len(stage_outputs)):
            stage_output_vectors[i].append(stage_outputs[i].cpu().numpy())

gc.collect()
torch.cuda.empty_cache()

latent_vectors = np.concatenate(latent_vectors, axis=0)

  0%|          | 0/500 [00:00<?, ?it/s]

In [None]:
pattern = "b c z y x -> (b z y x) c"
# pattern = "b c z y x -> b (c z y x)"

latent_vectors_rearranged = rearrange(latent_vectors, pattern)
for i in range(len(stage_output_vectors)):
    stage_output_vectors[i] = np.concatenate(stage_output_vectors[i], axis=0)
    stage_output_vectors[i] = rearrange(stage_output_vectors[i], pattern)

for name, fit_vectors in [(f"Stage {i+1}", stage_output_vectors[i]) for i in range(len(stage_output_vectors))] + [
    ("Sampled", latent_vectors_rearranged)
]:
    try:
        # Compute PCA
        pca = PCA()
        pca.fit(fit_vectors)

        # Calculate effective dimensionality
        explained_variance_ratio = pca.explained_variance_ratio_
        cumulative_variance = np.cumsum(explained_variance_ratio)
        effective_dim = np.argmax(cumulative_variance > 0.95) + 1

        print(
            f"Effective dim {name}: {effective_dim} / {fit_vectors.shape[1]}\t({int(effective_dim * 100 / fit_vectors.shape[1])}%)"
        )
    except Exception as e:
        print(f"Failed to compute PCA for {name}: {e}")

# Visualize patches

In [None]:
from matplotlib import pyplot as plt

In [None]:
patch_size = 4

x = batch["scan"]

plt.figure(figsize=(10, 10))
plt.imshow(x[0][0][128], cmap="gray")
plt.grid(color="r", linestyle="-", linewidth=0.3)
plt.gca().set_xticks(np.arange(0, x.shape[-1], patch_size))
plt.gca().set_yticks(np.arange(0, x.shape[-1], patch_size))
plt.show()

# Check latent space

In [None]:
raise Exception("Don't want to automatically run beyond this point")

In [None]:
dataloader_iter = iter(dataloader)
batch1 = next(dataloader_iter)
batch2 = next(dataloader_iter)
x1, x2 = batch1["scan"], batch2["scan"]
(batch1["uid"], batch2["uid"]), (x1.shape, x2.shape)

In [None]:
with torch.no_grad():
    output1 = model.process_step(x1, [], "valid", 0)
    output2 = model.process_step(x2, [], "valid", 0)

reconstructed1 = output1["reconstructed"]
reconstructed2 = output2["reconstructed"]

adapted_encoded1 = output1["adapted_encoded"]
adapted_encoded2 = output2["adapted_encoded"]

(reconstructed1.shape, reconstructed2.shape), (adapted_encoded1.shape, adapted_encoded2.shape)

In [None]:
adapted_encoded_inter = (adapted_encoded1 + adapted_encoded2) / 2
adapted_encoded_inter.shape

In [None]:
from torch.nn import functional as F

with torch.no_grad():
    decoded_inter = model.decode(adapted_encoded_inter)
    reconstructed_inter = F.interpolate(decoded_inter, x1.shape[2:], mode="trilinear")
reconstructed_inter.shape

In [None]:
plot_scans(
    [x1[0][0], x2[0][0], reconstructed1[0][0], reconstructed2[0][0], reconstructed_inter[0][0]],
    ["Scan1", "Scan2", "Reconstructed1", "Reconstructed2", "Reconstructed interpolated latent"],
    cols=2,
)

In [None]:
plot_scans((reconstructed1 - reconstructed2).abs()[0][0])