## The Well rSVD Preprocessing

[link to official documentation](https://polymathic-ai.org/the_well/tutorials/dataset/)  
[link to original paper](https://proceedings.neurips.cc/paper_files/paper/2024/file/4f9a5acd91ac76569f2fe291b1f4772b-Paper-Datasets_and_Benchmarks_Track.pdf)

## Imports

In [1]:

import sys
import pprint as pp
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import torch
from einops import rearrange
from tqdm import tqdm

from the_well.benchmark.metrics import VRMSE
from the_well.data import WellDataset
from the_well.utils.download import well_download
from the_well.data import WellDataset
from torch.utils.data import DataLoader

# Add the project root directory to Python path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root.absolute()))

from src.helpers import *

debug = False
n_iters = 5
dataset = 'planetswe'
n_steps = {'planetswe': 1008, 'active_matter': 81}
n_rank = {'active_matter': 50, 'planetswe': 75}

save_dir = project_root / 'datasets' / 'the_well_custom' / dataset
save_dir.mkdir(exist_ok=True, parents=True)

%load_ext autoreload
%autoreload 2
%matplotlib inline


## Load Data

In [2]:

# Load data from online, when using in practice we'll have to download the dataset
train_data = WellDataset(
    well_base_path=Path('/data') / 'alexey' / 'the_well',
    well_dataset_name=dataset,
    well_split_name="train",
    n_steps_input=n_steps[dataset],
    n_steps_output=0,
    use_normalization=False,
)

valid_data = WellDataset(
    well_base_path=Path('/data') / 'alexey' / 'the_well',
    well_dataset_name=dataset,
    well_split_name="valid",
    n_steps_input=n_steps[dataset],
    n_steps_output=0,
    use_normalization=False,
)

test_data = WellDataset(
    well_base_path=Path('/data') / 'alexey' / 'the_well',
    well_dataset_name=dataset,
    well_split_name="test",
    n_steps_input=n_steps[dataset],
    n_steps_output=0,
    use_normalization=False,
)

im_shape = train_data[0]['input_fields'].shape
im_rows, im_cols, im_dim = im_shape[1], im_shape[2], im_shape[3]
print(im_shape)



torch.Size([1008, 256, 512, 3])


## Concatenate

In [None]:

combine_all = False
train_mats_raw = create_mats(train_data, combine_all=combine_all, debug=debug)
valid_mats_raw = create_mats(valid_data, combine_all=combine_all, debug=debug)
test_mats_raw = create_mats(test_data, combine_all=combine_all, debug=debug)

print(len(train_mats_raw))
print(train_mats_raw[0].shape)


96
torch.Size([1008, 393216])


: 

In [None]:
full_mat = torch.cat(train_mats_raw + valid_mats_raw + test_mats_raw)

print(full_mat.shape)

## POD

In [5]:

U_full, S_full, V_full = generate_SVD(full_mat, n_rank=n_rank[dataset], n_iters=n_iters)


In [6]:

full_pod = create_pod(full_mat, V_full)

print(full_pod.shape)


torch.Size([3024, 75])


In [7]:

full_pod_scaled, full_scaler = scale_pod(full_pod)


## Inverse POD

In [8]:

full_mat_hat = inverse_pod(full_pod_scaled, full_scaler, V_full)


## Print Errors

In [None]:

print_errors([full_mat], [full_mat_hat], mean_relative_error, "Training POD errors:")


Training POD errors:
Error for i=0 is 3.37%



## Separate into Training, Validation, and Testing Data

In [None]:

total_tracks = int(full_mat.shape[0] / n_steps[dataset])

train_num = int(0.8*n_steps[dataset])
valid_num = int(0.1*n_steps[dataset])
test_num = int(0.1*n_steps[dataset])

print("train_num", train_num)
print("valid_num", valid_num)
print("test_num", test_num)

train_save = []
valid_save = []
test_save = []

train_pods_save = []
valid_pods_save = []
test_pods_save = []

for track_num in range(total_tracks):
    track_pod_scaled = full_pod_scaled[track_num*n_steps[dataset]:(track_num+1)*n_steps[dataset],:]
    track = full_mat[track_num*n_steps[dataset]:(track_num+1)*n_steps[dataset],:]

    train_pod_scaled = track_pod_scaled[0:train_num]
    valid_pod_scaled = track_pod_scaled[train_num:train_num+valid_num]
    test_pod_scaled = track_pod_scaled[train_num+valid_num:]

    train = track[0:train_num]
    val = track[train_num:train_num+valid_num]
    test = track[train_num+valid_num:]

    train_pods_save.append(train_pod_scaled)
    valid_pods_save.append(valid_pod_scaled)
    test_pods_save.append(test_pod_scaled)

    train_save.append(train)
    valid_save.append(val)
    test_save.append(test)



train_num 806
valid_num 100
test_num 100


## Calculate Per-Track and Per-Split Errors

In [11]:

train_save_hat = inverse_pods(train_pods_save, full_scaler, V_full)
print_errors(train_save, train_save_hat, mean_relative_error, "Train POD errors:")

valid_save_hat = inverse_pods(valid_pods_save, full_scaler, V_full)
print_errors(valid_save, valid_save_hat, mean_relative_error, "Validation POD errors:")

test_save_hat = inverse_pods(test_pods_save, full_scaler, V_full)
print_errors(test_save, test_save_hat, mean_relative_error, "Testing POD errors:")



Train POD errors:
Error for i=0 is 3.30%
Error for i=1 is 3.49%
Error for i=2 is 3.17%

Validation POD errors:
Error for i=0 is 3.16%
Error for i=1 is 4.23%
Error for i=2 is 2.97%

Testing POD errors:
Error for i=0 is 3.37%
Error for i=1 is 4.83%
Error for i=2 is 3.01%



## Save Results

In [12]:

# Create directories
(save_dir / 'metadata').mkdir(exist_ok=True)
(save_dir / 'full').mkdir(exist_ok=True)
(save_dir / 'pod').mkdir(exist_ok=True)

# Save scaler, V_full, and image metadata
torch.save(V_full, save_dir / 'metadata' / 'V.pt')
torch.save(full_scaler, save_dir / 'metadata' / 'scaler.pt')
torch.save((im_rows, im_cols, im_dim), save_dir / 'metadata' / 'im_dims.pt')

# Save full and pod tracks
for i in range(total_tracks):
    torch.save(train_save[i], save_dir / 'full' / f'train_{i}.pt')
    torch.save(train_pods_save[i], save_dir / 'pod' / f'train_{i}.pt')
    torch.save(valid_save[i], save_dir / 'full' / f'valid_{i}.pt')
    torch.save(valid_pods_save[i], save_dir / 'pod' / f'valid_{i}.pt')
    torch.save(test_save[i], save_dir / 'full' / f'test_{i}.pt')
    torch.save(test_pods_save[i], save_dir / 'pod' / f'test_{i}.pt')


## Example Loading

In [13]:

train_loaded = torch.load(save_dir / 'full' / 'train_0.pt', weights_only=False)
im_rows, im_cols, im_dim = torch.load(save_dir / 'metadata' / 'im_dims.pt', weights_only=False)
print(train_loaded.shape)
train_loaded_shaped = rearrange(train_loaded, "t (r c d) -> t r c d", t=int(n_steps[dataset]*0.8), r=im_rows, c=im_cols, d=im_dim)
print(train_loaded_shaped.shape)



torch.Size([806, 393216])
torch.Size([806, 256, 512, 3])
