This code provides all the necessary ingredients to run the QG experiment presented in the paper. We provide here the configuration of the LES simulation (on a reduced grid, with a CNN based closure term). In order to adapt this code to run the online learning experiment, you can do the following: 
* Adjust the parameters values to the One of the DNS (please follow table 3)
* Run DNS simulations (in DNS mode, please remove the CNN closure term)
* Filter the DNS simulations to the LES grid, this generates the training data
* Get back to LES mode (this code) and do the training

You can contact me at said.ouala@imt-atlantique.fr and I can help :)

In [12]:
import math
import torch
from L63.utils import  extract_and_reset_grads
from grid import TwoGrid

In [2]:
# parameters of the LES simulation
scale  = 16
params = {}
params['Lx']         = 2*math.pi     # Length of domain in x-direction
params['Ly']         = 2*math.pi     # Length of domain in y-direction
params['Nx']         = int(1024/scale)          # nb grid points in x
params['Ny']         = int(1024/scale)          # nb grid points in y
params['dt']         = (5E-5)*scale#480 / t_unit() # dt
params['mu']         = 0.1         # linear drag coef
params['nu']         = 1/20000      #viscosity
params['t0']         = 0.0          # Amplitude of wind stress [kg/ms^2]
params['nv']         = 1
params['B']          = 0.0 
params['device']     = "cuda"

params['Constraint_wn'] = 1
params['forget_every'] = 10
params['ntrain'] = 20
params['nb_train_simulations'] = 8
params['pretrained'] = pretrained = True
params['keep_training'] = True#True
params['Batch_size'] = 16
params['seq_size']   = 10

In [3]:
# CNN-based sub-model
class CNN_LES(torch.nn.Module):
    def __init__(self):
        super(CNN_LES, self).__init__()
        # dimensions : 
        self.conv1 = torch.nn.Conv2d(2, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv2 = torch.nn.Conv2d(64, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv3 = torch.nn.Conv2d(64, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv4 = torch.nn.Conv2d(64, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv5 = torch.nn.Conv2d(64, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv6 = torch.nn.Conv2d(64, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv7 = torch.nn.Conv2d(64, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv8 = torch.nn.Conv2d(64, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv9 = torch.nn.Conv2d(64, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv10 = torch.nn.Conv2d(64, 64, 5, padding ='same', padding_mode = "circular").double()
        self.conv11 = torch.nn.Conv2d(64, 1, 5, padding ='same').double()
        self.grid     = TwoGrid(params['device'], Nx=params['Nx'], Ny=params['Ny'], Lx=params['Lx'], Ly=params['Ly'])
    def to_spectral(self, x):
        return torch.fft.rfftn(x, norm='forward', dim = [-2,-1])
    def to_physical(self, x):
        return torch.fft.irfftn(x, norm='forward', dim = [-2,-1])        
    def forward(self,  q):
        qh = self.to_spectral(q)
        ph = -qh * self.grid.irsq.unsqueeze(0)
        p = self.to_physical(ph)
        x = torch.cat([q.unsqueeze(1),p.unsqueeze(1)],dim = 1)
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv5(x))
        x = torch.relu(self.conv6(x))
        x = torch.relu(self.conv7(x))
        x = torch.relu(self.conv8(x))
        x = torch.relu(self.conv9(x))
        x = torch.relu(self.conv10(x))
        x = (self.conv11(x))
        return x.squeeze(1)

In [20]:
# QG flow
class QG(torch.nn.Module):
    def __init__(self, params):
        super(QG, self).__init__()
        # dimensions : 
        self.sgs_     = CNN_LES().to(params['device'])
        self.wave_num = 4#torch.nn.Parameter(torch.randn(1,1)*0+1, requires_grad=True)
        self.mu       = params['mu']
        self.nu       = params['nu']
        self.nv       = params['nv']
        self.B        = params['B']
        self.grid     = TwoGrid(params['device'], Nx=params['Nx'], Ny=params['Ny'], Lx=params['Lx'], Ly=params['Ly'])
        self.d_alias  = TwoGrid(params['device'], Nx=int((3./2.)*params['Nx']), Ny=int((3./2.)*params['Ny']), Lx=params['Lx'], Ly=params['Ly'], dealias=1/3)
        self.Lin      = self.qg_Lin()
    def Fs(self, bs):
        y = (self.wave_num)*(torch.cos((self.wave_num) * self.grid.y).view(self.grid.Ny, 1) + torch.cos((self.wave_num) * self.grid.x).view(1, self.grid.Nx))
        y = y.unsqueeze(0).repeat(bs,1,1)
        yh = self.to_spectral(y)
        return yh
    def qg_Lin(self):
        Lc = -self.mu - self.nu * self.grid.krsq**self.nv + 1j * self.B * self.grid.kr * self.grid.irsq
        Lc[0, 0] = 0
        return Lc.unsqueeze(0)
    def to_spectral(self, x):
        return torch.fft.rfftn(x, norm='forward', dim = [-2,-1])
    def to_physical(self, x):
        return torch.fft.irfftn(x, norm='forward', dim = [-2,-1])
    def reduce(self, x):
        x_r = x.size()
        z = torch.zeros([x_r[0], self.grid.Ny, self.grid.dk], dtype=torch.complex128).to(x.device)
        z[:, :int(self.grid.Ny / 2)               , :self.grid.dk]        = x[:, :int(self.grid.Ny / 2), :self.grid.dk]
        z[:, int(self.grid.Ny / 2):self.grid.Ny, :self.grid.dk] = x[:, x_r[1] - int(self.grid.Ny / 2):x_r[1],:self.grid.dk]
        #x.data = z
        return z#x
    def increase(self, x):
        x_r = x.size()
        z = torch.zeros([x_r[0], self.d_alias.Ny, self.d_alias.dk], dtype=torch.complex128).to(x.device)
        z[:,  :int(x_r[1] / 2), :x_r[2]] = x[:, :int(x_r[1] / 2),        :x_r[2]]
        z[:, self.d_alias.Ny - int(x_r[1] / 2):self.d_alias.Ny,         :x_r[2]] = x[:,  int(x_r[1] / 2):x_r[1], :x_r[2]]
        #x.data = z
        return z#x
    def get_state_vars(self, q):
        qh = self.to_spectral(q)
        ph = -qh * self.grid.irsq.unsqueeze(0)
        uh = -1j * self.grid.ky.unsqueeze(0) * ph
        vh =  1j * self.grid.kr.unsqueeze(0) * ph

        # Potential vorticity
        q = self.to_physical(qh)
        # Streamfunction
        p = self.to_physical(ph)
        # x-axis velocity
        u = self.to_physical(uh)
        # y-axis velocity
        v = self.to_physical(vh)
        return q, p, u, v
    def non_linear(self, qh):
        # compute stream function and u and v in spectral domain
        ph = -qh * self.grid.irsq.unsqueeze(0)
        uh = -1j * self.grid.ky.unsqueeze(0) * ph
        vh =  1j * self.grid.kr.unsqueeze(0) * ph

        qhh = self.increase(qh)
        uhh = self.increase(uh)
        vhh = self.increase(vh)

        q = self.to_physical(qhh)
        u = self.to_physical(uhh)
        v = self.to_physical(vhh)

        uq = u * q
        vq = v * q

        uqhh = self.to_spectral(uq)
        vqhh = self.to_spectral(vq)

        qh = self.reduce(qhh)
        uqh = self.reduce(uqhh)
        vqh = self.reduce(vqhh)


        S = -1j * self.grid.kr.unsqueeze(0) * uqh - 1j * self.grid.ky.unsqueeze(0) * vqh

        return S, ph
    def qg_dyns(self, dt, qh):

        bs = qh.shape[0]

        Nlin, ph = self.non_linear(qh)
        Lin_ = self.Lin*qh

        sgs = self.to_spectral(self.sgs_(self.to_physical(qh)))
        S = Nlin + Lin_ + self.Fs(bs) + sgs

        return S, sgs
    def rk4(self, qh, dt):

        k1, sgs_term     = self.qg_dyns(dt, qh)
        inp_k2 = qh + 0.5*dt*k1

        k2, _     = self.qg_dyns(dt, inp_k2)
        inp_k3 = qh + 0.5*dt*k2       

        k3, _  = self.qg_dyns(dt, inp_k3)
        inp_k4 = qh + dt*k3          

        k4, _  = self.qg_dyns(dt, inp_k4)            
        qhdt = qh +dt*(k1+2*k2+2*k3+k4)/6  

        return qhdt, sgs_term
    
    def model_dt(self, inp, dt, grad_mode='exact'): 
        if grad_mode == 'exact':
            pred, sgs_term = self.rk4(inp, dt)
            return pred, sgs_term
        elif grad_mode == 'EGA-static':
            with torch.no_grad():
                pred, sgs_term = self.rk4(inp, dt)
        sgs = self.to_spectral(self.sgs_(self.to_physical(inp.detach())))
        output_p = dt * sgs + inp
        output_p.data = pred.data
        return output_p, sgs
        
    def enstrophy(self, q):
        return 0.5 * torch.mean(q**2)

    def energy(self, u, v):
        return 0.5 * torch.mean(u**2 + v**2)
    def spectrum(self, y):
        K = torch.sqrt(self.grid.krsq)
        d = 0.5
        k = torch.arange(1, self.grid.Nx // 2)
        m = torch.zeros(k.size())

        e = [torch.zeros(k.size()) for _ in range(len(y))]
        for ik in range(len(k)):
            n = k[ik]
            i = torch.nonzero((K < (n + d)) & (K > (n - d)), as_tuple=True)
            m[ik] = i[0].numel()
            for j, yj in enumerate(y):
                e[j][ik] = torch.sum(yj[i]) * k[ik] * math.pi / (m[ik] - d)
        return k, e

    def J(self, grid, qh):

        ph = -qh * grid.irsq
        uh = -1j * grid.ky * ph
        vh =  1j * grid.kr * ph

        q = self.to_physical(qh)
        u = self.to_physical(uh)
        v = self.to_physical(vh)

        uq = u * q
        vq = v * q

        uqh = self.to_spectral(uq)
        vqh = self.to_spectral(vq)

        J = 1j * grid.kr * uqh + 1j * grid.ky * vqh
        return J
    def fluxes(self, R, qh):
        # resolved rate
        rhss = self.J(self.grid, qh)

        sh = -torch.conj(qh) * rhss
        #self.g_.dealias(sh)
        # modeled rate
        lh =  torch.conj(qh) * R
        #self.g_.dealias(lh)

        k, [sk, lk] = self.spectrum([torch.real(sh), torch.real(lh)])
        return k, sk, lk
    def invariants(self, y):
        #qh = self.p_.sol
        y = self.to_spectral(y)
        qh = y.clone()
        ph = -qh * self.grid.irsq
        uh = -1j * self.grid.ky * ph
        vh =  1j * self.grid.kr * ph

        # kinetic energy
        e = torch.abs(uh)**2 + torch.abs(vh)**2
        #self.g_.dealias(e)

        # enstrophy
        z = torch.abs(qh)**2
        #self.g_.dealias(z)

        k, [ek, zk] = self.spectrum([e, z])
        return k, ek, zk

    def pred_and_diag(self, q0, dt, n):
        # transform pv to spectral
        qn = [q0]
        e_init = []
        ens_init = []
        k_init, specE_init, specZ_init = [], [], []

        k_flux = []
        s_flux = []
        r_flux = []
        r_spect, k_r_spect = [], []
        #qh0 = self.to_spectral(q0)
        qh_all = [self.to_spectral(q0)]

        qxx, pxx, uxx, vxx  = self.get_state_vars(self.to_physical(qh_all[-1]))
        e_init.append(self.energy(uxx[0],vxx[0]))
        ens_init.append(self.enstrophy(qxx[0]))
        k_tmp, ek_tmp, zk_tmp = self.invariants(qxx[0])
        k_init.append(k_tmp), specE_init.append(ek_tmp), specZ_init.append(zk_tmp)



        for i in range(n):
            tmp_pred, tmp_R = self.rk4(qh_all[-1], dt)

            qh_all.append(tmp_pred)
            qxx, pxx, uxx, vxx  = self.get_state_vars(self.to_physical(qh_all[-1]))
            e_init.append(self.energy(uxx[0],vxx[0]))
            ens_init.append(self.enstrophy(qxx[0]))
            k_tmp, ek_tmp, zk_tmp = self.invariants(qxx[0])
            k_init.append(k_tmp), specE_init.append(ek_tmp), specZ_init.append(zk_tmp)

            kr_tmp, [spec_r] = self.spectrum([torch.abs(tmp_R[0])])
            kflux, skflux, lkflux = self.fluxes(tmp_R[0], qh_all[-1][0])
            k_flux.append(kflux)
            s_flux.append(skflux)
            r_flux.append(lkflux)
            r_spect.append(spec_r)
            k_r_spect.append(kr_tmp)
            
        qh_all = torch.stack(qh_all)
        return self.to_physical(qh_all), torch.stack(e_init), torch.stack(ens_init) , torch.stack(k_init), torch.stack(specE_init), torch.stack(specZ_init), torch.stack(k_flux), torch.stack(s_flux) , torch.stack(r_flux), torch.stack(k_r_spect) , torch.stack(r_spect) 
    def forward(self, q0, dt, n, grad_mode = 'exact'):
        # transform pv to spectral
        qh_all = [self.to_spectral(q0)]
        for i in range(n):
            qh_all.append(self.model_dt(qh_all[-1], dt, grad_mode = grad_mode)[0])
        qh_all = torch.stack(qh_all)
        return self.to_physical(qh_all)
QG_model = QG(params).to(params['device'])

In [21]:
X_train = torch.rand((1,64,64)).to(params['device'])

In [27]:
Y_hat = QG_model(X_train, 0.01, 10, grad_mode = 'exact')
loss = torch.mean((Y_hat - torch.zeros_like(Y_hat))**2)
loss.backward()
grad_exact = extract_and_reset_grads(QG_model).detach().cpu().numpy()# this also resets gradients

In [28]:
Y_hat = QG_model(X_train, 0.01, 10, grad_mode = 'EGA-static')
loss = torch.mean((Y_hat - torch.zeros_like(Y_hat))**2)
loss.backward()
grad_EGA_static = extract_and_reset_grads(QG_model).detach().cpu().numpy()