In [None]:
%load_ext autoreload
%autoreload 2

import lovely_tensors as lt
lt.monkey_patch()

import numpy as np
import torch
import random
from matplotlib import pyplot as plt
from tqdm.auto import tqdm


import torch
import torch.nn as nn

from aim import Figure, Image, Run

from assembly_gym.envs.assembly_env import AssemblyEnv, Shape, Block
from assembly_gym.envs.gym_env import AssemblyGym, sparse_reward, tower_setup, bridge_setup, hard_tower_setup
from robotoddler.utils.actions import generate_actions, filter_actions
from assembly_gym.utils.rendering import plot_assembly_env, render_assembly_env

from robotoddler.training.successor_dqn import get_state_features, get_action_features, get_task_features,\
      update_target_net, train_policy_net, rollout_episode, EpsilonGreedy, log_episode
from robotoddler.models.cv import ConvNet, SuccessorMLP

from robotoddler.utils.replay_memory import ReplayBuffer
from robotoddler.utils.utils import init_weights

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
# setup environment and test features
width = height = 64

env = AssemblyGym(**tower_setup(), 
                  reward_fct=sparse_reward,
                  restrict_2d=True, 
                  assembly_env=AssemblyEnv(render=False))
plot_assembly_env(env)

obs, info = env.reset(**tower_setup())
available_actions = [*generate_actions(env, x_discr_ground=np.linspace(0.2, 0.8, 5))]
action = available_actions[3]

# make the state a bit more interesting
obs, _, _, _, info = env.step(action)
available_actions = [*generate_actions(env, x_discr_ground=np.linspace(0.2, 0.8, 5))]


# get all the features
state_features, binary_features = get_state_features(obs, img_size=(width, height))
task_features, obstacle_features = get_task_features(obs, img_size=(width, height))
# obstacle_features = get_obstacle_features(env, width=width, height=height)
action_features = get_action_features(env, available_actions, img_size=(width, height))
# binary_features = get_binary_features(obs)

# plotting / illustration
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(15, 5))

ax1.imshow(task_features.squeeze(), cmap='gray')
ax1.set_title('Task Features')

ax2.imshow(state_features.squeeze(), cmap='gray')
ax2.set_title('State Features')

ax3.imshow(obstacle_features.squeeze(), cmap='gray')
ax3.set_title('Obstacle Features')

ax4.imshow(action_features[4].squeeze(), cmap='gray')
ax4.set_title('Action Features')

env.assembly_env.disconnect_client()

In [None]:
# testing the SuccessorNet
model = SuccessorMLP(img_size=(width, height)).to(device)

def init_weights(m):
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model.apply(init_weights)
model.eval()

num_actions = 3
# random input of size 512 x 512
x = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
o = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
a = torch.rand(num_actions, 1, width, height, device=device)
z = torch.ones(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
y = torch.ones(1, 6, device=device).expand((num_actions, -1))

q_values, succ_img, succ_bin = model(x, y, a, z, o)
print(succ_img)
print(succ_bin)

img = succ_img[0].softmax(dim=0)[1]
features = succ_bin[0].softmax(dim=0)[1]

# 
print(features)
plt.imshow(img.detach().cpu().numpy(), cmap='gray', vmin=0, vmax=1)


In [None]:
# testing the ConvNet
model = ConvNet(4, img_size=(width, height)).to(device)

def init_weights(m):
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model.apply(init_weights)
model.eval()

num_actions = 3
# random input of size 512 x 512
x = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
o = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
a = torch.rand(num_actions, 1, width, height, device=device)
z = torch.ones(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
y = torch.ones(1, 6, device=device).expand((num_actions, -1))

q_values, _, features = model(x, y, a, z, o)
print(q_values)
print(features)

# Successor Feature Q-Learning

In [None]:
%matplotlib inline

# This code is taken from robotoddler.training.successor_dqn.py

args = dict(
    gamma=0.95,
    batch_size=32,
    evaluate_every=10,
    num_episodes=1000,
    learning_rate=0.0001,
    num_training_steps=25,
    tau=0.01,
    loss_function='mse_q_values',
    verbose=True,
    seed=1,
    tower_height=2,
    max_steps=10,
    image_size=(64, 64),
)

gamma = args['gamma']
verbose = args['verbose']

# random seed
random.seed(args['seed'])
np.random.seed(args['seed'])
torch.manual_seed(args['seed'])

# initialize everything
episode = 0

# setup function to generate environments
x_discr_ground = np.linspace(0.2, 0.8, 3)
def setup_fct():
    tower_height = 0.02 + 0.05 * args['tower_height']
    return tower_setup(targets=[(random.choice(x_discr_ground), 0, tower_height)])

env = AssemblyGym(reward_fct=sparse_reward, max_steps=args['max_steps'], restrict_2d=True, assembly_env=AssemblyEnv(render=False))

# Successor Feature MLP
hidden_dims = [256, 128, 64, 128, 256]
policy_net = SuccessorMLP(img_size=args['image_size'], hidden_dims=hidden_dims).to(device)
target_net = SuccessorMLP(img_size=args['image_size'], hidden_dims=hidden_dims).to(device)

# Standard Conv Net for predicting q values
# policy_net = ConvNet(img_size=args['image_size']).to(device)
# target_net = ConvNet(img_size=args['image_size']).to(device)

optimizer = torch.optim.Adam(policy_net.parameters(), lr=args['learning_rate'])
policy_net.apply(init_weights)
target_net.load_state_dict(policy_net.state_dict())

replay_buffer = ReplayBuffer(capacity=2000)

aim_run = None
# aim_run = aim.Run(experiment="SuccessorQLearning", repo=args['aim_repo'])

# define policies
eps_greedy = EpsilonGreedy(eps_start=0.5, gamma=0.999, eps_end=0.05, episode=episode)
greedy = lambda q, *args, **kwargs: torch.argmax(q)


# training loop
it = tqdm(range(episode + 1, episode + args['num_episodes'] + 1), disable=not verbose)
for i in it:
    # rollout episde
    transitions, images = rollout_episode(env=env, 
                                    policy=eps_greedy.step(), 
                                    policy_net=policy_net, 
                                    setup_fct=setup_fct, 
                                    x_discr_ground=x_discr_ground, 
                                    img_size=args['image_size'],
                                    device=device)
    
    # add transistions to replay buffer
    replay_buffer.push(transitions)
    
    # train policy net for n steps
    losses = train_policy_net(policy_net=policy_net, 
                target_net=target_net, 
                optimizer=optimizer, 
                loss_fct=args['loss_function'],
                replay_buffer=replay_buffer, 
                gamma=gamma, 
                batch_size=args['batch_size'],
                n_steps=args['num_training_steps'],
                device=device,
                verbose=False)

    # update the target net
    update_target_net(policy_net=policy_net, target_net=target_net, tau=args['tau'])

    # logging
    log_info, fig = log_episode(
        episode=i, 
        transitions=transitions,
        policy=eps_greedy,
        losses=losses,
        context='training',
        gamma=gamma,
        aim_run=aim_run
    )
    it.set_postfix(episode=i, **log_info)


    # evaluate using greedy policy
    if i % args['evaluate_every'] == 0:
        # evaluate
        transitions, images = rollout_episode(env=env, 
                                    policy=greedy, 
                                    policy_net=policy_net, 
                                    setup_fct=setup_fct, 
                                    x_discr_ground=x_discr_ground, 
                                    img_size=args['image_size'],
                                    device=device,
                                    log_images=True)

        log_info, fig = log_episode(
            episode=i,
            transitions=transitions,
            images=images,
            log_images=True,
            losses=None,
            context='evaluation',
            gamma=gamma,
            aim_run=aim_run
        )
        print(log_info)
        plt.show()
        # plt.close(fig)

        it.set_postfix(episode=episode, **log_info)