In [1]:
import xarray as xr
import os
import numpy as np
import matplotlib.pyplot as plt
import cv2
from PIL import Image as im

In [2]:
base = '/mnt/data/sonia/occetc/raw'
storms = {}

slp_min = float('Inf')
slp_max = float('-Inf')

shape = (30, 32)
varnames = ['slp', 'fparam', 'thetae']
varranges = {f'{v}_min':float('Inf') for v in varnames} |\
    {f'{v}_max':float('-Inf') for v in varnames}

for month in sorted(os.listdir(base)):
    if month not in set([str(y) for y in range(200609, 201801)]):
        continue
    for nc in sorted(os.listdir(os.path.join(base, month))):
        _, date, hr, lat, lon, vartype, sid, src = nc.split('_')
        assert src == 'MERRA21deg.ncdf'
        
        storm = xr.open_dataset(os.path.join(base, month, nc))
        if storm['latitude'].shape != shape: # ensure it's the right shape
            continue
        
        storms[sid] = storms.get(sid, [])
        storms[sid].append(storm)
        
        for v in varnames:
            if storms[sid][-1][v].min().item() < varranges[f'{v}_min']:
                varranges[f'{v}_min'] = storms[sid][-1][v].min().item()
            if storms[sid][-1][v].max().item() > varranges[f'{v}_max']:
                varranges[f'{v}_max'] = storms[sid][-1][v].max().item()
                
        # if storms[sid][-1]['slp'].min() < slp_min:
        #     slp_min = storms[sid][-1]['slp'].min()
        # if storms[sid][-1]['slp'].max() > slp_max:
        #     slp_max = storms[sid][-1]['slp'].max()
            
# slp_min = int(slp_min)
# slp_max = int(slp_max)

In [3]:
varranges

{'slp_min': 915.33203125,
 'fparam_min': -2217.1796875,
 'thetae_min': -999.0,
 'slp_max': 1053.2540283203125,
 'fparam_max': 2184.4560546875,
 'thetae_max': 468.8287353515625}

In [15]:
example

In [4]:
example = list(storms.values())[0][0]
for var in varnames:
    maxdims = len(example[var].dims)
maxdims

3

## SLP Only

In [5]:
slpset = []
target_len = 8 # how many frames to put in each video

zs = np.full([1, 32], 0)

os.makedirs('/mnt/data/sonia/occetc/out', exist_ok=True)
for sid, logs in storms.items():    
    if len(logs) >= target_len:
        logs = logs[:target_len]
    elif len(logs) < 5: # too short, throw it out
        continue
    else: # too short; need to lengthen
        n = len(logs)
        logs = logs + (target_len-n)*[logs[-1]]
    assert len(logs) == target_len
        
    os.makedirs(f'/mnt/data/sonia/occetc/outxtend/{sid}', exist_ok=True)
        
    for i, log in enumerate(logs):
        log = log['slp'].to_numpy()
        scaled = 255 * (log - slp_min) / (slp_max - slp_min) # scale from 0 to 1
        scaled = np.concatenate([zs, scaled, zs]) # so the frame is 32x32 instead of 30x32 (must be multiples of 8)
        frame = im.fromarray(scaled).convert('RGB')
        frame.save(f'/mnt/data/sonia/occetc/outxtend/{sid}/{i}.png')

  scaled = 255 * (log - slp_min) / (slp_max - slp_min) # scale from 0 to 1


In [6]:
import matplotlib.pyplot as plt
import numpy as np

# Generate 8 random images (32x32 grayscale)
frames = slpset[2]

# Create a 2-row, 4-column subplot layout
fig, axes = plt.subplots(2, 4, figsize=(12, 6))

# Loop through each frame and plot it
for i, ax in enumerate(axes.flat):
    ax.imshow(frames[i], cmap="viridis")  # Change cmap as needed
    ax.set_title(f"Frame {i+1}")
    ax.axis("off")  # Hide axes for clarity

plt.tight_layout()  # Adjust layout to prevent overlap
plt.show()


IndexError: list index out of range

In [None]:
accelerate launch train_svd.py \
    --per_gpu_batch_size=1 --gradient_accumulation_steps=1 \
    --max_train_steps=50 \
    --width=32 \
    --height=32 \
    --checkpointing_steps=25 --checkpoints_total_limit=1 \
    --learning_rate=1e-5 --lr_warmup_steps=0 \
    --seed=123 \
    --mixed_precision="fp16" \
    --validation_steps=50 \
    --num_frames=8

## SLP, F-parameter and Thetae

In [14]:
target_len = 8 # how many frames to put in each video
zs = np.full([1, 32], 0) # just zeros for padding to make a 32x32 square
out_name = 'multivar-25.04.01'

# os.makedirs('/mnt/data/sonia/occetc/out', exist_ok=True)
for sid, logs in storms.items():    
    if len(logs) >= target_len:
        logs = logs[:target_len]
    elif len(logs) < 5: # too short, throw it out
        continue
    else: # too short; need to lengthen
        n = len(logs)
        logs = logs + (target_len-n)*[logs[-1]]
    assert len(logs) == target_len
        
    os.makedirs(f'/mnt/data/sonia/occetc/out/{out_name}/{sid}', exist_ok=True)
        
    for i, log in enumerate(logs):
        os.makedirs(os.path.join('/mnt/data/sonia/occetc', out_name, sid), exist_ok=True)
        scaled = []
        for v in varnames:
            vardata = log[v]
            dims_ordered = [dim for dim in ("latitude", "longitude") if dim in vardata.dims] \
                + [dim for dim in vardata.dims if dim not in ("latitude", "longitude")]
            vardata = vardata.transpose(*dims_ordered).to_numpy()
            vardata = 255 * (vardata - varranges[f'{v}_min']) / (varranges[f'{v}_max'] - varranges[f'{v}_min'])
            pad_width = [(0,2)] + (vardata.ndim-1) * [(0,0)]
            vardata = np.pad(vardata, pad_width, mode='constant', constant_values=0)
            vardata = vardata.reshape(vardata.shape + (1,) * (maxdims - vardata.ndim))
            scaled.append(vardata)
            # print(vardata.shape)
        matrix = np.concatenate(scaled, axis=2)
        np.save(os.path.join('/mnt/data/sonia/occetc', out_name, sid, f'{i}.npy'), matrix)
        # frame = im.fromarray(scaled).convert('RGB')
        # frame.save(f'/mnt/data/sonia/occetc/outxtend/{sid}/{i}.png')

In [16]:
matrix.shape

(32, 32, 27)