<a href="https://colab.research.google.com/github/tarod13/Stochastic_Games/blob/master/stochastic_games_herkovitz_v2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
import numpy as np
from scipy.optimize import linprog
import matplotlib.pyplot as plt
import itertools
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter

In [7]:
beta = 0.99

M = 5
N_P = 2
N_A = 2
N_S = 3 + N_P * (2*(M+1) + 2)
G = [('G',)]
O = [('O',i) for i in range(1,N_P+1)]
E1 = [('E1',i,j) for i in range(0,N_P) for j in range(0,M+1)]
E2 = [('E2',i,j) for i in range(0,N_P) for j in range(0,M+1)]
R1 = [('R1',i) for i in range(0,N_P)]
R2 = [('R2',i) for i in range(0,N_P)]
S = list(itertools.chain(G, O, E1, E2, R1, R2)) 
S1 = list(itertools.chain(G, ('O',1), E1, R1))
S2 = list(itertools.chain([('O',2), E2, R2]))

RG1 = np.array([[3.,0.],[5.,1.]])
RG2 = np.array([[3.,5.],[0.,1.]])

In [8]:
def get_state_index(state):
  if state in S:
    return S.index(state)
  elif state in S_str:
    return S_str.index(state)
  else:
    assert 0 == 1, 'Invalid state'

def player_dim(i):
  if i in [1, '1']:
    return 'i'
  elif i in [2, '2']:
    return 'j'
  else:
    assert 0 == 1, 'Invalid player id'
  
def get_player_id(player):
  if player == '1':
    return 0
  elif player == '2':
    return 1
  else:
    assert 0 == 1, 'Invalid player'

def players():
  return range(1,2+1)

def players_str():
  return iter(['1','2']) 

def state_player_pairs():
  return itertools.product(S, players())

S_str = [str(s) for s in S]
def state_player_str_pairs():
  return itertools.product(S_str, players_str())

In [19]:
class full_game(nn.Module):
  def __init__(self):
    super().__init__()
    self.pi1 = nn.ParameterDict()
    self.pi1[str(('G',))] = Parameter(torch.Tensor(N_A,1))
    nn.init.constant_(self.pi1[str(('G',))], 1.0/N_A)
    self.pi1[str(('O',1))] = Parameter(torch.Tensor((M+1)*N_A,1))
    nn.init.constant_(self.pi1[str(('O',1))], 1.0/((M+1)*N_A))    
    self.pi1[str(('O',2))] = Parameter(torch.Tensor(1,1))
    nn.init.constant_(self.pi1[str(('O',2))], 1.0)
    
    for state in E1:
      self.pi1[str(state)] = Parameter(torch.Tensor(2,1))
      nn.init.constant_(self.pi1[str(state)], 1.0/2.0)
      
    for state in E2 + R2:
      self.pi1[str(state)] = Parameter(torch.Tensor(1,1))
      nn.init.constant_(self.pi1[str(state)], 1.0)

    for state in R1:
      self.pi1[str(state)] = Parameter(torch.Tensor(N_A,1))
      nn.init.constant_(self.pi1[str(state)], 1.0/N_A)

    self.pi2 = nn.ParameterDict()
    self.pi2[str(('G',))] = Parameter(torch.Tensor(N_A,1))
    nn.init.constant_(self.pi2[str(('G',))], 1.0/N_A)
    self.pi2[str(('O',1))] = Parameter(torch.Tensor(1,1))
    nn.init.constant_(self.pi2[str(('O',1))], 1.0)
    self.pi2[str(('O',2))] = Parameter(torch.Tensor((M+1)*N_A,1))
    nn.init.constant_(self.pi2[str(('O',2))], 1.0/((M+1)*N_A))

    for state in E2:
      self.pi2[str(state)] = Parameter(torch.Tensor(2,1))
      nn.init.constant_(self.pi2[str(state)], 1.0/2.0)
      
    for state in E1 + R1:
      self.pi2[str(state)] = Parameter(torch.Tensor(1,1))
      nn.init.constant_(self.pi2[str(state)], 1.0)

    for state in R2:
      self.pi2[str(state)] = Parameter(torch.Tensor(N_A,1))
      nn.init.constant_(self.pi2[str(state)], 1.0/N_A)

    self.v = Parameter(torch.Tensor(N_S,2))
    nn.init.zeros_(self.v)    

  def forward(self):
    return self.pi(), self.v
  
  def pi(self):
    pi = {'1': self.pi1, '2':self.pi2}
    return pi
  
  def pi2vec(self):
    pi_vector = {}
    for player in players_str():
      pi_list = []
      for s in S_str:
        if player == '1':
          pi_list.append(self.pi1[s])
        elif player == '2':
          pi_list.append(self.pi2[s])
      pi_vector[player] = torch.cat(pi_list, dim=0)
    return pi_vector
  
  def pi_sum(self):
    sum_vector = {}
    for player in players_str():
      sum_list = []
      for s in S_str:
        if player == '1':
          sum_list.append(self.pi1[s].sum(0, keepdim=True))
        elif player == '2':
          sum_list.append(self.pi2[s].sum(0, keepdim=True))
      sum_vector[player] = torch.cat(sum_list, dim=0)
    return sum_vector

def copy_game(game_original):
  game_copy = full_game().to('cuda')
  game_copy.load_state_dict(game_original.state_dict())
  return game_copy

In [10]:
game_b = full_game().to('cuda')

N_A_S = {}
N_A_S_tensors = {'1':[], '2':[]}
N_A_S_tensors_reduced = {'1':[], '2':[]}
N_A_total = {}
N_A_reduced = {}
N_S_reduced = {}
with torch.no_grad():
  pi, v = game_b()
  for state in S:
    N_A1_state = pi['1'][str(state)].shape[0]
    N_A2_state = pi['2'][str(state)].shape[0]
    N_A_S[str(state)] = {'1': N_A1_state, '2': N_A2_state}
    N_A_S_tensors['1'].append(N_A1_state)
    N_A_S_tensors['2'].append(N_A2_state)
    if N_A1_state >= 2:
      N_A_S_tensors_reduced['1'].append(N_A1_state)
    if N_A2_state >= 2:
      N_A_S_tensors_reduced['2'].append(N_A2_state)
for player in players_str():
  N_A_S_tensors[player] = torch.FloatTensor(N_A_S_tensors[player]).to('cuda').view(-1,1)
  N_A_S_tensors_reduced[player] = torch.FloatTensor(N_A_S_tensors_reduced[player]).to('cuda').view(-1,1)
  N_A_total[player] = int(N_A_S_tensors[player].sum().item())
  N_A_reduced[player] = int(N_A_S_tensors_reduced[player].sum().item())
  N_S_reduced[player] = N_A_S_tensors_reduced[player].view(-1).shape[0]

n_restrictions = 0
for player in players_str():
  n_restrictions += N_A_total[player] + N_A_reduced[player] + N_S_reduced[player]
n_vars = 2*N_S + N_A_reduced['1'] + N_A_reduced['2']

In [11]:
def more_than_one_action(vec):
  if vec.shape[0] > 1:
    return True
  else:
    return False


def more_than_one_action_in_s(player,s):
  if N_A_S[s][player] > 1:
    return True
  else:
    return False

def offer_accepted(action):
  if action == 0:
    return True
  else:
    return False
  

def mask_inequality_restrictions(vec):
  masked_vector = vec.clone()
  masked_vector[-N_S_reduced['1']-N_S_reduced['2']:,:] = 0.0
  return masked_vector

def mask_equality_restrictions(vec):
  masked_vector = vec.clone()
  masked_vector[:-N_S_reduced['1']-N_S_reduced['2'],:] = 0.0
  return masked_vector


def other_player(i):
  if i == 1:
    return 2
  elif i == 2:
    return 1
  elif i == '1':
    return '2'
  elif i == '2':
    return '1'
  else:
    assert 0 == 1, 'Invalid player id'


def transition_info(state, actions):
  if 'G' in state or 'R1' in state or 'R2' in state:
    T = [(('O',1), 0.5), (('O',2), 0.5)]
    return ('R', T)
  elif 'O' in state:
    _, id_player = state
    offeral = actions[id_player-1] // N_A
    action_requested = actions[id_player-1] % N_A
    return ('D', ('E'+str(other_player(id_player)), action_requested, offeral))
  elif 'E1' in state:
    _, action_requested, _ = state
    if offer_accepted(actions[0]):
      return ('D', ('R2', action_requested))
    else:
      return ('D', ('G',))
  elif 'E2' in state:
    _, action_requested, _ = state
    if offer_accepted(actions[1]):
      return ('D', ('R1', action_requested))
    else:
      return ('D', ('G',))   
  else:
    assert 0 == 1, 'Invalid state' 

In [12]:
transition_types = {}
for state in S:
  if 'G' in state or 'R1' in state or 'R2' in state:
    transition_types[str(state)] = (0,0) 
  else:
    transition_types[str(state)] = (1,1) # entries correspond to deterministic behaviour and dependence on actions

In [13]:
def rewards(state, actions):
  if 'G' in state:
    r1 = RG1[actions[0], actions[1]]
    r2 = RG2[actions[0], actions[1]]
  elif 'O' in state:
    r1 = r2 = 0.0
  elif 'E1' in state:
    _, _, offeral = state
    if offer_accepted(actions[0]):
      r1 = offeral
      r2 = -offeral
    else:
      r1 = r2 = 0
  elif 'E2' in state:
    _, _, offeral = state
    if offer_accepted(actions[1]):
      r1 = -offeral
      r2 = offeral
    else:
      r1 = r2 = 0
  elif 'R1' in state:
    _, i = state
    r1 = RG1[actions[0], i]
    r2 = RG2[actions[0], i]
  elif 'R2' in state:
    _, i = state
    r1 = RG1[i, actions[1]]
    r2 = RG2[i, actions[1]]
  else:
    assert 0 == 1, 'Invalid state'

  return r1, r2


def reward_matrices(s):
  if 'G' in s:
    RM1 = RG1.copy()
    RM2 = RG2.copy()
  else:
    N_A1 = N_A_S[str(s)]['1']
    N_A2 = N_A_S[str(s)]['2']
    RM1 = np.zeros((N_A1,N_A2))
    RM2 = np.zeros((N_A1,N_A2))
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        r1, r2 = rewards(s, [a1,a2])
        RM1[a1,a2] = r1
        RM2[a1,a2] = r2
  return RM1, RM2  

In [14]:
RM = {}
RM['1'] = {}
RM['2'] = {}
for s in S:
  RM1, RM2 = reward_matrices(s)
  RM['1'][str(s)] = torch.FloatTensor(RM1).to('cuda')
  RM['2'][str(s)] = torch.FloatTensor(RM2).to('cuda')

In [15]:
def player_consistent_reward_matrices():
  consistent_RM = {'1':{}, '2':{}}
  for s in S_str:
    consistent_RM['1'][s] = RM['1'][s].clone()
    consistent_RM['2'][s] = torch.t(RM['2'][s].clone())  
  return consistent_RM 

In [16]:
def next_value_matrices(s, v): # TODO: consider other 2 cases
  det, dep = transition_types[str(s)]
  N_A1 = N_A_S[str(s)]['1']
  N_A2 = N_A_S[str(s)]['2']
  vs = torch.zeros((N_A1,N_A2,2)).to('cuda')
  if det and dep:
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        _, next_state = transition_info(s, [a1,a2])
        vs[a1,a2,:] = v[get_state_index(next_state),:]
  elif (not det) and (not dep):
    _, transition_dic = transition_info(s, [])
    next_v = torch.zeros(1,2).to('cuda')
    for next_state, transition_prob in transition_dic:
      next_v = next_v + v[get_state_index(next_state),:].view(1,-1) * transition_prob      
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        vs[a1,a2,:] = next_v.view(-1)
  return vs

def next_value_dictionary(v):
  next_v_dic = {}
  for s in S:
    next_v_dic[str(s)] = next_value_matrices(s, v)
  return next_v_dic


def transition_matrix(pi): # TODO: consider other 2 cases
  transition_matrix = torch.zeros((N_S,N_S)).to('cuda')
  for state in S:
    strategy_1 = pi['1'][str(state)]
    strategy_2 = pi['2'][str(state)]

    det, dep = transition_types[str(state)]
    N_A1 = strategy_1.shape[0]
    N_A2 = strategy_2.shape[0]
    id_s = get_state_index(state)

    if det and dep:
      for a1 in range(0,N_A1):
        for a2 in range(0,N_A2):
          _, next_state = transition_info(state, [a1,a2])
          id_ns = get_state_index(next_state)
          transition_prob = strategy_1[a1,0] * strategy_2[a2,0]
          transition_matrix[id_s, id_ns] = transition_matrix[id_s, id_ns] + transition_prob
    elif (not det) and (not dep):
      _, transition_dic = transition_info(state, [])
      for next_state, transition_prob in transition_dic:
        transition_matrix[id_s, get_state_index(next_state)] = transition_prob
  return transition_matrix


def partial_transition_matrices(pi): # TODO: consider other 2 cases
  # Create dictionary of transition matrices for each state given
  # the strategy of the other player 
  transition_matrices = {}
  for s in S:
    strategy_1 = pi['1'][str(s)]
    strategy_2 = pi['2'][str(s)]
    N_A1 = strategy_1.shape[0]
    N_A2 = strategy_2.shape[0]
    transition_matrices[str(s)] = {
        '1': torch.zeros((N_A1,N_S)).to('cuda'),
        '2': torch.zeros((N_A2,N_S)).to('cuda')
    }

    # Fill matrices with transition probabilities depending on the type
    # of transition, i.e., if deterministic or random and independent or
    # not on the actions
    det, dep = transition_types[str(s)]
    if det and dep:
      for a1 in range(0,N_A1):
        for a2 in range(0,N_A2):
          _, next_state = transition_info(s, [a1,a2])
          id_ns = get_state_index(next_state)
          transition_prob1 = strategy_2[a2,0]
          transition_prob2 = strategy_1[a1,0]
          transition_matrices[str(s)]['1'][a1, id_ns] = (
              transition_matrices[str(s)]['1'][a1, id_ns] + transition_prob1)
          transition_matrices[str(s)]['2'][a2, id_ns] = (
              transition_matrices[str(s)]['2'][a2, id_ns] + transition_prob2)
    elif (not det) and (not dep):
      _, transition_dic = transition_info(s, [])
      for next_state, transition_prob in transition_dic:
        transition_matrices[str(s)]['1'][:, get_state_index(next_state)] = transition_prob
        transition_matrices[str(s)]['2'][:, get_state_index(next_state)] = transition_prob
  return transition_matrices


def expected_reward(RM, pi):
  r_mean = torch.zeros((N_S,2)).to('cuda')
  for i in range(1,2+1):
    RM_i = RM[str(i)]
    for s in S:
      strategy_1 = pi['1'][str(s)]
      strategy_2 = pi['2'][str(s)]
      r_mean_1 = torch.einsum('ij,ik->jk', RM_i[str(s)], strategy_1)
      r_mean[get_state_index(s),i-1] = (r_mean_1 * strategy_2).sum()
  return r_mean


def partial_expected_reward_other(RM, pi):
  r_mean = {}
  for player in players_str():
    r_mean[player] = {}

  for s, player in state_player_str_pairs():
      N_A = N_A_S[s][player]
      RM_i = RM[player]
      other_player_ = other_player(player)
      strategy = pi[other_player_][s].view(-1)
      formula = 'ij,'+player_dim(other_player_)+'->'+player_dim(player)
      r_mean[player][s] = torch.einsum(formula, RM_i[s], strategy).view(-1,1)
  return r_mean


def partial_expected_reward(RM, pi):
  # Create reward dictionary for each combination of players
  r_mean = {}
  for player in players_str():  
    r_mean[player] = {'1':{}, '2':{}}

  # Calculate expected reward for combination of players wrt the policy of one of the players
  for s, player in state_player_str_pairs():
    other_player_ = other_player(player) # Player used to calculate expected reward
    for second_player in players_str():  
      N_A = N_A_S[s][second_player]
      RM_i = RM[second_player] # Reward matrix for one of the players  
      strategy = pi[other_player_][s].view(-1)
      formula = 'ij,'+player_dim(other_player_)+'->'+player_dim(player)
      r_mean[player][second_player][s] = torch.einsum(formula, RM_i[s], strategy).view(-1,1)
  return r_mean


def bellman_projection(RM, pi, v):
  r_mean = expected_reward(RM, pi)
  next_v = torch.zeros((N_S,2)).to('cuda')
  for s in S:
    next_state_value_matrix = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    next_value_1 = torch.einsum('ijk,i->jk', next_state_value_matrix, strategy_1)
    next_v[get_state_index(s),:] = torch.einsum('jk,j->k', next_value_1, strategy_2)
  return r_mean + beta * next_v


def bellman_partial_projection_other(RM, pi, v):
  r_mean = partial_expected_reward_other(RM, pi)
  next_v = {'1':{}, '2':{}}
  for s in S:
    next_state_value_matrix = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    # mean next value when considering the strategy of the other player. Output: array of size m^i(s)
    next_v['1'][str(s)] = torch.einsum('ij,j->i', next_state_value_matrix[:,:,0], strategy_2).view(-1,1)
    next_v['2'][str(s)] = torch.einsum('ij,i->j', next_state_value_matrix[:,:,1], strategy_1).view(-1,1)
  bellman_projection_dic = {'1':{}, '2':{}}
  for s, player_id in state_player_pairs():
    bellman_projection_dic[str(player_id)][str(s)] = r_mean[str(player_id)][str(s)] + beta * next_v[str(player_id)][str(s)]
  return bellman_projection_dic


def partial_next_values(pi, v):
  # Create next-value dictionary for each player combination
  next_v = {'1':{}, '2':{}}
  for player in players_str():
    next_v[player] = {'1':{}, '2':{}}

  # Fill dictionary 
  for s in S:
    next_state_value_matrix = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    
    # Calculate mean next value when considering the strategy of one of the players. Output: array of size m^i(s)
    for player_id in players():
      next_v['1'][str(player_id)][str(s)] = torch.einsum('ij,j->i', next_state_value_matrix[:,:,player_id-1], strategy_2).view(-1,1)
      next_v['2'][str(player_id)][str(s)] = torch.einsum('ij,i->j', next_state_value_matrix[:,:,player_id-1], strategy_1).view(-1,1)  
  return next_v


def bellman_partial_projection(RM, pi, v):
  r_mean = partial_expected_reward(RM, pi)
  next_v = partial_next_values(pi, v)

  bellman_projection_dic = {'1':{}, '2':{}}
  for player in players_str():
    bellman_projection_dic[player] = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    for second_player in players_str():
      bellman_projection_dic[player][second_player][s] = (
          r_mean[player][second_player][s] + beta * next_v[player][second_player][s]
          )
  return bellman_projection_dic


def reward_baselines(RM, pi):
  r_mean = expected_reward(RM, pi)
  r_baseline = r_mean.mean(0).view(-1,1).detach().cpu().numpy()
  return r_baseline

In [17]:
def cost_vector_fixed_policies(pi):
  P = transition_matrix(pi)
  cost_vector = (1 - beta * P.sum(0)).view(-1,1).detach().cpu().numpy() / N_S
  return cost_vector

def restriction_matrices_fixed_policies(pi):
  transition_matrices = partial_transition_matrices(pi)
  restriction_matrices = {'1':[], '2':[]}
  for s, player_id in state_player_pairs():
    temp_matrix = - beta * transition_matrices[str(s)][str(player_id)]
    temp_matrix[:, get_state_index(s)] = temp_matrix[:, get_state_index(s)] + 1
    restriction_matrices[str(player_id)].append(temp_matrix)
  for player in players_str():
    restriction_matrices[player] = -torch.cat(restriction_matrices[player], dim=0).detach().cpu().numpy()
  return restriction_matrices

def restriction_vectors_fixed_policies(RM, pi, alpha=0.1):
  r_mean = partial_expected_reward_other(RM, pi)
  restriction_vectors = {'1':[], '2':[]}
  for s, player in state_player_str_pairs():
    restriction_vectors[player].append(r_mean[player][s].view(-1,1))
  for player in players_str():
    restriction_vectors[player] = -(torch.cat(restriction_vectors[player], dim=0)+alpha).detach().cpu().numpy()
  return restriction_vectors


def parameters_fixed_policies(game_, alpha):
  pi = game_()[0]
  c = cost_vector_fixed_policies(pi)
  f0 = reward_baselines(RM, pi)
  A_ub = restriction_matrices_fixed_policies(pi)
  b_ub = restriction_vectors_fixed_policies(RM, pi, alpha)
  return c, f0, A_ub, b_ub


def calculate_initial_v(game_, alpha=10):
  v0 = np.zeros((N_S,2))
  c, f0, A_ub, b_ub = parameters_fixed_policies(game_, alpha=alpha)
  for player_id in players():
    temp_res = linprog(c, A_ub=A_ub[str(player_id)], b_ub=b_ub[str(player_id)])
    v0[:,player_id-1] = temp_res.x
  return v0

In [18]:
def calculate_nash_restrictions(pi, v):
  q_estimated = bellman_partial_projection_other(RM, pi, v) # Dic. with an array of 'q'-values for each agent 
  g_nash = {'1':{}, '2':{}}
  for s, player_id in state_player_pairs():
    g_nash[str(player_id)][str(s)] = q_estimated[str(player_id)][str(s)] - v[get_state_index(s), player_id-1]
  return g_nash


def calculate_bellman_error(pi, v):
  # Calculation of original target function: Bellman approximation error 
  v_estimated = bellman_projection(RM, pi, v)
  f_bellman = (v - v_estimated).sum()
  return f_bellman


def check_nash_KKT_conditions(pi, v, lambda_nash, tol=1e-8):
  with torch.no_grad():
    # Calculate restrictions
    g_nash = calculate_nash_restrictions(pi, v)

    g_nash_satisfied = True
    product_zero_satisfied = True
    
    max_g_nash = -np.infty
    max_product_zero = -np.infty
    
    for player in players_str():
      remaining_duals = lambda_nash[player].clone()
      for s in S_str:
        NA = N_A_S[s][player]
        g_nash_satisfied = g_nash_satisfied and torch.all(g_nash[player][s] <= 0)
        max_g_nash = max(max_g_nash, g_nash[player][s].max())

        lambda_g_nash_product = g_nash[player][s].view(-1) * remaining_duals[:NA,:].view(-1)
        product_zero_satisfied = product_zero_satisfied and torch.all(lambda_g_nash_product.abs() <= tol)
        max_product_zero = max(max_product_zero, lambda_g_nash_product.abs().max())
        remaining_duals = remaining_duals[NA:,:]
    return g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero


def check_pi_KKT_conditions(pi_vector, pi_sum_vector, duals, tol=1e-8):
  with torch.no_grad():
    g_pi_plus_satisfied = True
    g_pi_one_satisfied = True

    product_zero_plus_satisfied = True
    product_zero_one_satisfied = True

    max_g_pi_plus = -np.infty
    max_g_pi_one = -np.infty

    max_product_zero_plus = -np.infty
    max_product_zero_one = -np.infty

    for player in players_str():
      g_pi_plus_satisfied = g_pi_plus_satisfied and torch.all(pi_vector[player] >= 0)
      g_pi_one_satisfied = g_pi_one_satisfied and torch.all(pi_sum_vector[player] <= 1)

      max_g_pi_plus = max(max_g_pi_plus, -pi_vector[player].min().item())
      max_g_pi_one = max(max_g_pi_one, (pi_sum_vector[player]-1).min().item())

      NAt = N_A_total[player]
      NAr = N_A_reduced[player]
      lambda_g_pi_plus_product = -pi_vector[player].view(-1) * duals[player][NAt:NAt+NAr,:].view(-1) # TODO: fix for assymetric number of actions
      product_zero_plus_satisfied = product_zero_plus_satisfied and torch.all(lambda_g_pi_plus_product.abs() <= tol)
      max_product_zero_plus = max(max_product_zero_plus, lambda_g_pi_plus_product.abs().max())
      lambda_g_pi_one_product = (pi_sum_vector[player].view(-1)-1) * duals[player][NAt+NAr:,:].view(-1)
      product_zero_one_satisfied = product_zero_one_satisfied and torch.all(lambda_g_pi_one_product.abs() <= tol)
      max_product_zero_one = max(max_product_zero_one, lambda_g_pi_one_product.abs().max())
   
    return (g_pi_plus_satisfied, max_g_pi_plus, product_zero_plus_satisfied, max_product_zero_plus, 
            g_pi_one_satisfied, max_g_pi_one, product_zero_one_satisfied, max_product_zero_one)

In [None]:
def bellman_error_gradients(pi, v):
  P = transition_matrix(pi)
  bellman_error_grad_v = (torch.eye(N_S).to('cuda') - torch.t(P)).sum(1, keepdim=True)

  q_individual = bellman_partial_projection(RM, pi, v)
  bellman_error_grad_pi = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    bellman_error_grad_pi[player][s] = 0.0
    for second_player in players_str():
      bellman_error_grad_pi[player][s] = bellman_error_grad_pi[player][s] - q_individual[player][second_player][s]  
  return bellman_error_grad_v, bellman_error_grad_pi

def nash_restriction_gradients(pi, v):
  P_partial = partial_transition_matrices(pi)
  next_value_dic = next_value_dictionary(v)
  consistent_RM = player_consistent_reward_matrices()

  nash_restriction_grad_v = {'1':{}, '2':{}}
  nash_restriction_grad_pi = {'1':{}, '2':{}}

  for s, player in state_player_str_pairs():
    Delta = torch.zeros_like(P_partial[s][player])
    Delta[:, get_state_index(s)] = 1.0
    nash_restriction_grad_v[player][s] = beta * P_partial[s][player]  - Delta
  
    other_player_ = other_player(player)
    next_value_matrix = next_value_dic[s][:,:,get_player_id(other_player_)]
    r_matrix = consistent_RM[other_player_][s]
    if player == '1':
      next_value_matrix = torch.t(next_value_matrix)
    nash_restriction_grad_pi[player][s] = r_matrix + beta * next_value_matrix
  return nash_restriction_grad_v, nash_restriction_grad_pi


def grad_bellman_v2vec(grad_f_v, grad_f_pi):
  n_vars = 2*N_S + N_A_reduced['1'] + N_A_reduced['2']
  grad_f_vector = torch.zeros(n_vars, 1).to('cuda')
  
  for i in range(0,2):
    grad_f_vector[i*N_S:(i+1)*N_S,:] = grad_f_v.clone()
  
  for player in players_str():
    grad_f_pi_list = []
    for s in S_str:
      if more_than_one_action_in_s(player, s):
        grad_f_pi_list.append(grad_f_pi[player][s].clone())  
    grad_f_pi_vector = torch.cat(grad_f_pi_list, dim=0)
    y0 = 2*N_S + get_player_id(player) * N_A_reduced['1']
    yf = 2*N_S + N_A_reduced['1'] + get_player_id(player) * N_A_reduced['2']
    grad_f_vector[y0:yf,:] = grad_f_pi_vector

  return grad_f_vector

def grad_bellman_pi2mat(player, grad_g_pi):
  grad_g_pi_list = []
  s_not_added_list = []
  for s in S_str:
    if more_than_one_action_in_s(player, s):
      grad = grad_g_pi[player][s].clone()
      if len(s_not_added_list) == 0:
        grad_g_pi_list.append(grad)
      else:
        n_null_rows = 0
        for s_ in s_not_added_list:
          n_null_rows += N_A_S[s_][other_player_]
        n_rows = n_null_rows + grad.shape[0]
        n_cols = grad.shape[1]
        null_rows_and_grad_g_pi = torch.zeros(n_rows, n_cols).to('cuda')
        null_rows_and_grad_g_pi[-n_rows+n_null_rows:,:] = grad
        grad_g_pi_list.append(null_rows_and_grad_g_pi)
        s_not_added_list = []
    else:
      s_not_added_list.append(s)
  if len(s_not_added_list) != 0:
    n_null_rows = 0
    for s_ in s_not_added_list:
      n_null_rows += N_A_S[s_][other_player_]
    n_rows = n_null_rows + grad_g_pi_list[-1].shape[0]
    n_cols = grad_g_pi_list[-1].shape[1]
    null_rows_and_grad_g_pi = torch.zeros(n_rows, n_cols).to('cuda')
    null_rows_and_grad_g_pi[:n_rows-n_null_rows,:] = grad_g_pi_list[-1].clone()
    grad_g_pi_list[-1] = null_rows_and_grad_g_pi
  
  J_gmi_pit = torch.block_diag(*grad_g_pi_list)
  return J_gmi_pit

def build_grad_tensors(grad_f_v, grad_f_pi, grad_g_v, grad_g_pi):

  grad_f_vector = grad_bellman_v2vec(grad_f_v, grad_f_pi)
    
  grad_g_matrix = torch.zeros(n_restrictions, n_vars).to('cuda')
  
  for player in players_str():
    other_player_ = other_player(player)

    # fill jacobian with restriction gradients for v
    grad_g_v_list = [grad_g_v[player][s].clone() for s in S_str]
    J_gi_vi = torch.cat(grad_g_v_list, dim=0)
    y0 = get_player_id(player) * N_A_total['1']
    yf = N_A_total['1'] + get_player_id(player) * N_A_total['2']
    x0 = get_player_id(player) * N_S
    xf = x0 + N_S
    grad_g_matrix[y0:yf,x0:xf] = J_gi_vi

    # fill jacobian with nash restriction gradients for pi 
    J_gmi_pit = grad_bellman_pi2mat(player, grad_g_pi)
    y0 = get_player_id(other_player_) * N_A_total['1']
    yf = N_A_total['1'] + get_player_id(other_player_) * N_A_total['2']
    x0 = 2*N_S + get_player_id(player) * N_A_reduced['1']
    xf = 2*N_S + N_A_reduced['1'] + get_player_id(player) * N_A_reduced['2']
    grad_g_matrix[y0:yf,x0:xf] = J_gmi_pit

    # fill jacobian with positivity restriction gradients for pi 
    J_gpi_plus = -torch.eye(N_A_reduced[player]).to('cuda')
    y0 = N_A_total['1'] + N_A_total['2'] + get_player_id(player) * N_A_reduced['1']
    yf = (N_A_total['1'] + N_A_total['2'] + N_A_reduced['1'] 
          + get_player_id(player) * N_A_reduced['2'])
    grad_g_matrix[y0:yf,x0:xf] = J_gpi_plus

    # fill jacobian with unitary sum restriction gradients for pi
    grad_gpi_one_list = []
    for s in S_str:
      NA = N_A_S[s][player]
      if more_than_one_action_in_s(player, s):
        grad_gpi_one_list.append(torch.ones(1,NA).to('cuda'))

    J_gpi_one = torch.block_diag(*grad_gpi_one_list)
    y0 = (N_A_total['1'] + N_A_total['2'] + N_A_reduced['1'] 
          + N_A_reduced['2'] + get_player_id(player) * N_S_reduced['1'])
    yf = (N_A_total['1'] + N_A_total['2'] + N_A_reduced['1'] + N_A_reduced['2'] 
          + N_S_reduced['1'] + get_player_id(player) * N_S_reduced['2'])
    grad_g_matrix[y0:yf,x0:xf] = J_gpi_one

  return grad_f_vector, grad_g_matrix


def g_dics2vec(g_nash, pi_vector, pi_sum_vector):
  g_list = []
  for player in players_str():
    for s in S_str:
      g_list.append(g_nash[player][s])
  for player in players_str():
    g_list.append(-pi_vector[player])
  for player in players_str():
    g_list.append(pi_sum_vector[player]-1)
  g_vector = torch.cat(g_list, dim=0)
  return g_vector


def calculate_descent_direction(g_vector, grad_f_vector, grad_g_matrix):
  g_vector_ineq = mask_inequality_restrictions(g_vector)
  g_vector_eq = mask_equality_restrictions(g_vector)
  g_diag_matrix = torch.diag(r_value * g_vector_ineq.view(-1))

  A_matrix = g_diag_matrix - torch.einsum('ik,jk->ij', grad_g_matrix, grad_g_matrix) 
  b_vector = torch.einsum('ij,jk->ik', grad_g_matrix, grad_f_vector) - g_vector_eq 
  duals_0_vector = torch.solve(b_vector, A_matrix)[0]

  d0_vector = - grad_f_vector - torch.einsum('ij,ik->jk', grad_g_matrix, duals_0_vector)
  norm_2_d0 = d0_vector.pow(2).sum().item() 
  return d0_vector, norm_2_d0, duals_0_vector, A_matrix, b_vector


def calculate_feasible_direction(g_vector, c_vector, duals_0_vector, 
                                 norm_2_d0, A_matrix, b_vector, rho,
                                  grad_f_vector, grad_g_matrix):
  b_unitary = torch.solve(torch.ones_like(b_vector), A_matrix)[0]
  g_vector_eq = mask_equality_restrictions(g_vector)
  dot = (g_vector_eq.view(-1) * b_unitary.view(-1)).sum()

  div = (duals_0_vector.sum() + c_vector.sum() - dot).item()
  if div > 0:
      rho_1 = (1-alpha) / div
      if rho_1 < rho:
        rho = 0.5 * rho_1

  duals_vector = torch.solve(b_vector - rho * norm_2_d0, A_matrix)[0]
  d_vector = - grad_f_vector - torch.einsum('ij,ik->jk', grad_g_matrix, duals_vector)
  return d_vector, duals_vector, rho


def calculate_auxiliary_bellman_error(f, g_vector, c_vector):
  return f - (c_vector.view(-1) * g_vector.view(-1)).sum()


def feasible_gradient_descent(game_0, f, g_vector, c_vector, d_vector,
                              grad_f_vector, grad_g_matrix, d_v, d_pi, 
                              duals_vector, max_steps=1200, 
                              verbose=False):
  # Calculate loss function
  theta = calculate_auxiliary_bellman_error(f, g_vector, c_vector)
  
  # Calculate required decrement in the loss function
  grad_theta_vector = grad_f_vector + torch.einsum('ij,ik->jk', grad_g_matrix, c_vector)
  decrement = (d_vector.view(-1) * grad_theta_vector.view(-1)).sum().item()

  # Procedure to find feasible step size 
  found_feasible_step_size = False
  step_size = 1.0
  n_step = 0
  max_steps = max(1, max_steps)
  while (not found_feasible_step_size) and (n_step < max_steps):
    game_temp = copy_game(game_0)
    # Update parameters performing step in feasible descent direction 
    game_temp.v.data.add_(step_size * d_v)
    for s in S_str:
      if more_than_one_action_in_s(player='1', s):
        game_temp.pi1[s].data.add_(step_size * d_pi['1'][s])
      if more_than_one_action_in_s(player='2', s):
        game_temp.pi2[s].data.add_(step_size * d_pi['2'][s])

    pi_temp, v_temp = game_temp()    
    pi_vector_temp = game_temp.pi_vector()
    pi_sum_vector_temp = game_temp.pi_sum_vector()
    f_temp = calculate_bellman_error(pi_temp, v_temp)
    g_nash_temp = calculate_nash_restrictions(pi_temp, v_temp)
    g_vector_temp = g_dics2g_vec(g_nash, pi_vector_temp, pi_sum_vector_temp)
    theta_temp = calculate_auxiliary_bellman_error(f_temp, g_vector_temp, c_vector)

    gamma = gamma_0 * torch.ones_like(duals_vector)
    gamma[duals_vector < 0] = 1.0
    gamma[-N_S_reduced['1']-N_S_reduced['2'],:] = 0.0

    theta_decreased = theta_temp <= theta + eta * step_size * decrement
    g_v_valid = torch.all(g_vector_temp[:114] <= (g_vector * gamma)[:114])
    g_pi_valid = torch.all(g_vector_temp[114:] <= (g_vector * gamma)[114:])
    g_valid = g_v_valid and g_pi_valid
    
    # Check if the current step size results in a feasible descent direction
    if theta_decreased and g_valid:
      found_feasible_step_size = True
      
    step_size /= nu
    n_step += 1
    if verbose:
      print(f_decreased.item(), g_v_valid.item(), g_pi_valid.item())
      print("Step: {}, Step size: {:.3e}, Found feasible: {}".format(
          n_step, step_size, found_feasible_step_size))
  return game_temp, found_feasible_step_size, n_step, f_temp


def vec2dic(d_vector, duals_vector):
  d_v = torch.zeros((N_S,2)).to('cuda')  
  d_pi = {'1':{}, '2':{}}
  duals = {}

  for i in range(0,2):
    d_v[:,i] = d_vector[i*N_S:(i+1)*N_S,:].view(-1)
  
  N_A_total = {}
  N_S_reduced = {}
  for player in players_str():
    N_A_total[player] = int(N_A_S_tensors[player].sum().item()) 
    N_S_reduced[player] = N_A_S_tensors_reduced[player].view(-1).shape[0]

  for player in players_str():
    y0 = 2*N_S + get_player_id(player) * (N_A_total['1'] - N_S)
    yf = N_S + N_A_total['1']  + get_player_id(player) * (N_A_total['2'] - N_S) 
    d_pi_vector = d_vector[y0:yf,:]
    n = 0
    for s in S_str:
      NA = N_A_S[s][player]
      if NA > 1:
        d_pi[player][s] = d_pi_vector[n:n+NA-1,:]
      else:
        d_pi[player][s] = None
      n = n + NA-1
    
  for player in players_str():
    i = get_player_id(player)
    other_player_ = other_player(player)
    n_restrictions = 2*N_A_total[player]-N_S+N_S_reduced[player]
    duals[player] = torch.zeros(n_restrictions,1).to('cuda')
    
    y0 = get_player_id(player) * N_A_total['1']
    yf = N_A_total['1']  + get_player_id(player) * N_A_total['2'] 
    duals[player][:N_A_total[player],:] = duals_vector[y0:yf,:]
    
    y0 = N_A_total['1'] + N_A_total['2'] + get_player_id(player) * (N_A_total['1'] - N_S)
    yf = 2*N_A_total['1'] + N_A_total['2'] - N_S + get_player_id(player) * (N_A_total['2'] - N_S) 
    duals[player][N_A_total[player]:2*N_A_total[player]-N_S,:] = duals_vector[y0:yf,:]

    y0 = 2*(N_A_total['1'] + N_A_total['2'] - N_S) + get_player_id(player) * N_S_reduced['1']
    yf = 2*(N_A_total['1'] + N_A_total['2'] - N_S) + N_S_reduced['1'] + get_player_id(player) * N_S_reduced['2'] 
    duals[player][-N_S_reduced[player]:,:] = duals_vector[y0:yf,:]

  return d_v, d_pi, duals


def update_c_vector(c_vector, duals_0_vector):
  new_c_vector = c_vector.clone()
  entries_to_update = c_vector < -1.2*duals_0_vector
  new_c_vector[entries_to_update] = -2*duals_0_vector[entries_to_update]
  new_c_vector = mask_equality_restrictions(new_c_vector)
  return new_c_vector


def optimize_game(game_0, rho, n_epochs=100, verbose=False):
  game_new = copy_game(game_0)
  pi_0, v_0 = game_0()
  f_0 = calculate_bellman_error(pi_0, v_0)
  ttype = game_new.transform_type
  c_vector = torch.ones(n_restrictions,1).to('cuda')
  c_vector = mask_equality_restrictions(c_vector)
  with torch.no_grad():
    for epoch in range(0, n_epochs):
      pi, v = game_new()
      pi_vector = game_new.pi_vector()
      pi_sum_vector = game_new.pi_sum_vector()
      f = calculate_bellman_error(pi, v)
      g_nash = calculate_nash_restrictions(pi, v)
      g_vector = g_dics2g_vec(g_nash, pi_vector, pi_sum_vector)
      grad_f_v, grad_f_pi = bellman_error_gradients(pi, v)
      grad_g_v, grad_g_pi = nash_restriction_gradients(pi, v)
      
      grad_f_vector, grad_g_matrix = build_grad_tensors(grad_f_v, grad_f_pi, grad_g_v, grad_g_pi)
      d0_vector, norm_2_d0, duals_0_vector, A_matrix, b_vector = calculate_descent_direction(g_vector, grad_f_vector, grad_g_matrix)
      c_vector = update_c_vector(c_vector, duals_0_vector)

      d_vector, duals_vector, rho = calculate_feasible_direction(g_vector, c_vector, duals_0_vector, 
                                                                 norm_2_d0, A_matrix, b_vector, rho, 
                                                                 grad_f_vector, grad_g_matrix)
      d_v, d_pi, duals = vec2dic(d_vector, duals_vector)

      g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero = check_nash_KKT_conditions(pi, v, duals)
      if ttype == 'nolast':
        pi_KKT = check_pi_KKT_conditions(pi_vector, pi_sum_vector, duals)
        g_pi_plus_satisfied, max_g_pi_plus, product_zero_plus_satisfied, max_product_zero_plus = pi_KKT[:4] 
        g_pi_one_satisfied, max_g_pi_one, product_zero_one_satisfied, max_product_zero_one = pi_KKT[4:]
      game_temp, found_feasible_step_size, n_step, f_new = feasible_gradient_descent(game_new, f, g_vector, c_vector, d_vector, 
                                                                                     grad_f_vector, grad_g_matrix, d_v, d_pi, 
                                                                                     duals_vector, verbose=verbose) # TODO: check if passing wrong grad_f
      if found_feasible_step_size:
        game_new = game_temp
        delta_f = (f_new - f_0) / f_0 * 100
        print('Epoch: {}, f: {:.3e}, delta f: {:.3e}%, nash satisfied: {}, max g: {:.3e}, dual nash satisfied: {}, max product: {:.3e}, rho: {:.3e}, norm2 d0:{:.3e}'.format(
            epoch, f_new.item(), delta_f.item(), g_nash_satisfied, max_g_nash.item(), product_zero_satisfied, max_product_zero.item(), rho, norm_2_d0.item()))
        if ttype == 'nolast':
          print('Epoch: {}, >0 satisfied: {}, max g>0: {:.3e}, dual >0 satisfied: {}, max product >0: {:.3e}, =1 satisfied: {}, max g=1: {:.3e}, dual =1 satisfied: {}, max product =1: {:.3e}'.format(
            epoch, g_pi_plus_satisfied, max_g_pi_plus.item(), product_zero_plus_satisfied, max_product_zero_plus.item(), g_pi_one_satisfied, max_g_pi_one.item(), product_zero_one_satisfied, max_product_zero_one.item()))
      else:
        break
        
  return game_new, rho     

#**Optimization**

In [None]:
game_b = nolast_game().to('cuda')

v0 = calculate_initial_v(game_b, alpha=0.01)
v0 = torch.FloatTensor(v0).to('cuda')
with torch.no_grad():
  game_b.v.data.add_(v0)

In [None]:
alpha = 0.5
gamma_0 = 0.5
eta = 0.01
nu = 1/0.98
rho_0 = 1.0
rho = rho_0
r_value = 1.0
c = 1

In [None]:
game_b2, rho = optimize_game(game_b, rho, n_epochs=100, verbose=False)