## NeLLoC (ANS)

In [1]:
import time
import torch
from models.nelloc_model import *
from coders.nelloc_ans import *
from coders.pnelloc_ans import *
import numpy as np
from tqdm import tqdm
from models.utils import get_test_image

device=torch.device("cpu")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def test(net,all_img, D, rf, p_prec=16, parallel=False):
    K=rf*2+1
    if parallel:
        time_length=np.arange(0,D+int((K+1)/2)*(D-1))
        index_matrix=np.zeros((D,D))
        for i in range(0,D):
            index_matrix[i:i+1,:]=time_length[i*int((K+1)/2): i*int((K+1)/2)+D].reshape(1,D)
        time_index=[]
        for t in time_length:
            time_index.append(list(zip(*np.where(index_matrix==t))))
    else:
        pass

    BPD_list=[]
    compression_time_list=[]
    decompression_time_list=[]
    for i in tqdm(range(0,all_img.size(0))):
        img=all_img[i].unsqueeze(0)
        if parallel:
            start = time.time()
            ans_stack=p_ans_compression(net,img,time_index,D,D,rf,p_prec)
            end = time.time()
        else:
            start = time.time()
            ans_stack=ans_compression(net,img,D,D,rf,p_prec)
            end = time.time()
        compression_time_list.append(end - start)
        BPD_list.append(ans_stack.get_length()/(D*D*3))
        
        if parallel:
            start = time.time()
            decode_img=p_ans_decompression(net,ans_stack,time_index,D,D,rf,p_prec)
            end = time.time()
        else:
            start = time.time()
            decode_img=ans_decompression(net,ans_stack,D,D,rf,p_prec)
            end = time.time()
        decompression_time_list.append(end - start)
        if (img-decode_img).sum().item()>0.:
            print('wrong')
    
    print('average compression time', np.mean(compression_time_list))
    print('average decompression time',np.mean(decompression_time_list))
    print('average BPD', np.mean(BPD_list))

In [3]:
net = LocalPixelCNN(res_num=0, in_kernel = 7, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/nelloc_rs0h3.pth',map_location=device))
D=32
test_images=get_test_image(D)[0:10,:,0:D,0:D]
test(net,test_images, D=D,rf=3,parallel=False)
test(net,test_images, D=D,rf=3,parallel=True)

100%|██████████| 10/10 [00:09<00:00,  1.10it/s]


average compression time 0.4507297992706299
average decompression time 0.4602129220962524
average BPD 3.39541015625


100%|██████████| 10/10 [00:04<00:00,  2.29it/s]

average compression time 0.21267218589782716
average decompression time 0.2230468511581421
average BPD 3.3938151041666664





In [4]:
net = LocalPixelCNN(res_num=1, in_kernel = 7, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/nelloc_rs1h3.pth',map_location=device))
D=32
test_images=get_test_image(D)[0:10,:,0:D,0:D]
test(net,test_images, D=D,rf=3,parallel=False)
test(net,test_images, D=D,rf=3,parallel=True)

100%|██████████| 10/10 [00:11<00:00,  1.15s/it]


average compression time 0.5721799373626709
average decompression time 0.5781133651733399
average BPD 3.31826171875


100%|██████████| 10/10 [00:05<00:00,  1.91it/s]

average compression time 0.25638842582702637
average decompression time 0.26610987186431884
average BPD 3.3184895833333337





In [5]:
net = LocalPixelCNN(res_num=3, in_kernel = 7, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/nelloc_rs3h3.pth',map_location=device))
D=32
test_images=get_test_image(D)[0:10,:,0:D,0:D]
test(net,test_images, D=D,rf=3,parallel=False)
test(net,test_images, D=D,rf=3,parallel=True)

100%|██████████| 10/10 [00:15<00:00,  1.53s/it]


average compression time 0.7565342426300049
average decompression time 0.7739993333816528
average BPD 3.2850260416666663


100%|██████████| 10/10 [00:06<00:00,  1.51it/s]

average compression time 0.32673208713531493
average decompression time 0.334816575050354
average BPD 3.2854166666666664





In [6]:
net = LocalPixelCNN(res_num=0, in_kernel = 7, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/nelloc_rs0h3.pth',map_location=device))
D=64
test_images=get_test_image(D)[0:10,:,0:D,0:D]
test(net,test_images, D=D,rf=3,parallel=False)
test(net,test_images, D=D,rf=3,parallel=True)

100%|██████████| 10/10 [00:36<00:00,  3.69s/it]


average compression time 1.815756893157959
average decompression time 1.8785051107406616
average BPD 3.0523763020833337


100%|██████████| 10/10 [00:14<00:00,  1.46s/it]

average compression time 0.7032831192016602
average decompression time 0.7568142890930176
average BPD 3.0521484375000005





In [7]:
net = LocalPixelCNN(res_num=0, in_kernel = 7, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/nelloc_rs0h3.pth',map_location=device))
D=128
test_images=get_test_image(D)[0:10,:,0:D,0:D]
test(net,test_images, D=D,rf=3,parallel=False)
test(net,test_images, D=D,rf=3,parallel=True)

100%|██████████| 10/10 [02:29<00:00, 14.97s/it]


average compression time 7.394210863113403
average decompression time 7.573542857170105
average BPD 2.9347513834635417


100%|██████████| 10/10 [00:42<00:00,  4.20s/it]

average compression time 1.9836273908615112
average decompression time 2.2172059297561644
average BPD 2.9345642089843755





In [20]:
import PIL
D=1024
test_img1=torch.tensor(np.asarray(PIL.Image.open('img-1024/1.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)
test_img2=torch.tensor(np.asarray(PIL.Image.open('img-1024/2.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)
test_img3=torch.tensor(np.asarray(PIL.Image.open('img-1024/3.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)
test_images=torch.cat((test_img1,test_img2,test_img3),0)
print(test_images.size())

net = LocalPixelCNN(res_num=0, in_kernel = 7, out_channels=100).to(device)
net.load_state_dict(torch.load('./model_save/nelloc_rs0h3.pth',map_location=device))
# test(net,test_images, D=D,rf=3,parallel=True)
test(net,test_images, D=D,rf=3,parallel=False)

torch.Size([3, 3, 1024, 1024])


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [47:03<00:00, 941.27s/it]

average compression time 465.3420154253642
average decompression time 475.91408737500507
average BPD 2.2237941953870983





In [19]:
test(net,test_images, D=D,rf=3,parallel=True)

torch.Size([3, 3, 1024, 1024])


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [09:13<00:00, 184.52s/it]

average compression time 84.46455391248067
average decompression time 100.04404664039612
average BPD 2.2237012651231556



