# NIMS-KISTI-NVIDIA Hackathon 

# Problem Statement 



## Agenda 
 - Introduction to FourCastNet
 - Configure FourCastNet
 - Typoon Dataset (JMA best track) 
 - ECMWF ERA5 dataset(CDS API)
 - custom interval 
 - inference 
 - post processing 
 






## Introduction to FourCastNet

 - [FourCastNet Paper Link](https://arxiv.org/abs/2202.11214)
 - [GitHub Repository](https://github.com/NVlabs/FourCastNet)
 
FourCastNet, short for Fourier ForeCasting Neural Network, is a global data-driven weather forecasting model that provides accurate short to medium range global predictions at 0.25° resolution.
<div align="center">

<img src="https://github.com/NVlabs/FourCastNet/blob/master/assets/FourCastNet.gif?raw=true">
</div>





### FourCastNet Architecture :
FourCastNet uses a Fourier transform-based token-mixing scheme with a vision transformer (ViT) backbone. This approach is based on the recent Fourier neural operator that learns in a resolution-invariant manner and has shown success in modeling challenging partial differential equations (PDE) such as fluid dynamics.


<div align="center">
<img src="https://docscontent.nvidia.com/dims4/default/9c9b4d6/2147483647/strip/true/crop/1188x788+0+0/resize/1188x788!/quality/90/?url=https%3A%2F%2Fk3-prod-nvidia-docs.s3.amazonaws.com%2Fbrightspot%2Fsphinx%2F00000187-c443-dd05-a3bf-e57fb6480000%2Fdeeplearning%2Fmodulus%2Fmodulus-sym%2F_images%2Ffourcastnet_overview.png" >

</div>

 


### FourCastNet modeled variables:    
| Vertical Level | Variable                |
|:-------------- |:-----------------------|
| Surface        | U10, V10, T2M, SP, MSLP |
| 1000 hPa       | U, V, Z                |
| 850 hPa        | T, U, V, Z, RH         |
| 500 hPa        | T, U, V, Z, RH         |
| 50 hPa         | Z                      |
| Integrated     | TCWV                   |
 


## copy model source code 

### install required modules for FourCastNet

## download dataset with CDS API 

FourCastNet modeled variables
<table align="left" border="1">
  <tr>
    <th>Vertical Level</th>
    <th>Variable</th>
  </tr>
  <tr>
    <td>Surface</td>
    <td>U10, V10, T2M, SP, MSLP</td>
  </tr>
  <tr>
    <td>1000 hPa</td>
    <td>U, V, Z</td>
  </tr>
  <tr>
    <td>850 hPa</td>
    <td>T, U, V, Z, RH</td>
  </tr>
  <tr>
    <td>500 hPa</td>
    <td>T, U, V, Z, RH</td>
  </tr>
  <tr>
    <td>50 hPa</td>
    <td>Z</td>
  </tr>
  <tr>
    <td>Integrated</td>
    <td>TCWV</td>
  </tr>
</table>



# visualize

In [None]:
!ls -lah FCN_output

In [None]:
h5_gt_filename = '/mnt/workspace/nims/notebooks/hdf5_dir/merged_01_2214_2022091212_2022092000.h5'
h5_pred_filename = '/mnt/workspace/nims/notebooks/FCN_output/autoregressive_predictions_z500_vis.h5'

In [None]:
import h5py
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

def get_variable(file_path, variable_name):
    hdf5_file = h5py.File(file_path, 'r')
    variable = hdf5_file[variable_name][:]
    hdf5_file.close()
    return variable

def plot_variable(variable, frame_index, variable_index):
    variable_names = ['u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z1000', 'u850', 'v850', 'z850', 'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv']
    vmin_values = [-30, -30, 220, None, None, 220, -30, -30, None, -30, -30, None, -30, -30, None, 220, None, 0, None, None]
    vmax_values = [30, 30, 300, None, None, 300, 30, 30, None, 30, 30, None, 30, 30, None, 270, None, 120, None, None]
    cmaps = [None, None, 'coolwarm', None, None, 'coolwarm', None, None, None, None, None, None, None, None, None, 'coolwarm', None, None, None, None]

    variable_name = variable_names[variable_index]
    vmin = vmin_values[variable_index]
    vmax = vmax_values[variable_index]
    cmap = cmaps[variable_index]

    plt.figure(figsize=(8, 4))

    if vmin is not None and vmax is not None:
        plt.imshow(variable[frame_index][variable_index], vmin=vmin, vmax=vmax, cmap=cmap)
    else:
        plt.imshow(variable[frame_index][variable_index], cmap=cmap)

    plt.colorbar(shrink=0.75)
    plt.title(f"Variable {variable_name} at Frame {frame_index}")
    plt.show()



In [None]:
gt_variables = get_variable(h5_gt_filename, 'fields' )

In [None]:
gt_variables.shape

In [None]:
# Plot the variable for a specific frame and variable index
frame_index = 16  # Replace with the desired frame index
variable_index = 4 # Replace with the desired variable index
plot_variable(gt_variables, frame_index, variable_index)

In [None]:
pred_variables = get_variable(h5_pred_filename, 'predicted' )

In [None]:
pred_variables.shape

In [None]:
# Plot the variable for a specific frame and variable index
frame_index = 16  # Replace with the desired frame index
variable_index = 4  # Replace with the desired variable index
plot_variable(pred_variables[0], frame_index, variable_index)

### TODO denormalize 


In [None]:
gm_file = '/mnt/workspace/nims/notebooks/FCN_weights_v0/stats_v0/global_means.npy'
gs_file = '/mnt/workspace/nims/notebooks/FCN_weights_v0/stats_v0/global_stds.npy'
tm_file = '/mnt/workspace/nims/notebooks/FCN_weights_v0/stats_v0/time_means.npy'

In [None]:
import numpy as np 
gm_value = np.load(gm_file)
gs_value = np.load(gs_file)
tm_value = np.load(tm_file)

In [None]:
print(gm_value.shape, gs_value.shape, tm_value.shape)

In [None]:
def normalize_gt(gt, global_means, global_stds):
    normalized_gt = (gt - global_means) / global_stds
    return normalized_gt

def denormalize_pred(pred, global_means, global_stds):
    denormalized_pred = pred * global_stds + global_means
    return denormalized_pred

## gt values

In [None]:
normalized_gt_variables = normalize_gt(gt_variables, gm_value[:,:20,:,:], gs_value[:,:20,:,:])
print(gt_variables.shape, normalized_gt_variables.shape)

In [None]:
# Plot the variable for a specific frame and variable index
frame_index = 16  # Replace with the desired frame index
variable_index = 4  # Replace with the desired variable index
plot_variable(normalized_gt_variables , frame_index, variable_index)

In [None]:
# Plot the variable for a specific frame and variable index
frame_index = 16  # Replace with the desired frame index
variable_index = 4  # Replace with the desired variable index
plot_variable(pred_variables[0], frame_index, variable_index)

## predict

In [None]:
denormalized_pred_variables = denormalize_pred(pred_variables, gm_value[:,:20,:,:], gs_value[:,:20,:,:])
print(pred_variables.shape, denormalized_pred_variables.shape)

In [None]:
# Plot the variable for a specific frame and variable index
frame_index = 16  # Replace with the desired frame index
variable_index = 4  # Replace with the desired variable index
plot_variable(pred_variables[0], frame_index, variable_index)

In [None]:
# Plot the variable for a specific frame and variable index
frame_index = 16  # Replace with the desired frame index
variable_index = 4  # Replace with the desired variable index
plot_variable(denormalized_pred_variables[0], frame_index, variable_index)

## visualize trajectory fron FCN inference 

In [None]:
import h5py
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

def get_variable(file_path, variable_name):
    hdf5_file = h5py.File(file_path, 'r')
    variable = hdf5_file[variable_name][:]
    hdf5_file.close()
    return variable

def extract_typhoon_trajectory(var_data):
    trajectory_data = []

    # Assuming the structure of your data, please change if it's different.
    longitude = np.linspace(0, 360, var_data.shape[2])
    latitude = np.linspace(-90, 90, var_data.shape[1])
    lon_grid, lat_grid = np.meshgrid(longitude, latitude)

    for idx, frame in enumerate(var_data):
        min_pressure = np.nanmin(frame) 
        typhoon_present = min_pressure < 99_000  

        if typhoon_present:
            typhoon_lat, typhoon_lon = lat_grid[frame == min_pressure][0], lon_grid[frame == min_pressure][0]
        else:
            typhoon_lat, typhoon_lon = np.nan, np.nan

        trajectory_data.append([idx, typhoon_present, typhoon_lon, typhoon_lat, min_pressure])

    df = pd.DataFrame(trajectory_data, columns=['idx_frame', 'typhoon_or_not', 'lon', 'lat', 'pressure'])

    return df

def visualize_pressure_animation_with_trajectory(var_data, df, projection='2d'):
    import pandas as pd
    import numpy as np    
    # Preparation for plotting
    var_data = np.flip(var_data, axis=1)
    central_longitude = 130
    central_latitude = 0
    longitude = np.linspace(0, 360, var_data.shape[2])

    if projection == '3d':
        fig = plt.figure(figsize=(10, 6))
        ax = plt.axes(projection=ccrs.Orthographic(central_longitude=central_longitude, central_latitude=central_latitude))
    elif projection == '2d':
        fig = plt.figure(figsize=(10, 5))
        ax = plt.axes(projection=ccrs.PlateCarree(central_longitude=central_longitude))
    else:
        raise ValueError("Invalid projection value. Choose either '2d' or '3d'.")

    coastline = cfeature.COASTLINE.with_scale('110m')
    lon_grid, lat_grid = np.meshgrid(longitude, np.linspace(-90, 90, var_data.shape[1]))
    
    # Initial plot
    im = ax.pcolormesh(lon_grid, lat_grid, var_data[0], vmin=97_000, vmax=100_000,  cmap='hot', transform=ccrs.PlateCarree())
    ax.add_feature(coastline)
    #ax.set_extent([100, 150, -5, 50], crs=ccrs.PlateCarree())
    ax.set_extent([50, 180, -5, 50], crs=ccrs.PlateCarree())    
    ax.gridlines()
    cbar = plt.colorbar(im, label='pressure (hPa)', aspect=40, shrink=0.65)
    title = ax.set_title(f'MSL - Frame 0')
    
    # Typhoon trajectory line
    traj_line, = ax.plot([], [], color='black', transform=ccrs.PlateCarree())

    # Update function for animation
    def update(frame):
        im.set_array(var_data[frame].ravel())  # Use flattened array
        title.set_text(f'Pressure - Frame {frame}')

        # Update typhoon trajectory
        traj_df = df[df['idx_frame'] <= frame]
        traj_line.set_data(traj_df['lon'], traj_df['lat'])

        return [im, title, traj_line]

    animation = FuncAnimation(fig, update, frames=range(var_data.shape[0]), interval=200)
    plt.close()  # Prevents displaying the initial plot
    animation_html = animation.to_jshtml()
    autoplay_html = animation_html.replace('controls>', 'controls autoplay>')

    return HTML(autoplay_html)


In [None]:
# Get your data
gt_data = get_variable(h5_gt_filename, 'fields')
pred_data = get_variable(h5_pred_filename, 'predicted').squeeze(0) 
print("input",gt_data.shape, pred_data.shape)

pred_data = denormalize_pred(pred_data, gm_value[:,:20,:,:], gs_value[:,:20,:,:])
#pred_data = pred_data.squeeze(0)  # Squeeze the first dimension
print("denorm", pred_data.shape, pred_data.shape)

# Assuming pressure is the variable you're interested in (change if needed)
gt_var_data = gt_data[:, 4, :, :] # Selecting msl the 5th variable across all frames
pred_var_data = pred_data[ :, 4, :, :] # Selecting msl the 5th variable across all frames for the predictions
print("select variable", gt_var_data.shape, pred_var_data.shape)

# Extract typhoon trajectories
gt_df = extract_typhoon_trajectory(gt_var_data)
pred_df = extract_typhoon_trajectory(pred_var_data)
print("traj ", gt_df.shape, pred_df.shape)


In [None]:
plt.imshow(gt_var_data[26], origin='upper' )

In [None]:
plt.imshow(pred_var_data[26], origin='upper' )

In [None]:
print(gt_df)

In [None]:
print(pred_df)

In [None]:
# Ground truth visualization
visualize_pressure_animation_with_trajectory(gt_var_data, gt_df, projection='3d')

In [None]:
# Prediction visualization
visualize_pressure_animation_with_trajectory(pred_var_data, pred_df, projection='3d')

# Accessment 



Evaluate 30 Sep. 2018 ( extreme case ECMWF-HRES forcast) 

```
{'id': '1825', 'dur': 49, 'start': '2018092800', 'end': '2018100712', 'latitude': 7.4, 'longitude': 150.9, 'lp': 900, 'hws': 115, 'grade': 5} # kongrey
{'id': '1826', 'dur': 54, 'start': '2018102018', 'end': '2018110300', 'latitude': 8.4, 'longitude': 160.7, 'lp': 900, 'hws': 115, 'grade': 5} # yutu
{'id': '2211', 'dur': 69, 'start': '2022082718', 'end': '2022090900', 'latitude': 23.8, 'longitude': 151.1, 'lp': 920, 'hws': 105, 'grade': 5}  # HINNAMNO

```
- FourCastNet autoregressive prediction 
<img src='https://docscontent.nvidia.com/dims4/default/9c9b4d6/2147483647/strip/true/crop/1188x788+0+0/resize/1188x788!/quality/90/?url=https%3A%2F%2Fk3-prod-nvidia-docs.s3.amazonaws.com%2Fbrightspot%2Fsphinx%2F00000187-c443-dd05-a3bf-e57fb6480000%2Fdeeplearning%2Fmodulus%2Fmodulus-sym%2F_images%2Ffourcastnet_overview.png'>


- FourCastNet iterative prediction with adjust

<img src='https://www.researchgate.net/publication/372137432/figure/fig4/AS:11431281172607179@1688612874115/Pangu-Weather-is-more-accurate-at-early-stage-cyclone-tracking-than-ECMWF-HRES-a-b.png'>
image from Pangu-Weather paper