In [1]:
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

import torch
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split, SubsetRandomSampler
from sklearn.model_selection import train_test_split
from scipy.spatial import cKDTree
from scipy.ndimage import zoom  # For resampling
import math

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.linalg import vector_norm

In [2]:
os.chdir('/data/jianglab1/xiaoyi/code/DSR_code/DSR/')

In [3]:
from DSR_pretrain import *

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [13]:
train_input = torch.load('Data/train_input_100.pt')
train_target = torch.load('Data/train_input_100.pt')

valid_input = torch.load('Data/valid_input_100.pt')
valid_target = torch.load('Data/valid_input_100.pt')

In [14]:
train_input.shape

torch.Size([400, 3, 16, 16, 16])

# Pre-train model

In [7]:
beta = 1

DSR = DSR_train(train_input, train_target, valid_input, valid_target, lr = 1e-4, num_epochs=20, batch_size=256, device=device)

GPU is available, running on GPU.

Batch is larger than half of the sample size. Training based on full-batch gradient descent.

Validation loss on the original (non-standardized) scale:
	Energy-loss: 29.1911,  E(|Y-Yhat|): 44.7358,  E(|Yhat-Yhat'|): 31.0894

Prediction-loss E(|Y-Yhat|) and variance-loss E(|Yhat-Yhat'|) should ideally be equally large
-- consider training for more epochs or adjusting hyperparameters if there is a mismatch 


In [8]:
DSR.save_model("DSR_unet_model_pretrain.pth")

Model saved to: DSR_unet_model_pretrain.pth


## Evaluate pre-train model on validation set

In [9]:
# Load the model
def load_model(model, model_path, device):
    # Load the model's state dict from the saved file
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint)
    return model

In [10]:
checkpoint = torch.load('DSR_unet_model_pretrain.pth')
DSR = UNet3D(in_channels=3, out_channels=3)
DSR.load_state_dict(checkpoint['model_state_dict'])
DSR.to(device)
DSR.eval()

UNet3D(
  (encoder1): Sequential(
    (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout3d(p=0.2, inplace=False)
    (4): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (5): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU(inplace=True)
    (7): Dropout3d(p=0.2, inplace=False)
  )
  (pool1): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (encoder2): Sequential(
    (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Dropout3d(p=0.2, inplace=False)
    (4): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=

In [11]:
### DSR prediction
batch_size = 32
sigma_t = 0.1
n_samples = valid_input.shape[0]

num_runs = 200
DSR_predictions = []

# Use torch.no_grad() to prevent storing gradients during inference.
for start_idx in range(0, n_samples, batch_size):
    end_idx = min(start_idx + batch_size, n_samples)
    
    batch = valid_input[start_idx:end_idx].to(device)
    target = valid_target[start_idx:end_idx].to(device)
    batch_sum = torch.zeros_like(target).to(device)

    # Loop over the number of runs to compute predictions
    for run in range(num_runs):
        # Generate noise and add it to the input batch
        epsilon_t = torch.randn_like(batch) * (sigma_t ** 0.5)
        input_batch = batch + epsilon_t
        DSR_preds = DSR(input_batch)
        batch_sum += DSR_preds.detach()

    # Compute the average prediction for this batch
    batch_mean = batch_sum / num_runs
    DSR_predictions.append(batch_mean)
        
# Concatenate all batch predictions into one final prediction tensor
DSR_predictions = torch.cat(DSR_predictions, dim=0)

In [12]:
DSR_predictions.shape

torch.Size([400, 3, 16, 16, 16])