In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
from datetime import date
from NashRL import run_Nash_Agent
from NashAgent_lib import *
import time

from simulation_lib import MarketSimulator


# Set global digit printing options
np.set_printoptions(precision=4)

# Define Training and Model Parameters
num_players = 5           # Total number of agents

# Default simulation parameters

kappa = 0.5
sim_dict = {
        'perm_price_impact': torch.tensor(0.05).cuda().detach(),
        'transaction_cost': torch.tensor(.1).cuda().detach(),
        'liquidation_cost': torch.tensor(.1).cuda().detach(),
        'running_penalty': torch.tensor(0.0).cuda().detach(),
        'trans_impact_scale':torch.tensor(0.02).cuda().detach(),
        'trans_impact_decay':torch.tensor(0.5).cuda().detach(),
        'T': torch.tensor(5).cuda().detach(),
        'dt': torch.tensor(0.5).cuda().detach(),
        'N_agents': num_players,
        'drift_function': (lambda x, y: kappa*(10-y)),
        'volatility': torch.tensor(0.1).cuda().detach(),
        'init_inv_var': torch.tensor(50).cuda().detach()}

# compute invariant distribution for initial price variance
inv_std = sim_dict['volatility']* torch.sqrt((1 - torch.exp(-2*kappa*sim_dict['T']))/ (2*kappa))
sim_dict['initial_price_var'] = torch.tensor(inv_std).cuda().detach()

norm_mean = torch.tensor([2.25, 10, 0, 0, 0 ]).cuda().detach()
norm_std = torch.tensor([1.4361406616345072, 0.74204157112471332 * 0.2763, 2.5 * 1.8078, 0.1 * 0.4225, 1 * 1.6726]).cuda().detach()

sim_obj = MarketSimulator(sim_dict, impact='sqrt')

nash_agent = NashNN(non_invar_dim=5, n_players=sim_obj.N,
                    output_dim=5, max_steps=10, lr = 3e-4, weighted_adam=True,
                    terminal_cost=sim_dict['liquidation_cost'], num_moms=0,c_cons=50,c2_cons=False,c3_pos=False, layers=4)

In [None]:
start = time.time()

str_dt = date.today().strftime("%d%m%Y")
nash_agent, loss_data = \
    run_Nash_Agent(sim_obj, sim_dict, nash_agent=nash_agent, num_sim=20000, max_steps = 10,
                   norm_mean = norm_mean,
                   norm_std = norm_std,
                   rv_min=0.5, rv_max=2.5, early_stop=True, early_lim=3000,
                   path = "/pt_files/Nash_DQN/",
                   AN_file_name="/pt_files/Nash_DQN/Action_Net_ADA", 
                   VN_file_name="/pt_files/Nash_DQN/Value_Net_ADA")

print("Total time taken: ")
print(time.time() - start)

In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_data)
plt.ylim((0, 100))