<a href="https://colab.research.google.com/github/tarod13/Stochastic_Games/blob/master/stochastic_games_herkovitz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import numpy as np
from scipy.optimize import linprog
import matplotlib.pyplot as plt
import itertools
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter

In [3]:
beta = 0.99

M = 5
N_P = 2
N_A = 2
N_S = 3 + N_P * (2*(M+1) + 2)
G = [('G',)]
O = [('O',i) for i in range(1,N_P+1)]
E1 = [('E1',i,j) for i in range(0,N_P) for j in range(0,M+1)]
E2 = [('E2',i,j) for i in range(0,N_P) for j in range(0,M+1)]
R1 = [('R1',i) for i in range(0,N_P)]
R2 = [('R2',i) for i in range(0,N_P)]
S = list(itertools.chain(G, O, E1, E2, R1, R2)) 
S1 = list(itertools.chain(G, ('O',1), E1, R1))
S2 = list(itertools.chain([('O',2), E2, R2]))

RG1 = np.array([[3.,0.],[5.,1.]])
RG2 = np.array([[3.,5.],[0.,1.]])

In [4]:
def get_state_index(state):
  if state in S:
    return S.index(state)
  elif state in S_str:
    return S_str.index(state)
  else:
    assert 0 == 1, 'Invalid state'

def player_dim(i):
  if i in [1, '1']:
    return 'i'
  elif i in [2, '2']:
    return 'j'
  else:
    assert 0 == 1, 'Invalid player id'
  
def get_player_id(player):
  if player == '1':
    return 0
  elif player == '2':
    return 1
  else:
    assert 0 == 1, 'Invalid player'

def players():
  return range(1,2+1)

def players_str():
  return iter(['1','2']) 

def state_player_pairs():
  return itertools.product(S, players())

S_str = [str(s) for s in S]
def state_player_str_pairs():
  return itertools.product(S_str, players_str())

In [5]:
class log_game(nn.Module):
  def __init__(self):
    super().__init__()
    self.transform_type = 'log'
    self.log_pi1 = nn.ParameterDict()
    self.log_pi1[str(('G',))] = Parameter(torch.Tensor(N_A,1))
    nn.init.zeros_(self.log_pi1[str(('G',))])
    self.log_pi1[str(('O',1))] = Parameter(torch.Tensor((M+1)*N_A,1))
    nn.init.zeros_(self.log_pi1[str(('O',1))])    
    self.log_pi1[str(('O',2))] = Parameter(torch.Tensor(1,1))
    nn.init.ones_(self.log_pi1[str(('O',2))])

    for state in E1:
      self.log_pi1[str(state)] = Parameter(torch.Tensor(2,1))
      nn.init.zeros_(self.log_pi1[str(state)])
      
    for state in E2 + R2:
      self.log_pi1[str(state)] = Parameter(torch.Tensor(1,1))
      nn.init.ones_(self.log_pi1[str(state)])

    for state in R1:
      self.log_pi1[str(state)] = Parameter(torch.Tensor(N_A,1))
      nn.init.zeros_(self.log_pi1[str(state)])

    self.log_pi2 = nn.ParameterDict()
    self.log_pi2[str(('G',))] = Parameter(torch.Tensor(N_A,1))
    nn.init.zeros_(self.log_pi2[str(('G',))])
    self.log_pi2[str(('O',1))] = Parameter(torch.Tensor(1,1))
    nn.init.ones_(self.log_pi2[str(('O',1))])
    self.log_pi2[str(('O',2))] = Parameter(torch.Tensor((M+1)*N_A,1))
    nn.init.zeros_(self.log_pi2[str(('O',2))])

    for state in E2:
      self.log_pi2[str(state)] = Parameter(torch.Tensor(2,1))
      nn.init.zeros_(self.log_pi2[str(state)])
      
    for state in E1 + R1:
      self.log_pi2[str(state)] = Parameter(torch.Tensor(1,1))
      nn.init.ones_(self.log_pi2[str(state)])

    for state in R2:
      self.log_pi2[str(state)] = Parameter(torch.Tensor(N_A,1))
      nn.init.zeros_(self.log_pi2[str(state)])

    self.v = Parameter(torch.Tensor(N_S,2))
    nn.init.zeros_(self.v)    

    self.sqrt_lambda_nash1 = nn.ParameterDict()
    self.sqrt_lambda_nash2 = nn.ParameterDict()  # dual vars. for Nash inequalities
    for s in S:
        self.sqrt_lambda_nash1[str(s)] = Parameter(torch.Tensor(self.log_pi1[str(s)].detach().size()))
        self.sqrt_lambda_nash2[str(s)] = Parameter(torch.Tensor(self.log_pi2[str(s)].detach().size()))
        nn.init.zeros_(self.sqrt_lambda_nash1[str(s)])
        #self.sqrt_lambda_nash1[str(s)].data.mul_(1.0)
        nn.init.zeros_(self.sqrt_lambda_nash2[str(s)])
        #self.sqrt_lambda_nash2[str(s)].data.mul_(1.0)

    for param in self.parameters():  
      nn.init.uniform_(param)

  def forward(self):
    return self.pi(), self.v, self.lambda_nash()
  
  def pi(self):
    pi = {'1':{}, '2':{}}
    for s in S:
      log_pi1 = self.log_pi1[str(s)]
      log_pi2 = self.log_pi2[str(s)]
      
      log_pi1_shift = log_pi1 - log_pi1.max()
      log_pi2_shift = log_pi2 - log_pi2.max()

      pi1 = torch.exp(log_pi1_shift + 1e-20)
      pi2 = torch.exp(log_pi2_shift + 1e-20)
      
      pi['1'][str(s)] = pi1 / pi1.sum()
      pi['2'][str(s)] = pi2 / pi2.sum()
    return pi
  
  def lambda_nash(self):
    lambda_ = {'1':{}, '2':{}}
    for s in S:
      lambda_['1'][str(s)] = self.sqrt_lambda_nash1[str(s)].pow(2)
      lambda_['2'][str(s)] = self.sqrt_lambda_nash2[str(s)].pow(2)
    return lambda_

In [6]:
class nolast_game(nn.Module):
  def __init__(self):
    super().__init__()
    self.transform_type = 'nolast'
    self.pi_nolast1 = nn.ParameterDict()
    self.pi_nolast1[str(('G',))] = Parameter(torch.Tensor(N_A-1,1))
    nn.init.constant_(self.pi_nolast1[str(('G',))], 1.0/N_A)
    self.pi_nolast1[str(('O',1))] = Parameter(torch.Tensor((M+1)*N_A-1,1))
    nn.init.constant_(self.pi_nolast1[str(('O',1))], 1.0/((M+1)*N_A))    
    self.pi_nolast1[str(('O',2))] = Parameter(torch.Tensor(1-1,1))
    
    for state in E1:
      self.pi_nolast1[str(state)] = Parameter(torch.Tensor(2-1,1))
      nn.init.constant_(self.pi_nolast1[str(state)], 1.0/2.0)
      
    for state in E2 + R2:
      self.pi_nolast1[str(state)] = Parameter(torch.Tensor(1-1,1))

    for state in R1:
      self.pi_nolast1[str(state)] = Parameter(torch.Tensor(N_A-1,1))
      nn.init.constant_(self.pi_nolast1[str(state)], 1.0/N_A)

    self.pi_nolast2 = nn.ParameterDict()
    self.pi_nolast2[str(('G',))] = Parameter(torch.Tensor(N_A-1,1))
    nn.init.constant_(self.pi_nolast2[str(('G',))], 1.0/N_A)
    self.pi_nolast2[str(('O',1))] = Parameter(torch.Tensor(1-1,1))
    self.pi_nolast2[str(('O',2))] = Parameter(torch.Tensor((M+1)*N_A-1,1))
    nn.init.constant_(self.pi_nolast2[str(('O',2))], 1.0/((M+1)*N_A))

    for state in E2:
      self.pi_nolast2[str(state)] = Parameter(torch.Tensor(2-1,1))
      nn.init.constant_(self.pi_nolast2[str(state)], 1.0/2.0)
      
    for state in E1 + R1:
      self.pi_nolast2[str(state)] = Parameter(torch.Tensor(1-1,1))

    for state in R2:
      self.pi_nolast2[str(state)] = Parameter(torch.Tensor(N_A-1,1))
      nn.init.constant_(self.pi_nolast2[str(state)], 1.0/N_A)

    self.v = Parameter(torch.Tensor(N_S,2))
    nn.init.zeros_(self.v)    

  def forward(self):
    return self.pi(), self.v
  
  def pi(self):
    pi = {'1':{}, '2':{}}
    for s in S:
      NA1 = self.pi_nolast1[str(s)].shape[0]+1
      pi1 = torch.zeros(NA1,1).to('cuda')
      if NA1 > 1:
        pi1[:-1,:] = self.pi_nolast1[str(s)]
      pi1[-1,0] = 1.0 - pi1.sum()
      
      NA2 = self.pi_nolast2[str(s)].shape[0]+1
      pi2 = torch.zeros(NA2,1).to('cuda')
      if NA2 > 1:
        pi2[:-1,:] = self.pi_nolast2[str(s)]
      pi2[-1,0] = 1.0 - pi2.sum()

      pi['1'][str(s)] = pi1
      pi['2'][str(s)] = pi2
    return pi

  def pi_discriminated(self):
    pi_nolast = {'1':[], '2':[]}
    pi_last = {'1':[], '2':[]}
    for s in S_str:
      if self.pi_nolast1[s].shape[0] >= 1:
        pi_nolast['1'].append(self.pi_nolast1[s])
        pi_last['1'].append(1.0 - self.pi_nolast1[s].sum())
      if self.pi_nolast2[s].shape[0] >= 1:
        pi_nolast['2'].append(self.pi_nolast2[s])
        pi_last['2'].append(1.0 - self.pi_nolast2[s].sum())
    for player in players_str():
      pi_nolast[player] = torch.cat(pi_nolast[player])
      pi_last[player] = torch.FloatTensor(pi_last[player]).view(-1,1).to('cuda')
    return pi_nolast, pi_last 

In [7]:
game_b = nolast_game().to('cuda')
optimizer = optim.Adam(game_b.parameters(), lr=1e-3)

N_A_S = {}
N_A_S_tensors = {'1':[], '2':[]}
N_A_S_tensors_reduced = {'1':[], '2':[]}
with torch.no_grad():
  pi, v = game_b()
  for state in S:
    N_A1_state = pi['1'][str(state)].shape[0]
    N_A2_state = pi['2'][str(state)].shape[0]
    N_A_S[str(state)] = {'1': N_A1_state, '2': N_A2_state}
    N_A_S_tensors['1'].append(N_A1_state)
    N_A_S_tensors['2'].append(N_A2_state)
    if N_A1_state >= 2:
      N_A_S_tensors_reduced['1'].append(N_A1_state)
    if N_A2_state >= 2:
      N_A_S_tensors_reduced['2'].append(N_A2_state)
for player in players_str():
  N_A_S_tensors[player] = torch.FloatTensor(N_A_S_tensors[player]).to('cuda').view(-1,1)
  N_A_S_tensors_reduced[player] = torch.FloatTensor(N_A_S_tensors_reduced[player]).to('cuda').view(-1,1)

In [8]:
def offer_accepted(action):
  if action == 0:
    return True
  else:
    return False
  

def other_player(i):
  if i == 1:
    return 2
  elif i == 2:
    return 1
  elif i == '1':
    return '2'
  elif i == '2':
    return '1'
  else:
    assert 0 == 1, 'Invalid player id'


def transition_info(state, actions):
  if 'G' in state or 'R1' in state or 'R2' in state:
    T = [(('O',1), 0.5), (('O',2), 0.5)]
    return ('R', T)
  elif 'O' in state:
    _, id_player = state
    offeral = actions[id_player-1] // N_A
    action_requested = actions[id_player-1] % N_A
    return ('D', ('E'+str(other_player(id_player)), action_requested, offeral))
  elif 'E1' in state:
    _, action_requested, _ = state
    if offer_accepted(actions[0]):
      return ('D', ('R2', action_requested))
    else:
      return ('D', ('G',))
  elif 'E2' in state:
    _, action_requested, _ = state
    if offer_accepted(actions[1]):
      return ('D', ('R1', action_requested))
    else:
      return ('D', ('G',))   
  else:
    assert 0 == 1, 'Invalid state' 

In [9]:
transition_types = {}
for state in S:
  if 'G' in state or 'R1' in state or 'R2' in state:
    transition_types[str(state)] = (0,0) 
  else:
    transition_types[str(state)] = (1,1) # entries correspond to deterministic behaviour and dependence on actions

In [10]:
def rewards(state, actions):
  if 'G' in state:
    r1 = RG1[actions[0], actions[1]]
    r2 = RG2[actions[0], actions[1]]
  elif 'O' in state:
    r1 = r2 = 0.0
  elif 'E1' in state:
    _, _, offeral = state
    if offer_accepted(actions[0]):
      r1 = offeral
      r2 = -offeral
    else:
      r1 = r2 = 0
  elif 'E2' in state:
    _, _, offeral = state
    if offer_accepted(actions[1]):
      r1 = -offeral
      r2 = offeral
    else:
      r1 = r2 = 0
  elif 'R1' in state:
    _, i = state
    r1 = RG1[actions[0], i]
    r2 = RG2[actions[0], i]
  elif 'R2' in state:
    _, i = state
    r1 = RG1[i, actions[1]]
    r2 = RG2[i, actions[1]]
  else:
    assert 0 == 1, 'Invalid state'

  return r1, r2


def reward_matrices(s):
  if 'G' in s:
    RM1 = RG1.copy()
    RM2 = RG2.copy()
  else:
    N_A1 = N_A_S[str(s)]['1']
    N_A2 = N_A_S[str(s)]['2']
    RM1 = np.zeros((N_A1,N_A2))
    RM2 = np.zeros((N_A1,N_A2))
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        r1, r2 = rewards(s, [a1,a2])
        RM1[a1,a2] = r1
        RM2[a1,a2] = r2
  return RM1, RM2  

In [11]:
RM = {}
RM['1'] = {}
RM['2'] = {}
for s in S:
  RM1, RM2 = reward_matrices(s)
  RM['1'][str(s)] = torch.FloatTensor(RM1).to('cuda')
  RM['2'][str(s)] = torch.FloatTensor(RM2).to('cuda')

In [12]:
def player_consistent_reward_matrices():
  consistent_RM = {'1':{}, '2':{}}
  for s in S_str:
    consistent_RM['1'][s] = RM['1'][s].clone()
    consistent_RM['2'][s] = torch.t(RM['2'][s].clone())  
  return consistent_RM 

In [13]:
def next_value_matrices(s, v): # TODO: consider other 2 cases
  det, dep = transition_types[str(s)]
  N_A1 = N_A_S[str(s)]['1']
  N_A2 = N_A_S[str(s)]['2']
  vs = torch.zeros((N_A1,N_A2,2)).to('cuda')
  if det and dep:
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        _, next_state = transition_info(s, [a1,a2])
        vs[a1,a2,:] = v[get_state_index(next_state),:]
  elif (not det) and (not dep):
    _, transition_dic = transition_info(s, [])
    next_v = torch.zeros(1,2).to('cuda')
    for next_state, transition_prob in transition_dic:
      next_v = next_v + v[get_state_index(next_state),:].view(1,-1) * transition_prob      
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        vs[a1,a2,:] = next_v.view(-1)
  return vs

def next_value_dictionary(v):
  next_v_dic = {}
  for s in S:
    next_v_dic[str(s)] = next_value_matrices(s, v)
  return next_v_dic


def transition_matrix(pi): # TODO: consider other 2 cases
  transition_matrix = torch.zeros((N_S,N_S)).to('cuda')
  for state in S:
    strategy_1 = pi['1'][str(state)]
    strategy_2 = pi['2'][str(state)]

    det, dep = transition_types[str(state)]
    N_A1 = strategy_1.shape[0]
    N_A2 = strategy_2.shape[0]
    id_s = get_state_index(state)

    if det and dep:
      for a1 in range(0,N_A1):
        for a2 in range(0,N_A2):
          _, next_state = transition_info(state, [a1,a2])
          id_ns = get_state_index(next_state)
          transition_prob = strategy_1[a1,0] * strategy_2[a2,0]
          transition_matrix[id_s, id_ns] = transition_matrix[id_s, id_ns] + transition_prob
    elif (not det) and (not dep):
      _, transition_dic = transition_info(state, [])
      for next_state, transition_prob in transition_dic:
        transition_matrix[id_s, get_state_index(next_state)] = transition_prob
  return transition_matrix


def partial_transition_matrices(pi): # TODO: consider other 2 cases
  # Create dictionary of transition matrices for each state given
  # the strategy of the other player 
  transition_matrices = {}
  for s in S:
    strategy_1 = pi['1'][str(s)]
    strategy_2 = pi['2'][str(s)]
    N_A1 = strategy_1.shape[0]
    N_A2 = strategy_2.shape[0]
    transition_matrices[str(s)] = {
        '1': torch.zeros((N_A1,N_S)).to('cuda'),
        '2': torch.zeros((N_A2,N_S)).to('cuda')
    }

    # Fill matrices with transition probabilities depending on the type
    # of transition, i.e., if deterministic or random and independent or
    # not on the actions
    det, dep = transition_types[str(s)]
    if det and dep:
      for a1 in range(0,N_A1):
        for a2 in range(0,N_A2):
          _, next_state = transition_info(s, [a1,a2])
          id_ns = get_state_index(next_state)
          transition_prob1 = strategy_2[a2,0]
          transition_prob2 = strategy_1[a1,0]
          transition_matrices[str(s)]['1'][a1, id_ns] = (
              transition_matrices[str(s)]['1'][a1, id_ns] + transition_prob1)
          transition_matrices[str(s)]['2'][a2, id_ns] = (
              transition_matrices[str(s)]['2'][a2, id_ns] + transition_prob2)
    elif (not det) and (not dep):
      _, transition_dic = transition_info(s, [])
      for next_state, transition_prob in transition_dic:
        transition_matrices[str(s)]['1'][:, get_state_index(next_state)] = transition_prob
        transition_matrices[str(s)]['2'][:, get_state_index(next_state)] = transition_prob
  return transition_matrices


def expected_reward(RM, pi):
  r_mean = torch.zeros((N_S,2)).to('cuda')
  for i in range(1,2+1):
    RM_i = RM[str(i)]
    for s in S:
      strategy_1 = pi['1'][str(s)]
      strategy_2 = pi['2'][str(s)]
      r_mean_1 = torch.einsum('ij,ik->jk', RM_i[str(s)], strategy_1)
      r_mean[get_state_index(s),i-1] = (r_mean_1 * strategy_2).sum()
  return r_mean


def partial_expected_reward_other(RM, pi):
  r_mean = {}
  for player in players_str():
    r_mean[player] = {}

  for s, player in state_player_str_pairs():
      N_A = N_A_S[s][player]
      RM_i = RM[player]
      other_player_ = other_player(player)
      strategy = pi[other_player_][s].view(-1)
      formula = 'ij,'+player_dim(other_player_)+'->'+player_dim(player)
      r_mean[player][s] = torch.einsum(formula, RM_i[s], strategy).view(-1,1)
  return r_mean


def partial_expected_reward(RM, pi):
  # Create reward dictionary for each combination of players
  r_mean = {}
  for player in players_str():  
    r_mean[player] = {'1':{}, '2':{}}

  # Calculate expected reward for combination of players wrt the policy of one of the players
  for s, player in state_player_str_pairs():
    other_player_ = other_player(player) # Player used to calculate expected reward
    for second_player in players_str():  
      N_A = N_A_S[s][second_player]
      RM_i = RM[second_player] # Reward matrix for one of the players  
      strategy = pi[other_player_][s].view(-1)
      formula = 'ij,'+player_dim(other_player_)+'->'+player_dim(player)
      r_mean[player][second_player][s] = torch.einsum(formula, RM_i[s], strategy).view(-1,1)
  return r_mean


def bellman_projection(RM, pi, v):
  r_mean = expected_reward(RM, pi)
  next_v = torch.zeros((N_S,2)).to('cuda')
  for s in S:
    next_state_value_matrix = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    next_value_1 = torch.einsum('ijk,i->jk', next_state_value_matrix, strategy_1)
    next_v[get_state_index(s),:] = torch.einsum('jk,j->k', next_value_1, strategy_2)
  return r_mean + beta * next_v


def bellman_partial_projection_other(RM, pi, v):
  r_mean = partial_expected_reward_other(RM, pi)
  next_v = {'1':{}, '2':{}}
  for s in S:
    next_state_value_matrix = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    # mean next value when considering the strategy of the other player. Output: array of size m^i(s)
    next_v['1'][str(s)] = torch.einsum('ij,j->i', next_state_value_matrix[:,:,0], strategy_2).view(-1,1)
    next_v['2'][str(s)] = torch.einsum('ij,i->j', next_state_value_matrix[:,:,1], strategy_1).view(-1,1)
  bellman_projection_dic = {'1':{}, '2':{}}
  for s, player_id in state_player_pairs():
    bellman_projection_dic[str(player_id)][str(s)] = r_mean[str(player_id)][str(s)] + beta * next_v[str(player_id)][str(s)]
  return bellman_projection_dic


def partial_next_values(pi, v):
  # Create next-value dictionary for each player combination
  next_v = {'1':{}, '2':{}}
  for player in players_str():
    next_v[player] = {'1':{}, '2':{}}

  # Fill dictionary 
  for s in S:
    next_state_value_matrix = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    
    # Calculate mean next value when considering the strategy of one of the players. Output: array of size m^i(s)
    for player_id in players():
      next_v['1'][str(player_id)][str(s)] = torch.einsum('ij,j->i', next_state_value_matrix[:,:,player_id-1], strategy_2).view(-1,1)
      next_v['2'][str(player_id)][str(s)] = torch.einsum('ij,i->j', next_state_value_matrix[:,:,player_id-1], strategy_1).view(-1,1)  
  return next_v


def bellman_partial_projection(RM, pi, v):
  r_mean = partial_expected_reward(RM, pi)
  next_v = partial_next_values(pi, v)

  bellman_projection_dic = {'1':{}, '2':{}}
  for player in players_str():
    bellman_projection_dic[player] = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    for second_player in players_str():
      bellman_projection_dic[player][second_player][s] = (
          r_mean[player][second_player][s] + beta * next_v[player][second_player][s]
          )
  return bellman_projection_dic


def reward_baselines(RM, pi):
  r_mean = expected_reward(RM, pi)
  r_baseline = r_mean.mean(0).view(-1,1).detach().cpu().numpy()
  return r_baseline

In [14]:
def cost_vector_fixed_policies(pi):
  P = transition_matrix(pi)
  cost_vector = (1 - beta * P.sum(0)).view(-1,1).detach().cpu().numpy() / N_S
  return cost_vector

def restriction_matrices_fixed_policies(pi):
  transition_matrices = partial_transition_matrices(pi)
  restriction_matrices = {'1':[], '2':[]}
  for s, player_id in state_player_pairs():
    temp_matrix = - beta * transition_matrices[str(s)][str(player_id)]
    temp_matrix[:, get_state_index(s)] = temp_matrix[:, get_state_index(s)] + 1
    restriction_matrices[str(player_id)].append(temp_matrix)
  for player in players_str():
    restriction_matrices[player] = -torch.cat(restriction_matrices[player], dim=0).detach().cpu().numpy()
  return restriction_matrices

def restriction_vectors_fixed_policies(RM, pi, alpha=0.1):
  r_mean = partial_expected_reward_other(RM, pi)
  restriction_vectors = {'1':[], '2':[]}
  for s, player in state_player_str_pairs():
    restriction_vectors[player].append(r_mean[player][s].view(-1,1))
  for player in players_str():
    restriction_vectors[player] = -(torch.cat(restriction_vectors[player], dim=0)+alpha).detach().cpu().numpy()
  return restriction_vectors


def parameters_fixed_policies(game_, alpha):
  pi = game_()[0]
  c = cost_vector_fixed_policies(pi)
  f0 = reward_baselines(RM, pi)
  A_ub = restriction_matrices_fixed_policies(pi)
  b_ub = restriction_vectors_fixed_policies(RM, pi, alpha)
  return c, f0, A_ub, b_ub


def calculate_initial_v(game_, alpha=10):
  v0 = np.zeros((N_S,2))
  c, f0, A_ub, b_ub = parameters_fixed_policies(game_, alpha=alpha)
  for player_id in players():
    temp_res = linprog(c, A_ub=A_ub[str(player_id)], b_ub=b_ub[str(player_id)])
    v0[:,player_id-1] = temp_res.x
  return v0

In [15]:
def calculate_nash_restrictions(pi, v):
  q_estimated = bellman_partial_projection_other(RM, pi, v) # Dic. with an array of 'q'-values for each agent 
  g_nash = {'1':{}, '2':{}}
  for s, player_id in state_player_pairs():
    g_nash[str(player_id)][str(s)] = q_estimated[str(player_id)][str(s)] - v[get_state_index(s), player_id-1]
  return g_nash


def calculate_bellman_error(pi, v):
  # Calculation of original target function: Bellman approximation error 
  v_estimated = bellman_projection(RM, pi, v)
  f_bellman = (v - v_estimated).sum()
  return f_bellman


def lagrangian_nash(pi, v, lambda_nash, c=1):
  # Calculation of original target function: Bellman approximation error 
  v_estimated = bellman_projection(RM, pi, v)
  f_bellman = (v - v_estimated).sum()

  # Calculation of restrictions
  g_nash = calculate_nash_restrictions(pi, v)

  # Calculation of shifted restrictions  
  nash_restriction_products = {'1':{}, '2':{}}  
  for s, player_id in state_player_str_pairs():
    nash_restriction_products[player_id][s] = (g_nash[player_id][s].view(-1,1) * lambda_nash[player_id][s].view(-1,1)) 
                                                #+ c * g_nash[player_id][s].view(-1,1).pow(2) * lambda_nash[player_id][s].view(-1,1).detach())    
  # Calculation of 2-norms  
  nash_restriction_sum = 0.0
  for s, player_id in state_player_str_pairs():
    nash_restriction_sum = nash_restriction_sum + nash_restriction_products[player_id][s].sum()
    
  # Calculation of augmented lagrangian
  lagrangian = f_bellman + nash_restriction_sum  
  return lagrangian, f_bellman


def check_nash_KKT_conditions(pi, v, lambda_nash, tol=1e-8):
  with torch.no_grad():
    # Calculate restrictions
    g_nash = calculate_nash_restrictions(pi, v)

    g_nash_satisfied = True
    product_zero_satisfied = True
    
    max_g_nash = -np.infty
    max_product_zero = -np.infty
    
    for player in players_str():
      remaining_duals = lambda_nash[player].clone()
      for s in S_str:
        NA = N_A_S[s][player]
        g_nash_satisfied = g_nash_satisfied and torch.all(g_nash[player][s] <= 0)
        max_g_nash = max(max_g_nash, g_nash[player][s].max())

        lambda_g_nash_product = g_nash[player][s].view(-1) * remaining_duals[:NA,:].view(-1)
        product_zero_satisfied = product_zero_satisfied and torch.all(lambda_g_nash_product.abs() <= tol)
        max_product_zero = max(max_product_zero, lambda_g_nash_product.abs().max())
        remaining_duals = remaining_duals[NA:,:]
    return g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero


def check_pi_KKT_conditions(pi_nolast, pi_last, duals, tol=1e-8):
  with torch.no_grad():
    g_pi_plus_satisfied = True
    g_pi_one_satisfied = True

    product_zero_plus_satisfied = True
    product_zero_one_satisfied = True

    max_g_pi_plus = -np.infty
    max_g_pi_one = -np.infty

    max_product_zero_plus = -np.infty
    max_product_zero_one = -np.infty

    for player in players_str():
      g_pi_plus_satisfied = g_pi_plus_satisfied and torch.all(pi_nolast[player] >= 0)
      g_pi_one_satisfied = g_pi_one_satisfied and torch.all(pi_last[player] >= 0)

      max_g_pi_plus = max(max_g_pi_plus, -pi_nolast[player].min())
      max_g_pi_one = max(max_g_pi_one, -pi_last[player].min())

      N_A_total = int(N_A_S_tensors[player].sum().item())
      lambda_g_pi_plus_product = -pi_nolast[player].view(-1) * duals[player][N_A_total:2*N_A_total-N_S,:].view(-1) # TODO: fix for assymetric number of actions
      product_zero_plus_satisfied = product_zero_plus_satisfied and torch.all(lambda_g_pi_plus_product.abs() <= tol)
      max_product_zero_plus = max(max_product_zero_plus, lambda_g_pi_plus_product.abs().max())
      lambda_g_pi_one_product = -pi_last[player].view(-1) * duals[player][2*N_A_total-N_S:,:].view(-1)
      product_zero_one_satisfied = product_zero_one_satisfied and torch.all(lambda_g_pi_one_product.abs() <= tol)
      max_product_zero_one = max(max_product_zero_one, lambda_g_pi_one_product.abs().max())
   
    return (g_pi_plus_satisfied, max_g_pi_plus, product_zero_plus_satisfied, max_product_zero_plus, 
            g_pi_one_satisfied, max_g_pi_one, product_zero_one_satisfied, max_product_zero_one)


def check_nash_conditions(pi, v):
  with torch.no_grad():
    # Calculate restrictions
    g_nash = calculate_nash_restrictions(pi, v)
    g_nash_satisfied = True
    max_g_nash = -np.infty
    
    for s, player_id in state_player_str_pairs():
      g_nash_satisfied = g_nash_satisfied and torch.all(g_nash[player_id][s] <= 0)
      max_g_nash = max(max_g_nash, g_nash[player_id][s].max())
    return g_nash_satisfied, max_g_nash
  

# def optimize_game(game_, n_epochs, optimizers, params_dual, 
#                   print_each=50, save_each=10):
#   losses = []
#   costs = []
#   delta_cost = 0.0
#   previous_cost = np.infty
#   opti_primal, opti_dual = optimizers
  
#   for epoch in range(0, n_epochs):
#     pi, v, lambda_nash = game_()
#     loss, cost = lagrangian_nash(pi, v, lambda_nash)
#     delta_cost = cost.item() - previous_cost

#     previous_cost = cost.item()
#     torch.save(game_.state_dict(), './game_temp.pth')

#     opti_primal.zero_grad()
#     opti_dual.zero_grad()    
#     loss.backward()
    
#     for p in params_dual:
#       if p.grad is not None:
#         p.grad.data.mul_(-1)

#     opti_primal.step()
#     opti_dual.step()

#     if (epoch + 1) % print_each == 0:
#       with torch.no_grad():
#         nash_satisfied, max_nash_restrictions, KKT_zero_product_satisfied, max_KKT_product = check_KKT_conditions(pi, v, lambda_nash)
#       print("Epoch: {}, Loss: {:.3f}, f: {:.3f}, nash: {}, max g: {:.3f}, KKT product: {}, max product: {:.3e}".format(
#           epoch+1,loss.item(),cost.item(), nash_satisfied, max_nash_restrictions,
#           KKT_zero_product_satisfied, max_KKT_product))

#     if (epoch + 1) % save_each == 0:
#       losses.append(loss.item())
#       costs.append(cost.item())
#   return losses, costs

In [62]:
def bellman_error_gradients(pi, v):
  P = transition_matrix(pi)
  bellman_error_grad_v = (torch.eye(N_S).to('cuda') - torch.t(P)).sum(1, keepdim=True)

  q_individual = bellman_partial_projection(RM, pi, v)
  bellman_error_grad_pi = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    bellman_error_grad_pi[player][s] = 0.0
    for second_player in players_str():
      bellman_error_grad_pi[player][s] = bellman_error_grad_pi[player][s] - q_individual[player][second_player][s]  
  return bellman_error_grad_v, bellman_error_grad_pi

def nash_restriction_gradients(pi, v):
  P_partial = partial_transition_matrices(pi)
  next_value_dic = next_value_dictionary(v)
  consistent_RM = player_consistent_reward_matrices()

  nash_restriction_grad_v = {'1':{}, '2':{}}
  nash_restriction_grad_pi = {'1':{}, '2':{}}

  for s, player in state_player_str_pairs():
    Delta = torch.zeros_like(P_partial[s][player])
    Delta[:, get_state_index(s)] = 1.0
    nash_restriction_grad_v[player][s] = beta * P_partial[s][player]  - Delta
  
    other_player_ = other_player(player)
    next_value_matrix = next_value_dic[s][:,:,get_player_id(other_player_)]
    r_matrix = consistent_RM[other_player_][s]
    if player == '1':
      next_value_matrix = torch.t(next_value_matrix)
    nash_restriction_grad_pi[player][s] = r_matrix + beta * next_value_matrix
  return nash_restriction_grad_v, nash_restriction_grad_pi


def calculate_jacobian_pi_log_pi(pi):
  jacobian_pi_log_pi_dic = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    p = pi[player][s].view(-1)
    jacobian_pi_log_pi_dic[player][s] = torch.diag(p) - torch.einsum('i,j->ij',p,p)
  return jacobian_pi_log_pi_dic


def calculate_jacobian_pi_pi_nolast():
  jacobian_pi_pi_nolast_dic = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    NA = N_A_S[s][player]
    if NA > 1:
      jacobian_pi_pi_nolast_dic[player][s] = torch.eye(NA, m=NA-1).to('cuda')
      jacobian_pi_pi_nolast_dic[player][s][-1,:] = -1.0
    else:
      jacobian_pi_pi_nolast_dic[player][s] = None
  return jacobian_pi_pi_nolast_dic


def calculate_nash_pi_transformed_gradients(pi, grad_f_pi, grad_g_pi, transform_type='nolast'):
  grad_f_pi_transformed = {'1':{}, '2':{}}
  grad_g_pi_transformed = {'1':{}, '2':{}}

  if transform_type == 'nolast':
    jacobian_pi_pi_transformed = calculate_jacobian_pi_pi_nolast()
  elif transform_type == 'log':
    jacobian_pi_pi_transformed = calculate_jacobian_pi_log_pi(pi)
  else:
    assert 0 == 1, 'Invalid transformation'

  for s, player in state_player_str_pairs():
    gradient_f = grad_f_pi[player][s]
    jacobian_g = grad_g_pi[player][s]
    jacobian_pi = jacobian_pi_pi_transformed[player][s]

    if jacobian_pi is not None:
      grad_f_pi_transformed[player][s] = torch.einsum('ij,iw->jw', jacobian_pi, gradient_f)
      grad_g_pi_transformed[player][s] = torch.einsum('ij,jk->ik', jacobian_g, jacobian_pi)
    else:
      grad_f_pi_transformed[player][s] = None
      grad_g_pi_transformed[player][s] = None
  return grad_f_pi_transformed, grad_g_pi_transformed


def build_grad_tensors(grad_f_v, grad_f_pit, grad_g_v, grad_g_pit):
  N_A_total = {}
  N_S_reduced = {}
  for player in players_str():
    N_A_total[player] = int(N_A_S_tensors[player].sum().item()) 
    N_S_reduced[player] = N_A_S_tensors_reduced[player].view(-1).shape[0]

  n_restrictions = 2*(N_A_total['1'] + N_A_total['2']-N_S) + N_S_reduced['1'] + N_S_reduced['2']
  n_vars = N_A_total['1'] + N_A_total['2']
  grad_f_vector = torch.zeros(n_vars, 1).to('cuda')
  grad_g_matrix = torch.zeros(n_restrictions, n_vars).to('cuda')

  for i in range(0,2):
    grad_f_vector[i*N_S:(i+1)*N_S,:] = grad_f_v.clone()
  
  for player in players_str():
    grad_f_pit_list = []
    for s in S_str:
      if grad_f_pit[player][s] is not None:
        grad_f_pit_list.append(grad_f_pit[player][s].clone())  
    grad_f_pit_tensor = torch.cat(grad_f_pit_list, dim=0)
    y0 = 2*N_S + get_player_id(player) * (N_A_total['1'] - N_S)
    yf = N_S + N_A_total['1'] + get_player_id(player) * N_A_total['2']
    grad_f_vector[y0:yf,:] = grad_f_pit_tensor

  for player in players_str():
    other_player_ = other_player(player)

    grad_g_v_list = [grad_g_v[player][s].clone() for s in S_str]
    J_gi_vi = torch.cat(grad_g_v_list, dim=0)
    y0 = get_player_id(player) * N_A_total['1']
    yf = N_A_total['1'] + get_player_id(player) * N_A_total['2']
    x0 = get_player_id(player) * N_S
    xf = x0 + N_S
    grad_g_matrix[y0:yf,x0:xf] = J_gi_vi

    grad_g_pit_list = []
    s_not_added_list = []
    for s in S_str:
      if grad_g_pit[player][s] is not None:
        if len(s_not_added_list) == 0:
          grad_g_pit_list.append(grad_g_pit[player][s])
        else:
          n_null_rows = 0
          for s_ in s_not_added_list:
            n_null_rows += N_A_S[s_][other_player_]
          n_rows = n_null_rows+grad_g_pit[player][s].shape[0]
          n_cols = grad_g_pit[player][s].shape[1]
          null_rows_and_grad_g_pit = torch.zeros(n_rows, n_cols).to('cuda')
          null_rows_and_grad_g_pit[-n_rows+n_null_rows:,:] = grad_g_pit[player][s]
          grad_g_pit_list.append(null_rows_and_grad_g_pit)
          s_not_added_list = []
      else:
        s_not_added_list.append(s)
    if len(s_not_added_list) != 0:
      n_null_rows = 0
      for s_ in s_not_added_list:
        n_null_rows += N_A_S[s_][other_player_]
      n_rows = n_null_rows+grad_g_pit_list[-1].shape[0]
      n_cols = grad_g_pit_list[-1].shape[1]
      null_rows_and_grad_g_pit = torch.zeros(n_rows, n_cols).to('cuda')
      null_rows_and_grad_g_pit[:n_rows-n_null_rows,:] = grad_g_pit_list[-1].clone()
      grad_g_pit_list[-1] = null_rows_and_grad_g_pit
    
    J_gmi_pit = torch.block_diag(*grad_g_pit_list)
    y0 = get_player_id(other_player_) * N_A_total['1']
    yf = N_A_total['1'] + get_player_id(other_player_) * N_A_total['2']
    x0 = 2*N_S + get_player_id(player) * (N_A_total['1'] - N_S)
    xf = N_S + N_A_total['1'] + get_player_id(player) * (N_A_total['2'] - N_S)
    grad_g_matrix[y0:yf,x0:xf] = J_gmi_pit

    J_gpi_plus = -torch.eye(N_A_total[player] - N_S).to('cuda')
    y0 = N_A_total['1'] + N_A_total['2'] + get_player_id(player) * (N_A_total['1'] - N_S)
    yf = 2*N_A_total['1'] + N_A_total['2'] - N_S + get_player_id(player) * (N_A_total['2'] - N_S)
    grad_g_matrix[y0:yf,x0:xf] = J_gpi_plus

    grad_gpi_one_list = []
    for s in S_str:
      NA = N_A_S[s][player]
      if NA > 1:
        grad_gpi_one_list.append(torch.ones(1,NA-1).to('cuda'))

    J_gpi_one = torch.block_diag(*grad_gpi_one_list)
    y0 = 2*(N_A_total['1'] + N_A_total['2'] - N_S) + get_player_id(player) * N_S_reduced['1']
    yf = 2*(N_A_total['1'] + N_A_total['2'] - N_S) + N_S_reduced['1'] + get_player_id(player) * N_S_reduced['2']
    grad_g_matrix[y0:yf,x0:xf] = J_gpi_one

  return grad_f_vector, grad_g_matrix
  

def calculate_descent_direction(pi_nolast_vector, pi_last_vector, g_nash, 
                                grad_f_v, grad_f_pit, grad_g_v, grad_g_pit): # TODO: make compatible with log transform
  duals_0 = {}
  grad_gi_vi_matrix = {}
  A_system = {}
  b_system = {}
  A_system_dic = {'1': {}, '2': {}}
  b_system_dic = {'1': {}, '2': {}}
  for player in players_str():
    for i in range(1,3+1):
      A_system_dic[player][str(i)] = {}    

  d0_v = torch.zeros((N_S,2)).to('cuda')  
  d0_pit = {'1':{}, '2':{}}
  norm_2_d0 = 0.0

  for player in players_str():
    other_player_ = other_player(player)
    N_A_total = int(N_A_S_tensors[player].sum().item())
    
    grad_g_v_list = [grad_g_v[player][s] for s in S_str]
    J_gi_vi = torch.cat(grad_g_v_list, dim=0)
    g_nash_list = [g_nash[player][s] for s in S_str]
    restriction_diagonal = torch.diag(torch.cat(g_nash_list).view(-1))
    A_system_dic[player]['1']['1'] = (r_value * restriction_diagonal - 
                                  torch.einsum('as,bs->ab', J_gi_vi, J_gi_vi))
    A_system_dic[player]['2']['2'] = (r_value * torch.diag(-pi_nolast_vector[other_player_].view(-1)) - 
                                  torch.eye(int(N_A_total-N_S)).to('cuda'))
    A_system_dic[player]['3']['3'] = (r_value * torch.diag(-pi_last_vector[other_player_].view(-1)) - 
                                  torch.diag(N_A_S_tensors_reduced[other_player_].view(-1)-1.0))
    
    b_system_dic[player]['1'] = torch.einsum('as,sw->aw', J_gi_vi, grad_f_v)

    A_system_pit = {}
    for i in range(1,3+1):
      A_system_pit[str(i)] = {}
      for j in range(1,3+1):
        A_system_pit[str(i)][str(j)] = []  
    
    b_system_pit = {}
    for i in range(1,3+1):
        b_system_pit[str(i)] = []

    n_null_rows = 0
    for s in S_str:
      NA = N_A_S[s][other_player_]
      if NA > 1:
        A_system_pit['1']['1'].append(torch.einsum('as,bs->ab', grad_g_pit[other_player_][s], grad_g_pit[other_player_][s])) # TODO: is other_player_ ok?
        A_system_pit['2']['3'].append(-torch.ones(NA-1,1).to('cuda'))
        
        grad_g_v_times_grad_g_pi_plus = -grad_g_pit[other_player_][s]
        grad_g_v_times_grad_g_pi_one = grad_g_pit[other_player_][s].sum(1, keepdim=True)
        if n_null_rows != 0:
          zero_matrix = torch.zeros(n_null_rows, NA-1).to('cuda')
          zero_vector = torch.zeros(n_null_rows, 1).to('cuda')
          grad_g_v_times_grad_g_pi_plus = torch.cat([zero_matrix, grad_g_v_times_grad_g_pi_plus.clone()], dim=0)
          grad_g_v_times_grad_g_pi_one = torch.cat([zero_vector, grad_g_v_times_grad_g_pi_one.clone()], dim=0)          
          n_null_rows = 0

        A_system_pit['1']['2'].append(grad_g_v_times_grad_g_pi_plus)
        A_system_pit['1']['3'].append(grad_g_v_times_grad_g_pi_one)

        b_system_pit['1'].append(torch.einsum('as,sw->aw', grad_g_pit[other_player_][s], grad_f_pit[other_player_][s]))
        b_system_pit['2'].append(-grad_f_pit[other_player_][s])
        b_system_pit['3'].append(grad_f_pit[other_player_][s].sum())

      else: # If there are no gradients of g_v_sa for certain pair (s,a) - that happens when there is only one action for the other player 
        A_system_pit['1']['1'].append(torch.zeros(N_A_S[s][player], N_A_S[s][player]).to('cuda'))
        n_null_rows = n_null_rows + N_A_S[s][player] # + since there might be succeding states with no gradients
        if get_state_index(s) == (N_S-1):
          N_A_previous_s = A_system_pit['1']['2'][-1].shape[1]
          zero_matrix = torch.zeros(n_null_rows, N_A_previous_s).to('cuda')
          zero_vector = torch.zeros(n_null_rows, 1).to('cuda')
          A_system_pit['1']['2'][-1] = torch.cat([A_system_pit['1']['2'][-1].clone(), zero_matrix], dim=0)
          A_system_pit['1']['3'][-1] = torch.cat([A_system_pit['1']['3'][-1].clone(), zero_vector], dim=0)
        
        b_system_pit['1'].append(torch.zeros(N_A_S[s][player],1).to('cuda'))

    A_system_dic[player]['1']['1'] = A_system_dic[player]['1']['1'] - torch.block_diag(*A_system_pit['1']['1'])
    for i,j in [('2','3'), ('1','2'), ('1','3')]:
      A_system_dic[player][i][j] = -torch.block_diag(*A_system_pit[i][j])
    for i,j in [('3','2'), ('2','1'), ('3','1')]:
      A_system_dic[player][i][j] = -torch.t(A_system_dic[player][j][i])
    
    b_system_dic[player]['1'] = b_system_dic[player]['1'] + torch.cat(b_system_pit['1'], dim=0)   
    b_system_dic[player]['2'] = torch.cat(b_system_pit['2']) 
    b_system_dic[player]['3'] = torch.FloatTensor(b_system_pit['3']).view(-1,1).to('cuda')   

    dim = 0
    dims = [0]
    for i in range(1,3+1):
      dim = dim + A_system_dic[player][str(i)][str(i)].shape[0]
      dims.append(dim)
    A_system[player] = torch.zeros(dim,dim).to('cuda')
    b_system[player] = torch.zeros(dim,1).to('cuda')

    for i in range(1,3+1):
      b_system[player][dims[i-1]:dims[i],:] = b_system_dic[player][str(i)].clone()
      for j in range(1,3+1):
        A_system[player][dims[i-1]:dims[i],dims[j-1]:dims[j]] = A_system_dic[player][str(i)][str(j)].clone()
    
    # Solve system to find duals
    duals_0[player], _ = torch.solve(b_system[player], A_system[player])    
    
    # Calculate descent direction from duals
    d0_v[:,get_player_id(player)] = -torch.einsum('as,a->s', J_gi_vi, duals_0[player].view(-1)[:N_A_total]) - grad_f_v.view(-1)
    norm_2_d0 += d0_v[:,get_player_id(player)].pow(2).sum()

    remaining_duals = duals_0[player].clone()
    for s in S_str:
      NA = N_A_S[s][player]
      NA_other = N_A_S[s][other_player_]
      if NA_other > 1:
        d0_pit[other_player_][s] = -torch.einsum('ba,bw->aw', grad_g_pit[other_player_][s], remaining_duals[:NA,:]) - grad_f_pit[other_player_][s]
      else:
        d0_pit[other_player_][s] = None
      remaining_duals = remaining_duals[NA:,:]
    for s in S_str:
      NA_other = N_A_S[s][other_player_]
      if NA_other > 1:
        d0_pit[other_player_][s] = d0_pit[other_player_][s] + remaining_duals[:NA_other-1,:]
        assert remaining_duals[:NA_other-1,:].shape == (NA_other-1,1), 'Incorrect dual size'
        assert d0_pit[other_player_][s].shape == (NA_other-1,1), 'Incorrect gradient size'
      remaining_duals = remaining_duals[NA_other-1:,:]
    for s in S_str:
      NA_other = N_A_S[s][other_player_]
      if NA_other > 1:
        d0_pit[other_player_][s] = d0_pit[other_player_][s] - remaining_duals[0,0] * torch.ones(NA_other-1,1).to('cuda')
        assert d0_pit[other_player_][s].shape == (NA_other-1,1), 'Incorrect gradient size'
        remaining_duals = remaining_duals[1:,:]
        norm_2_d0 += d0_pit[other_player_][s].pow(2).sum()      
  return d0_v, d0_pit, norm_2_d0, duals_0, A_system, b_system


def calculate_descent_direction_w_tensors(pi_nolast_vector, pi_last_vector, 
                                          g_nash, grad_f_vector, grad_g_matrix):
  g_list = []
  for player in players_str():
    for s in S_str:
      g_list.append(g_nash[player][s])
  for player in players_str():
    g_list.append(-pi_nolast_vector[player])
  for player in players_str():
    g_list.append(-pi_last_vector[player])
  g_tensor = torch.cat(g_list, dim=0)
  g_diag_matrix = torch.diag(g_tensor.view(-1))

  A_matrix = g_diag_matrix - torch.einsum('ik,jk->ij', grad_g_matrix, grad_g_matrix) 
  b_vector = torch.einsum('ij,jk->ik', grad_g_matrix, grad_f_vector) 
  duals_0_vector = torch.solve(b_vector, A_matrix)[0]

  d0_vector = - grad_f_vector - torch.einsum('ij,ik->jk', grad_g_matrix, duals_0_vector)
  norm_2_d0 = d0_vector.pow(2).sum().item() 
  return d0_vector, norm_2_d0, duals_0_vector, A_matrix, b_vector


def calculate_feasible_direction_w_tensors(duals_0_vector, norm_2_d0, 
                                           A_matrix, b_vector, rho,
                                           grad_f_vector, grad_g_matrix):
  div = duals_0_vector.sum().item()
  if div > 0:
      rho_1 = (1-alpha) / div
      if rho_1 < rho:
        rho = 0.5 * rho_1

  duals_vector = torch.solve(b_vector - rho * norm_2_d0, A_matrix)[0]

  d_vector = - grad_f_vector - torch.einsum('ij,ik->jk', grad_g_matrix, duals_vector)
  return d_vector, duals_vector, rho


def calculate_feasible_direction(d0_v, d0_pit, norm_2_d0, duals_0, 
                                 A_system, b_system, grad_f_v, grad_g_v, 
                                 grad_f_pit, grad_g_pit, rho): # TODO: check if compatible with log transform
  div = duals_0['1'].sum() + duals_0['2'].sum()
  if div > 0:
      rho_1 = (1-alpha) / div
      if rho_1 < rho:
        rho = 0.5 * rho_1

  duals = {}  
  d_pit = {'1':{}, '2':{}}
  d_v = torch.zeros((N_S,2)).to('cuda')  
  
  for player in players_str():    
    
    # Solve system to find duals
    duals[player], _ = torch.solve(b_system[player]-rho*norm_2_d0, A_system[player])    
    
    # Calculate descent direction from duals
    N_A_total = int(N_A_S_tensors[player].sum().item())
    grad_g_v_list = [grad_g_v[player][s] for s in S_str]
    J_gi_vi = torch.cat(grad_g_v_list, dim=0)
    d_v[:,get_player_id(player)] = -torch.einsum('as,a->s', J_gi_vi, duals[player].view(-1)[:N_A_total]) - grad_f_v.view(-1)
    
    other_player_ = other_player(player)
    remaining_duals = duals[player].clone()
    for s in S_str:
      NA = N_A_S[s][player]
      NA_other = N_A_S[s][other_player_]
      if NA_other > 1:
        d_pit[other_player_][s] = -torch.einsum('ba,bw->aw', grad_g_pit[other_player_][s], remaining_duals[:NA,:]) - grad_f_pit[other_player_][s]
      else:
        d_pit[other_player_][s] = None
      remaining_duals = remaining_duals[NA:,:]
    for s in S_str:
      NA_other = N_A_S[s][other_player_]
      if NA_other > 1:
        d_pit[other_player_][s] = d_pit[other_player_][s] + remaining_duals[:NA_other-1,:]
        assert remaining_duals[:NA_other-1,:].shape == (NA_other-1,1), 'Incorrect dual size'
        assert d_pit[other_player_][s].shape == (NA_other-1,1), 'Incorrect gradient size'
      remaining_duals = remaining_duals[NA_other-1:,:]
    for s in S_str:
      NA_other = N_A_S[s][other_player_]
      if NA_other > 1:
        d_pit[other_player_][s] = d_pit[other_player_][s] - remaining_duals[0,0] * torch.ones(NA_other-1,1).to('cuda')
        assert d_pit[other_player_][s].shape == (NA_other-1,1), 'Incorrect gradient size'
        remaining_duals = remaining_duals[1:,:]
            
  return d_v, d_pit, duals, rho


# def calculate_feasible_direction(d0_v, d0_pit, norm_2_d0, duals_0, A, b, grad_g_v, grad_g_pit, rho): # TODO: check if compatible with log transform
#   div = duals_0['1'].sum() + duals_0['2'].sum()
#   if div > 0:
#       rho_1 = (1-alpha) / div
#       if rho_1 < rho:
#         rho = 0.5 * rho_1

#   duals = {}  
#   d_v = d0_v.clone() 
#   d_pit = {'1':{}, '2':{}}
#   for s, player in state_player_str_pairs():
#     if d0_pit[player][s] is not None:
#       d_pit[player][s] = d0_pit[player][s].clone()
#     else:
#       d_pit[player][s] = None
  
#   for player in players_str():    
    
#     dual_corrections = -rho * norm_2_d0 * torch.solve(torch.ones_like(b[player]), A[player])[0]
    
#     N_A_total = int(N_A_S_tensors[player].sum().item())
#     grad_g_v_list = [grad_g_v[player][s] for s in S_str]
#     J_gi_vi = torch.cat(grad_g_v_list, dim=0)
#     d_v[:,get_player_id(player)] -= torch.einsum('as,a->s', J_gi_vi, dual_corrections.view(-1)[:N_A_total])

#     other_player_ = other_player(player)
#     remaining_corrections = dual_corrections.clone()
#     for s in S_str:
#       NA = N_A_S[s][player]
#       NA_other = N_A_S[s][other_player_]
#       if NA_other > 1:
#         grad_displacement = - torch.einsum('ba,bw->aw', grad_g_pit[other_player_][s], remaining_corrections[:NA,:])
#         d_pit[other_player_][s] = d_pit[other_player_][s] + grad_displacement # TODO: check mistake in previous version (displacement or replacement?)
#       remaining_corrections = remaining_corrections[NA:,:]
#     for s in S_str:
#       NA_other = N_A_S[s][other_player_]
#       if NA_other > 1:
#         d_pit[other_player_][s] = d_pit[other_player_][s] + remaining_corrections[:NA_other-1,:]        
#         remaining_corrections = remaining_corrections[NA_other-1:,:]
#     for s in S_str:
#       NA_other = N_A_S[s][other_player_]
#       if NA_other > 1:
#         d_pit[other_player_][s] = d_pit[other_player_][s] - remaining_corrections[0,0] * torch.ones(NA_other-1,1).to('cuda')
#         remaining_corrections = remaining_corrections[1:,:]

#     duals[player] = duals_0[player] + dual_corrections.clone()
#   return d_v, d_pit, duals, rho


def copy_game(game_original):
  if game_original.transform_type == 'log':
    game_copy = log_game().to('cuda')
  elif game_original.transform_type == 'nolast':
    game_copy = nolast_game().to('cuda')
  else:
    assert 0 == 1, 'Incorrect game type'
  game_copy.load_state_dict(game_original.state_dict())
  return game_copy


def g_dics2g_vec(g_nash, pi_nolast_vector, pi_last_vector):
  g_list = []
  for player in players_str():
    for s in S_str:
      g_list.append(g_nash[player][s])
  for player in players_str():
    g_list.append(-pi_nolast_vector[player])
  for player in players_str():
    g_list.append(-pi_last_vector[player])
  g_vector = torch.cat(g_list, dim=0)
  return g_vector


def feasible_gradient_descent(game_0, pi, v, g, grad_f_v, grad_f_pit, d_v, d_pit, duals_vector,
                              transform_type='nolast', pi_nolast=None, pi_last=None, 
                              max_steps=10000, verbose=False):
  # Calculate required decrement in the loss function
  f = calculate_bellman_error(pi, v)
  g_vector = g_dics2g_vec(g, pi_nolast, pi_last)
  decrement = torch.einsum('ij,i->j', d_v, grad_f_v.view(-1)).sum()
  for s, player in state_player_str_pairs():
    NA = N_A_S[s][player]
    if NA > 1 or transform_type!='nolast':
      decrement += (d_pit[player][s].view(-1) * grad_f_pit[player][s].view(-1)).sum()

  # Procedure to find feasible step size 
  found_feasible_step_size = False
  step_size = 1.0
  n_step = 0
  max_steps = max(1, max_steps)
  while (not found_feasible_step_size) and (n_step < max_steps):
    game_temp = copy_game(game_0)
    v_before = game_temp.v.clone()
    # Update parameters performing step in feasible descent direction 
    game_temp.v.data.add_(step_size * d_v)
    v_after = game_temp.v.clone()

    for s in S_str:
      if transform_type == 'log':
        game_temp.log_pi1[s].data.add_(step_size * d_pit['1'][s])
        game_temp.log_pi2[s].data.add_(step_size * d_pit['2'][s])
      elif transform_type == 'nolast':
        if d_pit['1'][s] is not None:
          game_temp.pi_nolast1[s].data.add_(step_size * d_pit['1'][s])
        if d_pit['2'][s] is not None:
          game_temp.pi_nolast2[s].data.add_(step_size * d_pit['2'][s]) 

    pi_temp, v_temp = game_temp()    
    f_temp = calculate_bellman_error(pi_temp, v_temp)
    g_temp = calculate_nash_restrictions(pi_temp, v_temp)
    pi_nolast_temp, pi_last_temp = game_temp.pi_discriminated()

    g_vector_temp = g_dics2g_vec(g_temp, pi_nolast_temp, pi_last_temp)
    gamma = gamma_0 * torch.ones_like(duals_vector)
    gamma[duals_vector < 0] = 1.0

    f_decreased = f_temp <= f + step_size * eta * decrement
    g_v_valid = torch.all(g_vector_temp[:114] <= (g_vector * gamma)[:114])
    g_pi_valid = torch.all(g_vector_temp[114:] <= (g_vector * gamma)[114:])
    g_valid = g_v_valid and g_pi_valid
    
    # Check if the current step size results in a feasible descent direction
    if f_decreased and g_valid:
      found_feasible_step_size = True
      
    step_size /= nu
    n_step += 1
    if verbose:
      print(f_decreased.item(), g_v_valid.item(), g_pi_valid.item())
      dv = (v_after - v_before).pow(2).sum().pow(0.5).item()
      print("Step: {}, Step size: {:.3e}, Delta V: {:.3e}, Found feasible: {}".format(n_step, step_size, dv, found_feasible_step_size))
  return game_temp, found_feasible_step_size, n_step, f_temp


def vec2dic(d_vector, duals_vector):
  d_v = torch.zeros((N_S,2)).to('cuda')  
  d_pit = {'1':{}, '2':{}}
  duals = {}

  for i in range(0,2):
    d_v[:,i] = d_vector[i*N_S:(i+1)*N_S,:].view(-1)
  
  N_A_total = {}
  N_S_reduced = {}
  for player in players_str():
    N_A_total[player] = int(N_A_S_tensors[player].sum().item()) 
    N_S_reduced[player] = N_A_S_tensors_reduced[player].view(-1).shape[0]

  for player in players_str():
    y0 = 2*N_S + get_player_id(player) * (N_A_total['1'] - N_S)
    yf = N_S + N_A_total['1']  + get_player_id(player) * (N_A_total['2'] - N_S) 
    d_pit_vector = d_vector[y0:yf,:]
    n = 0
    for s in S_str:
      NA = N_A_S[s][player]
      if NA > 1:
        d_pit[player][s] = d_pit_vector[n:n+NA-1,:]
      else:
        d_pit[player][s] = None
      n = n + NA-1
    
  for player in players_str():
    i = get_player_id(player)
    other_player_ = other_player(player)
    n_restrictions = 2*N_A_total[player]-N_S+N_S_reduced[player]
    duals[player] = torch.zeros(n_restrictions,1).to('cuda')
    
    y0 = get_player_id(player) * N_A_total['1']
    yf = N_A_total['1']  + get_player_id(player) * N_A_total['2'] 
    duals[player][:N_A_total[player],:] = duals_vector[y0:yf,:]
    
    y0 = N_A_total['1'] + N_A_total['2'] + get_player_id(player) * (N_A_total['1'] - N_S)
    yf = 2*N_A_total['1'] + N_A_total['2'] - N_S + get_player_id(player) * (N_A_total['2'] - N_S) 
    duals[player][N_A_total[player]:2*N_A_total[player]-N_S,:] = duals_vector[y0:yf,:]

    y0 = 2*(N_A_total['1'] + N_A_total['2'] - N_S) + get_player_id(player) * N_S_reduced['1']
    yf = 2*(N_A_total['1'] + N_A_total['2'] - N_S) + N_S_reduced['1'] + get_player_id(player) * N_S_reduced['2'] 
    duals[player][-N_S_reduced[player]:,:] = duals_vector[y0:yf,:]

  return d_v, d_pit, duals


def optimize_game(game_0, rho, n_epochs=100, verbose=False):
  game_new = copy_game(game_0)
  pi_0, v_0 = game_0()
  f_0 = calculate_bellman_error(pi_0, v_0)
  ttype = game_new.transform_type
  with torch.no_grad():
    for epoch in range(0, n_epochs):
      pi, v = game_new()
      pi_nolast, pi_last = game_new.pi_discriminated()
      g_nash = calculate_nash_restrictions(pi, v)
      grad_f_v, grad_f_pi = bellman_error_gradients(pi, v)
      grad_g_v, grad_g_pi = nash_restriction_gradients(pi, v)
      grad_f_pit, grad_g_pit = calculate_nash_pi_transformed_gradients(pi, grad_f_pi, grad_g_pi, transform_type=ttype)
      
      # d0_v, d0_pit, norm_2_d0, duals_0, A_system, b_system = calculate_descent_direction(pi_nolast, pi_last, g_nash, 
      #                                                                                    grad_f_v, grad_f_pit, grad_g_v, grad_g_pit)
      # d_v, d_pit, duals, rho = calculate_feasible_direction(d0_v, d0_pit, norm_2_d0, duals_0, A_system, b_system, grad_f_v, grad_g_v, grad_f_pit, grad_g_pit, rho)
      
      grad_f_vector, grad_g_matrix = build_grad_tensors(grad_f_v, grad_f_pit, grad_g_v, grad_g_pit)
      d0_vector, norm_2_d0, duals_0_vector, A_matrix, b_vector = calculate_descent_direction_w_tensors(pi_nolast, pi_last, g_nash, grad_f_vector, grad_g_matrix)
      d_vector, duals_vector, rho = calculate_feasible_direction_w_tensors(duals_0_vector, norm_2_d0, A_matrix, b_vector, rho, grad_f_vector, grad_g_matrix)
      d_v, d_pit, duals = vec2dic(d_vector, duals_vector)

      g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero = check_nash_KKT_conditions(pi, v, duals)
      if ttype == 'nolast':
        pi_KKT = check_pi_KKT_conditions(pi_nolast, pi_last, duals)
        g_pi_plus_satisfied, max_g_pi_plus, product_zero_plus_satisfied, max_product_zero_plus = pi_KKT[:4] 
        g_pi_one_satisfied, max_g_pi_one, product_zero_one_satisfied, max_product_zero_one = pi_KKT[4:]
      game_temp, found_feasible_step_size, n_step, f_new = feasible_gradient_descent(game_new, pi, v, g_nash, grad_f_v, grad_f_pit, d_v, d_pit, 
                                                                                     duals_vector, transform_type=ttype, pi_nolast=pi_nolast, pi_last=pi_last, verbose=verbose) # TODO: check if passing wrong grad_f
      
      if found_feasible_step_size:
        game_new = game_temp
        delta_f = (f_new - f_0) / f_0 * 100
        print('Epoch: {}, f: {:.3e}, delta f: {:.3e}%, nash satisfied: {}, max g: {:.3e}, dual nash satisfied: {}, max product: {:.3e}, rho: {:.3e}, norm2 d0:{:.3e}'.format(
            epoch, f_new.item(), delta_f.item(), g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero, rho, norm_2_d0))
        if ttype == 'nolast':
          print('Epoch: {}, >0 satisfied: {}, max g>0: {:.3e}, dual >0 satisfied: {}, max product >0: {:.3e}, =1 satisfied: {}, max g=1: {:.3e}, dual =1 satisfied: {}, max product =1: {:.3e}'.format(
            epoch, g_pi_plus_satisfied, max_g_pi_plus, product_zero_plus_satisfied, max_product_zero_plus, g_pi_one_satisfied, max_g_pi_one, product_zero_one_satisfied, max_product_zero_one))
      else:
        break
        
  return game_new, rho     

#**Optimization**

In [17]:
game_b = nolast_game().to('cuda')

v0 = calculate_initial_v(game_b)
v0 = torch.FloatTensor(v0).to('cuda')
with torch.no_grad():
  game_b.v.data.add_(v0)

In [18]:
alpha = 0.5
gamma_0 = 0.5
eta = 0.01
nu = 1/0.98
rho_0 = 1.0
rho = rho_0
r_value = 1.0

In [82]:
game_b2, rho = optimize_game(game_b, rho, n_epochs=100, verbose=False)

Epoch: 0, f: 5.977e+02, delta f: -8.175e+00%, nash satisfied: True, max g: -1.000e+01, dual nash satisfied: False, max product: 4.894e+00, rho: 1.270e-02, norm2 d0:5.685e+01
Epoch: 0, >0 satisfied: True, max g>0: -8.333e-02, dual >0 satisfied: False, max product >0: 1.232e+00, =1 satisfied: True, max g=1: -8.333e-02, dual =1 satisfied: False, max product =1: 1.161e+00
Epoch: 1, f: 5.400e+02, delta f: -1.704e+01%, nash satisfied: True, max g: -8.364e+00, dual nash satisfied: False, max product: 5.272e+00, rho: 6.268e-03, norm2 d0:5.082e+01
Epoch: 1, >0 satisfied: True, max g>0: -4.246e-02, dual >0 satisfied: False, max product >0: 8.840e-01, =1 satisfied: True, max g=1: -1.025e-01, dual =1 satisfied: False, max product =1: 6.844e-01
Epoch: 2, f: 5.001e+02, delta f: -2.317e+01%, nash satisfied: True, max g: -6.662e+00, dual nash satisfied: False, max product: 5.273e+00, rho: 6.268e-03, norm2 d0:4.441e+01
Epoch: 2, >0 satisfied: True, max g>0: -6.582e-02, dual >0 satisfied: False, max pro

In [83]:
game_b3, rho = optimize_game(game_b2, rho, n_epochs=100, verbose=False)

Epoch: 0, f: 1.001e+02, delta f: -2.695e-02%, nash satisfied: True, max g: -2.441e-04, dual nash satisfied: False, max product: 3.578e-01, rho: 6.611e-04, norm2 d0:3.494e-01
Epoch: 0, >0 satisfied: True, max g>0: -1.111e-05, dual >0 satisfied: False, max product >0: 8.222e-02, =1 satisfied: True, max g=1: -2.486e-04, dual =1 satisfied: False, max product =1: 8.176e-02
Epoch: 1, f: 9.995e+01, delta f: -1.339e-01%, nash satisfied: True, max g: -2.441e-04, dual nash satisfied: False, max product: 3.568e-01, rho: 6.611e-04, norm2 d0:3.482e-01
Epoch: 1, >0 satisfied: True, max g>0: -6.400e-06, dual >0 satisfied: False, max product >0: 8.184e-02, =1 satisfied: True, max g=1: -2.472e-04, dual =1 satisfied: False, max product =1: 8.138e-02
Epoch: 2, f: 9.993e+01, delta f: -1.549e-01%, nash satisfied: True, max g: -2.441e-04, dual nash satisfied: False, max product: 3.516e-01, rho: 6.611e-04, norm2 d0:3.429e-01
Epoch: 2, >0 satisfied: True, max g>0: -1.007e-05, dual >0 satisfied: False, max pro

In [106]:
torch.save(game_b3.state_dict(), './game_b3.pth')

In [94]:
v03 = calculate_initial_v(game_b3, alpha=0.1)
v03 = torch.FloatTensor(v03).to('cuda')

pi3, v3 = game_b3()

f03 = calculate_bellman_error(pi3, v03)
f3 = calculate_bellman_error(pi3, v3)
print(f3, f03)

tensor(95.7261, device='cuda:0', grad_fn=<SumBackward0>) tensor(21.2995, device='cuda:0', grad_fn=<SumBackward0>)


In [102]:
game_b4 = copy_game(game_b3)
with torch.no_grad():
  nn.init.zeros_(game_b4.v)
  game_b4.v.data.add_(v03)  

In [104]:
game_b5, rho = optimize_game(game_b4, 1.0, n_epochs=100, verbose=False)

Epoch: 0, f: 1.941e+01, delta f: -8.890e+00%, nash satisfied: True, max g: -9.998e-02, dual nash satisfied: False, max product: 2.674e-01, rho: 1.598e-03, norm2 d0:1.475e-01
Epoch: 0, >0 satisfied: True, max g>0: -4.300e-06, dual >0 satisfied: False, max product >0: 1.141e-02, =1 satisfied: True, max g=1: -1.551e-04, dual =1 satisfied: False, max product =1: 1.094e-02
Epoch: 1, f: 1.909e+01, delta f: -1.035e+01%, nash satisfied: True, max g: -5.020e-02, dual nash satisfied: False, max product: 2.118e-01, rho: 1.598e-03, norm2 d0:9.467e-02
Epoch: 1, >0 satisfied: True, max g>0: -6.252e-05, dual >0 satisfied: False, max product >0: 9.986e-03, =1 satisfied: True, max g=1: -1.744e-04, dual =1 satisfied: False, max product =1: 9.683e-03
Epoch: 2, f: 1.872e+01, delta f: -1.211e+01%, nash satisfied: True, max g: -4.373e-02, dual nash satisfied: False, max product: 2.041e-01, rho: 1.598e-03, norm2 d0:8.872e-02
Epoch: 2, >0 satisfied: True, max g>0: -3.230e-05, dual >0 satisfied: False, max pro

In [24]:
#torch.save(game_b5.state_dict(), './game_b5.pth')

v05 = calculate_initial_v(game_b5, alpha=0.0001)
v05 = torch.FloatTensor(v05).to('cuda')

pi5, v5 = game_b5()

f05 = calculate_bellman_error(pi5, v05)
f5 = calculate_bellman_error(pi5, v5)
print(f5, f05)

tensor(13.3938, device='cuda:0', grad_fn=<SumBackward0>) tensor(13.1853, device='cuda:0', grad_fn=<SumBackward0>)


In [65]:
game_b5 = nolast_game().to('cuda')
game_b5.load_state_dict(torch.load('./game_b5.pth'))

<All keys matched successfully>

In [66]:
game_b6 = copy_game(game_b5)
with torch.no_grad():
  nn.init.zeros_(game_b6.v)
  game_b6.v.data.add_(v05) 

In [67]:
game_b7, rho = optimize_game(game_b6, 1.0, n_epochs=12, verbose=False)

Epoch: 0, f: 1.318e+01, delta f: -1.296e-02%, nash satisfied: True, max g: -6.104e-05, dual nash satisfied: False, max product: 4.484e-03, rho: 1.682e-03, norm2 d0:4.922e-03
Epoch: 0, >0 satisfied: True, max g>0: -7.249e-07, dual >0 satisfied: False, max product >0: 9.706e-04, =1 satisfied: True, max g=1: -7.868e-06, dual =1 satisfied: False, max product =1: 8.843e-04
Epoch: 1, f: 1.318e+01, delta f: -1.805e-02%, nash satisfied: True, max g: -3.052e-05, dual nash satisfied: False, max product: 4.529e-03, rho: 1.682e-03, norm2 d0:4.923e-03
Epoch: 1, >0 satisfied: True, max g>0: -8.677e-07, dual >0 satisfied: False, max product >0: 1.040e-03, =1 satisfied: True, max g=1: -7.868e-06, dual =1 satisfied: False, max product =1: 9.572e-04
Epoch: 2, f: 1.318e+01, delta f: -2.083e-02%, nash satisfied: True, max g: -6.104e-05, dual nash satisfied: False, max product: 4.596e-03, rho: 1.682e-03, norm2 d0:4.926e-03
Epoch: 2, >0 satisfied: True, max g>0: -6.517e-07, dual >0 satisfied: False, max pro

In [71]:
#torch.save(game_b7.state_dict(), './game_b7.pth')

v07 = calculate_initial_v(game_b7, alpha=0.1)
v07 = torch.FloatTensor(v07).to('cuda')

pi7, v7 = game_b7()

f07 = calculate_bellman_error(pi7, v07)
f7 = calculate_bellman_error(pi7, v7)

  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=sym_pos)
  return sp.linalg.solve(M, r, sym_pos=s

In [72]:
g_7 = calculate_nash_restrictions(pi7, v7)
pi_nolast_7, pi_last_7 = game_b7.pi_discriminated()
g_vector_7 = g_dics2g_vec(g_7, pi_nolast_7, pi_last_7)

In [70]:
g_vector_7

tensor([[-4.8826e+00],
        [-2.8826e+00],
        [-1.2207e-04],
        [-1.2207e-04],
        [-1.2207e-04],
        [-1.5259e-04],
        [-1.2207e-04],
        [-1.5259e-04],
        [-1.2207e-04],
        [-1.8311e-04],
        [-1.2207e-04],
        [-4.5776e-04],
        [-6.1035e-05],
        [-6.4026e-02],
        [-9.1553e-05],
        [-7.8038e+00],
        [-6.1035e-05],
        [-6.8038e+00],
        [-6.1035e-05],
        [-5.8038e+00],
        [-6.1035e-05],
        [-4.8038e+00],
        [-6.1035e-05],
        [-3.8038e+00],
        [-6.1035e-05],
        [-2.8039e+00],
        [-1.2207e-04],
        [-4.0016e+00],
        [-6.1035e-05],
        [-3.0016e+00],
        [-6.1035e-05],
        [-2.0016e+00],
        [-6.1035e-05],
        [-1.0016e+00],
        [-6.1035e-05],
        [-1.5564e-03],
        [-3.0518e-05],
        [-3.0518e-05],
        [-9.9850e-01],
        [-6.1035e-05],
        [-6.1035e-05],
        [-6.1035e-05],
        [-6.1035e-05],
        [-6

In [73]:
print(f7, f07)

tensor(13.1753, device='cuda:0', grad_fn=<SumBackward0>) tensor(19.3723, device='cuda:0', grad_fn=<SumBackward0>)


In [74]:
game_b8 = copy_game(game_b7)
with torch.no_grad():
  nn.init.zeros_(game_b8.v)
  game_b8.v.data.add_(v07) 

In [60]:
alpha = 0.5
gamma_0 = 0.5
eta = 0.1
nu = 1/0.98
rho_0 = 1.0
rho = rho_0
r_value = 1.0

In [75]:
game_b9, rho = optimize_game(game_b8, 1.0, n_epochs=20, verbose=False)

Epoch: 0, f: 1.748e+01, delta f: -9.750e+00%, nash satisfied: True, max g: -9.998e-02, dual nash satisfied: False, max product: 2.679e-01, rho: 1.692e-03, norm2 d0:1.469e-01
Epoch: 0, >0 satisfied: True, max g>0: -9.996e-07, dual >0 satisfied: False, max product >0: 1.226e-02, =1 satisfied: True, max g=1: -8.225e-06, dual =1 satisfied: False, max product =1: 1.176e-02
Epoch: 1, f: 1.714e+01, delta f: -1.150e+01%, nash satisfied: True, max g: -5.005e-02, dual nash satisfied: False, max product: 2.118e-01, rho: 1.692e-03, norm2 d0:9.393e-02
Epoch: 1, >0 satisfied: True, max g>0: -8.012e-05, dual >0 satisfied: False, max product >0: 1.042e-02, =1 satisfied: True, max g=1: -8.678e-05, dual =1 satisfied: False, max product =1: 1.010e-02
Epoch: 2, f: 1.672e+01, delta f: -1.371e+01%, nash satisfied: True, max g: -4.303e-02, dual nash satisfied: False, max product: 2.034e-01, rho: 1.692e-03, norm2 d0:8.761e-02
Epoch: 2, >0 satisfied: True, max g>0: -4.061e-05, dual >0 satisfied: False, max pro

In [76]:
game_b10, rho = optimize_game(game_b9, 1.0, n_epochs=100, verbose=False)

Epoch: 0, f: 1.195e+01, delta f: -8.758e-01%, nash satisfied: True, max g: -1.221e-04, dual nash satisfied: False, max product: 1.294e-01, rho: 1.528e-03, norm2 d0:6.857e-02
Epoch: 0, >0 satisfied: True, max g>0: -1.072e-05, dual >0 satisfied: False, max product >0: 3.023e-02, =1 satisfied: True, max g=1: -1.538e-04, dual =1 satisfied: False, max product =1: 3.002e-02
Epoch: 1, f: 1.179e+01, delta f: -2.214e+00%, nash satisfied: True, max g: -1.221e-04, dual nash satisfied: False, max product: 1.160e-01, rho: 1.528e-03, norm2 d0:5.613e-02
Epoch: 1, >0 satisfied: True, max g>0: -7.519e-06, dual >0 satisfied: False, max product >0: 2.714e-02, =1 satisfied: True, max g=1: -1.376e-04, dual =1 satisfied: False, max product =1: 2.697e-02
Epoch: 2, f: 1.177e+01, delta f: -2.393e+00%, nash satisfied: True, max g: -9.155e-05, dual nash satisfied: False, max product: 8.716e-02, rho: 1.528e-03, norm2 d0:3.394e-02
Epoch: 2, >0 satisfied: True, max g>0: -8.572e-06, dual >0 satisfied: False, max pro

In [86]:
torch.save(game_b10.state_dict(), './game_b10.pth')

In [80]:
v0_10 = calculate_initial_v(game_b10, alpha=0.1)
v0_10 = torch.FloatTensor(v0_10).to('cuda')

pi_10, v_10 = game_b10()
f0_10 = calculate_bellman_error(pi_10, v0_10)
f_10 = calculate_bellman_error(pi_10, v_10)

In [82]:
print(f_10.item(), f0_10.item())

11.4930419921875 17.479324340820312


In [84]:
game_b11 = copy_game(game_b10)
with torch.no_grad():
  nn.init.zeros_(game_b11.v)
  game_b11.v.data.add_(v0_10) 

In [85]:
game_b12, rho = optimize_game(game_b11, 1.0, n_epochs=100, verbose=False)

Epoch: 0, f: 1.559e+01, delta f: -1.078e+01%, nash satisfied: True, max g: -9.998e-02, dual nash satisfied: False, max product: 2.674e-01, rho: 1.807e-03, norm2 d0:1.457e-01
Epoch: 0, >0 satisfied: True, max g>0: -5.987e-07, dual >0 satisfied: False, max product >0: 1.239e-02, =1 satisfied: True, max g=1: -6.974e-06, dual =1 satisfied: False, max product =1: 1.186e-02
Epoch: 1, f: 1.521e+01, delta f: -1.300e+01%, nash satisfied: True, max g: -5.013e-02, dual nash satisfied: False, max product: 2.114e-01, rho: 1.807e-03, norm2 d0:9.319e-02
Epoch: 1, >0 satisfied: True, max g>0: -8.461e-05, dual >0 satisfied: False, max product >0: 1.043e-02, =1 satisfied: True, max g=1: -9.096e-05, dual =1 satisfied: False, max product =1: 1.009e-02
Epoch: 2, f: 1.469e+01, delta f: -1.595e+01%, nash satisfied: True, max g: -4.205e-02, dual nash satisfied: False, max product: 2.018e-01, rho: 1.807e-03, norm2 d0:8.614e-02
Epoch: 2, >0 satisfied: True, max g>0: -4.248e-05, dual >0 satisfied: False, max pro

In [88]:
torch.save(game_b12.state_dict(), './game_b12.pth')

In [99]:
v0_12 = calculate_initial_v(game_b12, alpha=0.08)
v0_12 = torch.FloatTensor(v0_12).to('cuda')

pi_12, v_12 = game_b12()
f0_12 = calculate_bellman_error(pi_12, v0_12)
f_12 = calculate_bellman_error(pi_12, v_12)

In [100]:
print(f_12.item(), f0_12.item())

9.595474243164062 14.2904052734375


In [101]:
game_b13 = copy_game(game_b12)
with torch.no_grad():
  nn.init.zeros_(game_b13.v)
  game_b13.v.data.add_(v0_12) 

In [102]:
game_b14, rho = optimize_game(game_b13, 1.0, n_epochs=100, verbose=False)

Epoch: 0, f: 1.286e+01, delta f: -1.003e+01%, nash satisfied: True, max g: -7.999e-02, dual nash satisfied: False, max product: 2.288e-01, rho: 1.940e-03, norm2 d0:1.085e-01
Epoch: 0, >0 satisfied: True, max g>0: -9.772e-07, dual >0 satisfied: False, max product >0: 1.247e-02, =1 satisfied: True, max g=1: -9.000e-06, dual =1 satisfied: False, max product =1: 1.205e-02
Epoch: 1, f: 1.247e+01, delta f: -1.272e+01%, nash satisfied: True, max g: -4.074e-02, dual nash satisfied: False, max product: 1.821e-01, rho: 1.940e-03, norm2 d0:7.073e-02
Epoch: 1, >0 satisfied: True, max g>0: -6.408e-05, dual >0 satisfied: False, max product >0: 1.019e-02, =1 satisfied: True, max g=1: -7.033e-05, dual =1 satisfied: False, max product =1: 9.916e-03
Epoch: 2, f: 1.188e+01, delta f: -1.684e+01%, nash satisfied: True, max g: -3.256e-02, dual nash satisfied: False, max product: 1.721e-01, rho: 1.940e-03, norm2 d0:6.468e-02
Epoch: 2, >0 satisfied: True, max g>0: -3.251e-05, dual >0 satisfied: False, max pro

In [103]:
torch.save(game_b14.state_dict(), './game_b14.pth')

In [None]:
v0_14 = calculate_initial_v(game_b14, alpha=0.075)
v0_14 = torch.FloatTensor(v0_14).to('cuda')

pi_14, v_14 = game_b14()
f0_14 = calculate_bellman_error(pi_14, v0_14)
f_14 = calculate_bellman_error(pi_14, v_14)

In [115]:
print(f_14.item(), f0_14.item())

7.52215576171875 11.959686279296875


In [116]:
game_b15 = copy_game(game_b14)
with torch.no_grad():
  nn.init.zeros_(game_b15.v)
  game_b15.v.data.add_(v0_14) 

In [117]:
game_b16, rho = optimize_game(game_b15, 1.0, n_epochs=50, verbose=False)

Epoch: 0, f: 1.062e+01, delta f: -1.119e+01%, nash satisfied: True, max g: -7.498e-02, dual nash satisfied: False, max product: 2.164e-01, rho: 2.103e-03, norm2 d0:9.677e-02
Epoch: 0, >0 satisfied: True, max g>0: -9.589e-07, dual >0 satisfied: False, max product >0: 1.199e-02, =1 satisfied: True, max g=1: -9.358e-06, dual =1 satisfied: False, max product =1: 1.158e-02
Epoch: 1, f: 1.017e+01, delta f: -1.493e+01%, nash satisfied: True, max g: -3.807e-02, dual nash satisfied: False, max product: 1.718e-01, rho: 2.103e-03, norm2 d0:6.329e-02
Epoch: 1, >0 satisfied: True, max g>0: -6.153e-05, dual >0 satisfied: False, max product >0: 9.762e-03, =1 satisfied: True, max g=1: -6.825e-05, dual =1 satisfied: False, max product =1: 9.496e-03
Epoch: 2, f: 9.436e+00, delta f: -2.111e+01%, nash satisfied: True, max g: -2.853e-02, dual nash satisfied: False, max product: 1.598e-01, rho: 2.103e-03, norm2 d0:5.686e-02
Epoch: 2, >0 satisfied: True, max g>0: -3.130e-05, dual >0 satisfied: False, max pro

In [118]:
torch.save(game_b16.state_dict(), './game_b16.pth')

In [119]:
v0_16 = calculate_initial_v(game_b16, alpha=0.075)
v0_16 = torch.FloatTensor(v0_16).to('cuda')

pi_16, v_16 = game_b16()
f0_16 = calculate_bellman_error(pi_16, v0_16)
f_16 = calculate_bellman_error(pi_16, v_16)

  return sp.linalg.solve(M, r, sym_pos=sym_pos)


In [120]:
print(f_16.item(), f0_16.item())

6.14483642578125 10.612747192382812


In [122]:
game_b17 = copy_game(game_b16)
with torch.no_grad():
  nn.init.zeros_(game_b17.v)
  game_b17.v.data.add_(v0_16) 

In [123]:
game_b18, rho = optimize_game(game_b17, 1.0, n_epochs=50, verbose=False)

Epoch: 0, f: 9.130e+00, delta f: -1.397e+01%, nash satisfied: True, max g: -7.498e-02, dual nash satisfied: False, max product: 1.921e-01, rho: 2.172e-03, norm2 d0:7.162e-02
Epoch: 0, >0 satisfied: True, max g>0: -2.502e-06, dual >0 satisfied: False, max product >0: 4.449e-03, =1 satisfied: True, max g=1: -1.395e-05, dual =1 satisfied: False, max product =1: 4.761e-03
Epoch: 1, f: 8.704e+00, delta f: -1.798e+01%, nash satisfied: True, max g: -3.793e-02, dual nash satisfied: False, max product: 1.418e-01, rho: 2.172e-03, norm2 d0:4.341e-02
Epoch: 1, >0 satisfied: True, max g>0: -5.279e-05, dual >0 satisfied: False, max product >0: 4.234e-03, =1 satisfied: True, max g=1: -5.478e-05, dual =1 satisfied: False, max product =1: 4.423e-03
Epoch: 2, f: 8.088e+00, delta f: -2.379e+01%, nash satisfied: True, max g: -2.954e-02, dual nash satisfied: False, max product: 1.269e-01, rho: 2.172e-03, norm2 d0:3.762e-02
Epoch: 2, >0 satisfied: True, max g>0: -2.656e-05, dual >0 satisfied: False, max pro

In [124]:
torch.save(game_b18.state_dict(), './game_b18.pth')

In [131]:
v0_18 = calculate_initial_v(game_b18, alpha=0.03)
v0_18 = torch.FloatTensor(v0_18).to('cuda')

pi_18, v_18 = game_b18()
f0_18 = calculate_bellman_error(pi_18, v0_18)
f_18 = calculate_bellman_error(pi_18, v_18)

  return sp.linalg.solve(M, r, sym_pos=sym_pos)


In [132]:
print(f_18.item(), f0_18.item())

5.9244537353515625 7.6738128662109375


In [133]:
game_b19 = copy_game(game_b18)
with torch.no_grad():
  nn.init.zeros_(game_b19.v)
  game_b19.v.data.add_(v0_18) 

In [135]:
game_b20, rho = optimize_game(game_b19, 1.0, n_epochs=50, verbose=False)

KeyboardInterrupt: ignored

In [136]:
torch.save(game_b20.state_dict(), './game_b20.pth')

In [146]:
v0_20 = calculate_initial_v(game_b20, alpha=0.025)
v0_20 = torch.FloatTensor(v0_20).to('cuda')

pi_20, v_20 = game_b20()
f0_20 = calculate_bellman_error(pi_20, v0_20)
f_20 = calculate_bellman_error(pi_20, v_20)

In [147]:
print(f_20.item(), f0_20.item())

5.57208251953125 7.0506591796875


In [148]:
game_b21 = copy_game(game_b20)
with torch.no_grad():
  nn.init.zeros_(game_b21.v)
  game_b21.v.data.add_(v0_20) 

In [150]:
game_b22, rho = optimize_game(game_b21, 1.0, n_epochs=75, verbose=False)

Epoch: 0, f: 6.560e+00, delta f: -6.957e+00%, nash satisfied: True, max g: -2.499e-02, dual nash satisfied: False, max product: 8.638e-02, rho: 2.095e-03, norm2 d0:2.701e-02
Epoch: 0, >0 satisfied: True, max g>0: -7.046e-06, dual >0 satisfied: False, max product >0: 8.280e-03, =1 satisfied: True, max g=1: -1.478e-05, dual =1 satisfied: False, max product =1: 8.392e-03
Epoch: 1, f: 6.292e+00, delta f: -1.076e+01%, nash satisfied: True, max g: -1.266e-02, dual nash satisfied: False, max product: 8.433e-02, rho: 2.095e-03, norm2 d0:2.262e-02
Epoch: 1, >0 satisfied: True, max g>0: -1.416e-05, dual >0 satisfied: False, max product >0: 8.361e-03, =1 satisfied: True, max g=1: -2.444e-05, dual =1 satisfied: False, max product =1: 8.457e-03
Epoch: 2, f: 6.046e+00, delta f: -1.425e+01%, nash satisfied: True, max g: -7.507e-03, dual nash satisfied: False, max product: 8.329e-02, rho: 2.095e-03, norm2 d0:2.110e-02
Epoch: 2, >0 satisfied: True, max g>0: -7.298e-06, dual >0 satisfied: False, max pro

In [151]:
game_b23, rho = optimize_game(game_b22, 1.0, n_epochs=75, verbose=False)

Epoch: 0, f: 5.010e+00, delta f: -9.037e-02%, nash satisfied: True, max g: -3.052e-05, dual nash satisfied: False, max product: 8.888e-02, rho: 1.854e-03, norm2 d0:2.135e-02
Epoch: 0, >0 satisfied: True, max g>0: -8.687e-06, dual >0 satisfied: False, max product >0: 9.823e-03, =1 satisfied: True, max g=1: -6.855e-06, dual =1 satisfied: False, max product =1: 9.900e-03
Epoch: 1, f: 5.005e+00, delta f: -2.008e-01%, nash satisfied: True, max g: -3.052e-05, dual nash satisfied: False, max product: 8.892e-02, rho: 1.854e-03, norm2 d0:2.136e-02
Epoch: 1, >0 satisfied: True, max g>0: -6.997e-06, dual >0 satisfied: False, max product >0: 9.826e-03, =1 satisfied: True, max g=1: -3.457e-06, dual =1 satisfied: False, max product =1: 9.907e-03
Epoch: 2, f: 5.001e+00, delta f: -2.812e-01%, nash satisfied: True, max g: -1.526e-05, dual nash satisfied: False, max product: 8.896e-02, rho: 1.854e-03, norm2 d0:2.137e-02
Epoch: 2, >0 satisfied: True, max g>0: -7.775e-06, dual >0 satisfied: False, max pro

In [153]:
game_b23, rho = optimize_game(game_b23, 1.0, n_epochs=20, verbose=False)

Epoch: 0, f: 4.327e+00, delta f: -7.405e-03%, nash satisfied: True, max g: -1.526e-05, dual nash satisfied: False, max product: 4.365e-02, rho: 1.805e-03, norm2 d0:8.871e-03
Epoch: 0, >0 satisfied: True, max g>0: -3.316e-06, dual >0 satisfied: False, max product >0: 4.789e-03, =1 satisfied: True, max g=1: -2.384e-06, dual =1 satisfied: False, max product =1: 4.822e-03
Epoch: 1, f: 4.324e+00, delta f: -7.440e-02%, nash satisfied: True, max g: -1.526e-05, dual nash satisfied: False, max product: 4.321e-02, rho: 1.805e-03, norm2 d0:8.796e-03
Epoch: 1, >0 satisfied: True, max g>0: -3.300e-06, dual >0 satisfied: False, max product >0: 4.744e-03, =1 satisfied: True, max g=1: -2.384e-06, dual =1 satisfied: False, max product =1: 4.777e-03
Epoch: 2, f: 4.324e+00, delta f: -7.651e-02%, nash satisfied: True, max g: -1.526e-05, dual nash satisfied: False, max product: 3.867e-02, rho: 1.805e-03, norm2 d0:8.049e-03
Epoch: 2, >0 satisfied: True, max g>0: -2.328e-06, dual >0 satisfied: False, max pro

In [154]:
torch.save(game_b23.state_dict(), './game_b23.pth')

In [None]:
v0_23 = calculate_initial_v(game_b23, alpha=0.017)
v0_23 = torch.FloatTensor(v0_23).to('cuda')

pi_23, v_23 = game_b23()
f0_23 = calculate_bellman_error(pi_23, v0_23)
f_23 = calculate_bellman_error(pi_23, v_23)

In [166]:
print(f_23.item(), f0_23.item())

4.31756591796875 5.162017822265625


In [167]:
game_b24 = copy_game(game_b23)
with torch.no_grad():
  nn.init.zeros_(game_b24.v)
  game_b24.v.data.add_(v0_23) 

In [168]:
game_b25, rho = optimize_game(game_b24, 1.0, n_epochs=75, verbose=False)

Epoch: 0, f: 5.107e+00, delta f: -1.062e+00%, nash satisfied: True, max g: -1.581e-02, dual nash satisfied: False, max product: 4.865e-02, rho: 1.932e-03, norm2 d0:1.111e-02
Epoch: 0, >0 satisfied: True, max g>0: -2.122e-06, dual >0 satisfied: False, max product >0: 2.029e-03, =1 satisfied: True, max g=1: -4.172e-06, dual =1 satisfied: False, max product =1: 2.072e-03
Epoch: 1, f: 4.978e+00, delta f: -3.562e+00%, nash satisfied: True, max g: -1.483e-02, dual nash satisfied: False, max product: 4.683e-02, rho: 1.932e-03, norm2 d0:1.055e-02
Epoch: 1, >0 satisfied: True, max g>0: -2.721e-06, dual >0 satisfied: False, max product >0: 2.003e-03, =1 satisfied: True, max g=1: -2.146e-06, dual =1 satisfied: False, max product =1: 2.043e-03
Epoch: 2, f: 4.929e+00, delta f: -4.514e+00%, nash satisfied: True, max g: -1.248e-02, dual nash satisfied: False, max product: 4.245e-02, rho: 1.932e-03, norm2 d0:9.322e-03
Epoch: 2, >0 satisfied: True, max g>0: -3.464e-06, dual >0 satisfied: False, max pro

In [169]:
torch.save(game_b25.state_dict(), './game_b25.pth')

In [192]:
v0_25 = calculate_initial_v(game_b25, alpha=0.025)
v0_25 = torch.FloatTensor(v0_25).to('cuda')

pi_25, v_25 = game_b25()
f0_25 = calculate_bellman_error(pi_25, v0_25)
f_25 = calculate_bellman_error(pi_25, v_25)

In [188]:
print(f_25.item(), f0_25.item())

4.2151641845703125 5.6512908935546875


In [193]:
game_b26 = copy_game(game_b25)
with torch.no_grad():
  nn.init.zeros_(game_b26.v)
  game_b26.v.data.add_(v0_25) 

In [194]:
game_b27, rho = optimize_game(game_b26, 1.0, n_epochs=75, verbose=False)

Epoch: 0, f: 5.130e+00, delta f: -9.229e+00%, nash satisfied: True, max g: -2.498e-02, dual nash satisfied: False, max product: 5.848e-02, rho: 1.933e-03, norm2 d0:1.285e-02
Epoch: 0, >0 satisfied: True, max g>0: -1.895e-06, dual >0 satisfied: False, max product >0: 2.174e-04, =1 satisfied: True, max g=1: -1.371e-06, dual =1 satisfied: False, max product =1: 1.680e-04
Epoch: 1, f: 5.079e+00, delta f: -1.013e+01%, nash satisfied: True, max g: -1.254e-02, dual nash satisfied: False, max product: 3.902e-02, rho: 1.933e-03, norm2 d0:8.163e-03
Epoch: 1, >0 satisfied: True, max g>0: -7.841e-06, dual >0 satisfied: False, max product >0: 6.758e-04, =1 satisfied: True, max g=1: -4.530e-06, dual =1 satisfied: False, max product =1: 1.876e-04
Epoch: 2, f: 5.003e+00, delta f: -1.147e+01%, nash satisfied: True, max g: -1.149e-02, dual nash satisfied: False, max product: 3.712e-02, rho: 1.933e-03, norm2 d0:7.825e-03
Epoch: 2, >0 satisfied: True, max g>0: -6.435e-06, dual >0 satisfied: False, max pro

KeyboardInterrupt: ignored

In [195]:
for s in S_str:
  print(s)
  print(pi_25['1'][s].detach().cpu().numpy())

('G',)
[[9.9999076e-01]
 [9.2387199e-06]]
('O', 1)
[[4.3605594e-03]
 [4.3490697e-03]
 [4.3556122e-03]
 [4.3272479e-03]
 [4.3490212e-03]
 [4.3102736e-03]
 [4.3414407e-03]
 [4.2158761e-03]
 [4.3019666e-03]
 [2.5135855e-04]
 [9.6083623e-01]
 [1.3709068e-06]]
('O', 2)
[[1.]]
('E1', 0, 0)
[[1.8945373e-06]
 [9.9999809e-01]]
('E1', 0, 1)
[[2.5282839e-06]
 [9.9999750e-01]]
('E1', 0, 2)
[[3.4774978e-06]
 [9.9999654e-01]]
('E1', 0, 3)
[[5.3463818e-06]
 [9.9999464e-01]]
('E1', 0, 4)
[[1.2350479e-05]
 [9.9998766e-01]]
('E1', 0, 5)
[[1.9552601e-06]
 [9.9999803e-01]]
('E1', 1, 0)
[[2.6374576e-06]
 [9.9999738e-01]]
('E1', 1, 1)
[[4.1492417e-06]
 [9.9999583e-01]]
('E1', 1, 2)
[[5.792298e-06]
 [9.999942e-01]]
('E1', 1, 3)
[[1.444188e-05]
 [9.999856e-01]]
('E1', 1, 4)
[[0.01042041]
 [0.98957956]]
('E1', 1, 5)
[[9.9999076e-01]
 [9.2387199e-06]]
('E2', 0, 0)
[[1.]]
('E2', 0, 1)
[[1.]]
('E2', 0, 2)
[[1.]]
('E2', 0, 3)
[[1.]]
('E2', 0, 4)
[[1.]]
('E2', 0, 5)
[[1.]]
('E2', 1, 0)
[[1.]]
('E2', 1, 1)
[[1.]]
('

In [133]:
game_new = copy_game(game_b)
pi_0, v_0 = game_b()
f_0 = calculate_bellman_error(pi_0, v_0)
ttype = game_new.transform_type

In [33]:
with torch.no_grad():    
  pi, v = game_new()
  pi_nolast, pi_last = game_new.pi_discriminated()
  g_nash = calculate_nash_restrictions(pi, v)
  grad_f_v, grad_f_pi = bellman_error_gradients(pi, v)
  grad_g_v, grad_g_pi = nash_restriction_gradients(pi, v)
  grad_f_pit, grad_g_pit = calculate_nash_pi_transformed_gradients(pi, grad_f_pi, grad_g_pi, transform_type=ttype)
  d0_v, d0_pit, norm_2_d0, duals_0, A_system, b_system = calculate_descent_direction(pi_nolast, pi_last, g_nash, 
                                                                                      grad_f_v, grad_f_pit, grad_g_v, grad_g_pit)
  d_v, d_pit, duals, rho = calculate_feasible_direction(d0_v, d0_pit, norm_2_d0, duals_0, A_system, b_system,grad_f_v, grad_g_v, grad_f_pit, grad_g_pit, rho)
  grad_f_vector, grad_g_matrix = build_grad_tensors(grad_f_v, grad_f_pit, grad_g_v, grad_g_pit)
  d0_vector, norm_2_d0_wt, duals_0_vector, A_matrix, b_vector = calculate_descent_direction_w_tensors(pi_nolast, pi_last, g_nash, grad_f_vector, grad_g_matrix)
  d_vector, duals_vector, rho_wt = calculate_feasible_direction_w_tensors(duals_0_vector, norm_2_d0_wt, A_matrix, b_vector, rho, grad_f_vector, grad_g_matrix)
  d_v_wt, d_pit_wt, duals_wt = vec2dic(d_vector, duals_vector)
  # g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero = check_nash_KKT_conditions(pi, v, duals)
  # if ttype == 'nolast':
  #   pi_KKT = check_pi_KKT_conditions(pi_nolast, pi_last, duals)
  #   g_pi_plus_satisfied, max_g_pi_plus, product_zero_plus_satisfied, max_product_zero_plus = pi_KKT[:4] 
  #   g_pi_one_satisfied, max_g_pi_one, product_zero_one_satisfied, max_product_zero_one = pi_KKT[4:]

In [34]:
decrement0 = torch.einsum('ij,i->j', d0_v, grad_f_v.view(-1)).sum()
for s, player in state_player_str_pairs():
  NA = N_A_S[s][player]
  if NA > 1:
    decrement0 += (d0_pit[player][s].view(-1) * grad_f_pit[player][s].view(-1)).sum()

decrement0_wt = (d0_vector * grad_f_vector).sum()

decrement = torch.einsum('ij,i->j', d_v, grad_f_v.view(-1)).sum()
for s, player in state_player_str_pairs():
  NA = N_A_S[s][player]
  if NA > 1:
    decrement += (d_pit[player][s].view(-1) * grad_f_pit[player][s].view(-1)).sum()

decrement_wt = (d_vector * grad_f_vector).sum()

decrement_wt_dic = torch.einsum('ij,i->j', d_v_wt, grad_f_v.view(-1)).sum()
for s, player in state_player_str_pairs():
  NA = N_A_S[s][player]
  if NA > 1:
    decrement_wt_dic += (d_pit_wt[player][s].view(-1) * grad_f_pit[player][s].view(-1)).sum()

In [35]:
print(norm_2_d0, norm_2_d0_wt)
print(decrement0, decrement0_wt)
print(decrement, decrement_wt, decrement_wt_dic)

tensor(85.5610, device='cuda:0') 56.85142517089844
tensor(-124.8436, device='cuda:0') tensor(-153.9966, device='cuda:0')
tensor(2894.5242, device='cuda:0') tensor(-139.7838, device='cuda:0') tensor(-139.7838, device='cuda:0')


In [25]:
np.savetxt('grad_f.csv', grad_f_vector.cpu().numpy(), delimiter=';')
np.savetxt('grad_g.csv', grad_g_matrix.cpu().numpy(), delimiter=';')

In [71]:
A_eigenval, A_eigenvec = torch.eig(A_matrix)

In [None]:
A_eigenval

In [58]:
g_list = []
for player in players_str():
  for s in S_str:
    g_list.append(g_nash[player][s])
for player in players_str():
  g_list.append(-pi_nolast[player])
for player in players_str():
  g_list.append(-pi_last[player])
g_tensor = torch.cat(g_list, dim=0)
g_diag_matrix = torch.diag(g_tensor.view(-1))

grad_g_outer = torch.einsum('ik,jk->ij', grad_g_matrix, grad_g_matrix)

In [61]:
g_diag_eigenval, g_diag_eigenvec = torch.eig(g_diag_matrix)
grad_g_outer_eigenval, grad_g_outer_eigenvec = torch.eig(grad_g_outer)

In [35]:
for s in S_str:
  print(s)
  print(RM['1'][s].detach().cpu().numpy())

('G',)
[[3. 0.]
 [5. 1.]]
('O', 1)
[[0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]]
('O', 2)
[[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.]]
('E1', 0, 0)
[[0.]
 [0.]]
('E1', 0, 1)
[[1.]
 [0.]]
('E1', 0, 2)
[[2.]
 [0.]]
('E1', 0, 3)
[[3.]
 [0.]]
('E1', 0, 4)
[[4.]
 [0.]]
('E1', 0, 5)
[[5.]
 [0.]]
('E1', 1, 0)
[[0.]
 [0.]]
('E1', 1, 1)
[[1.]
 [0.]]
('E1', 1, 2)
[[2.]
 [0.]]
('E1', 1, 3)
[[3.]
 [0.]]
('E1', 1, 4)
[[4.]
 [0.]]
('E1', 1, 5)
[[5.]
 [0.]]
('E2', 0, 0)
[[0. 0.]]
('E2', 0, 1)
[[-1.  0.]]
('E2', 0, 2)
[[-2.  0.]]
('E2', 0, 3)
[[-3.  0.]]
('E2', 0, 4)
[[-4.  0.]]
('E2', 0, 5)
[[-5.  0.]]
('E2', 1, 0)
[[0. 0.]]
('E2', 1, 1)
[[-1.  0.]]
('E2', 1, 2)
[[-2.  0.]]
('E2', 1, 3)
[[-3.  0.]]
('E2', 1, 4)
[[-4.  0.]]
('E2', 1, 5)
[[-5.  0.]]
('R1', 0)
[[3.]
 [5.]]
('R1', 1)
[[0.]
 [1.]]
('R2', 0)
[[3. 0.]]
('R2', 1)
[[5. 1.]]


In [41]:
P9 = transition_matrix(pi_9)
r9 = expected_reward(RM, pi_9)
val9 = torch.solve(r9, (torch.eye(P9.shape[0]).to('cuda') - beta*P9))[0]

In [43]:
v_9-val9

tensor([[164.4264, 164.4214],
        [161.1544, 161.1530],
        [161.1578, 161.1496],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7821, 162.7773],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7820, 162.7772],
        [162.7820, 162.7771],
        [163.7510, 162.7552],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7821, 162.7773],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7821, 162.7773],
        [162.7821, 162.7772],
        [162.7821, 162.7772],
        [162.7820, 162.7771],
        [162.7601, 163.7462],
        [159.5446, 159.5398],
        [160.2549, 159.5398],
        [159.5446, 159.5398],
        [159.5446, 160.2502]], device='cuda:0', grad_fn=<SubBackward0>)

In [44]:
pi_9_mod = {'1':{}, '2':{}}
for s, player in state_player_str_pairs():
  pi_9_mod[player][s] = pi_9[player][s]

In [48]:
pi_9_mod['1'][str(('G',))] = 1 - pi_9_mod['1'][str(('G',))]

In [49]:
P9_mod = transition_matrix(pi_9_mod)
r9_mod = expected_reward(RM, pi_9_mod)
val9_mod = torch.solve(r9_mod, (torch.eye(P9_mod.shape[0]).to('cuda') - beta*P9_mod))[0]

In [50]:
val9_mod-val9

tensor([[  67.3352, -101.0034],
        [  65.9952,  -98.9934],
        [  65.9952,  -98.9934],
        [  66.6619,  -99.9934],
        [  66.6618,  -99.9934],
        [  66.6618,  -99.9934],
        [  66.6618,  -99.9933],
        [  66.6618,  -99.9933],
        [  66.6619,  -99.9934],
        [  66.6618,  -99.9933],
        [  66.6618,  -99.9934],
        [  66.6618,  -99.9933],
        [  66.6618,  -99.9933],
        [  66.6618,  -99.9933],
        [  66.6513,  -99.9775],
        [  66.6618,  -99.9934],
        [  66.6618,  -99.9934],
        [  66.6618,  -99.9934],
        [  66.6619,  -99.9934],
        [  66.6618,  -99.9933],
        [  66.6618,  -99.9934],
        [  66.6619,  -99.9934],
        [  66.6619,  -99.9933],
        [  66.6619,  -99.9934],
        [  66.6619,  -99.9934],
        [  66.6618,  -99.9933],
        [  66.6513,  -99.9776],
        [  65.3352,  -98.0034],
        [  65.3352,  -98.0034],
        [  65.3352,  -98.0034],
        [  65.3352,  -98.0034]], device=