In [1]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [2]:
# リファレンス(浮動小数点版)
R = np.ndarray((8))
for i in range(8):
    R[i] = np.cos(i / 16.0 * np.pi) * np.sqrt(2.0)
INVSQRT2 = 1.0 / np.sqrt(2.0)

def llm_dct_8(x):
    y = np.ndarray((8))   
    c1 = x[0]; c2 = x[7]; t0 = c1 + c2; t7 = c1 - c2
    c1 = x[1]; c2 = x[6]; t1 = c1 + c2; t6 = c1 - c2
    c1 = x[2]; c2 = x[5]; t2 = c1 + c2; t5 = c1 - c2
    c1 = x[3]; c2 = x[4]; t3 = c1 + c2; t4 = c1 - c2
    
    c0 = t0 + t3; c3 = t0 - t3
    c1 = t1 + t2; c2 = t1 - t2
    
    y[0] = c0 + c1
    y[4] = c0 - c1
    y[2] = c2 * R[6] + c3 * R[2]
    y[6] = c3 * R[6] - c2 * R[2]

    c3 = t4 * R[3] + t7 * R[5]
    c0 = t7 * R[3] - t4 * R[5]
    c2 = t5 * R[1] + t6 * R[7]
    c1 = t6 * R[1] - t5 * R[7]
    
    y[5] = c3 - c1; y[3] = c0 - c2
    c0 = (c0 + c2) * INVSQRT2
    c3 = (c3 + c1) * INVSQRT2
    y[1] = c0 + c3; y[7] = c0 - c3
    for i in range(8):
        y[i] *= INVSQRT2 * 0.5
    return y


def llm_dct_8x8(src):
    dst = np.ndarray((8, 8))
    for i in range(8):
        dst[i, :] = llm_dct_8(src[i, :])
    for i in range(8):
        dst[:, i] = llm_dct_8(dst[:, i])
    return dst

In [3]:
# 固定小数点版
Q = 12
QV = 2**Q
QA = 2**(Q-1)

QR = np.ndarray((8), dtype=np.int64)
for i in range(8):
    QR[i] = int(np.round(np.cos(i / 16.0 * np.pi) * np.sqrt(2.0) * QV))
QINVSQRT2 = int(np.round(1.0 / np.sqrt(2.0) * QV))

# 固定小数点乗算
def fix_mul(a, b):
    return ((a * b) + QA) >> Q

def fixpoint_llm_dct_8(x):
    x = np.array(x).astype(np.int64)
    c1 = x[0]; c2 = x[7]; st0_t0 = c1 + c2; st0_t7 = c1 - c2
    c1 = x[1]; c2 = x[6]; st0_t1 = c1 + c2; st0_t6 = c1 - c2
    c1 = x[2]; c2 = x[5]; st0_t2 = c1 + c2; st0_t5 = c1 - c2
    c1 = x[3]; c2 = x[4]; st0_t3 = c1 + c2; st0_t4 = c1 - c2

    st1_c0 = st0_t0 + st0_t3
    st1_c3 = st0_t0 - st0_t3
    st1_c1 = st0_t1 + st0_t2
    st1_c2 = st0_t1 - st0_t2
    st1_t4 = st0_t4
    st1_t7 = st0_t7
    st1_t5 = st0_t5
    st1_t6 = st0_t6

    st2_c0 = st1_c1 + st1_c0
    st2_c1 = st1_c0 - st1_c1
    st2_c2 = fix_mul(st1_c2, QR[6]) + fix_mul(st1_c3, QR[2])
    st2_c3 = fix_mul(st1_c3, QR[6]) - fix_mul(st1_c2, QR[2])
    st2_c4 = fix_mul(st1_t7, QR[3]) + fix_mul(st1_t4, -QR[5])
    st2_c7 = fix_mul(st1_t4, QR[3]) - fix_mul(st1_t7, -QR[5])
    st2_c6 = fix_mul(st1_t5, QR[1]) + fix_mul(st1_t6, QR[7])
    st2_c5 = fix_mul(st1_t6, QR[1]) - fix_mul(st1_t5, QR[7])

    st3_y0 = st2_c0
    st3_y4 = st2_c1
    st3_y2 = st2_c2
    st3_y6 = st2_c3
    st3_y5 = st2_c7 - st2_c5; st3_t7 = fix_mul((st2_c7 + st2_c5), QINVSQRT2)
    st3_y3 = st2_c4 - st2_c6; st3_t4 = fix_mul((st2_c4 + st2_c6), QINVSQRT2)
    st3_y1 = st3_t4 + st3_t7
    st3_y7 = st3_t4 - st3_t7

    st4_y = np.ndarray((8), dtype=np.int64)
    st4_y[0] = fix_mul(st3_y0, (QINVSQRT2+1) // 2)
    st4_y[1] = fix_mul(st3_y1, (QINVSQRT2+1) // 2)
    st4_y[2] = fix_mul(st3_y2, (QINVSQRT2+1) // 2)
    st4_y[3] = fix_mul(st3_y3, (QINVSQRT2+1) // 2)
    st4_y[4] = fix_mul(st3_y4, (QINVSQRT2+1) // 2)
    st4_y[5] = fix_mul(st3_y5, (QINVSQRT2+1) // 2)
    st4_y[6] = fix_mul(st3_y6, (QINVSQRT2+1) // 2)
    st4_y[7] = fix_mul(st3_y7, (QINVSQRT2+1) // 2)

    return st4_y

def fixpoint_llm_dct_8x8(src):
    dst = np.ndarray((8, 8))
    for i in range(8):
        dst[i, :] = fixpoint_llm_dct_8(src[i, :])
    for i in range(8):
        dst[:, i] = fixpoint_llm_dct_8(dst[:, i])
    return dst

In [4]:
# ファイル読み込み
img_src = cv2.imread("Mandrill.bmp")
w = img_src.shape[1]
h = img_src.shape[0]
# plt.imshow(img_src[:,:,::-1])

# YCbCr化 (色相は縮小)
img_ycrcb = cv2.cvtColor(img_src, cv2.COLOR_BGR2YCrCb)
img_y = img_ycrcb[:,:,0]
img_cr = cv2.resize(img_ycrcb[:,:,1], (w//2, h//2))
img_cb = cv2.resize(img_ycrcb[:,:,2], (w//2, h//2))

if False:
    # 表示
    plt.subplot(131)
    plt.imshow(img_y, 'gray')
    plt.subplot(132)
    plt.imshow(img_cr, 'gray')
    plt.subplot(133)
    plt.imshow(img_cb, 'gray')

#print(img_y[0:8][0:8])
for i in range(8):
    print("[", end='')
    for j in range(8):
        print(f"{img_y[i][j]}, ", end='')
    print("],")


[145, 49, 137, 62, 71, 92, 153, 74, ],
[77, 46, 98, 60, 108, 49, 72, 30, ],
[87, 132, 72, 67, 67, 89, 70, 35, ],
[43, 134, 103, 61, 83, 93, 89, 84, ],
[57, 56, 137, 66, 82, 147, 98, 134, ],
[33, 57, 71, 84, 88, 172, 167, 120, ],
[36, 90, 63, 80, 135, 113, 120, 103, ],
[29, 84, 83, 25, 94, 152, 72, 64, ],


In [5]:
for i in range(8):
    print("{", end='')
    for j in range(8):
        print(f"8'h{img_y[i][j]:02x}, ", end='')
    print("},")

{8'h91, 8'h31, 8'h89, 8'h3e, 8'h47, 8'h5c, 8'h99, 8'h4a, },
{8'h4d, 8'h2e, 8'h62, 8'h3c, 8'h6c, 8'h31, 8'h48, 8'h1e, },
{8'h57, 8'h84, 8'h48, 8'h43, 8'h43, 8'h59, 8'h46, 8'h23, },
{8'h2b, 8'h86, 8'h67, 8'h3d, 8'h53, 8'h5d, 8'h59, 8'h54, },
{8'h39, 8'h38, 8'h89, 8'h42, 8'h52, 8'h93, 8'h62, 8'h86, },
{8'h21, 8'h39, 8'h47, 8'h54, 8'h58, 8'hac, 8'ha7, 8'h78, },
{8'h24, 8'h5a, 8'h3f, 8'h50, 8'h87, 8'h71, 8'h78, 8'h67, },
{8'h1d, 8'h54, 8'h53, 8'h19, 8'h5e, 8'h98, 8'h48, 8'h40, },


In [6]:
# test データ
src_x = np.array([
[145, 49, 137, 62, 71, 92, 153, 74, ],
[77, 46, 98, 60, 108, 49, 72, 30, ],
[87, 132, 72, 67, 67, 89, 70, 35, ],
[43, 134, 103, 61, 83, 93, 89, 84, ],
[57, 56, 137, 66, 82, 147, 98, 134, ],
[33, 57, 71, 84, 88, 172, 167, 120, ],
[36, 90, 63, 80, 135, 113, 120, 103, ],
[29, 84, 83, 25, 94, 152, 72, 64, ],
])

# 期待値
y_exp = llm_dct_8x8(src_x)
print(np.round(y_exp).astype(np.int32))

[[693 -74 -26  33 -97 -18  24  32]
 [-18 107  32 -10  25  64  -3  48]
 [-22  17  -3  40  14  28   1  43]
 [ 72 -88  18  -6  12  74  32  24]
 [ 20   7  26 -20 -29  -3  56  13]
 [ 34  10  22  -4  -3  -4 -54 -11]
 [ 16  -8  27  29 -38  23   7 -37]
 [ 16  19  12   0  13   2  21  -3]]


In [7]:
Q = 12
x = src_x.astype(np.int64) << (Q - 8)
y = fixpoint_llm_dct_8x8(x)
print(y // 2**(Q - 8))
#print(y / 2**(Q - 8))
err = y / 2**(Q - 8) - y_exp
print('err mean : ', np.mean(err))
print('err std  : ', np.std(err))

[[692. -74. -26.  33. -97. -18.  24.  31.]
 [-19. 107.  32. -11.  25.  64.  -4.  48.]
 [-23.  16.  -3.  40.  13.  28.   1.  42.]
 [ 71. -88.  18.  -6.  11.  73.  32.  23.]
 [ 20.   6.  25. -21. -30.  -4.  55.  13.]
 [ 33.   9.  22.  -4.  -4.  -4. -54. -11.]
 [ 15.  -8.  27.  28. -38.  23.   7. -37.]
 [ 15.  19.  12.  -1.  13.   1.  21.  -4.]]
err mean :  -0.005282940593393613
err std  :  0.033746803678626175


In [8]:
y.astype(np.int32)

array([[11086, -1181,  -411,   531, -1551,  -283,   387,   510],
       [ -291,  1712,   519,  -167,   404,  1029,   -56,   772],
       [ -353,   270,   -48,   640,   217,   455,    21,   687],
       [ 1150, -1403,   292,   -92,   191,  1177,   513,   380],
       [  324,   104,   408,  -321,  -468,   -53,   892,   210],
       [  538,   157,   352,   -62,   -56,   -64,  -862,  -169],
       [  255,  -121,   432,   463,  -605,   375,   116,  -589],
       [  248,   309,   198,    -7,   208,    25,   337,   -53]],
      dtype=int32)