In [1]:
import plotly.express as px
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import MedNISTDataset
from monai.config import print_config
from monai.data import DataLoader, Dataset
from monai.utils import first, set_determinism
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from generative.inferers import LatentDiffusionInferer
from generative.losses.adversarial_loss import PatchAdversarialLoss
from generative.losses.perceptual import PerceptualLoss
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler

from generative.inferers import DiffusionInferer
from generative.networks.nets import DiffusionModelUNet

import pandas as pd
import sys
sys.path.insert(0,'..')

from loaders.ultrasound_dataset import USDataModule, USDataset
from transforms.ultrasound_transforms import DiffusionEvalTransforms, DiffusionTrainTransforms

from loaders.mr_dataset import MRDataModuleVolumes, MRDatasetVolumes
from transforms.mr_transforms import MRDiffusionEvalTransforms, MRDiffusionTrainTransforms
# from callbacks.logger import DiffusionImageLogger

from nets import diffusion
import pickle
import os
import pytorch_lightning as pl


mount_point = "/mnt/raid/C1_ML_Analysis/"

print_config()

2023-04-24 14:59:12,937 - Created a temporary directory at /tmp/tmpquag660o
2023-04-24 14:59:12,939 - Writing /tmp/tmpquag660o/_remote_module_non_scriptable.py
MONAI version: 1.1.0
Numpy version: 1.23.1
Pytorch version: 1.12.1+cu113
MONAI flags: HAS_EXT = False, USE_COMPILED = False, USE_META_DICT = False
MONAI rev id: a2ec3752f54bfc3b40e7952234fbeb5452ed63e3
MONAI __file__: /mnt/raid/home/jprieto/anaconda3/envs/torch_us/lib/python3.8/site-packages/monai/__init__.py

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 9.2.0
Tensorboard version: 2.12.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.13.1+cu113
tqdm version: 4.64.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.9.4
pandas version: 1.4.3
einops version: 0.6.0
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: 

The 'neptune-client' package has been deprecated and will be removed in the future. Install the 'neptune' package instead. For more, see https://docs.neptune.ai/setup/upgrading/
You're importing the Neptune client library via the deprecated `neptune.new` module, which will be removed in a future release. Import directly from `neptune` instead.


In [2]:

csv_test = "/mnt/raid/C1_ML_Analysis/CSV_files/extract_frames_blind_sweeps_c1_30082022_wscores_1e-4_train_train_sample.parquet"

if(os.path.splitext(csv_test)[1] == ".csv"):        
    df_test = pd.read_csv(os.path.join(mount_point, csv_test))
else:        
    df_test = pd.read_parquet(os.path.join(mount_point, csv_test))

In [3]:


# train_transform = DiffusionTrainTransforms()
valid_transform = DiffusionEvalTransforms()

test_ds = USDataset(df_test, mount_point, img_column='img_path', transform=valid_transform, repeat_channel=False)
test_data = DataLoader(test_ds, batch_size=1, num_workers=4, persistent_workers=True, pin_memory=True, shuffle=True, prefetch_factor=1)


In [None]:
# device = torch.device("cuda")

# model = DiffusionModelUNet(
#     spatial_dims=2,
#     in_channels=1,
#     out_channels=1,
#     num_channels=(128, 256, 256),
#     attention_levels=(False, True, True),
#     num_res_blocks=1,
#     num_head_channels=256,
# )
# model.to(device)

# scheduler = DDPMScheduler(num_train_timesteps=1000)

# optimizer = torch.optim.Adam(params=model.parameters(), lr=2.5e-5)

# inferer = DiffusionInferer(scheduler)

In [None]:
# timesteps = torch.randint(
#     0, inferer.scheduler.num_train_timesteps, (2,)
# ).long()
# timesteps

In [None]:
# model.eval()
# noise = torch.randn((1, 1, 64, 64))
# noise = noise.to(device)
# scheduler.set_timesteps(num_inference_steps=1000)
# with autocast(enabled=True):
#     image, intermediates = inferer.sample(
#         input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=100
#     )

# chain = torch.cat(intermediates, dim=-1)

# plt.style.use("default")
# plt.imshow(chain[0, 0].cpu(), vmin=0, vmax=1, cmap="gray")
# plt.tight_layout()
# plt.axis("off")
# plt.show()

In [4]:
class StackDataset(Dataset):
    def __init__(self, dataset, multiple_slices=10):        
        self.dataset = dataset
        self.dataset.df = self.dataset.df.sample(frac=1).reset_index(drop=True)
        self.multiple_slices = multiple_slices       

    def __len__(self):
        return len(self.dataset)//self.multiple_slices

    def __getitem__(self, idx):
        
        start_idx = idx*self.multiple_slices

        return torch.stack([self.dataset[idx] for idx in range(start_idx, start_idx + self.multiple_slices)], dim=1)

    
class ConcatDataset(torch.utils.data.Dataset):
    def __init__(self, *datasets):
        self.datasets = datasets

    def __getitem__(self, i):
        return tuple(d[i] for d in self.datasets)

    def __len__(self):
        return min(len(d) for d in self.datasets)

In [5]:
test_ds = USDataset(df_test, mount_point, img_column='img_path', transform=valid_transform, repeat_channel=False)
test_ds_multiple_us = USDatasetMultipleSlices(test_ds, multiple_slices=10)

In [6]:
csv_test = "CSV_files/MR_diffusion_test.csv"
df_test_mr = pd.read_csv(os.path.join(mount_point, csv_test))
test_ds_mr = MRDatasetVolumes(df_test_mr, mount_point=mount_point, img_column="img_path", transform=MRDiffusionEvalTransforms(mount_point=mount_point, random_slice_size=10))

concat_ds = ConcatDataset(test_ds_mr, test_ds_multiple_us)

In [7]:
class MRUSDataModule(pl.LightningDataModule):
    def __init__(self, mr_dataset_train, us_dataset_train, mr_dataset_val, us_dataset_val, batch_size=4, num_workers=4):
        super().__init__()

        self.mr_dataset_train = mr_dataset_train
        self.us_dataset_train = us_dataset_train

        self.mr_dataset_val = mr_dataset_val
        self.us_dataset_val = us_dataset_val
        
        self.batch_size = batch_size
        self.num_workers = num_workers

    def setup(self, stage=None):

        # Assign train/val datasets for use in dataloaders
        self.train_ds = ConcatDataset(self.mr_dataset_train, self.us_dataset_train)
        self.val_ds = ConcatDataset(self.mr_dataset_val, self.us_dataset_val)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True, shuffle=True, collate_fn=self.arrange_slices)

    def val_dataloader(self):
        return DataLoader(self.val_ds, batch_size=self.batch_size, num_workers=self.num_workers, persistent_workers=True, pin_memory=True, collate_fn=self.arrange_slices)    

    def arrange_slices(self, batch):
        mr_batch = [mr for mr, us in batch]
        us_batch = [us for mr, us in batch]        
        mr_batch = torch.cat(mr_batch, axis=1).permute(dims=(1,0,2,3))
        us_batch = torch.cat(us_batch, axis=1).permute(dims=(1,0,2,3))        
        return mr_batch[torch.randperm(mr_batch.shape[0])], us_batch

In [8]:
data_module = MRUSDataModule(test_ds_mr, test_ds_multiple_us, test_ds_mr, test_ds_multiple_us)
data_module.setup()
loader = data_module.train_dataloader()

for l in loader:
    mr, us = l
    print(mr.shape, us.shape)

torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size([40, 1, 256, 256])
torch.Size([40, 1, 256, 256]) torch.Size

KeyboardInterrupt: 