In [None]:
%load_ext autoreload
%autoreload 2

import lovely_tensors as lt
lt.monkey_patch()

import numpy as np
from collections import namedtuple, deque
import torch
import gc
import math
import random
from collections import namedtuple, deque
from matplotlib import pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim.lr_scheduler import MultiStepLR

from aim import Figure, Image, Run

from assembly_gym.envs.assembly_env import AssemblyEnv, Shape, Block
from assembly_gym.envs.gym_env import AssemblyGym, sparse_reward, tower_setup, bridge_setup, hard_tower_setup
from assembly_gym.utils.geometry import align_frames_2d
from assembly_gym.utils.rendering import plot_assembly_env, render_assembly_env

from robotoddler.utils.actions import generate_actions

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
# define image-based features

def get_state_features(observation, xlim=(0, 1), ylim=(0, 1), width=512, height=512):
    return render_blocks(observation['blocks'], xlim=xlim, ylim=ylim, width=width, height=height)

def get_obstacle_features(env, xlim=(0, 1), ylim=(0, 1), width=512, height=512):
    return render_blocks(env.assembly_env.obstacles, xlim=xlim, ylim=ylim, width=width, height=height)

def action_image_features(env, actions, xlim=(0,1), ylim=(0, 1), width=512, height=512):
    blocks = [env.create_block(action) for action in actions]
    return [render_blocks([block], xlim=xlim, ylim=ylim, width=width, height=height) for block in blocks]

def get_binary_features(observation):
    return np.array([
        observation['stable'],
        observation['collision'], # ToDo obstacle vs block collision
        observation['collision_block'],
        observation['collision_obstacle'],
        observation['collision_floor'],
        observation['collision_boundary'],
    ])

def get_task_features(env, xlim=(0, 1), ylim=(0, 1), width=512, height=512):
    cube = Shape(urdf_file='../assembly_gym/shapes/cube.urdf')
    target_blocks = [Block(shape=cube, position=target) for target in env.targets]
    return render_blocks(target_blocks, xlim=xlim, ylim=ylim, width=width, height=height)

def render_blocks(blocks, xlim, ylim, width=512, height=512):
    image = np.zeros((width, height), dtype=bool)
    X, Y = np.meshgrid(np.linspace(*xlim, image.shape[0]), np.linspace(*ylim, image.shape[1]))
    positions = np.vstack([X.ravel(), Y.ravel()]).T
    for block in blocks:
        image = image | block.contains_2d(positions).reshape(image.shape)

    return image

In [None]:
def reduce_available_actions(state_features, obstacle_features, available_actions, action_features):
    """ 
    helper function to prune available actions based on immediate collisions checking image overlap
    """
    mask = torch.zeros(len(available_actions), dtype=torch.bool)
    reduced_available_actions= []
    for i, action in enumerate(available_actions):
        if torch.sum(action_features[i] * state_features) == 0 and torch.sum(action_features[i] * obstacle_features) == 0:
            mask[i] = True
            reduced_available_actions.append(action)
    
    return reduced_available_actions, action_features[mask]

In [None]:
# # load shapes
# trapezoid = Shape(urdf_file='../assembly_gym/shapes/trapezoid.urdf')
# vblock = Shape(urdf_file='../assembly_gym/shapes/v_block.urdf')
# cube = Shape(urdf_file='../assembly_gym/shapes/cube.urdf')
# tblock = Shape(urdf_file='../assembly_gym/shapes/t_block.urdf')

# trapezoid.num_faces_2d, trapezoid.urdf_file

In [None]:
# setup environment and test features
width = height = 64

env = AssemblyGym(**tower_setup(), 
                  reward_fct=sparse_reward,
                  restrict_2d=True, 
                  assembly_env=AssemblyEnv(render=False))
plot_assembly_env(env)

x_discr_ground=np.linspace(0, 1, 10)
x_block_offset=[0.]

obs, info = env.reset()
available_actions = [*generate_actions(env, x_discr_ground, x_block_offset)]
action = available_actions[5]

# make the state a bit more interesting
obs, _, _, _, info = env.step(action)
available_actions = [*generate_actions(env, x_discr_ground, x_block_offset)]

# get all the features
state_features = get_state_features(obs, width=width, height=height)
task_features = get_task_features(env, width=width, height=height)
obstacle_features = get_obstacle_features(env, width=width, height=height)
action_features = action_image_features(env, available_actions, width=width, height=height)
binary_features = get_binary_features(obs)
print(binary_features)

# plotting / illustration
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(15, 5))

ax1.imshow(task_features, cmap='gray')
ax1.invert_yaxis()
ax1.set_title('Task Features')

ax2.imshow(state_features, cmap='gray')
ax2.invert_yaxis()
ax2.set_title('State Features')

ax3.imshow(obstacle_features, cmap='gray')
ax3.invert_yaxis()
ax3.set_title('Obstacle Features')

ax4.imshow(action_features[4], cmap='gray')
ax4.invert_yaxis()
ax4.set_title('Action Features')

env.assembly_env.disconnect_client()

In [None]:
# replay buffer definition
# because the available actions depend on the state, we need to store a bit more to avoid re-evaluting the environment when training

# noinspection PyTypeChecker
Transition = namedtuple('Transition',
                        ('state_features',
                         'binary_features',
                         'action',
                         'action_features', 
                         'reward',
                         'next_state_features', 
                         'next_binary_features',
                         'next_available_actions',
                         'next_actions_features',
                         'task_features',
                         'obstacle_features',
                         'done'))

def tensor_size_MB(a):
    return a.element_size() * a.nelement() / 1024 / 1024

class ReplayBuffer:

    def __init__(self, capacity):
        self.episode = 0
        self.memory = deque([], maxlen=capacity)

    def push(self, transition):
        """Save a transition"""
        self.memory.append(transition)

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)
    
    def size_in_MB(self):
        if len(self.memory) == 0:
            return 0
        size_of_first_element = 0
        
        for field in Transition._fields:
            if isinstance(getattr(self.memory[0], field), torch.Tensor):
                size_of_first_element += tensor_size_MB(getattr(self.memory[0], field))
        return size_of_first_element * len(self.memory)

    def __len__(self):
        return len(self.memory)

In [None]:
# Here we define two models
# - one conv net that maps input states to q-values and successor features for only the binary variables (not the images)
# - one UNet that maps input states to q-values and successor features for the images
# (Johannes) The Unet does currently not work well, it's probably not a good architecture for greating images 
# that significantly deviate from the input, or something else is going wrong ...

class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)
        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        # x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        # x = self.bn2(x)
        x = self.relu(x)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.conv = ConvBlock(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)
        return x, p
    
class MLP(nn.Module):
    def __init__(self, in_d, out_d, hidden_d=64):
        super().__init__()
        self.fc1 = nn.Linear(in_d, hidden_d)
        self.fc2 = nn.Linear(hidden_d, out_d)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.fc1(inputs)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        return x
    

class DecoderBlock(nn.Module):
    def __init__(self, in_c, out_c, add_skip=True):
        super().__init__()
        self.add_skip = add_skip
        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = ConvBlock((1 + int(bool(add_skip))) * out_c, out_c)
        
        # fully connected upsampling of the features
        # self.feature_up = mlp()

    def forward(self, inputs, skip=None):
        x = self.up(inputs)
        if self.add_skip:
            x = torch.cat([x, skip], axis=1)
        x = self.conv(x)
        return x


class SuccessorNet(nn.Module):
    """
    A MLP with bottle net that predicts successor images and features.
    """
    def __init__(self, in_channels=4, img_size=(512, 512), num_features=6, hidden_dims=None):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [128, 64, 128]

        self.img_size = img_size
        self.relu = nn.ReLU()

        self.layers = nn.ModuleList()
        cur_dim = in_channels * img_size[0] * img_size[1] + num_features
        last_dim = 2 * img_size[0] * img_size[1] + 2 * num_features

        for hidden_dim in hidden_dims + [last_dim]:
            self.layers.append(nn.Linear(cur_dim, hidden_dim))
            cur_dim = hidden_dim

    def forward(self, state, features, action, task, obstacles):
        img_size = self.img_size
        x = torch.cat([state, action, task, obstacles], dim=1).view(state.shape[0], -1)
        x = torch.cat([x, features], dim=1)

        for layer in self.layers:
            x = self.relu(layer(x))
        
        img_out = x[:, :2 * img_size[0] * img_size[1]].view(-1, 2, img_size[0], img_size[1])
        features_out = x[:, 2 * img_size[0] * img_size[1]:].view(-1, 2, features.shape[1])
        return img_out, features_out


class UNet(nn.Module):
    def __init__(self, in_channels, img_size=(512, 512), num_features=6, add_skip=True):
        super().__init__()
        self.add_skip = add_skip

        """ 1st Encoder """
        self.e0 = EncoderBlock(in_channels, 16)
        self.e1 = EncoderBlock(16, 32)
        self.e2 = EncoderBlock(32, 64)
        self.e3 = EncoderBlock(64, 128)
        # self.e4 = EncoderBlock(128, 256)
        
        
        """ Bottleneck """
        # self.b = ConvBlock(256, 512)
        self.b = ConvBlock(128, 256)

        """ Decoder """
        # self.d1 = DecoderBlock(512, 256, add_skip=add_skip)
        self.d2 = DecoderBlock(256, 128, add_skip=add_skip)
        self.d3 = DecoderBlock(128, 64, add_skip=add_skip)
        self.d4 = DecoderBlock(64, 32, add_skip=add_skip)
        self.d5 = DecoderBlock(32, 16, add_skip=add_skip)
        
        self.out = nn.Conv2d(16, 2, kernel_size=1, padding=0)

        # compute bottleneck size
        x = torch.zeros(1, in_channels, *img_size)
        with torch.no_grad():
            x = self.e0(x)[-1]
            x = self.e1(x)[-1]
            x = self.e2(x)[-1]
            x = self.e3(x)[-1]
            # x = self.e4(x)[-1]
            x = self.b(x)
            self.bottleneck_size = x.shape[1] * x.shape[2] * x.shape[3]

        self.mlp = MLP(self.bottleneck_size + num_features, 2 * num_features)
        
    def forward(self, state, features, action, task, obstacles):
        # print(state.shape, obstacles.shape, action.shape, task.shape)
        input = torch.cat([state, action, task, obstacles], dim=1)
        
        # print("min/max", input.min().item(), input.max().item())
        # print(input.shape)
        s0, p0 = self.e0(input)
        # print("min/max", p0.min().item(), p0.max().item())
        """ 1st Encoder """
        s1, p1 = self.e1(p0)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        # s4, p4 = self.e4(p3)
        
        """ Bottleneck """
        # b = self.b(p4)
        b = self.b(p3)

        """ Decoder """
        if self.add_skip:
            # d1 = self.d1(b, s4)
            # d2 = self.d2(d1, s3)
            d2 = self.d2(b, s3)
            d3 = self.d3(d2, s2)
            d4 = self.d4(d3, s1)
            d5 = self.d5(d4, s0)
        else:
            # d1 = self.d1(b)
            # d2 = self.d2(d1)
            d2 = self.d2(b)
            d3 = self.d3(d2)
            d4 = self.d4(d3)
            d5 = self.d5(d4)

        succ_image = self.out(d5)

        features_out = self.mlp(torch.cat([b.view(-1, self.bottleneck_size), features], dim=1))
        features_out = features_out.view(-1, 2, features_out.shape[1] // 2)
        

        return succ_image, features_out
    

class ConvNet(nn.Module):
    """
    A plain conv net for classification / predicting q-values.
    """
    def __init__(self, in_channels, img_size=(512, 512), num_features=6):
        super().__init__()
        self.pool = nn.MaxPool2d((2, 2))
        """ 1st Encoder """
        self.c0 = ConvBlock(in_channels, 16)
        self.c1 = ConvBlock(16, 32)
        self.c2 = ConvBlock(32, 64)
        self.c3 = ConvBlock(64, 128)
        self.c4 = ConvBlock(128, 256)


        # compute bottleneck size
        x = torch.zeros(1, in_channels, *img_size)
        with torch.no_grad():
            x = self.c0(x)
            x = self.pool(x)
            x = self.c1(x)
            x = self.pool(x)
            x = self.c2(x)
            x = self.pool(x)
            x = self.c3(x)
            x = self.pool(x)
            x = self.c4(x)
            self.bottleneck_size = x.shape[1] * x.shape[2] * x.shape[3]

        self.mlp = MLP(self.bottleneck_size + num_features, 2 * num_features + 1)
        
    def forward(self, state, features, action, task, obstacles):
        # print(state.shape, obstacles.shape, action.shape, task.shape)
        input = torch.cat([state, action, task, obstacles], dim=1)
        
        s0 = self.c0(input)
        s0 = self.pool(s0)
        s1 = self.c1(s0)
        s1 = self.pool(s1)
        s2 = self.c2(s1)
        s2 = self.pool(s2)
        s3 = self.c3(s2)
        s3 = self.pool(s3)
        s4 = self.c4(s3)

        out = self.mlp(torch.cat([s4.view(-1, self.bottleneck_size), features], dim=1))
        q = out[:,0]
        features_out = out[:,1:]
        features_out = features_out.view(-1, 2, features.shape[1])

        return q, features_out
    

def compute_reward(state, task):
    """
    compute the reward based on the state and the task, i.e. the inner product of the state and the task images
    """
    image_features = state.softmax(dim=1)[:,1]
    return torch.sum(image_features * task.squeeze(1), dim=(1,2))

In [None]:
# testing the Unet
model = UNet(4, add_skip=False, img_size=(width, height)).to(device)

def init_weights(m):
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model.apply(init_weights)
model.eval()

num_actions = 3
# random input of size 512 x 512
x = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
o = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
a = torch.rand(num_actions, 1, width, height, device=device)
z = torch.ones(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
y = torch.ones(1, 6, device=device).expand((num_actions, -1))

img_out, features_out = model(state=x, features=y, action=a, task=z, obstacles=o)
print(img_out)
print(features_out)

img = img_out[0].softmax(dim=0)[1]
features = features_out[0].softmax(dim=0)[1]

# 
print(features)
plt.imshow(img.detach().cpu().numpy(), cmap='gray')
# Q_valu


In [None]:
# testing the SuccessorNet
model = SuccessorNet(img_size=(width, height)).to(device)

def init_weights(m):
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model.apply(init_weights)
model.eval()

num_actions = 3
# random input of size 512 x 512
x = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
o = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
a = torch.rand(num_actions, 1, width, height, device=device)
z = torch.ones(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
y = torch.ones(1, 6, device=device).expand((num_actions, -1))

img_out, features_out = model(state=x, features=y, action=a, task=z, obstacles=o)
print(img_out)
print(features_out)

img = img_out[0].softmax(dim=0)[1]
features = features_out[0].softmax(dim=0)[1]

# 
print(features)
plt.imshow(img.detach().cpu().numpy(), cmap='gray', vmin=0, vmax=1)


In [None]:
# testing the ConvNet

model = ConvNet(4, img_size=(width, height)).to(device)

def init_weights(m):
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

model.apply(init_weights)
model.eval()

num_actions = 3
# random input of size 512 x 512
x = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
o = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
a = torch.rand(num_actions, 1, width, height, device=device)
z = torch.ones(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
y = torch.ones(1, 6, device=device).expand((num_actions, -1))

q, features = model(state=x, features=y, action=a, task=z, obstacles=o)
print(q)
print(features)

# Q-Learning

In [None]:
def select_action(model, state_features, binary_features, action_features, task_features, obstacle_features, epsilon=None, return_succ_features=False):
    """
    Implements eps-greedy action selection based on the model's predictions.
    If return_succ_features is True, the function returns the index of the selected action and the predicted successor features (for training)
    Note that the return values are semantically different for Unet and ConvNet/SuccessorNet models.
    """
    num_actions = action_features.shape[0]

    if epsilon is None or random.random() > epsilon:
        with torch.no_grad():
            if isinstance(model, (UNet, SuccessorNet)):
                succ_img, succ_bin = model(state_features.expand((num_actions, -1, -1, -1)), 
                                    binary_features.expand((num_actions, -1)), 
                                    action_features, 
                                    task_features.expand((num_actions, -1, -1, -1)), 
                                    obstacle_features.expand((num_actions, -1, -1, -1)))
                q_values = compute_reward(succ_img, task_features)
                i = torch.argmax(q_values)
                if return_succ_features:
                    return i, succ_img[i], succ_bin[i]
                return i
            
            elif isinstance(model, ConvNet):
                q_values, succ_bin = model(state_features.expand((num_actions, -1, -1, -1)), 
                                    binary_features.expand((num_actions, -1)), 
                                    action_features, 
                                    task_features.expand((num_actions, -1, -1, -1)), 
                                    obstacle_features.expand((num_actions, -1, -1, -1)))
                i = torch.argmax(q_values)
                if return_succ_features:
                    return i, q_values[i], succ_bin[i]
                return i
            
    else:
        i = np.random.randint(0, num_actions)
        if return_succ_features:
            succ_img, succ_bin = model(state_features.unsqueeze(0),
                                  binary_features.unsqueeze(0), 
                                  action_features[i].unsqueeze(0), 
                                  task_features.unsqueeze(0), 
                                  obstacle_features.unsqueeze(0))
            return i, succ_img[i], succ_bin[i]
        return i


In [None]:
def optimize_successor_net(policy_net, target_net, optimizer, scheduler, memory, gamma, n_steps=10, batch_size=16, verbose=True):
    """
    Perform n_steps of optimization on the policy_net (UNet or SuccessorNet) using the transitions in memory.
    There are different ways of defining the loss functions for the successor features (commented out below)
    """

    policy_net.train()
    
    if len(memory) < batch_size:
        return
    
    if verbose:
        print(f"optimizing for {n_steps} steps...")
    it = tqdm(range(n_steps), disable=not verbose)
    losses = []
    for i in it:
        transitions = memory.sample(batch_size)
        batch = Transition(*zip(*transitions))

        # policy net prediction for current state
        successor_state_features, successor_binary_features = policy_net(torch.stack(batch.state_features),
                                                                         torch.stack(batch.binary_features),
                                                                         torch.stack(batch.action_features), 
                                                                         torch.stack(batch.task_features), 
                                                                         torch.stack(batch.obstacle_features))
        
        state_action_values = compute_reward(successor_state_features, torch.stack(batch.task_features))
        

        # Compute next state predictions based on the target net
        next_state_values = torch.zeros(batch_size, device=device)
        next_state_successor_features = torch.zeros((batch_size, *successor_state_features.shape[-2:]), device=device)
        next_state_binary_features = torch.zeros((batch_size, successor_binary_features.shape[-1]), device=device)
        
        with torch.no_grad():
            # TODO: can we batch this?
            for j, transition in enumerate(transitions):
                if not transition.done:
                    num_actions = transition.next_actions_features.shape[0]
                    i, _next_state_succ_img, _next_state_succ_bin = select_action(target_net, state_features=transition.next_state_features.expand((num_actions, -1, -1, -1)), 
                                              binary_features=transition.next_binary_features.expand((num_actions, -1)), 
                                              action_features=transition.next_actions_features, 
                                              task_features=transition.task_features.expand((num_actions, -1, -1, -1)), 
                                              obstacle_features=transition.obstacle_features.expand((num_actions, -1, -1, -1)), return_succ_features=True)
                    
                    next_state_values[j] = compute_reward(_next_state_succ_img.unsqueeze(0), transition.task_features.unsqueeze(0))
                    next_state_binary_features[j] = _next_state_succ_bin.softmax(dim=0)[1]
                    next_state_successor_features[j] = _next_state_succ_img.softmax(dim=0)[1]

                else:
                    # done state
                    # backpropagate zero future reward and static future state
                    next_state_values[j] = 0
                    next_state_binary_features[j] = transition.next_binary_features
                    next_state_successor_features[j] = transition.next_state_features.squeeze()
                    
        
        # define loss functions
        mse = nn.MSELoss()
        cross_entropy = nn.CrossEntropyLoss()
              
        # reward loss
        # loss = mse(state_action_values, torch.stack(batch.reward) + gamma * next_state_values)

        # state feature cross-entropy loss
        # state_target = (1-gamma) * torch.stack(batch.state_features).squeeze(1) + gamma * next_state_successor_features
        # state_target = torch.stack([1 - state_target, state_target], dim=1)
        # loss = cross_entropy(successor_state_features, state_target) 

        # state feature mse loss
        state_target = (1-gamma) * torch.stack(batch.state_features).squeeze(1) + gamma * next_state_successor_features
        loss = mse(successor_state_features.softmax(dim=1)[:,1], state_target)

        # additional features loss
        # add_features_target = (1-gamma) * torch.stack(batch.binary_features) + gamma * next_state_binary_features
        # loss += cross_entropy(successor_binary_features, torch.stack([1 - add_features_target, add_features_target], dim=1)) 

        # Optimize the model
        optimizer.zero_grad()
        loss.backward()
        # In-place gradient clipping
        # torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        it.set_postfix(loss=loss.item())
        losses.append(loss.item())

    # free all memory
    del transitions, batch, state_action_values, next_state_values, next_state_successor_features, next_state_binary_features
    torch.cuda.empty_cache()
    gc.collect()

    policy_net.eval()
    return losses
    

def optimize_convnet(policy_net, target_net, optimizer, scheduler, memory, gamma, n_steps=10, batch_size=16, verbose=True):
    """
    Perform n_steps of optimization on the policy_net using the transitions in memory.
    This is essentially standard Q-learning with a target network.
    """
    
    if len(memory) < batch_size:
        return
    
    if verbose:
        print(f"optimizing for {n_steps} steps...")
    it = tqdm(range(n_steps), disable=not verbose)
    losses = []
    for i in it:
        transitions = memory.sample(batch_size)
        batch = Transition(*zip(*transitions))

        # current state q-values predictions
        state_action_values, successor_binary_features = policy_net(torch.stack(batch.state_features),
                                                                         torch.stack(batch.binary_features),
                                                                         torch.stack(batch.action_features), 
                                                                         torch.stack(batch.task_features), 
                                                                         torch.stack(batch.obstacle_features))
        
        # next state q-values and predictions
        next_state_values = torch.zeros(batch_size, device=device)
        next_state_binary_features = torch.zeros((batch_size, successor_binary_features.shape[-1]), device=device)
        
        with torch.no_grad():
            # TODO: can we batch this?
            for j, transition in enumerate(transitions):
                if not transition.done:
                    num_actions = transition.next_actions_features.shape[0]
                    imax, _next_state_q_values, _next_state_succ_bin = select_action(target_net, state_features=transition.next_state_features.expand((num_actions, -1, -1, -1)), 
                                              binary_features=transition.next_binary_features.expand((num_actions, -1)), 
                                              action_features=transition.next_actions_features, 
                                              task_features=transition.task_features.expand((num_actions, -1, -1, -1)), 
                                              obstacle_features=transition.obstacle_features.expand((num_actions, -1, -1, -1)), return_succ_features=True)
                    
                    next_state_values[j] = _next_state_q_values
                    next_state_binary_features[j] = _next_state_succ_bin.softmax(dim=0)[1]

                else:
                    # done state
                    # backpropagate zero future reward and static future state
                    next_state_values[j] = 0
                    next_state_binary_features[j] = transition.next_binary_features
                    
        
        # define loss functions
        mse = nn.MSELoss()
        cross_entropy = nn.CrossEntropyLoss()
              
        # reward loss
        q_target = torch.stack(batch.reward) + gamma * next_state_values
        loss = mse(state_action_values, q_target)

        # additional features loss
        # add_features_target = (1-gamma) * torch.stack(batch.binary_features) + gamma * next_state_binary_features
        # loss += cross_entropy(successor_binary_features, torch.stack([1 - add_features_target, add_features_target], dim=1)) 

        # Optimize the model
        optimizer.zero_grad()
        loss.backward()
        # In-place gradient clipping
        # torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
        optimizer.step()
        if scheduler is not None:
            scheduler.step()

        it.set_postfix(loss=loss.item())
        losses.append(loss.item())


    # free all memory
    del transitions, batch, state_action_values, next_state_values
    if isinstance(policy_net, UNet):
        del next_state_successor_features, next_state_binary_features
    torch.cuda.empty_cache()
    gc.collect()
    return losses
    

In [None]:
# free memory
gc.collect(), torch.cuda.empty_cache()

# initialization for training

BATCH_SIZE = 32
GAMMA = 0.95
EPS_START = 0.5
EPS_END = 0.05
EPS_DECAY = 1000
TAU = 0.01
OPT_STEPS = 25  # opt steps per round
SHOW_PLOTS = False  # if true, there are more debug outputs in the notebook, so don't run it too long it might crash eventually


def init_weights(m):
    # weights initialization
    if type(m) == nn.Conv2d:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)
    if type(m) == nn.Linear:
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)

# UNet: Doesn't really work
# add_skip = False
# policy_net = UNet(4, add_skip=add_skip, img_size=(width, height)).to(device) # DQN(n_observations, n_actions).to(device)
# target_net = UNet(4, add_skip=add_skip, img_size=(width, height)).to(device) # DQN(n_observations, n_actions).to(device)


# CNN: Works ok, just goal-conditioned q-learning without successor images
        
# policy_net = ConvNet(4, img_size=(width, height)).to(device) # DQN(n_observations, n_actions).to(device)
# target_net = ConvNet(4, img_size=(width, height)).to(device) # DQN(n_observations, n_actions).to(device)
# a smaller learning rate works better for the CNN:
# optimizer = optim.Adam(policy_net.parameters(), lr=0.0001)

# SuccessorNet: Works well, currently trained to predict successor state images with MSE loss
# Cross entropy didn't work well, but maybe it's worth trying again
hidden_dims = [256, 128, 64, 128, 256]
policy_net = SuccessorNet(img_size=(width, height), hidden_dims=hidden_dims).to(device) # DQN(n_observations, n_actions).to(device)
target_net = SuccessorNet(img_size=(width, height), hidden_dims=hidden_dims).to(device) # DQN(n_observations, n_actions).to(device)
optimizer = optim.Adam(policy_net.parameters(), lr=0.001)


# initialization
policy_net.apply(init_weights)
target_net.load_state_dict(policy_net.state_dict())


# scheduler = MultiStepLR(optimizer, milestones=[5000,10000], gamma=0.8)
scheduler = None
memory = ReplayBuffer(2000)


steps_done = 0

np.random.seed(0)

# ground placement discretization
x_discr_ground = np.linspace(0.2, 0.8, 3)

# setup environment
env = AssemblyGym(**tower_setup(targets=[(0.5, 0, 0. + 0.05)]), 
                  reward_fct=sparse_reward,
                  restrict_2d=True, 
                  assembly_env=AssemblyEnv(render=False))
state, info = env.reset()

# setup logging
run = Run(experiment="AssemblyDQN", repo='../aim-data/')
run['setup'] = 'tower'

In [None]:
# actual training

if torch.cuda.is_available():
    num_episodes = 2000
else:
    num_episodes = 50

def extract_img(logits_img):
    return logits_img[0].softmax(dim=0)[1].detach().cpu().numpy()


episode_rewards = []
tq = tqdm(range(num_episodes))

target_net.eval()
policy_net.eval()

losses = []
for i_episode in tq:
    # print(f"starting episode {i_episode}")
    # Initialize a random task environment and get its state
    tower_height = 0.02 + 0.05 * 2 # np.random.randint(0, 3)
    observation, info = env.reset(**tower_setup(targets=[(np.random.choice(x_discr_ground, 1, replace=False).item(), 0, tower_height)]))
    
    # initialize features
    task_features = get_task_features(env, width=width, height=height)
    task_features = torch.tensor(task_features, dtype=torch.float32, device=device).unsqueeze(0)

    obstacle_features = get_obstacle_features(env, width=width, height=height)
    obstacle_features = torch.tensor(obstacle_features, dtype=torch.float32, device=device).unsqueeze(0)
    
    state_features = get_state_features(observation, width=width, height=height)
    state_features = torch.tensor(state_features, dtype=torch.float32, device=device).unsqueeze(0)

    binary_features = get_binary_features(observation)
    binary_features = torch.tensor(binary_features, dtype=torch.float32, device=device)

    available_actions = [*generate_actions(env, x_discr_ground=x_discr_ground)]
    action_features = np.array(action_image_features(env, available_actions, width=width, height=height))
    action_features = torch.tensor(action_features, dtype=torch.float32, device=device).unsqueeze(1)
    available_actions, action_features = reduce_available_actions(state_features, obstacle_features, available_actions, action_features)

    # episode
    episode_reward = 0
    epsilon = EPS_END + (EPS_START - EPS_END) * \
            math.exp(-1. * i_episode / EPS_DECAY)

    eval_round = (i_episode % 5) == 0
    if eval_round:
        epsilon = 0.
    
    # logging
    context = {'eval_round' : eval_round}
    run.track(epsilon, name='epsilon', step=i_episode, context=context)
    run.track(tower_height, name='tower_height', step=i_episode, context=context)

    # maximum of 10 steps in each episode 
    for t in range(10):
    
        imax = select_action(model=policy_net, 
                               state_features=state_features, 
                               binary_features=binary_features, 
                               action_features=action_features, 
                               task_features=task_features, 
                               obstacle_features=obstacle_features, epsilon=epsilon, return_succ_features=False)
        
        selected_action, selected_action_features = available_actions[imax], action_features[imax]
                               
        # if eval_round or i_episode > 0:
        if i_episode > 10 and isinstance(policy_net, (SuccessorNet, UNet)):
            with torch.no_grad():
                fig, axes = plt.subplots(ncols=4, figsize=(5, 10))
                (ax1, ax2, ax3, ax4) = axes
                ax1.imshow(state_features.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=1)
                ax1.set_title('state')

                ax2.imshow(selected_action_features.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=1)
                ax2.set_title('action')

                succ_img, succ_bin = policy_net(state_features.unsqueeze(0), 
                                    binary_features.unsqueeze(0), 
                                    selected_action_features.unsqueeze(0), 
                                    task_features.unsqueeze(0), 
                                    obstacle_features.unsqueeze(0))
                ax3.imshow(extract_img(succ_img), cmap='gray', vmin=0, vmax=1)
                ax3.set_title('prediction')

                # plot task features
                ax4.imshow(task_features.squeeze().cpu().numpy(), cmap='gray', vmin=0, vmax=1)
                ax4.set_title('task')

                for ax in axes:
                    ax.axis('off')
                    ax.invert_yaxis()

                if SHOW_PLOTS:
                    plt.show()
                plt.close(fig)
                aim_figure = Image(fig)

                run.track(aim_figure, name=f"pred_step_{t}", step=i_episode)        
                del succ_img, succ_bin
        
        # environment step
        observation, reward, terminated, truncated, info = env.step(selected_action)

        # overwrite reward
        reward = torch.sum(selected_action_features * task_features) / torch.sum(task_features)
        episode_reward += GAMMA**t * reward
        done = terminated or truncated

        # compute features
        next_state_features = get_state_features(observation, width=width, height=height)
        next_binary_features = get_binary_features(observation)
        
        next_state_features = torch.tensor(next_state_features, dtype=torch.float32, device=device).unsqueeze(0)
        next_binary_features = torch.tensor(next_binary_features, dtype=torch.float32, device=device)
        
        # compute available_actions and features
        next_available_actions = [*generate_actions(env, x_discr_ground=x_discr_ground)]
        next_actions_features = np.array(action_image_features(env, next_available_actions, width=width, height=height))
        next_actions_features = torch.tensor(next_actions_features, dtype=torch.float32, device=device).unsqueeze(1)
        next_available_actions, next_actions_features = reduce_available_actions(next_state_features, obstacle_features, next_available_actions, next_actions_features)

        # Store the transition in memory
        transition = Transition(state_features=state_features, 
                                binary_features=binary_features,
                                action_features=selected_action_features, 
                                action=selected_action,
                                reward=reward, 
                                next_state_features=next_state_features, 
                                next_binary_features=next_binary_features,
                                next_available_actions=next_available_actions,
                                next_actions_features=next_actions_features,
                                task_features=task_features,
                                obstacle_features=obstacle_features,
                                done=done)
        memory.push(transition)

        # Move to the next state        
        state_features = next_state_features
        binary_features = next_binary_features
        available_actions = next_available_actions
        action_features = next_actions_features
        
        if done:
            break
    
    if i_episode > 10:
        # Perform one step of the optimization (on the policy network)
        policy_net.train()
        if isinstance(policy_net, (UNet, SuccessorNet)):
            losses += optimize_successor_net(policy_net=policy_net,
                        target_net=target_net,
                        optimizer=optimizer, 
                        scheduler=scheduler,
                        memory=memory,
                        gamma=GAMMA,
                        n_steps=OPT_STEPS, 
                        batch_size=BATCH_SIZE, 
                        verbose=SHOW_PLOTS)
        elif isinstance(policy_net, ConvNet):
            losses += optimize_convnet(policy_net=policy_net,
                        target_net=target_net,
                        optimizer=optimizer, 
                        scheduler=scheduler,
                        memory=memory,
                        gamma=GAMMA,
                        n_steps=OPT_STEPS, 
                        batch_size=BATCH_SIZE, 
                        verbose=True)

        policy_net.eval()

        # Soft update of the target network's weights
        # θ′ ← τ θ + (1 −τ )θ′
        target_net_state_dict = target_net.state_dict()
        policy_net_state_dict = policy_net.state_dict()
        for key in policy_net_state_dict:
            target_net_state_dict[key] = policy_net_state_dict[key]*TAU + target_net_state_dict[key]*(1-TAU)
        target_net.load_state_dict(target_net_state_dict)


    # free memory
    del state_features, binary_features, action_features, next_state_features, next_binary_features, next_actions_features
    del task_features, obstacle_features
    torch.cuda.empty_cache()
    gc.collect()
    # print(f"Memory size: {memory.size_in_MB()} MB; Total used memory: {torch.cuda.memory_allocated(0) / 1024 / 1024} MB")

    # logging
    run.track(episode_reward, name='reward', step=i_episode, context=context)
    run.track(t, name='num_steps', step=i_episode, context=context)
    
    # track memory usage
    run.track(memory.size_in_MB(), name='replay_buffer_size', step=i_episode, context=context)
    run.track(torch.cuda.memory_allocated(0) / 1024 / 1024, name='total_gpu_memory', step=i_episode, context=context)

    # print(f"episode reward: {episode_reward}, steps taken: {t}, tower height: {tower_height}, epsilon: {epsilon}")
    tq.set_postfix(reward=episode_reward.cpu().numpy(), 
                    steps=t, 
                    height=tower_height, 
                    epsilon=epsilon,
                     memory=f"{memory.size_in_MB():.0f} MB", 
                     total_memory=f"{torch.cuda.memory_allocated(0) / 1024 / 1024:.0f} MB")
    episode_rewards.append(episode_reward)

# Code to debug the training, e.g. by fitting on a set of transitions

In [None]:
# transitions = memory.sample(3)

# for i, transition in enumerate(transitions):
#     print(transition.reward, transition.done)
# plot state features of each transition
# fig, axes = plt.subplots(nrows=3, ncols=4, figsize=(10, 15))
# for i, transition in enumerate(transitions):
#     (ax1, ax2, ax3, ax4) = axes[i]
#     ax1.imshow(transition.state_features.squeeze().cpu().numpy(), cmap='gray')
#     ax1.set_title('state')

#     ax2.imshow(transition.action_features.squeeze().cpu().numpy(), cmap='gray')
#     ax2.set_title('action')

#     ax3.imshow(transition.next_state_features.squeeze().cpu().numpy(), cmap='gray')
#     ax3.set_title('next state')

#     # ax4.imshow(transition.next_binary_features.cpu().numpy(), cmap='gray')
#     # ax4.set_title('next binary')

#     for ax in axes[i]:
#         ax.axis('off')
#         ax.invert_yaxis()

In [None]:
print(len(memory))
losses = []
# policy_net = ConvNet(4, img_size=(width, height)).to(device) # DQN(n_observations, n_actions).to(device)
# policy_net.apply(init_weights)

# target_net = ConvNet(4, img_size=(width, height)).to(device) # DQN(n_observations, n_actions).to(device)
# target_net.load_state_dict(policy_net.state_dict())

add_skip=True
# policy_net = UNet(4, add_skip=add_skip, img_size=(width, height)).to(device) # DQN(n_observations, n_actions).to(device)
# policy_net.apply(init_weights)

# target_net = UNet(4, add_skip=add_skip, img_size=(width, height)).to(device) # DQN(n_observations, n_actions).to(device)
# target_net.load_state_dict(policy_net.state_dict())

hidden_dims = [256, 128, 64, 128, 256]
policy_net = SuccessorNet(img_size=(width, height), hidden_dims=hidden_dims).to(device) # DQN(n_observations, n_actions).to(device)
policy_net.apply(init_weights)

target_net = SuccessorNet(img_size=(width, height), hidden_dims=hidden_dims).to(device) # DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

In [None]:
optimizer = optim.Adam(policy_net.parameters(), lr=0.0001) 

policy_net.train()
losses += optimize_successor_net(policy_net=policy_net,
                        target_net=target_net,
                        optimizer=optimizer, 
                        scheduler=scheduler,
                        memory=memory,
                        gamma=GAMMA,
                        n_steps=500, 
                        batch_size=32, 
                        verbose=True)

In [None]:
plt.plot(losses)

# sample predictions

num_predictions = 10
transitions = memory.sample(num_predictions)

fig, axes = plt.subplots(nrows=num_predictions, ncols=4, figsize=(4*2, num_predictions*2 ))
for i, transition in enumerate(transitions):
    (ax1, ax2, ax3, ax4) = axes[i]
    ax1.imshow(transition.state_features.squeeze().cpu().numpy(), cmap='gray')
    ax1.set_title('state')

    ax2.imshow(transition.action_features.squeeze().cpu().numpy(), cmap='gray')
    ax2.set_title('action')

    succ_img, succ_bin = policy_net(transition.state_features.unsqueeze(0), 
                                    transition.binary_features.unsqueeze(0), 
                                    transition.action_features.unsqueeze(0), 
                                    transition.task_features.unsqueeze(0), 
                                    transition.obstacle_features.unsqueeze(0))
    ax3.imshow(extract_img(succ_img), cmap='gray')
    ax3.set_title('prediction')

    ax4.imshow(transition.task_features.squeeze().cpu().numpy(), cmap='gray')
    ax4.set_title('task')

    for ax in axes[i]:
        ax.axis('off')
        ax.invert_yaxis()

fig.tight_layout()
fig.subplots_adjust()

In [None]:
# import torch
# import torch.nn as nn
# import torch.optim as optim

# # Generate a random target tensor
# num_actions = 1

# x = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
# o = torch.rand(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
# a = torch.rand(num_actions, 1, width, height, device=device)
# z = torch.ones(1, 1, width, height, device=device).expand((num_actions, -1, -1 ,-1))
# y = torch.ones(1, 6, device=device).expand((num_actions, -1))


# target = torch.rand(num_actions, 1, height, width, device=device)
# print(target.shape)
# target = torch.tensor(task, device=device, dtype=torch.float32).view(num_actions, 1, height, width)
# print(target)


# # plot target
# plt.imshow(target[0, 0].detach().cpu().numpy(), cmap='gray')


# # Define the network architecture
# model = UNet(4, add_skip=False, img_size=(width, height)).to(device)
# model.apply(init_weights)

# # Define the loss function
# # criterion = nn.MSELoss()
# criterion = nn.CrossEntropyLoss()

# # Define the optimizer
# # optimizer = optim.SGD(model.parameters(), lr=0.01)

# # use Adam optimizer
# optimizer = optim.Adam(model.parameters(), lr=0.01)

# # Training loop
# num_epochs = 200
# it = tqdm(range(num_epochs))
# for epoch in it:
#     # Forward pass
#     output, _ = model(state=x, features=y, action=a, task=z, obstacles=o)
#     # print(output.shape, target.shape)

#     # Compute the loss
#     loss = criterion(output, target.squeeze(1).long())
    
#     # Backward pass
#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
    
#     # Print the loss for monitoring
#     it.set_postfix(loss=loss.item())

#     if epoch % 100 == 0:
#         # plot
#         img = output[0].softmax(dim=0)[1]

#         fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
#         ax1.imshow(target[0, 0].detach().cpu().numpy(), cmap='gray')
#         ax1.set_title('Target')
#         ax2.imshow(img.detach().cpu().numpy(), cmap='gray')
#         ax2.set_title('Output')
#         plt.show(fig)

#     # (f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")

# # Evaluate the trained network on the target tensor
# # output, _ = model(x, y, a, z, o)
# # evaluation_loss = criterion(output, target.squeeze(1).long())
# # print(f"Evaluation Loss: {evaluation_loss.item()}")



In [None]:
def MLP(in_dim=10, out_dim=10, hidden_dim=[2, 2]):
    hidden_dim.append(out_dim)

    for cur_dim in hidden_dim:
        print(f"adding layer {in_dim, cur_dim}")
        in_dim = cur_dim

print("first MLP")
MLP()
print("second MLP")
MLP()