In [35]:
from L96 import *
from EnKF import *
from utils import *
from parameterizations import *
import gym

In [7]:
initX, initY = np.load('./data/initX.npy'), np.load('./data/initY.npy')

In [9]:
l96_tru = L96TwoLevel(X_init=initX, Y_init=initY)
l96_tru.iterate(1)

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




In [11]:
X_tru = l96_tru.X.copy()

In [52]:
class L96TwoLevelRL(L96TwoLevel):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def step_with_B(self, B):

        k1_X = self._rhs_X_dt(self.X, B=0)
        k2_X = self._rhs_X_dt(self.X + k1_X / 2, B=0)
        k3_X = self._rhs_X_dt(self.X + k2_X / 2, B=0)
        k4_X = self._rhs_X_dt(self.X + k3_X, B=0)

        self.X += 1 / 6 * (k1_X + 2 * k2_X + 2 * k3_X + k4_X)
        
        self.X += B * self.dt

        self.step_count += 1
        if self.step_count % self.save_steps == 0:
            Y_mean = self.Y.reshape(self.K, self.J).mean(1)
            Y2_mean = (self.Y.reshape(self.K, self.J)**2).mean(1)
            self._history_X.append(self.X.copy())
            self._history_Y_mean.append(Y_mean.copy())
            self._history_Y2_mean.append(Y2_mean.copy())
            self._history_B.append(B.copy())
            if not self.noYhist:
                self._history_Y.append(self.Y.copy())

In [58]:
class L96Gym(gym.Env):
    def __init__(self, lead_time, X_init, Y_init, dt=0.01, action_bounds=(-20,20)):
        self.lead_time = lead_time
        self.X_init, self.Y_init = X_init, Y_init
        self.step_count = 0
        self.dt = dt
        self.nsteps = self.lead_time // self.dt
        
        self.l96_tru = L96TwoLevel(X_init=initX, Y_init=initY)
        self.l96_tru.iterate(lead_time)
        self.fc_target = self.l96_tru.X.copy()
        
        self.action_space = gym.spaces.Box(
            low=np.array([action_bounds[0]]), 
            high=np.array([action_bounds[1]])
        )
        self.observation_space = gym.spaces.Box(-np.array([np.inf]), np.array([np.inf]))
        
    
    def reset(self):
        self.l96 = L96TwoLevelRL(noYhist=True, X_init=initX, dt=self.dt)
        state = self.l96.X
        return state[:, None]
    
    def step(self, action):
        self.l96.step_with_B(action)
        state = self.l96.X
        self.step_count += 1
        if self.step_count >= self.nsteps:
            done = True
            reward = -((state - self.fc_target)**2).mean()
        else:
            done = False
            reward = 0
        return state[:, None], reward, done, None

In [59]:
env = L96Gym(1, initX, initY)
env.reset()

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




array([ 7.22218496,  2.99854488,  1.42502222,  3.22644313,  3.42971768,
       -2.85367383,  5.61076055,  6.09349954,  2.29462152,  4.53825525,
        6.2793666 ,  0.91918918, -5.64401153,  1.45303851,  8.56648317,
        3.90647132,  0.84966233, -2.14808322, -1.81945225,  4.21658127,
        5.85408652,  1.78285521, -1.21044663, -0.97350462,  5.30371467,
        4.31567286,  3.71369269,  2.24476462, -1.04748423, -2.42960882,
        6.78435451,  5.16713111,  3.79624924,  1.52825985, -1.75016476,
       -2.58994597])

In [60]:
env.action_space, env.observation_space

(Box(1,), Box(1,))

In [61]:
done = False
while not done:
    state, reward, done,_ = env.step(np.zeros(X_tru.shape))

In [65]:
state.shape, state.ndim

((36, 1), 2)

In [34]:
class ReplayBuffer:
    def __init__(self, capacity):
        self.buffer = deque(maxlen=capacity)
    
    def push(self, state, action, reward, next_state, done):
        if len(state) == 1:
            self.buffer.append([state, action, reward, next_state, done])
        else:
            for s, a, ns in zip(state, action, next_state):
                self.buffer.append([s, a, reward, ns, done])
    
    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done
    
    def __len__(self):
        return len(self.buffer)