In [27]:
import argparse
import copy
import gzip
import heapq
import itertools
import os
import pickle
from collections import defaultdict
from itertools import count
import matplotlib.pyplot as plt

import numpy as np
from scipy.stats import norm
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.distributions.categorical import Categorical

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
import ipywidgets as widgets
from ipywidgets import interact


In [28]:
_dev = [torch.device('cpu')]
# tf = lambda x: torch.FloatTensor(x).to(_dev[0])
# tl = lambda x: torch.LongTensor(x).to(_dev[0])
tf = lambda x: torch.FloatTensor(np.array(x)).to(_dev[0])  # Convert to numpy array first
tl = lambda x: torch.LongTensor(np.array(x)).to(_dev[0])

def set_device(dev):
    _dev[0] = dev 

def func_corners(x):
    ax = abs(x)
    return (ax > 0.5).prod(-1) * 0.5 + ((ax < 0.8) * (ax > 0.6)).prod(-1) * 2 + 1e-1



# Define the sigmoid function
def sigmoid(z):
    return 1 / (1 + np.exp(-z))

# Define the dynamical system for the n-node system with sigmoid
def node_system_with_sigmoid(x, t, coord):
    matrix_dim = len(x)
    M_tilde = np.reshape(coord, (matrix_dim, matrix_dim))
    z = M_tilde.dot(x)  # Compute M_tilde * x
    sigmoid_z = sigmoid(z)
    dxdt = sigmoid_z - x  # Compute the derivative   
    return dxdt

# Calculate reward given the weights
def reward_oscillator(coord, ndim):
    delta = 0.0001  # 0.0001
    matrix_dim = int(np.sqrt(ndim))
    x0 = np.linspace(0, 1, matrix_dim, endpoint=False)  # Initial conditions
    t = np.linspace(0, 20, 200)  # Define the time points
    sol = odeint(node_system_with_sigmoid, x0, t, args=(coord, ))
    
    # Calculate the total number of sharp peaks across all time series
    total_peaks = 0
    for i in range(matrix_dim):  # Loop through each time series x1, x2, ..., xn
        x_i = sol[:, i]
        dx_i = np.diff(x_i)  # First derivative approximation
        peaks = 0
        for j in range(1, len(dx_i)):
            if dx_i[j-1] > 0 and dx_i[j] < 0:  # Detect a peak
                sharpness = x_i[j] - (x_i[j-1] + x_i[j+1]) / 2
                if sharpness > delta:  # Check if the peak is sharp
                    peaks += 1
        total_peaks += peaks  # Add the number of sharp peaks for this time series
    
    if total_peaks == 0:
        return 2.5e-5  # see log_reg_c
    else:
        return total_peaks**3  # number of peaks   


In [29]:

class GridEnv:

    def __init__(self, horizon, ndim=2, multiplier=2, func=None):
        self.horizon = horizon
        self.ndim = ndim
        self.multiplier = multiplier
        self.func = func   # Sets the reward function.
        self._true_density = None
 
    def obs(self, s=None):
        """
        Returns a one-hot encoded observation of the current state.
        The observation is a flattened vector representing the agent's position in the grid.
        """
        s = np.int32(self._state if s is None else s)
        z = np.zeros((self.horizon * self.ndim), dtype=np.float32)
        z[np.arange(len(s)) * self.horizon + s] = 1 
        return z    # one-hot agent's current position in the grid.
    
    def s2x(self, s):
        """
        Transform the grid of indices (state s) to spherical coordinates and then calculate the cartesian coordinates
        (x_1, x_2, ..., x_n) for n dimensions.
        
        s[0]: radial distance (r)
        s[1:n-1]: polar angles (theta_1, theta_2, ..., theta_{n-2})
        s[n-1]: azimuthal angle (phi)
        
        multiplier: scalar value to multiply the state s (default is 1)
        """
        # Apply the multiplier to the state s
        s = s * self.multiplier
        
        # Initialize the radius (r) from the first component of the state vector
        r = s[0]
    
        # Initialize an array to hold the Cartesian coordinates
        x = np.zeros(self.ndim)
    
        # Constants or parameters (assuming self.horizon refers to the number of steps in the grid)
        horizon = self.horizon
        
        # Calculate the spherical to Cartesian conversion
        product = r
        for i in range(1, self.ndim):
            if i == self.ndim - 1:
                # The last angle (phi) ranges from 0 to 2π
                phi = s[i] * 2 * np.pi / horizon
                x[i - 1] = product * np.cos(phi)
                x[i] = product * np.sin(phi)
            else:
                # The other angles (theta) range from 0 to π
                theta = s[i] * np.pi / horizon
                x[i - 1] = product * np.sin(theta)
                product *= np.cos(theta)
    
        return x
    
    def reset(self):
        """
        Resets the environment to the initial state.
        """
        self._state = np.int32([0] * self.ndim)   # start position (0,0...)
        self._step = 0
        return self.obs(), self.func(self.s2x(self._state), self.ndim), self._state

    def parent_transitions(self, s, used_stop_action):
        """
        Determines the parent states and corresponding actions that could have led to the current state.
        
        Parameters:
        - s: The current state.
        - used_stop_action: A boolean indicating if the stop action was used.
        
        Returns:
        - A list of possible parent states (one-hot encoded).
        - A list of corresponding actions.
        """
        if used_stop_action:
            return [self.obs(s)], [self.ndim]
            
        parents = []
        actions = []
        for i in range(self.ndim):
            if s[i] > 0:
                sp = s.copy()  # s + 0
                sp[i] -= 1
                if sp.max() == self.horizon - 1:  # Can't have a terminal parent
                    continue
                parents.append(self.obs(sp))  # Generate observation for parent state
                actions.append(i)
        return parents, actions

    
    def step(self, a):
        """
        Updates the environment's state based on the action `a` and 
        returns the new observation, reward, done signal, and new state.
        """
        s = self._state.copy()
        if a < self.ndim:
            s[a] += 1
        
        done = s.max() >= self.horizon - 1 or a == self.ndim
        self._state = s  # Update the internal state
        self._step += 1  # Increment step counter
        
        return self.obs(), 0 if not done else self.func(self.s2x(s), self.ndim), done, s



class ReplayBuffer:
    def __init__(self, args, env):
        self.buf = []
        self.strat = args.replay_strategy
        self.sample_size = args.replay_sample_size
        self.bufsize = args.replay_buf_size
        self.env = env

    def add(self, x, r_x):
        if self.strat == 'top_k':
            if len(self.buf) < self.bufsize or r_x > self.buf[0][0]:
                self.buf = sorted(self.buf + [(r_x, x)])[-self.bufsize:]

    def sample(self):
        if not len(self.buf):
            return []
        idxs = np.random.randint(0, len(self.buf), self.sample_size)
        return sum([self.generate_backward(*self.buf[i]) for i in idxs], [])  # Samples from the buffer and generates trajectories backward.

    def generate_backward(self, r, s0):
        s = np.int8(s0)
        os0 = self.env.obs(s)
        # If s0 is a forced-terminal state, the the action that leads
        # to it is s0.argmax() which .parents finds, but if it isn't,
        # we must indicate that the agent ended the trajectory with
        # the stop action
        used_stop_action = s.max() < self.env.horizon - 1
        done = True
        # Now we work backward from that last transition
        traj = []
        while s.sum() > 0:
            parents, actions = self.env.parent_transitions(s, used_stop_action)
            # add the transition
            traj.append([tf(i) for i in (parents, actions, [r], [self.env.obs(s)], [done])])
            # Then randomly choose a parent state
            if not used_stop_action:
                i = np.random.randint(0, len(parents))
                a = actions[i]
                s[a] -= 1
            # Values for intermediary trajectory states:
            used_stop_action = False
            done = False
            r = 0
        return traj  # Generates a trajectory by working backward from a terminal state.

def make_mlp(l, act=nn.LeakyReLU(), tail=[]):
    return nn.Sequential(*(sum(
        [[nn.Linear(i, o)] + ([act] if n < len(l)-2 else [])
         for n, (i, o) in enumerate(zip(l, l[1:]))], []) + tail))
    
class FlowNetAgent:
    def __init__(self, args, envs):
        self.model = make_mlp([args.horizon * args.ndim] +
                              [args.n_hid] * args.n_layers +
                              [args.ndim+1])
        self.model.to(args.dev)
        self.target = copy.deepcopy(self.model)
        self.envs = envs
        self.ndim = args.ndim
        self.tau = args.bootstrap_tau
        self.replay = ReplayBuffer(args, envs[0])
        self.log_reg_c = args.log_reg_c
        
        # to store training data
        self.trn_all_losses = []
        self.trn_all_visited_done = []


    def parameters(self):
        return self.model.parameters()

    def sample_many(self, mbsize, all_visited_done):
        """Collects transition data from multiple parallel trajectories."""
        batch = []  # store transitions.
        batch += self.replay.sample()
        s = tf([i.reset()[0] for i in self.envs])
        done = [False] * mbsize
        while not all(done):
            # Note to self: this is ugly, ugly code
            with torch.no_grad():
                acts = Categorical(logits=self.model(s)).sample()   # Samples actions based on model's logits.
            step = [i.step(a) for i,a in zip([e for d, e in zip(done, self.envs) if not d], acts)]
            p_a = [self.envs[0].parent_transitions(sp_state, a == self.ndim)
                   for a, (sp, r, done, sp_state) in zip(acts, step)]
            batch += [[tf(i) for i in (p, a, [r], [sp], [d])]
                      for (p, a), (sp, r, d, _) in zip(p_a, step)]
            c = count(0)
            m = {j:next(c) for j in range(mbsize) if not done[j]}
            done = [bool(d or step[m[i]][2]) for i, d in enumerate(done)]
            s = tf([i[0] for i in step if not i[2]])
            for (_, r, d, sp) in step:
                if d:
                    all_visited_done.append((tuple(sp), r))  # (state, reward) pairs
                    self.replay.add(tuple(sp), r) 
        return batch  # it returns a batch of collected transitions for training. {parents, actions, reward, next_state, done}

    def sample_one_traj(self):
        """
        Samples a single trajectory and returns it.
        """
        traj = []
        env = self.envs[0]
        traj.append([[], [], 0, np.int32([0] * self.ndim), False])
        
        s = tf(env.reset()[0])
        done = False
        while not done:
            with torch.no_grad():
                logits = self.model(s.unsqueeze(0))
                action_dist = Categorical(logits=logits)
                a = action_dist.sample().item()
            sp, r, done_flag, sp_state = env.step(a)
            parent_states, parent_actions = env.parent_transitions(sp_state, a == self.ndim)
            traj.append([parent_states, parent_actions, r, sp_state, done_flag])
            
            s = tf(sp)
            done = done_flag
        return traj

    def learn_from(self, it, batch):
        loginf = tf([1000])
        batch_idxs = tl(sum([[i]*len(parents) for i, (parents,_,_,_,_) in enumerate(batch)], []))
        parents, actions, r, sp, done = map(torch.cat, zip(*batch))
        parents_Qsa = self.model(parents)[torch.arange(parents.shape[0]), actions.long()]
        in_flow = torch.log(self.log_reg_c + torch.zeros((sp.shape[0],))
                            .index_add_(0, batch_idxs, torch.exp(parents_Qsa)))
        if self.tau > 0:
            with torch.no_grad(): next_q = self.target(sp)
        else:
            next_q = self.model(sp)
        next_qd = next_q * (1-done).unsqueeze(1) + done.unsqueeze(1) * (-loginf)
        out_flow = torch.logsumexp(torch.cat([torch.log(self.log_reg_c + r)[:, None], next_qd], 1), 1)
        
        term_loss = ((in_flow - out_flow) * done).pow(2).sum() / (done.sum() + 1e-20)
        flow_loss = ((in_flow - out_flow) * (1-done)).pow(2).sum() / ((1-done).sum() + 1e-20)
        
        # loss = (in_flow - out_flow).pow(2).mean()
        leaf_coef = 10
        loss = term_loss * leaf_coef + flow_loss

        if self.tau > 0:
            for a,b in zip(self.model.parameters(), self.target.parameters()):
                b.data.mul_(1-self.tau).add_(self.tau*a)

        return loss, term_loss, flow_loss    




In [30]:
# Training 

def make_opt(params, args):
    params = list(params)
    if not len(params):
        return None
    if args.opt == 'adam':
        opt = torch.optim.Adam(params, args.learning_rate,
                               betas=(args.adam_beta1, args.adam_beta2))
    elif args.opt == 'msgd':
        opt = torch.optim.SGD(params, args.learning_rate, momentum=args.momentum)
    return opt

def compute_empirical_reward_distribution(visited):
    if not len(visited):
        return {}
    reward_hist = defaultdict(int)
    for _, reward in visited:
        reward_hist[reward] += 1
    total_visits = sum(reward_hist.values())
    empirical_distribution = {reward: count / total_visits for reward, count in reward_hist.items()}
    return empirical_distribution



all_losses = []
all_visited_done = []

def main(args):
    args.dev = torch.device(args.device)
    set_device(args.dev)
    f = {'default': None,
         'corners': func_corners,
         'oscillator': reward_oscillator,
    }[args.func]
    
    env = GridEnv(args.horizon, args.ndim, multiplier=args.multiplier, func=f)
    envs = [GridEnv(args.horizon, args.ndim, multiplier=args.multiplier, func=f)
            for i in range(args.mbsize)] 
    ndim = args.ndim
    nnode = args.nnode

    if args.method == 'flownet':
        agent = FlowNetAgent(args, envs)
    elif args.method == 'mcmc':
        agent = MHAgent(args, envs)
    elif args.method == 'random_traj':
        agent = RandomTrajAgent(args, envs)

    opt = make_opt(agent.parameters(), args)

        
    
    # Log file setup
    root = args.save_path 
    log_file_path = os.path.join(root, f'trn-out-{args.nnode}-node.log')
    os.makedirs(os.path.dirname(log_file_path), exist_ok=True)
    with open(log_file_path, 'w') as log_file:
    
        # Training Loop Setup
        
        ttsr = max(int(args.train_to_sample_ratio), 1) # train to sample ratio
        sttr = max(int(1/args.train_to_sample_ratio), 1) # sample to train ratio
        
        for i in tqdm(range(args.n_train_steps+1), disable=not args.progress):
            data = []  # a list of transitions from a batch of trajectories
            for j in range(sttr):
                """Agent samples trajectories for training."""
                data += agent.sample_many(args.mbsize, all_visited_done)   
            for j in range(ttsr):
                """Agent updates its model using the sampled data."""
                losses = agent.learn_from(i * ttsr + j, data) # returns (opt loss, *metrics)
                if losses is not None:
                    losses[0].backward(retain_graph=(not i % 50))
                    if args.clip_grad_norm > 0:
                        torch.nn.utils.clip_grad_norm_(agent.parameters(),
                                                       args.clip_grad_norm)
                    opt.step()
                    opt.zero_grad()
                    all_losses.append([i.item() for i in losses])
        
            # Log empirical reward every 100 iterations
            if not i % 100:
                empirical_distribution = compute_empirical_reward_distribution(all_visited_done[-args.num_empirical_loss:])
                print('Partial Empirical Reward Distribution:', empirical_distribution)
                log_file.write(f'Partial Empirical Reward Distribution: {empirical_distribution}\n')
                log_file.flush()  # Ensure data is written to the log file
                        
            # Save the agent and model every 1000 iterations
            if not i % 1000:
                # Update agent with current all_losses and all_visited_done
                agent.trn_all_losses = all_losses.copy()
                agent.trn_all_visited_done = all_visited_done.copy()
                
                # Save the entire agent
                agent_save_path = os.path.join(root, f"agent_checkpoint_{i}.pkl.gz")  
                with gzip.open(agent_save_path, 'wb') as f: 
                    pickle.dump(agent, f)
                print(f"Agent checkpoint saved at iteration {i} in {nnode}.")
                log_file.write(f"Agent checkpoint saved at iteration {i} in {nnode}.\n")
                log_file.flush() 
            
                # Save the agent's model separately
                model_save_path = os.path.join(root, f"model_checkpoint_{i}.pkl.gz") 
                with gzip.open(model_save_path, 'wb') as f:  
                    pickle.dump(agent.model, f)
                print(f"Model checkpoint saved at iteration {i} in {nnode}.")
                log_file.write(f"Model checkpoint saved at iteration {i} in {nnode}.\n")
                log_file.flush() 


In [31]:
class Args:
    save_path = 'results-v3/7-node-v1'
    device = 'cpu'
    progress = True  
    method = 'flownet'
    learning_rate = 5e-4
    opt = 'adam'
    adam_beta1 = 0.9
    adam_beta2 = 0.999
    momentum = 0.9  # SGD with momentum
    mbsize = 8  # number of parallel environments (trajectories) are collected by one agent (One Agent's model is shared in Many Environments). 
    train_to_sample_ratio = 1.0  # determines how many times the agent should update its model (train) for each set of data it collects from the environment. 
    clip_grad_norm = 0.
    n_hid = 256  # number of hidden units in each hidden layer
    n_layers = 3
    n_train_steps = 90000 
    num_empirical_loss = 200000  # number of samples used to compute the empirical distribution loss during evaluation.
    
    
    # Env
    func = 'oscillator'
    horizon = 12  # 4*3
    nnode = 7
    ndim = nnode*nnode 
    multiplier = 2
    
    # Flownet
    bootstrap_tau = 0.0  # no bootstrapping,target network isn't being updated gradually but possibly replaced entirely at some point.
    replay_strategy = 'top_k'  # 'top_k' or 'none'
    replay_sample_size = 5  # number of experiences to sample from the replay buffer at each update step.
    replay_buf_size = 100  #  size of the replay buffer, which stores past experiences for the agent to learn from.
    log_reg_c = 2.5e-5



args = Args()
torch.set_num_threads(200)
main(args)



  0%|          | 1/90001 [00:00<12:50:12,  1.95it/s]

Partial Empirical Reward Distribution: {2.5e-05: 1.0}
Agent checkpoint saved at iteration 0 in 7.
Model checkpoint saved at iteration 0 in 7.


  0%|          | 101/90001 [01:27<11:59:54,  2.08it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.8923267326732673, 1: 0.10643564356435643, 8: 0.0012376237623762376}


  0%|          | 201/90001 [02:21<15:45:04,  1.58it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9067164179104478, 1: 0.09079601990049752, 8: 0.0024875621890547263}


  0%|          | 301/90001 [03:36<17:25:48,  1.43it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9136212624584718, 1: 0.08388704318936877, 8: 0.0024916943521594683}


  0%|          | 401/90001 [04:40<16:39:35,  1.49it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9223815461346634, 1: 0.07543640897755612, 8: 0.0021820448877805485}


  1%|          | 501/90001 [05:41<16:14:13,  1.53it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9301397205588823, 1: 0.06786427145708583, 8: 0.001996007984031936}


  1%|          | 601/90001 [06:29<8:45:41,  2.83it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9363560732113144, 1: 0.06198003327787022, 8: 0.0016638935108153079}


  1%|          | 701/90001 [07:05<13:15:00,  1.87it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.942225392296719, 1: 0.056348074179743225, 8: 0.0014265335235378032}


  1%|          | 801/90001 [07:50<11:54:43,  2.08it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9464731585518102, 1: 0.0516541822721598, 8: 0.0018726591760299626}


  1%|          | 901/90001 [08:35<12:05:35,  2.05it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9504716981132075, 1: 0.047863485016648166, 8: 0.001664816870144284}


  1%|          | 1000/90001 [09:22<10:08:30,  2.44it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.954045954045954, 1: 0.044455544455544456, 8: 0.0014985014985014985}


  1%|          | 1001/90001 [09:24<20:33:25,  1.20it/s]

Agent checkpoint saved at iteration 1000 in 7.
Model checkpoint saved at iteration 1000 in 7.


  1%|          | 1101/90001 [10:05<8:30:50,  2.90it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9564032697547684, 1: 0.04212079927338783, 8: 0.0014759309718437783}


  1%|▏         | 1201/90001 [10:45<8:24:41,  2.93it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.958368026644463, 1: 0.040070774354704415, 8: 0.0014571190674437969, 27: 0.00010407993338884263}


  1%|▏         | 1301/90001 [11:26<10:32:19,  2.34it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9591660261337432, 1: 0.03920061491160646, 8: 0.0014411990776325902, 27: 0.0001921598770176787}


  2%|▏         | 1401/90001 [12:08<9:12:38,  2.67it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9596716630977873, 1: 0.03854389721627409, 8: 0.001516773733047823, 27: 0.0002676659528907923}


  2%|▏         | 1501/90001 [12:47<14:28:13,  1.70it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9601932045303131, 1: 0.03772485009993338, 8: 0.0018321119253830779, 27: 0.0002498334443704197}


  2%|▏         | 1602/90001 [13:21<7:11:14,  3.42it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9626795752654591, 1: 0.035368519675203, 8: 0.001717676452217364, 27: 0.00023422860712054967}


  2%|▏         | 1701/90001 [13:55<7:38:57,  3.21it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9648736037624926, 1: 0.033289241622574954, 8: 0.0016166960611405056, 27: 0.0002204585537918871}


  2%|▏         | 1801/90001 [14:36<12:17:34,  1.99it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9665463631315936, 1: 0.03171848972792893, 8: 0.001526929483620211, 27: 0.0002082176568573015}


  2%|▏         | 1901/90001 [15:16<11:12:45,  2.18it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9679116254602841, 1: 0.030444502893214098, 8: 0.0014466070489216202, 27: 0.00019726459758022093}


  2%|▏         | 2000/90001 [15:51<7:08:28,  3.42it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9694527736131934, 1: 0.028985507246376812, 8: 0.001374312843578211, 27: 0.0001874062968515742}


  2%|▏         | 2001/90001 [15:54<24:46:37,  1.01s/it]

Agent checkpoint saved at iteration 2000 in 7.
Model checkpoint saved at iteration 2000 in 7.


  2%|▏         | 2101/90001 [16:29<7:07:38,  3.43it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9700142789148025, 1: 0.02849833412660638, 8: 0.0013089005235602095, 27: 0.00017848643503093765}


  2%|▏         | 2201/90001 [17:05<9:19:27,  2.62it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9698432530667879, 1: 0.02862335302135393, 8: 0.0013062244434348023, 27: 0.00022716946842344388}


  3%|▎         | 2301/90001 [17:39<7:09:45,  3.40it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9711538461538461, 1: 0.027379400260756193, 8: 0.0012494567579313341, 27: 0.000217296827466319}


  3%|▎         | 2401/90001 [18:13<8:33:16,  2.84it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9723032069970845, 1: 0.02629112869637651, 8: 0.001197417742607247, 27: 0.00020824656393169514}


  3%|▎         | 2501/90001 [18:50<7:24:30,  3.28it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9734106357457017, 1: 0.025239904038384647, 8: 0.0011495401839264295, 27: 0.00019992003198720512}


  3%|▎         | 2601/90001 [19:24<8:56:13,  2.72it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9744329104190695, 1: 0.02426951172625913, 8: 0.0011053440984236831, 27: 0.00019223375624759708}


  3%|▎         | 2701/90001 [19:59<8:35:49,  2.82it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9753794890781192, 1: 0.023370973713439467, 8: 0.0010644205849685302, 27: 0.00018511662347278786}


  3%|▎         | 2801/90001 [20:33<9:13:42,  2.62it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.976258479114602, 1: 0.022536594073545163, 8: 0.001026419136022849, 27: 0.0001785076758300607}


  3%|▎         | 2901/90001 [21:07<8:49:42,  2.74it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9770768700448121, 1: 0.02175973802137194, 8: 0.0009910375732506032, 27: 0.0001723543605653223}


  3%|▎         | 3000/90001 [21:41<6:39:36,  3.63it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.97784071976008, 1: 0.02103465511496168, 8: 0.0009580139953348884, 27: 0.00016661112962345885}


  3%|▎         | 3001/90001 [21:44<28:54:40,  1.20s/it]

Agent checkpoint saved at iteration 3000 in 7.
Model checkpoint saved at iteration 3000 in 7.


  3%|▎         | 3101/90001 [22:18<7:05:08,  3.41it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9773863269912931, 1: 0.02128345694937117, 8: 0.0010883585940019349, 27: 0.0002418574653337633}


  4%|▎         | 3201/90001 [23:01<9:23:02,  2.57it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9780927835051546, 1: 0.020618556701030927, 8: 0.0010543580131208998, 27: 0.00023430178069353328}


  4%|▎         | 3301/90001 [23:36<7:20:12,  3.28it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9787564374431991, 1: 0.019993941229930322, 8: 0.0010224174492578007, 27: 0.0002272038776128446}


  4%|▍         | 3401/90001 [24:11<7:58:52,  3.01it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9793810643928257, 1: 0.019406057042046457, 8: 0.000992355189650103, 27: 0.00022052337547780066}


  4%|▍         | 3502/90001 [24:45<7:31:51,  3.19it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9799700085689803, 1: 0.018851756640959727, 8: 0.0009640102827763496, 27: 0.00021422450728363326}


  4%|▍         | 3601/90001 [25:19<8:55:48,  2.69it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9805262427103583, 1: 0.018328242154956955, 8: 0.000937239655651208, 27: 0.00020827547903360177}


  4%|▍         | 3701/90001 [25:57<10:42:10,  2.24it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9810524182653337, 1: 0.01783301810321535, 8: 0.0009119156984598757, 27: 0.00020264793299108348}


  4%|▍         | 3801/90001 [26:38<9:14:18,  2.59it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9815180215732702, 1: 0.017396737700605104, 8: 0.000887924230465667, 27: 0.0001973164956590371}


  4%|▍         | 3901/90001 [27:14<7:29:26,  3.19it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9819917969751346, 1: 0.016950781850807485, 8: 0.0008651627787746732, 27: 0.00019225839528326071}


  4%|▍         | 4000/90001 [27:48<8:35:17,  2.78it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9823794051487128, 1: 0.01658960259935016, 8: 0.0008435391152211947, 27: 0.00018745313671582103}


  4%|▍         | 4001/90001 [27:52<35:24:36,  1.48s/it]

Agent checkpoint saved at iteration 4000 in 7.
Model checkpoint saved at iteration 4000 in 7.


  5%|▍         | 4101/90001 [28:31<8:08:31,  2.93it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9826871494757377, 1: 0.016276517922457937, 8: 0.000853450377956596, 27: 0.000182882223847842}


  5%|▍         | 4201/90001 [29:07<8:27:24,  2.82it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9830695072601762, 1: 0.015918828850273743, 8: 0.0008331349678647941, 27: 0.00017852892168531301}


  5%|▍         | 4301/90001 [29:43<8:38:19,  2.76it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9832597070448733, 1: 0.015752150662636598, 8: 0.0008137642408742153, 27: 0.00017437805161590327}


  5%|▍         | 4401/90001 [30:12<5:59:15,  3.97it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.983498068620768, 1: 0.01550783912747103, 8: 0.0007952738014087707, 27: 0.00019881845035219268}


  5%|▌         | 4501/90001 [30:43<8:31:27,  2.79it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9838091535214397, 1: 0.015218840257720507, 8: 0.0007776049766718507, 27: 0.00019440124416796267}


  5%|▌         | 4601/90001 [31:18<9:17:46,  2.55it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9840252119104542, 1: 0.01502390784612041, 8: 0.0007607041947402739, 27: 0.00019017604868506847}


  5%|▌         | 4701/90001 [31:53<6:39:55,  3.55it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9838598170602, 1: 0.015182939800042544, 8: 0.000771112529249096, 27: 0.00018613061050840247}


  5%|▌         | 4802/90001 [32:26<5:46:09,  4.10it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9836492397417205, 1: 0.015413455530097896, 8: 0.000755051031035201, 27: 0.00018225369714642783}


  5%|▌         | 4901/90001 [32:58<6:50:11,  3.46it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9838553356457865, 1: 0.015226484390940624, 8: 0.0007396449704142012, 27: 0.0001785349928586003}


  6%|▌         | 5000/90001 [33:27<6:33:46,  3.60it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9841531693661267, 1: 0.014947010597880424, 8: 0.0007248550289942012, 27: 0.00017496500699860028}


  6%|▌         | 5001/90001 [33:32<37:53:30,  1.60s/it]

Agent checkpoint saved at iteration 5000 in 7.
Model checkpoint saved at iteration 5000 in 7.


  6%|▌         | 5101/90001 [34:03<7:13:15,  3.27it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9844638306214468, 1: 0.014653989413840423, 8: 0.0007106449715742012, 27: 0.00017153499313860026}


  6%|▌         | 5201/90001 [34:38<7:27:33,  3.16it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9847625456642953, 1: 0.014372236108440685, 8: 0.0006969813497404346, 27: 0.00016823687752355316}


  6%|▌         | 5301/90001 [35:10<6:53:36,  3.41it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9850499905678174, 1: 0.014101112997547632, 8: 0.0006838332390115072, 27: 0.00016506319562346728}


  6%|▌         | 5401/90001 [35:43<7:00:50,  3.35it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9851879281614516, 1: 0.013955748935382336, 8: 0.000694315867431957, 27: 0.0001620070357341233}


  6%|▌         | 5501/90001 [36:16<8:12:08,  2.86it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9854117433193965, 1: 0.013747500454462824, 8: 0.0006816942374113798, 27: 0.00015906198872932194}


  6%|▌         | 5601/90001 [36:51<7:59:53,  2.93it/s] 

Partial Empirical Reward Distribution: {2.5e-05: 0.9854043920728441, 1: 0.013747545081235494, 8: 0.0006918407427245135, 27: 0.0001562221031958579}


  6%|▋         | 5701/90001 [37:23<6:49:27,  3.43it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9856384844764077, 1: 0.01352832836344501, 8: 0.0006797053148570426, 27: 0.00015348184529029996}


  6%|▋         | 5801/90001 [37:56<8:45:39,  2.67it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.9857352180658507, 1: 0.013424409584554388, 8: 0.0006895362868470954, 27: 0.0001508360627478021}


  7%|▋         | 5901/90001 [38:31<8:20:11,  2.80it/s]

Partial Empirical Reward Distribution: {2.5e-05: 0.985510930350788, 1: 0.013578207083545162, 8: 0.0007413997627520759, 27: 0.0001694628029147602}


  7%|▋         | 5976/90001 [38:55<9:07:21,  2.56it/s]


KeyboardInterrupt: 

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming all_losses, all_visited, empirical_distrib_losses are lists of lists or arrays

# Plot the losses
if all_losses:
    all_losses_array = np.array(all_losses)
    plt.figure(figsize=(12, 6))
    for i in range(all_losses_array.shape[1]):
        plt.plot(all_losses_array[:, i], label=f'Loss {i+1}')
    plt.xlabel('Training Step')
    plt.ylabel('Loss')
    plt.title('Training Losses Over Time')
    plt.legend()
    plt.show()


In [None]:
# Sort visited states by reward in descending order
visited_sorted_by_reward = sorted(all_visited_done, key=lambda x: x[1], reverse=True)

# Sample the top 10 states with the highest rewards
top_states = visited_sorted_by_reward[:20]

print("Top 10 States with Highest Rewards:")
for i, (state, reward) in enumerate(top_states):
    print(f"State {i+1}: {state}, Reward: {reward}")


In [None]:
import pickle
import gzip
import os

args = Args()

env = GridEnv(args.horizon, args.ndim, func=f)

agent_path = "/home/dannyhuang/gflownet/frontline/results-v3/7-node/agent_checkpoint_30000.pkl.gz"
print(agent_path)

# Load the trained_agent object from the file
try:
    with gzip.open(agent_path, 'rb') as f:
        trained_agent = pickle.load(f)
    print(trained_agent)
except FileNotFoundError:
    print(f"Error: The file {agent_path} was not found.")
except Exception as e:
    print(f"An error occurred while loading the file: {str(e)}")

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint
import ipywidgets as widgets
from ipywidgets import interact

delta = 0.0001

# Define the sigmoid function
def sigmoid(z):
    return 1 / (1 + np.exp(-z))

# Define the dynamical system for the n-node system with sigmoid
def n_node_system_with_sigmoid(x, t, *weights):
    nnode = int(np.sqrt(len(weights)))
    M_tilde = np.array(weights).reshape((nnode, nnode))
    
    # Compute M_tilde * x
    z = M_tilde.dot(x)
    
    # Apply sigmoid
    sigmoid_z = sigmoid(z)
    
    # Compute the derivative
    dxdt = sigmoid_z - x
    return dxdt

# Function to update plot and calculate reward
def update_plot(*weights):
    nnode = int(np.sqrt(len(weights)))
    ndim = nnode * nnode

    # Initial conditions
    x0 = np.linspace(0, 1, nnode, endpoint=False)

    # Define the time points
    t = np.linspace(0, 20, 200)

    # Solve ODE
    sol = odeint(n_node_system_with_sigmoid, x0, t, args=tuple(weights))
    
    # Clear current plot
    plt.clf()
    
    # Plot the results
    for i in range(nnode):
        plt.plot(t, sol[:, i], label=f'$x_{i+1}$')
    plt.xlabel('Time')
    plt.ylabel('Concentration')
    plt.title(f'{nnode}-Node System Dynamics with Sigmoid')
    plt.legend()
    plt.grid(True)
    plt.show()
    
    # Calculate reward based on the number of sharp peaks
    reward = calculate_reward(sol)
    print(f"Reward: {reward}")

# Function to calculate reward based on the number of sharp peaks
def calculate_reward(sol, delta=delta):
    reward = 0
    for i in range(sol.shape[1]):  # Loop through each time series x1, x2, ..., xn
        x_i = sol[:, i]
        dx_i = np.diff(x_i)  # First derivative approximation
        peaks = 0
        for j in range(1, len(dx_i)):
            if dx_i[j-1] > 0 and dx_i[j] < 0:  # Detect a peak
                sharpness = x_i[j] - (x_i[j-1] + x_i[j+1]) / 2
                if sharpness > delta:  # Check if the peak is sharp
                    peaks += 1
        reward += peaks  # Add the number of sharp peaks as reward
    return reward


# # Example usage
# nnode = 5  # Change this to set the number of nodes
# ndim = nnode * nnode
# test_weights = np.ones(ndim)
# update_plot(*test_weights)

# test_state = [3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
# 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
# test_weight = env.s2x(np.int32(test_state))
# update_plot(*test_weight)




# Create sliders for each weight
import ipywidgets as widgets
from ipywidgets import interactive
from IPython.display import display

# Initialize weights based on the test_state
initial_weights = env.s2x(np.int32([3, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]))

# Create sliders for all 49 weights
sliders = [widgets.FloatSlider(value=initial_weights[i], min=-15, max=15, step=0.1, description=f'Weight {i+1}') for i in range(49)]

# Create an output widget to display the plot
out = widgets.Output()

# Define the update function
def update_plot_with_sliders(**w):
    with out:
        out.clear_output(wait=True)
        new_weights = list(w.values())
        update_plot(*new_weights)

# Create the interactive plot
interactive_plot = interactive(update_plot_with_sliders, **{f'w{i}': slider for i, slider in enumerate(sliders)})

# Display the sliders and the plot
display(widgets.VBox([interactive_plot, out]))

# Initial plot update
update_plot_with_sliders(**{f'w{i}': initial_weights[i] for i in range(49)})

In [None]:
initial_weights