## Parallel NeLLoC (Arithmetic Coding)

In [1]:
import time
import torch
from model import *
import matplotlib.pyplot as plt
from compression import *
import numpy as np
from tqdm import tqdm
from scipy.io import loadmat
from utils import get_test_image
%matplotlib inline 

device=torch.device("cpu")
net = PixelCNN_light(in_kernel = 5, in_channels=3, channels=100, out_channels=9).to(device)
net.load_state_dict(torch.load('./model_save/logistic_ks5.pt',map_location=device))

<All keys matched successfully>

In [2]:
def test(all_img,D, K, prec=6000, parallel=False):
    if parallel:
        time_length=np.arange(0,D+int((K+1)/2)*(D-1))
        index_matrix=np.zeros((D,D))
        for i in range(0,D):
            index_matrix[i:i+1,:]=time_length[i*int((K+1)/2): i*int((K+1)/2)+D].reshape(1,D)
        time_index=[]
        for t in time_length:
            time_index.append(list(zip(*np.where(index_matrix==t))))
    else:
        pass
    
    BPD_list=[]
    compression_time_list=[]
    decompression_time_list=[]

    for i in tqdm(range(0,all_img.size(0))):
        img=all_img[i].unsqueeze(0)
        getcontext().prec=prec
        if parallel:
            start = time.time()
            code=ac_compression_parallel(net,img,time_index,K)
            end = time.time()
        else:
            start = time.time()
            code=ac_compression(net,img,K)
            end = time.time()
        compression_time_list.append(end - start)
        BPD_list.append(len(code)/(D*D*3))
        
        if parallel:
            start = time.time()
            decode_img=ac_decompression_parallel(net,code,D,D,time_index,K)
            end = time.time()
        else:
            start = time.time()
            decode_img=ac_decompression(net,code,D,D,K)
            end = time.time()
        decompression_time_list.append(end - start)
        if (img-decode_img).sum().item()>0.:
            print('wrong')
    
    if parallel:
        print('Using parallelization')
    else:
        print('Not using parallelization')
    print('average compression time', np.mean(compression_time_list))
    print('average decompression time',np.mean(decompression_time_list))
    print('average BPD', np.mean(BPD_list))

In [3]:
D=32
test_images=get_test_image(D)
test(test_images,D,5,parallel=False)
test(test_images,D,5,parallel=True)

100%|██████████| 10/10 [00:25<00:00,  2.55s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Not using parallelization
average compression time 1.2363938570022583
average decompression time 1.3162425994873046
average BPD 3.65078125


100%|██████████| 10/10 [00:13<00:00,  1.32s/it]

Using parallelization
average compression time 0.6016015768051147
average decompression time 0.7219446897506714
average BPD 3.6511067708333336





In [4]:
D=64
prec=20000
test_images=get_test_image(D)
test(test_images,D,5,prec,parallel=False)
test(test_images,D,5,prec,parallel=True)

100%|██████████| 10/10 [01:49<00:00, 10.94s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Not using parallelization
average compression time 5.355291795730591
average decompression time 5.588459205627442
average BPD 3.5683430989583336


100%|██████████| 10/10 [01:00<00:00,  6.03s/it]

Using parallelization
average compression time 2.6736512422561645
average decompression time 3.3601449728012085
average BPD 3.5684488932291663





In [7]:
D=128
prec=100000
test_images=get_test_image(D)
test(test_images,D,5,prec,parallel=False)
test(test_images,D,5,prec,parallel=True)

100%|██████████| 10/10 [09:43<00:00, 58.33s/it]
  0%|          | 0/10 [00:00<?, ?it/s]

Not using parallelization
average compression time 26.590954637527467
average decompression time 31.736795520782472
average BPD 3.1255859375


100%|██████████| 10/10 [06:47<00:00, 40.75s/it]

Using parallelization
average compression time 17.660334014892577
average decompression time 23.08530488014221
average BPD 3.125579833984375



