In [None]:
# nts: activate langchain_env 
import cdsapi
import logging
from tqdm import tqdm
import xarray as xr
import numpy as np
import pandas as pd
import os
import sys
import contextlib
import threading

# trackspath='/home/sonia/mcms/tracker/1940-2010/era5/read_era5/out_era5_output_1940_2010.txt'
trackspath='/home/sonia/mcms/tracker/2010-2024/era5/out_era5/era5/mcms_era5_2010_2024_tracks.txt'
use_slp = False # whether to include slp channel
threads = 32

In [73]:
regmask = xr.open_dataset('/home/cyclone/regmask_0723_anl.nc')
regmask.sel(lono=-178.8, lato=-87, method='nearest')['regmaskoc'].values

array([ 58, 107, 116, 117, 122, 123, 129, 132, 138, 141, 147, 150],
      dtype=int32)

In [None]:
# atlantic ocean is regmask['reg_name'].values[109] # so 110 in regmaskoc values
reg_id = 110

In [75]:
# make dataframe of all tracks 
tracks = pd.read_csv(trackspath, sep=' ', header=None, 
        names=['year', 'month', 'day', 'hour', 'total_hrs', 'unk1', 'unk2', 'unk3', 'unk4', 'unk5', 'unk6', 
               'z1', 'z2', 'unk7', 'tid', 'sid'])
tracks = tracks.sort_values(by=['year', 'month', 'day', 'hour'])
tracks['lat'] = 90-tracks['unk1'].values/100
tracks['lon'] = tracks['unk2'].values/100
tracks = tracks[['year', 'month', 'day', 'hour', 'tid', 'sid', 'lat', 'lon']]
# tracks = tracks[tracks['year']<=1942] # for debugging
tracks

Unnamed: 0,year,month,day,hour,tid,sid,lat,lon
0,1940,1,1,0,19400101000110034050,19400101000110034050,79.02,340.65
6,1940,1,1,0,19400101000140017350,19400101000140017350,76.13,173.53
12,1940,1,1,0,19400101000150027900,19400101000150027900,74.89,278.86
22,1940,1,1,0,19400101000195000700,19400101000195000700,70.54,7.21
38,1940,1,1,0,19400101000350004550,19400101000350004550,55.08,45.59
...,...,...,...,...,...,...,...,...
3112943,2010,12,31,18,20101231181610016250,20101230061565015400,-71.01,162.31
3112950,2010,12,31,18,20101231181635023250,20101230061635023250,-73.33,232.30
3112962,2010,12,31,18,20101231180405018900,20101230120465017150,49.33,189.18
3112968,2010,12,31,18,20101231181740027150,20101230121740027150,-84.06,271.39


In [76]:
sids = tracks[tracks['tid'] == tracks['sid']]
natlantic = 0
for _, sid in sids.iterrows():
    if sid['lat'] > 0 and 110 in regmask.sel(lono=sid['lon'], lato=sid['lat'], method='nearest')['regmaskoc'].values:
        natlantic += 1
natlantic

3312

In [3]:
@contextlib.contextmanager
def suppress_output():
    with open(os.devnull, 'w') as devnull:
        old_stdout = sys.stdout
        old_stderr = sys.stderr
        sys.stdout = devnull
        sys.stderr = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout
            sys.stderr = old_stderr

In [4]:
box = 32 # (box/2 from center in each direction)
if use_slp:
    file_year = 1940
    slp = xr.open_dataset('/home/cyclone/slp.1940.nc')
    slp_next = xr.open_dataset('/home/cyclone/slp.1941.nc')
with suppress_output():
    client = cdsapi.Client()

def prep_point(df, thread=0):
    """make one training datapoint. df contains year/../hr, lat, lon of center"""
    boxes = []
    for _, frame in df.iterrows():
        year, month, day, hour = frame['year'], frame['month'], frame['day'], frame['hour']
        lat, lon = frame['lat'], frame['lon']
        if use_slp:
            # get the box
            if year==file_year:
                slp_box = slp.sel(time=f'{year}-{month:02d}-{day:02d}T{hour:02d}:00:00',
                                lat=slice(lat+box/4, lat-box/4), lon=slice(lon-box/4, lon+box/4))
                                # /4, because /2 for half box and /2 for grid resolution of 0.5 degrees
            elif year==file_year+1:
                slp_box = slp_next.sel(time=f'{year}-{month:02d}-{day:02d}T{hour:02d}:00:00', 
                                lat=slice(lat+box/4, lat-box/4), lon=slice(lon-box/4, lon+box/4))
            else:
                raise ValueError(f'Year {year} not supported, file year is {file_year}')
            slp_box = slp_box.slp.squeeze().values
        
        request = {
            "product_type": ["reanalysis"],
            "variable": [
                "10m_u_component_of_wind",
                "10m_v_component_of_wind",
                # "sea_surface_temperature"
            ],
            "year": [str(year)],
            "month": [str(month)],
            "day": [str(day)],
            "time": [f"{hour}:00"],
            'format': 'netcdf',
            "download_format": "unarchived",
            "area": [lat+box/4, lon-box/4, lat-box/4, lon+box/4],
            'grid': '0.5/0.5', 
        }
        with suppress_output():
            out = client.retrieve('reanalysis-era5-single-levels', request, f'temp_{str(thread)}.nc') #.download()

        ds = xr.open_dataset('temp.nc')
        u = ds['u10'].squeeze().values[:box, :box] # deal with rounding things
        v = ds['v10'].squeeze().values[:box, :box]
        magnitude = np.sqrt(u**2 + v**2)
        boxes.append(magnitude)
        
    return boxes

In [36]:
sids = tracks['sid'].unique()
RADIUS=6371 # Earth radius in km
outpath = '/home/cyclone/train/slp_small'
if not os.path.exists(outpath):
    os.makedirs(outpath)
readme = """small debugging dataset: 32x32 of just slp, 8 frames long, over [1940,1942]"""
with open(f'{outpath}/README.txt', 'w') as f:
    f.write(readme)

def worker(sids_chunk, thread_id):
    for i, sid in enumerate(sids_chunk):
        if i % 100 == 0:
            print(f'Thread {thread_id}: Processing sid {i}/{len(sids_chunk)}: {i//len(sids_chunk)*100:.2f}% complete')
        sid_df = tracks[tracks['sid'] == sid]
        if len(sid_df) < 10:
            continue
        elif sid_df[sid_df['tid']==sid]['lat'].abs().iloc[0] > 70:
            continue # starts poleward of 70 degrees
        elif sid_df[sid_df['tid']==sid]['lon'].iloc[0]<28 or sid_df[sid_df['tid']==sid]['lon'].iloc[0] >120:
            continue # only get indian ocean area

        sid_df = sid_df.sort_values(by=['tid'])
        
        # # check total distance traveled (sum of great circle)
        # lat1 = np.radians(sid_df['lat'].to_numpy()[:-1])
        # lon1 = np.radians(sid_df['lon'].to_numpy()[1])
        # lat2 = np.radians(sid_df['lat'].to_numpy()[1:])
        # lon2 = np.radians(sid_df['lon'].to_numpy()[1:])
        # dlat = lat2 - lat1
        # dlon = lon2 - lon1
        # a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
        # c = 2 * np.arcsin(np.sqrt(a))
        # dist = np.sum(RADIUS * c)
        
        sid_df = sid_df.iloc[:8]  # only take the first 8 frames for debugging
        
        if use_slp and sid_df['year'].iloc[0] == file_year + 1: # starts in the next year
            slp = slp_next
            try:
                slp_next = xr.open_dataset(f'/home/cyclone/slp.{file_year+2}.nc')
            except:
                slp_next = None # reaching the end of our data
            file_year += 1 

        point = prep_point(sid_df)
        os.makedirs(f'{outpath}/{sid}', exist_ok=True)
        for i, frame in enumerate(point):
            np.save(f'{outpath}/{sid}/{i}.npy', frame)

In [None]:
for i in range(threads):
    start = i * len(sids) // threads
    end = (i + 1) * len(sids) // threads
    sids_chunk = sids[start:end]
    print(start, end, sids_chunk.shape)
    thread = threading.Thread(target=worker, args=(sids_chunk, i))
    thread.start()
    # worker(sids_chunk, i)
    
for i in range(threads):
    thread.join()
print("All threads completed.")