In [1]:
%matplotlib inline
from JSAnimation import IPython_display
from matplotlib import animation
import matplotlib.pyplot as plt

import gym, pickle,os, policies,torch

In [2]:
def display_frames_as_gif(frames, name="untitled"):
    """
    Displays a list of frames as a gif, with controls
    """
    plt.figure(figsize=(frames[0].shape[1] / 72.0, frames[0].shape[0] / 72.0), dpi = 72)
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames = len(frames), interval=50)
    anim.save("gifs/%s.gif"%(name), writer='imagemagick', fps=60)
    display(IPython_display.display_animation(anim, default_mode='loop'))


In [7]:
# Run a demo of the environment
def get_train_policy(path, env):
    constraints_p = [f for f in os.listdir(path) if ("constraints.p" in f) ]
    constraints = pickle.load(open(path+constraints_p[-1],'rb'))["constraints"]

    policy = policies.Policy_quad(env.observation_space.shape[0], env.action_space.shape[0])
    states, actions, rewards, info = zip(*constraints)
    policy.train(torch.tensor(states),torch.tensor(actions), epoch=500)
    return policy
    
def getframes(policy, env, steps=1000):
    observation = env.reset()
    cum_reward = 0
    frames = []
    for t in range(steps):
        # Render into buffer. 
        # You will still see the window.
        frames.append(env.render(mode = 'rgb_array'))
        obs =  torch.from_numpy(observation).unsqueeze(0).float()
        action = policy(obs).detach().numpy()
        observation, reward, done, info = env.step(action)
        if done:
            break
    return frames

In [None]:
path = "results/Hopper-v2/05_16_03_08/"
env = gym.make('Hopper-v2')
policy_1000 = get_train_policy(path,env)
frames_1000 = getframes(policy, env)
display_frames_as_gif(frames_1000)