# Backpropagation on GRU

In [1]:
import numpy as np
from numpy import tanh

In [2]:

# Define activation functions
def sigmoid(x):
    """Sigmoid activation function."""
    return 1 / (1 + np.exp(-x))


def dsigmoid(x):
    """Derivative of sigmoid function."""
    return sigmoid(x) * (1 - sigmoid(x))


def dtanh(x):
    """Derivative of tanh function."""
    return 1 - np.tanh(x)**2


## forward pass

$$
h_t = (1-z_t) * h_{t-1} + z_t * \tilde{h}_t
$$
$$
\tilde{h}_t = g\left( W_h x_t + U_h (r_t * h_{t-1}) \right)
$$

and for the gates

$$
z_t = \sigma \left( W_z x_t + U_z h_{t-1} \right)
$$
$$
r_t = \sigma \left( W_r x_t + U_r h_{t-1}\right)
$$

For a simple example containing just one sample $x \in \mathbb{R}^2$, suppose that the the weights at a certain point look like that:

In [3]:
# The sample
x = np.array([0.2, 0.3, 0.4])
y = 7.0

# Initialize Weights
Wh = np.array([0.2, 0.9])

Wz = np.array([0.1, 3.1])
Wr = np.array([2.3, 0.5])

Uh = np.array([[1.5, 2.6], [1.8, 3.6]])
Uz = np.array([[0.1, 4.1], [0.2, 1.0]])
Ur = np.array([[1.3, 7.1], [9.1, 4.5]])

w = np.array([2.0, 4.0])

We can implement the forward pass like that:

In [4]:
h = [np.zeros_like(Wh)]
h_ = []
z = []
r = []
y_ = []

for t, xt in enumerate(x):
    # Calculate values of the gates
    zt = sigmoid(Wz.dot(xt) + Uh.dot(h[t]))
    rt = sigmoid(Wr.dot(xt) + Ur.dot(h[t]))

    # Calculate candidate update
    h_t = tanh(Wh.dot(xt) + Uh.dot(rt * h[t]))

    # Calculate cell state
    ht = (1-zt) * h[t] + zt * h_t

    # Calculate prediction at step t
    y_t = w.dot(ht)

    # Save variables to container
    h.append(ht)
    h_.append(h_t)
    z.append(zt)
    r.append(rt)
    y_.append(y_t)

## backward pass

We need the following gradients

* the candidate update $\frac{\partial loss}{\partial W_h}$ and $\frac{\partial loss}{\partial U_h}$
* the update for the update gate $\frac{\partial loss}{\partial W_z}$ and $\frac{\partial loss}{\partial U_z}$
* the update for the recurrent gate $\frac{\partial loss}{\partial W_r}$ and $\frac{\partial loss}{\partial U_r}$

### Weights for the candidate $\tilde{h}_t$

\begin{align}
\frac{\partial {loss}_t}{\partial W_h} &= \frac{\partial {loss}_t}{\partial \hat{y}_t}\frac{\partial \hat{y}_t}{\partial h_t}\frac{\partial h_t}{\partial \tilde{h}_t}\frac{\partial \tilde{h}_t}{\partial W_h} \\
                                   &= (y-\hat{y}_t) w * z_t * g^\prime \left( W_h x_t + U_h \cdot (r_t * h_{t-1}) \right) x_t
\end{align}

In [5]:
dLossdWh = np.zeros_like(Wh)

for t, xt in enumerate(x):
    # Note that `h` has an entry at start, so indexing at t accesses h_{t-1}
    dh_tdWh = dtanh(Wh.dot(xt) + Uh.dot(r[t] * h[t])) * xt
    dLossdWh += (y-y_[t]) * w * z[t] * dh_tdWh

\begin{align}
\frac{\partial {loss}_t}{\partial U_h} &= \frac{\partial {loss}_t}{\partial \hat{y}_t}\frac{\partial \hat{y}_t}{\partial h_t}\frac{\partial h_t}{\partial \tilde{h}_t}\frac{\partial \tilde{h}_t}{\partial U_h} \\
                                    &= (y-\hat{y}_t)\ \textbf{diag} \left[ w * z_t * g^\prime \left( W_h x_t + U_h \cdot (r_t * h_{t-1}) \right)\right] \begin{bmatrix} (r * h_{t-1})^T \\ (r * h_{t-1})^T \end{bmatrix}
\end{align}

In [6]:
dLossdUh = np.zeros_like(Uh)

for t, xt in enumerate(x):
    dy_dh_t = w * z[t] * dtanh(Wh.dot(xt) + Uh.dot(r[t] * h[t]))
    dLossdUh += (y-y_[t]) * dy_dh_t.reshape(-1, 1) * h[t]

### Weights for the update gate $z_t$

\begin{align}
\frac{\partial {loss}_t}{\partial W_z} &= \frac{\partial {loss}_t}{\partial \hat{y}_t}\frac{\partial \hat{y}_t}{\partial h_t}\frac{\partial h_t}{\partial z_t}\frac{\partial z_t}{\partial W_z} \\
                                   &= (y-\hat{y}_t)\ w * [-h_t + \tilde{h}_t] * \sigma^\prime \left( W_z x_t + U_z h_{t-1} \right) x_t
\end{align}

In [7]:
dLossdWz = np.zeros_like(Wz)

for t, xt in enumerate(x):
    dztdWz = dsigmoid(Wz.dot(xt) + Uz.dot(h[t])) * xt
    dLossdWz += (y-y_t) * w * (-h[t+1] - h_[t]) * dztdWz

\begin{align}
\frac{\partial {loss}_t}{\partial U_z} &= \frac{\partial loss}{\partial \hat{y}_t}\frac{\partial \hat{y}_t}{\partial h_t}\frac{\partial h_t}{\partial z_t}\frac{\partial z_t}{\partial U_z} \\
                                   &= (y-\hat{y}_t)\ \textbf{diag} \left[ w * [-h_t + \tilde{h}_t] * \sigma^\prime \left( W_z x_t + U_z h_{t-1} \right)\right] \begin{bmatrix} h_{t-1}^T \\ h_{t-1}^T \end{bmatrix}
\end{align}

In [8]:
dLossdUz = np.zeros_like(Uz)

for t, xt in enumerate(x):
    dy_dzt = w * [-h[t+1] + h_[t] * dsigmoid(Wz.dot(xt) + Uz.dot(h[t]))]
    dLossdUz += (y-y_[t]) * dy_dzt.reshape(-1, 1) * h[t]

### Weights for the recurrent gate $r_t$
\begin{align}
\frac{\partial {loss}_t}{\partial W_r} &= \frac{\partial {loss}_t}{\partial \hat{y}_t}\frac{\partial \hat{y}_t}{\partial h_t}\frac{\partial h_t}{\partial \tilde{h}_t}\frac{\partial \tilde{h}_t}{\partial W_r}
\end{align}

However, due to the entangled nature of $U_h$, $r_t$ and $h_{t-1}$ and the element-wise multiplication inside the parentheses it is not immediately obvious how to obtain the derivative with respect to any of the inner weight matrices. Thus, we will rewrite $\tilde{h}_t$ as a system of linear equations. For ease of notation denote

\begin{equation}
W_h = \begin{bmatrix} w_1^h \\ w_2^h \end{bmatrix} \quad U_h = \begin{bmatrix} u_{11}^h & u_{12}^h \\ u_{21}^h &
u_{22}^h \end{bmatrix} \quad W_r = \begin{bmatrix} w_{1}^r \\ w_{2}^r \end{bmatrix} \quad U_r = \begin{bmatrix} u_{11}^r & u_{12}^r \\ u_{21}^r &
u_{22}^r \end{bmatrix} \quad \tilde{h}_t = g\left( W_h x_t + U_h (r_t * h_{t-1}) \right) = \begin{pmatrix} g(\cdot)_1 \\ g(\cdot)_2  \end{pmatrix} \quad h_{t-1} = \begin{pmatrix} h^{t-1}_1 \\ h^{t-1}_2\end{pmatrix}.
\end{equation}

This way we get

\begin{align}
\tilde{h}_t &= g\left( W_h x_t + U_h \cdot (r_t * h_{t-1}) \right) \\
            &= g \begin{pmatrix} w_{1}^h x_t + u_{11}^h r_1 h^{t-1}_1 + u_{12}^h r_2 h^{t-1}_2 \\
                                 w_{2}^h x_t + u_{21}^h r_1 h^{t-1}_1 + u_{22}^h r_2 h^{t-1}_2
                                 \end{pmatrix} \\
            &= \begin{pmatrix} g(w_{1}^h x_t + u_{11}^h \sigma(w_{1}^r x_t + u_{11}^r h^{t-1}_1 + u_{12}^r h^{t-1}_2) h^{t-1}_1 + u^h_{12} \sigma(w_{2}^r x_t + u_{21}^r h^{t-1}_1 + u^r_{22} h^{t-1}_2) h^{t-1}_2) \\
                               g(w^h_{2} x_t + u^h_{21} \sigma(w^r_{1}x_t + u^r_{11} h^{t-1}_1 + u^r_{12} h^{t-1}_2) h^{t-1}_1 + u^h_{22} \sigma(w^r_{2}x_t + u^r_{21} h^{t-1}_1 + u^r_{22} h^{t-1}_2) h^{t-1}_2)
                                 \end{pmatrix} \\
\end{align}

This way we get the derivatives for

\begin{align}
\frac{\partial \tilde{h}_t}{\partial w_1^r} = \begin{pmatrix} g^\prime(\cdot)_1 u^h_{11} h_1^{t-1} \sigma^\prime(\cdot)_1 x_t \\
                                                              g^\prime(\cdot)_2 u^h_{21} h_1^{t-1} \sigma^\prime(\cdot)_1 x_t
                                              \end{pmatrix}
\end{align}

and
\begin{align}
\frac{\partial \tilde{h}_t}{\partial w_2^r} = \begin{pmatrix} g^\prime(\cdot)_1 u^h_{12} h_2^{t-1} \sigma^\prime(\cdot)_2 x_t \\
                                                              g^\prime(\cdot)_2 u^h_{22} h_2^{t-1} \sigma^\prime(\cdot)_2 x_t
                                              \end{pmatrix}
\end{align}

which we can combine and reformulate as

\begin{align}
\frac{\partial \tilde{h}_t}{\partial W_r} &= \begin{bmatrix} g_1^\prime(\cdot) u^h_{11} h_1^{t-1} \sigma^\prime(\cdot)_1 x_t &                                                                        g_1^\prime(\cdot) u^h_{12} h_2^{t-1} \sigma^\prime(\cdot)_2 x_t \\
                                                             g_2^\prime(\cdot) u^h_{21} h_1^{t-1} \sigma^\prime(\cdot)_1 x_t &
                                                             g_2^\prime(\cdot) u^h_{22} h_2^{t-1} \sigma^\prime(\cdot)_2 x_t
                                                             \end{bmatrix} \\
                                          &= \begin{bmatrix} g^\prime(\cdot)_1 & 0 \\
                                                             0 & g^\prime(\cdot)_2
                                             \end{bmatrix} * U_h * \begin{bmatrix} h_{t-1}^T \\ h_{t-1}^T \end{bmatrix}\ x_t \\
                                          &= \textbf{diag}\left[ \tilde{h}_t \right] \cdot U_h * \begin{bmatrix} h_{t-1}^T \\ h_{t-1}^T \end{bmatrix}\ x_t
\end{align}

which in turn provides us with
\begin{align}
\frac{\partial {loss}_t}{\partial W_r} &= \frac{\partial {loss}_t}{\partial \hat{y}_t}\frac{\partial \hat{y}_t}{\partial h_t}\frac{\partial h_t}{\partial \tilde{h}_t}\frac{\partial \tilde{h}_t}{\partial W_r} \\
                                   &= (y-\hat{y}_t)\ \textbf{diag} \left[ w * z_t * \tilde{h}_t \right] \cdot U_h * \begin{bmatrix} h_{t-1}^T \\ h_{t-1}^T \end{bmatrix}\ x_t.
\end{align}

In our example this is a $2 \times 2$ Jacobian matrix which we need to transform in order to update our $2 \times 1$ vector $W_z$. As it stands, we have ordered the gradients column-wise. The first column of $\frac{\partial {loss}_t }{\partial W_r}$ contains the partial derivative with respect to the first element of $W_r$ and the second column of $\frac{\partial {loss}_t}{\partial W_r}$ contains the partial derivative with respect to the second element of $W_r$. So in order to obtain the total gradients we can compute the column-sums. 

In [9]:
dLossdWr = np.zeros_like(Wr)

for t, xt in enumerate(x):
    jacobian = (y-y_[t]) * (w * z[t] * h_[t]).reshape(-1, 1) * Uh * h[t] * xt
    dLossdWr += jacobian.sum(axis=0)

Similarly we get for the gradients with respect to $U_r$

\begin{align}
\frac{\partial \tilde{h}_t}{\partial u^r_{11}} &= \begin{bmatrix} 
    g^\prime (\cdot)_1 u_{11}^r h^{t-1}_1 \sigma^\prime (\cdot)_1 h^{t-1}_1 \\ 
    g^\prime (\cdot)_2 u_{21}^r h^{t-1}_1 \sigma^\prime (\cdot)_2 h^{t-1}_1
\end{bmatrix}  \quad 
\frac{\partial \tilde{h}_t}{\partial u^r_{12}} = \begin{bmatrix} 
    g^\prime (\cdot)_1 u_{11}^r h^{t-1}_1 \sigma^\prime (\cdot)_1 h^{t-1}_2 \\ 
    g^\prime (\cdot)_2 u_{21}^r h^{t-1}_1 \sigma^\prime (\cdot)_1 h^{t-1}_2
\end{bmatrix}\\
\frac{\partial \tilde{h}_t}{\partial u^r_{21}} &= \begin{bmatrix} 
    g^\prime (\cdot)_1 u_{12}^r h^{t-1}_2 \sigma^\prime (\cdot)_1 h^{t-1}_1 \\ 
    g^\prime (\cdot)_2 u_{22}^r h^{t-1}_2 \sigma^\prime (\cdot)_2 h^{t-1}_1
\end{bmatrix} \quad 
\frac{\partial \tilde{h}_t}{\partial u^r_{22}} = \begin{bmatrix} 
    g^\prime (\cdot)_1 u_{12}^r h^{t-1}_2 \sigma^\prime (\cdot)_1 h^{t-1}_2 \\ 
    g^\prime (\cdot)_2 u_{22}^r h^{t-1}_2 \sigma^\prime (\cdot)_2 h^{t-1}_2
\end{bmatrix}
\end{align}

Where, again, we will sum all fitting elements of each gradient in order to get a $2 \times 2$ matrix. Therefore we get

\begin{align}
\frac{\partial \tilde{h}_t}{\partial U_r} &= \begin{bmatrix} g^\prime (\cdot)_1 u_{11}^r h^{t-1}_1 \sigma^\prime (\cdot)_1 h^{t-1}_1 + 
    g^\prime (\cdot)_2 u_{21}^r h^{t-1}_1 \sigma^\prime (\cdot)_2 h^{t-1}_1 & g^\prime (\cdot)_1 u_{11}^r h^{t-1}_1 \sigma^\prime (\cdot)_1 h^{t-1}_2 + g^\prime (\cdot)_2 u_{21}^r h^{t-1}_1 \sigma^\prime (\cdot)_1 h^{t-1}_2 \\
    g^\prime (\cdot)_1 u_{12}^r h^{t-1}_2 \sigma^\prime (\cdot)_1 h^{t-1}_1 + 
    g^\prime (\cdot)_2 u_{22}^r h^{t-1}_2 \sigma^\prime (\cdot)_2 h^{t-1}_1 & g^\prime (\cdot)_1 u_{12}^r h^{t-1}_2 \sigma^\prime (\cdot)_1 h^{t-1}_2 + 
    g^\prime (\cdot)_2 u_{22}^r h^{t-1}_2 \sigma^\prime (\cdot)_2 h^{t-1}_2\end{bmatrix} \\
    &= \begin{bmatrix} g^\prime (\cdot)_1 & g^\prime (\cdot)_2 \\ g^\prime (\cdot)_1 & g^\prime (\cdot)_2 \end{bmatrix} \cdot U_r * 
    \begin{bmatrix} \sigma^\prime (\cdot)_1 & \sigma^\prime (\cdot)_1 \\
                    \sigma^\prime (\cdot)_2 & \sigma^\prime (\cdot)_2 \end{bmatrix} * \begin{bmatrix}
                    h^{t-1}_1 h^{t-1}_1 & h^{t-1}_1 h^{t-1}_2 \\ h^{t-1}_2 h^{t-1}_1 & h^{t-1}_2 h^{t-1}_2
                    \end{bmatrix} \\
    &= \begin{bmatrix} g^\prime (\cdot)^T \\ g^\prime (\cdot)^T \end{bmatrix} \cdot U_r \cdot \textbf{diag}\left[ \sigma^\prime (\cdot)\right] * h_{t-1} h_{t-1}^T
\end{align}

which leaves us with
\begin{align}
\frac{\partial {loss}_t}{\partial U_r} &= \frac{\partial {loss}_t}{\partial \hat{y}_t}\frac{\partial \hat{y}_t}{\partial h_t}\frac{\partial h_t}{\partial \tilde{h}_t}\frac{\partial \tilde{h}_t}{\partial U_r} \\
                                       &= 
\end{align}


In [None]:
dLossdUr = np.zeros_like(Ur)

for t, xt in enumerate(x):
    dtanh(Wh.dot(xt) + Uh.dot(r[t] * h[t])).reshape(-1, 1)
    dsigmoid(Wr.dot(xt) + Ur.dot(h[t])).reshape(-1, 1)
    np.outer(h[t], h[t])
    