In [72]:
from torch import nn
from bound_propagation import BoundModelFactory, HyperRectangle
from bound_propagation.polynomial import Pow
import torch
import numpy as np

class NNDM(nn.Sequential):
    def __init__(self):
        super(NNDM, self).__init__(
            nn.Linear(28, 64),
            nn.Tanh(),
            nn.Linear(64, 24),
        )

    def forward(self, x):
        out = super().forward(x)
        return out + x[:,:24]

    
class HHead(nn.Sequential):
    def __init__(self):
        
        super().__init__(
            Pow(2),
            nn.Linear(24, 2)
        )
        # TODO: implement h function
        # self[1].weight.data = torch.tensor([[-1/x_0_max**2, 0, 0, 0],
        #                                     [0, 0, -1/x_2_max**2, 0]])
        # self[1].bias.data = torch.tensor([1., 1.])

class CombinedModel(nn.Sequential):
    def __init__(self):
        super(CombinedModel, self).__init__()
        self.add_module('nndm', NNDM())
        self.add_module('hhead', HHead())
    

# thetanet = HThetaHead(pole_angel)
net = CombinedModel()

factory = BoundModelFactory()
boundnet = factory.build(net)

In [73]:
import torch.nn.functional as F

class Actor(nn.Module):

    def __init__(self, n_observations, n_actions):
        super(Actor, self).__init__()
        self.layer1 = nn.Linear(n_observations, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, n_actions)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        x = torch.tanh(self.layer3(x))
        return self.layer3(x)
    
class Critic(nn.Module):

    def __init__(self, no_state_actions):
        super(Critic, self).__init__()
        self.layer1 = nn.Linear(no_state_actions, 128)
        self.layer2 = nn.Linear(128, 128)
        self.layer3 = nn.Linear(128, 1)

    def forward(self, x):
        x = F.relu(self.layer1(x))
        x = F.relu(self.layer2(x))
        return self.layer3(x)
    
actor_network = Actor(24, 4)
critic_network = Critic(28)

# TODO: when we have trained the model
# dqn_model.load_state_dict(torch.load(PATH_TO_MODEL_WEIGHTS))

In [74]:
def create_action_partitions(env, partitions):
    action_space = env.action_space
    num_actions = action_space.shape[0]
    action_low = action_space.low
    action_high = action_space.high

    res = []

    def generate_partitions(dimensions, lower, upper, current_partition):
        if dimensions == num_actions:
            # If we've reached the number of dimensions, add the current partition
            res.append(HyperRectangle(np.array(lower), np.array(upper)))
        else:
            # Calculate the size of the partition for the current dimension
            partition_size = (action_high[dimensions] - action_low[dimensions]) / partitions

            for part in range(partitions):
                # Determine the lower and upper bounds for the current dimension
                dim_lower_bound = action_low[dimensions] + part * partition_size
                dim_upper_bound = dim_lower_bound + partition_size

                # Recursively generate partitions for the next dimension
                generate_partitions(dimensions + 1, lower + [dim_lower_bound], upper + [dim_upper_bound], current_partition)

    generate_partitions(0, [], [], [])

    return res

In [75]:
def get_lower_bound(model, state, action, epsilon):
    state_action = torch.cat((state, action), dim=1).view(1, -1)
    input_bounds = HyperRectangle.from_eps(state_action, epsilon)
    crown_bounds = model.crown(input_bounds)

    lower_bound = crown_bounds.lower[0].unsqueeze(0)

    return lower_bound.view(-1, state_action.shape[1])

In [76]:
def create_bound_matrices(partitions, state, env):
    action_space = env.action_space
    num_actions = action_space.shape[0]
    res = []
    for partition in partitions:
        action = torch.tensor(partition.center).view(1, -1)
        bounds = get_lower_bound(boundnet, state, action, 0.1)
        action_bounds = bounds[:, -num_actions:]
        state_bounds = bounds[:, :-num_actions]
        state_vec = state_bounds @ state.to(state_bounds.dtype).t()
        vecs = (partition, action_bounds.detach().numpy(), state_vec.detach().numpy())
        res.append(vecs)

    return res

In [77]:
import gymnasium as gym


env = gym.make("BipedalWalker-v3")

state = torch.rand(1, 24)
action = torch.rand(1, 4)

partitions = create_action_partitions(env, 2)
bound_matrices = create_bound_matrices(partitions, state, env)

In [78]:
class InfeasibilityError(Exception):
    """Exception raised if there are no actions that fulfill the safety criterions."""

    def __init__(self, message="No safe action to take"):
        self.message = message
        super().__init__(self.message)
