In [1]:
import numpy as np
import torch
from giverny.turbulence_dataset import *
from giverny.turbulence_toolkit import *

In [None]:
auth_token = 'your_auth_token_here'  # Replace with your actual authentication token
dataset_title = 'sabl2048high'
dataset = turb_dataset(dataset_title=dataset_title, output_path='', auth_token=auth_token)

In [3]:
print(dataset.metadata)

{'name': 'JHTDB', 'description': 'The original Johns Hopkins Turbulence Database', 'pickled_metadata_filepath': '/home/idies/workspace/turbulence-ceph/sciserver-turbulence/jhtdb_metadata/jhtdb_pickled', 'variables': [{'code': 'pressure', 'name': 'Pressure', 'description': None, 'component_codes': ['p'], 'cardinality': 1}, {'code': 'temperature', 'name': 'Temperature', 'description': None, 'component_codes': ['Œ∏'], 'cardinality': 1}, {'code': 'soiltemperature', 'name': 'Soil Temperature', 'description': None, 'component_codes': ['Œ∏-soil'], 'cardinality': 1}, {'code': 'sgsenergy', 'name': 'Subgrid-scale Energy', 'description': None, 'component_codes': ['e'], 'cardinality': 1}, {'code': 'sgsviscosity', 'name': 'Subgrid-scale Viscosity', 'description': None, 'component_codes': ['ŒΩ'], 'cardinality': 1}, {'code': 'density', 'name': 'Density', 'description': None, 'component_codes': ['œÅ'], 'cardinality': 1}, {'code': 'velocity', 'name': 'Velocity', 'description': None, 'component_codes': 

In [4]:
timesteps = list(range(1, 20))  # t=1 to 100
stride = 32
patch_size = 32

X = []
y = []

In [5]:
import numpy as np
import torch
import os
from tqdm import tqdm

# -------------------------------
# Configuration
# -------------------------------
patch_size = 32
stride = 32
timesteps = list(range(1, 10))

X, y = [], []

# -------------------------------
# Step 1: First Pass - Collect StdDev stats
# -------------------------------
std_vort_list, std_Tgrad_list, std_Pgrad_list = [], [], []

def compute_vorticity_2d(patch_v):
    u = patch_v[..., 0]
    v = patch_v[..., 1]
    dv_dx = np.gradient(v, axis=1)
    du_dy = np.gradient(u, axis=0)
    return dv_dx - du_dy

def compute_grad_magnitude(field):
    dx = np.gradient(field, axis=1)
    dy = np.gradient(field, axis=0)
    return np.sqrt(dx**2 + dy**2)

def compute_composite_metrics(patch_v, patch_T, patch_p):
    vort = compute_vorticity_2d(patch_v)
    grad_T = compute_grad_magnitude(patch_T)
    grad_P = compute_grad_magnitude(patch_p)
    return np.std(vort), np.std(grad_T), np.std(grad_P)

print("üîç Collecting statistics for thresholding...")

for t in tqdm(timesteps):
    try:
        axes_ranges = np.array([[0, 2047], [0, 2047], [16, 16]], dtype=np.int32)
        strides_arr = np.array([1, 1, 1], dtype=np.int32)

        velocity_data = getCutout(dataset, 'velocity', t, axes_ranges, strides_arr)
        temperature_data = getCutout(dataset, 'temperature', t, axes_ranges, strides_arr)
        pressure_data = getCutout(dataset, 'pressure', t, axes_ranges, strides_arr)

        velocity = np.squeeze(velocity_data[f'velocity_{t:04d}'].values)
        temperature = np.squeeze(temperature_data[f'temperature_{t:04d}'].values)
        pressure = np.squeeze(pressure_data[f'pressure_{t:04d}'].values)

        for i in range(0, 2048 - patch_size + 1, stride):
            for j in range(0, 2048 - patch_size + 1, stride):
                v_patch = velocity[i:i+patch_size, j:j+patch_size, :]
                T_patch = temperature[i:i+patch_size, j:j+patch_size]
                p_patch = pressure[i:i+patch_size, j:j+patch_size]

                std_vort, std_T, std_P = compute_composite_metrics(v_patch, T_patch, p_patch)
                std_vort_list.append(std_vort)
                std_Tgrad_list.append(std_T)
                std_Pgrad_list.append(std_P)

    except Exception as e:
        print(f"‚ùå Failed for timestep {t}: {e}")

# -------------------------------
# Step 2: Compute thresholds
# -------------------------------
thresholds = {
    'vort': np.percentile(std_vort_list, 85),
    'temp': np.percentile(std_Tgrad_list, 85),
    'pres': np.percentile(std_Pgrad_list, 85),
}

print("\nüìà Thresholds computed from 85th percentile:")
print(thresholds)

# -------------------------------
# Step 3: Second Pass - Create dataset
# -------------------------------
X, y = [], []
print("\nüì¶ Generating dataset with new labels...")

def label_patch_composite(std_vort, std_T, std_P, thresholds):
    return int(std_vort > thresholds['vort'] or std_T > thresholds['temp'] or std_P > thresholds['pres'])

for t in tqdm(timesteps):
    try:
        axes_ranges = np.array([[0, 2047], [0, 2047], [16, 16]], dtype=np.int32)
        strides_arr = np.array([1, 1, 1], dtype=np.int32)

        velocity_data = getCutout(dataset, 'velocity', t, axes_ranges, strides_arr)
        temperature_data = getCutout(dataset, 'temperature', t, axes_ranges, strides_arr)
        pressure_data = getCutout(dataset, 'pressure', t, axes_ranges, strides_arr)

        velocity = np.squeeze(velocity_data[f'velocity_{t:04d}'].values)
        temperature = np.squeeze(temperature_data[f'temperature_{t:04d}'].values)
        pressure = np.squeeze(pressure_data[f'pressure_{t:04d}'].values)

        for i in range(0, 2048 - patch_size + 1, stride):
            for j in range(0, 2048 - patch_size + 1, stride):
                v_patch = velocity[i:i+patch_size, j:j+patch_size, :]
                T_patch = temperature[i:i+patch_size, j:j+patch_size]
                p_patch = pressure[i:i+patch_size, j:j+patch_size]

                std_vort, std_T, std_P = compute_composite_metrics(v_patch, T_patch, p_patch)
                label = label_patch_composite(std_vort, std_T, std_P, thresholds)

                input_tensor = np.stack([
                    v_patch[..., 0],
                    v_patch[..., 1],
                    T_patch
                ], axis=0)

                X.append(input_tensor)
                y.append(label)

    except Exception as e:
        print(f"‚ùå Failed for timestep {t}: {e}")

# -------------------------------
# Step 4: Save Dataset
# -------------------------------
X_tensor = torch.tensor(np.array(X), dtype=torch.float32)
y_tensor = torch.tensor(np.array(y), dtype=torch.long)

dataset_dict = {'X': X_tensor, 'y': y_tensor}
save_path = 'turbulence_dataset_sabl2048high_composite_label.pt'
torch.save(dataset_dict, save_path)

print(f"\n‚úÖ Saved dataset to '{save_path}'")
print(f"Total samples: {len(y_tensor)} | üå™Ô∏è Turbulence: {y_tensor.sum().item()} | üõ´ Non-turbulence: {(y_tensor == 0).sum().item()}")


üîç Collecting statistics for thresholding...


  0% 0/9 [00:00<?, ?it/s]


-----
getCutout is processing...

total time elapsed = 37.279 seconds (0.621 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.848 seconds (0.181 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 12.419 seconds (0.207 minutes)

query completed successfully.
-----


 11% 1/9 [01:05<08:42, 65.32s/it]


-----
getCutout is processing...

total time elapsed = 32.538 seconds (0.542 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 12.538 seconds (0.209 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.448 seconds (0.174 minutes)

query completed successfully.
-----


 22% 2/9 [02:05<07:16, 62.38s/it]


-----
getCutout is processing...

total time elapsed = 29.597 seconds (0.493 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 11.536 seconds (0.192 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.937 seconds (0.182 minutes)

query completed successfully.
-----


 33% 3/9 [03:02<05:59, 59.88s/it]


-----
getCutout is processing...

total time elapsed = 30.924 seconds (0.515 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.919 seconds (0.182 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.937 seconds (0.182 minutes)

query completed successfully.
-----


 44% 4/9 [04:00<04:54, 58.94s/it]


-----
getCutout is processing...

total time elapsed = 32.222 seconds (0.537 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.591 seconds (0.177 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 11.865 seconds (0.198 minutes)

query completed successfully.
-----


 56% 5/9 [04:59<03:56, 59.15s/it]


-----
getCutout is processing...

total time elapsed = 32.016 seconds (0.534 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 11.169 seconds (0.186 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 11.273 seconds (0.188 minutes)

query completed successfully.
-----


 67% 6/9 [05:58<02:57, 59.22s/it]


-----
getCutout is processing...

total time elapsed = 32.917 seconds (0.549 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.260 seconds (0.171 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.962 seconds (0.183 minutes)

query completed successfully.
-----


 78% 7/9 [06:57<01:58, 59.16s/it]


-----
getCutout is processing...

total time elapsed = 32.373 seconds (0.540 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.473 seconds (0.175 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 10.324 seconds (0.172 minutes)

query completed successfully.
-----


 89% 8/9 [07:56<00:58, 58.85s/it]


-----
getCutout is processing...

total time elapsed = 32.620 seconds (0.544 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 9.638 seconds (0.161 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 11.403 seconds (0.190 minutes)

query completed successfully.
-----


100% 9/9 [08:54<00:00, 59.42s/it]



üìà Thresholds computed from 85th percentile:
{'vort': 0.19501825124025343, 'temp': 0.014994575642049313, 'pres': 0.0430082231760025}

üì¶ Generating dataset with new labels...


  0% 0/9 [00:00<?, ?it/s]


-----
getCutout is processing...

total time elapsed = 21.558 seconds (0.359 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 5.732 seconds (0.096 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 4.873 seconds (0.081 minutes)

query completed successfully.
-----


 11% 1/9 [00:37<05:00, 37.56s/it]


-----
getCutout is processing...

total time elapsed = 25.970 seconds (0.433 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 6.615 seconds (0.110 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 6.213 seconds (0.104 minutes)

query completed successfully.
-----


 22% 2/9 [01:21<04:50, 41.53s/it]


-----
getCutout is processing...

total time elapsed = 25.496 seconds (0.425 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 6.180 seconds (0.103 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 5.847 seconds (0.097 minutes)

query completed successfully.
-----


 33% 3/9 [02:04<04:12, 42.08s/it]


-----
getCutout is processing...

total time elapsed = 27.130 seconds (0.452 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 12.171 seconds (0.203 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 6.373 seconds (0.106 minutes)

query completed successfully.
-----


 44% 4/9 [02:56<03:49, 45.81s/it]


-----
getCutout is processing...

total time elapsed = 29.631 seconds (0.494 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 5.784 seconds (0.096 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 5.539 seconds (0.092 minutes)

query completed successfully.
-----


 56% 5/9 [03:42<03:03, 45.89s/it]


-----
getCutout is processing...

total time elapsed = 23.422 seconds (0.390 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 5.481 seconds (0.091 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 5.097 seconds (0.085 minutes)

query completed successfully.
-----


 67% 6/9 [04:21<02:10, 43.61s/it]


-----
getCutout is processing...

total time elapsed = 24.167 seconds (0.403 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 4.325 seconds (0.072 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 4.368 seconds (0.073 minutes)

query completed successfully.
-----


 78% 7/9 [04:59<01:23, 41.83s/it]


-----
getCutout is processing...

total time elapsed = 17.741 seconds (0.296 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 3.967 seconds (0.066 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 3.948 seconds (0.066 minutes)

query completed successfully.
-----


 89% 8/9 [05:30<00:38, 38.33s/it]


-----
getCutout is processing...

total time elapsed = 12.686 seconds (0.211 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 2.941 seconds (0.049 minutes)

query completed successfully.
-----

-----
getCutout is processing...

total time elapsed = 3.068 seconds (0.051 minutes)

query completed successfully.
-----


100% 9/9 [05:54<00:00, 39.35s/it]



‚úÖ Saved dataset to 'turbulence_dataset_sabl2048high_composite_label.pt'
Total samples: 36864 | üå™Ô∏è Turbulence: 11268 | üõ´ Non-turbulence: 25596
