In [1]:
import nibabel as nib
from omegaconf import OmegaConf
import numpy as np
from pathlib import Path
import os
import einops
import torch.nn.functional as F
import torch

In [23]:
src_path = Path("/home/sz9jt/data/generative_brain/diffusion/diffusion_medil_reduced_CT_pretrained_channelwise")
out_path = src_path / "sample-unscaled"
out_path.mkdir(exist_ok=True)
config = OmegaConf.load(src_path / "config.yaml")

print(config.dataset.params)

if config.dataset.params.scale_each_ch:
    print("We should scale each dimension")
    scale = [
             config.dataset.params.ch1_fixed_scale,
             config.dataset.params.ch2_fixed_scale,
             config.dataset.params.ch3_fixed_scale,
             config.dataset.params.ch4_fixed_scale,
             config.dataset.params.ch5_fixed_scale,
             config.dataset.params.ch6_fixed_scale,
             config.dataset.params.ch7_fixed_scale,
             config.dataset.params.ch8_fixed_scale,
            ]
else:
    print("Global scale")
    scale = config.dataset.params.fixed_scale
    
print("Using scale:", scale)

{'data_paths': '/home/sz9jt/data/t1w_processed/outputs_8mm/2025-01-28T00_42_41_medil_reduced_CT_pretrained/train_val_encodings', 'transform': 'None', 'fixed_scale': 0.13, 'output_shape': [32, 32, 32], 'if_sample': True, 'scale_each_ch': True, 'ch1_fixed_scale': 0.25, 'ch2_fixed_scale': 0.133, 'ch3_fixed_scale': 0.133, 'ch4_fixed_scale': 0.167, 'ch5_fixed_scale': 0.25, 'ch6_fixed_scale': 0.25, 'ch7_fixed_scale': 0.25, 'ch8_fixed_scale': 0.25}
We should scale each dimension
Using scale: [0.25, 0.133, 0.133, 0.167, 0.25, 0.25, 0.25, 0.25]


In [24]:
encoding_path = config.dataset.params.data_paths
tmp = os.listdir(encoding_path)

for i in range(5):
    img = nib.load(f"{encoding_path}/{tmp[i]}")
    print(img.affine)
    affine = img.affine

[[  10.    0.   -0. -155.]
 [   0.   10.   -0. -155.]
 [  -0.    0.   10. -155.]
 [   0.    0.    0.    1.]]
[[  10.    0.   -0. -155.]
 [   0.   10.   -0. -155.]
 [  -0.    0.   10. -155.]
 [   0.    0.    0.    1.]]
[[  10.    0.   -0. -155.]
 [   0.   10.   -0. -155.]
 [  -0.    0.   10. -155.]
 [   0.    0.    0.    1.]]
[[  10.    0.   -0. -155.]
 [   0.   10.   -0. -155.]
 [  -0.    0.   10. -155.]
 [   0.    0.    0.    1.]]
[[  10.    0.   -0. -155.]
 [   0.   10.   -0. -155.]
 [  -0.    0.   10. -155.]
 [   0.    0.    0.    1.]]


In [25]:
samples = os.listdir(src_path / "sample-ckpt1000")
samples.sort()
print(samples[0])
img = np.load(src_path / "sample-ckpt1000" / samples[0])
print(img.shape)

sample_epoch1000_fixscale0.13_0.npy
(1, 8, 32, 32, 32)


In [26]:
# For brain
samples = os.listdir(src_path / "sample-ckpt1000")
samples.sort()
print(samples[0])

    
def scale_samples(samples, scale):
    for item in samples:
        out_name = item.replace(".npy", ".nii.gz")
        img = np.load(src_path / "sample-ckpt1000" / item)
        img = img.squeeze()
        img = np.transpose(img, (1, 2, 3, 0))
        if isinstance(scale, int):
            img = img / scale
        elif isinstance(scale, list):
            assert img.shape[-1] == len(scale)
            tmp = np.array(scale)[None, None, None, ...]
            img = img / tmp
        else:
            raise ValueError("Wrong scale_type")

        nifti_img = nib.Nifti1Image(img, affine=affine)
        nib.save(nifti_img, out_path / out_name)
        

if isinstance(scale, int):
    print("Global scale")    
elif isinstance(scale, list):
    print("Channelwise scale")
scale_samples(samples, scale)

sample_epoch1000_fixscale0.13_0.npy
Channelwise scale


In [27]:
out = os.listdir(out_path)
print(len(out))

data = nib.load(out_path / out[0]).get_fdata()

print(data.shape, data.min(), data.max())
print(data[..., 0].min(), data[..., 0].max())
print(data[..., 1].min(), data[..., 1].max())
print(data[..., 2].min(), data[..., 2].max())

1000
(32, 32, 32, 8) -7.30332216822115 6.721348690807371
-3.9100406169891357 3.864105701446533
-7.30332216822115 4.662463091369858
-4.9443482456350685 6.721348690807371
