In [18]:
import numpy as np
import matplotlib.pyplot as plt
import cv2

In [19]:
# リファレンス(浮動小数点版)
R = np.ndarray((8))
for i in range(8):
    R[i] = np.cos(i / 16.0 * np.pi) * np.sqrt(2.0)
INVSQRT2 = 1.0 / np.sqrt(2.0)

def llm_dct_8(x):
    y = np.ndarray((8))   
    c1 = x[0]; c2 = x[7]; t0 = c1 + c2; t7 = c1 - c2
    c1 = x[1]; c2 = x[6]; t1 = c1 + c2; t6 = c1 - c2
    c1 = x[2]; c2 = x[5]; t2 = c1 + c2; t5 = c1 - c2
    c1 = x[3]; c2 = x[4]; t3 = c1 + c2; t4 = c1 - c2
    
    c0 = t0 + t3; c3 = t0 - t3
    c1 = t1 + t2; c2 = t1 - t2
    
    y[0] = c0 + c1
    y[4] = c0 - c1
    y[2] = c2 * R[6] + c3 * R[2]
    y[6] = c3 * R[6] - c2 * R[2]

    c3 = t4 * R[3] + t7 * R[5]
    c0 = t7 * R[3] - t4 * R[5]
    c2 = t5 * R[1] + t6 * R[7]
    c1 = t6 * R[1] - t5 * R[7]
    
    y[5] = c3 - c1; y[3] = c0 - c2
    c0 = (c0 + c2) * INVSQRT2
    c3 = (c3 + c1) * INVSQRT2
    y[1] = c0 + c3; y[7] = c0 - c3
    for i in range(8):
        y[i] *= INVSQRT2 * 0.5
    return y

In [20]:
# ファイル読み込み
img_src = cv2.imread("Mandrill.bmp")
w = img_src.shape[1]
h = img_src.shape[0]
# plt.imshow(img_src[:,:,::-1])

# YCbCr化 (色相は縮小)
img_ycrcb = cv2.cvtColor(img_src, cv2.COLOR_BGR2YCrCb)
img_y = img_ycrcb[:,:,0]
img_cr = cv2.resize(img_ycrcb[:,:,1], (w//2, h//2))
img_cb = cv2.resize(img_ycrcb[:,:,2], (w//2, h//2))

if False:
    # 表示
    plt.subplot(131)
    plt.imshow(img_y, 'gray')
    plt.subplot(132)
    plt.imshow(img_cr, 'gray')
    plt.subplot(133)
    plt.imshow(img_cb, 'gray')

print(img_y[0][0:8])

[145  49 137  62  71  92 153  74]


In [30]:
# test データ
x = np.array([
    145,
    49,
    137,
    62,
    71,
    92,
    153,
    74,
])

# 期待値
y_exp = llm_dct_8(x)
print(y_exp)

[276.83230483   3.20388141  34.56059356  20.09426522 -27.93071786
  71.37149684  28.92776128  58.93695859]


In [45]:
# 固定小数点版
Q = 12
QV = 2**Q
QA = 2**(Q-1)

QR = np.ndarray((8), dtype=np.int64)
for i in range(8):
    QR[i] = int(np.round(np.cos(i / 16.0 * np.pi) * np.sqrt(2.0) * QV))
QINVSQRT2 = int(np.round(1.0 / np.sqrt(2.0) * QV))

# 固定小数点乗算
def fix_mul(a, b):
    return ((a * b) + QA) >> Q

x = np.array(x).astype(np.int64)
st4_y = np.ndarray((8), dtype=np.int64)
c1 = x[0]; c2 = x[7]; st0_t0 = c1 + c2; st0_t7 = c1 - c2
c1 = x[1]; c2 = x[6]; st0_t1 = c1 + c2; st0_t6 = c1 - c2
c1 = x[2]; c2 = x[5]; st0_t2 = c1 + c2; st0_t5 = c1 - c2
c1 = x[3]; c2 = x[4]; st0_t3 = c1 + c2; st0_t4 = c1 - c2

st1_c0 = st0_t0 + st0_t3
st1_c3 = st0_t0 - st0_t3
st1_c1 = st0_t1 + st0_t2
st1_c2 = st0_t1 - st0_t2
st1_t4 = st0_t4
st1_t7 = st0_t7
st1_t5 = st0_t5
st1_t6 = st0_t6

st2_c0 = st1_c1 + st1_c0
st2_c1 = st1_c0 - st1_c1
st2_c2 = fix_mul(st1_c2, QR[6]) + fix_mul(st1_c3, QR[2])
st2_c3 = fix_mul(st1_c3, QR[6]) - fix_mul(st1_c2, QR[2])
st2_c4 = fix_mul(st1_t7, QR[3]) + fix_mul(st1_t4, -QR[5])
st2_c7 = fix_mul(st1_t4, QR[3]) - fix_mul(st1_t7, -QR[5])
st2_c6 = fix_mul(st1_t5, QR[1]) + fix_mul(st1_t6, QR[7])
st2_c5 = fix_mul(st1_t6, QR[1]) - fix_mul(st1_t5, QR[7])

st3_y0 = st2_c0
st3_y4 = st2_c1
st3_y2 = st2_c2
st3_y6 = st2_c3
st3_y5 = st2_c7 - st2_c5; st3_t7 = fix_mul((st2_c7 + st2_c5), QINVSQRT2)
st3_y3 = st2_c4 - st2_c6; st3_t4 = fix_mul((st2_c4 + st2_c6), QINVSQRT2)
st3_y1 = st3_t4 + st3_t7
st3_y7 = st3_t4 - st3_t7

st4_y[0] = fix_mul(st3_y0, (QINVSQRT2+1) // 2)
st4_y[1] = fix_mul(st3_y1, (QINVSQRT2+1) // 2)
st4_y[2] = fix_mul(st3_y2, (QINVSQRT2+1) // 2)
st4_y[3] = fix_mul(st3_y3, (QINVSQRT2+1) // 2)
st4_y[4] = fix_mul(st3_y4, (QINVSQRT2+1) // 2)
st4_y[5] = fix_mul(st3_y5, (QINVSQRT2+1) // 2)
st4_y[6] = fix_mul(st3_y6, (QINVSQRT2+1) // 2)
st4_y[7] = fix_mul(st3_y7, (QINVSQRT2+1) // 2)

st5_y = st4_y / 16
print(st4_y)
print('err mean : ', np.mean(st4_y - y_exp))
print('err std  : ', np.std(st4_y - y_exp))

[277   3  34  20 -28  71  29  58]
err mean :  -0.24956798490585885
err std  :  0.3392061842912857


In [34]:
print("st1 :", st1_c0, st1_c3, st1_c1, st1_c2, st1_t4, st1_t7, st1_t5, st1_t6)
print("st2 :", st2_c0, st2_c1, st2_c2, st2_c3, st2_c4, st2_c7, st2_c6, st2_c5)
print("st3 :", st3_y0, st3_y4, st3_y2, st3_y6, st3_y5, st3_y3, st3_y1, st3_y7)
print("out :", st4_y[0], st4_y[4], st4_y[2], st4_y[6], st4_y[5], st4_y[3], st4_y[1], st4_y[7])

st1 : 352 86 431 -27 -9 71 45 -104
st2 : 783 -79 97 82 90 45 33 -156
st3 : 783 -79 97 82 201 57 9 165
out : 277 -28 34 29 71 20 3 58


In [43]:
# table
Q32 = 32
Q32V = 2**Q32
Q32A = 2**(Q32-1)

Q32R = np.ndarray((8), dtype=np.int64)
for i in range(8):
    Q32R[i] = int(np.round(np.cos(i / 16.0 * np.pi) * np.sqrt(2.0) * Q32V))
Q32INVSQRT2 = int(np.round(1.0 / np.sqrt(2.0) * Q32V))

for v in Q32R:
    print("r :", hex(v))
print("  1/sqrt(2) :", hex(Q32INVSQRT2))
print("0.5/sqrt(2) :", hex(Q32INVSQRT2//2))


r : 0x16a09e668
r : 0x163150b16
r : 0x14e7ae914
r : 0x12d062ef9
r : 0x100000000
r : 0xc9234e07
r : 0x8a8bd3df
r : 0x46a1577b
  1/sqrt(2) : 0xb504f334
0.5/sqrt(2) : 0x5a82799a
