In [1]:
from torch.utils.data import *
from lircst_ana_dataset import LircstAnaDataset
from torch import Generator

dataset = LircstAnaDataset('/home/samnub/dev/lircst-ana/data/')

rand_generator = Generator().manual_seed(42) # The meaning of life, the universe and everything

dataset_train, dataset_valid, dataset_test = random_split(dataset, [0.8, 0.1, 0.1], generator=rand_generator)

print(f"Train set size: {len(dataset_train)}")
print(f"Validation set size: {len(dataset_valid)}")
print(f"Test set size: {len(dataset_test)}")

Train set size: 14969
Validation set size: 1871
Test set size: 1871


In [2]:
model_args = {
    "ECD-Phys": {
        "train_dataset": dataset_train,
        "valid_dataset": None,
        "test_dataset": dataset_test,
        "physics": True,  # Use physics-based loss
        "latent": False,  # Don't use latent diffusion
    },
    "ECD": {
        "train_dataset": dataset_train,
        "valid_dataset": None,
        "test_dataset": dataset_test,
        "physics": False,  # Don't use physics-based loss
        "latent": False,  # Don't use latent diffusion
    },
    "ECLD-Phys": {
        "train_dataset": dataset_train,
        "valid_dataset": None,
        "test_dataset": dataset_test,
        "physics": True,  # Use physics-based loss
        "latent": True,  # Use latent diffusion
    },
    "ECLD": {
        "train_dataset": dataset_train,
        "valid_dataset": None,
        "test_dataset": dataset_test,
        "physics": False,  # Don't use physics-based loss
        "latent": True,  # Use latent diffusion
    },
}


In [None]:
# Full pipeline
from encoded_conditional_diffusion import ECDiffusion
from util import generate_directory_name, get_latest_ckpt


# Setup Diffusion modules
import pytorch_lightning as pl
from Diffusion.EMA import EMA
from pytorch_lightning.callbacks import ModelCheckpoint

pre_load: bool = False # Load the latest checkpoint if available
train_mode: bool = True
test_afterward: bool = True

def train():
    for name, model_arg in model_args.items():
        print(f"Training {name}...")

        model = ECDiffusion(**model_arg)
        
        trainer = pl.Trainer(
            max_epochs=200,
            max_steps=2e5,
            callbacks=[EMA(0.9999)],
            accelerator='gpu',
            devices=[0],
            num_sanity_val_steps=0,  # Disable sanity check on dataloader
            default_root_dir=generate_directory_name(name, get_latest_ckpt(name)[1] if pre_load else None),
        )
        
        trainer.fit(model, ckpt_path=get_latest_ckpt(name)[0] if pre_load else None)
        
        if test_afterward:
            trainer.test(model, ckpt_path=get_latest_ckpt(name)[0] if pre_load else None)

if train_mode:
    train()


  from .autonotebook import tqdm as notebook_tqdm


Training ECD-Phys...
Is Time embed used ?  True


  rank_zero_warn(
You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA RTX A4000') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                        | Params | Mode 
----------------------------------------------------------------------
0 | model         | EncodedConditionalDiffusion | 56.6 M | train
1 | physics_model | PhysicsIncorporated         | 0      | train
-------------------------------------------------------------------

Epoch 0:  10%|▉         | 92/936 [00:41<06:23,  2.20it/s, v_num=0, train_loss=0.141] 

In [None]:
# Display some samples from each model
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from encoded_conditional_diffusion import ECDiffusion
from util import get_latest_ckpt

def show_samples(model: ECDiffusion, dataset_idx: int|None=None, num_samples=4):
    phan, sino, _ = dataset_test[np.random.randint(0, len(dataset_test)) if dataset_idx is None else dataset_idx]
    #sino = torch.from_numpy(sino)
    batch_input = torch.stack(num_samples*[sino]).cuda()

    out, encoded_condition = model(batch_input, verbose=True)

    # Pre-process our data to be consistently -1 to 1 scaled
    phan = model.preprocess(image=phan.unsqueeze(0))[0].squeeze(0)
    #out, _ = model.preprocess(image=out) # If the model is trained right, this should not be necessary

    # print min and max values of the output
    print(f"Output min: {out.min().item()}, max: {out.max().item()}")
    print(f"Encoded condition min: {encoded_condition.min().item()}, max: {encoded_condition.max().item()}")
    print(f"Phan min: {phan.min().item()}, max: {phan.max().item()}")

    plt.figure(dpi=800)
    plt.subplot(1,3+len(out)*2,1)
    plt.imshow(torch.sum(sino, axis=2))
    plt.title('Input')
    plt.axis('off')
    for idx in range(out.shape[0]*3):
        if idx % 3 == 1:
            continue
        if idx % 3 == 2:
            continue
        plt.subplot(1,3+len(out)*3,idx+2)
        plt.imshow(torch.sum(encoded_condition[idx//3].detach().cpu(), axis=0))
        plt.axis('off')
        plt.subplot(1,3+len(out)*3,idx+3)
        plt.imshow(out[idx//3].detach().cpu()[0])
        plt.axis('off')
        plt.subplot(1,3+len(out)*3,idx+4)
        plt.imshow(out[idx//3].detach().cpu()[-1])
        plt.axis('off')
    plt.subplot(1,3+len(out)*3,2+len(out)*3)
    plt.imshow(phan[0].cpu())
    plt.title('S')
    plt.axis('off')
    plt.subplot(1,3+len(out)*3,3+len(out)*3)
    plt.imshow(phan[1].cpu())
    plt.title('A')
    plt.axis('off')
    plt.show()

    def compare_images(imageA, imageB):
        # Compute SSIM between two images, and PSNR
        
        # If images aren't the same size, resize them
        if imageA.shape != imageB.shape:
            imageA = F.interpolate(imageA, size=imageB.shape[-2:], mode='bilinear', align_corners=False)

        # P.S. Scikit-image returns a value between -1 and 1, where 1 is a perfect match and -1 is a complete mismatch
        s = ssim(imageA, 
                imageB, 
                multichannel=True,
                data_range=imageB.max() - imageB.min())
        
        p = psnr(imageA, imageB, data_range=imageB.max() - imageB.min())

        return s, p
    
    print("SSIM, PSNR:")
    for idx in range(len(out)):
        print(f"Scatter channel: {compare_images(out[idx].detach().cpu().numpy()[0], phan[0].cpu().numpy())}")
        print(f"Attenuation channel: {compare_images(out[idx].detach().cpu().numpy()[-1], phan[1].cpu().numpy())}")

random_idx = np.random.randint(0, len(dataset_test))
print(f"Random dataset index: {random_idx}")

for name, model_arg in model_args.items():
    print(f"Showing samples for {name}...")

    model = ECDiffusion.load_from_checkpoint(
        get_latest_ckpt(name)[0],
        **model_arg
    ).cuda()

    show_samples(model, dataset_idx=random_idx)