In [34]:
import seaborn as sns
from pathlib import Path
import torch
import pandas as pd
from DQN import DQN_agent_modular
from envs.GraphEnv.impnode import ImpnodeEnv
from DQN.finetune_dqn import finetune_dqn
from DQN.test_and_compare import test_loop, test_loop2
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
import time
import mlflow
import  numpy as np
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [35]:
mlflow.set_tracking_uri('sqlite:///mlflow.db')
mlflow.set_experiment('finetune_baseline_models')

<Experiment: artifact_location='file:///C:/rituja_git/ma-rituja-pardhi/mlruns/21', creation_time=1714560410552, experiment_id='21', last_update_time=1714560410552, lifecycle_stage='active', name='finetune_baseline_models', tags={}>

In [36]:
import random

seed = 412
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.use_deterministic_algorithms(True,warn_only=True)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [37]:
def plot_sns(sum_reward_histories, style="classic", plot_size=(10, 8)):
    with plt.style.context(style=style):
        fig, ax = plt.subplots(figsize=plot_size)
        sns.lineplot({'ImpNode_ANC':sum_reward_histories[0], 'ImpNode_ANC_finetune':sum_reward_histories[1]})
        
        ax.set_title("cumulative sum of rewards", fontsize=14)
        ax.set_xlabel("num nodes removed", fontsize=12)
        ax.set_ylabel("reward", fontsize=12)
        for i in ax.get_xticklabels() + ax.get_yticklabels():
            i.set_fontsize(10)
        ax.legend_.remove()
        plt.tight_layout()
    plt.close(fig)
    return fig

In [38]:
def test(device, model_path, finetuned_model_path, graph_name):
    subdir = 'data/real/Cost'
    file_name = '{}_degree.gml'.format(graph_name)
    data_path = Path.cwd()/subdir
    max_removed_nodes = None
    NUM_TEST_EPS = 1 # number of test episodes to run
    
    
    env_test = ImpnodeEnv(anc='dw_nd', num_nodes=(30, 50), data_path=data_path,mode='test',  file_name=file_name, max_removed_nodes=max_removed_nodes)
    
    dqn_agent_test = DQN_agent_modular.DQNAgent(device=device,
                                                alpha=0.001,
                                        gnn_depth=4,
                                        state_size=2,
                                        hidden_size1=32,
                                        hidden_size2=64,
                                        action_size=1,
                                        discount=0.0,
                                        eps_max=0.0,
                                        eps_min=0.0,
                                        eps_step=0.0,
                                        memory_capacity=0,
                                        lr=0,
                                        mode='test')
    
    dqn_agent_test.load_model('{}/model.pt'.format(model_path))
    
    actions, reward_history, ep_score_history = test_loop2(env=env_test,
                                                            agent=dqn_agent_test,
                                                            NUM_TEST_EPS = NUM_TEST_EPS,
                                                           step_ratio=0.01)
    
    dqn_agent_test_finetuned = DQN_agent_modular.DQNAgent(device=device,
                                                alpha=0.001,
                                        gnn_depth=4,
                                        state_size=2,
                                        hidden_size1=32,
                                        hidden_size2=64,
                                        action_size=1,
                                        discount=0.0,
                                        eps_max=0.0,
                                        eps_min=0.0,
                                        eps_step=0.0,
                                        memory_capacity=0,
                                        lr=0,
                                        mode='test')
    
    dqn_agent_test_finetuned.load_model('{}/model.pt'.format(finetuned_model_path))
    
    f_actions, f_reward_history, f_ep_score_history = test_loop2(env=env_test,
                                                            agent=dqn_agent_test_finetuned,
                                                            NUM_TEST_EPS = NUM_TEST_EPS,
                                                                 step_ratio=0.01)
    actions = [int(action.to('cpu')) for action in actions]
    df_actions = pd.DataFrame(actions)
    df_actions.to_csv('{}/actions.csv'.format(finetuned_model_path))
    
    f_actions = [int(action.to('cpu')) for action in f_actions]
    df_f_actions = pd.DataFrame(f_actions)
    df_f_actions.to_csv('{}/finetuned_actions.csv'.format(finetuned_model_path))
    
    #mlflow.log_metric('actions', actions)
    #mlflow.log_metric('finetuned_actions', f_actions)
    #cum_sum_reward_his = np.cumsum(reward_history)
    #cum_sum_hda_reward_his = np.cumsum(f_reward_history)
    
    #fig = plot_sns([cum_sum_reward_his, cum_sum_hda_reward_his])
    
    #mlflow.log_figure(fig,'Cumulative reward.png')
    #mlflow.log_param('test_graph',graph_name)
    
    

In [39]:
def testing(run_id, model_path, finetuned_model_path, graph_name):
    
    seed = 412
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.use_deterministic_algorithms(True,warn_only=True)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    with mlflow.start_run(run_id=run_id):
        test(device, model_path, finetuned_model_path, graph_name)

In [40]:
run_ids = ['3381e0c8fcce4b84a5120b6a9714b822',
           '0fc0785ffc3d49d7a9c6883224300d9b',
           #'74ee27e3a2194a29bf33c8009fef0945',
           'ea8e10d895024d7c9f3b35a0262784d5',
           '5f9bbe6b95d448f5b6447f5a57d07bb8',
           #'32300d3e7f494257a4d5324c0041a46d',
           '9f97a43eb0d345399d9e7647ee6c93be',
           '016ee29739714dc9b24ce5e89388749c',
           #'ae7de19f6e0245489966e5a7bdf0f0da'
           ]
model_paths = ['results/train_baseline_models/barabasi-albert_20240430122831',
               'results/train_baseline_models/barabasi-albert_20240430122831',
               #'results/train_baseline_models/barabasi-albert_20240430122831',
               'results/train_baseline_models/erdos-renyi_20240430184633',
               'results/train_baseline_models/erdos-renyi_20240430184633',
               #'results/train_baseline_models/erdos-renyi_20240430184633',
               'results/train_baseline_models/watts-strogatz_20240430231416',
               'results/train_baseline_models/watts-strogatz_20240430231416',
               #'results/train_baseline_models/watts-strogatz_20240430231416'
               ]
finetuned_model_paths = \
    ['results/train_baseline_models/barabasi-albert_20240430122831/t-BA_f-Facebook_20240501221814',
     'results/train_baseline_models/barabasi-albert_20240430122831/t-BA_f-Gnutella31_20240501230533',
     #'results/train_baseline_models/barabasi-albert_20240430122831/t-BA_f-PG_20240501161733',
     'results/train_baseline_models/erdos-renyi_20240430184633/t-ER_f-Facebook_20240501234340',
     'results/train_baseline_models/erdos-renyi_20240430184633/t-ER_f-Gnutella31_20240502003032',
     #'results/train_baseline_models/erdos-renyi_20240430184633/t-ER_f-PG_20240501184319',
     'results/train_baseline_models/watts-strogatz_20240430231416/t-WS_f-Facebook_20240502010733',
     'results/train_baseline_models/watts-strogatz_20240430231416/t-WS_f-Gnutella31_20240502015348',
     #'results/train_baseline_models/watts-strogatz_20240430231416/t-WS_f-PG_20240501202443'
     ]
graph_names = ['Facebook','Gnutella31',#'PG',
               'Facebook','Gnutella31',#'PG',
               'Facebook','Gnutella31',#'PG'
                 ]

In [41]:
for run_id,model_path,finetuned_model_path,graph_name in zip(run_ids,model_paths,finetuned_model_paths,graph_names):
    testing(run_id,model_path,finetuned_model_path,graph_name)
    print('Finished')
    

IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



Finished


IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



Finished


IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



Finished


IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



Finished


IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



Finished


IOPub data rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_data_rate_limit`.

Current values:
ServerApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
ServerApp.rate_limit_window=3.0 (secs)



Finished
