In [68]:
import numpy as np
from scipy.linalg import lstsq
import  numpy.random as random
from operator import itemgetter

In [69]:
class ReplayMemory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def push_batch(self, batch):
        if len(self.buffer) < self.capacity:
            append_len = min(self.capacity - len(self.buffer), len(batch))
            self.buffer.extend([None] * append_len)

        if self.position + len(batch) < self.capacity:
            self.buffer[self.position : self.position + len(batch)] = batch
            self.position += len(batch)
        else:
            self.buffer[self.position : len(self.buffer)] = batch[:len(self.buffer) - self.position]
            self.buffer[:len(batch) - len(self.buffer) + self.position] = batch[len(self.buffer) - self.position:]
            self.position = len(batch) - len(self.buffer) + self.position

    def sample(self, batch_size):
        if batch_size > len(self.buffer):
            batch_size = len(self.buffer)
        batch = random.sample(self.buffer, int(batch_size))
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def sample_all_batch(self, batch_size):
        idxes = np.random.randint(0, len(self.buffer), batch_size)
        batch = list(itemgetter(*idxes)(self.buffer))
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def return_all(self):
        return self.buffer

    def __len__(self):
        return len(self.buffer)


In [83]:
class ENV:
    def __init__(self,A,B,Q,R,target_state=np.ones(3)):
        self.A=A
        self.B=B
        self.Q=Q
        self.R=R
        self.target_state=target_state
        self.current_action=None
        self.current_state=None

    def reset(self):
        self.current_state=np.ones(len(self.A))
        return self.current_state

    def step(self,action):
        mean = self.A@np.array([self.current_state]).T + self.B@action
        print(mean)
        next_state=np.random.multivariate_normal(mean.T[0],np.eye(len(mean)))
        part_1=np.array([next_state])@self.Q@np.array([next_state]).T
        part_2=action.T@self.R@action
        
        self.current_action=action
        
        # if next_state==self.target_state:
        #     return part_1+part_2,next_state,True
        return part_1+part_2,next_state,False

In [77]:
def New_Estimate(A_hat,B_hat,Q_hat,R_hat,D_real):
    if len(D_real)==0:
        return A_hat,B_hat,Q_hat,R_hat
    dim_state=D_real[0][0].shape
    dim_action=D_real[0][1].shape
    X=np.zeros(shape=(dim_state+dim_action,len(D_real)))
    Y=np.zeros(shape=(dim_state,len(D_real)))
    for i in range(len(D_real)):
        Y[:,i]=D_real[i][2]
        X[:,i]=np.concatenate(D_real[i][0],D_real[i][1])
    A=X.T
    b=Y.T
    total_hat=lstsq(A,b) # need to split
    return total_hat

In [78]:
def Sample_state(env_pool):
    space=[x[0] for x in env_pool] 
    return np.random.choice(space)

def gradient_with_model(A,B):
    
    grad=4
    return grad 

def gradient_with_exp(D_fake):
    
    grad=4
    return grad

In [79]:
def get_fake_traj(S_t,horiz_len,A_hat,B_hat,Q_hat,R_hat,K_t):
    i=0
    holder=[]
    while i<horiz_len:
        prev=S_t
        u_T=K_t*S_t
        S_t=A_hat*S_t+B_hat*u_T+np.random.normal(0,1)
        R_t=0.5*[S_t]@Q_hat@[S_t].T + 0.5*[u_T]@R_hat@[u_T].T
        done_t=False
        if S_t==np.zeros_like(S_t):
            done_t=True
        holder.append((prev,u_T,R_t,S_t,done_t))
        i+=1
    return holder

def get_from_env(K_t,env,len_traj):
    
    i=0
    Is_done=False
    holder=[]
    while i<len_traj and Is_done!=True:
        u_T=K_t@np.array([env.current_state]).T
        R_t,S_t,Is_done=env.step(u_T)
        holder.append((env.current_state,u_T,R_t,S_t,Is_done))
        env.current_state=S_t
        if Is_done:
            break
        i+=1  
    return holder

In [84]:
# A ->>3*3
# B ->>3*3
# C->eye(3)
# K ->>3*3



np.random.seed(0)

# True parameters of the env

A=np.diag([1,-2,3])
B=np.diag([1,2,3])
C=np.eye(3)
Q=np.diag([1,2,3])
R=np.diag([5,1,3])



A_hat=np.random.rand(3,3)   # Initial theta
B_hat=np.random.rand(3,3)   # Initial theta
Q_hat=np.diag(np.ones(3))   # Initial theta
R_hat=np.diag(np.ones(3))   # Initial theta

K=np.random.rand(3,3)   # Initial phi
K_t=K

D_real = ReplayMemory(10000)   # Real data
D_fake = ReplayMemory(10000)  # Fake data


env=ENV(A,B,Q,R)
env.reset() # Reset
################################################################

num_of_epcoh=20
N=10
E=10
M=10
G=10
horiz_len=10
num_of_rollouts=10

################################################################

In [85]:
for n in range(num_of_epcoh):
    A_hat,B_hat,Q_hat,R_hat=New_Estimate(A_hat,B_hat,Q_hat,R_hat,D_real) # Regression 
    for e in range(E):
        S_t=np.random.rand(len(A_hat)) # Update D_real
        sup=get_from_env(K_t,env,20)
        D_real.push(sup)
        for m in range(M):
            S_t=Sample_state(D_real)    # Random sampling
            D_fake.push(get_fake_traj(S_t,horiz_len,A_hat,B_hat,K_t))
            # Update D_fake with horiz_len
        for g in range(G):
            K_t+=gradient_with_model(A_hat,B_hat)    # Known parameter
            K_t+=gradient_with_exp(D_fake)    # unKnown parameter (From trajectories -off policy settings)

[[3.62678724]
 [2.08233421]
 [5.7046462 ]]
[[16.95595767]
 [18.18348684]
 [31.51907013]]
[[ 74.08588398]
 [ 54.05371311]
 [147.96640313]]
[[327.72185404]
 [292.28803595]
 [641.48750381]]
[[1462.76801088]
 [1208.86506616]
 [2869.38324889]]
[[ 6461.51055323]
 [ 5516.44036869]
 [12686.31163105]]
[[28705.51120661]
 [24192.04879594]
 [56396.40728274]]
[[127276.06377619]
 [107859.94770423]
 [250062.48752869]]
[[ 564874.59052313]
 [ 477623.01495941]
 [1109953.56646449]]
[[2506191.18041402]
 [2121125.95896799]
 [4924556.36584909]]


KeyboardInterrupt: 