In [None]:
def plot_x0_dist_report(x0, save_fig=False, x0_args=None):

  fig, ax = plt.subplots(3, 2, figsize=(15, 15))

  ax[0,0].hist2d(x0[:,2] / (1. + x0[:,1]) , x0[:,2] * x0[:,1] / (1. + x0[:,1]),
                 bins=50)
  ax[0,0].set_title(r'(a)')
  ax[0,0].set_xlabel(r'$A^B$')
  ax[0,0].set_ylabel(r'$A^S$')
  ax[0,0].grid()

  ax[0,1].hist(x0[:,1], bins = 50)
  if x0_args is not None:
    ax[0,1].axvline(x0_args['initial_price'], color = 'r')
  ax[0,1].set_title(r'(b)')
  ax[0,1].set_xlabel(r'CPFM Price')
  ax[0,1].set_ylabel(r'Count')
  ax[0,1].grid()

  ax[1,0].hist(x0[:,0] - x0[:,1], bins = 50)
  ax[1,0].axvline(0, color = 'r')
  ax[1,0].set_title('(c)')
  ax[1,0].set_xlabel('Price Difference (Reference - CPFM Price)')
  ax[1,0].set_ylabel('Count')
  ax[1,0].grid()

  ax[1,1].hist(x0[:,4], bins = 50)
  ax[1,1].set_title(r'(d)')
  ax[1,1].set_xlabel(r'General Account Value')
  ax[1,1].set_ylabel(r'Count')
  ax[1,1].grid()

  ax[2,0].hist(x0[:,3], bins = 30)
  ax[2,0].axvline(0, color = 'r')
  ax[2,0].set_title(r'(e)')
  ax[2,0].set_xlabel(r'Position')
  ax[2,0].set_ylabel(r'Count')
  ax[2,0].grid()

  ax[2,1].hist(x0[:,5], bins = 50)
  ax[2,1].set_title(r'(f)')
  ax[2,1].set_xlabel('Margin Account Value')
  ax[2,1].set_ylabel('Count')
  ax[2,1].grid()

  fig.tight_layout()
  if save_fig:
    fig.savefig(os.path.join(results_path,fig_name))
  plt.show()

In [None]:
def plot_x0_dist(x0, save_fig=False, x0_args=None):

  fig, ax = plt.subplots(3, 3, figsize=(20, 15))

  ax[0,0].hist2d(x0[:,2] / (1. + x0[:,1]) , x0[:,2] * x0[:,1] / (1. + x0[:,1]),
                 bins=50)
  ax[0,0].set_title(r'Histogram of initial reserves')
  ax[0,0].set_xlabel(r'$A_B$')
  ax[0,0].set_ylabel(r'$A_S$')
  ax[0,0].grid()

  ax[0,1].hist(x0[:,1], bins = 50)
  if x0_args is not None:
    ax[0,1].axvline(x0_args['initial_price'], color = 'r')
  ax[0,1].set_title(r'Pool Price')
  ax[0,1].set_xlabel(r'$Price$')
  ax[0,1].set_ylabel(r'$Count$')
  ax[0,1].grid()

  ax[0,2].scatter(x0[:,1], x0[:, 0])
  ax[0,2].set_title(r'Pool vs ReferencePrice')
  ax[0,2].set_xlabel(r'$Pool Price$')
  ax[0,2].set_ylabel(r'$Reference Price$')
  ax[0,2].grid()

  ax[1,0].hist(x0[:,0] - x0[:,1], bins = 50)
  ax[1,0].axvline(0, color = 'r')
  ax[1,0].set_title('Residues')
  ax[1,0].set_xlabel('$Price Difference$')
  ax[1,0].set_ylabel('Count')
  ax[1,0].grid()

  ax[1,1].hist(x0[:,3], bins = 20)
  ax[1,1].axvline(0, color = 'r')
  ax[1,1].set_title(r'Initial Position')
  ax[1,1].set_xlabel(r'$Position$')
  ax[1,1].set_ylabel(r'$Counts$')
  ax[1,1].grid()

  #ax[1,1].hist2d(x0[:,3], x0[:,5], bins=20)
  ax[1,2].scatter(x0[:,3], x0[:,5])
  ax[1,2].set_title(r'Histogram of initial accounts')
  ax[1,2].set_xlabel(r'$n$')
  ax[1,2].set_ylabel(r'$m$')
  ax[1,2].grid()

  ax[2,0].hist(x0[:,4], bins = 50)
  ax[2,0].set_title(r'General')
  ax[2,0].set_xlabel(r'$Value$')
  ax[2,0].set_ylabel(r'$Count$')
  ax[2,0].grid()

  ax[2,1].hist(x0[:,5], bins = 50)
  ax[2,1].set_title(r'Margin')
  ax[2,1].set_xlabel(r'$Value$')
  ax[2,1].set_ylabel(r'$Count$')
  ax[2,1].grid()

  if x0.shape[1] > 6:
    ax[2,2].hist(x0[:,6], bins = 50)
    ax[2,2].set_title(r'Pool Price')
    ax[2,2].set_xlabel(r'$Price$')
    ax[2,2].set_ylabel(r'$Count$')
    ax[2,2].grid()

  if save_fig:
    fig.savefig(os.path.join(results_path,fig_name))
  plt.show()

In [None]:
def plot_sampler(x0_args, num_samples, save_fig = False):
  x0_sampler = partial(sample_x0, device=device, **x0_args)
  x0 = x0_sampler(num_samples).cpu().numpy()
  
  plot_x0_dist(x0, save_fig, x0_args)

  return x0

In [None]:
def plot_bellman(bellman_loss_dict, bellman_approx_dict, save_fig=False):
  fig, ax = plt.subplots(1, 2, figsize=(18,6))

  for name, bellman_loss in bellman_loss_dict.items():
    ax[0].plot(torch.Tensor(bellman_loss).cpu(), label = name)
    ax[1].plot(torch.Tensor(bellman_approx_dict[name]).cpu(), label = name)
    
  ax[0].set_xlabel('Value function updates')
  ax[0].set_ylabel('Bellman loss')
  ax[0].grid()
  ax[1].set_xlabel('Policy updates')
  ax[1].set_ylabel('Bellman approximation of value function')
  ax[1].grid()
  plt.legend()
  if save_fig:
    fig.savefig(os.path.join(results_path,fig_name))
  plt.show()

In [None]:
def plot_average_pnl_over_time(results, fig_name='', save_fig=False):
  fig, ax = plt.subplots(1, 2, figsize=(15, 5))

  for result in results:

    mean_pnl, std_pnl = result.get_mean_pnl()
    ax[0].plot(mean_pnl.cpu(), label = result.name)

    ax[1].plot(std_pnl.cpu(), label = result.name)

  ax[0].grid()
  ax[0].set_xlabel('Time Step')
  ax[0].set_ylabel('Average PNL')
  ax[0].legend()

  ax[1].grid()
  ax[1].set_xlabel('Time Step')
  ax[1].set_ylabel('Standard Deviation of PNL')
  ax[1].legend()

  fig.tight_layout()

  plt.show()
  if save_fig:
    fig.savefig(os.path.join(results_path,fig_name))

In [None]:
def compare_quantiles(results, save_fig = False):

  fig, ax = plt.subplots(1, len(results), figsize=(len(results)*8, 6))
  i = 0
  y_min = 0
  y_max = 0
  for result in results:
    mean_pnl, std_pnl = result.get_mean_pnl()

    ci_10, ci_90 = result.get_pnl_quantile(0.8)
    ci_5, ci_95 = result.get_pnl_quantile(0.9)
    ci_1, ci_99 = result.get_pnl_quantile(0.98)
    ci_25, ci_75 = result.get_pnl_quantile(0.5)
    ci_max = torch.max(result.PnL_cumsum, 0).values
    ci_min = torch.min(result.PnL_cumsum, 0).values
    x = np.linspace(0, 99, 99)

    ax[i].plot(mean_pnl.cpu(), color='k', label = 'Mean')
    ax[i].fill_between(x, ci_min.cpu(), ci_max.cpu(), alpha=0.25, label = 'Max/Min')
    ax[i].fill_between(x, ci_1.cpu(), ci_99.cpu(), alpha=0.25, label = '98%')
    ax[i].fill_between(x, ci_5.cpu(), ci_95.cpu(), alpha=0.25, label = '90%')
    ax[i].fill_between(x, ci_25.cpu(), ci_75.cpu(), alpha=0.25, label = '50%')

    ax[i].legend(loc = 'upper left')
    ax[i].grid()
    ax[i].set_title(result.name)
    ax[i].set_xlabel('Time Step')
    ax[i].set_ylabel('Profit Net Loss (PNL)')
    y_min = min(y_min, ax[i].get_ylim()[0])
    y_max = max(y_max, ax[i].get_ylim()[1])
    i += 1

  for j in range(i):
    ax[j].set_ylim([y_min, y_max])

  if save_fig:
    fig.savefig(os.path.join(results_path,str(i)+fig_name))
  plt.show()

In [None]:
def plot_worst_pnl(result, ideal_result, save_fig=False):
  worst_result_index = torch.argmin(result.final_profits()).cpu().numpy()
  make_comparison_plots([result, ideal_result], 1, [worst_result_index], save_fig)

  return worst_result_index

def plt_best_pnl(result, ideal_result):
  best_result_index = torch.argmax(result.final_profits()).cpu().numpy()
  make_comparison_plots([result, ideal_result], 1, [best_result_index])