In [1]:
import numpy as np
from PIL import Image
from scipy.fft import dct, idct
from matplotlib import pylab as pylab
from matplotlib import pyplot as plt

%matplotlib inline
pylab.rcParams['figure.figsize'] = (8, 8)

In [2]:
def mdct4(x):
    N = x.shape[0]
    if N%4 != 0:
        raise ValueError("MDCT4 only defined for vectors of length multiple of four.")
    M = N // 2
    N4 = N // 4
    
    rot = np.roll(x, N4)
    rot[:N4] = -rot[:N4]
    t = np.arange(0, N4)
    w = np.exp(-1j*2*np.pi*(t + 1./8.) / N)
    c = np.take(rot,2*t) - np.take(rot, N-2*t-1) \
        - 1j * (np.take(rot, M+2*t) - np.take(rot,M-2*t-1))
    c = (2./np.sqrt(N)) * w * np.fft.fft(0.5 * c * w, N4)
    y = np.zeros(M)
    y[2*t] = np.real(c[t])
    y[M-2*t-1] = -np.imag(c[t])
    return y

def imdct4(x):
    N = x.shape[0]
    if N%2 != 0:
        raise ValueError("iMDCT4 only defined for even-length vectors.")
    M = N // 2
    N2 = N*2
    
    t = np.arange(0,M)
    w = np.exp(-1j*2*np.pi*(t + 1./8.) / N2)
    c = np.take(x,2*t) + 1j * np.take(x,N-2*t-1)
    c = 0.5 * w * c
    c = np.fft.fft(c,M)
    c = ((8 / np.sqrt(N2))*w)*c
    
    rot = np.zeros(N2)
    
    rot[2*t] = np.real(c[t])
    rot[N+2*t] = np.imag(c[t])
    
    t = np.arange(1,N2,2)
    rot[t] = -rot[N2-t-1]
    
    t = np.arange(0,3*M)
    y = np.zeros(N2)
    y[t] = rot[t+M]
    t = np.arange(3*M,N2)
    y[t] = -rot[t-3*M]
    return y

In [32]:
def mdct(x):
    N = x.shape[0]

    if N%4 != 0:
        raise ValueError("MDCT4 only defined for vectors of length multiple of four.")

    N4 = N // 4

    a = x[0*N4:1*N4]
    b = x[1*N4:2*N4]
    c = x[2*N4:3*N4]
    d = x[3*N4:4*N4]

    br = np.flip(b)
    cr = np.flip(c)

    return dct(np.hstack([-cr - d, a - br]), type=4, norm='ortho', orthogonalize=True) / 2


def imdct(y):
    N = y.shape[0] * 2

    if N%4 != 0:
        raise ValueError("IMDCT is only defined for vectors lengths multiple of two.")
    
    N4 = N // 4

    z = idct(y, type=4, norm='ortho', orthogonalize=True)

    z = np.hstack([z, -np.flip(z), -z]) * 2

    return z[N4:5*N//4]

In [37]:
x=np.arange(1600)

%timeit mdct4(x)
%timeit mdct(x)

86 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
22.5 µs ± 383 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
