In [None]:
import numpy as np
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split

In [None]:
def get_file_list(npy_dir, folder_name):
  file_list = sorted(os.listdir(os.path.join(npy_dir, folder_name)))
  return file_list

In [None]:
def join_files(velocity_files, pressure_files):
  joined_files = []
  for u_file in velocity_files:

    u_index = u_file.split('_')[-1].split('.')[0]  # Assuming the format 'velocity_X.npy'
    corresponding_p_file = next((p for p in pressure_files if p.split('_')[-1].split('.')[0] == u_index), None)

    if corresponding_p_file:
      joined_files.append((u_file, corresponding_p_file))

  return joined_files

In [None]:
def get_joined_files(velocity_dir, pressure_dir, folder_name):

  file_list_u_true = get_file_list(velocity_dir, folder_name)
  file_list_p_true = get_file_list(pressure_dir, folder_name)

  return join_files(file_list_u_true, file_list_p_true)

In [None]:
class FlowFieldDataset(Dataset):
  def __init__(self, velocity_dir, pressure_dir, files_true, files_noisy):

    self.velocity_dir = velocity_dir
    self.pressure_dir = pressure_dir
    self.data_true = [self.load_one_instance(velocity_dir, pressure_dir, f, 'true_data') for f in files_true]
    self.data_noisy = [self.load_one_instance(velocity_dir, pressure_dir, f, 'noisy_data') for f in files_noisy]

  def load_one_instance(self, velocity_dir, pressure_dir, file_tuple, data_type):
    u_file, p_file = file_tuple
    u_data = np.load(os.path.join(velocity_dir, data_type, u_file))
    p_data = np.load(os.path.join(pressure_dir, data_type, p_file))

    return torch.from_numpy(u_data).float(), torch.from_numpy(p_data).float()

  def __len__(self):
        return len(self.data_true)

  def __getitem__(self, idx):
    u_hr_tensor, p_hr_tensor = self.data_true[idx]
    u_lr_tensor, p_lr_tensor = self.data_noisy[idx]

    # Reorder the dimensions: [height, width, channels] to [channels, height, width]
    u_hr_tensor = u_hr_tensor.permute(2, 0, 1)
    p_hr_tensor = p_hr_tensor.unsqueeze(0)
    u_lr_tensor = u_lr_tensor.permute(2, 0, 1)
    p_lr_tensor = p_lr_tensor.unsqueeze(0)

    return (u_hr_tensor, p_hr_tensor), (u_lr_tensor, p_lr_tensor)

In [None]:
def train_test_split(dataset, test_size=0.2):

  total_samples = len(dataset)
  test_sample_size = int(test_size * total_samples)
  train_sample_size = total_samples - test_sample_size

  train_dataset, test_dataset = random_split(dataset, [train_sample_size, test_sample_size])

  return train_dataset, test_dataset