In [1]:
import numpy as np
import pdb
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import sys

project_root = os.path.abspath(os.path.join(os.getcwd(), '../..')) 
sys.path.append(project_root)  # Add folder_a to sys.path
sys.path.append(os.path.join(project_root, '../..', 'Networks'))  # Add folder_b to sys.path
import matplotlib.pyplot as plt
import torch
import torch.optim.swa_utils as swa_utils

from torch import nn
import copy

from torch import profiler
from tqdm import tqdm
import numpy as np

import pdb
import torchsde
from Networks.utils import get_gpu_memory,objective,update_ema, adj_lattice, criticalK, order_param,  load_clean_state_dict, sample_kuramoto,loss_compute_kur,save_model
from helpers import  KuramotoHyperUniversal,  find_params
import psutil
from adabelief_pytorch import AdaBelief
import gc
from IPython.display import clear_output

is_cuda = torch.cuda.is_available()
device = 'cuda' if is_cuda else 'cpu'
#device = 'cpu'
if not is_cuda:
    print("Warning: CUDA not available; falling back to CPU but this is likely to be very slow.")
gc.collect()  
torch.cuda.empty_cache()
path=os.path.abspath(os.getcwd())
print(get_gpu_memory())


[{'device': 0, 'total_memory': 24564.0, 'free_memory': 20583.875, 'used_memory': 3980.125}]


In [3]:
#state size
state_size=64
#batch size
batch_size=1024
#weight
weight=10**(-4)
#scale of frequencies
scale=5
#scale of initial phases
scale_kin=5
#learning rate
learning_rate=0.0001
#scheduler decay rate
gamma=0.999
#ema coefficient
alpha=0.995
#generate graph
Astored=adj_lattice(state_size, "erdos", 0.3)

#calculate critical coupling and laplacian
Kcrit, L=criticalK(Astored,state_size, scale)
#set coupling constant in experiments
K = 0.01* Kcrit

%store  Astored
A=torch.tensor(Astored).to(torch.float32).to(device)
L=torch.tensor(L).to(torch.float32).to(device)
#check memory
print(get_gpu_memory())
#check coupling constant (to avoid Nan, inf)
print(K)



[{'device': 0, 'total_memory': 24564.0, 'free_memory': 20376.1875, 'used_memory': 4187.8125}]
0.841310456374586


In [None]:
#save performance data
global_perf=[]
global_eperf=[]
global_stperf=[]
global_param=[]

dist = torch.distributions.Normal(0, 1)
#number of epochs
iter = 5000
#number of time steps
t_size=32
#total time    
t_total=1

ts_exp = torch.linspace(0, t_total, t_size).to(device)

for layers in range(1,5):
    
    skip=False
    #select strategy: tHNC, HNC, ENC
    strategy="tHNC"
    #initialize model
    SML= KuramotoHyperUniversal(batch_size, state_size, K,A, layers,skip,strategy,device).to(device)
 
    #initialize optimizer
    params=find_params(SML)
    controller_optimizer = torch.optim.AdamW(params, lr=learning_rate,  weight_decay=0.0, amsgrad=False)
    #initialize scheduler
    scheduler = torch.optim.lr_scheduler.ExponentialLR(controller_optimizer, gamma=gamma)
    
    #track performance
    perf=[]
    eperf=[]
    stperf=[]
    erpe=[]
    orpam=[]

    #initialize EMA model
    ema_model = copy.deepcopy(SML)
    ema_loss=0

    score=1000000
 
    for j in tqdm(range(iter)):
        #sample initial conditions and parameters
        freqs, ys_exp= sample_kuramoto(state_size,batch_size, dist, scale, scale_kin, device)

        SML.freqs=freqs
        #forward simulations
        ys_tray = torchsde.sdeint_adjoint(SML, ys_exp, ts_exp,method='reversible_heun', dt=ts_exp[1]-ts_exp[0],
                                     adjoint_method='adjoint_reversible_heun',)
        #compute loss
        loss, ste_loss, logtrick=loss_compute_kur(ys_tray, state_size ,batch_size,weight,A)
        
        print("order parameter")
        ord_param=order_param(ys_tray[-1,:,0:state_size],L)
        print(ord_param)
    
        logtrick.backward()
        print("total loss")
        print(loss)
 
        #perform update of parameters
        controller_optimizer.step()
        controller_optimizer.zero_grad(set_to_none=True)
        #perform scheduler step
        scheduler.step()
        #update EMA model
        update_ema(SML, ema_model, alpha)
        print(get_gpu_memory())
        
        #compute ema loss
        if j==0:
            ema_loss=loss.cpu().detach().numpy()  
        else:
            ema_loss=ema_loss*alpha+loss.cpu().detach().numpy()*(1-alpha)

        #save the model
        score= save_model(score, ema_loss, ema_model, SML, path,strategy,layers)

        #append data
        perf.append(loss.cpu().detach().numpy())
        eperf.append(ema_loss)
        stperf.append(ste_loss.cpu().detach().numpy())
        orpam.append(ord_param.cpu().detach().numpy())
        
        #clear memory
        gc.collect()        
        with torch.no_grad():
            torch.cuda.empty_cache()
        #clear output
        if j>100 and j%500==0:
                clear_output(wait=False) 
        del ys_tray
    #append data    
    global_perf.append(perf)
    global_eperf.append(eperf)
    global_stperf.append(stperf)
    global_param.append(orpam)
    
    with torch.no_grad():
        torch.cuda.empty_cache()
    gc.collect()