In [1]:
import numpy as np
import xarray as xr
import fsspec
import matplotlib
import matplotlib.pyplot as plt
import pyqg
import json

from pyqg.diagnostic_tools import calc_ispec as _calc_ispec
%run coarsening_ops.ipynb

%matplotlib inline
plt.rcParams.update({'font.size': 13})

calc_ispec = lambda *args, **kwargs: _calc_ispec(*args, averaging=False, truncate=False, **kwargs)



### ––– Forcing Data ––– 
---
###### Inputs:
- ```nx_hires```: number of grid points in high resolution (Spatial)
- ```nx_lores```: number of grid points in low resolution (Spatial)
- ```dt```: timestep (Temporal)
- ```tmax```: simulation time (Temporal)
- ```tavestart```: time at which we begin averaging diagnostics (Temporal)


###### Outputs:
- A list of ```forcing datasets``` containing filtered and coarse-grained high resolution snapshots along with diagnostics and subgrid forcing variables. 
- Each element of the list corresponds with a particular filtering and coarse-graining method.
---
Notes: 
- Generally speaking, we can play around with the temporal side if we want larger/smaller datasets. Quasisteady state takes ~5 years. 
- The spatial side will likely remain the same as eddy/jet configurations rely on the Rossby deformation radius, which is defined by L (real space grid size) and nx (number of grid points). Since L is more or less held constant, nx should also.
---
TO DO:
- Generalize function to take in arbitrary amounts of filtering and coarse-graining methods
    - add input variable of such operators

In [2]:
def generate_forcing_data(nx_hires = 256, nx_lores = 64, dt = 3600.0, tmax = 311040000.0, tavestart = 155520000.0):
    # ——— FORCING DATASETS ———
    # Create datasets with:
    # 1. Snapshots of filtered and coarse-grained high resolution simulations at every timestep
    # 2. Associated subgrid forcings
    # 3. Diagnostics evaluated at the end (most recent, should already be time averaged)

    snapshots_hr = [] # high resolution snapshots
    forcing_1 = [] 
    forcing_2 = []
    forcing_3 = []
    
    base_kwargs = {'dt': dt, 'tmax': tmax, 'tavestart': tavestart}
    high_res = pyqg.QGModel(nx=nx_hires, **base_kwargs)

    # note: since each count doesn't update (reinitialization at every step):
    count = [0,0,0] # 1. to iterate diagnostics
    diag_temp = [] # 2. to save diagnostics of previous iteration

    while high_res.t < high_res.tmax:
        if high_res.tc % 1000 == 0: # every 1000 hours
            snapshots_hr.append(high_res.to_dataset().copy(deep=True)) # note: this is the only thing in snapshots_hr

            op1 = Operator1(high_res, nx_lores)
            op2 = Operator2(high_res, nx_lores) 
            op3 = Operator3(high_res, nx_lores)
            ops = [op1, op2, op3]

            fcg_lowres = [] # temporary array to store filtered & coarse-grained high resolution and its variables

            # 1. loop to store variables over time since the operators re-initialize low res simulation at every call
            for i, op in enumerate(ops):

                # ————————————————— DIAGNOSTICS —————————————————
                op.m2._initialize_diagnostics(diagnostics_list='all')

                if len(diag_temp) == 3: # make sure we have diagnostics in this array (i.e. skip first loop)
                    op.m2.diagnostics = diag_temp[i] # initialize our new m2 with previous iteration's diagnostics

                # See: _increment_diagnostics() 
                # Link: https://github.com/pyqg/pyqg/blob/8a792d4d4d36580af025b002417be60afd6a991a/pyqg/model.py#L770
                if (high_res.t>=high_res.dt) and (high_res.t>=high_res.tavestart) and (high_res.tc%high_res.taveints==0):
                    for d in op.m2.diagnostics:
                        if op.m2.diagnostics[d]['active']:
                            op.m2.diagnostics[d]['count'] = count[i]
                    op.m2._increment_diagnostics()
                    count[i]+=1

                # 1. save diagnostics to initialize m2 diagnostics in next iteration   
                if len(diag_temp) != 3:
                    diag_temp.append(op.m2.diagnostics)
                diag_temp[i] = op.m2.diagnostics 

                # ———————————————————— TIME —————————————————————
                # 1. transform models into datasets
                # 2. set the time variable of the coarsened model
                temp = op.m2.to_dataset().copy(deep=True)
                temp['time'] = [high_res.t]
                fcg_lowres.append(temp)

                # —————————————————— FORCINGS ——————————————————
                # 1. add forcing terms to dataset
                for var in ['q','u','v']:
                    fcg_lowres[i][var+'_subgrid_forcing'] = (['lev','y','x'], op.subgrid_forcing(var))

                uq_flux, vq_flux = op.subgrid_fluxes('q')
                fcg_lowres[i]['uq_subgrid_flux'] = (['lev','y','x'], uq_flux)
                fcg_lowres[i]['vq_subgrid_flux'] = (['lev','y','x'], vq_flux)

                uu_flux, uv_flux = op.subgrid_fluxes('u')
                vv_flux, vu_flux = op.subgrid_fluxes('v')
                fcg_lowres[i]['uu_subgrid_flux'] = (['lev','y','x'], uu_flux)
                fcg_lowres[i]['vv_subgrid_flux'] = (['lev','y','x'], vv_flux)
                fcg_lowres[i]['uv_subgrid_flux'] = (['lev','y','x'], uv_flux)

                dqdt_bar = op.coarsen(op.m1.dqhdt)
                dqbar_dt = op.to_real(op.m2.dqhdt)
                fcg_lowres[i]['dqdt_bar'] = (['lev','y','x'], dqdt_bar)
                fcg_lowres[i]['dqbar_dt'] = (['lev','y','x'], dqbar_dt)

            forcing_1.append(fcg_lowres[0])
            forcing_2.append(fcg_lowres[1])
            forcing_3.append(fcg_lowres[2])
        high_res._step_forward()
    
    forcing = [forcing_1, forcing_2, forcing_3]

    # 1. create dataset by joining datasets along the temporal axis 
    ds_forcing = []
    for f in forcing:
        ds_forcing.append(xr.concat(f, dim='time'))
    
    # 2. clean up/tie up loose ends
    for i,f in enumerate(forcing):
        # 2a. add back diagnostics that may have been dropped by the above since some lack a time coordinate
        for k,v in f[-1].variables.items():
            if k not in ds_forcing[i]:
                ds_forcing[i][k] = v.isel(time=-1)
        
        # 2b. drop complex vars since they cannot be saved
        complex_vars = [k for k,v in ds_forcing[i].variables.items() if np.iscomplexobj(v)]
        ds_forcing[i] = ds_forcing[i].drop_vars(complex_vars)
        ds_forcing[i] = ds_forcing[i].drop_vars('dqdt')

        # 2c. add missing attributes  
        ds_forcing[i].attrs['hires'] = nx_hires 
        ds_forcing[i].attrs['lores'] = nx_lores
        base_kwargs['nx'] = nx_lores
        ds_forcing[i].attrs['pyqg_params'] = base_kwargs   
        
    return ds_forcing