In [17]:
%%capture
import torch
import modules.model2sim as sim
import torch.nn as nn
import gym
import config
from IPython import display
import matplotlib.pyplot as plt
from IPython.utils import io


def Binarize(tensor, include_zero = False):
    if include_zero:
        return ((tensor+0.5).sign()+(tensor-0.5).sign())/2
    else:
        return tensor.sign()

class BinarizeLinear(nn.Linear):

    def __init__(self, *kargs, **kwargs):
        super(BinarizeLinear, self).__init__(*kargs, **kwargs)

    def forward(self, input):
        if input.size(1) != 784:
            input.data=Binarize(input.data)
        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data=Binarize(self.weight.org)
        out = nn.functional.linear(input, self.weight)
        if not self.bias is None:
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1).expand_as(out)
        return out

class PositiveBinarizeLinear(nn.Linear):

    def __init__(self, *kargs, **kwargs):
        super(PositiveBinarizeLinear, self).__init__(*kargs, **kwargs)

    def forward(self, input):
        zero = torch.zeros_like(input.data)
        input.data = torch.where(input.data > 0, input.data, zero)
        input.data=Binarize(input.data)
        if not hasattr(self.weight,'org'):
            self.weight.org=self.weight.data.clone()
        self.weight.data=Binarize(self.weight.org)
        out = nn.functional.linear(input, self.weight)
        if not self.bias is None:
            self.bias.org=self.bias.data.clone()
            out += self.bias.view(1, -1).expand_as(out)

        return out

class Predefined_policy(nn.Module):
    def __init__(self):
        super(Predefined_policy, self).__init__()
        self.fc = BinarizeLinear(4, 2, bias = False)
        self.fcp = PositiveBinarizeLinear(2, 1, bias = False)
        self.fc.weight = nn.Parameter(torch.tensor([[1.0,0.0,-1.0,0],[0.0,1.0,0,-1.0]]))
        self.fcp.weight = nn.Parameter(torch.tensor([[1.0,1.0]]))
        
        self.saved_log_probs = []
        self.rewards = []

    def parseInput(self, x):
        theta, w = x[0][2:4]
        res = torch.zeros([1,4])
        res[0][0] = float(theta > 0)
        res[0][1] = float(w > 0)
        res[0][2] = float(abs(theta) < 0.03)
        res[0][3] = float(abs(theta) >= 0.03)
        return res

    def forward(self, x):
        x = self.parseInput(x)
        x = self.fc(x)
        action_scores = self.fcp(x)
        return action_scores


def parseInput(x):
    theta, w = x[2:4]
    res = torch.zeros([4])
    res[0] = float(theta > 0)
    res[1] = float(w > 0)
    res[2] = float(abs(theta) < 0.03)
    res[3] = float(abs(theta) >= 0.03)
    return res

model = Predefined_policy()
model.load_state_dict(torch.load(config.TRAINED_MODELS_DIR + "predefined-CartPole.pt",map_location=torch.device('cpu')))

env = gym.make('CartPole-v1')
env = gym.wrappers.Monitor(env, "./", force=True)
state = env.reset()
for t in range(1, 1000):  # Don't infinite loop while learning
    net = sim.SimfromModel(model)
    with io.capture_output() as captured:
        res = net.stimulate(parseInput(state), simLen = 30, verbosity=1);
    action = len(res[0][0]["times"])
    state, reward, done, _ = env.step(action)
    #plt.imshow(env.render(mode='rgb_array', close=False))
    #display.display(plt.gcf())    
    #display.clear_output(wait=True)
    
    if done:
        print("done!")
        env.close()
        break
        
# this shouldnt take more than a minute

In [15]:
import base64
from IPython.core.display import HTML

video = open(env.videos[-1][0], 'r+b').read()
encoded = base64.b64encode(video)
HTML(data='<video width="360" height="auto" alt="test" controls><source src="data:video/mp4;base64,{0}" type="video/mp4" /></video>'.format(encoded.decode('ascii')))
