In [2]:
import numpy as np
from numpy import cos, sin
from numpy.testing import assert_allclose

In [8]:
def layer(x: np.ndarray, θ: np.ndarray, w: float) -> tuple:
    """
    Each layer is the product of three rotations.
    
    Parmeters
    ---------
    θi : (3) array
        Bias parameters of each rotation.
    wi : float
        Weight of the X rotation.
    
    Returns
    -------
    A : (G,2,2) array
        Unitary matrix of the layer. 
    """
    ϕ = w * x + θ[0]
    Rx = np.array([[cos(ϕ/2), -1j * sin(ϕ/2)], [-1j * sin(ϕ/2), cos(ϕ/2)]])
    Ry = np.array([[cos(θ[1]/2), -sin(θ[1]/2)], [sin(θ[1]/2), cos(θ[1]/2)]])
    Rz = np.array([[cos(θ[2]/2) - 1j * sin(θ[2]/2), 0], [0, cos(θ[2]/2) + 1j * sin(θ[2]/2)]])

    Ui = np.einsum('mn, np, pqi -> mqi', Rz, Ry, Rx)
    return np.moveaxis(Ui, -1, 0)

In [9]:
def layer_v2(x: np.ndarray, θ: np.ndarray, w: float) -> tuple:
    """
    Each layer is the product of three rotations.
    
    Parmeters
    ---------
    θi : (3) array
        Bias parameters of each rotation.
    wi : float
        Weight of the X rotation.
    
    Returns
    -------
    A : (G,2,2) array
        Unitary matrix of the layer. 
    """
    ϕ = w * x + θ[0]
    Rx = np.array([[cos(ϕ/2), -1j * sin(ϕ/2)], [-1j * sin(ϕ/2), cos(ϕ/2)]])
    Ry = np.array([[cos(θ[1]/2), -sin(θ[1]/2)], [sin(θ[1]/2), cos(θ[1]/2)]])
    Rz = np.array([[cos(θ[2]/2) - 1j * sin(θ[2]/2), 0], [0, cos(θ[2]/2) + 1j * sin(θ[2]/2)]])

    return np.einsum('mn,np,pqi->imq', Rz, Ry, Rx)

In [11]:
x = np.linspace(1,100,500)
θ = np.array([2,3,4])
w = 3

assert_allclose(layer(x, θ, w), layer_v2(x, θ, w))

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 2000 / 2000 (100%)
Max absolute difference: 1.
Max relative difference: 2.60770824
 x: array([[[ 0.566409-0.196898j, -0.371053-0.709036j],
        [ 0.371053-0.709036j,  0.566409+0.196898j]],
...
 y: array([[[-0.433591-0.196898j, -1.371053-0.709036j],
        [-0.628947-0.709036j, -0.433591+0.196898j]],
...

In [6]:
a = np.arange(2*3*4)
a.reshape(2,3,4)

array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

In [8]:
a.reshape(2*3,4)

array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11],
       [12, 13, 14, 15],
       [16, 17, 18, 19],
       [20, 21, 22, 23]])