In [1]:
import os
import numpy as np
import glob
import vtk
import vtk.numpy_interface.dataset_adapter as dsa
wdo = dsa.WrapDataObject
from tqdm import tqdm
import matplotlib.pyplot as plt

from vtk.util.numpy_support import vtk_to_numpy

In [2]:
import torch
from torch.utils.data import TensorDataset, Dataset, DataLoader, random_split, SubsetRandomSampler
from sklearn.model_selection import train_test_split
from scipy.spatial import cKDTree
from scipy.spatial import KDTree
from scipy.ndimage import zoom  # For resampling
import math 
from pathos.multiprocessing import ProcessingPool as Pool

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.linalg import vector_norm

In [4]:
from sklearn.preprocessing import StandardScaler

The dataset used for pre-training consists of CFD data. For each observation time of each patient, the data contain n points, comprising two components:

- (i) velocity vectors of shape (n,3);
- (ii) corresponding spatial coordinates of shape (n,3).

The following code illustrates how to sample 100 points from a single patient and transform the point data into a cubic patch (of shape (N, 3, 16, 16, 16)) for super-resolution.

## Step 1: Generate training and validation data

In [7]:
def process_point(args):
    idx, sample_all_data, sample_all_data_point, all_data, all_data_point, epsilon, tree, fixed_grid_size = args
    point = sample_all_data[idx]
    indices = tree.query_ball_point(point, r=epsilon)
    if idx not in indices:
        indices.append(idx)
    selected_points = all_data[indices]
    num_neighbors = len(indices)
    t = max(1, int(math.log2(num_neighbors) // 3))

    x_min, y_min, z_min = selected_points.min(axis=0)
    x_max, y_max, z_max = selected_points.max(axis=0)

    num_partitions = 2 ** t
    x_edges = np.linspace(x_min, x_max, num_partitions + 1)
    y_edges = np.linspace(y_min, y_max, num_partitions + 1)
    z_edges = np.linspace(z_min, z_max, num_partitions + 1)
    x_centers = (x_edges[:-1] + x_edges[1:]) / 2
    y_centers = (y_edges[:-1] + y_edges[1:]) / 2
    z_centers = (z_edges[:-1] + z_edges[1:]) / 2
    Xc, Yc, Zc = np.meshgrid(
        x_centers, y_centers, z_centers, indexing='ij'
    )
    centers = np.column_stack((Xc.ravel(), Yc.ravel(), Zc.ravel()))
        
    channel_x = np.zeros(centers.shape[0])
    channel_y = np.zeros(centers.shape[0])
    channel_z = np.zeros(centers.shape[0])
    
    selected_values = all_data_point[indices]    

    for i, center in enumerate(centers):
        distances = np.linalg.norm(selected_points - center, axis=1)
        if np.any(distances == 0):
            idx_zero = np.where(distances == 0)[0][0]
            channel_x[i] = selected_values[idx_zero, 0]
            channel_y[i] = selected_values[idx_zero, 1]
            channel_z[i] = selected_values[idx_zero, 2]
        else:
            weights = 1 / distances
            weights /= weights.sum()
            channel_x[i] = np.dot(weights, selected_values[:, 0])
            channel_y[i] = np.dot(weights, selected_values[:, 1])
            channel_z[i] = np.dot(weights, selected_values[:, 2])
    
    channel_values = np.stack((channel_x, channel_y, channel_z), axis=-1)   
    tensor_shape = (num_partitions, num_partitions, num_partitions, 3)
    tensor = channel_values.reshape(tensor_shape)

    center_values = np.stack((Xc.ravel(), Yc.ravel(), Zc.ravel()), axis=-1)   
    center_tensor = center_values.reshape(tensor_shape)
    
    # Resample tensor to fixed grid size using nearest neighbor interpolation

    zoom_factors = [fixed_size / float(orig_size) for fixed_size, orig_size in zip(fixed_grid_size, tensor.shape[:3])]
    # Apply zoom with order=0 for nearest neighbor interpolation
    tensor_resized = zoom(tensor, zoom_factors + [1], order=0)
    tensor_resized = torch.from_numpy(tensor_resized).float()

    zoom_factors1 = [fixed_size / float(orig_size) for fixed_size, orig_size in zip(fixed_grid_size, center_tensor.shape[:3])]
    tensor_resized1 = zoom(center_tensor, zoom_factors1 + [1], order=0)
    center_tensor_resized = torch.from_numpy(tensor_resized1).float()

    return (tensor_resized, center_tensor_resized, t, idx)

In [9]:
def process_all_points_parallel(
    sample_all_data, sample_all_data_point, all_data, all_data_point, epsilon, batch_size=32, save_dir='tensor_batches', fixed_grid_size=(32, 32, 32)
):
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
        
    N = sample_all_data.shape[0]
    tree = cKDTree(all_data)

    args_list = [
        (i, sample_all_data, sample_all_data_point, all_data, all_data_point, epsilon, tree, fixed_grid_size) for i in range(N)
    ]
    batch = []
    batch_idx = 0
    num_batches = 0

    for arg in tqdm(args_list, total=N, desc='Processing'):
        result = process_point(arg)
        if result is not None:
            tensor, center_tensor, t, idx = result
            batch.append((tensor, center_tensor, t, idx))
            batch_idx += 1
            if batch_idx >= batch_size:
                save_path = os.path.join(save_dir, f'batch_{num_batches}.pt')
                torch.save(batch, save_path)
                batch = []
                batch_idx = 0
                num_batches += 1

    # Save any remaining tensors in the final batch
    if batch:
        save_path = os.path.join(save_dir, f'batch_{num_batches}.pt')
        torch.save(batch, save_path)
        num_batches += 1

    return num_batches

In [11]:
# Define a custom Dataset to load the saved batches
class train_VoxelDataset(Dataset):
    def __init__(self, save_dir='tensor_batches/train'):
        self.save_dir = save_dir
        self.batch_files = [
            os.path.join(save_dir, f)
            for f in os.listdir(save_dir)
            if f.endswith('.pt')
        ]
        self.batch_files.sort()
        self.index_map = []
        self._create_index_map()

    def _create_index_map(self):
        for batch_file in self.batch_files:
            batch = torch.load(batch_file)
            batch_size = len(batch)
            for i in range(batch_size):
                self.index_map.append((batch_file, i))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        batch_file, tensor_idx = self.index_map[idx]
        batch = torch.load(batch_file)
        tensor, center_tensor, t, point_idx = batch[tensor_idx]
        return tensor, center_tensor, t  # Return both velocity, location and t

In [13]:
# Define a custom Dataset to load the saved batches
class valid_VoxelDataset(Dataset):
    def __init__(self, save_dir='tensor_batches/valid'):
        self.save_dir = save_dir
        self.batch_files = [
            os.path.join(save_dir, f)
            for f in os.listdir(save_dir)
            if f.endswith('.pt')
        ]
        self.batch_files.sort()
        self.index_map = []
        self._create_index_map()

    def _create_index_map(self):
        for batch_file in self.batch_files:
            batch = torch.load(batch_file)
            batch_size = len(batch)
            for i in range(batch_size):
                self.index_map.append((batch_file, i))

    def __len__(self):
        return len(self.index_map)

    def __getitem__(self, idx):
        batch_file, tensor_idx = self.index_map[idx]
        batch = torch.load(batch_file)
        tensor, center_tensor, t, point_idx = batch[tensor_idx]
        return tensor, center_tensor, t  # Return both tensor and t

In [15]:
## extract data from vtk file

def vtk_to_numpy_array(vtk_array):
    return vtk.util.numpy_support.vtk_to_numpy(vtk_array)

In [17]:
def get_true_function(func_name):
    """
    Returns a callable for the requested true function g.
    The callable should map any input tensor x to another tensor of the same shape.
    """
    if func_name == "softplus":
        return lambda x: nn.Softplus()(x)
    elif func_name == "square":
        # Just an example: ReLU the input, then square, then divide
        return lambda x: (F.relu(x)).pow(2) / 20
    elif func_name == "log":
        # Example piecewise log function
        return lambda x: (x/3 + np.log(3) - 2/3)*(x <= 2) + \
                         (torch.log(1 + x*(x > 2)))*(x > 2) 
    elif func_name == "cubic":
        return lambda x: x.pow(3)/30
    elif func_name == 'None':
        return lambda x: x
    else:
        # Identity if unknown
        return lambda x: x

class SuperResolutionDataset(Dataset):
    def __init__(self, images, locs, T, sigma_t_list, true_function='None'):
        """
        images: List of Tensors, each Tensor is shape (C, D, H, W).
        T: Number of downsampling steps
        sigma_t_list: list of length T (noise scales)
        true_function: string or callable. If string, we map it via get_true_function.
        """
        super().__init__()
        self.images = images
        self.T = T
        self.locs = locs
        self.sigma_t_list = sigma_t_list
        
        # Convert the true_function string to an actual callable
        if isinstance(true_function, str):
            self.g = get_true_function(true_function)
        else:
            self.g = true_function  # assume user passed a callable directly

        self.data_pairs = []
        self.prepare_data()

    def gaussian_kernel_1d(self, kernel_size, sigma):
        # Create a 1D Gaussian kernel
        x = torch.arange(kernel_size) - kernel_size // 2
        kernel = torch.exp(-0.5 * (x / sigma) ** 2)
        kernel = kernel / kernel.sum()
        return kernel

    def gaussian_blur_3d(self, x, kernel_size=5, sigma=1):
        # x: Tensor of shape (C, D, H, W)
        device = x.device
        x = x.unsqueeze(0)  # batch dimension
        N, C, D, H, W = x.shape

        # Adjust kernel_size if needed
        max_kernel_size = min(kernel_size, D, H, W)
        if max_kernel_size % 2 == 0:
            max_kernel_size -= 1  # ensure odd
        if max_kernel_size < 1:
            x_blur = x
        else:
            kernel = self.gaussian_kernel_1d(max_kernel_size, sigma).to(device)
            kernel_3d = kernel[:, None, None] * kernel[None, :, None] * kernel[None, None, :]
            kernel_3d = kernel_3d / kernel_3d.sum()
            kernel_3d = kernel_3d.view(1, 1, max_kernel_size, max_kernel_size, max_kernel_size)
            kernel_3d = kernel_3d.repeat(C, 1, 1, 1, 1)
            padding = max_kernel_size // 2
            # reflect-pad
            x_padded = F.pad(x, (padding, padding, padding, padding, padding, padding), mode='reflect')
            x_blur = F.conv3d(x_padded, kernel_3d, groups=C)
        return x_blur.squeeze(0)

    def interpolate_3d(self, tensor, **kwargs):
        # shape: (C, D, H, W)
        tensor = tensor.unsqueeze(0)
        tensor_interp = F.interpolate(tensor, **kwargs)
        return tensor_interp.squeeze(0)

    def generate_downsampled_images(self, image, loc):
        """
        image: (C, D, H, W)
        Returns a list [X_0, X_1, ..., X_T], 
        where X_0 is the final smallest scale, X_T is the original resolution.
        """
        X_t_list = []
        X_t = image  # start from the highest resolution
        X_t_list.append(X_t)

        X_loc_list = []
        X_loc = loc  
        X_loc_list.append(X_loc)
        
        for t in range(self.T, 0, -1):
            # 1. Downsample to half size
            X_t_down = self.interpolate_3d(X_t, scale_factor=0.5, 
                                           mode='trilinear', 
                                           align_corners=False,
                                           recompute_scale_factor=True)
            # 2. Blur
            X_t_blur = self.gaussian_blur_3d(X_t_down)
            # 3. Upsample back to original size
            original_size = image.shape[1:]  # (D, H, W)
            X_t_blur_upsampled = self.interpolate_3d(X_t_blur, size=original_size,
                                                     mode='trilinear', align_corners=False)
            X_t_list.insert(0, X_t_blur_upsampled)
            X_t = X_t_down

            X_loc_list.insert(0, loc)
            X_loc = loc
        return X_t_list, X_loc_list


    def prepare_data(self):
        """
        Creates pairs ( (X_{t-1} + epsilon_{t-1}), X_t ) for t in [1..T].
        But we also apply g(...) to the noisy input if desired.
        """
        for i in range(len(self.images)):
            X_t_list, X_loc_list = self.generate_downsampled_images(self.images[i], self.locs[i])
            for t in range(1, self.T + 1):         
                X_t = X_t_list[t]
                X_loc = X_loc_list[t]
                X_t_minus_1 = X_t_list[t - 1]
                sigma_t_minus_1 = self.sigma_t_list[t - 1]                
                # Noise
                epsilon_t_minus_1 = torch.randn_like(X_t_minus_1) * (sigma_t_minus_1 ** 0.5)
                X_t_minus_1_noisy = X_t_minus_1 + epsilon_t_minus_1
                # APPLY the user-specified g(...) to the noised input
                X_t_pred = self.g(X_t_minus_1_noisy)
                # We'll store (g-noisy, target, t) in data_pairs
                self.data_pairs.append((X_t_pred, X_t, X_loc, t))

    def __len__(self):
        return len(self.data_pairs)

    def __getitem__(self, idx):
        X_input, X_target, X_coords, t = self.data_pairs[idx]
        return X_input, X_target, X_coords, t

In [19]:
# Custom collate function
def custom_collate_fn(batch):
    X_inputs, X_targets, X_coords, ts = zip(*batch)
    # Determine the maximum spatial dimensions in the batch
    max_C = max(x_input.shape[0] for x_input in X_inputs)
    max_D = max(x_input.shape[1] for x_input in X_inputs)
    max_H = max(x_input.shape[2] for x_input in X_inputs)
    max_W = max(x_input.shape[3] for x_input in X_inputs)

    # Pad all tensors to the maximum size
    X_inputs_padded = []
    X_targets_padded = []
    X_coords_padded = []
    for x_input, x_target, x_coords in zip(X_inputs, X_targets, X_coords):
        padding_input = (
            0, max_W - x_input.shape[3],  # Width padding
            0, max_H - x_input.shape[2],  # Height padding
            0, max_D - x_input.shape[1],  # Depth padding
        )
        padding_target = (
            0, max_W - x_target.shape[3],
            0, max_H - x_target.shape[2],
            0, max_D - x_target.shape[1],
        )
        padding_coords = (
            0, max_W - x_coords.shape[3],
            0, max_H - x_coords.shape[2],
            0, max_D - x_coords.shape[1],
        )
        x_input_padded = F.pad(x_input, padding_input, mode='constant', value=0)
        x_target_padded = F.pad(x_target, padding_target, mode='constant', value=0)
        x_coords_padded = F.pad(x_coords, padding_coords, mode='constant', value=0)
        
        X_inputs_padded.append(x_input_padded)
        X_targets_padded.append(x_target_padded)
        X_coords_padded.append(x_coords_padded)

    X_inputs_batch = torch.stack(X_inputs_padded)
    X_targets_batch = torch.stack(X_targets_padded)
    X_coords_batch = torch.stack(X_coords_padded)
    ts_batch = torch.tensor(ts)
    return X_inputs_batch, X_targets_batch, X_coords_batch, ts_batch

In [21]:
train_patient = ['ANY-011-001']
valid_patient = ['ANY-035-001']

In [23]:
epsilon = 0.001  # Define the epsilon radius
batch_size = 64  # Define the batch size
fixed_grid_size = (16, 16, 16)  # Define the fixed grid size for resampling

In [37]:
train_input_list = []
train_target_list = []
train_loc_list = []
train_t_list = []

for i in range(len(train_patient)):
    filename = train_patient[i] + str('_00.vtm')
    path = '/Data/' + filename
    
    reader = vtk.vtkXMLMultiBlockDataReader()
    reader.SetFileName(path)  ## change here for other data sets
    reader.Update()
    
    blocks = [reader.GetOutput().GetBlock(i) for i in range(reader.GetOutput().GetNumberOfBlocks())]
    
    # Initialize lists to store data
    points_list = []
    point_data_list = []
    
    for block in blocks:
        if isinstance(block, vtk.vtkUnstructuredGrid):
            points = vtk_to_numpy_array(block.GetPoints().GetData())
            point_data = vtk_to_numpy_array(block.GetPointData().GetArray('velocity'))
    
            points_list.append(points)
            point_data_list.append(point_data)

    # Convert lists to single numpy arrays
    all_points = np.vstack(points_list)
    all_point_data = np.vstack(point_data_list)
    
    # Create a mask for selecting rows where not all dimensions are zero
    mask = np.any(all_point_data != 0, axis=1)
    
    # Apply the mask to filter the data
    filtered_all_point_data = all_point_data[mask]
    
    filtered_all_points = all_points[mask]
    
    # Fix a random seed for reproducibility
    np.random.seed(42)
    # Number of samples to subsample
    n_samples = 100
    
    # Generate random indices
    indices = np.random.choice(filtered_all_points.shape[0], size=n_samples, replace=False)
    
    # Subsample data
    all_points_subsampled = filtered_all_points[indices]
    all_point_data_subsampled = filtered_all_point_data[indices]
    
    # Process all points in parallel and save tensors in batches with progress bar
    save_dir = 'tensor_batches/train/' + train_patient[i]
    
    train_num_batches = process_all_points_parallel(
        all_points_subsampled, all_point_data_subsampled, filtered_all_points, filtered_all_point_data, epsilon, batch_size=batch_size, save_dir=save_dir, fixed_grid_size=fixed_grid_size
    )
    print(f"Saved {train_num_batches} valid batches of tensors.")
    
    # Create a Dataset and DataLoader for PyTorch
    train_dataset = train_VoxelDataset(save_dir=save_dir)
    
    train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=False)
    
    # Iterate over the dataloader
    for batch_tensors, batch_center, batch_t_values in train_dataloader:
        print(f"Batch tensors shape: {batch_tensors.shape}")
        print(f"Batch centers shape: {batch_tensors.shape}")
        print(f"Unique batch t values: {batch_t_values.unique()}")

    # List all tensor batch files
    batch_files = [
            os.path.join(save_dir, f)
            for f in os.listdir(save_dir)
            if f.endswith('.pt')]
        
    batch_files.sort()
    
    images = []
    locs = []
    
    print("Loading images from processed tensors...")
    for batch_file in tqdm(batch_files, desc="Loading batches"):
        batch = torch.load(batch_file)  # Each batch is a list of (tensor, t, idx)
        for data in batch:
            tensor, center, t, idx = data
            # tensor is of shape (D, H, W, C), need to permute to (C, D, H, W)
            tensor = tensor.permute(3, 0, 1, 2)
            center = center.permute(3, 0, 1, 2)
            images.append(tensor)
            locs.append(center)
    
    
    T = 4  # Maximum resolution level
    sigma_t_list = [0.1] * T  # Prespecified \sigma_t^2 for each t from 0 to T-1
    
    train_dataset = SuperResolutionDataset(images, locs, T, sigma_t_list)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

    for X_input, X_target, X_coords, t in train_dataloader:
        train_input_list.append(X_input)
        train_target_list.append(X_target)
        train_loc_list.append(X_coords)
        train_t_list.append(t)

Processing: 100%|████████████████████████████████████████████████████████████████████| 100/100 [00:03<00:00, 30.32it/s]
  batch = torch.load(batch_file)
  batch = torch.load(batch_file)


Saved 2 valid batches of tensors.
Batch tensors shape: torch.Size([64, 16, 16, 16, 3])
Batch centers shape: torch.Size([64, 16, 16, 16, 3])
Unique batch t values: tensor([3, 4])
Batch tensors shape: torch.Size([36, 16, 16, 16, 3])
Batch centers shape: torch.Size([36, 16, 16, 16, 3])
Unique batch t values: tensor([2, 3, 4])
Loading images from processed tensors...


  batch = torch.load(batch_file)  # Each batch is a list of (tensor, t, idx)
Loading batches: 100%|███████████████████████████████████████████████████████████████████████████| 2/2 [00:00<?, ?it/s]


In [39]:
valid_input_list = []
valid_target_list = []
valid_loc_list = []

for i in range(len(valid_patient)):
    filename = valid_patient[i] + str('_00.vtm')
    path = '/Data/' + filename
    
    reader = vtk.vtkXMLMultiBlockDataReader()
    reader.SetFileName(path)  ## change here for other data sets
    reader.Update()
    
    blocks = [reader.GetOutput().GetBlock(i) for i in range(reader.GetOutput().GetNumberOfBlocks())]
    
    # Initialize lists to store data
    points_list = []
    point_data_list = []
    
    for block in blocks:
        if isinstance(block, vtk.vtkUnstructuredGrid):
            points = vtk_to_numpy_array(block.GetPoints().GetData())
            point_data = vtk_to_numpy_array(block.GetPointData().GetArray('velocity'))
    
            points_list.append(points)
            point_data_list.append(point_data)
       
    # Convert lists to single numpy arrays
    all_points = np.vstack(points_list)
    all_point_data = np.vstack(point_data_list)
    
    # Create a mask for selecting rows where not all dimensions are zero
    mask = np.any(all_point_data != 0, axis=1)
    
    # Apply the mask to filter the data
    filtered_all_point_data = all_point_data[mask]
    
    filtered_all_points = all_points[mask]
    
    # Fix a random seed for reproducibility
    np.random.seed(42)
    # Number of samples to subsample
    n_samples = 100
    
    # Generate random indices
    indices = np.random.choice(filtered_all_points.shape[0], size=n_samples, replace=False)
    
    # Subsample data
    all_points_subsampled = filtered_all_points[indices]
    all_point_data_subsampled = filtered_all_point_data[indices]
    
    # Process all points in parallel and save tensors in batches with progress bar
    save_dir = 'tensor_batches/valid/' + valid_patient[i]
    
    valid_num_batches = process_all_points_parallel(
        all_points_subsampled, all_point_data_subsampled, filtered_all_points, filtered_all_point_data, epsilon, batch_size=batch_size, save_dir=save_dir, fixed_grid_size=fixed_grid_size
    )
    print(f"Saved {valid_num_batches} valid batches of tensors.")
    
    # Create a Dataset and DataLoader for PyTorch
    valid_dataset = valid_VoxelDataset(save_dir=save_dir)
    
    valid_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=False)
    
    # Iterate over the dataloader
    for batch_tensors, batch_center, batch_t_values in valid_dataloader:
        print(f"Batch tensors shape: {batch_tensors.shape}")
        print(f"Unique batch t values: {batch_t_values.unique()}")

    # List all tensor batch files
    batch_files = [
            os.path.join(save_dir, f)
            for f in os.listdir(save_dir)
            if f.endswith('.pt')]
        
    batch_files.sort()

    images = []
    locs = []
    
    print("Loading images from processed tensors...")
    for batch_file in tqdm(batch_files, desc="Loading batches"):
        batch = torch.load(batch_file)  # Each batch is a list of (tensor, t, idx)
        for data in batch:
            tensor, center, t, idx = data
            # tensor is of shape (D, H, W, C), need to permute to (C, D, H, W)
            tensor = tensor.permute(3, 0, 1, 2)
            center = center.permute(3, 0, 1, 2)
            images.append(tensor)
            locs.append(center)
    
    
    T = 4  # Maximum resolution level
    sigma_t_list = [0] * T  # Prespecified \sigma_t^2 for each t from 0 to T-1
    
    valid_dataset = SuperResolutionDataset(images, locs, T, sigma_t_list)

    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

    for X_input, X_target, X_coords, t in valid_dataloader:
        valid_input_list.append(X_input)
        valid_target_list.append(X_target)
        valid_loc_list.append(X_coords)

Processing: 100%|████████████████████████████████████████████████████████████████████| 100/100 [00:06<00:00, 16.53it/s]
  batch = torch.load(batch_file)
  batch = torch.load(batch_file)


Saved 2 valid batches of tensors.
Batch tensors shape: torch.Size([64, 16, 16, 16, 3])
Unique batch t values: tensor([2, 3, 4])
Batch tensors shape: torch.Size([36, 16, 16, 16, 3])
Unique batch t values: tensor([2, 3, 4])
Loading images from processed tensors...


  batch = torch.load(batch_file)  # Each batch is a list of (tensor, t, idx)
Loading batches: 100%|███████████████████████████████████████████████████████████████████████████| 2/2 [00:00<?, ?it/s]


In [41]:
valid_input = torch.cat(valid_input_list, dim = 0)
valid_target = torch.cat(valid_target_list, dim = 0)
valid_loc = torch.cat(valid_loc_list, dim = 0)

In [43]:
valid_input.shape

torch.Size([400, 3, 16, 16, 16])

In [45]:
train_input = torch.cat(train_input_list, dim = 0)
train_target = torch.cat(train_target_list, dim = 0)
train_loc = torch.cat(train_loc_list, dim = 0)

In [47]:
train_input.shape

torch.Size([400, 3, 16, 16, 16])

In [49]:
torch.save(train_input, "/Data/train_input_100.pt")

In [51]:
torch.save(train_target, "/Data/train_target_100.pt")

In [53]:
torch.save(valid_input, "/Data/valid_input_100.pt")

In [55]:
torch.save(valid_target, "/Data/valid_target_100.pt")