<a href="https://colab.research.google.com/github/tarod13/Stochastic_Games/blob/master/stochastic_games_herkovitz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from scipy.optimize import linprog
import matplotlib.pyplot as plt
import itertools
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter

In [None]:
beta = 0.99

M = 5
N_P = 2
N_A = 2
N_S = 3 + N_A * (2*(M+1) + 2)
G = [('G')]
O = [('O',i) for i in range(1,3)]
E1 = [('E1',i,j) for i in range(0,N_A) for j in range(0,M+1)]
E2 = [('E2',i,j) for i in range(0,N_A) for j in range(0,M+1)]
R1 = [('R1',i) for i in range(0,N_A)]
R2 = [('R2',i) for i in range(0,N_A)]
S = list(itertools.chain(G, O, E1, E2, R1, R2)) 
S1 = list(itertools.chain(G, ('O',1), E1, R1))
S2 = list(itertools.chain([('O',2), E2, R2]))

RG1 = np.array([[3.,0.],[5.,1.]])
RG2 = np.array([[3.,5.],[0.,1.]])

In [None]:
class game(nn.Module):
  def __init__(self):
    super().__init__()
    self.log_pi1 = nn.ParameterDict()
    self.log_pi1['G'] = Parameter(torch.Tensor(N_A,1))
    nn.init.zeros_(self.log_pi1['G'])
    self.log_pi1[str(('O',1))] = Parameter(torch.Tensor((M+1)*N_A,1))
    nn.init.zeros_(self.log_pi1[str(('O',1))])    
    self.log_pi1[str(('O',2))] = Parameter(torch.Tensor(1,1))
    nn.init.ones_(self.log_pi1[str(('O',2))])

    for state in E1:
      self.log_pi1[str(state)] = Parameter(torch.Tensor(2,1))
      nn.init.zeros_(self.log_pi1[str(state)])
      
    for state in E2 + R2:
      self.log_pi1[str(state)] = Parameter(torch.Tensor(1,1))
      nn.init.ones_(self.log_pi1[str(state)])

    for state in R1:
      self.log_pi1[str(state)] = Parameter(torch.Tensor(N_A,1))
      nn.init.zeros_(self.log_pi1[str(state)])

    self.log_pi2 = nn.ParameterDict()
    self.log_pi2['G'] = Parameter(torch.Tensor(N_A,1))
    nn.init.zeros_(self.log_pi2['G'])
    self.log_pi2[str(('O',1))] = Parameter(torch.Tensor(1,1))
    nn.init.ones_(self.log_pi2[str(('O',1))])
    self.log_pi2[str(('O',2))] = Parameter(torch.Tensor((M+1)*N_A,1))
    nn.init.zeros_(self.log_pi2[str(('O',2))])

    for state in E2:
      self.log_pi2[str(state)] = Parameter(torch.Tensor(2,1))
      nn.init.zeros_(self.log_pi2[str(state)])
      
    for state in E1 + R1:
      self.log_pi2[str(state)] = Parameter(torch.Tensor(1,1))
      nn.init.ones_(self.log_pi2[str(state)])

    for state in R2:
      self.log_pi2[str(state)] = Parameter(torch.Tensor(N_A,1))
      nn.init.zeros_(self.log_pi2[str(state)])

    self.v = Parameter(torch.Tensor(N_S,2))
    nn.init.zeros_(self.v)    

    self.sqrt_lambda_nash1 = nn.ParameterDict()
    self.sqrt_lambda_nash2 = nn.ParameterDict()  # dual vars. for Nash inequalities
    for s in S:
        self.sqrt_lambda_nash1[str(s)] = Parameter(torch.Tensor(self.log_pi1[str(s)].detach().size()))
        self.sqrt_lambda_nash2[str(s)] = Parameter(torch.Tensor(self.log_pi2[str(s)].detach().size()))
        nn.init.zeros_(self.sqrt_lambda_nash1[str(s)])
        #self.sqrt_lambda_nash1[str(s)].data.mul_(1.0)
        nn.init.zeros_(self.sqrt_lambda_nash2[str(s)])
        #self.sqrt_lambda_nash2[str(s)].data.mul_(1.0)

  def forward(self):
    return self.pi(), self.v, self.lambda_nash()
  
  def pi(self):
    pi = {'1':{}, '2':{}}
    for s in S:
      log_pi1 = self.log_pi1[str(s)]
      log_pi2 = self.log_pi2[str(s)]
      
      log_pi1_shift = log_pi1 - log_pi1.max()
      log_pi2_shift = log_pi2 - log_pi2.max()

      pi1 = torch.exp(log_pi1_shift + 1e-20)
      pi2 = torch.exp(log_pi2_shift + 1e-20)
      
      pi['1'][str(s)] = pi1 / pi1.sum()
      pi['2'][str(s)] = pi2 / pi2.sum()
    return pi
  
  def lambda_nash(self):
    lambda_ = {'1':{}, '2':{}}
    for s in S:
      lambda_['1'][str(s)] = self.sqrt_lambda_nash1[str(s)].pow(2)
      lambda_['2'][str(s)] = self.sqrt_lambda_nash2[str(s)].pow(2)
    return lambda_

In [None]:
game_b = game().to('cuda')
optimizer = optim.Adam(game_b.parameters(), lr=1e-3)

N_A_S = {}
with torch.no_grad():
  pi, v, lambda_nash = game_b()
  for state in S:
    N_A1_state = pi['1'][str(state)].shape[0]
    N_A2_state = pi['2'][str(state)].shape[0]
    N_A_S[str(state)] = {'1': N_A1_state, '2': N_A2_state}

In [None]:
def offer_accepted(action):
  if action == 0:
    return True
  else:
    return False
  

def other_player(i):
  if i == 1:
    return 2
  elif i == 2:
    return 1
  else:
    assert 0 == 1, 'Invalid player id'


def transition_info(state, actions):
  if state == 'G' or 'R1' in state or 'R2' in state:
    T = [(('O',1), 0.5), (('O',2), 0.5)]
    return ('R', T)
  elif 'O' in state:
    _, id_player = state
    offeral = actions[id_player-1] // N_A
    action_requested = actions[id_player-1] % N_A
    return ('D', ('E'+str(other_player(id_player)), action_requested, offeral))
  elif 'E1' in state:
    _, action_requested, _ = state
    if offer_accepted(actions[0]):
      return ('D', ('R2', action_requested))
    else:
      return ('D', ('G'))
  elif 'E2' in state:
    _, action_requested, _ = state
    if offer_accepted(actions[1]):
      return ('D', ('R1', action_requested))
    else:
      return ('D', ('G'))   
  else:
    assert 0 == 1, 'Invalid state' 

In [None]:
transition_types = {}
for state in S:
  if state == 'G' or 'R1' in state or 'R2' in state:
    transition_types[str(state)] = (0,0) 
  else:
    transition_types[str(state)] = (1,1) # entries correspond to deterministic behaviour and dependence on actions

In [None]:
def rewards(state, actions):
  if state == 'G':
    r1 = RG1[actions[0], actions[1]]
    r2 = RG2[actions[0], actions[1]]
  elif 'O' in state:
    r1 = r2 = 0.0
  elif 'E1' in state:
    _, _, offeral = state
    if offer_accepted(actions[0]):
      r1 = offeral
      r2 = -offeral
    else:
      r1 = r2 = 0
  elif 'E2' in state:
    _, _, offeral = state
    if offer_accepted(actions[1]):
      r1 = -offeral
      r2 = offeral
    else:
      r1 = r2 = 0
  elif 'R1' in state:
    _, i = state
    r1 = RG1[actions[0], i]
    r2 = RG2[actions[0], i]
  elif 'R2' in state:
    _, i = state
    r1 = RG1[i, actions[1]]
    r2 = RG2[i, actions[1]]
  else:
    assert 0 == 1, 'Invalid state'

  return r1, r2


def reward_matrices(s):
  if s == 'G':
    RM1 = RG1.copy()
    RM2 = RG2.copy()
  else:
    N_A1 = N_A_S[str(s)]['1']
    N_A2 = N_A_S[str(s)]['2']
    RM1 = np.zeros((N_A1,N_A2))
    RM2 = np.zeros((N_A1,N_A2))
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        r1, r2 = rewards(s, [a1,a2])
        RM1[a1,a2] = r1
        RM2[a1,a2] = r2
  return RM1, RM2  

In [None]:
RM = {}
RM['1'] = {}
RM['2'] = {}
for s in S:
  RM1, RM2 = reward_matrices(s)
  RM['1'][str(s)] = torch.FloatTensor(RM1).to('cuda')
  RM['2'][str(s)] = torch.FloatTensor(RM2).to('cuda')

In [None]:
def get_state_index(state):
  return S.index(state)

def player_dim(i):
  if i == 1:
    return 'i'
  elif i == 2:
    return 'j'
  else:
    assert 0 == 1, 'Invalid player id'

def players():
  return range(1,2+1)

def players_str():
  return iter(['1','2']) 

def state_player_pairs():
  return itertools.product(S, players())

S_str = [str(s) for s in S]
def state_player_str_pairs():
  return itertools.product(S_str, players_str())


#-------------------------------------------------------------------------------

def next_value_matrices(s, v): # TODO: consider other 2 cases
  det, dep = transition_types[str(s)]
  N_A1 = N_A_S[str(s)]['1']
  N_A2 = N_A_S[str(s)]['2']
  vs = torch.zeros((N_A1,N_A2,2)).to('cuda')
  if det and dep:
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        _, next_state = transition_info(s, [a1,a2])
        vs[a1,a2,:] = v[get_state_index(next_state),:]
  elif (not det) and (not dep):
    _, transition_dic = transition_info(s, [])
    next_v = torch.zeros(1,2).to('cuda')
    for next_state, transition_prob in transition_dic:
      next_v = next_v + v[get_state_index(next_state),:].view(1,-1) * transition_prob      
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        vs[a1,a2,:] = next_v.view(-1)
  return vs 


def transition_matrix(pi): # TODO: consider other 2 cases
  transition_matrix = torch.zeros((N_S,N_S)).to('cuda')
  for state in S:
    strategy_1 = pi['1'][str(state)]
    strategy_2 = pi['2'][str(state)]

    det, dep = transition_types[str(state)]
    N_A1 = strategy_1.shape[0]
    N_A2 = strategy_2.shape[0]
    id_s = get_state_index(state)

    if det and dep:
      for a1 in range(0,N_A1):
        for a2 in range(0,N_A2):
          _, next_state = transition_info(state, [a1,a2])
          id_ns = get_state_index(next_state)
          transition_prob = strategy_1[a1,0] * strategy_2[a2,0]
          transition_matrix[id_s, id_ns] = transition_matrix[id_s, id_ns] + transition_prob
    elif (not det) and (not dep):
      _, transition_dic = transition_info(state, [])
      for next_state, transition_prob in transition_dic:
        transition_matrix[id_s, get_state_index(next_state)] = transition_prob
  return transition_matrix


def partial_transition_matrices(pi): # TODO: consider other 2 cases
  transition_matrices = {}
  for s in S:
    strategy_1 = pi['1'][str(s)]
    strategy_2 = pi['2'][str(s)]
    N_A1 = strategy_1.shape[0]
    N_A2 = strategy_2.shape[0]
    transition_matrices[str(s)] = {
        '1': torch.zeros((N_A1,N_S)).to('cuda'),
        '2': torch.zeros((N_A2,N_S)).to('cuda')
    }

    det, dep = transition_types[str(s)]
    if det and dep:
      for a1 in range(0,N_A1):
        for a2 in range(0,N_A2):
          _, next_state = transition_info(s, [a1,a2])
          id_ns = get_state_index(next_state)
          transition_prob1 = strategy_2[a2,0]
          transition_prob2 = strategy_1[a1,0]
          transition_matrices[str(s)]['1'][a1, id_ns] = (
              transition_matrices[str(s)]['1'][a1, id_ns] + transition_prob1)
          transition_matrices[str(s)]['2'][a2, id_ns] = (
              transition_matrices[str(s)]['2'][a2, id_ns] + transition_prob2)
    elif (not det) and (not dep):
      _, transition_dic = transition_info(s, [])
      for next_state, transition_prob in transition_dic:
        transition_matrices[str(s)]['1'][:, get_state_index(next_state)] = transition_prob
        transition_matrices[str(s)]['2'][:, get_state_index(next_state)] = transition_prob
  return transition_matrices


def expected_reward(RM, pi):
  r_mean = torch.zeros((N_S,2)).to('cuda')
  for i in range(1,2+1):
    RM_i = RM[str(i)]
    for s in S:
      strategy_1 = pi['1'][str(s)]
      strategy_2 = pi['2'][str(s)]
      r_mean_1 = torch.einsum('ij,ik->jk', RM_i[str(s)], strategy_1)
      r_mean[get_state_index(s),i-1] = (r_mean_1 * strategy_2).sum()
  return r_mean


def partial_expected_reward(RM, pi):
  r_mean = {}
  for player_id in range(1,2+1):  
    r_mean[str(player_id)] = {}  
    for s in S:
      N_A = N_A_S[str(s)][str(player_id)]
      RM_i = RM[str(player_id)]
      other_player_id = other_player(player_id)
      strategy = pi[str(other_player_id)][str(s)].view(-1)
      formula = 'ij,'+player_dim(other_player_id)+'->'+player_dim(player_id)
      r_mean[str(player_id)][str(s)] = torch.einsum(formula, RM_i[str(s)], strategy).view(-1,1)
  return r_mean


def bellman_projection(RM, pi, v):
  r_mean = expected_reward(RM, pi)
  next_v = torch.zeros((N_S,2)).to('cuda')
  for s in S:
    next_state_action_value = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    next_value_1 = torch.einsum('ijk,i->jk', next_state_action_value, strategy_1)
    next_v[get_state_index(s),:] = torch.einsum('jk,j->k', next_value_1, strategy_2)
  return r_mean + beta * next_v


def bellman_partial_projection(RM, pi, v):
  r_mean = partial_expected_reward(RM, pi)
  next_v = {'1':{}, '2':{}}
  for s in S:
    next_state_action_value = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    # mean next value when considering the strategy of the other player. Output: array of size m^i(s)
    next_v['1'][str(s)] = torch.einsum('ij,j->i', next_state_action_value[:,:,0], strategy_2).view(-1,1)
    next_v['2'][str(s)] = torch.einsum('ij,i->j', next_state_action_value[:,:,1], strategy_1).view(-1,1)
  bellman_projection_dic = {'1':{}, '2':{}}
  for s, player_id in state_player_pairs():
    bellman_projection_dic[str(player_id)][str(s)] = r_mean[str(player_id)][str(s)] + beta * next_v[str(player_id)][str(s)]
  return bellman_projection_dic


def cost_vector_fixed_policies(pi):
  P = transition_matrix(pi)
  cost_vector = (1 - beta * P.sum(0)).view(-1,1).detach().cpu().numpy() / N_S
  return cost_vector

def reward_baselines(RM, pi):
  r_mean = expected_reward(RM, pi)
  r_baseline = r_mean.mean(0).view(-1,1).detach().cpu().numpy()
  return r_baseline

def restriction_matrices_fixed_policies(pi):
  transition_matrices = partial_transition_matrices(pi)
  restriction_matrices = {'1':[], '2':[]}
  for s, player_id in state_player_pairs():
    temp_matrix = - beta * transition_matrices[str(s)][str(player_id)]
    temp_matrix[:, get_state_index(s)] = temp_matrix[:, get_state_index(s)] + 1
    restriction_matrices[str(player_id)].append(temp_matrix)
  for player in players_str():
    restriction_matrices[player] = -torch.cat(restriction_matrices[player], dim=0).detach().cpu().numpy()
  return restriction_matrices

def restriction_vectors_fixed_policies(RM, pi, alpha=0.1):
  r_mean = partial_expected_reward(RM, pi)
  restriction_vectors = {'1':[], '2':[]}
  for s, player in state_player_str_pairs():
    restriction_vectors[player].append(r_mean[player][s].view(-1,1))
  for player in players_str():
    restriction_vectors[player] = -(torch.cat(restriction_vectors[player], dim=0)+alpha).detach().cpu().numpy()
  return restriction_vectors


def parameters_fixed_policies(game_, alpha):
  pi = game_()[0]
  c = cost_vector_fixed_policies(pi)
  f0 = reward_baselines(RM, pi)
  A_ub = restriction_matrices_fixed_policies(pi)
  b_ub = restriction_vectors_fixed_policies(RM, pi, alpha)
  return c, f0, A_ub, b_ub


#-------------------------------------------------------------------------------

def calculate_nash_restrictions(pi, v):
  q_estimated = bellman_partial_projection(RM, pi, v) # Dic. with an array of 'q'-values for each agent 
  g_nash = {'1':{}, '2':{}}
  for s, player_id in state_player_pairs():
    g_nash[str(player_id)][str(s)] = q_estimated[str(player_id)][str(s)] - v[get_state_index(s), player_id-1]
  return g_nash


def lagrangian_nash(pi, v, lambda_nash, c=1):
  # Calculation of original target function: Bellman approximation error 
  v_estimated = bellman_projection(RM, pi, v)
  f_bellman = (v - v_estimated).sum()

  # Calculation of restrictions
  g_nash = calculate_nash_restrictions(pi, v)

  # Calculation of shifted restrictions  
  nash_restriction_products = {'1':{}, '2':{}}  
  for s, player_id in state_player_str_pairs():
    nash_restriction_products[player_id][s] = (g_nash[player_id][s].view(-1,1) * lambda_nash[player_id][s].view(-1,1)) 
                                                #+ c * g_nash[player_id][s].view(-1,1).pow(2) * lambda_nash[player_id][s].view(-1,1).detach())
    
  # Calculation of 2-norms  
  nash_restriction_sum = 0.0
  for s, player_id in state_player_str_pairs():
    nash_restriction_sum = nash_restriction_sum + nash_restriction_products[player_id][s].sum()
    
  # Calculation of augmented lagrangian
  lagrangian = f_bellman + nash_restriction_sum
  
  return lagrangian, f_bellman


def check_KKT_conditions(pi, v, lambda_nash, tol=1e-8):
  with torch.no_grad():
    # Calculate restrictions
    g_nash = calculate_nash_restrictions(pi, v)

    g_nash_satisfied = True
    product_zero_satisfied = True
    
    max_g_nash = -np.infty
    max_product_zero = -np.infty
    
    for s, player_id in state_player_str_pairs():
      g_nash_satisfied = g_nash_satisfied and torch.all(g_nash[player_id][s] <= 0)
      max_g_nash = max(max_g_nash, g_nash[player_id][s].max())

      lambda_g_nash_product = g_nash[player_id][s].view(-1) * lambda_nash[player_id][s].view(-1)
      product_zero_satisfied = product_zero_satisfied and torch.all(lambda_g_nash_product.abs() <= tol)
      max_product_zero = max(max_product_zero, lambda_g_nash_product.abs().max())

    return g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero


def check_nash_conditions(pi, v, lambda_nash):
  with torch.no_grad():
    # Calculate restrictions
    g_nash = calculate_nash_restrictions(pi, v)
    g_nash_satisfied = True
    max_g_nash = -np.infty
    
    for s, player_id in state_player_str_pairs():
      g_nash_satisfied = g_nash_satisfied and torch.all(g_nash[player_id][s] <= 0)
      max_g_nash = max(max_g_nash, g_nash[player_id][s].max())

    return g_nash_satisfied, max_g_nash
  

def optimize_game(game_, n_epochs, optimizers, params_dual, 
                  print_each=50, save_each=10):
  losses = []
  costs = []
  delta_cost = 0.0
  previous_cost = np.infty
  opti_primal, opti_dual = optimizers
  
  for epoch in range(0, n_epochs):
    pi, v, lambda_nash = game_()
    loss, cost = lagrangian_nash(pi, v, lambda_nash)
    delta_cost = cost.item() - previous_cost

    previous_cost = cost.item()
    torch.save(game_.state_dict(), './game_temp.pth')

    opti_primal.zero_grad()
    opti_dual.zero_grad()    
    loss.backward()
    
    for p in params_dual:
      if p.grad is not None:
        p.grad.data.mul_(-1)

    opti_primal.step()
    opti_dual.step()

    if (epoch + 1) % print_each == 0:
      with torch.no_grad():
        nash_satisfied, max_nash_restrictions, KKT_zero_product_satisfied, max_KKT_product = check_KKT_conditions(pi, v, lambda_nash)
      print("Epoch: {}, Loss: {:.3f}, f: {:.3f}, nash: {}, max g: {:.3f}, KKT product: {}, max product: {:.3e}".format(
          epoch+1,loss.item(),cost.item(), nash_satisfied, max_nash_restrictions,
          KKT_zero_product_satisfied, max_KKT_product))

    if (epoch + 1) % save_each == 0:
      losses.append(loss.item())
      costs.append(cost.item())
  return losses, costs

def calculate_initial_v(game_):
  v0 = np.zeros((N_S,2))
  c, f0, A_ub, b_ub = parameters_fixed_policies(game_, alpha=10)
  for player_id in players():
    temp_res = linprog(c, A_ub=A_ub[str(player_id)], b_ub=b_ub[str(player_id)])
    v0[:,player_id-1] = temp_res.x
  return v0

#-------------------------------------------------------------------------------
def nash_error_gradient(pi):
  P = transition_matrix(pi)
  nash_error_grad = (torch.eye(N_S).to('cuda') - torch.t(P)).sum(1, keepdim=True)
  return nash_error_grad

#**Optimization**

In [None]:
game_b = game().to('cuda')
params_primal = [game_b.v] + list(game_b.log_pi1.parameters()) + list(game_b.log_pi2.parameters())
params_dual = list(game_b.sqrt_lambda_nash1.parameters()) + list(game_b.sqrt_lambda_nash2.parameters())
optim_primal = optim.Adam(params_primal, lr=1e-3)
optim_dual = optim.Adam(params_dual, lr=1e-4)
optimizers = [optim_primal, optim_dual]

In [None]:
v0 = calculate_initial_v(game_b)
v0 = torch.FloatTensor(v0).to('cuda')
game_b.v.data.add_(v0)

tensor([[1163.1011, 1163.1011],
        [1160.3406, 1161.0554],
        [1161.0554, 1160.3406],
        [1161.4701, 1161.9601],
        [1161.4701, 1161.4601],
        [1161.4701, 1160.9601],
        [1161.9851, 1160.4601],
        [1162.9851, 1159.9601],
        [1163.9851, 1159.4601],
        [1161.4701, 1159.9801],
        [1161.4701, 1159.4801],
        [1162.4701, 1158.9801],
        [1163.4701, 1158.4801],
        [1164.4701, 1157.9801],
        [1165.4701, 1157.4801],
        [1161.9601, 1161.4701],
        [1161.4601, 1161.4701],
        [1160.9601, 1161.4701],
        [1160.4601, 1161.9851],
        [1159.9601, 1162.9851],
        [1159.4601, 1163.9851],
        [1159.9801, 1161.4701],
        [1159.4801, 1161.4701],
        [1158.9801, 1162.4701],
        [1158.4801, 1163.4701],
        [1157.9801, 1164.4701],
        [1157.4801, 1165.4701],
        [1164.0911, 1160.5911],
        [1160.0911, 1162.0911],
        [1160.5911, 1164.0911],
        [1162.0911, 1160.0911]], device=

In [None]:
alpha = 0.5
gamma_0 = 0.5
eta = 1e-2
nu = 1/0.9999
rho_0 = 1.0

In [None]:
with torch.no_grad():
  pi,v,_ = game_b()
  grad_f_v = nash_error_gradient(pi)

In [None]:
losses, costs = optimize_game(game_b, 20000, optimizers, params_dual)

In [None]:
losses2, costs2 = optimize_game(game_b, 23000, optimizers, params_dual)

In [None]:
dloss = np.array(losses[1:])-np.array(losses[:-1])
dcost = np.array(costs[1:])-np.array(costs[:-1])

fig, ax = plt.subplots(2,2, figsize=(12,8))
ax[0,0].plot(np.array(losses))
ax[0,1].plot(np.array(costs))
ax[1,0].plot(np.array(dloss))
ax[1,1].plot(np.array(dcost))
plt.show()
plt.close()

print(np.argmax(dloss), np.argmax(dcost))

In [None]:
with torch.no_grad():
  pi, v = game_b()
lambda_nash = update_nash_dual_variables(pi, v, lambda_nash)

In [None]:
losses2, costs2 = optimize_policy_values(game_b, [lambda_nash], 2000, optimizer)

In [None]:
dloss2 = np.array(losses2[1:])-np.array(losses2[:-1])
dcost2 = np.array(costs2[1:])-np.array(costs2[:-1])

fig, ax = plt.subplots(2,2, figsize=(12,8))
ax[0,0].plot(np.array(losses2))
ax[0,1].plot(np.array(costs2))
ax[1,0].plot(np.array(dloss2))
ax[1,1].plot(np.array(dcost2))
plt.show()
plt.close()

print(np.argmax(dloss2), np.argmax(dcost2))

In [None]:
torch.save(game_b.state_dict(), './game.pth')
pickle.dump(lambda_nash, open('./lambda_nash.p', 'wb'))

In [None]:
with torch.no_grad():
  pi, v = game_b()
lambda_nash = update_nash_dual_variables(pi, v, lambda_nash)