In [13]:
import os
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import scipy.io

import torch
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split, SubsetRandomSampler

from sklearn.model_selection import train_test_split
from scipy.spatial import cKDTree
from scipy.ndimage import zoom  # For resampling
import math

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.linalg import vector_norm

In [5]:
from DSR_finetune import *

# Data preparation for fine-tuning

For the fine-tuning process, we use both 4DF and CFD data. When these data are stored as 3D velocity fields with corresponding spatial coordinates (as in the pretraining stage), cubic patches can be constructed using the `pretrain_data_preparation' procedure. Below, we provide the code for the construction of the cubic patch where the 4DF and CFD data are stored in the format (m, n, k, 3).

In [8]:
# Load the .mat file
vcdf_mat = scipy.io.loadmat("data/CFD4DF/VCFD.mat")
Velocity_mat = scipy.io.loadmat("data/CFD4DF/Velocity.mat")

In [9]:
vcdf_data = vcdf_mat['Velocity_CFD'][0]
vel_data = Velocity_mat['Velocity'][0]

In [10]:
vcdf_data.shape

(76, 901, 231, 3)

In [11]:
vcdf_data_flat = vcdf_data.reshape(-1, 3)

# Create a boolean mask that is True for rows that are not all zeros.
mask1 = ~np.all(vcdf_data_flat == 0, axis=1)
mask0 = np.all(vcdf_data_flat == 0, axis=1)
vcdf_data1 = vcdf_data_flat[mask1]

x = 46
y = 46
z = 168

vcdf_data2 = vcdf_data1.reshape(x, y, z, 3)

In [12]:
vel_data_flat = vel_data.reshape(-1, 3)

# Create a boolean mask that is True for rows that are not all zeros.
vel_data1 = vel_data_flat[mask1]

x = 46
y = 46
z = 168

vel_data2 = vel_data1.reshape(x, y, z, 3)

In [14]:
data_input = torch.Tensor(vel_data2)
data_target = torch.Tensor(vcdf_data2)

pad_tuple = (0, 0,  # +0 on dim #3 -> remains 3
             4, 4,  # +8 on dim #2 -> 168 + 8 = 176
             1, 1,  # +2 on dim #1 -> 46 + 2 = 48
             1, 1)  # +2 on dim #0 -> 46 + 2 = 48

# Apply padding - remove torch.Tensor() wrapper from pad_tuple
data_input_expanded = F.pad(data_input, pad_tuple)
data_target_expanded = F.pad(data_target, pad_tuple)

# Verify the new shape
print("Expanded data1 shape:", data_input_expanded.shape)
print("Expanded data2 shape:", data_target_expanded.shape)

Expanded data1 shape: torch.Size([48, 48, 176, 3])
Expanded data2 shape: torch.Size([48, 48, 176, 3])


In [15]:
# Step 2: Generate cubic patches of shape [16, 16, 16, 3] from both data1 and data2
patch_size = 16

# Prepare lists to store the patches for both data1 (X_inputs) and data2 (X_targets)
X_inputs_patches = []
X_targets_patches = []

# Loop through the expanded data to extract cubic patches
for i in range(0, data_input_expanded.shape[0], patch_size):  # Loop along the first dimension (80)
    for j in range(0, data_input_expanded.shape[1], patch_size):  # Loop along the second dimension (912)
        for k in range(0, data_input_expanded.shape[2], patch_size):  # Loop along the third dimension (240)
            # Extract patches from data1 (X_inputs) and data2 (X_targets)
            patch_data1 = data_input_expanded[i:i+patch_size, j:j+patch_size, k:k+patch_size, :]
            patch_data2 = data_target_expanded[i:i+patch_size, j:j+patch_size, k:k+patch_size, :]
            
            if patch_data1.shape == torch.Size([patch_size, patch_size, patch_size, 3]):  # Ensure the patch is correct
                X_inputs_patches.append(patch_data1)
                X_targets_patches.append(patch_data2)

# Convert the list of patches into tensors
X_inputs_tensor = torch.stack(X_inputs_patches)
X_targets_tensor = torch.stack(X_targets_patches)

# Verify the shape of the patches tensors
print("X_inputs_tensor shape:", X_inputs_tensor.shape)  # Should be [num_patches, 16, 16, 16, 3]
print("X_targets_tensor shape:", X_targets_tensor.shape)  # Should be [num_patches, 16, 16, 16, 3]

X_inputs_tensor shape: torch.Size([99, 16, 16, 16, 3])
X_targets_tensor shape: torch.Size([99, 16, 16, 16, 3])


In [16]:
X_inputs_tensor1 = torch.permute(X_inputs_tensor, (0, 4, 1, 2, 3))
X_targets_tensor1 = torch.permute(X_targets_tensor, (0, 4, 1, 2, 3))

In [17]:
np.random.seed(42)

n_samples = 15

# Generate random indices
indices = np.random.choice(X_inputs_tensor1.shape[0], size=n_samples, replace=False)
print(indices)

# Subsample data
train_x_input = X_inputs_tensor1[indices]
train_x_target = X_targets_tensor1[indices]

N = X_inputs_tensor1.shape[0]

mask = np.ones(N, dtype=bool)

mask[indices] = False

test_x_input = X_inputs_tensor1[mask]
test_x_target = X_targets_tensor1[mask]

[62 40 95 18 97 84 64 42 10  0 31 76 47 26 44]


# Two-step fine-tuning

In [18]:
model_path = 'code/DSR_code/DSR/DSR_unet_model_pretrain.pth'

In [19]:
dsr_finetune1 = dsr_fine_tune_step1(train_x_input, train_x_target, 
                                  load_path=model_path,
                                  lr=1e-5,  
                                  num_epochs=30,
                                  print_every_nepoch=10,
                                  device=device)

Batch is larger than half of the sample size. Training based on full-batch gradient descent.


In [20]:
dsr_finetune1.save_model("code/DSR_code/DSR/DSR_unet_model_finetune_s1.pth")

Model saved with validation loss inf


In [21]:
model_path1 = 'code/DSR_code/DSR/DSR_unet_model_finetune_s1.pth'

In [22]:
dsr_finetune2 = dsr_fine_tune_step2(train_x_input, train_x_target, 
                                  load_path=model_path1,
                                  lr=1e-4,  
                                  num_epochs=20,
                                  print_every_nepoch=10,
                                  device=device)

Batch is larger than half of the sample size. Training based on full-batch gradient descent.


In [23]:
dsr_finetune2.save_model("code/DSR_code/DSR/DSR_unet_model_finetune_s2.pth")

Model saved with validation loss inf


# Evaluation

In [25]:
checkpoint = torch.load('code/DSR_code/DSR/DSR_unet_model_finetune_s2.pth')

backbone = UNet3D(in_channels=3, out_channels=3, init_features=32).to(device)
backbone.conv = nn.Identity()                          
new_head = AdvancedHead(feat_ch=32, mid_ch=128, out_ch=3,
                        dropout_p=0.1).to(device)
DSR_tune = nn.Sequential(backbone, new_head)
DSR_tune.load_state_dict(checkpoint['model_state_dict'])

DSR_tune.eval()

Sequential(
  (0): UNet3D(
    (encoder1): Sequential(
      (0): Conv3d(3, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Dropout3d(p=0.2, inplace=False)
      (4): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (5): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): Dropout3d(p=0.2, inplace=False)
    )
    (pool1): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (encoder2): Sequential(
      (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Dropout3d(p=0.2, inplace=False)
      (4): Conv3d(64, 64, kernel_size=

In [26]:
### DSR prediction
batch_size = 32
sigma_t = 0.1
n_samples = test_x_input.shape[0]

num_runs = 200
DSR_predictions = []

# Use torch.no_grad() to prevent storing gradients during inference.
for start_idx in range(0, n_samples, batch_size):
    end_idx = min(start_idx + batch_size, n_samples)
    
    batch = test_x_input[start_idx:end_idx].to(device)
    target = test_x_target[start_idx:end_idx].to(device)
    batch_sum = torch.zeros_like(target).to(device)

    # Loop over the number of runs to compute predictions
    for run in range(num_runs):
        # Generate noise and add it to the input batch
        epsilon_t = torch.randn_like(batch) * (sigma_t ** 0.5)
        input_batch = batch + epsilon_t
        DSR_preds = DSR_tune(input_batch)
        batch_sum += DSR_preds.detach()

    # Compute the average prediction for this batch
    batch_mean = batch_sum / num_runs
    DSR_predictions.append(batch_mean)
        
# Concatenate all batch predictions into one final prediction tensor
DSR_predictions = torch.cat(DSR_predictions, dim=0)

In [27]:
DSR_predictions.shape

torch.Size([84, 3, 16, 16, 16])

## Shape reconstruction

In [28]:
def reconstructed_predict(predict_tensor):
    patch_size = 16
    num_patches_dim0 = 48 // patch_size  
    num_patches_dim1 = 48 // patch_size  
    num_patches_dim2 = 176 // patch_size  
    
    predict_tensor = torch.Tensor(predict_tensor).permute(0,2,3,4,1)

    predict_reconstructed = predict_tensor.reshape(num_patches_dim0, num_patches_dim1, num_patches_dim2,
                                              patch_size, patch_size, patch_size, 3)

    predict_reconstructed = predict_reconstructed.permute(0, 3, 1, 4, 2, 5, 6)

    predict_reconstructed = predict_reconstructed.reshape(48, num_patches_dim1 * patch_size, num_patches_dim2 * patch_size, 3)
    
    predict_reconstructed1 = predict_reconstructed[1:-1, 1:-1, 4:-4, :]
    
    return predict_reconstructed1

In [31]:
def construct_origin_shape(prediction):
    original_shape = (99, 3, 16, 16, 16)

    fine_tuning_data = train_x_target 
    #fine_tuning_indices = np.array([62, 40, 95, 18, 97, 84, 64, 42, 10])
    fine_tuning_indices = np.array([62,40,95,18,97,84,64,42,10,0,31,76,47,26,44])

    predictions = prediction 

    combined_data = np.zeros(original_shape)

    combined_data[fine_tuning_indices] = fine_tuning_data

    all_indices = np.arange(original_shape[0])  # [0, 1, ..., 98]
    remaining_indices = np.setdiff1d(all_indices, fine_tuning_indices)  # 90 indices

    combined_data[remaining_indices] = predictions.detach().cpu().numpy()
    
    eng_reshape1 = reconstructed_predict(combined_data)

    eng_reshape1_flat = eng_reshape1.reshape(-1,3)

    eng_prepare = np.zeros((15817956, 3))

    eng_prepare[mask1] = eng_reshape1_flat.detach().numpy()

    eng_reconstructed = eng_prepare.reshape(76, 901, 231, 3)

    return eng_reconstructed

In [32]:
DSR_reconstructed = construct_origin_shape(DSR_predictions)
DSR_reconstructed.shape

(76, 901, 231, 3)