In [28]:
import numpy as np
from utils.ode import StudentTeacherODE

## Simplified Expressions

#### Start of Training

In [87]:
# sigmoid
def du_dt_start(Q, V, h2, v1, alpha):
    derivative = np.zeros((2, 2)).astype(float)
    for (i, p), _ in np.ndenumerate(derivative):
        s = 0
        for m, head_unit in enumerate(v1):
            nom = 2 * V[p][m] * (1 + Q[i][i])
            den = np.pi * np.sqrt(2 * (1 + Q[i][i])) * (1 + Q[i][i])
            s += head_unit * (nom / den)
        derivative[i][p] = alpha * h2[i] * s
    return derivative

## Start of Training

In [97]:
overlaps = {
    "Q": np.array([[0.1, 0], [0, 0.1]]),
    "R": np.array([[0, 0]]),
    "U": np.array([[0, 0]]),
    "T": np.array([[1]]),
    "S": np.array([[1]]),
    "V": np.array([[0.1]]),
    "h1": np.array([[0.01, 0]]),
    "h2": np.array([[0.01, 0]])
}

t1_head = np.array([[0.2, 0.3]])
t2_head = np.array([[0.1, 0.4]])


assert overlaps["Q"][0][0] - overlaps["R"][0][0] * overlaps["R"][0][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,m)"
assert overlaps["Q"][0][0] - overlaps["R"][0][1] * overlaps["R"][1][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,m)"
assert overlaps["Q"][1][1] - overlaps["R"][0][1] * overlaps["R"][1][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,m)"
assert overlaps["Q"][1][1] - overlaps["R"][1][1] * overlaps["R"][1][1] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,m)"

assert overlaps["Q"][0][0] - overlaps["U"][0][0] * overlaps["U"][0][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,p)"
assert overlaps["Q"][0][0] - overlaps["U"][0][1] * overlaps["U"][1][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,p)"
assert overlaps["Q"][1][1] - overlaps["U"][0][1] * overlaps["U"][1][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,p)"
assert overlaps["Q"][1][1] - overlaps["U"][1][1] * overlaps["U"][1][1] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,p)"

IndexError: index 1 is out of bounds for axis 0 with size 1

In [80]:
stode = StudentTeacherODE(overlaps=overlaps, nonlinearity="sigmoid", learning_rate=0.5, teacher_head1=t1_head, teacher_head2=t2_head)

In [81]:
drdt = stode.dr_dt()

In [82]:
drdt

array([[0.00042921, 0.00064381],
       [0.        , 0.        ]])

In [83]:
dudt = stode.du_dt()

In [88]:
dudt

array([[4.29208963e-05, 6.43813444e-05],
       [0.00000000e+00, 0.00000000e+00]])

In [85]:
dudt_simp = du_dt_start(overlaps["Q"], overlaps["V"], overlaps["h2"], t1_head, 0.5)

In [86]:
dudt_simp

array([[4.29208963e-05, 6.43813444e-05],
       [0.00000000e+00, 0.00000000e+00]])

## Switch After Convergence

In [95]:
overlaps = {
    "Q": np.array([[1, 0], [0, 1]]),
    "R": np.array([[1, 0], [0, 1]]),
    "U": np.array([[0, 0], [0, 0]]),
    "T": np.array([[1, 0], [0, 1]]),
    "S": np.array([[1, 0], [0, 1]]),
    "V": np.array([[1, 0], [0, 1]]),
    "h1": np.array([0.01, 0]),
    "h2": np.array([0, 0])
}

t1_head = np.array([0.2, 0.3])
t2_head = np.array([0.1, 0.4])

assert overlaps["Q"][0][0] - overlaps["R"][0][0] * overlaps["R"][0][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,m)"
assert overlaps["Q"][0][0] - overlaps["R"][0][1] * overlaps["R"][1][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,m)"
assert overlaps["Q"][1][1] - overlaps["R"][0][1] * overlaps["R"][1][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,m)"
assert overlaps["Q"][1][1] - overlaps["R"][1][1] * overlaps["R"][1][1] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,m)"

assert overlaps["Q"][0][0] - overlaps["U"][0][0] * overlaps["U"][0][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,p)"
assert overlaps["Q"][0][0] - overlaps["U"][0][1] * overlaps["U"][1][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,p)"
assert overlaps["Q"][1][1] - overlaps["U"][0][1] * overlaps["U"][1][0] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,p)"
assert overlaps["Q"][1][1] - overlaps["U"][1][1] * overlaps["U"][1][1] >= 0, "Covariance matrix constraint (positive semi-definiteness) violated for C(i,p)"