In [14]:
import numpy as np
import copy
from multiprocessing import Pool
import multiprocessing

### Generate matrices

In [15]:
d = 10
np.random.seed(1812)
#generate random T
T = np.random.uniform(low = -1.0,high = 1.0, size = (d,d))
#compute QR factorisation
Q, R_1 = np.linalg.qr(T,mode='complete')
A_12 = Q
#compute Lambda_0
Lambda_0 = np.diag(np.random.uniform(low = 1.0,high = 2.0, size = (d,)))
A_22 = (Q.T).dot(Lambda_0.dot(Q))
#generate R
R = np.random.uniform(low = -1.0,high = 1.0, size = (d,d))
A_11 = R.dot(R.T) + 5*np.eye(d)
#compute Lambda_1
Lambda_1 = np.diag(np.random.uniform(low = -1.0,high = 1.0, size = (d,)))
A_21 = (Q.T).dot(Lambda_1)
#sample true parameters
theta_star = np.random.uniform(low = -1.0,high = 1.0, size = (d,))
w_star = np.random.uniform(low = -1.0,high = 1.0, size = (d,))
#compute b_1,b_2
b_1 = A_11 @ theta_star + A_12 @ w_star
b_2 = A_21 @ theta_star + A_22 @ w_star

### Check assumptions

In [16]:
#compute Delta
Delta = A_11 - A_12.dot(np.linalg.inv(A_22).dot(A_21))
eigvals, eigfuncs = np.linalg.eig(Delta)
print(eigvals)

[14.57061784 11.40444971 10.58819872  9.78375006  7.30257669  6.81006409
  4.86218813  5.02569884  5.25124318  5.46024408]


### Run GTD

In [20]:
N_iters = 1*10**4
#step size
beta = np.zeros(N_iters,dtype = float)
gamma = np.zeros(N_iters,dtype = float)
#N_0_beta = 10**4
N_0_beta = 1*10**4
N_0_gamma = 1*10**4
deg_beta = 1.0
deg_gamma = 0.7
#deg_gamma = 0.5
C_0_beta = 2*10**2
#C_0_beta = 12*10**4
C_0_gamma = 10.0
#C_0_gamma = 10.0
for i in range(N_iters):
    beta[i] = C_0_beta/(N_0_beta+i)
    gamma[i] = C_0_gamma/(N_0_gamma+i)**(deg_gamma)

In [21]:
def main_loop(j,beta,gamma):
    np.random.seed(2020+j)
    w_cur = np.random.randn(d)
    theta_cur = np.random.randn(d)
    V_funcs = np.zeros((N_iters,d))
    W_params = np.zeros((N_iters,d))
    sigma_V = 0.1
    sigma_W = 0.5
    Prod_all = np.eye(d)
    Sigma_11 = sigma_V**2*(np.eye(d)*(1+np.linalg.norm(theta_star)**2 + np.linalg.norm(w_star)**2))
    Sigma_22 = sigma_W**2*(np.eye(d)*(1+np.linalg.norm(theta_star)**2 + np.linalg.norm(w_star)**2))
    A_22_inv = np.linalg.inv(A_22)
    Sigma = Sigma_11 + A_12.dot(A_22_inv.dot(Sigma_22.dot(A_22_inv.T.dot(A_12.T))))
    I_all = np.zeros(N_iters)
    ###Main loop
    for N in range(N_iters):
        #generate noisy V
        F_V = sigma_V*np.random.randn(d)
        A_V_theta = sigma_V*np.random.randn(d,d)
        A_V_w = sigma_V*np.random.randn(d,d)
        V = F_V - A_V_theta@theta_cur - A_V_w@w_cur
        #generate noisy W
        F_W = sigma_W*np.random.randn(d)
        A_W_theta = sigma_W*np.random.randn(d,d)
        A_W_w = sigma_W*np.random.randn(d,d)
        W = F_W - A_W_theta@theta_cur - A_W_w@w_cur
        #update
        theta_cur = theta_cur + beta[N]*(b_1 - A_11@theta_cur - A_12@w_cur + V)
        w_cur = w_cur + gamma[N]*(b_2 - A_21@theta_cur - A_22@w_cur + W)
        #compute I_k
        Prod_all = Prod_all.dot(np.eye(d) - beta[N]*Delta)
        Prod_cur = copy.copy(Prod_all)
        I_k = 0.0
        for i in range(N+1):
            I_k += (beta[i]**2)*np.trace(Prod_cur.dot(Sigma.dot(Prod_cur.T)))
            Prod_cur = (np.linalg.inv(np.eye(d)-beta[i]*Delta)).dot(Prod_cur)
        #save value function
        V_funcs[N] = theta_cur
        W_params[N] = w_cur
        I_all[N] = I_k
    return np.asarray([V_funcs,W_params]),I_all

In [22]:
res_indep,I = main_loop(0,beta,gamma)
print(res_indep.shape)

KeyboardInterrupt: 

In [None]:
norms_theta = np.zeros(N_iters,dtype=float)
norms_w = np.zeros(N_iters,dtype=float)

for i in range(N_iters):
    norms_theta[i] = np.linalg.norm(res_indep[0,i,:]-theta_star)
    norms_w[i] = np.linalg.norm(res_indep[1,i,:] - w_star)

In [None]:
import matplotlib.pyplot as plt
N_start = 0
N_last = N_iters
plt.figure(figsize=(12,8)) 
plt.plot(np.arange(N_start,N_last), norms_theta[N_start:N_last]**2, color='r' ,label='Squared error on $\\theta$,  $\\|\\theta_k-\\theta^*\\|^2$')
plt.plot(np.arange(N_start,N_last), norms_w[N_start:N_last]**2, color='g' ,label='Squared error on $w$,  $\\|w_k-w^*\\|^2$') 
plt.plot(np.arange(N_start,N_last), I[N_start:N_last]**2, color='b' ,label='Leading term  $I_k$') 
plt.xlabel('iteration number',fontsize = 18)
#plt.ylabel('cost',fontsize = 18) 
#plt.title('VR cost for MDCV, Gaussian distribution, quadratic target',fontsize = 20)
plt.yscale('log')
#plt.xscale('log')
plt.legend() 
plt.show()

In [None]:
import matplotlib
matplotlib.rc('xtick', labelsize=20) 
matplotlib.rc('ytick', labelsize=20) 

In [None]:
import matplotlib.pyplot as plt
N_start = 0
N_last = N_iters
plt.figure(figsize=(10,10)) 
plt.plot(np.arange(N_start,N_last), (norms_theta[N_start:N_last]**2), color='r' ,label='Squared error on $\\theta$,  $\\|\\theta_k-\\theta^*\\|^2$')
plt.plot(np.arange(N_start,N_last), beta[N_start:N_last], color = 'b', label = '$\\beta_k$')
#plt.plot(np.arange(N_start,N_last), (norms_w[N_start:N_last]**2)/gammas[N_start:N_last], color='g' ,label='Squared error on $w$,  $\\|w_k-w^*\\|^2/\\gamma_k$') 
plt.xlabel('iteration number',fontsize = 18)
#plt.ylabel('cost',fontsize = 18) 
plt.title('GTD($0$)',fontsize = 20)
plt.yscale('log')
plt.xscale('log')
plt.legend(fontsize = 18,loc='lower left')
plt.grid(linestyle='--', linewidth=1.0)
#plt.savefig("GTD_theta_squared.pdf")
#plt.show()

In [None]:
import matplotlib.pyplot as plt
N_start = 0
N_last = N_iters
plt.figure(figsize=(10,10)) 
plt.plot(np.arange(N_start,N_last), norms_w[N_start:N_last]**2, color='g' ,label='Squared error of $w$,  $\\|w_k-w^*\\|^2$') 
plt.plot(np.arange(N_start,N_last), gamma[N_start:N_last], color = 'b', label = '$\\gamma_k$')
#plt.plot(np.arange(N_start,N_last), (norms_w[N_start:N_last]**2)/gammas[N_start:N_last], color='g' ,label='Squared error on $w$,  $\\|w_k-w^*\\|^2/\\gamma_k$') 
plt.xlabel('iteration number',fontsize = 18)
#plt.ylabel('cost',fontsize = 18) 
plt.title('GTD($0$)',fontsize = 20)
plt.yscale('log')
plt.xscale('log')
plt.legend(fontsize = 18,loc='lower left') 
plt.grid(linestyle='--', linewidth=1.0)
#plt.savefig("GTD_w_squared.pdf")
#plt.show()

In [None]:
####

In [None]:
nbcores = multiprocessing.cpu_count()
trav = Pool(nbcores)
res_indep = trav.starmap(main_loop, [(j,alpha,s0) for j in range (len(powers))])
trav.close()

In [None]:
res_indep = np.asarray(res_indep)
print(res_indep.shape)

In [None]:
norms = np.zeros((len(powers),N_iters),dtype=float)
norms_J0_rem = np.zeros((len(powers),N_iters),dtype=float)
norms_J1_rem = np.zeros((len(powers),N_iters),dtype=float)
norms_transient = np.zeros((len(powers),N_iters),dtype=float)

norms_J0 = np.zeros((len(powers),N_iters),dtype=float)
norms_J1 = np.zeros((len(powers),N_iters),dtype=float)
for j in range(len(powers)):
    for i in range(N_iters):
        norms[j][i] = np.linalg.norm(res_indep[j,0,i,:]-theta_star)
        norms_J0_rem[j][i] = np.linalg.norm(res_indep[j,0,i,:] - res_indep[j,1,i,:]-theta_star)
        norms_J1_rem[j][i] = np.linalg.norm(res_indep[j,0,i,:] - res_indep[j,1,i,:]-res_indep[j,2,i,:]-theta_star)
        norms_transient[j][i] = np.linalg.norm(res_indep[j,3,i,:])
        norms_J0[j][i] = np.linalg.norm(res_indep[j,1,i,:])
        norms_J1[j][i] = np.linalg.norm(res_indep[j,2,i,:])

### Save results

### Plot graphics

In [None]:
import matplotlib.pyplot as plt
N_start = 0
j=3
plt.figure(figsize=(12,8)) 
plt.plot(np.arange(N_start,N_iters), norms[j][N_start:], color='r' ,label='MSE error') 
plt.plot(np.arange(N_start,N_iters), norms_J0_rem[j][N_start:], color='g' ,label='MSE error without J_0') 
plt.plot(np.arange(N_start,N_iters), norms_J1_rem[j][N_start:], color='b' ,label='MSE error without J_0, J_1')
plt.xlabel('iteration number',fontsize = 18)
#plt.ylabel('cost',fontsize = 18) 
#plt.title('VR cost for MDCV, Gaussian distribution, quadratic target',fontsize = 20)
plt.yscale('log')
plt.legend() 
plt.show()

In [None]:
#initialize policy
V = copy.deepcopy(v0)
J0_cur = np.zeros(N_s,dtype=float)
J1_cur = np.zeros(N_s,dtype=float)
Transient_cur = v0 - theta_star
###Main loop
for N in range(N_iters):
    #sample action
    a = np.random.choice(N_a, 1, replace=True, p=Policy[:,s0])
    a=a[0]
    #sample next state
    s = np.random.choice(Inds_nz[s0,a], 1, replace=True, p=P[Inds_nz[s0,a],s0,a])
    s=s[0]
    #calculate J0
    eps = np.zeros(N_s,dtype=float)
    eps[s0] = R[a,s0] + gamma*theta_star[s]-theta_star[s0]
    eps_TD = R[a,s0] + gamma*V[s]-V[s0]
    #calculate J1
    A_tilde = np.zeros((N_s,N_s),dtype=float)
    A_tilde[s0,s0] = 1.0
    A_tilde[s0,s] = -gamma
    J1_cur = (np.eye(N_s) - alpha[N]*A_star)@J1_cur - alpha[N]*(A_tilde-A_star)@J0_cur
    #calculate transient term
    Transient_cur = (np.eye(N_s) - alpha[N]*A_tilde)@Transient_cur
    #calculate J0
    J0_cur = (np.eye(N_s) - alpha[N]*A_star)@J0_cur + alpha[N]*eps
    #TD update
    V[s0] = V[s0] + alpha[N]*eps_TD
    #save value function
    V_funcs[N] = V
    #save J_0
    J_0[N] = J0_cur
    #save J_1
    J_1[N] = J1_cur
    #save transient term
    Transient[N] = Transient_cur
    #update current state
    s0 = s

In [None]:
norms = np.zeros(N_iters)
norms_J0_rem = np.zeros(N_iters)
norms_J1_rem = np.zeros(N_iters)
norms_transient = np.zeros(N_iters)

norms_J0 = np.zeros(N_iters)
norms_J1 = np.zeros(N_iters)
for i in range(N_iters):
    norms[i] = np.linalg.norm(V_funcs[i,:]-theta_star)
    norms_J0_rem[i] = np.linalg.norm(V_funcs[i,:] - J_0[i,:]-theta_star)
    norms_J1_rem[i] = np.linalg.norm(V_funcs[i,:] - J_0[i,:]-J_1[i,:]-theta_star)
    norms_transient[i] = np.linalg.norm(Transient[i])
    norms_J0[i] = np.linalg.norm(J_0[i,:])
    norms_J1[i] = np.linalg.norm(J_1[i,:])

In [None]:
import matplotlib.pyplot as plt
N_start = 5*10**4
plt.figure(figsize=(12,8)) 
plt.plot(np.arange(N_start,N_iters), norms[N_start:], color='r' ,label='MSE error') 
plt.plot(np.arange(N_start,N_iters), norms_J0_rem[N_start:], color='g' ,label='MSE error without J_0') 
plt.plot(np.arange(N_start,N_iters), norms_J1_rem[N_start:], color='b' ,label='MSE error without J_0, J_1')
plt.xlabel('iteration number',fontsize = 18)
#plt.ylabel('cost',fontsize = 18) 
#plt.title('VR cost for MDCV, Gaussian distribution, quadratic target',fontsize = 20)
#plt.yscale('log')
plt.legend() 
plt.show()

In [None]:
import matplotlib.pyplot as plt
N_start = 1
plt.figure(figsize=(12,8)) 
plt.plot(np.arange(N_start,N_iters), norms_transient[N_start:], color='r' ,label='Norm of transient term') 
plt.plot(np.arange(N_start,N_iters), norms_J0[N_start:], color='g' ,label='Norm of J_0') 
plt.plot(np.arange(N_start,N_iters), norms_J1[N_start:], color='b' ,label='Norm of J_1') 
plt.xlabel('iteration number',fontsize = 18)
#plt.ylabel('cost',fontsize = 18) 
#plt.title('VR cost for MDCV, Gaussian distribution, quadratic target',fontsize = 20)
plt.yscale('log')
plt.legend() 
plt.show()