<a href="https://colab.research.google.com/github/udlbook/udlbook/blob/main/Blogs/BorealisGradientFlow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gradient flow

This notebook replicates some of the results in the the Borealis AI [blog](https://www.borealisai.com/research-blogs/gradient-flow/) on gradient flow.  


In [None]:
# Import relevant libraries
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import expm
from matplotlib import cm
from matplotlib.colors import ListedColormap

Create the three data points that are used to train the linear model in the blog.  Each input point is a column in $\mathbf{X}$ and consists of the $x$ position in the plot and the value 1, which is used to allow the model to fit bias terms neatly.

In [None]:
X = np.array([[0.2, 0.4, 0.8], [1, 1, 1]])
y = np.array([[-0.1], [0.15], [0.3]])
D = X.shape[0]
I = X.shape[1]

print("X=\n", X)
print("y=\n", y)

In [None]:
# Draw the three data points
fig, ax = plt.subplots()
ax.plot(X[0:1, :], y.T, "ro")
ax.set_xlim([0, 1])
ax.set_ylim([-0.5, 0.5])
ax.set_xlabel("x")
ax.set_ylabel("y")
plt.show()

Compute the evolution of the residuals, loss, and parameters as a function of time.

In [None]:
# Discretized time to evaluate quantities at
t_all = np.arange(0, 20, 0.01)
nT = t_all.shape[0]

# Initial parameters, and initial function output at training points
phi_0 = np.array([[-0.05], [-0.4]])
f_0 = X.T @ phi_0

# Precompute pseudoinverse term (not a very sensible numerical implementation, but it works...)
XXTInvX = np.linalg.inv(X @ X.T) @ X

# Create arrays to hold function at data points over time, residual over time, parameters over time
f_all = np.zeros((I, nT))
f_minus_y_all = np.zeros((I, nT))
phi_t_all = np.zeros((D, nT))

# For each time, compute function, residual, and parameters at each time.
for t in range(len(t_all)):
    f = y + expm(-X.T @ X * t_all[t]) @ (f_0 - y)
    f_all[:, t : t + 1] = f
    f_minus_y_all[:, t : t + 1] = f - y
    phi_t_all[:, t : t + 1] = phi_0 - XXTInvX @ (
        np.identity(3) - expm(-X.T @ X * t_all[t])
    ) @ (f_0 - y)

Plot the results that were calculated in the previous cell

In [None]:
# Plot function at data points
fig, ax = plt.subplots()
ax.plot(t_all, np.squeeze(f_all[0, :]), "r-", label="$f[x_{0},\phi]$")
ax.plot(t_all, np.squeeze(f_all[1, :]), "g-", label="$f[x_{1},\phi]$")
ax.plot(t_all, np.squeeze(f_all[2, :]), "b-", label="$f[x_{2},\phi]$")
ax.set_xlim([0, np.max(t_all)])
ax.set_ylim([-0.5, 0.5])
ax.set_xlabel("t")
ax.set_ylabel("f")
plt.legend(loc="lower right")
plt.show()

# Plot residual
fig, ax = plt.subplots()
ax.plot(t_all, np.squeeze(f_minus_y_all[0, :]), "r-", label="$f[x_{0},\phi]-y_{0}$")
ax.plot(t_all, np.squeeze(f_minus_y_all[1, :]), "g-", label="$f[x_{1},\phi]-y_{1}$")
ax.plot(t_all, np.squeeze(f_minus_y_all[2, :]), "b-", label="$f[x_{2},\phi]-y_{2}$")
ax.set_xlim([0, np.max(t_all)])
ax.set_ylim([-0.5, 0.5])
ax.set_xlabel("t")
ax.set_ylabel("f-y")
plt.legend(loc="lower right")
plt.show()

# Plot loss (sum of residuals)
fig, ax = plt.subplots()
square_error = 0.5 * np.sum(f_minus_y_all * f_minus_y_all, axis=0)
ax.plot(t_all, square_error, "k-")
ax.set_xlim([0, np.max(t_all)])
ax.set_ylim([-0.0, 0.25])
ax.set_xlabel("t")
ax.set_ylabel("Loss")
plt.show()

# Plot parameters
fig, ax = plt.subplots()
ax.plot(t_all, np.squeeze(phi_t_all[0, :]), "c-", label="$\phi_{0}$")
ax.plot(t_all, np.squeeze(phi_t_all[1, :]), "m-", label="$\phi_{1}$")
ax.set_xlim([0, np.max(t_all)])
ax.set_ylim([-1, 1])
ax.set_xlabel("t")
ax.set_ylabel("$\phi$")
plt.legend(loc="lower right")
plt.show()

Define the model and the loss function

In [None]:
# Model is just a straight line with intercept phi[0] and slope phi[1]
def model(phi, x):
    y_pred = phi[0] + phi[1] * x
    return y_pred


# Loss function is 0.5 times sum of squares of residuals for training data
def compute_loss(data_x, data_y, model, phi):
    pred_y = model(phi, data_x)
    loss = 0.5 * np.sum((pred_y - data_y) * (pred_y - data_y))
    return loss

Draw the loss function

In [None]:
def draw_loss_function(compute_loss, X, y, model, phi_iters):
    # Define pretty colormap
    my_colormap_vals_hex = (
        "2a0902",
        "2b0a03",
        "2c0b04",
        "2d0c05",
        "2e0c06",
        "2f0d07",
        "300d08",
        "310e09",
        "320f0a",
        "330f0b",
        "34100b",
        "35110c",
        "36110d",
        "37120e",
        "38120f",
        "39130f",
        "3a1410",
        "3b1411",
        "3c1511",
        "3d1612",
        "3e1613",
        "3f1713",
        "401714",
        "411814",
        "421915",
        "431915",
        "451a16",
        "461b16",
        "471b17",
        "481c17",
        "491d18",
        "4a1d18",
        "4b1e19",
        "4c1f19",
        "4d1f1a",
        "4e201b",
        "50211b",
        "51211c",
        "52221c",
        "53231d",
        "54231d",
        "55241e",
        "56251e",
        "57261f",
        "58261f",
        "592720",
        "5b2821",
        "5c2821",
        "5d2922",
        "5e2a22",
        "5f2b23",
        "602b23",
        "612c24",
        "622d25",
        "632e25",
        "652e26",
        "662f26",
        "673027",
        "683027",
        "693128",
        "6a3229",
        "6b3329",
        "6c342a",
        "6d342a",
        "6f352b",
        "70362c",
        "71372c",
        "72372d",
        "73382e",
        "74392e",
        "753a2f",
        "763a2f",
        "773b30",
        "783c31",
        "7a3d31",
        "7b3e32",
        "7c3e33",
        "7d3f33",
        "7e4034",
        "7f4134",
        "804235",
        "814236",
        "824336",
        "834437",
        "854538",
        "864638",
        "874739",
        "88473a",
        "89483a",
        "8a493b",
        "8b4a3c",
        "8c4b3c",
        "8d4c3d",
        "8e4c3e",
        "8f4d3f",
        "904e3f",
        "924f40",
        "935041",
        "945141",
        "955242",
        "965343",
        "975343",
        "985444",
        "995545",
        "9a5646",
        "9b5746",
        "9c5847",
        "9d5948",
        "9e5a49",
        "9f5a49",
        "a05b4a",
        "a15c4b",
        "a35d4b",
        "a45e4c",
        "a55f4d",
        "a6604e",
        "a7614e",
        "a8624f",
        "a96350",
        "aa6451",
        "ab6552",
        "ac6552",
        "ad6653",
        "ae6754",
        "af6855",
        "b06955",
        "b16a56",
        "b26b57",
        "b36c58",
        "b46d59",
        "b56e59",
        "b66f5a",
        "b7705b",
        "b8715c",
        "b9725d",
        "ba735d",
        "bb745e",
        "bc755f",
        "bd7660",
        "be7761",
        "bf7862",
        "c07962",
        "c17a63",
        "c27b64",
        "c27c65",
        "c37d66",
        "c47e67",
        "c57f68",
        "c68068",
        "c78169",
        "c8826a",
        "c9836b",
        "ca846c",
        "cb856d",
        "cc866e",
        "cd876f",
        "ce886f",
        "ce8970",
        "cf8a71",
        "d08b72",
        "d18c73",
        "d28d74",
        "d38e75",
        "d48f76",
        "d59077",
        "d59178",
        "d69279",
        "d7937a",
        "d8957b",
        "d9967b",
        "da977c",
        "da987d",
        "db997e",
        "dc9a7f",
        "dd9b80",
        "de9c81",
        "de9d82",
        "df9e83",
        "e09f84",
        "e1a185",
        "e2a286",
        "e2a387",
        "e3a488",
        "e4a589",
        "e5a68a",
        "e5a78b",
        "e6a88c",
        "e7aa8d",
        "e7ab8e",
        "e8ac8f",
        "e9ad90",
        "eaae91",
        "eaaf92",
        "ebb093",
        "ecb295",
        "ecb396",
        "edb497",
        "eeb598",
        "eeb699",
        "efb79a",
        "efb99b",
        "f0ba9c",
        "f1bb9d",
        "f1bc9e",
        "f2bd9f",
        "f2bfa1",
        "f3c0a2",
        "f3c1a3",
        "f4c2a4",
        "f5c3a5",
        "f5c5a6",
        "f6c6a7",
        "f6c7a8",
        "f7c8aa",
        "f7c9ab",
        "f8cbac",
        "f8ccad",
        "f8cdae",
        "f9ceb0",
        "f9d0b1",
        "fad1b2",
        "fad2b3",
        "fbd3b4",
        "fbd5b6",
        "fbd6b7",
        "fcd7b8",
        "fcd8b9",
        "fcdaba",
        "fddbbc",
        "fddcbd",
        "fddebe",
        "fddfbf",
        "fee0c1",
        "fee1c2",
        "fee3c3",
        "fee4c5",
        "ffe5c6",
        "ffe7c7",
        "ffe8c9",
        "ffe9ca",
        "ffebcb",
        "ffeccd",
        "ffedce",
        "ffefcf",
        "fff0d1",
        "fff2d2",
        "fff3d3",
        "fff4d5",
        "fff6d6",
        "fff7d8",
        "fff8d9",
        "fffada",
        "fffbdc",
        "fffcdd",
        "fffedf",
        "ffffe0",
    )
    my_colormap_vals_dec = np.array(
        [int(element, base=16) for element in my_colormap_vals_hex]
    )
    r = np.floor(my_colormap_vals_dec / (256 * 256))
    g = np.floor((my_colormap_vals_dec - r * 256 * 256) / 256)
    b = np.floor(my_colormap_vals_dec - r * 256 * 256 - g * 256)
    my_colormap = ListedColormap(np.vstack((r, g, b)).transpose() / 255.0)

    # Make grid of intercept/slope values to plot
    intercepts_mesh, slopes_mesh = np.meshgrid(
        np.arange(-1.0, 1.0, 0.005), np.arange(-1.0, 1.0, 0.005)
    )
    loss_mesh = np.zeros_like(slopes_mesh)
    # Compute loss for every set of parameters
    for idslope, slope in np.ndenumerate(slopes_mesh):
        loss_mesh[idslope] = compute_loss(
            X, y, model, np.array([[intercepts_mesh[idslope]], [slope]])
        )

    fig, ax = plt.subplots()
    fig.set_size_inches(8, 8)
    ax.contourf(intercepts_mesh, slopes_mesh, loss_mesh, 256, cmap=my_colormap)
    ax.contour(intercepts_mesh, slopes_mesh, loss_mesh, 40, colors=["#80808080"])
    ax.set_ylim([1, -1])
    ax.set_xlim([-1, 1])

    ax.plot(phi_iters[1, :], phi_iters[0, :], "g-")
    ax.set_xlabel("Intercept")
    ax.set_ylabel("Slope")
    plt.show()

In [None]:
draw_loss_function(compute_loss, X[0:1, :], y.T, model, phi_t_all)

Draw the evolution of the function

In [None]:
fig, ax = plt.subplots()
ax.plot(X[0:1, :], y.T, "ro")
x_vals = np.arange(0, 1, 0.001)
ax.plot(x_vals, phi_t_all[0, 0] * x_vals + phi_t_all[1, 0], "r-", label="t=0.00")
ax.plot(x_vals, phi_t_all[0, 10] * x_vals + phi_t_all[1, 10], "g-", label="t=0.10")
ax.plot(x_vals, phi_t_all[0, 30] * x_vals + phi_t_all[1, 30], "b-", label="t=0.30")
ax.plot(x_vals, phi_t_all[0, 200] * x_vals + phi_t_all[1, 200], "c-", label="t=2.00")
ax.plot(x_vals, phi_t_all[0, 1999] * x_vals + phi_t_all[1, 1999], "y-", label="t=20.0")
ax.set_xlim([0, 1])
ax.set_ylim([-0.5, 0.5])
ax.set_xlabel("x")
ax.set_ylabel("y")
plt.legend(loc="upper left")
plt.show()

In [None]:
# Compute MAP and ML solutions
MLParams = np.linalg.inv(X @ X.T) @ X @ y
sigma_sq_p = 3.0
sigma_sq = 0.05
MAPParams = (
    np.linalg.inv(X @ X.T + np.identity(X.shape[0]) * sigma_sq / sigma_sq_p) @ X @ y
)

Finally, we predict both the mean and the uncertainty in the fitted model as a function of time

In [None]:
# Define x positions to make predictions (appending a 1 to each column)
x_predict = np.arange(0, 1, 0.01)[None, :]
x_predict = np.concatenate((x_predict, np.ones_like(x_predict)))
nX = x_predict.shape[1]

# Create variables to store evolution of mean and variance of prediction over time
predict_mean_all = np.zeros((nT, nX))
predict_var_all = np.zeros((nT, nX))

# Initial covariance
sigma_sq_p = 2.0
cov_init = sigma_sq_p * np.identity(2)

# Run through each time computing a and b and hence mean and variance of prediction
for t in range(len(t_all)):
    a = x_predict.T @ (XXTInvX @ (np.identity(3) - expm(-X.T @ X * t_all[t])) @ y)
    b = (
        x_predict.T
        - x_predict.T @ XXTInvX @ (np.identity(3) - expm(-X.T @ X * t_all[t])) @ X.T
    )
    predict_mean_all[t : t + 1, :] = a.T
    predict_cov = b @ cov_init @ b.T
    # We just want the diagonal of the covariance to plot the uncertainty
    predict_var_all[t : t + 1, :] = np.reshape(np.diag(predict_cov), (1, nX))

Plot the mean and variance at various times

In [None]:
def plot_mean_var(
    X, y, x_predict, predict_mean_all, predict_var_all, this_t, sigma_sq=0.00001
):
    fig, ax = plt.subplots()
    ax.plot(X[0:1, :], y.T, "ro")
    ax.plot(x_predict[0:1, :].T, predict_mean_all[this_t : this_t + 1, :].T, "r-")
    lower = np.squeeze(
        predict_mean_all[this_t : this_t + 1, :].T
        - np.sqrt(predict_var_all[this_t : this_t + 1, :].T + np.sqrt(sigma_sq))
    )
    upper = np.squeeze(
        predict_mean_all[this_t : this_t + 1, :].T
        + np.sqrt(predict_var_all[this_t : this_t + 1, :].T + np.sqrt(sigma_sq))
    )
    ax.fill_between(np.squeeze(x_predict[0:1, :]), lower, upper, color="lightgray")
    ax.set_xlim([0, 1])
    ax.set_ylim([-0.5, 0.5])
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    plt.show()


plot_mean_var(X, y, x_predict, predict_mean_all, predict_var_all, this_t=0)
plot_mean_var(X, y, x_predict, predict_mean_all, predict_var_all, this_t=40)
plot_mean_var(X, y, x_predict, predict_mean_all, predict_var_all, this_t=80)
plot_mean_var(X, y, x_predict, predict_mean_all, predict_var_all, this_t=200)
plot_mean_var(X, y, x_predict, predict_mean_all, predict_var_all, this_t=500)
plot_mean_var(X, y, x_predict, predict_mean_all, predict_var_all, this_t=1000)