## ShearLoC (ANS)

In [1]:
import time
import torch
from models.shearloc_model import *
from coders.shearloc_ans import *
import numpy as np
from tqdm import tqdm
from models.utils import get_test_image,shear_quantity
%matplotlib inline 

device=torch.device("cpu")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def test(net,all_img, D, O, K, p_prec=16):
    BPD_list=[]
    compression_time_list=[]
    decompression_time_list=[]
    quantity=shear_quantity(D,O)
    for i in tqdm(range(0,all_img.size(0))):
        img=all_img[i].unsqueeze(0)
        start = time.time()
        ans_stack=ans_compression(net,img,quantity,K,p_prec)
        end = time.time()
        compression_time_list.append(end - start)
        BPD_list.append(ans_stack.get_length()/(D*D*3))
        

        start = time.time()
        decode_img=ans_decompression(net,ans_stack,quantity,K,p_prec)
        end = time.time()
        decompression_time_list.append(end - start)
        if (img-decode_img).sum().item()>0.:
            print('wrong')
    
    print('average compression time', np.mean(compression_time_list))
    print('average decompression time',np.mean(decompression_time_list))
    print('average BPD', np.mean(BPD_list))



In [3]:
h=3 ## dependency horizon
o=h+1 ## shear offset
kh=h+1 ## height of the cnn kernel 
kw=o*h+h ## width of the cnn kernel
mix_num=10 ## mixture num in the discretized logitsic mixture distribution

In [4]:
D=32 ## image side length
test_images=get_test_image(D)[0:10,:,0:D,0:D]

res=0 ## number of resnet blocks
net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)
dict_loaded=torch.load('./model_save/nelloc_rs0h3.pth',map_location=device)
a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]
dict_loaded['in_cnn.weight']=a.clone()
net.load_state_dict(dict_loaded,strict=False)
test_images=get_test_image(D)[0:10,:,0:D,0:D]
test(net,test_images, D=D,O=o,K=(kh,kw))


100%|██████████| 10/10 [00:04<00:00,  2.32it/s]

average compression time 0.21366124153137206
average decompression time 0.217557692527771
average BPD 3.3937825520833336





In [5]:
res=1
net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)
dict_loaded=torch.load('./model_save/nelloc_rs1h3.pth',map_location=device)
a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]
dict_loaded['in_cnn.weight']=a.clone()
net.load_state_dict(dict_loaded,strict=False)
test(net,test_images, D=D,O=o,K=(kh,kw))

100%|██████████| 10/10 [00:04<00:00,  2.27it/s]

average compression time 0.21780450344085694
average decompression time 0.22225584983825683
average BPD 3.3184895833333337





In [6]:
res=3
net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)
dict_loaded=torch.load('./model_save/nelloc_rs3h3.pth',map_location=device)
a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]
dict_loaded['in_cnn.weight']=a.clone()
net.load_state_dict(dict_loaded,strict=False)
test(net,test_images, D=D,O=o,K=(kh,kw))

100%|██████████| 10/10 [00:04<00:00,  2.06it/s]

average compression time 0.2405768394470215
average decompression time 0.24517347812652587
average BPD 3.2854166666666664





In [7]:
D=64
test_images=get_test_image(D)[0:10,:,0:D,0:D]

res=0
net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)
dict_loaded=torch.load('./model_save/nelloc_rs0h3.pth',map_location=device)
a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]
dict_loaded['in_cnn.weight']=a.clone()
net.load_state_dict(dict_loaded,strict=False)
test(net,test_images, D=D,O=o,K=(kh,kw))

100%|██████████| 10/10 [00:12<00:00,  1.20s/it]

average compression time 0.5880004644393921
average decompression time 0.6124924182891845
average BPD 3.0521484375000005





In [8]:
D=128
test_images=get_test_image(D)[0:10,:,0:D,0:D]

res=0
net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)
dict_loaded=torch.load('./model_save/nelloc_rs0h3.pth',map_location=device)
a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]
dict_loaded['in_cnn.weight']=a.clone()
net.load_state_dict(dict_loaded,strict=False)
test(net,test_images, D=D,O=o,K=(kh,kw))

100%|██████████| 10/10 [00:32<00:00,  3.24s/it]

average compression time 1.5599421262741089
average decompression time 1.6827704668045045
average BPD 2.9345642089843755





In [6]:
import PIL
D=1024
test_img1=torch.tensor(np.asarray(PIL.Image.open('img-1024/1.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)
test_img2=torch.tensor(np.asarray(PIL.Image.open('img-1024/2.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)
test_img3=torch.tensor(np.asarray(PIL.Image.open('img-1024/3.png').convert('RGB')),dtype=torch.int32).permute(2,0,1).reshape(1,3,1024,1024)
test_images=torch.cat((test_img1,test_img2,test_img3),0)
print(test_images.size())`

res=0
net = LocalPixelCNN( res_num=res, kernel_size = [kh,kw], out_channels=mix_num*10).to(device)
dict_loaded=torch.load('./model_save/nelloc_rs0h3.pth',map_location=device)
a=shear(dict_loaded['in_cnn.weight'],offset=o)[:,:,:kh,:kw]
dict_loaded['in_cnn.weight']=a.clone()
net.load_state_dict(dict_loaded,strict=False)
test(net,test_images, D=D,O=o,K=(kh,kw))

torch.Size([3, 3, 1024, 1024])


100%|██████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [06:47<00:00, 135.76s/it]

average compression time 62.75702730814616
average decompression time 73.00248901049297
average BPD 2.223701265123155



