In [3]:
import torch
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

import gym
import d4rl

from rlkit.torch.networks import Mlp



In [5]:
env_name = 'halfcheetah-medium-v0'

env = gym.make(env_name)
obs_dim = env.observation_space.low.size
action_dim = env.action_space.low.size




## dataset test


In [9]:
dataset = env.get_dataset()
obs = dataset['observations']
actions = dataset['actions']

In [11]:
obs = torch.Tensor(obs) # transform to torch tensor
actions = torch.Tensor(actions)

dataset = TensorDataset(obs,actions) # create your datset
dataloader = DataLoader(dataset, batch_size=4, shuffle=False) #

In [12]:
x , y = next(iter(dataloader))  

In [13]:
x.size()

torch.Size([4, 17])

## Checkpoint

In [7]:
M = 64
network = Mlp(
        input_size=obs_dim + action_dim,
        output_size=1,
        hidden_sizes=[M, M],
    )

target_network = Mlp(
    input_size=obs_dim + action_dim,
    output_size=1,
    hidden_sizes=[M, M],
)

In [8]:
path = '../models/Nov-03-2020_1648_halfcheetah-medium-v0.pt'
checkpoint = torch.load(path)
network.load_state_dict(checkpoint['network_state_dict'])
target_network.load_state_dict(checkpoint['target_state_dict'])

print('Loading model: {}'.format(path))

Loading model: ../models/Nov-03-2020_1648_halfcheetah-medium-v0.pt


In [14]:
data = torch.cat((x, y), dim=1)
print(data.size())

torch.Size([4, 23])


In [15]:
out1 = network(data)
out2 = target_network(data)

In [16]:
out1

tensor([[0.0020],
        [0.0073],
        [0.0038],
        [0.0107]], grad_fn=<AddmmBackward>)

In [17]:
out2


tensor([[0.0022],
        [0.0087],
        [0.0043],
        [0.0110]], grad_fn=<AddmmBackward>)

In [18]:
(out1 - out2).pow(2)

tensor([[4.0721e-08],
        [2.0440e-06],
        [2.1009e-07],
        [6.3520e-08]], grad_fn=<PowBackward0>)

In [19]:
tmp = abs(out1-out2)

tensor([[0.0002],
        [0.0014],
        [0.0005],
        [0.0003]], grad_fn=<AbsBackward>)

In [None]:
print()

## logger 


### checkpoint

In [5]:
import pickle
path = '../logs/sac-d4rl/sac_d4rl_2020_11_02_19_13_40_0000--s-0/params.pkl'
with open(path, "rb") as input_file:
    data= pickle.load(input_file)

In [6]:
data

119547037146038801333356

### experiment

In [8]:
from rlkit.launchers.launcher_util import setup_logger

exp_name = 'aaa'
variant = dict()
setup_logger(exp_name, variant=variant, base_log_dir='../tmp/')


2020-11-04 11:11:12.293254 UTC | Variant:
2020-11-04 11:11:12.296242 UTC | {}


'../tmp/aaa/aaa_2020_11_04_11_11_12_0000--s-0'

In [11]:
exp_name = 'aaa'
variant = dict()
setup_logger(exp_name, exp_id = 1,variant=variant, base_log_dir='../tmp/')


2020-11-04 11:13:47.133939 UTC | [aaa_2020_11_04_11_11_12_0000--s-0] [aaa_2020_11_04_11_12_04_0000--s-0] Variant:
2020-11-04 11:13:47.140471 UTC | [aaa_2020_11_04_11_11_12_0000--s-0] [aaa_2020_11_04_11_12_04_0000--s-0] {}


'../tmp/aaa/aaa_2020_11_04_11_13_47_0001--s-0'