In [2]:
import numpy as np

test = np.load('test.npy')

In [5]:
test = test.reshape(18, -1)
print(test.shape)

(18, 18)


In [11]:
import sys
sys.path.append("..")
from mars.equilibrium_solver import NashEquilibriumECOSSolver, NashEquilibriumMWUSolver, NashEquilibriumParallelMWUSolver

all_dists, all_ne_values = NashEquilibriumParallelMWUSolver([test])
print(all_dists, all_ne_values)

[[[nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
   nan]
  [nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan nan
   nan]]] [nan]


In [31]:
all_dists, all_ne_values = NashEquilibriumECOSSolver(test)
print(all_dists, all_ne_values)

(array([3.68420245e-11, 8.70780568e-11, 1.34364502e-10, 1.73996001e-10,
       1.41773746e-10, 5.78266602e-11, 8.65573113e-10, 1.90947802e-11,
       9.99999998e-01, 7.38973225e-11, 1.85573915e-11, 2.68601907e-10,
       6.05324137e-12, 4.72596831e-11, 1.23540629e-11, 3.35665454e-11,
       1.25879486e-11, 7.67281956e-12]), array([7.91304485e-09, 2.33073186e-09, 4.57582617e-09, 5.39447349e-09,
       3.98977170e-09, 6.81164853e-09, 2.49367279e-09, 2.52734253e-09,
       3.79759084e-09, 5.76069426e-09, 2.75630817e-09, 1.21474592e-09,
       5.30544777e-09, 9.99999927e-01, 1.13202839e-09, 5.72141246e-09,
       6.82444074e-09, 3.95812496e-09])) 10.003111842652576


  warn("Converting G to a CSC matrix; may take a while.")
  warn("Converting A to a CSC matrix; may take a while.")


In [14]:
all_dists, all_ne_values = NashEquilibriumMWUSolver(test)
print(all_dists, all_ne_values)

[[0.00161182 0.00262086 0.00502843 0.0042215  0.00351083 0.0021654
  0.03430183 0.00098112 0.91959615 0.00241737 0.00148063 0.01328705
  0.00137177 0.00193724 0.00119753 0.00173841 0.0013012  0.00123086]
 [0.05688533 0.14625373 0.02096117 0.10587083 0.01607077 0.02896784
  0.01159129 0.01223895 0.02069162 0.0257301  0.01185669 0.00757924
  0.03861383 0.3414817  0.00596349 0.03017909 0.09532486 0.02373948]] 9.960371861288117


In [37]:
import copy

def NashEquilibriumParallelMWUSolver(A, Itr=5000, verbose=False):
    """ Solve mulitple Nash equilibrium with multiplicative weights udpate."""
    EPS = 1e-7 # prevent numerical problem
    A = np.array(A)
    matrix_num = A.shape[0]
    row_action_num = A.shape[1]
    col_action_num = A.shape[2]
    learning_rate = np.sqrt(np.log(row_action_num)/Itr)  # sqrt(log |A| / T)

    row_policy = np.ones(row_action_num)/row_action_num
    col_policy = np.ones(col_action_num)/col_action_num
    policies = np.array(matrix_num*[[row_policy, col_policy]])
    final_policy = copy.deepcopy(policies)

    for i in range(Itr):
        # for row player, maximizer
        policies_ = copy.deepcopy(policies)  # track old value before update (update is inplace)
        payoff_vec = np.einsum('nb,nab->na', policies_[:, 1], A) 
        policies[:, 0] = policies[:, 0] * np.exp(learning_rate*payoff_vec)

        # for col player, minimizer
        payoff_vec = np.einsum('na,nab->nb', policies_[:, 0], A) 
        policies[:, 1] = policies[:, 1] * np.exp(-learning_rate*payoff_vec)


        # above is unnormalized, normalize it to be distribution
        policies = policies/np.expand_dims(EPS+np.sum(policies, axis=-1), -1)
        # MWU is average-iterate coverging, so accumulate polices
        final_policy += policies

    final_policy = final_policy / (Itr+1)

    if verbose:
        print(f'For row player, strategy is {final_policy[:, 0]}')
        print(f'For column player, strategy is {final_policy[:, 1]}')
        print(learning_rate)
    
    nash_value = np.einsum('nb,nb->n', np.einsum('na,nab->nb', policies[:, 0], A), final_policy[:, 1])

    return final_policy, nash_value

all_dists, all_ne_values = NashEquilibriumParallelMWUSolver([test])
print(all_dists, all_ne_values)

[[[4.98451100e-04 8.18024637e-04 1.58012176e-03 1.32360372e-03
   1.09929940e-03 6.73543091e-04 1.07702293e-02 2.98566218e-04
   9.74833614e-01 7.53606379e-04 4.56938250e-04 4.18643654e-03
   4.22318211e-04 6.01910960e-04 3.67191054e-04 5.38370793e-04
   3.99965950e-04 3.77730465e-04]
  [5.20287984e-02 1.35215508e-01 1.42089288e-02 1.23702172e-01
   1.13354754e-02 2.39619668e-02 7.12724594e-03 7.33720067e-03
   1.24830220e-02 1.92213963e-02 7.41009006e-03 3.90878657e-03
   2.27921525e-02 4.60971303e-01 3.23877813e-03 2.08185517e-02
   6.02951682e-02 1.39433286e-02]]] [10.0260578]
