In [None]:
%config InlineBackend.figure_format = 'svg'
%env MUJOCO_GL=egl
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
from dm_control import suite
from dm_control.suite.wrappers import pixels
from models import Encoder, Decoder, RewardModel, RSSM
from mpc import MPC
from replay import ExpReplay
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
from utils import display_img, display_video, preprocess_img

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
matplotlib.rcParams['animation.embed_limit'] = 2**128
random_state = np.random.RandomState(0)

In [None]:
# For animations to render inline in jupyter,
# download ffmpeg and set the path below to the location of the ffmpeg executable
# plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'

In [None]:
SEED_EPS = 3
TRAIN_EPS = 100
UPDATES = 100
ACTION_REPEAT = 8
BATCH_SZ = 50
CHUNK_LEN = 50

In [None]:
env = suite.load('cartpole', 'swingup')
env = pixels.Wrapper(env) # only use pixels instead of internal state
act_spec = env.action_spec()
action_dim = act_spec.shape[0]

data = ExpReplay(BATCH_SZ, CHUNK_LEN, action_dim)

In [None]:
# Generate random seed data
total_reward_seed = 0
t = 0
for i in range(SEED_EPS):
    state = env.reset()
    reward = 0
    while not state.last():
        t += 1
        action = random_state.uniform(act_spec.minimum, act_spec.maximum, action_dim)
        reward = state.reward
        if reward is None: reward = 0
        total_reward_seed += reward
        frame = env.physics.render(camera_id=0, height=200, width=200)
        frame = preprocess_img(frame).to(device)
        data.append(frame, torch.as_tensor(action), torch.as_tensor(reward))
        state = env.step(action)
print("Avg reward per ep: ",total_reward_seed/SEED_EPS)
print("Avg timesteps per ep: ", t/SEED_EPS)

In [None]:
enc = Encoder().to(device)
dec = Decoder().to(device)
reward_model = RewardModel().to(device)
rssm = RSSM(action_dim).to(device)
params = list(enc.parameters()) + list(dec.parameters()) + list(reward_model.parameters()) + list(rssm.parameters())
optimizer = optim.Adam(params, lr=1e-3, eps=1e-4)

planner = MPC(action_dim)

In [None]:
rewards_list = []
losses_list = []
observations=[]

# Train for 250 eps
for i in tqdm(range(10)):
    # MODEL FITTING
    total_loss = 0
    for j in range(UPDATES):
        obs, actions, rewards = data.sample_batch()
        state = env.reset()
        obs_loss, reward_loss = 0, 0
        det_state = torch.zeros((50,200)).to(device)
        stoc_state = torch.zeros((50,30)).to(device)
        prior_mean, prior_dev, post_mean, post_dev = torch.zeros((50,30)).to(device), torch.zeros((50,30)).to(device), torch.zeros((50,30)).to(device), torch.zeros((50,30)).to(device)
        for b in range(BATCH_SZ):
            det_state = rssm.drnn(det_state, stoc_state, actions[b])
            prior_state, prior_mean, prior_dev = rssm.ssm_prior(det_state)
            posterior_state, post_dev, post_dev = rssm.ssm_posterior(det_state, enc(obs[b]))
            obs_loss += F.mse_loss(dec(det_state, stoc_state), obs[b])
            reward_loss += F.mse_loss(reward_model(torch.cat((det_state, stoc_state),dim=1)), rewards[b])
        optimizer.zero_grad()
        torch.nn.utils.clip_grad_norm_(params, 1000., norm_type=2)
        loss = obs_loss + reward_loss + kl_div
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    losses_list.append(total_loss)
    # DATA COLLECTION
    t= 0
    eps_reward = 0
    vid = []
    with torch.no_grad():
        state = env.reset()
        det_state = torch.zeros(200).to(device)
        stoc_state = torch.zeros(30).to(device)
        action = torch.zeros(action_dim).to(device)
        frame = preprocess_img(env.physics.render(camera_id=0, height=200, width=200)).to(device)
        while not state.last():
        # for _ in tqdm(range(1000)):
        # for i in tqdm(range(10)):
            # t+=1
            # if (t%20==0): print(t)
            det_state = rssm.drnn(det_state, stoc_state, action.to(device))
            stoc_state, _, _ = rssm.ssm_posterior(det_state, enc(frame))
            stoc_state = stoc_state.squeeze()
            action = planner.get_action(det_state.to(device), stoc_state.to(device), rssm, reward_model)
            for _ in range(ACTION_REPEAT):
                if state.last(): break
                state = env.step(action)
                eps_reward += state.reward
            frame = env.physics.render(camera_id=0, height=200, width=200)
            vid.append(frame)
            frame = preprocess_img(frame).to(device)
            data.append(frame, action, state.reward)
        rewards_list.append(eps_reward)
        observations.append(vid)
        print("Loss: ", total_loss)
        print("Reward: ", eps_reward)

In [None]:
o, a, r = data.sample_batch()


In [None]:
det = torch.zeros(200).to(device)
stoc = torch.zeros(30).to(device)
det_b = torch.zeros((50,200)).to(device)
stoc_b = torch.zeros((50,30)).to(device)

In [None]:
rew = reward_model(torch.cat((det, stoc)))
print(rew.shape)
rew_b = reward_model(torch.cat((det_b, stoc_b), dim=1))
print(rew_b.shape)

In [None]:
torch.cat((det_b,stoc_b), dim=1).shape

In [None]:
de = dec(det, stoc)
print(de.shape)

In [None]:
de_b = dec(det_b, stoc_b)
print(de_b.shape)

In [None]:
d = rssm.drnn(det, stoc, a[0][0])
print(d.shape)

In [None]:
d = rssm.drnn(det_b, stoc_b, a[0])
print(d.shape)

In [None]:
prior =rssm.ssm_prior(det)
print(prior[0].shape)
print(prior[1].shape)
print(prior[2].shape)

In [None]:
prior_b =rssm.ssm_prior(det_b)
print(prior_b[0].shape)
print(prior_b[1].shape)
print(prior_b[2].shape)

In [None]:
posterior =rssm.ssm_posterior(det, enc(o[0][0]))
print(posterior[0].shape)
print(posterior[1].shape)
print(posterior[2].shape)

In [None]:
enc(o[0]).shape

In [None]:
posterior_b =rssm.ssm_posterior(det_b, enc(o[0]))
print(posterior_b[0].shape)
print(posterior_b[1].shape)
print(posterior_b[2].shape)