In [2]:
from curl_sac import RadSacAgent

import numpy as np
import torch
import argparse
import os
import math
import gym
import sys
import random
import time
import json
import dmc2gym
import copy
import retro
from pathlib import Path
import weightwatcher as ww


import utils
from logger import Logger
from video import VideoRecorder

from curl_sac import RadSacAgent
from torchvision import transforms
import data_augs as rad

In [3]:
def parse_args(file_path):
    args = json.load(open(os.path.join(file_path, 'args.json')))
    return args

def make_agent(obs_shape, action_shape, args, device):
    if args.agent == 'rad_sac':
        return RadSacAgent(
            obs_shape=obs_shape,
            action_shape=action_shape,
            device=device,
            hidden_dim=args.hidden_dim,
            discount=args.discount,
            init_temperature=args.init_temperature,
            alpha_lr=args.alpha_lr,
            alpha_beta=args.alpha_beta,
            actor_lr=args.actor_lr,
            actor_beta=args.actor_beta,
            actor_log_std_min=args.actor_log_std_min,
            actor_log_std_max=args.actor_log_std_max,
            actor_update_freq=args.actor_update_freq,
            critic_lr=args.critic_lr,
            critic_beta=args.critic_beta,
            critic_tau=args.critic_tau,
            critic_target_update_freq=args.critic_target_update_freq,
            encoder_type=args.encoder_type,
            encoder_feature_dim=args.encoder_feature_dim,
            encoder_lr=args.encoder_lr,
            encoder_tau=args.encoder_tau,
            num_layers=args.num_layers,
            num_filters=args.num_filters,
            log_interval=args.log_interval,
            detach_encoder=args.detach_encoder,
            latent_dim=args.latent_dim,
            data_augs=args.data_augs
        )
    else:
        assert 'agent is not supported: %s' % args.agent







In [4]:
class dotdict(dict):
    """dot.notation access to dictionary attributes"""
    __getattr__ = dict.get
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__

In [5]:
args = parse_args('tmp/cartpole-swingup-12-03-im84-b128-s23-pixel')
args = dotdict(args)
if args['seed'] == -1: 
    args.__dict__["seed"] = np.random.randint(1,1000000)
utils.set_seed_everywhere(args['seed'])

pre_transform_image_size = args['pre_transform_image_size'] if 'crop' in args['data_augs'] else args['image_size']
pre_image_size = args['pre_transform_image_size'] # record the pre transform image size for translation

env = dmc2gym.make(
    domain_name=args.domain_name,
    task_name=args.task_name,
    seed=args.seed,
    visualize_reward=False,
    from_pixels=(args.encoder_type == 'pixel'),
    height=pre_transform_image_size,
    width=pre_transform_image_size,
    frame_skip=args.action_repeat
)

env.seed(args['seed'])

# stack several consecutive frames together
if args.encoder_type == 'pixel':
    env = utils.FrameStack(env, k=args.frame_stack)
    

# # make directory
# ts = time.gmtime() 
# ts = time.strftime("%m-%d", ts)    
# env_name = args.domain_name + '-' + args.task_name
# exp_name = env_name + '-' + ts + '-im' + str(args.image_size) +'-b'  \
# + str(args.batch_size) + '-s' + str(args.seed)  + '-' + args.encoder_type
# args.work_dir = args.work_dir + '/'  + exp_name

# utils.make_dir(args.work_dir)
# video_dir = utils.make_dir(os.path.join(args.work_dir, 'video'))
# model_dir = utils.make_dir(os.path.join(args.work_dir, 'model'))
# buffer_dir = utils.make_dir(os.path.join(args.work_dir, 'buffer'))

# video = VideoRecorder(video_dir if args.save_video else None)

# with open(os.path.join(args.work_dir, 'args.json'), 'w') as f:
#     json.dump(vars(args), f, sort_keys=True, indent=4)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

action_shape = env.action_space.shape

if args.encoder_type == 'pixel':
    obs_shape = (3*args.frame_stack, args.image_size, args.image_size)
    pre_aug_obs_shape = (3*args.frame_stack,pre_transform_image_size,pre_transform_image_size)
else:
    obs_shape = env.observation_space.shape
    pre_aug_obs_shape = obs_shape

replay_buffer = utils.ReplayBuffer(
    obs_shape=pre_aug_obs_shape,
    action_shape=action_shape,
    capacity=args.replay_buffer_capacity,
    batch_size=args.batch_size,
    device=device,
    image_size=args.image_size,
    pre_image_size=pre_image_size,
)

agent = make_agent(
    obs_shape=obs_shape,
    action_shape=action_shape,
    args=args,
    device=device
)

INFO:absl:MUJOCO_GL is not set, so an OpenGL backend will be chosen automatically.
INFO:absl:Failed to import OpenGL backend: glfw
INFO:OpenGL.acceleratesupport:No OpenGL_accelerate module loaded: No module named 'OpenGL_accelerate'
INFO:absl:Successfully imported OpenGL backend: egl
INFO:absl:MuJoCo library version is: 200
  "Box bound precision lowered by casting to {}".format(self.dtype)


In [6]:
agent.load('tmp/cartpole-swingup-12-03-im84-b128-s23-pixel/model', 0)

FileNotFoundError: [Errno 2] No such file or directory: 'tmp/cartpole-swingup-12-03-im84-b128-s23-pixel/model/actor_0.pt'

In [None]:
# from utils import prepare_model


watcher = ww.WeightWatcher(model=agent.actor)
details = watcher.analyze()
details.to_csv("esd_plot.csv")
# check csv to get which layers to input into the analyze function's arg in the next lines
esd_layer_ = watcher.get_esd()

summary = watcher.get_summary(details)
# layers_6 = [127,142,90,71,48,46] # rank and choose from esd_plot.csv
# details = watcher.analyze(plot=True, randomize=True, savefig="esd_plot_figure", layers=layers_6)
