In [1]:
import numpy as np
import cvxpy as cp

In [2]:
pi_exp = np.load('../data/pi.npy')
P = np.load('../data/P.npy')

num_states = P.shape[0]
gridsize = int(np.sqrt(num_states))
num_actions = P.shape[-1]

Pastr = []
lPa = []
for s_idx in range(num_states):
    astr = np.argmax(pi_exp[s_idx, :])
    a = np.delete(np.arange(num_actions), astr)
    Pastr.append(P[s_idx, :, astr])
    lPa.append(P[s_idx, :, a])
Pastr = np.array(Pastr)
lPa = np.array(lPa).transpose(1, 0, 2)

In [3]:
# problem information (input from user)
m = num_states  # number of states
gamma = 0.9  # discount factor
Pastr = Pastr  # transition matrix of the optimal action
lPa = lPa  # list of transition matrices of the other actions

# hyperparameters (input from user)
rmax = 100  # reward function bound
lbd_up = 100
lbd_low = 0
epsilon = 0.05

r = cp.Variable(m)
s = cp.Variable(m)
lbd = cp.Parameter(nonneg=True)  # scalarization weight

constraints = []
H = np.linalg.inv(np.identity(m) - gamma * Pastr)
D = np.array([[Pastr[i] - Pa[i] for Pa in lPa] for i in range(m)])
for i in range(m):
    constraints.append(D[i] @ H @ r + s[i] >= 0)
for Pa in lPa:
    constraints.append((Pastr - Pa) @ H @ r >= 0)
constraints.append(rmax >= r)
constraints.append(r >= 0)

obj = cp.Minimize(cp.sum(s) + lbd * cp.norm(r, 1))
prob = cp.Problem(obj, constraints)

while True:
    lbd.value = 0.5 * (lbd_up + lbd_low)
    opt_val = prob.solve()

    print(f'lambda: {lbd.value:.2f} (upper: {lbd_up:.2f}, lower: {lbd_low:.2f}), optimal value: {prob.value:.2f}')
    if np.abs(opt_val) < 1e-6:
        lbd_up = lbd.value
    elif lbd_up - lbd_low <= epsilon:
        break
    else:
        lbd_low = lbd.value

lambda: 50.00 (upper: 100.00, lower: 0.00), optimal value: 0.00
lambda: 25.00 (upper: 50.00, lower: 0.00), optimal value: 0.00
lambda: 12.50 (upper: 25.00, lower: 0.00), optimal value: 0.00
lambda: 6.25 (upper: 12.50, lower: 0.00), optimal value: 0.00
lambda: 3.12 (upper: 6.25, lower: 0.00), optimal value: -915.37
lambda: 4.69 (upper: 6.25, lower: 3.12), optimal value: 0.00
lambda: 3.91 (upper: 4.69, lower: 3.12), optimal value: 0.00
lambda: 3.52 (upper: 3.91, lower: 3.12), optimal value: 0.00
lambda: 3.32 (upper: 3.52, lower: 3.12), optimal value: -294.59
lambda: 3.42 (upper: 3.52, lower: 3.32), optimal value: 0.00
lambda: 3.37 (upper: 3.42, lower: 3.32), optimal value: -142.62
lambda: 3.39 (upper: 3.42, lower: 3.37), optimal value: -68.58
