# NIMS-KISTI-NVIDIA Hackathon 

# Problem Statement 



## Agenda 
 - Introduction to FourCastNet
 - Configure FourCastNet
 - Typoon Dataset (JMA best track) 
 - ECMWF ERA5 dataset(CDS API)
 - custom interval 
 - inference 
 - post processing 
 






## Introduction to FourCastNet

 - [FourCastNet Paper Link](https://arxiv.org/abs/2202.11214)
 - [GitHub Repository](https://github.com/NVlabs/FourCastNet)
 
FourCastNet, short for Fourier ForeCasting Neural Network, is a global data-driven weather forecasting model that provides accurate short to medium range global predictions at 0.25° resolution.
<div align="center">

<img src="https://github.com/NVlabs/FourCastNet/blob/master/assets/FourCastNet.gif?raw=true">
</div>





### FourCastNet Architecture :
FourCastNet uses a Fourier transform-based token-mixing scheme with a vision transformer (ViT) backbone. This approach is based on the recent Fourier neural operator that learns in a resolution-invariant manner and has shown success in modeling challenging partial differential equations (PDE) such as fluid dynamics.


<div align="center">
<img src="https://docscontent.nvidia.com/dims4/default/9c9b4d6/2147483647/strip/true/crop/1188x788+0+0/resize/1188x788!/quality/90/?url=https%3A%2F%2Fk3-prod-nvidia-docs.s3.amazonaws.com%2Fbrightspot%2Fsphinx%2F00000187-c443-dd05-a3bf-e57fb6480000%2Fdeeplearning%2Fmodulus%2Fmodulus-sym%2F_images%2Ffourcastnet_overview.png" >

</div>

 


### FourCastNet modeled variables:    
| Vertical Level | Variable                |
|:-------------- |:-----------------------|
| Surface        | U10, V10, T2M, SP, MSLP |
| 1000 hPa       | U, V, Z                |
| 850 hPa        | T, U, V, Z, RH         |
| 500 hPa        | T, U, V, Z, RH         |
| 50 hPa         | Z                      |
| Integrated     | TCWV                   |
 


## copy model source code 

### install required modules for FourCastNet

In [None]:
! pip install wandb h5py mpi4py netCDF4 cdsapi  ruamel.yaml tqdm timm einops 
! pip install git+https://github.com/romerojosh/benchy.git
! pip install cdsapi netcdf4 xarray

### copy FourCastNet source code 

In [None]:
!git clone https://github.com/NVlabs/FourCastNet.git

## download checkpoint and stat0

The model weights hosted at Trained Model Weights :

```
FCN_weights_v0/
│   backbone.ckpt  
│   precip.ckpt  
```

The pre-computed normalization statistics hosted at additional. It is crucial that you use the statistics that are provided if you are using the pre-trained model weights that we have provided since these stats were used when trainig the model. The normalization statistics go hand-in-hand with the trained model weights. The stats folder contains:


```
stats_v0
│   global_means.npy  
│   global_stds.npy  
│   land_sea_mask.npy  
│   latitude.npy  
│   longitude.npy  
│   time_means.npy
│   time_means_daily.h5
```


In [None]:
%%time 
!wget -N -P FCN_weights_v0          https://portal.nersc.gov/project/m4134/FCN_weights_v0/backbone.ckpt 
#!wget -N -P FCN_weights_v0         https://portal.nersc.gov/project/m4134/FCN_weights_v0/precip.ckpt               
!wget -N -P FCN_weights_v0/stats_v0 https://portal.nersc.gov/project/m4134/FCN_weights_v0/stats_v0/global_means.npy 
!wget -N -P FCN_weights_v0/stats_v0 https://portal.nersc.gov/project/m4134/FCN_weights_v0/stats_v0/global_stds.npy  
!wget -N -P FCN_weights_v0/stats_v0 https://portal.nersc.gov/project/m4134/FCN_weights_v0/stats_v0/land_sea_mask.npy
!wget -N -P FCN_weights_v0/stats_v0 https://portal.nersc.gov/project/m4134/FCN_weights_v0/stats_v0/latitude.npy    
!wget -N -P FCN_weights_v0/stats_v0 https://portal.nersc.gov/project/m4134/FCN_weights_v0/stats_v0/longitude.npy    
!wget -N -P FCN_weights_v0/stats_v0 https://portal.nersc.gov/project/m4134/FCN_weights_v0/stats_v0/time_means.npy   

# Typoon Dataset

### get typoon data with JMA Besttrack(~2022)

## download dataset with CDS API 

FourCastNet modeled variables
<table align="left" border="1">
  <tr>
    <th>Vertical Level</th>
    <th>Variable</th>
  </tr>
  <tr>
    <td>Surface</td>
    <td>U10, V10, T2M, SP, MSLP</td>
  </tr>
  <tr>
    <td>1000 hPa</td>
    <td>U, V, Z</td>
  </tr>
  <tr>
    <td>850 hPa</td>
    <td>T, U, V, Z, RH</td>
  </tr>
  <tr>
    <td>500 hPa</td>
    <td>T, U, V, Z, RH</td>
  </tr>
  <tr>
    <td>50 hPa</td>
    <td>Z</td>
  </tr>
  <tr>
    <td>Integrated</td>
    <td>TCWV</td>
  </tr>
</table>



# Data conversion ( NetCDF4 to HDF5)

## pytorch data loader 
- ECMWF ERA5 dataset : multiple NetCDF4
  - SL : variables(5)  x H(721) x W(1440)
  - PL : presure levels(4) x variables(4) x H(721) x W(1440)
  - eliminate lat/lon data (same for whole datasets) 
- Pytorch DataLoader : HDF5 data ( structured data) 
  - time_frame x variables(20) x H(721) x W(1440)
  

In [None]:
%%file merge_h5.py 

import os
import h5py
import numpy as np
from netCDF4 import Dataset as DS
from glob import glob

netcdf_dir = './custom_interval'
hdf5_dir = './hdf5_dir'
prefix = 'merged'
file_ext = 'h5'
dset_name = 'fields'

def get_matched_files(netcdf_dir, DEBUG=False):
    pl_files = sorted(glob(os.path.join(netcdf_dir, '*_pl.nc')))
    sl_files = sorted(glob(os.path.join(netcdf_dir, '*_sl.nc')))
    
    matched_files = []
    for pl_file in pl_files:
        sl_file = pl_file[:-len('_pl.nc')] + '_sl.nc'
        if sl_file in sl_files:
            matched_files.append((pl_file, sl_file))
    if DEBUG :
        print( f"DEBUG {len(matched_files)}")
    
    return matched_files

def generate_output_filename(orig_filename, idx,  hdf5_dir, prefix):
    base_name = os.path.basename(orig_filename) 
    base_name_without_ext = os.path.splitext(base_name)[0][:-3] # this removes the last '_pl' or '_sl' part
    return os.path.join(hdf5_dir, f"{prefix}_{idx:02d}_{base_name_without_ext}.h5")

def writetofile(src, dest, channel_idx, varslist, src_idx=0, frmt='nc', DEBUG=False):       
    for variable_name in varslist:
        if os.path.isfile(src):
            if frmt == 'nc':
                fsrc = DS(src, 'r', format="NETCDF4").variables[variable_name]
            elif frmt == 'h5':
                fsrc = h5py.File(src, 'r')[varslist[0]]

            fdest = h5py.File(dest, 'a')
            if dset_name not in fdest:
                print("DEBUG: create file")
                Nimgtot = fsrc.shape[0]
                fdest.create_dataset(dset_name, (Nimgtot, 20, 721, 1440), dtype='f')

            if len(fsrc.shape) == 4:
                ims = fsrc[:, src_idx]
            else:
                ims = fsrc[:]

            fdest['fields'][:, channel_idx, :, :] = ims
        channel_idx += 1
    if DEBUG : 
        print("done", varslist, ims.shape)

def process_files(files, hdf5_dir, prefix, file_ext, DEBUG=False):
    for i, (pl_file, sl_file) in enumerate(files):
        output_file = generate_output_filename(pl_file, i, hdf5_dir, prefix)
        if DEBUG : 
            print( f"DEBUG {i} {sl_file}  {pl_file} {output_file}") 

        writetofile(sl_file, output_file, 0, ['u10'])
        writetofile(sl_file, output_file, 1, ['v10'])
        writetofile(sl_file, output_file, 2, ['t2m'])
        writetofile(sl_file, output_file, 3, ['sp'])
        writetofile(sl_file, output_file, 4, ['msl'])

        writetofile(pl_file, output_file, 5, ['t'], 2)
        writetofile(pl_file, output_file, 6, ['u'], 3)
        writetofile(pl_file, output_file, 7, ['v'], 3)
        writetofile(pl_file, output_file, 8, ['z'], 3)

        writetofile(pl_file, output_file, 9, ['u'], 2)
        writetofile(pl_file, output_file, 10, ['v'], 2)
        writetofile(pl_file, output_file, 11, ['z'], 2)

        writetofile(pl_file, output_file, 12, ['u'], 1)
        writetofile(pl_file, output_file, 13, ['v'], 1)
        writetofile(pl_file, output_file, 14, ['z'], 1)

        writetofile(pl_file, output_file, 15, ['t'], 1)
        writetofile(pl_file, output_file, 16, ['z'], 0)
        writetofile(pl_file, output_file, 17, ['r'], 1)
        writetofile(pl_file, output_file, 18, ['r'], 2)

        writetofile(sl_file, output_file, 19, ['tcwv'])

def main():
    DEBUG= True
    print("get lists")
    files_to_process = get_matched_files(netcdf_dir, DEBUG=DEBUG)
    os.makedirs(hdf5_dir, exist_ok=True)
    
    print("start process")
    process_files(files_to_process, hdf5_dir, prefix, file_ext, DEBUG=DEBUG)

if __name__ == '__main__':
    main()


In [None]:
!python3 merge_h5.py

## visualize variables 

In [None]:
import h5py
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

def get_variable(file_path, variable_name):
    # Open the HDF5 file
    hdf5_file = h5py.File(file_path, 'r')

    # Retrieve the variable data
    variable = hdf5_file[variable_name][:]

    # Close the file
    hdf5_file.close()

    return variable


import matplotlib.pyplot as plt

def plot_variable(variable, frame_index, variable_index):
    variable_names = ['u10', 'v10', 't2m', 'sp', 'msl', 't850', 'u1000', 'v1000', 'z1000', 'u850', 'v850', 'z850', 'u500', 'v500', 'z500', 't500', 'z50', 'r500', 'r850', 'tcwv']
    vmin_values = [-30, -30, 220, None, None, 220, -30, -30, None, -30, -30, None, -30, -30, None, 220, None, 0, None, None]
    vmax_values = [30, 30, 300, None, None, 300, 30, 30, None, 30, 30, None, 30, 30, None, 270, None, 120, None, None]
    cmaps = [None, None, 'coolwarm', None, None, 'coolwarm', None, None, None, None, None, None, None, None, None, 'coolwarm', None, None, None, None]

    variable_name = variable_names[variable_index]
    vmin = vmin_values[variable_index]
    vmax = vmax_values[variable_index]
    cmap = cmaps[variable_index]

    plt.figure(figsize=(8, 4))

    if vmin is not None and vmax is not None:
        plt.imshow(variable[frame_index][variable_index], vmin=vmin, vmax=vmax, cmap=cmap)
    else:
        plt.imshow(variable[frame_index][variable_index], cmap=cmap)

    plt.colorbar(shrink=0.75)
    plt.title(f"Variable {variable_name} at Frame {frame_index}")
    plt.show()

        

In [None]:
variables = get_variable('./hdf5_dir/merged_01_2214_2022091212_2022092000.h5', 'fields' )

In [None]:
variables.shape

In [None]:
for i in range(20):
    plot_variable(variables,8,i)

## normalize 
FourCastNet dataloader use stats_v0 ( mean,std) for each variables 


```
stats_v0
│   global_means.npy  
│   global_stds.npy  
│   land_sea_mask.npy  
│   latitude.npy  
│   longitude.npy  
│   time_means.npy
│   time_means_daily.h5
```
