In [3]:
import numpy as np
import pdb

import control
import matplotlib.pyplot as plt
import torch
import torchsde
from torch import nn, profiler, Tensor, nan, vmap
import copy
import matplotlib.pyplot as plt

from tqdm import tqdm
import numpy as np
import pdb

import os
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), '../..')) 
sys.path.append(project_root)  # Add folder_a to sys.path
sys.path.append(os.path.join(project_root, '../..', 'Networks'))  # Add folder_b to sys.path

from Networks.utils import get_gpu_memory, normal, ema_update,mass_append, update_ema, generate_system_canon, load_clean_state_dict, sample_linear, compute_gradnorm,save_model, loss_compute
from helpers import   HyperCoeffsLinearControlStochasticLQRImpl,  find_params
import psutil

from IPython.display import clear_output
from adabelief_pytorch import AdaBelief

is_cuda = torch.cuda.is_available()
device = 'cuda' if is_cuda else 'cpu'

if not is_cuda:
    print("Warning: CUDA not available; falling back to CPU but this is likely to be very slow.")
import gc
path=os.path.abspath(os.getcwd())
gc.collect()
torch.cuda.empty_cache()

In [5]:
#state size
state_size=128
#batch size
batch_size=1024
#weight
weight=0.1
#generate canonical A and B
A,B=generate_system_canon(state_size)


In [None]:
#learning rate
learning_rate=0.001
#EMA coefficient
alpha=0.999
#scheduler decay rate
gamma=0.9999
#total epochs
iter=5000

#store performance data
global_perf=[]
global_eperf=[]
global_stperf=[]
global_grad=[]

#initialize dynamic matrices
At=torch.tensor(A).float().to(device) 
Bt=torch.tensor(B).float().to(device) 
coremat=(At.unsqueeze(0)).repeat(batch_size,1,1)
#choose strategy: HNC or ENC
strategy="HNC"

#total time
t_total=1
t_size=int(32)
ts_exp = torch.linspace(0, t_total, t_size).float().to(device)

for layers in range(1,5):
    #initialize model
    SML= HyperCoeffsLinearControlStochasticLQRImpl(At,Bt,batch_size, device,layers,strategy).to(device) 
    
    #initialize optimizer
    params =  find_params(SML)    
    controller_optimizer = AdaBelief(params, lr=learning_rate, eps=1e-16, betas=(0.9,0.999), weight_decouple = True, rectify = True, amsgrad=False)
    #initialize scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(controller_optimizer, gamma=gamma)
    
    
    SML.batch=batch_size
              
    score= 10000000
    #initialize EMA model
    ema_model = copy.deepcopy(SML)
    
                
    ema_loss, ema_ener, ema_std, enr_std = (0.0,) * 4

    eperf, stperf, eloss, stloss, emaperf, emastperf, emaener, emastener, gradnorms = ([] for _ in range(9))

   
    for j in tqdm(range(iter)):
        #sample initial conditions and system parameters
        amat, x0= sample_linear(dim,batch_size, device)
       
        ini_enr = torch.zeros(batch_size, 1).to(device)
    
        ys_exp = torch.cat((x0,ini_enr),dim=1).to(device)
        
        coremat[:,-1,:]=amat
        SML.At= coremat
        SML.Bt= Bt
        SML.poly.amat=amat

        #forward simulations
        ys_tray = torchsde.sdeint_adjoint(SML, ys_exp, ts_exp,method='reversible_heun', dt=ts_exp[1]-ts_exp[0],
                             adjoint_method='adjoint_reversible_heun',)
        
                
        print(psutil.virtual_memory())
        #calculate loss
        lossvec, logtrick, stdvec, energy, stdener=loss_compute(ys_tray, state_size, weight) 

        print('total loss')
        print(lossvec)
        
        logtrick.backward()
        
        total_sq_norm=compute_gradnorm(SML)    

        print("grad norm")
        print(total_sq_norm)
    
        #update parameters
        controller_optimizer.step()
        controller_optimizer.zero_grad(set_to_none=True)
        #scheduler step
        scheduler.step()
        #update EMA model
        update_ema(SML, ema_model, alpha)
        [ema_loss, ema_ener, enr_std, ema_std]=ema_update(j,alpha, lossvec,energy,stdener,stdvec,ema_loss,ema_ener,enr_std,ema_std)
       
        
        print('ema loss')
        print(ema_loss)
        print('ema_ener')
        print(ema_ener)
        
        [eperf, stperf, eloss, stloss, emaperf, emaener,emastperf,emastener, gradnorms]= mass_append(lossvec,stdvec,energy,stdener, batch_size, ema_loss, ema_ener, enr_std, ema_std, eperf, stperf, eloss, stloss, emaperf, emaener,emastperf,emastener,total_sq_norm, gradnorms)
        #save model
        score= save_model(score, ema_loss, ema_model, SML, path,strategy,layers)
     
        if lossvec>500:
            break
        
        if j%200==0:
            gc.collect()
            clear_output(wait=True)
        del lossvec, stdvec 
    global_stperf.append(stperf)    
    global_perf.append(eperf)
    global_eperf.append(emaperf)
    global_grad.append(gradnorms)
