In [1]:
'''
Read TD controller and convert to C-style parameters.
'''

import numpy as np
import torch
import TD3
import utils

def print_param(param):
    s = ''
    print np.shape(param)
    param = np.reshape(param, -1)
    for i in range(len(param)):
        s += str(param[i])
        if i <> len(param)-1:
            s += ', '
    return s

def relu(x):    
    return (abs(x) + x) / 2.0

def tanh(x):
    return (np.exp(x) - np.exp(-x)) / (np.exp(x) + np.exp(-x))


state_dim = 2
action_dim = 1
max_action = 1
dt = 0.02

args = {
    'start_timesteps':1e5,  # 1e4
    'eval_freq': 5e3,
    'expl_noise': 0.1,
    'batch_size': 256,
    'discount': 0.99,
    'tau': 0.005,
    'policy_noise': 0.2,
    'noise_clip': 0.5,
    'policy_freq': 2
}

kwargs = {
    "state_dim": state_dim,
    "action_dim": action_dim,
    "max_action": max_action,
    "discount": args['discount'],
    "tau": args['tau'],
}

# Target policy smoothing is scaled wrt the action scale
kwargs["policy_noise"] = args['policy_noise'] * max_action
kwargs["noise_clip"] = args['noise_clip'] * max_action
kwargs["policy_freq"] = args['policy_freq']
policy = TD3.TD3(**kwargs)

# load policy
policy.load('TD3models/ctrlmodel_600')

state = np.array([0.1, 0.2])
action = policy.select_action(state)

print 'action from pytorch:', action

W1 = policy.actor.l1.weight.data.numpy()
b1 = policy.actor.l1.bias.data.numpy()
W2 = policy.actor.l2.weight.data.numpy()
b2 = policy.actor.l2.bias.data.numpy()
W3 = policy.actor.l3.weight.data.numpy()
b3 = policy.actor.l3.bias.data.numpy()

print np.shape(W1), np.shape(W2), np.shape(W3)

fc1 = relu(np.matmul(state, W1.T) + b1)
fc2 = relu(np.matmul(fc1, W2.T) + b2)
out = tanh(np.matmul(fc2, W3.T) + b3)
print 'fc1:', fc1
print 'fc2:', fc2
print 'action from matmul:', out * 200

action from pytorch: [0.9996142]
(8, 2) (8, 8) (1, 8)
fc1: [0.49692867 0.4095066  0.77692754 0.11859583 0.55293866 0.57295867
 0.96415396 0.        ]
fc2: [0.         1.03887431 1.34751408 1.53618617 0.         1.49733175
 0.         0.        ]
action from matmul: [199.92284131]


In [2]:
filename = 'NNparam.txt'
with open(filename, 'w') as f:
    f.write('float w1[2][8] = {'  + print_param(W1.T) + '};\n' )
    f.write('float w2[8][8] = {' + print_param(W2.T) + '};\n' )
    f.write('float w3[8][1] = {'  + print_param(W3.T) + '};\n' )
    f.write('float b1[8] = {'  + print_param(b1) + '};\n' )
    f.write('float b2[8] = {'  + print_param(b2) + '};\n' )
    f.write('float b3[1] = {'   + print_param(b3) + '};\n' )

(2, 8)
(8, 8)
(8, 1)
(8,)
(8,)
(1,)
