<a href="https://colab.research.google.com/github/sanAkel/ufs_diurnal_diagnostics/blob/main/RTOFS/AWS_download/rtofs_aws_download_plot_ssh.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## RTOFS nowcast (`n`) 2d files from [AWS s3](https://registry.opendata.aws/noaa-rtofs/)

### Download from [this url.](https://noaa-nws-rtofs-pds.s3.amazonaws.com/index.html)

In [None]:
!pip install cartopy

In [None]:
import xarray as xr
import numpy as np
import pandas as pd

import glob as glob
from PIL import Image

import matplotlib.pyplot as plt

import cartopy.crs as ccrs
import cartopy.feature as cfeature
from cartopy.mpl.gridliner import LONGITUDE_FORMATTER, LATITUDE_FORMATTER

#%matplotlib inline

## Cut-out region coordinates. See [this notebook if there are issues.](https://github.com/sanAkel/ufs_diurnal_diagnostics/blob/main/RTOFS/AWS_download/rtofs_cut_out_region.ipynb)

In [None]:
def get_index(lat_array, lon_array, lat0, lon0):
  # First, find the index of the grid point nearest a specific lat/lon.
  abslat = np.abs(lat_array-lat0)
  abslon = np.abs(lon_array-lon0)
  c = np.maximum(abslon, abslat)

  ([xloc], [yloc]) = np.where(c == np.min(c))

  #point_ds = ds.sel(X=xloc, Y=yloc)
  #print(xloc, yloc)
  return [xloc, yloc]

def get_cutOut(ds, lon_s, lon_e, lat_s, lat_e):
  # Lower left
  [x1, y1] = get_index(ds.Latitude.values, ds.Longitude.values, lat_s, lon_s)
  # Lower right
  [x2, y2] = get_index(ds.Latitude.values, ds.Longitude.values, lat_s, lon_e)
  # x1 will be same as x2
  # ----------------------
  # Upper left
  [x3, y3] = get_index(ds.Latitude.values, ds.Longitude.values, lat_e, lon_s)
  # Upper right
  [x4, y4] = get_index(ds.Latitude.values, ds.Longitude.values, lat_e, lon_e)
  # x3 will be same as x4

  #ds_cutOut=ds.sel(X=slice(y1, y2), Y=slice(x1, x3))
  return [y1, y2, x1, x3]

In [None]:
def plot_rtofs_ssh(ds, xMin, xMax, yMin, yMax,
                   vName1='ssh',
                   vName2='surf_curr'):
  # form date string
  yyyymmdd = ds.MT.values[0].astype("str")
  dStr = yyyymmdd.split('T')[0]+'T'+yyyymmdd.split('T')[1].split(':')[0]
  print(dStr)

  fig=plt.figure(figsize=(8, 6), dpi=180)
  ax=fig.add_subplot(111)

  im1=ds[vName].sel(X=slice(xMin, xMax), Y=slice(yMin, yMax)).plot(ax=ax, vmin=-.5, vmax=.5, cmap="gist_ncar", add_colorbar=False)

  im2=ds[vName].sel(X=slice(xMin, xMax), Y=slice(yMin, yMax)).squeeze().plot.contour(ax=ax, vmin=-.5, vmax=.5, cmap="gray")

  cbar=plt.colorbar(im1, ax=ax)
  if vName == 'ssh':
    cbar.set_label("Sea Surface Height (m)")

  ax.set_title("{}".format(dStr))
  ax.set_xlabel("")
  ax.set_ylabel("")

  figName= vName + '_' + dStr + '.png'
  plt.savefig(figName)
  plt.close()

In [None]:
def plot_region(plot_data, vName1='ssh', vName2='surf_curr',
                vMin1=-0.3, vMax1=0.8, cMap1='gist_ncar',
                vMin2=0., vMax2=1., cMap2='gray'):

  fig = plt.figure(figsize=(10, 12), dpi=180)
  ax = fig.add_subplot(1,1,1, projection=ccrs.PlateCarree(central_longitude=-180.0))


  ax.add_feature(cfeature.LAND, zorder=0,
                 edgecolor='k', facecolor=("lightgray"), alpha=0.7)
  ax.coastlines(alpha=0.2)

  gl = ax.gridlines(draw_labels=True)
  gl.top_labels = False
  gl.right_labels = False
  gl.xformatter = LONGITUDE_FORMATTER
  gl.yformatter = LATITUDE_FORMATTER
  gl.xlabel_style = {'size': 6}
  gl.ylabel_style = {'size': 6}

  im1= ax.pcolormesh(plot_data['Longitude'].values,
                     plot_data['Latitude'].values,
                     plot_data[vName1].values.squeeze(),
                     transform=ccrs.PlateCarree(),
                     vmin=vMin1, vmax=vMax1, cmap=cMap1)

  im2= ax.contour(plot_data['Longitude'].values,
                  plot_data['Latitude'].values,
                  plot_data[vName2].values.squeeze(),
                  transform=ccrs.PlateCarree(),
                  vmin=vMin2, vmax=vMax2, cmap=cMap2, alpha=0.3)

  cbar=plt.colorbar(im1, ax=ax, pad=0.03, orientation='horizontal', shrink=0.85)

  if vName1 == 'ssh':
    cbar.set_label("Sea Surface Height (m)")

  # form date string
  yyyymmdd = ds.MT.values[0].astype("str")
  dStr = yyyymmdd.split('T')[0]+'HH'+yyyymmdd.split('T')[1].split(':')[0]
  print(dStr)
  ax.set_title("{}".format(dStr))

  figName= vName1 + '_' + vName2 + '_' + dStr + '.png'
  plt.savefig(figName, bbox_inches='tight')
  plt.close()

## User inputs

In [None]:
# file name and time-stamp -1 day lag
start_date, end_date = ["20250321", "20250331"]

hSkip=12 # plot every hour
hours = np.arange(0, 25, hSkip)

# Plot region bounds
lat_s, lat_e = [-50, 50]
lon_s, lon_e = [100, 290]

### Set fixed parameters

In [None]:
data_dates = pd.date_range(start=start_date, end=end_date)
url_base = "https://noaa-nws-rtofs-pds.s3.amazonaws.com/rtofs."
fPref = "rtofs_glo_2ds_n"
fSuff = "_diag.nc"

## Download files, plot, save figure and delete downloaded file

In [None]:
for dd in data_dates:
      print("Downloading data for:\t", dd.strftime("%Y%m%d"))
      for hr in hours:
        url = url_base + dd.strftime("%Y%m%d") + "/" +\
        fPref + str(hr).zfill(3)+ fSuff
        #fName = fPref + str(hr).zfill(3)+ fSuff
        fName_save = fPref + dd.strftime("_%Y%m%d_") + str(hr).zfill(3)+ fSuff
        #print(url, fName_save)

        !wget $url -O $fName_save
        ds = xr.open_dataset(fName_save)

        if (dd == data_dates[0] and hr== hours[0]):
          # Get cut-out box limits
          [y1, y2, x1, x3] = get_cutOut(ds, lon_s, lon_e, lat_s, lat_e)

        plot_data=ds.sel(X=slice(y1, y2), Y=slice(x1, x3))
        plot_data['surf_curr'] = np.sqrt(plot_data['u_barotropic_velocity']**2 + plot_data['v_barotropic_velocity']**2)

        #plot_rtofs_ssh(ds, xMin, xMax, yMin, yMax)
        plot_region(plot_data) # nicer plot than what above makes

        !rm $fName_save

In [None]:
!mkdir -p figs
!mv *.png figs/

## Make a gif animation from png files

In [None]:
png_files_path = 'figs/'
png_files = sorted( glob.glob( png_files_path + '/' + 'ssh_*.png'))
#print(fNames)

DUR, LOOP = [150, 0] # Duration in milliseconds, infinite loop

images = []
for filename in png_files:
  im = Image.open(filename)
  images.append(im)

# save as a gif
fOut = 'rtofs_ssh_' + start_date + '_' + end_date + '.gif'
images[0].save(fOut,save_all=True,
               append_images=images[0:],
               optimize=False,
               duration=DUR, # Duration in milliseconds
               loop=LOOP) # infinite loop

print('\nSaved:\t{}\n'.format(fOut))