In [2]:
import os
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from monai import transforms
from monai.config import print_config
from monai.utils import first, set_determinism
from torch import nn
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler

from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Normalize
from utils import set_seed, DoseActivityDataset, JointCompose, Resize3D

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

Using cpu


  return torch._C._cuda_getDeviceCount() > 0


In [None]:
# Creating the dataset
input_dir = "data/dataset_1/input"
output_dir = "data/dataset_1/output"

# Statistics of the dataset (previously found for the entire Prostate dataset)
mean_input = 0.002942
std_input = 0.036942
max_input = 1.977781
min_input = 0.0

mean_output = 0.00000057475
std_output = 0.00000662656
max_output = 0.00060621166
min_output = 0.0

# Transformations
input_transform = Compose([
    # GaussianBlurFloats(p=0.3, sigma=1),
    Normalize(mean_input, std_input)
])

output_transform = Compose([
    Normalize(mean_output, std_output)
])


img_size = (128, 64, 64)
joint_transform = JointCompose([
    Resize3D(img_size)
])


set_seed(seed)

num_samples = 948
# Create dataset applying the transforms
dataset = DoseActivityDataset(input_dir=input_dir, output_dir=output_dir,
                              input_transform=input_transform, output_transform=output_transform, joint_transform=joint_transform,
                              num_samples=num_samples)

# Split dataset into 80% training, 15% validation, 5% testing
train_size = int(0.8 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, val_size, test_size])

# Create DataLoaders for training
batch_size = 8
num_workers = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
