<a href="https://colab.research.google.com/github/spatank/GraphRL/blob/main/test_on_larger.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
os.chdir('/content/drive/My Drive/GraphRL/')

In [None]:
#@title Imports

import networkx as nx
from functools import lru_cache
import random
import numpy as np
import time
from copy import deepcopy
import time
from typing import NamedTuple
from tqdm import tqdm
import glob

import matplotlib.pyplot as plt
import matplotlib.animation as animation

plt.rcParams["animation.html"] = "jshtml"

import torch # check version using torch.__version__ before using PyG wheels
import torch.nn as nn
import torch.nn.functional as F
# from torch.utils import checkpoint # unused

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import sys
!{sys.executable} -m pip install -q torch-scatter -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
!{sys.executable} -m pip install -q torch-sparse -f https://data.pyg.org/whl/torch-1.12.0+cu113.html
!{sys.executable} -m pip install -q torch-geometric

import torch_geometric
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GCNConv, SAGEConv
from torch_geometric import utils, transforms

!{sys.executable} -m pip install -q Cython
!{sys.executable} -m pip install -q Ripser

from ripser import ripser

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
#@title Environments

@lru_cache(maxsize = 100000)
def get_NX_subgraph(environment, frozen_set_of_nodes):

  return environment.graph_NX.subgraph(list(frozen_set_of_nodes))

@lru_cache(maxsize = 500000)
def get_PyG_subgraph(environment, frozen_set_of_nodes):

  return environment.graph_PyG.subgraph(torch.tensor(list(frozen_set_of_nodes)))

@lru_cache(maxsize = 100000)
def compute_feature_value(environment, state_subgraph_NX):

  return environment.feature_function(state_subgraph_NX)

@lru_cache(maxsize = 100000)
def get_neighbors(environment, frozen_set_of_nodes, cutoff = 1):
  """
  Returns the n-th degree neighborhood of a set of nodes, where degree 
  is specified by the cutoff argument.
  """

  nodes = list(frozen_set_of_nodes)
  neighbors = set()

  for node in nodes:
    neighbors.update(set(nx.single_source_shortest_path_length(environment.graph_NX, 
                                                               node, 
                                                               cutoff = cutoff).keys()))
    
  neighbors = neighbors - set(nodes) # remove input nodes from their own neighborhood

  if not neighbors:
    neighbors = set(environment.graph_NX.nodes()) - set(environment.visited)

  return list(neighbors)

class GraphEnvironment():
  
  def __init__(self, ID, graph_NX, feature):
    super().__init__()

    self.ID = ID # identifier for the environment

    self.graph_NX = graph_NX # environment graph (NetworkX Graph object)  
    self.graph_PyG = utils.from_networkx(graph_NX, group_node_attrs = all)
    self.num_nodes = self.graph_NX.number_of_nodes()

    self.visited = [random.choice(list(self.graph_NX.nodes()))] # list of visited nodes

    self.state_NX = get_NX_subgraph(self, frozenset(self.visited))
    self.state_PyG = get_PyG_subgraph(self, frozenset(self.visited))

    self.feature_function = feature # function handle to network feature-of-interest
    self.feature_values = [self.feature_function(self.state_NX)] # list to store values of the feature-of-interest
    
  def step(self, action):
    """
    Execute an action in the environment, i.e. visit a new node.
    """

    assert action in self.get_actions(self.visited), "Invalid action!"
    visited_new = deepcopy(self.visited)
    visited_new.append(action) # add new node to list of visited nodes
    self.visited = visited_new
    self.state_NX = get_NX_subgraph(self, frozenset(self.visited))
    self.state_PyG = get_PyG_subgraph(self, frozenset(self.visited))
    reward = self.compute_reward()
    terminal = bool(len(self.visited) == self.graph_NX.number_of_nodes())

    return self.get_state_dict(), reward, terminal, self.get_info()

  def compute_reward(self):

    self.feature_values.append(compute_feature_value(self, self.state_NX))
    reward = sum(self.feature_values)/len(self.visited)

    return reward

  def reset(self):
    """
    Reset to initial state.
    """

    self.visited = [random.choice(list(self.graph_NX.nodes()))] # empty the list of visited nodes
    self.state_NX = get_NX_subgraph(self, frozenset(self.visited))
    self.state_PyG = get_PyG_subgraph(self, frozenset(self.visited))
    self.feature_values = [compute_feature_value(self, self.state_NX)]
    terminal = False

    return self.get_state_dict(), terminal, self.get_info()

  def get_state_dict(self):

    return {'visited': self.visited, 
            'state_NX': self.state_NX, 
            'state_PyG': self.state_PyG}
      
  def get_info(self):
    
    return {'environment_ID': self.ID, # useful for DQN training
            'feature_value': compute_feature_value(self, self.state_NX)}
  
  def get_actions(self, nodes):
    """ 
    Returns available actions given a list of nodes.
    """

    return get_neighbors(self, frozenset(nodes))
  
  def render(self):
    """
    Render current state to the screen.
    """

    plt.figure()
    nx.draw(self.state_NX, with_labels = True)

class MultipleEnvironments():

  def __init__(self, environments):
    
    self.environments = environments
    self.num_environments = len(self.environments)

  def reset(self):

    state_dicts = []
    terminals = []
    all_info = []

    for environment in self.environments:
      state_dict, terminal, info = environment.reset()
      state_dicts.append(state_dict)
      terminals.append(terminal)
      all_info.append(info)

    return state_dicts, terminals, all_info
  
  def step(self, actions):

    state_dicts = []
    rewards = []
    terminals = []
    all_info = []

    for idx, environment in enumerate(self.environments):
      state_dict, reward, terminal, info = environment.step(actions[idx])
      state_dicts.append(state_dict)
      rewards.append(reward)
      terminals.append(terminal)
      all_info.append(info)

    return state_dicts, rewards, terminals, all_info
  
  def __len__(self):
    return self.num_environments

In [None]:
#@title Baseline Agents

class RandomAgent():
  """
  RandomAgent() chooses an action at random. The agent is not deterministic.
  """
  
  def __init__(self):
    super().__init__()
    
    self.environments = None # should be instance of MultipleEnvironments() class
    self.is_trainable = False # useful to manage control flow during simulations

  def choose_action(self):

    if not self.environments:
      assert False, "Supply environment(s) for the agent to interact with."

    actions = []

    for environment in self.environments.environments:
      available_actions = environment.get_actions(environment.visited)
      action = random.choice(available_actions)
      actions.append(action)

    return actions

class HighestDegreeAgent():
  """
  HighestDegreeAgent() chooses the action with the highest node degree. The 
  agent is deterministic.
  """

  def __init__(self):
    super().__init__()

    self.environments = None # should be instance of MultipleEnvironments() class
    self.is_trainable = False # useful to manage control flow during simulations

  def choose_action(self):

    if not self.environments:
      assert False, "Supply environment(s) for the agent to interact with."

    actions = []

    for environment in self.environments.environments:
      available_actions = environment.get_actions(environment.visited)
      all_degrees = list(zip(*(environment.graph_NX.degree(available_actions))))[1]
      action_idx = all_degrees.index(max(all_degrees)) # first largest when ties
      action = available_actions[action_idx]
      actions.append(action)

    return actions

class LowestDegreeAgent():
  """
  LowestDegreeAgent() chooses the action with the lowest node degree. The 
  agent is deterministic.
  """

  def __init__(self):
    super().__init__()

    self.environments = None # should be instance of MultipleEnvironments() class
    self.is_trainable = False # useful to manage control flow during simulations

  def choose_action(self):

    if not self.environments:
      assert False, "Supply environment(s) for the agent to interact with."

    actions = []

    for environment in self.environments.environments:
      available_actions = environment.get_actions(environment.visited)
      all_degrees = list(zip(*(environment.graph_NX.degree(available_actions))))[1]
      action_idx = all_degrees.index(min(all_degrees)) # first smallest when ties
      action = available_actions[action_idx]
      actions.append(action)

    return actions

class GreedyAgent():
  """
  GreedyAgent() chooses the action that would result in the greatest reward.
  The agent uses a copy of the environment to simulate each available action and 
  returns the best performing action. The agent is deterministic.
  """

  def __init__(self):
    super().__init__()

    self.environments = None # should be instance of MultipleEnvironments() class
    self.is_trainable = False # useful to manage control flow during simulations

  def choose_action(self):

    if not self.environments:
      assert False, "Supply environment(s) for the agent to interact with."

    actions = []

    for environment in self.environments.environments:
      available_actions = environment.get_actions(environment.visited)
      best_reward = float('-inf')
      best_action = None

      for action in available_actions:
        environment_copy = deepcopy(environment)
        state_dict, reward, terminal, info = environment_copy.step(action)

        if reward > best_reward:
          best_reward = reward
          best_action = action

      actions.append(best_action)

    return actions

In [None]:
#@title DQN Agent

class GNN(nn.Module):

  def __init__(self, hyperparameters):
    super().__init__()
    
    self.conv1 = SAGEConv(
        hyperparameters['num_node_features'],
        hyperparameters['GNN_latent_dimensions'],
        aggr = 'mean')
    self.conv2 = SAGEConv(
        hyperparameters['GNN_latent_dimensions'],
        hyperparameters['embedding_dimensions'],
        aggr = 'mean')

  def forward(self, x, edge_index, batch = None):

    x = self.conv1(x, edge_index)
    x = F.relu(x)
    x = self.conv2(x, edge_index)
    x = F.relu(x) # node embeddings
    x = torch_geometric.nn.global_add_pool(x, batch = batch) # graph embedding

    return x

class QN(nn.Module):

  def __init__(self, hyperparameters):
    super().__init__()

    self.fc1 = nn.Linear(hyperparameters['embedding_dimensions'], 
                         hyperparameters['QN_latent_dimensions'])
    self.fc2 = nn.Linear(hyperparameters['QN_latent_dimensions'], 1)

  def forward(self, x):

    x = self.fc1(x)
    x = F.relu(x)
    x = self.fc2(x)

    return x

class DQNAgent():

  def __init__(self, embedding_module, q_net, 
               replay_buffer, train_start, batch_size, learn_every,
               optimizer, 
               epsilon, epsilon_decay_rate, epsilon_min):
    super().__init__()

    self.environments = None # should be instance of MultipleEnvironments() class
    self.is_trainable = True # useful to manage control flow during simulations
    
    self.embedding_module = embedding_module
    self.q_net = q_net
    
    self.target_embedding_module = deepcopy(embedding_module)
    self.target_q_net = deepcopy(q_net)
    
    # disable gradients for target networks
    for parameter in self.target_embedding_module.parameters():
      parameter.requires_grad = False

    for parameter in self.target_q_net.parameters():
      parameter.requires_grad = False
    
    self.replay_buffer = replay_buffer
    self.train_start = train_start # specify burn-in period
    self.batch_size = batch_size
    self.learn_every = learn_every # steps between updates to target nets

    self.optimizer = optimizer

    self.epsilon = epsilon # probability with which to select a non-greedy action
    self.epsilon_decay_rate = epsilon_decay_rate
    self.epsilon_min = epsilon_min

    self.step = 0

  def choose_action(self):
    """
    Choose an action to perform for each environment in self.environments.
    """

    if not self.environments:
      assert False, "Supply environment(s) for the agent to interact with."

    actions = []

    for environment in self.environments.environments:
      available_actions = environment.get_actions(environment.visited)
      new_subgraphs = [] # list to store all possible next states

      for action in available_actions:
        visited_nodes_new = deepcopy(environment.visited)
        visited_nodes_new.append(action)
        new_subgraph = get_PyG_subgraph(environment, frozenset(visited_nodes_new))
        new_subgraphs.append(new_subgraph)

      # create a batch to allow for a single forward pass
      batch = Batch.from_data_list(new_subgraphs)

      # gradients for the target networks are disabled
      with torch.no_grad(): # technically redundant
        q_values = self.target_q_net(self.target_embedding_module(batch.x, 
                                                                  batch.edge_index, 
                                                                  batch.batch))
      if torch.rand(1) < self.epsilon: # explore
        action = np.random.choice(available_actions)
      else: # exploit
        action_idx = torch.argmax(q_values).item()
        action = available_actions[action_idx]

      actions.append(action)

    return actions

  def train(self, state_dicts, actions, next_state_dicts, rewards, discounts, all_info):

    self.replay_buffer.add(state_dicts, actions, next_state_dicts, rewards, discounts, all_info)
    self.step += 1
    
    if self.step < self.train_start: # inside the burn-in period
      return 

    # (1) Get lists of experiences from memory
    states, actions, next_states, rewards, discounts, all_info = self.replay_buffer.sample(self.batch_size)
    
    # (2) Build state + action = new subgraph (technically identical to next state)
    new_subgraphs = []
    for idx, state_dict in enumerate(states):
      visited_nodes_new = deepcopy(state_dict['visited'])
      visited_nodes_new.append(actions[idx])
      assert visited_nodes_new == next_states[idx]['visited'], "train() assertion failed."
      new_subgraph = get_PyG_subgraph(self.environments.environments[all_info[idx]['environment_ID']], 
                                      frozenset(visited_nodes_new))
      new_subgraphs.append(new_subgraph)

    batch = Batch.from_data_list(new_subgraphs)

    # (3) Pass batch of next_state subgraphs through ANN to get predicted q-values
    q_predictions = self.q_net(self.embedding_module(batch.x, 
                                                     batch.edge_index, 
                                                     batch.batch))

    # (4) Compute target q-values for batch
    q_targets = []
    for idx, next_state_dict in enumerate(next_states):
      available_actions = self.environments.environments[all_info[idx]['environment_ID']].get_actions(next_state_dict['visited'])

      if available_actions: # terminal states have no available actions
        new_subgraphs = [] # each available action results in a new state

        for action in available_actions:
          visited_nodes_new = deepcopy(next_state_dict['visited'])
          visited_nodes_new.append(action)
          new_subgraph = get_PyG_subgraph(self.environments.environments[all_info[idx]['environment_ID']], 
                                          frozenset(visited_nodes_new))
          new_subgraphs.append(new_subgraph)

        batch = Batch.from_data_list(new_subgraphs)

        with torch.no_grad(): # technically, no_grad() is unnecessary
          q_target = self.target_q_net(self.target_embedding_module(batch.x, 
                                                                    batch.edge_index, 
                                                                    batch.batch))
          q_target = q_target.max().view(-1, 1) # get the largest next q-value
          q_target = rewards[idx] + discounts[idx] * q_target
          q_targets.append(q_target)

      else:
        q_targets.append(rewards[idx])

    q_targets = torch.Tensor(q_targets).view(-1, 1)
      
    # (5) Compute MSE loss between predicted and target q-values
    loss = F.mse_loss(q_predictions, q_targets).mean()

    # (6) Backpropagate gradients
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    # (7) Copy parameters from source to target networks
    if self.step % self.learn_every == 0: 
      copy_parameters_from_to(self.embedding_module, self.target_embedding_module)
      copy_parameters_from_to(self.q_net, self.target_q_net)
      
    # (8) Decrease exploration rate
    self.epsilon *= self.epsilon_decay_rate
    self.epsilon = max(self.epsilon, self.epsilon_min)

    return loss.item() # if needed for logging

In [None]:
#@title Helper Functions: Miscellaneous

def initialize_weights(m):
  """
  Xavier initialization of model weights.
  """

  if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
    m.weight.data.fill_(1.0)
    m.bias.data.zero_()

  elif isinstance(m, SAGEConv):
    m.lin_l.weight.data = nn.init.xavier_uniform_(
        m.lin_l.weight.data, gain = nn.init.calculate_gain('relu'))
    
    if m.lin_l.bias is not None:
      m.lin_l.bias.data.zero_()

    m.lin_r.weight.data = nn.init.xavier_uniform_(
        m.lin_r.weight.data, gain = nn.init.calculate_gain('relu'))
    
    if m.lin_r.bias is not None: # redundant
      m.lin_r.bias.data.zero_()

  elif isinstance(m, nn.Linear):
    m.weight.data = nn.init.xavier_uniform_(
        m.weight.data, gain = nn.init.calculate_gain('relu'))
    
    if m.bias is not None:
      m.bias.data.zero_()

def compute_Frobenius_norm(network):
    """
    Compute the Frobenius norm of all network tensors.
    """
    norm = 0.0

    for name, param in network.named_parameters():
        norm += torch.norm(param).data  
               
    return norm.item()

def copy_parameters_from_to(source_network, target_network):
  """
  Update the parameters of the target network by copying values from the source
  network.
  """

  for source, target in zip(source_network.parameters(), target_network.parameters()):
    target.data.copy_(source.data)

  return

def average_area_under_the_curve(all_feature_values):
  """
  Returns the average area under the curve given a list of list of feature 
  values. Each list inside all_feature_values corresponds to an environment. 
  Each list inside that list corresponds to an episode. Each element of the 
  inner list is a feature value at a given step during an episode.
  """

  all_areas = []
  for env_results in all_feature_values:
    areas = [sum(feature_values) for feature_values in env_results]
    all_areas.append(sum(areas)/len(areas))
  
  return sum(all_areas)/len(all_areas)

def generate_video(plotting_dict):

  feature_values_random = plotting_dict['random']
  feature_values_degree = plotting_dict['degree']
  feature_values_greedy = plotting_dict['greedy']
  feature_values_DQN = np.array(plotting_dict['DQN'])

  xlim = feature_values_DQN.shape[1]
  x = np.arange(xlim) # number of nodes

  ylim = max(max(feature_values_random), 
             max(feature_values_degree), 
             max(feature_values_greedy), 
             np.max(feature_values_DQN))

  fig, ax = plt.subplots()
  ax.axis([0, xlim, 0, ylim + 0.01 * ylim])

  line1, = ax.plot(x, feature_values_random, label = 'random', color = 'blue')
  line2, = ax.plot(x, feature_values_degree, label = 'max degree', color = 'orange')
  line3, = ax.plot(x, feature_values_greedy, label = 'greedy', color = 'green')
  line4, = ax.plot([], [], label = 'DQN', color = 'black')

  ax.legend()

  plt.xlabel('Step')
  plt.ylabel('Value')

  def animate(i):
    line4.set_data(x, feature_values_DQN[i])
    
  anim_handle = animation.FuncAnimation(fig, animate, 
                                        frames = len(feature_values_DQN),
                                        interval = 100,  
                                        blit = False, repeat = False, 
                                        repeat_delay = 10000)
  plt.close() # do not show extra figure

  return anim_handle

def node_featurizer(graph_NX):

  graph_NX = deepcopy(graph_NX)

  attributes = {}

  for node in graph_NX.nodes():
    neighborhood = set(nx.single_source_shortest_path_length(graph_NX, node, cutoff = 1).keys())
    neighborhood.remove(node) # remove node from its own neighborhood
    neighborhood = list(neighborhood) 

    if neighborhood:
      neighborhood_degrees = list(map(list, zip(*graph_NX.degree(neighborhood))))[1]
    else: # no neighbors
      neighborhood_degrees = [0]

    node_attributes = {}
    node_attributes['degree_1'] = graph_NX.degree(node)
    node_attributes['min_degree_1'] = min(neighborhood_degrees)
    node_attributes['max_degree_1'] = max(neighborhood_degrees)
    node_attributes['mean_degree_1'] = float(np.mean(neighborhood_degrees))
    node_attributes['std_degree_1'] = float(np.std(neighborhood_degrees))

    neighborhood = set(nx.single_source_shortest_path_length(graph_NX, node, cutoff = 2).keys())
    neighborhood.remove(node) # remove node from its own neighborhood
    neighborhood = list(neighborhood) 

    if neighborhood:
      neighborhood_degrees = list(map(list, zip(*graph_NX.degree(neighborhood))))[1]
    else: # no neighbors
      neighborhood_degrees = [0]

    node_attributes['min_degree_2'] = min(neighborhood_degrees)
    node_attributes['max_degree_2'] = max(neighborhood_degrees)
    node_attributes['mean_degree_2'] = float(np.mean(neighborhood_degrees))
    node_attributes['std_degree_2'] = float(np.std(neighborhood_degrees))

    attributes[node] = node_attributes
    
  nx.set_node_attributes(graph_NX, attributes)

  return graph_NX

def node_defeaturizer(graph_NX):

  graph_NX = deepcopy(graph_NX)

  for (n, d) in graph_NX.nodes(data = True):

    del d["degree_1"]
    del d["min_degree_1"]
    del d["max_degree_1"]
    del d["mean_degree_1"]
    del d["std_degree_1"]
    del d["min_degree_2"]
    del d["max_degree_2"]
    del d["mean_degree_2"]
    del d["std_degree_2"]

    return graph_NX

class ReplayBuffer():
  
  def __init__(self, buffer_size):

    self.buffer_size = buffer_size
    self.ptr = 0 # index to latest experience in memory
    self.num_experiences = 0 # number of experiences stored in memory
    self.states = [None] * self.buffer_size
    self.actions = [None] * self.buffer_size
    self.next_states = [None] * self.buffer_size
    self.rewards = [None] * self.buffer_size
    self.discounts = [None] * self.buffer_size
    self.all_info = [None] * self.buffer_size

  def add(self, state_dicts, actions, next_state_dicts, rewards, discounts, all_info):

    # check if arguments are lists
    if not isinstance(state_dicts, list): # i.e. a single experience
      state_dicts = [state_dicts]
      actions = [actions]
      next_state_dicts = [next_state_dicts]
      rewards = [rewards]
      discounts = [discounts]
      all_info = [all_info]

    for i in range(len(state_dicts)):
      self.states[self.ptr] = state_dicts[i]
      self.actions[self.ptr] = actions[i]
      self.next_states[self.ptr] = next_state_dicts[i]
      self.rewards[self.ptr] = rewards[i]
      self.discounts[self.ptr] = discounts[i]
      self.all_info[self.ptr] = all_info[i]
      
      if self.num_experiences < self.buffer_size:
        self.num_experiences += 1

      self.ptr = (self.ptr + 1) % self.buffer_size 
      # if (ptr + 1) exceeds buffer size then begin overwriting older experiences

  def sample(self, batch_size):      

    indices = np.random.choice(self.num_experiences, batch_size)   
    states = [self.states[index] for index in indices] 
    actions = [self.actions[index] for index in indices] 
    next_states = [self.next_states[index] for index in indices] 
    rewards = [self.rewards[index] for index in indices] 
    discounts = [self.discounts[index] for index in indices] 
    all_info = [self.all_info[index] for index in indices] 
    
    return states, actions, next_states, rewards, discounts, all_info

def save_checkpoint(embedding_module, q_net, 
                    optimizer, 
                    replay_buffer, 
                    returns, feature_values_train,
                    validation_scores, feature_values_val,
                    step,
                    save_path):
  
  save_dict = {'embedding_module_state_dict': embedding_module.state_dict(),
               'q_net_state_dict': q_net.state_dict(),
               'optimizer_state_dict': optimizer.state_dict(),
               'buffer_ptr': replay_buffer.ptr,
               'buffer_num_experience': replay_buffer.num_experiences,
               'buffer_states': replay_buffer.states,
               'buffer_actions': replay_buffer.actions,
               'buffer_next_states': replay_buffer.next_states,
               'buffer_rewards': replay_buffer.rewards,
               'buffer_discounts': replay_buffer.discounts,
               'buffer_all_info': replay_buffer.all_info,
               'returns': returns,
               'feature_values_train': feature_values_train,
               'validation_scores': validation_scores,
               'feature_values_val': feature_values_val,
               'step': step}

  torch.save(save_dict, save_path)

def load_checkpoint(load_path, embedding_module, q_net, 
                    optimizer = None, 
                    replay_buffer = None):
  
  checkpoint = torch.load(load_path)

  embedding_module.load_state_dict(checkpoint['embedding_module_state_dict'])
  q_net.load_state_dict(checkpoint['q_net_state_dict'])

  if optimizer:
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

  if replay_buffer:
    replay_buffer.ptr = checkpoint['buffer_ptr']
    replay_buffer.num_experiences = checkpoint['buffer_num_experience']
    replay_buffer.states = checkpoint['buffer_states']
    replay_buffer.actions = checkpoint['buffer_actions']
    replay_buffer.next_states = checkpoint['buffer_next_states']
    replay_buffer.rewards = checkpoint['buffer_rewards']
    replay_buffer.discounts = checkpoint['buffer_discounts']
    replay_buffer.all_info = checkpoint['buffer_all_info']

  returns = checkpoint['returns']
  feature_values_train = checkpoint['feature_values_train']
  validation_scores = checkpoint['validation_scores']
  feature_values_val = checkpoint['feature_values_val']

  train_results = {'returns': returns,
                   'feature_values_train': feature_values_train}

  val_results = {'validation_scores': validation_scores,
                 'feature_values_val': feature_values_val}

  return train_results, val_results

In [None]:
#@title Helper Functions: Simulation

def simulate(agent, environments, num_episodes = 100, verbose = True):
  """
  Simulate agent in multiple environments for a specified number of episodes.
  We do not use methods from the MultipleEnvironment() class because each 
  environment may have a different number of nodes.
  """

  agent = deepcopy(agent) # do not alter the original agent's environments
  agent.environments = environments # supply the agent with different environments

  all_feature_values = []

  for idx, environment in enumerate(tqdm(environments.environments, 
                                         disable = not verbose)):
    
    state_dict, terminal, info = environment.reset()
    environment_feature_values = []

    for _ in range(num_episodes):
      episode_rewards = []
      episode_feature_values = []
      
      while not terminal:
        actions = agent.choose_action() 
        action = actions[idx] # agent chooses an action for each environment
        state_dict, reward, terminal, info = environment.step(action)
        episode_feature_values.append(info['feature_value'])
      
      state_dict, terminal, info = environment.reset() # reset environment after use

      environment_feature_values.append(episode_feature_values)
      
    all_feature_values.append(environment_feature_values)
  
  environments.reset() # redundant

  return all_feature_values

def learn_environments(agent, train_environments, val_environments, 
                       num_steps, discount_factor, base_save_path,
                       log_val_results = True, verbose = True):
  """
  Train agent on multiple environments by simulating agent-environment 
  interactions for a specified number of steps.
  """

  agent.environments = train_environments # supply the agent with environments

  # training logs
  all_episode_returns_train = [[] for i in range(train_environments.num_environments)]
  all_episode_feature_values_train = [[] for i in range(train_environments.num_environments)]
  episode_returns_train = [0] * train_environments.num_environments
  episode_feature_values_train = [[] for i in range(train_environments.num_environments)]

  # validation logs
  all_episode_feature_values_val = []
  if not val_environments: 
    log_val_results = False
  val_scores = []
  val_score = -float('inf')

  state_dicts, terminals, all_info = train_environments.reset()

  pbar = tqdm(range(num_steps), unit = 'Step', disable = not verbose)

  for step in pbar:
    actions = agent.choose_action() # choose an action for each environment
    next_state_dicts, rewards, terminals, all_info = train_environments.step(actions)
    episode_returns_train = [sum(x) for x in zip(rewards, episode_returns_train)]

    for idx, info in enumerate(all_info):
      episode_feature_values_train[idx].append(info['feature_value'])

    if agent.is_trainable:
      discounts = [discount_factor * (1 - terminal) for terminal in terminals]
      loss = agent.train(state_dicts, actions, next_state_dicts, rewards, discounts, all_info)
      
      if log_val_results and step % 2000 == 0 or step == num_steps:
        all_feature_values_val = simulate(agent, val_environments,
                                          num_episodes = 10, verbose = False)
        val_score = average_area_under_the_curve(all_feature_values_val)
        val_scores.append(val_score)
        all_episode_feature_values_val.append(all_feature_values_val)

      if loss: 
        pbar.set_description('Loss: %0.5f, Val. Score: %0.5f' % (loss, val_score))
      else: # no loss value is returned inside the burn-in period
        pbar.set_description('Loss: %0.5f, Val. Score: %0.5f' % (float('inf'), val_score))

    state_dicts = next_state_dicts

    for idx, terminal in enumerate(terminals):
      # if terminal then gather episode results for this environment and reset
      if terminal: 
        all_episode_returns_train[idx].append(episode_returns_train[idx])
        episode_returns_train[idx] = 0
        all_episode_feature_values_train[idx].append(episode_feature_values_train[idx])
        episode_feature_values_train[idx] = []
        state_dict, terminal, info = train_environments.environments[idx].reset()
        state_dicts[idx] = state_dict

    if step % 2500 == 0 or step == num_steps - 1: # save model every 2500 steps
      checkpoint_name = 'checkpoint_' + str(step) + '.pt'
      save_path = os.path.join(base_save_path, checkpoint_name) 
      save_checkpoint(agent.embedding_module, agent.q_net, 
                      agent.optimizer,
                      agent.replay_buffer,
                      all_episode_returns_train, all_episode_feature_values_train,
                      val_scores, all_episode_feature_values_val, 
                      step, 
                      save_path)

  train_environments.reset()

  train_results = {'returns': all_episode_returns_train,
                   'feature_values_train': all_episode_feature_values_train}

  val_results = {'validation_scores': val_scores,
                 'feature_values_val': all_episode_feature_values_val}

  return train_results, val_results

In [None]:
#@title Helper Functions: Rewards

def make_filtration_matrix(G):
    """
    Takes in adjacency matrix and returns a filtration matrix for Ripser
    """

    N = G.shape[0]
    weighted_G = np.ones([N, N])
    for col in range(N):
        weighted_G[:col, col] = weighted_G[:col, col] * col
        weighted_G[col, :col] = weighted_G[col, :col] * col
    weighted_G += 1 # pushes second node's identifier to 2
    # removes diagonals, simultaneously resetting first node's identifier to 0
    weighted_G = np.multiply(G, weighted_G) 
    # place 1 to N along the diagonal
    np.fill_diagonal(weighted_G, list(range(1, N + 1)))
    # set all zeros to be non-edges (i.e. at inf distance)
    weighted_G[weighted_G == 0] = np.inf
    # remove 1 from everywhere to ensure first node has identifier 0
    weighted_G -= 1
    
    return weighted_G

def betti_numbers(G, maxdim = 2, dim = 1):
  """
  Given a NetworkX graph object, computes number of topological cycles 
  (i.e. Betti numbers) of various dimensions upto maxdim.
  """
  adj = nx.to_numpy_array(G)
  adj[adj == 0] = np.inf # set unconnected nodes to be infinitely apart
  np.fill_diagonal(adj, 1) # set diagonal to 1 to indicate all nodes are born at once
  bars = ripser(adj, distance_matrix = True, maxdim = maxdim)['dgms'] # returns barcodes
  bars_list = list(zip(range(maxdim + 1), bars))
  bettis_dict = dict([(dim, len(cycles)) for (dim, cycles) in bars_list])

  return bettis_dict[dim] # return Betti number for dimension of interest

def get_barcode(filt_mat, maxdim = 2):
    """
    Calculates the persistent homology for a given filtration matrix
    ``filt_mat``, default dimensions 0 through 2. Wraps ripser.
    """

    b = ripser(filt_mat, distance_matrix = True, maxdim = maxdim)['dgms']

    return list(zip(range(maxdim + 1), b))

def betti_curves(bars, length):
    """
    Takes in bars and returns the betti curves
    """

    bettis = np.zeros((len(bars), length))
    for i in range(bettis.shape[0]):
        bn = bars[i][1]
        for bar in bn:
            birth = int(bar[0])
            death = length+1 if np.isinf(bar[1]) else int(bar[1]+1)
            bettis[i][birth:death] += 1

    return bettis

def plot_bettis(bettis):
  
  N = bettis.shape[1]
  colors = ['xkcd:emerald green', 'xkcd:tealish', 'xkcd:peacock blue']
  for i in range(3):
    plt.plot(list(range(N)), bettis[i], color = colors[i], 
             label = '$\\beta_{}$'.format(i), 
             linewidth = 1)
  plt.xlabel('Nodes')
  plt.ylabel('Number of Cycles')
  plt.legend()

In [None]:
#@title Parameters from Darvariu et al. 

# Data
# |G_train| = 10000
# |G_val| = 100
# |G_test| = 100

# Model
# 3 message passing rounds
# 128 hidden units in MLP
# linear exploration decay from 1 to 0.1 for steps/2 and then 0.1  

# Training
# steps = 40000 (*1 or *2 or *5)
# gamma = 1 (finite horizon)
# learn_every = 50
# weights initialized using Glorot scheme
# learning rate = 0.0001
# rewards scaled by 100

In [None]:
#@title Load Networks + Build Environments

base_path = '/content/drive/My Drive/GraphRL/Networks/'

network_type = 'Synthetic'
# generator_type = 'BA'
generator_type = 'ER'

# feature = nx.average_clustering
feature = betti_numbers

# build test environments
mode = 'LargerTest'
full_path = os.path.join(base_path, network_type, generator_type, mode)
all_test_net_paths = glob.glob(full_path + '/*.gml')

test_environments = []
for idx, net_path in enumerate(all_test_net_paths):
  G = nx.read_gml(net_path, destringizer = int)
  G = node_featurizer(G)
  environment = GraphEnvironment(idx, G, feature)
  test_environments.append(environment)

test_environments = MultipleEnvironments(test_environments)

# Run Simulations: Larger Test Networks

In [None]:
num_episodes = 10

agent = RandomAgent()
all_feature_values = simulate(agent, test_environments, num_episodes)
feature_values_mean_random = np.mean(np.mean(np.array(all_feature_values), axis = 0), axis = 0)

agent = HighestDegreeAgent()
all_feature_values = simulate(agent, test_environments, num_episodes)
feature_values_mean_max_degree = np.mean(np.mean(np.array(all_feature_values), axis = 0), axis = 0)

agent = LowestDegreeAgent()
all_feature_values = simulate(agent, test_environments, num_episodes)
feature_values_mean_min_degree = np.mean(np.mean(np.array(all_feature_values), axis = 0), axis = 0)

agent = GreedyAgent()
all_feature_values = simulate(agent, test_environments, num_episodes)
feature_values_mean_greedy = np.mean(np.mean(np.array(all_feature_values), axis = 0), axis = 0)

# load trained DQN from checkpoint
checkpoint = 'checkpoint_59999.pt'
load_path = os.path.join(base_path, network_type, generator_type, 
                         'Model', feature.__name__, 'Checkpoints', checkpoint)

hyperparameters = {'num_node_features': 9,
                   'GNN_latent_dimensions': 64,
                   'embedding_dimensions': 64,
                   'QN_latent_dimensions': 32,
                   'buffer_size': 500000,
                   'train_start': 320,
                   'batch_size': 32,
                   'learn_every': 16,
                   'epsilon_initial': 0.1,
                   'epsilon_decay_rate': 1,
                   'epsilon_min': 0.1,
                   'discount_factor': 0.75,
                   'learning_rate': 3e-4}

embedding_module = GNN(hyperparameters)
q_net = QN(hyperparameters)

_, _ = load_checkpoint(load_path, embedding_module, q_net)

agent = DQNAgent(embedding_module, q_net, 
                 replay_buffer = None, train_start = None, batch_size = None, 
                 learn_every = None, 
                 optimizer = None, 
                 epsilon = 0, epsilon_decay_rate = None, epsilon_min = None)
all_feature_values = simulate(agent, test_environments, num_episodes)
feature_values_mean_DQN = np.mean(np.mean(np.array(all_feature_values), axis = 0), axis = 0)

In [None]:
checkpoint = 'checkpoint_0.pt'
load_path = os.path.join(base_path, network_type, generator_type, 
                         'Model', feature.__name__, 'Checkpoints', checkpoint)

embedding_module_untrained = GNN(hyperparameters)
q_net_untrained = QN(hyperparameters)

_, _ = load_checkpoint(load_path, embedding_module_untrained, q_net_untrained)

agent = DQNAgent(embedding_module_untrained, q_net_untrained, 
                 replay_buffer = None, train_start = None, batch_size = None, learn_every = None, 
                 optimizer = None, 
                 epsilon = 1, epsilon_decay_rate = None, epsilon_min = None)
all_feature_values = simulate(agent, test_environments, num_episodes)
feature_values_mean_DQN_untrained = np.mean(np.mean(np.array(all_feature_values), axis = 0), axis = 0)

In [None]:
plt.title('Test Performance (Larger Networks)')
plt.xlabel('Step')
plt.ylabel('Feature Value')
plt.plot(feature_values_mean_random, label = 'Random', color = 'blue')
plt.plot(feature_values_mean_max_degree, label = 'Max Degree', color = 'orange')
plt.plot(feature_values_mean_min_degree, label = 'Min Degree', color = 'red')
plt.plot(feature_values_mean_greedy, label = 'Greedy', color = 'green')
plt.plot(feature_values_mean_DQN, label = 'DQN', color = 'black')
plt.plot(feature_values_mean_DQN_untrained, label = 'DQN (untrained)', color = 'black', linestyle = 'dashed')
plt.legend()
save_path = os.path.join(base_path, 
                         network_type, generator_type, 
                         'Model', feature.__name__, 'Figures', 'LargerTest')
plt.savefig(os.path.join(save_path, 'larger_test_performance.eps'), format = 'eps')

In [None]:
plt.title('Test Performance (Larger Networks)')
plt.xlabel('Step')
plt.ylabel('Feature Value')
plt.plot(feature_values_mean_random, label = 'Random', color = 'blue')
plt.plot(feature_values_mean_max_degree, label = 'Max Degree', color = 'orange')
plt.plot(feature_values_mean_min_degree, label = 'Min Degree', color = 'red')
plt.plot(feature_values_mean_greedy, label = 'Greedy', color = 'green')
plt.plot(feature_values_mean_DQN, label = 'DQN', color = 'black')
plt.plot(feature_values_mean_DQN_untrained, label = 'DQN (untrained)', color = 'black', linestyle = 'dashed')
plt.legend()
save_path = os.path.join(base_path, 
                         network_type, generator_type, 
                         'Model', feature.__name__, 'Figures', 'LargerTest')
plt.savefig(os.path.join(save_path, 'larger_test_performance.eps'), format = 'eps')