In [None]:
import matplotlib.pyplot as plt
import skimage
import pickle

import warnings
warnings.filterwarnings("ignore")

from tqdm import tqdm

In [None]:
import sys
sys.path.append("../")
sys.path.append("../imagen/")
sys.path.append("../../dataproc")

from utils import *

In [None]:
seed_value = 42
torch.manual_seed(seed_value)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed_value)

In [None]:
# region = "North Indian Ocean" ; name = "Amphan" ; start = 0
# region = "North Indian Ocean" ; name = "Tauktae" ; start = 0
# region = "Australia" ; name = "Ilsa" ; start = 0
# region = "West Indian Ocean" ; name = "Batsirai" ; start = 0
# region = "North Atlantic Ocean" ; name = "Iota" ; start = 0
# region = "West Pacific Ocean" ; name = "Chanthu" ; start = 0

# region = "West Indian Ocean" ; name = "Emnati" ; start = 24
# region = "North Pacific Ocean" ; name = "Orlene" ; start = 12

# region = "North Indian Ocean" ; name = "Mocha" ; start = 0
# region = "North Indian Ocean" ; name = "Maha" ; start = 24
# region = "Australia" ; name = "Veronica" ; start = 0
# region = "West Indian Ocean" ; name = "Gombe" ; start = 24
# region = "North Atlantic Ocean" ; name = "Ida" ; start = 12
# region = "North Pacific Ocean" ; name = "Rosyln" ; start = 0
# region = "West Pacific Ocean" ; name = "Molave" ; start = 30


name = name.replace(' ', '').lower()

cyclone = Cyclone(region, name)
cyclone.load_era5()

ir108_fn = cyclone.metadata['satmaps'][start]['ir108_fn']
ir108_scn = cyclone.get_ir108_data(ir108_fn)    
img = ir108_scn.to_numpy() ; 
img = transform_make_sq_image(img)  
img_o = skimage.transform.resize(img, (64, 64), anti_aliasing=True) 

In [None]:
FORECAST_DAYS = 1
horizon = min(cyclone.metadata['count']-start, 
              FORECAST_DAYS*24)

prev_img = img_o
prev_img = torch.from_numpy(prev_img).unsqueeze(0)
img_64_seq = torch.empty(0, 64, 64)
img_64_seq = torch.cat([img_64_seq, prev_img])

era5_64_seq = torch.empty(0, 3, 64, 64)
era5_128_seq = torch.empty(0, 3, 128, 128)

fcdiff_model = FCDiffModel("64_FC_woERA5_rot904_3e-4", img_o, woERA5=True)
srdiff_model = SRDiffModel("64_128_woERA5_rot904_sep_3e-4", img_o, woERA5=True)
tpdiff_model = None

print("Generating forecasts in 64x64 ...")

for satmap_idx in tqdm(range(start, start+horizon)):
    era5_idx = cyclone.metadata['satmaps'][satmap_idx]['era5_idx']
    era5 = cyclone.get_era5_data(era5_idx, gfs=True)
    
    era5_64 = skimage.transform.resize(era5, (3, 64, 64), anti_aliasing=True)
    era5_64 = torch.from_numpy(era5_64)
    era5_64_seq = torch.cat([era5_64_seq, era5_64.unsqueeze(0)])
    
    era5_128 = skimage.transform.resize(era5, (3, 128, 128), anti_aliasing=True)
    era5_128 = torch.from_numpy(era5_128)
    era5_128_seq = torch.cat([era5_128_seq, era5_128.unsqueeze(0)])
    
    if satmap_idx == start: 
        era5_tp = cyclone.get_era5_tp_data(era5_idx)
        era5_tp = skimage.transform.resize(era5_tp, (64, 64), anti_aliasing=True)
        tpdiff_model = TPDiffModel("64_PRP_woERA5_rot904_3e-4", era5_tp, woERA5=True)        
        continue
    
    era5_64 = torch.cat([prev_img]).unsqueeze(0)  
    era5_64 = era5_64.reshape(era5_64.shape[0], -1).float()

    curr_img = fcdiff_model.get_sampled_image(era5_64)
    curr_img = curr_img.cpu()
    img_64_seq = torch.cat([img_64_seq, curr_img]) 
    prev_img = curr_img

print("Forecast generation completed.")

print("Performing super-resolution to 128x128 ...")
sr_images = srdiff_model.get_sampled_images(img_64_seq, era5_128_seq)
print("Super resolution completed.")

print("Generating 64x64 precipitation maps ...")
tp_images = tpdiff_model.get_sampled_images(img_64_seq, era5_64_seq)
print("Precipitation maps generated.")

In [None]:
actual_era5_tp = torch.empty(0, 64, 64)
actual_ir108 = torch.empty(0, 128, 128)
dates = []

print("Loading actual data ...")

for satmap_idx in tqdm(range(start, start+horizon)):
    ir108_fn = cyclone.metadata['satmaps'][satmap_idx]['ir108_fn']
    ir108_scn = cyclone.get_ir108_data(ir108_fn)    
    img = ir108_scn.to_numpy() ; 
    img = transform_make_sq_image(img)    
      
    img_n = skimage.transform.resize(img, (128, 128), anti_aliasing=True)
    img_n = torch.from_numpy(img_n).unsqueeze(0)
    actual_ir108 = torch.cat([actual_ir108, img_n])
    
    era5_idx = cyclone.metadata['satmaps'][satmap_idx]['era5_idx']
    era5_tp = cyclone.get_era5_tp_data(era5_idx)
    era5_tp = skimage.transform.resize(era5_tp, (64, 64), anti_aliasing=True)
    era5_tp = torch.from_numpy(era5_tp).unsqueeze(0)
    actual_era5_tp = torch.cat([actual_era5_tp, era5_tp])

    dates.append(cyclone.metadata['satmaps'][satmap_idx]['date'])

tp_images[0] = actual_era5_tp[0]
print("Actual data loaded.")

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

from mpl_toolkits.basemap import Basemap
from scipy import ndimage

In [None]:
def update(frame_idx):
    fig.clear()
    axs = fig.subplot_mosaic([['ir', 'tp'], ['ir_pred', 'tp_pred']],
                          gridspec_kw={'width_ratios':[1, 1]})

    get_map_img(m, axs['ir'], 
                img2req(actual_ir108[frame_idx]), 
                cyclone.metadata['map_bounds'])
    axs['ir'].set_title("IR 10.8 µm\nGround Truth")

    get_map_img(m, axs['tp'], 
                actual_era5_tp[frame_idx], 
                cyclone.metadata['map_bounds'], era5=True)
    axs['tp'].set_title("Total Precipitation\nGround Truth")

    get_map_img(m, axs['ir_pred'], 
                img2req(sr_images[frame_idx][0].cpu()), 
                cyclone.metadata['map_bounds'])
    axs['ir_pred'].set_title("IR 10.8 µm\nDiffusion Model Forecast\n[w/o ERA5]")

    if frame_idx != start:
        get_map_img(m, axs['tp_pred'],
                ndimage.minimum_filter(tp_images[frame_idx][0].cpu(), size=3),
                cyclone.metadata['map_bounds'], era5=True)
    else:
        get_map_img(m, axs['tp_pred'],
                    tp_images[frame_idx][0].cpu(),
                    cyclone.metadata['map_bounds'], era5=True)
    axs['tp_pred'].set_title("Total Precipitation\nDiffusion Model Forecast\n[w/o ERA5]")

    fig.suptitle(f"Cyclone {name.replace('-', ' ').title()}\n{region}\n{dates[frame_idx].strftime('%Y-%m-%d %H:%M')}")

In [None]:
SAVE = False

m = Basemap(llcrnrlon=cyclone.metadata['map_bounds'][0], llcrnrlat=cyclone.metadata['map_bounds'][1],
            urcrnrlon=cyclone.metadata['map_bounds'][2], urcrnrlat=cyclone.metadata['map_bounds'][3],
            projection='cyl', resolution='l')

fig = plt.figure(figsize=(8,8), constrained_layout=True)
animation = FuncAnimation(fig, update, frames=horizon, interval=250)

if SAVE:
    predictions_dict = {
        "actual": {
            "ir_108": actual_ir108,
            "tp": actual_era5_tp,
        },
        "predicted": {
            "ir_108": sr_images,
            "tp": tp_images
        },
        "dates": dates
    }
    with open(f"./pkls_woERA5/{region_to_abbv[region]}_{start:02}_{name}_forecast.pkl", "wb") as file:
        pickle.dump(predictions_dict, file)
    animation.save(f'./gifs_woERA5/{region_to_abbv[region]}_{start:02}_{name}_forecast.gif', writer='imagemagick')

plt.close()  
HTML(animation.to_jshtml())