In [1]:
import numpy as np
import sympy

In [2]:
x1, x2, y1, y2 = sympy.symbols('x1, x2, y1, y2')

In [3]:
def idx_to_expr(idx: int, s1, s2):
    assert idx >= 0 and idx < 4

    expr = 1
    if idx & 2:
        expr *= s1
    else:
        expr *= (1 - s1)

    if idx & 1:
        expr *= s2
    else:
        expr *= (1 - s2)

    return expr

def matrix_to_poly(a: np.array, field):
    assert a.shape == (4,4)
    e = 0
    
    for i in range(4):
        for j in range(4):
            subexpr = a[i][j] * idx_to_expr(i, x1, x2) * idx_to_expr(j, y1, y2)
            # print(f"{i},{j}: {subexpr}")
            e += subexpr

    return e.as_poly(domain=field)

def eval_matrix_poly_at(poly, i: int, j: int):
    return poly.eval({x1: i & 2 > 0, x2: i & 1, y1: j & 2 > 0, y2: j & 1})

def pretty_print_poly_evals(poly):
    for i in range(4):
        for j in range(4):
            print(eval_matrix_poly_at(poly, i, j), end=' ')
        print()

In [4]:
A = np.array([
 [0,1,1,0],
 [1,0,1,1],
 [1,1,0,1],
 [0,1,1,0],
])
A2 = A @ A

In [5]:
np.sum(A2 * A)

np.int64(12)

In [6]:
F389 = sympy.GF(389)

In [7]:
p_A = matrix_to_poly(A, F389)
p_A2 = matrix_to_poly(A2, F389)

In [8]:
g = p_A * p_A2
g

Poly(-60*x1**2*x2**2*y1**2*y2**2 + 60*x1**2*x2**2*y1**2*y2 - 15*x1**2*x2**2*y1**2 + 60*x1**2*x2**2*y1*y2**2 - 62*x1**2*x2**2*y1*y2 + 16*x1**2*x2**2*y1 - 15*x1**2*x2**2*y2**2 + 16*x1**2*x2**2*y2 - 4*x1**2*x2**2 + 60*x1**2*x2*y1**2*y2**2 - 68*x1**2*x2*y1**2*y2 + 19*x1**2*x2*y1**2 - 52*x1**2*x2*y1*y2**2 + 62*x1**2*x2*y1*y2 - 18*x1**2*x2*y1 + 11*x1**2*x2*y2**2 - 14*x1**2*x2*y2 + 4*x1**2*x2 - 15*x1**2*y1**2*y2**2 + 19*x1**2*y1**2*y2 - 6*x1**2*y1**2 + 11*x1**2*y1*y2**2 - 15*x1**2*y1*y2 + 5*x1**2*y1 - 2*x1**2*y2**2 + 3*x1**2*y2 - x1**2 + 60*x1*x2**2*y1**2*y2**2 - 52*x1*x2**2*y1**2*y2 + 11*x1*x2**2*y1**2 - 68*x1*x2**2*y1*y2**2 + 62*x1*x2**2*y1*y2 - 14*x1*x2**2*y1 + 19*x1*x2**2*y2**2 - 18*x1*x2**2*y2 + 4*x1*x2**2 - 62*x1*x2*y1**2*y2**2 + 62*x1*x2*y1**2*y2 - 15*x1*x2*y1**2 + 62*x1*x2*y1*y2**2 - 68*x1*x2*y1*y2 + 18*x1*x2*y1 - 15*x1*x2*y2**2 + 18*x1*x2*y2 - 6*x1*x2 + 16*x1*y1**2*y2**2 - 18*x1*y1**2*y2 + 5*x1*y1**2 - 14*x1*y1*y2**2 + 18*x1*y1*y2 - 6*x1*y1 + 3*x1*y2**2 - 4*x1*y2 + 2*x1 - 15*x2**2*y1

In [9]:
sum(eval_matrix_poly_at(g, i,j) for i in range(4) for j in range(4))

12