<a href="https://colab.research.google.com/github/supsi-dacd-isaac/TeachDecisionMakingUncertainty/blob/main/L09/Robust_MPC_the_need_of_feedback.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Robust MPC - the need of feedback


The following cell provides an intuitive example of why the control action of a robust controller can't be a scalar. That is, the control action must be either a **feedback law**, or we **should consider more than one control action** in the future, to be able to counteract all the possible adversial events.

Suppose we system dynamics is a simple integrator, where the noise $w$ at each step can be either -0.5 or 0.5, with unknown probability, and we want to be sure the state is bounded in the -1, 1 interval. The robust problem can be written as:

\begin{aligned}
\min_{U} \color{red}{\max_w} \
 &s_T^{\top} P s_T + \sum_{t=0}^{T-1}\left[s_t^{\top} Q s_t+u_t^{\top} R u_t\right]\\
&s.t.: \quad s_{t+1}=s_t+u_t +\color{red}{w_t} \qquad \ \forall \  \color{red}{w \in \{-0.5, 0.5\}}\\
& \qquad \qquad s_{t=0} = s_0 \\
& \qquad \qquad s_t \in [-1, 1] \quad \forall t
\end{aligned}

In this case it is impossible to counteract all the possible disturbances with a simple open loop controller.
The following code generates a disturbance tree encoding all the possible future paths of $w_t$. The function `find_violating_sequences` then finds in an adversarial way the worst sequene of disturbances given your control. Try to see if you can find a feasible control!

In [3]:
import itertools
import numpy as np

# Parameters
K = 5
x_init = 0.0
x_bounds = (-1, 1)
w_values = [-0.5, 0.5]

def generate_disturbance_tree():
    tree = [[()]]  # Empty path for step 0
    for step in range(1, K):
        level = list(itertools.product(w_values, repeat=step))
        tree.append(level)
    return tree

disturbance_tree = generate_disturbance_tree()

def simulate(x0, u_seq, w_seq):
    x = [x0]
    for k in range(K):
        x_next = x[-1] + u_seq[k] + w_seq[k]
        x.append(x_next)
    return x

def find_violating_sequence(u_seq):
    candidates = disturbance_tree[K - 1]
    for w_seq in candidates:
        w_full = [0.0] + list(w_seq)
        x_traj = simulate(x_init, u_seq, w_full)
        if any(abs(x) > 1 for x in x_traj):
            return w_full, x_traj
    return [0.0] * K, simulate(x_init, u_seq, [0.0] * K)

In [None]:
#@title Plotting Adversarial Actions
# --- Fix: Update node_id and interactive layout ---

import matplotlib.pyplot as plt
from ipywidgets import FloatSlider, VBox, HBox, interactive_output
import networkx as nx
import matplotlib
matplotlib.rcParams['font.family'] = 'DejaVu Sans'


def node_id(step, path):
    if step == 0:
        return f"{step}-root"
    return f"{step}-{'_'.join(f'{x:.1f}' for x in path)}"

def plot_violation(u0, u1, u2, u3, u4):
    import matplotlib.pyplot as plt
    from matplotlib.gridspec import GridSpec

    u_seq = [u0, u1, u2, u3, u4]
    w_seq, x_seq = find_violating_sequence(u_seq)
    steps = np.arange(K + 1)

    fig = plt.figure(figsize=(14, 5))
    gs = GridSpec(1, 2, width_ratios=[1.4, 1], figure=fig)

    # --- Left: Control plot
    ax_left = fig.add_subplot(gs[0, 0])
    ax_left.set_title("Control 😇")

    ax_left.bar(steps[:-1] - 0.25, u_seq, width=0.25, color='skyblue', label='$u_k$')
    ax_left.bar(steps[:-1], w_seq, width=0.25, color='lightcoral', label='$w_k$')
    ax_left.bar(steps + 0.25, x_seq, width=0.25, color='lightgreen', label='$x_k$')

    ax_left.axhline(x_bounds[1], color='red', linestyle='--')
    ax_left.axhline(x_bounds[0], color='red', linestyle='--')

    ax_left.set_xticks(steps)
    ax_left.set_xlabel('Time Step')
    ax_left.set_ylim(-2, 2)
    ax_left.legend()
    ax_left.grid(True)

    # --- Right: Tree plot
    ax_right = fig.add_subplot(gs[0, 1])
    plot_disturbance_tree(w_seq[1:], ax_right)
    ax_right.set_title("Worst-case Disturbance 😈")  # 💥 Ensure it's called AFTER drawing

    plt.tight_layout()
    plt.show()



def plot_disturbance_tree(selected_path, ax):
    import networkx as nx

    G = nx.DiGraph()
    node_labels = {}
    node_positions = {}

    max_width = 2 ** (K - 1)

    for step in range(K):
        level = disturbance_tree[step]
        n_nodes = len(level)
        vertical_spacing = max_width / n_nodes
        y_center_offset = (n_nodes - 1) / 2

        for i, path in enumerate(level):
            nid = node_id(step, path)
            G.add_node(nid)
            node_labels[nid] = f"{path[-1]:.1f}" if step > 0 else "0.0"
            y_pos = (i - y_center_offset) * vertical_spacing
            node_positions[nid] = (step, y_pos)

            if step > 0:
                parent_path = path[:-1]
                parent_id = node_id(step - 1, parent_path)
                G.add_edge(parent_id, nid)

    # selected_path = full sequence [0.0, w1, ..., w_{K-1}]
    selected_ids = []
    for step in range(K):
        if step == 0:
            selected_ids.append("0-root")
        else:
            selected_ids.append(node_id(step, tuple(selected_path[:step])))

    node_colors = ["red" if n in selected_ids else "lightblue" for n in G.nodes]

    nx.draw(G, pos=node_positions, ax=ax, with_labels=False, node_color=node_colors, node_size=500)
    nx.draw_networkx_labels(G, pos=node_positions, labels=node_labels, font_size=8, ax=ax)
    ax.set_title("Disturbance Tree")
    ax.axis('off')


# Define sliders
sliders = {f'u{k}': FloatSlider(min=-1, max=1, step=0.1, value=0, description=f'u{k}') for k in range(K)}
slider_box = VBox([sliders[f'u{k}'] for k in range(K)])

# Interactive output
out = interactive_output(plot_violation, sliders)

# Final layout
display(HBox([slider_box, out]))
