<a href="https://colab.research.google.com/github/tarod13/Stochastic_Games/blob/master/stochastic_games_herkovitz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
from scipy.optimize import linprog
import matplotlib.pyplot as plt
import itertools
import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter

In [None]:
beta = 0.99

M = 5
N_P = 2
N_A = 2
N_S = 3 + N_A * (2*(M+1) + 2)
G = [('G',)]
O = [('O',i) for i in range(1,3)]
E1 = [('E1',i,j) for i in range(0,N_A) for j in range(0,M+1)]
E2 = [('E2',i,j) for i in range(0,N_A) for j in range(0,M+1)]
R1 = [('R1',i) for i in range(0,N_A)]
R2 = [('R2',i) for i in range(0,N_A)]
S = list(itertools.chain(G, O, E1, E2, R1, R2)) 
S1 = list(itertools.chain(G, ('O',1), E1, R1))
S2 = list(itertools.chain([('O',2), E2, R2]))

RG1 = np.array([[3.,0.],[5.,1.]])
RG2 = np.array([[3.,5.],[0.,1.]])

In [None]:
def get_state_index(state):
  if state in S:
    return S.index(state)
  elif state in S_str:
    return S_str.index(state)
  else:
    assert 0 == 1, 'Invalid state'

def player_dim(i):
  if i in [1, '1']:
    return 'i'
  elif i in [2, '2']:
    return 'j'
  else:
    assert 0 == 1, 'Invalid player id'
  
def get_player_id(player):
  if player == '1':
    return 0
  elif player == '2':
    return 1
  else:
    assert 0 == 1, 'Invalid player'

def players():
  return range(1,2+1)

def players_str():
  return iter(['1','2']) 

def state_player_pairs():
  return itertools.product(S, players())

S_str = [str(s) for s in S]
def state_player_str_pairs():
  return itertools.product(S_str, players_str())

In [None]:
class log_game(nn.Module):
  def __init__(self):
    super().__init__()
    self.log_pi1 = nn.ParameterDict()
    self.log_pi1[str(('G',))] = Parameter(torch.Tensor(N_A,1))
    nn.init.zeros_(self.log_pi1[str(('G',))])
    self.log_pi1[str(('O',1))] = Parameter(torch.Tensor((M+1)*N_A,1))
    nn.init.zeros_(self.log_pi1[str(('O',1))])    
    self.log_pi1[str(('O',2))] = Parameter(torch.Tensor(1,1))
    nn.init.ones_(self.log_pi1[str(('O',2))])

    for state in E1:
      self.log_pi1[str(state)] = Parameter(torch.Tensor(2,1))
      nn.init.zeros_(self.log_pi1[str(state)])
      
    for state in E2 + R2:
      self.log_pi1[str(state)] = Parameter(torch.Tensor(1,1))
      nn.init.ones_(self.log_pi1[str(state)])

    for state in R1:
      self.log_pi1[str(state)] = Parameter(torch.Tensor(N_A,1))
      nn.init.zeros_(self.log_pi1[str(state)])

    self.log_pi2 = nn.ParameterDict()
    self.log_pi2[str(('G',))] = Parameter(torch.Tensor(N_A,1))
    nn.init.zeros_(self.log_pi2[str(('G',))])
    self.log_pi2[str(('O',1))] = Parameter(torch.Tensor(1,1))
    nn.init.ones_(self.log_pi2[str(('O',1))])
    self.log_pi2[str(('O',2))] = Parameter(torch.Tensor((M+1)*N_A,1))
    nn.init.zeros_(self.log_pi2[str(('O',2))])

    for state in E2:
      self.log_pi2[str(state)] = Parameter(torch.Tensor(2,1))
      nn.init.zeros_(self.log_pi2[str(state)])
      
    for state in E1 + R1:
      self.log_pi2[str(state)] = Parameter(torch.Tensor(1,1))
      nn.init.ones_(self.log_pi2[str(state)])

    for state in R2:
      self.log_pi2[str(state)] = Parameter(torch.Tensor(N_A,1))
      nn.init.zeros_(self.log_pi2[str(state)])

    self.v = Parameter(torch.Tensor(N_S,2))
    nn.init.zeros_(self.v)    

    self.sqrt_lambda_nash1 = nn.ParameterDict()
    self.sqrt_lambda_nash2 = nn.ParameterDict()  # dual vars. for Nash inequalities
    for s in S:
        self.sqrt_lambda_nash1[str(s)] = Parameter(torch.Tensor(self.log_pi1[str(s)].detach().size()))
        self.sqrt_lambda_nash2[str(s)] = Parameter(torch.Tensor(self.log_pi2[str(s)].detach().size()))
        nn.init.zeros_(self.sqrt_lambda_nash1[str(s)])
        #self.sqrt_lambda_nash1[str(s)].data.mul_(1.0)
        nn.init.zeros_(self.sqrt_lambda_nash2[str(s)])
        #self.sqrt_lambda_nash2[str(s)].data.mul_(1.0)

    for param in self.parameters():  
      nn.init.uniform_(param)

  def forward(self):
    return self.pi(), self.v, self.lambda_nash()
  
  def pi(self):
    pi = {'1':{}, '2':{}}
    for s in S:
      log_pi1 = self.log_pi1[str(s)]
      log_pi2 = self.log_pi2[str(s)]
      
      log_pi1_shift = log_pi1 - log_pi1.max()
      log_pi2_shift = log_pi2 - log_pi2.max()

      pi1 = torch.exp(log_pi1_shift + 1e-20)
      pi2 = torch.exp(log_pi2_shift + 1e-20)
      
      pi['1'][str(s)] = pi1 / pi1.sum()
      pi['2'][str(s)] = pi2 / pi2.sum()
    return pi
  
  def lambda_nash(self):
    lambda_ = {'1':{}, '2':{}}
    for s in S:
      lambda_['1'][str(s)] = self.sqrt_lambda_nash1[str(s)].pow(2)
      lambda_['2'][str(s)] = self.sqrt_lambda_nash2[str(s)].pow(2)
    return lambda_

In [None]:
class nolast_game(nn.Module):
  def __init__(self):
    super().__init__()
    self.pi_nolast1 = nn.ParameterDict()
    self.pi_nolast1[str(('G',))] = Parameter(torch.Tensor(N_A-1,1))
    nn.init.zeros_(self.pi_nolast1[str(('G',))])
    self.pi_nolast1[str(('O',1))] = Parameter(torch.Tensor((M+1)*N_A-1,1))
    nn.init.zeros_(self.pi_nolast1[str(('O',1))])    
    self.pi_nolast1[str(('O',2))] = Parameter(torch.Tensor(1-1,1))
    
    for state in E1:
      self.pi_nolast1[str(state)] = Parameter(torch.Tensor(2-1,1))
      nn.init.zeros_(self.pi_nolast1[str(state)])
      
    for state in E2 + R2:
      self.pi_nolast1[str(state)] = Parameter(torch.Tensor(1-1,1))

    for state in R1:
      self.pi_nolast1[str(state)] = Parameter(torch.Tensor(N_A-1,1))
      nn.init.zeros_(self.pi_nolast1[str(state)])

    self.pi_nolast2 = nn.ParameterDict()
    self.pi_nolast2[str(('G',))] = Parameter(torch.Tensor(N_A-1,1))
    nn.init.zeros_(self.pi_nolast2[str(('G',))])
    self.pi_nolast2[str(('O',1))] = Parameter(torch.Tensor(1-1,1))
    self.pi_nolast2[str(('O',2))] = Parameter(torch.Tensor((M+1)*N_A-1,1))
    nn.init.zeros_(self.pi_nolast2[str(('O',2))])

    for state in E2:
      self.pi_nolast2[str(state)] = Parameter(torch.Tensor(2-1,1))
      nn.init.zeros_(self.pi_nolast2[str(state)])
      
    for state in E1 + R1:
      self.pi_nolast2[str(state)] = Parameter(torch.Tensor(1-1,1))

    for state in R2:
      self.pi_nolast2[str(state)] = Parameter(torch.Tensor(N_A-1,1))
      nn.init.zeros_(self.pi_nolast2[str(state)])

    self.v = Parameter(torch.Tensor(N_S,2))
    nn.init.zeros_(self.v)    

  def forward(self):
    return self.pi(), self.v
  
  def pi(self):
    pi = {'1':{}, '2':{}}
    for s in S:
      NA1 = self.pi_nolast1[str(s)].shape[0]+1
      pi1 = torch.zeros(NA1,1).to('cuda')
      if NA1 > 1:
        pi1[:-1,:] = self.pi_nolast1[str(s)]
      pi1[-1,0] = 1.0 - pi1.sum()
      
      NA2 = self.pi_nolast2[str(s)].shape[0]+1
      pi2 = torch.zeros(NA2,1).to('cuda')
      if NA2 > 1:
        pi2[:-1,:] = self.pi_nolast2[str(s)]
      pi2[-1,0] = 1.0 - pi2.sum()

      pi['1'][str(s)] = pi1
      pi['2'][str(s)] = pi2
    return pi

  def pi_discriminated(self):
    pi_nolast = {'1':[], '2':[]}
    pi_last = {'1':[], '2':[]}
    for s in S_str:
      pi_nolast[1].append(self.pi_nolast1[s])
      pi_last[1].append(1.0 - self.pi_nolast1[s].sum())
      pi_nolast[2].append(self.pi_nolast2[s])
      pi_last[2].append(1.0 - self.pi_nolast2[s].sum())
    for player in players_str():
      pi_nolast[player] = torch.cat(pi_nolast[player])
      pi_last[player] = torch.FloaTensor(pi_last[player]).view(-1,1).to('cuda')
    return pi_nolast, pi_last 

In [None]:
game_b = game().to('cuda')
optimizer = optim.Adam(game_b.parameters(), lr=1e-3)

N_A_S = {}
N_A_S_tensors = {'1':[], '2':[]}
with torch.no_grad():
  pi, v, lambda_nash = game_b()
  for state in S:
    N_A1_state = pi['1'][str(state)].shape[0]
    N_A2_state = pi['2'][str(state)].shape[0]
    N_A_S[str(state)] = {'1': N_A1_state, '2': N_A2_state}
    N_A_S_tensors['1'].append(N_A1_state)
    N_A_S_tensors['2'].append(N_A2_state)
for player in players_str():
  N_A_S_tensors[player] = torch.FloatTensor(N_A_S_tensors[player]).to('cuda').view(-1,1)

In [None]:
def offer_accepted(action):
  if action == 0:
    return True
  else:
    return False
  

def other_player(i):
  if i == 1:
    return 2
  elif i == 2:
    return 1
  elif i == '1':
    return '2'
  elif i == '2':
    return '1'
  else:
    assert 0 == 1, 'Invalid player id'


def transition_info(state, actions):
  if 'G' in state or 'R1' in state or 'R2' in state:
    T = [(('O',1), 0.5), (('O',2), 0.5)]
    return ('R', T)
  elif 'O' in state:
    _, id_player = state
    offeral = actions[id_player-1] // N_A
    action_requested = actions[id_player-1] % N_A
    return ('D', ('E'+str(other_player(id_player)), action_requested, offeral))
  elif 'E1' in state:
    _, action_requested, _ = state
    if offer_accepted(actions[0]):
      return ('D', ('R2', action_requested))
    else:
      return ('D', ('G',))
  elif 'E2' in state:
    _, action_requested, _ = state
    if offer_accepted(actions[1]):
      return ('D', ('R1', action_requested))
    else:
      return ('D', ('G',))   
  else:
    assert 0 == 1, 'Invalid state' 

In [None]:
transition_types = {}
for state in S:
  if 'G' in state or 'R1' in state or 'R2' in state:
    transition_types[str(state)] = (0,0) 
  else:
    transition_types[str(state)] = (1,1) # entries correspond to deterministic behaviour and dependence on actions

In [None]:
def rewards(state, actions):
  if 'G' in state:
    r1 = RG1[actions[0], actions[1]]
    r2 = RG2[actions[0], actions[1]]
  elif 'O' in state:
    r1 = r2 = 0.0
  elif 'E1' in state:
    _, _, offeral = state
    if offer_accepted(actions[0]):
      r1 = offeral
      r2 = -offeral
    else:
      r1 = r2 = 0
  elif 'E2' in state:
    _, _, offeral = state
    if offer_accepted(actions[1]):
      r1 = -offeral
      r2 = offeral
    else:
      r1 = r2 = 0
  elif 'R1' in state:
    _, i = state
    r1 = RG1[actions[0], i]
    r2 = RG2[actions[0], i]
  elif 'R2' in state:
    _, i = state
    r1 = RG1[i, actions[1]]
    r2 = RG2[i, actions[1]]
  else:
    assert 0 == 1, 'Invalid state'

  return r1, r2


def reward_matrices(s):
  if 'G' in s:
    RM1 = RG1.copy()
    RM2 = RG2.copy()
  else:
    N_A1 = N_A_S[str(s)]['1']
    N_A2 = N_A_S[str(s)]['2']
    RM1 = np.zeros((N_A1,N_A2))
    RM2 = np.zeros((N_A1,N_A2))
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        r1, r2 = rewards(s, [a1,a2])
        RM1[a1,a2] = r1
        RM2[a1,a2] = r2
  return RM1, RM2  

In [None]:
RM = {}
RM['1'] = {}
RM['2'] = {}
for s in S:
  RM1, RM2 = reward_matrices(s)
  RM['1'][str(s)] = torch.FloatTensor(RM1).to('cuda')
  RM['2'][str(s)] = torch.FloatTensor(RM2).to('cuda')

In [None]:
def player_consistent_reward_matrices():
  consistent_RM = {'1':{}, '2':{}}
  for s in S_str:
    consistent_RM['1'][s] = RM['1'][s].clone()
    consistent_RM['2'][s] = torch.t(RM['2'][s].clone())  
  return consistent_RM 

In [None]:
def next_value_matrices(s, v): # TODO: consider other 2 cases
  det, dep = transition_types[str(s)]
  N_A1 = N_A_S[str(s)]['1']
  N_A2 = N_A_S[str(s)]['2']
  vs = torch.zeros((N_A1,N_A2,2)).to('cuda')
  if det and dep:
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        _, next_state = transition_info(s, [a1,a2])
        vs[a1,a2,:] = v[get_state_index(next_state),:]
  elif (not det) and (not dep):
    _, transition_dic = transition_info(s, [])
    next_v = torch.zeros(1,2).to('cuda')
    for next_state, transition_prob in transition_dic:
      next_v = next_v + v[get_state_index(next_state),:].view(1,-1) * transition_prob      
    for a1 in range(0,N_A1):
      for a2 in range(0,N_A2):
        vs[a1,a2,:] = next_v.view(-1)
  return vs

def next_value_dictionary(v):
  next_v_dic = {}
  for s in S:
    next_v_dic[str(s)] = next_value_matrices(s, v)
  return next_v_dic


def transition_matrix(pi): # TODO: consider other 2 cases
  transition_matrix = torch.zeros((N_S,N_S)).to('cuda')
  for state in S:
    strategy_1 = pi['1'][str(state)]
    strategy_2 = pi['2'][str(state)]

    det, dep = transition_types[str(state)]
    N_A1 = strategy_1.shape[0]
    N_A2 = strategy_2.shape[0]
    id_s = get_state_index(state)

    if det and dep:
      for a1 in range(0,N_A1):
        for a2 in range(0,N_A2):
          _, next_state = transition_info(state, [a1,a2])
          id_ns = get_state_index(next_state)
          transition_prob = strategy_1[a1,0] * strategy_2[a2,0]
          transition_matrix[id_s, id_ns] = transition_matrix[id_s, id_ns] + transition_prob
    elif (not det) and (not dep):
      _, transition_dic = transition_info(state, [])
      for next_state, transition_prob in transition_dic:
        transition_matrix[id_s, get_state_index(next_state)] = transition_prob
  return transition_matrix


def partial_transition_matrices(pi): # TODO: consider other 2 cases
  # Create dictionary of transition matrices for each state given
  # the strategy of the other player 
  transition_matrices = {}
  for s in S:
    strategy_1 = pi['1'][str(s)]
    strategy_2 = pi['2'][str(s)]
    N_A1 = strategy_1.shape[0]
    N_A2 = strategy_2.shape[0]
    transition_matrices[str(s)] = {
        '1': torch.zeros((N_A1,N_S)).to('cuda'),
        '2': torch.zeros((N_A2,N_S)).to('cuda')
    }

    # Fill matrices with transition probabilities depending on the type
    # of transition, i.e., if deterministic or random and independent or
    # not on the actions
    det, dep = transition_types[str(s)]
    if det and dep:
      for a1 in range(0,N_A1):
        for a2 in range(0,N_A2):
          _, next_state = transition_info(s, [a1,a2])
          id_ns = get_state_index(next_state)
          transition_prob1 = strategy_2[a2,0]
          transition_prob2 = strategy_1[a1,0]
          transition_matrices[str(s)]['1'][a1, id_ns] = (
              transition_matrices[str(s)]['1'][a1, id_ns] + transition_prob1)
          transition_matrices[str(s)]['2'][a2, id_ns] = (
              transition_matrices[str(s)]['2'][a2, id_ns] + transition_prob2)
    elif (not det) and (not dep):
      _, transition_dic = transition_info(s, [])
      for next_state, transition_prob in transition_dic:
        transition_matrices[str(s)]['1'][:, get_state_index(next_state)] = transition_prob
        transition_matrices[str(s)]['2'][:, get_state_index(next_state)] = transition_prob
  return transition_matrices


def expected_reward(RM, pi):
  r_mean = torch.zeros((N_S,2)).to('cuda')
  for i in range(1,2+1):
    RM_i = RM[str(i)]
    for s in S:
      strategy_1 = pi['1'][str(s)]
      strategy_2 = pi['2'][str(s)]
      r_mean_1 = torch.einsum('ij,ik->jk', RM_i[str(s)], strategy_1)
      r_mean[get_state_index(s),i-1] = (r_mean_1 * strategy_2).sum()
  return r_mean


def partial_expected_reward_other(RM, pi):
  r_mean = {}
  for player in players_str():
    r_mean[player] = {}

  for s, player in state_player_str_pairs():
      N_A = N_A_S[s][player]
      RM_i = RM[player]
      other_player_ = other_player(player)
      strategy = pi[other_player_][s].view(-1)
      formula = 'ij,'+player_dim(other_player_)+'->'+player_dim(player)
      r_mean[player][s] = torch.einsum(formula, RM_i[s], strategy).view(-1,1)
  return r_mean


def partial_expected_reward(RM, pi):
  # Create reward dictionary for each combination of players
  r_mean = {}
  for player in players_str():  
    r_mean[player] = {'1':{}, '2':{}}

  # Calculate expected reward for combination of players wrt the policy of one of the players
  for s, player in state_player_str_pairs():
    other_player_ = other_player(player) # Player used to calculate expected reward
    for second_player in players_str():  
      N_A = N_A_S[s][second_player]
      RM_i = RM[second_player] # Reward matrix for one of the players  
      strategy = pi[other_player_][s].view(-1)
      formula = 'ij,'+player_dim(other_player_)+'->'+player_dim(player)
      r_mean[player][second_player][s] = torch.einsum(formula, RM_i[s], strategy).view(-1,1)
  return r_mean


def bellman_projection(RM, pi, v):
  r_mean = expected_reward(RM, pi)
  next_v = torch.zeros((N_S,2)).to('cuda')
  for s in S:
    next_state_value_matrix = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    next_value_1 = torch.einsum('ijk,i->jk', next_state_value_matrix, strategy_1)
    next_v[get_state_index(s),:] = torch.einsum('jk,j->k', next_value_1, strategy_2)
  return r_mean + beta * next_v


def bellman_partial_projection_other(RM, pi, v):
  r_mean = partial_expected_reward_other(RM, pi)
  next_v = {'1':{}, '2':{}}
  for s in S:
    next_state_value_matrix = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    # mean next value when considering the strategy of the other player. Output: array of size m^i(s)
    next_v['1'][str(s)] = torch.einsum('ij,j->i', next_state_value_matrix[:,:,0], strategy_2).view(-1,1)
    next_v['2'][str(s)] = torch.einsum('ij,i->j', next_state_value_matrix[:,:,1], strategy_1).view(-1,1)
  bellman_projection_dic = {'1':{}, '2':{}}
  for s, player_id in state_player_pairs():
    bellman_projection_dic[str(player_id)][str(s)] = r_mean[str(player_id)][str(s)] + beta * next_v[str(player_id)][str(s)]
  return bellman_projection_dic


def partial_next_values(pi, v):
  # Create next-value dictionary for each player combination
  next_v = {'1':{}, '2':{}}
  for player in players_str():
    next_v[player] = {'1':{}, '2':{}}

  # Fill dictionary 
  for s in S:
    next_state_value_matrix = next_value_matrices(s, v)
    strategy_1 = pi['1'][str(s)].squeeze(1)
    strategy_2 = pi['2'][str(s)].squeeze(1)
    
    # Calculate mean next value when considering the strategy of one of the players. Output: array of size m^i(s)
    for player_id in players():
      next_v['1'][str(player_id)][str(s)] = torch.einsum('ij,j->i', next_state_value_matrix[:,:,player_id-1], strategy_2).view(-1,1)
      next_v['2'][str(player_id)][str(s)] = torch.einsum('ij,i->j', next_state_value_matrix[:,:,player_id-1], strategy_1).view(-1,1)  
  return next_v


def bellman_partial_projection(RM, pi, v):
  r_mean = partial_expected_reward(RM, pi)
  next_v = partial_next_values(pi, v)

  bellman_projection_dic = {'1':{}, '2':{}}
  for player in players_str():
    bellman_projection_dic[player] = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    for second_player in players_str():
      bellman_projection_dic[player][second_player][s] = (
          r_mean[player][second_player][s] + beta * next_v[player][second_player][s]
          )
  return bellman_projection_dic


def reward_baselines(RM, pi):
  r_mean = expected_reward(RM, pi)
  r_baseline = r_mean.mean(0).view(-1,1).detach().cpu().numpy()
  return r_baseline

In [None]:
def cost_vector_fixed_policies(pi):
  P = transition_matrix(pi)
  cost_vector = (1 - beta * P.sum(0)).view(-1,1).detach().cpu().numpy() / N_S
  return cost_vector

def restriction_matrices_fixed_policies(pi):
  transition_matrices = partial_transition_matrices(pi)
  restriction_matrices = {'1':[], '2':[]}
  for s, player_id in state_player_pairs():
    temp_matrix = - beta * transition_matrices[str(s)][str(player_id)]
    temp_matrix[:, get_state_index(s)] = temp_matrix[:, get_state_index(s)] + 1
    restriction_matrices[str(player_id)].append(temp_matrix)
  for player in players_str():
    restriction_matrices[player] = -torch.cat(restriction_matrices[player], dim=0).detach().cpu().numpy()
  return restriction_matrices

def restriction_vectors_fixed_policies(RM, pi, alpha=0.1):
  r_mean = partial_expected_reward_other(RM, pi)
  restriction_vectors = {'1':[], '2':[]}
  for s, player in state_player_str_pairs():
    restriction_vectors[player].append(r_mean[player][s].view(-1,1))
  for player in players_str():
    restriction_vectors[player] = -(torch.cat(restriction_vectors[player], dim=0)+alpha).detach().cpu().numpy()
  return restriction_vectors


def parameters_fixed_policies(game_, alpha):
  pi = game_()[0]
  c = cost_vector_fixed_policies(pi)
  f0 = reward_baselines(RM, pi)
  A_ub = restriction_matrices_fixed_policies(pi)
  b_ub = restriction_vectors_fixed_policies(RM, pi, alpha)
  return c, f0, A_ub, b_ub


def calculate_initial_v(game_):
  v0 = np.zeros((N_S,2))
  c, f0, A_ub, b_ub = parameters_fixed_policies(game_, alpha=10)
  for player_id in players():
    temp_res = linprog(c, A_ub=A_ub[str(player_id)], b_ub=b_ub[str(player_id)])
    v0[:,player_id-1] = temp_res.x
  return v0

In [None]:
def calculate_nash_restrictions(pi, v):
  q_estimated = bellman_partial_projection_other(RM, pi, v) # Dic. with an array of 'q'-values for each agent 
  g_nash = {'1':{}, '2':{}}
  for s, player_id in state_player_pairs():
    g_nash[str(player_id)][str(s)] = q_estimated[str(player_id)][str(s)] - v[get_state_index(s), player_id-1]
  return g_nash


def calculate_bellman_error(pi, v):
  # Calculation of original target function: Bellman approximation error 
  v_estimated = bellman_projection(RM, pi, v)
  f_bellman = (v - v_estimated).sum()
  return f_bellman


def lagrangian_nash(pi, v, lambda_nash, c=1):
  # Calculation of original target function: Bellman approximation error 
  v_estimated = bellman_projection(RM, pi, v)
  f_bellman = (v - v_estimated).sum()

  # Calculation of restrictions
  g_nash = calculate_nash_restrictions(pi, v)

  # Calculation of shifted restrictions  
  nash_restriction_products = {'1':{}, '2':{}}  
  for s, player_id in state_player_str_pairs():
    nash_restriction_products[player_id][s] = (g_nash[player_id][s].view(-1,1) * lambda_nash[player_id][s].view(-1,1)) 
                                                #+ c * g_nash[player_id][s].view(-1,1).pow(2) * lambda_nash[player_id][s].view(-1,1).detach())    
  # Calculation of 2-norms  
  nash_restriction_sum = 0.0
  for s, player_id in state_player_str_pairs():
    nash_restriction_sum = nash_restriction_sum + nash_restriction_products[player_id][s].sum()
    
  # Calculation of augmented lagrangian
  lagrangian = f_bellman + nash_restriction_sum  
  return lagrangian, f_bellman


def check_KKT_conditions(pi, v, lambda_nash, tol=1e-8):
  with torch.no_grad():
    # Calculate restrictions
    g_nash = calculate_nash_restrictions(pi, v)

    g_nash_satisfied = True
    product_zero_satisfied = True
    
    max_g_nash = -np.infty
    max_product_zero = -np.infty
    
    for player in players_str():
      remaining_duals = lambda_nash[player].clone()
      for s in S_str:
        NA = N_A_S[s][player]
        g_nash_satisfied = g_nash_satisfied and torch.all(g_nash[player][s] <= 0)
        max_g_nash = max(max_g_nash, g_nash[player][s].max())

        lambda_g_nash_product = g_nash[player][s].view(-1) * remaining_duals[:NA,:].view(-1)
        product_zero_satisfied = product_zero_satisfied and torch.all(lambda_g_nash_product.abs() <= tol)
        max_product_zero = max(max_product_zero, lambda_g_nash_product.abs().max())
        remaining_duals = remaining_duals[NA:,:]
    return g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero


def check_nash_conditions(pi, v):
  with torch.no_grad():
    # Calculate restrictions
    g_nash = calculate_nash_restrictions(pi, v)
    g_nash_satisfied = True
    max_g_nash = -np.infty
    
    for s, player_id in state_player_str_pairs():
      g_nash_satisfied = g_nash_satisfied and torch.all(g_nash[player_id][s] <= 0)
      max_g_nash = max(max_g_nash, g_nash[player_id][s].max())
    return g_nash_satisfied, max_g_nash
  

# def optimize_game(game_, n_epochs, optimizers, params_dual, 
#                   print_each=50, save_each=10):
#   losses = []
#   costs = []
#   delta_cost = 0.0
#   previous_cost = np.infty
#   opti_primal, opti_dual = optimizers
  
#   for epoch in range(0, n_epochs):
#     pi, v, lambda_nash = game_()
#     loss, cost = lagrangian_nash(pi, v, lambda_nash)
#     delta_cost = cost.item() - previous_cost

#     previous_cost = cost.item()
#     torch.save(game_.state_dict(), './game_temp.pth')

#     opti_primal.zero_grad()
#     opti_dual.zero_grad()    
#     loss.backward()
    
#     for p in params_dual:
#       if p.grad is not None:
#         p.grad.data.mul_(-1)

#     opti_primal.step()
#     opti_dual.step()

#     if (epoch + 1) % print_each == 0:
#       with torch.no_grad():
#         nash_satisfied, max_nash_restrictions, KKT_zero_product_satisfied, max_KKT_product = check_KKT_conditions(pi, v, lambda_nash)
#       print("Epoch: {}, Loss: {:.3f}, f: {:.3f}, nash: {}, max g: {:.3f}, KKT product: {}, max product: {:.3e}".format(
#           epoch+1,loss.item(),cost.item(), nash_satisfied, max_nash_restrictions,
#           KKT_zero_product_satisfied, max_KKT_product))

#     if (epoch + 1) % save_each == 0:
#       losses.append(loss.item())
#       costs.append(cost.item())
#   return losses, costs

In [None]:
def bellman_error_gradients(pi, v):
  P = transition_matrix(pi)
  bellman_error_grad_v = (torch.eye(N_S).to('cuda') - torch.t(P)).sum(1, keepdim=True)

  q_individual = bellman_partial_projection(RM, pi, v)
  bellman_error_grad_pi = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    bellman_error_grad_pi[player][s] = 0.0
    for second_player in players_str():
      bellman_error_grad_pi[player][s] = bellman_error_grad_pi[player][s] - q_individual[player][second_player][s]  
  return bellman_error_grad_v, bellman_error_grad_pi

def nash_restriction_gradients(pi, v):
  P_partial = partial_transition_matrices(pi)
  next_value_dic = next_value_dictionary(v)
  consistent_RM = player_consistent_reward_matrices()

  nash_restriction_grad_v = {'1':{}, '2':{}}
  nash_restriction_grad_pi = {'1':{}, '2':{}}

  for s, player in state_player_str_pairs():
    Delta = torch.zeros_like(P_partial[s][player])
    Delta[:, get_state_index(s)] = 1.0
    nash_restriction_grad_v[player][s] = beta * P_partial[s][player]  - Delta
  
    other_player_ = other_player(player)
    next_value_matrix = next_value_dic[s][:,:,get_player_id(other_player_)]
    r_matrix = consistent_RM[other_player_][s]
    if player == '1':
      next_value_matrix = torch.t(next_value_matrix)
    nash_restriction_grad_pi[player][s] = r_matrix + beta * next_value_matrix
  return nash_restriction_grad_v, nash_restriction_grad_pi


def calculate_jacobian_pi_log_pi(pi):
  jacobian_pi_log_pi_dic = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    p = pi[player][s].view(-1)
    jacobian_pi_log_pi_dic[player][s] = torch.diag(p) - torch.einsum('i,j->ij',p,p)
  return jacobian_pi_log_pi_dic


def calculate_jacobian_pi_pi_nolast():
  jacobian_pi_pi_nolast_dic = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    NA = N_A_S[s][player]
    if NA > 1:
      jacobian_pi_pi_nolast_dic[player][s] = torch.eye(NA, m=NA-1).to('cuda')
      jacobian_pi_pi_nolast_dic[player][s][-1,:] = -1.0
    else:
      jacobian_pi_pi_nolast_dic[player][s] = None
  return jacobian_pi_pi_nolast_dic


def calculate_nash_pi_transformed_gradients(pi, grad_f_pi, grad_g_pi, transform_type='nolast'):
  grad_f_pi_transformed = {'1':{}, '2':{}}
  grad_g_pi_transformed = {'1':{}, '2':{}}

  if transform_type == 'nolast':
    jacobian_pi_pi_transformed = calculate_jacobian_pi_pi_nolast()
  elif transform_type == 'log':
    jacobian_pi_pi_transformed = calculate_jacobian_pi_log_pi(pi)
  else:
    assert 0 == 1, 'Invalid transformation'

  for s, player in state_player_str_pairs():
    gradient_f = grad_f_pi[player][s]
    jacobian_g = grad_g_pi[player][s]
    jacobian_pi = jacobian_pi_pi_transformed[player][s]

    if jacobian_pi is not None:
      grad_f_pi_transformed[player][s] = torch.einsum('ij,iw->jw', jacobian_pi, gradient_f)
      grad_g_pi_transformed[player][s] = torch.einsum('ij,jk->ik', jacobian_g, jacobian_pi)
    else:
      grad_f_pi_transformed[player][s] = None
      grad_g_pi_transformed[player][s] = None
  return grad_f_pi_transformed, grad_g_pi_transformed

def calculate_descent_direction(pi_nolast_tensor, pi_last_tensor, g_nash, 
                                grad_f_v, grad_f_pit, grad_g_v, grad_g_pit):
  duals_0 = {}
  grad_gi_vi_matrix = {}
  A_system = {'1': {}, '2': {}}
  b_system = {'1': {}, '2': {}}
  for player in players_str():
    for i in range(1,3+1):
      A_system[player][str(i)] = {}     

  d0_v = torch.zeros((N_S,2)).to('cuda')  
  d0_pit = {'1':{}, '2':{}}
  norm_2_d0 = 0.0

  for player in players_str():
    N_A_total = N_A_S_tensors[player].sum()
    J_gi_vi = torch.cat(list(grad_g_v[player].values()))
    grad_gi_vi_matrix[player] = J_gi_vi.clone()
    A_system[player]['1']['1'] = (r_value*torch.diag(torch.cat(list(g_nash[player].values())).view(-1)) - 
                                  torch.einsum('as,bs->ab', J_gi_vi, J_gi_vi))
    A_system[player]['2']['2'] = (r_value*torch.diag(-pi_nolast_tensor.view(-1)) - 
                                  torch.eye(int(N_A_total-N_S)))
    A_system[player]['3']['3'] = (r_value*torch.diag(-pi_last_tensor.view(-1)) - 
                                  torch.diag(N_A_S_tensors[player].view(-1)-1.0))
    
    b_system[player]['1'] = torch.einsum('as,sw->aw', J_gi_vi, grad_f_v)

    A_system_pit = {}
    for i in range(1,3+1):
      A_system_pit[str(i)] = {}
      for j in range(1,3+1):
        A_system_pit[str(i)][str(j)] = []  
    
    b_system_pit = {}
    for i in range(1,3+1):
        b_system_pit[str(i)] = []

    other_player_ = other_player(player)
    for s in S_str:
      NA = N_A_S[s][other_player_]
      if NA > 1:
        A_system_pit['1']['1'].append(torch.einsum('as,bs->ab', grad_g_pit[other_player_][s], grad_g_pit[other_player_][s])) # TODO: is other_player_ ok?
        A_system_pit['2']['3'].append(-torch.ones(NA-1,1).to('cuda'))
        A_system_pit['1']['2'].append(-grad_g_pit[other_player_][s])
        A_system_pit['1']['3'].append(grad_g_pit[other_player_][s].sum(1, keepdim=True))

        b_system_pit['1'].append(torch.einsum('as,sw->aw', grad_g_pit[other_player_][s], grad_f_pit[other_player_][s]))
        b_system_pit['2'].append(-grad_f_pit[other_player_][s]))
        b_system_pit['3'].append(grad_f_pit[other_player_][s].sum())

    A_system[player]['1']['1'] = A_system[player]['1']['1'] - torch.block_diag(*A_system_pit['1']['1'])
    for i,j in [('2','3'), ('1','2'), ('1','3')]:
      A_system[player][i][j] = torch.block_diag(*A_system_pit[i][j])
    for i,j in [('3','2'), ('2','1'), ('3','1')]:
      A_system[player][i][j] = torch.t(A_system[player][i][j])
    
    b_system[player]['1'] = b_system[player]['1'] + torch.cat(b_system_pit['1'])   
    b_system[player]['2'] = torch.cat(b_system_pit['2']) 
    b_system[player]['3'] = torch.FloatTensor(b_system_pit['3']).view(-1,1).to('cuda')   

    for i range(1,3+1):
      A_system[player][str(i)] = torch.cat(list(A_system[player][str(i)].values()), dim=1)
    A_system[player] = torch.cat(list(A_system[player].values()), dim=0)
    b_system[player] = torch.cat(list(b_system[player].values()), dim=0)

    duals_0[player], _ = torch.solve(b_system[player], A_system[player])    
    
    d0_v[:,get_player_id(player)] = -torch.einsum('as,a->s', J_gi_vi, duals_0[player].view(-1)[:N_A_total]) - grad_f_v.view(-1)
    norm_2_d0 += d0_v[:,get_player_id(player)].pow(2).sum()

    remaining_duals = duals_0[player].clone()
    for s in S_str:
      NA = N_A_S[s][player]
      if NA > 1:
        d0_pit[other_player_][s] = -torch.einsum('ba,bw->aw', grad_g_pit[other_player_][s], remaining_duals[:NA-1,:]) - grad_f_pit[other_player_][s]
      else:
        d0_pit[other_player_][s] = 0.0
      remaining_duals = remaining_duals[NA-1:,:]
    for s in S_str:
      NA = N_A_S[s][player]
      if NA > 1:
        d0_pit[other_player_][s] = d0_pit[other_player_][s] + remaining_duals[:NA-1,:]
        assert remaining_duals[:NA-1,:].shape == (NA-1,1), 'Incorrect dual size'
        assert d0_pit[other_player_][s].shape == (NA-1,1), 'Incorrect gradient size'
      remaining_duals = remaining_duals[NA-1:,:]
    for s in S_str:
      NA = N_A_S[s][player]
      if NA > 1:
        d0_pit[other_player_][s] = d0_pit[other_player_][s] - remaining_duals[0,0] * torch.ones(NA-1,1).to('cuda')
        assert d0_pit[other_player_][s].shape == (NA-1,1), 'Incorrect gradient size'
      remaining_duals = remaining_duals[1:,:]
      norm_2_d0 += d0_pit[other_player_][s].pow(2).sum()      
  return d0_v, d0_pit, norm_2_d0, duals_0, A_system, b_system


def calculate_feasible_direction(d0_v, d0_pit, norm_2_d0, duals_0, A, b, grad_g_v, grad_g_pit, rho):
  div = duals_0['1'].sum() + duals_0['2'].sum()
  if div > 0:
    rho_1 = (1-alpha) / div
    if rho_1 < rho:
      rho = 0.5 * rho_1
  
  duals = {}  
  d_v = d0_v.clone() 
  d_pit = {'1':{}, '2':{}}
  for s, player in state_player_str_pairs():
    d_pit[player][s] = d0_pit[player][s].clone()
  
  for player in players_str():
    dual_corrections = -rho * norm_2_d0 * torch.solve(torch.ones_like(b[player]), A[player])[0]
    
    N_A_total = N_A_S_tensors[player].sum()
    J_gi_vi = torch.cat(list(grad_g_v[player].values()))
    d_v[:,get_player_id(player)] -= torch.einsum('as,a->s', J_gi_vi, dual_corrections.view(-1)[:N_A_total])

    other_player_ = other_player(player)
    remaining_corrections = dual_corrections.clone()
    for s in S_str:
      NA = N_A_S[s][player]
      if NA > 1:
        grad_displacement = - torch.einsum('ba,bw->aw', grad_g_pit[other_player_][s], remaining_corrections[:NA-1,:])
        d_pit[other_player_][s] = d_pit[other_player_][s] + grad_displacement # TODO: check mistake in previous version
      remaining_corrections = remaining_corrections[NA-1:,:]
    for s in S_str:
      NA = N_A_S[s][player]
      if NA > 1:
        d_pit[other_player_][s] = d_pit[other_player_][s] + remaining_corrections[:NA-1,:]        
      remaining_corrections = remaining_corrections[NA-1:,:]
    for s in S_str:
      NA = N_A_S[s][player]
      if NA > 1:
        d_pit[other_player_][s] = d_pit[other_player_][s] - remaining_corrections[0,0] * torch.ones(NA-1,1).to('cuda')
      remaining_corrections = remaining_corrections[1:,:]

    duals[player] = duals_0[player] + dual_corrections
  return d_v, d_pit, duals, rho


def copy_game(game_original):
  game_copy = game().to('cuda')
  game_copy.load_state_dict(game_original.state_dict())
  return game_copy


def feasible_gradient_descent(game_0, pi, v, g, grad_f_v, grad_f_pi, d_v, d_pit, duals,
                              max_steps=1000, verbose=False):
  f = calculate_bellman_error(pi, v)
  decrement = torch.einsum('ij,i->j', d_v, grad_f_v.view(-1)).sum()
  for s, player in state_player_str_pairs():
    decrement += (d_pit[player][s] * grad_f_pi[player][s]).sum()

  found_feasible_step_size = False
  step_size = 1.0
  n_step = 0
  max_steps = max(1, max_steps)
  while (not found_feasible_step_size) and (n_step < max_steps):
    game_temp = copy_game(game_0)
    game_temp.v.data.add_(step_size * d_v)
    for s in S_str:
      game_temp.log_pi1[s].data.add_(step_size * d_pit['1'][s])
      game_temp.log_pi2[s].data.add_(step_size * d_pit['2'][s])

    pi_temp, v_temp, _ = game_temp()
    f_temp = calculate_bellman_error(pi_temp, v_temp)
    if f_temp <= f + step_size * eta * decrement:
      feasible_step = True
      g_temp = calculate_nash_restrictions(pi_temp, v_temp)
      for player in players_str():
        gammas = gamma_0 * torch.ones_like(duals[player])
        gammas[duals[player]< 0] = 1.0
        remaining_gammas = gammas.clone()
        for s in S_str:
          NA = N_A_S[s][player]
          if not torch.all(g_temp[player][s] <= remaining_gammas[:NA,:] * g[player][s]):
            feasible_step = False
            break
        if not feasible_step:
          break
      if feasible_step:
        found_feasible_step_size = True
    step_size /= nu
    n_step += 1
    if verbose:
      print("Step: {}, Found feasible step size: {}".format(n_step, found_feasible_step_size))
  return game_temp, found_feasible_step_size, n_step, f_temp


def optimize_game(game_0, rho, n_epochs=100):
  game_new = copy_game(game_0)
  pi_0, v_0 = game_0()
  f_0 = calculate_bellman_error(pi_0, v_0)
  with torch.no_grad():
    for epoch in range(0, n_epochs):
      pi, v = game_new()
      pi_nolast, pi_last = game_new.pi_discriminated()
      g_nash = calculate_nash_restrictions(pi, v)
      grad_f_v, grad_f_pi = bellman_error_gradients(pi, v)
      grad_g_v, grad_g_pi = nash_restriction_gradients(pi, v)
      grad_f_pit, grad_g_pit = calculate_nash_pi_transformed_gradients(pi, grad_f_pi, grad_g_pi, transform_type='nolast')
      d0_v, d0_pit, norm_2_d0, duals_0, A_system, b_system = calculate_descent_direction(pi_nolast, pi_last, g_nash, 
                                                                                         grad_f_v, grad_f_pit, grad_g_v, grad_g_pit)
      d_v, d_pit, duals, rho = calculate_feasible_direction(d0_v, d0_pit, norm_2_d0, duals_0, A_system, b_system, grad_g_v, grad_g_pit, rho)
      g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero = check_KKT_conditions(pi, v, duals)
      game_temp, found_feasible_step_size, n_step, f_new = feasible_gradient_descent(game_new, pi, v, g_nash, grad_f_v, grad_f_pi, d_v, d_pit, duals)
      
      if found_feasible_step_size:
        game_new = game_temp
        delta_f = (f_new - f_0) / f_0 * 100
        print('Epoch: {}, f: {:.3e}, delta f: {:.3e}%, nash satisfied: {}, max g: {:.3e}, dual nash satisfied: {}, max product: {:.3e}, rho: {:.3e}, norm2 d0:{:.3e}'.format(
            epoch, f_new.item(), delta_f.item(), g_nash_satisfied, max_g_nash, product_zero_satisfied, max_product_zero, rho, norm_2_d0))
      else:
        break
        
  return game_new     

#**Optimization**

In [None]:
game_b = game().to('cuda')
params_primal = [game_b.v] + list(game_b.log_pi1.parameters()) + list(game_b.log_pi2.parameters())
params_dual = list(game_b.sqrt_lambda_nash1.parameters()) + list(game_b.sqrt_lambda_nash2.parameters())
optim_primal = optim.Adam(params_primal, lr=1e-3)
optim_dual = optim.Adam(params_dual, lr=1e-4)
optimizers = [optim_primal, optim_dual]

In [None]:
v0 = calculate_initial_v(game_b)
v0 = torch.FloatTensor(v0).to('cuda')
game_b.v.data.add_(v0)

tensor([[1173.4198, 1142.6169],
        [1170.6841, 1141.3544],
        [1171.7487, 1141.1611],
        [1172.2035, 1142.7819],
        [1172.1705, 1142.1958],
        [1171.7673, 1141.0442],
        [1173.0305, 1140.8301],
        [1174.1405, 1140.8645],
        [1174.7847, 1139.5341],
        [1171.6725, 1140.8264],
        [1171.7070, 1139.6440],
        [1172.3224, 1139.3293],
        [1172.9286, 1139.0126],
        [1174.3192, 1137.9739],
        [1175.3152, 1138.2100],
        [1172.5956, 1141.3788],
        [1172.1632, 1141.6699],
        [1170.5911, 1142.0400],
        [1170.7838, 1142.9020],
        [1170.4510, 1143.7872],
        [1169.6688, 1144.0585],
        [1170.8831, 1141.2123],
        [1170.5372, 1141.4984],
        [1168.3439, 1142.2196],
        [1169.5537, 1144.1621],
        [1168.4790, 1144.4883],
        [1167.6859, 1145.6003],
        [1173.8311, 1140.5386],
        [1169.8845, 1142.5107],
        [1171.0253, 1144.9800],
        [1171.3553, 1140.4384]], device=

In [None]:
alpha = 0.1
gamma_0 = 0.1
eta = 0.1
nu = 1/0.98
rho_0 = 1.0
rho = rho_0
r_value = 1.0

In [None]:
game_b2 = optimize_game(game_b, rho, n_epochs=10)

Epoch: 0, f: 5.691e+02, delta f: -1.415e+01%, nash satisfied: True, max g: -9.134e+00, dual nash satisfied: False, max product: 5.194e+00, rho: 2.553e-02, norm2 d0:6.737e+01
Epoch: 1, f: 5.231e+02, delta f: -2.108e+01%, nash satisfied: True, max g: -6.760e+00, dual nash satisfied: False, max product: 6.076e+00, rho: 2.553e-02, norm2 d0:7.789e+01
Epoch: 2, f: 4.986e+02, delta f: -2.478e+01%, nash satisfied: True, max g: -5.141e+00, dual nash satisfied: False, max product: 6.316e+00, rho: 2.553e-02, norm2 d0:8.278e+01
Epoch: 3, f: 4.670e+02, delta f: -2.955e+01%, nash satisfied: True, max g: -4.553e+00, dual nash satisfied: False, max product: 5.793e+00, rho: 2.553e-02, norm2 d0:7.420e+01
Epoch: 4, f: 4.221e+02, delta f: -3.631e+01%, nash satisfied: True, max g: -4.210e+00, dual nash satisfied: False, max product: 4.919e+00, rho: 2.553e-02, norm2 d0:5.857e+01
Epoch: 5, f: 3.707e+02, delta f: -4.407e+01%, nash satisfied: True, max g: -2.988e+00, dual nash satisfied: False, max product: 3.

In [None]:
A = {'1':{}, '2':{}, '3':{}}
for i in range(1,3+1):
  for j in range(1,3+1):
    A[str(i)][str(j)] = np.eye(2,3)

In [None]:
X = torch.from_numpy(np.concatenate(list(A['1'].values())))

In [None]:
X.device

device(type='cpu')