In [1]:
import numpy as np

In [2]:
def softmax(X):
    return np.exp(X) / np.sum(np.exp(X), axis=0)





In [9]:
XOR_INPUTS = [np.array([[1, 0], [1, 0]]), np.array([[0, 1], [1, 0]]), np.array([[1, 0], [0, 1]]), np.array([[0, 1], [0, 1]]), ]
INPUTS = [(0, 0), (1, 0), (0, 1), (1, 1)]

Output of encoder is
$$
H = \sigma(X W_Q W_K^T X^T) X W_V
$$

Output of decoder times linear layer is
$$
D = \sigma(q^T H) H W
$$
Where $W = W_V^{dec} W'$ absorbs the decoder values matrix and $q^T = v_{SOS} W_Q^{dec} (W_K^{dec})^T$ absorbs several weight matrices. 

Without the FFNN at the end of the encoder, I'm actually, very suspicious that this might just be a mostly-linear function, i.e. the input appears right here:
$$
D = \sigma(q^T \sigma(X W_Q W_K^T X^T) X W_V) \sigma(X W_Q W_K^T X^T) \textcolor{red}{X} W_V W 
$$

In [146]:

# Compute Q, K, V for encoder self-attention
def get_QKV(X, Wq, Wk, Wv):

    Q = X @ Wq
    K = X @ Wk
    V = X @ Wv
    return Q, K, V

# Compute output of encoder
def encoder(X, Wq, Wk, Wv):
    Q, K, V = get_QKV(X, Wq, Wk, Wv)
    Z = softmax(Q @ K.T) @ V
    return Z

def rowwise_FFNN(X, W1, W2):
    """With ReLU activation function."""
    return np.maximum(0, X @ W1) @ W2

# Compute output of decoder with cross attention
def sigma_qkt(memory, v):   
    return softmax(v @ memory)

def decoder_linear(memory, v, W):
    return sigma_qkt(memory, v) @ memory @ W


In [165]:
Wq = np.array([[1, 1], [1, -1]])
Wk = np.array([[1, 1], [1, -1]])
for pair, X in zip(INPUTS, XOR_INPUTS):
    
    S = softmax(X @ Wq @ Wk.T @ X.T)
    print(f"pair {pair} S:")
    print(S)

pair (0, 0) S:
[[0.5 0.5]
 [0.5 0.5]]
pair (1, 0) S:
[[0.88079708 0.11920292]
 [0.11920292 0.88079708]]
pair (0, 1) S:
[[0.88079708 0.11920292]
 [0.11920292 0.88079708]]
pair (1, 1) S:
[[0.5 0.5]
 [0.5 0.5]]


In [163]:
# Specify query vector for the SOS token
# This abosrbs the Wq and Wk matrices for decoder


# Specify encoder Wq, Wk, Wv for self-attention
Wq = np.array([[1, 1], [1, -1]])
Wk = np.array([[1, 1], [1, -1]])
Wv = np.array(
    [[1, 1/2],
     [1/2, -1]]

)
# Wq = np.random.rand(2, 2)
# Wk = np.random.rand(2, 2)
# Wv = np.random.rand(2, 2)

W1 = np.array([[1, 0], [0, 1]])
W2 = np.array([[1, 0], [0, 1]])

# DECODER
# vqk_vec = np.array([-1, .2])
vqk_vec = np.random.rand(2)
W = np.array(
    [[1, 0],
     [0, 1]]
)

for pair, X in zip(INPUTS, XOR_INPUTS):
    print(f"{pair} encoder")
    enc_out = encoder(X, Wq, Wk, Wv)
    print(enc_out)
    memory = rowwise_FFNN(enc_out, W1, W2)
    print(memory)

    sigma_q_Kt = sigma_qkt(memory, vqk_vec)
    print("sigma_qT_H", sigma_q_Kt)

    decoder_prelinear = sigma_qkt(memory, vqk_vec) @ memory
    print("     decoder_prelinear", decoder_prelinear)

    decoder = decoder_linear(memory, vqk_vec, W)
    # print("decoder_linear", decoder)  
    print(f"output: {pair} ->", softmax(decoder))
    print()

# print(Wq, Wk, Wv, vqk_vec)

(0, 0) encoder
[[1.  0.5]
 [1.  0.5]]
[[1.  0.5]
 [1.  0.5]]
sigma_qT_H [0.66702773 0.33297227]
     decoder_prelinear [1.  0.5]
output: (0, 0) -> [0.62245933 0.37754067]

(1, 0) encoder
[[ 0.55960146 -0.82119562]
 [ 0.94039854  0.32119562]]
[[0.55960146 0.        ]
 [0.94039854 0.32119562]]
sigma_qT_H [0.6966924 0.3033076]
     decoder_prelinear [0.67510011 0.09742107]
output: (1, 0) -> [0.64053318 0.35946682]

(0, 1) encoder
[[ 0.94039854  0.32119562]
 [ 0.55960146 -0.82119562]]
[[0.94039854 0.32119562]
 [0.55960146 0.        ]]
sigma_qT_H [0.69134289 0.30865711]
     decoder_prelinear [0.82286281 0.22205631]
output: (0, 1) -> [0.6458408 0.3541592]

(1, 1) encoder
[[ 0.5 -1. ]
 [ 0.5 -1. ]]
[[0.5 0. ]
 [0.5 0. ]]
sigma_qT_H [0.66702773 0.33297227]
     decoder_prelinear [0.5 0. ]
output: (1, 1) -> [0.62245933 0.37754067]



In [167]:
v = np.array([1, 1])
Sigma = np.array([[1, 0], [0, 1]])
W = np.array([
    [2, 1], 
    [3, 1]])

for pair, X in zip(INPUTS, XOR_INPUTS):
    print(v.T @ Sigma @ X @ W)

[4 2]
[5 2]
[5 2]
[6 2]
