## Noise Ground:

Having trained some networks, here we are trying to play around with the codes by adding some ambiguization noise and see their decoded outputs.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import  DataLoader
from torchvision import transforms, utils
import torch.nn.functional as F
#
from cv2 import imwrite, imread, IMWRITE_JPEG_QUALITY, COLOR_BGR2RGB
import os
#
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 
import dataTools as D
import Tools as T

from skimage.measure import compare_psnr,compare_ssim

%precision 5

In [None]:
database_name = 'CelebA'
############################
if database_name == 'CelebA':
    batch_size = 100
    root = 'path/to/data/CelebA/128_crop/'
    im_size = (128,128,3)
elif database == 'CYale':
    root = 'path/to/data/CYale'
    im_size = (168, 192,1)
#######################################
num_channel = im_size[2]
device = torch.device("cuda:1")

### In order not to mix test with train, here I'm reading the list of corresponding test images from a file I have saved before.

In [None]:
fwd_transform = transforms.Compose([
    transforms.ToTensor(),  
                ])

In [None]:
test_names_path = './state_dict/2019-10-14/test_set_CelebA_filts40-40-40-40-40-10_scale1-2-1-2-1-2_codes20_dim512_k128.txt'
dataset = D.dataRead_fromName(root,im_size,test_names_path, transform=fwd_transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1)

In [None]:
##########################################  Loading a trained model ##############################
num_blocks = 6
num_filts = [40, 40, 40, 40, 40, 10]
scale_factor = [1, 2, 1, 2, 1, 2]
num_codes = 20
neck_dim = 512
k = 128
##################################
def list2str(L):
    s = ''
    for i,l in enumerate(L):
        s += str(l)
        if i < len(L) - 1:
            s += '-'
    return s
#################################
from models import Autoencoder
net = Autoencoder((im_size[2], im_size[0], im_size[1]), 
                  num_blocks, num_filts, scale_factor, num_codes, neck_dim, k).to(device)
#########################
model_name = database_name + \
            '_filts' + list2str(num_filts)+ \
            '_scale' + list2str(scale_factor) +\
            '_codes' + str(num_codes) +\
            '_dim' + str(neck_dim) +\
            '_k' + str(k) + \
            '.pth'

print(model_name)
date = '2019-10-14'
##########################
net.load_state_dict(torch.load(os.path.join('./state_dict/', date, model_name)))
net.eval()

In [None]:
############################### To run the network  on the test set #######################
with torch.no_grad():
    for i, inputs in enumerate(dataloader):
        inputs = inputs['image'].to(device) 
        outputs, code = net(inputs)
        outputs.sigmoid_()
        break
            

In [None]:
plt.figure()
plt.imshow(inputs[0,:,:,:].squeeze(0).transpose(0,2).cpu().numpy())
plt.figure()
with torch.no_grad():
    plt.imshow(outputs[0,:,:,:].squeeze(0).transpose(0,2).cpu().numpy())

In [None]:
############################ Decoding with less sparsity:
with torch.no_grad():
    code_k64 = torch.clone(code).cpu()
    code_k64 = T.KBest(code_k64, 64)
    outputs_k64 = net.decoder(code_k64.to(device)).sigmoid_()
    plt.imshow(outputs_k64[0,:,:,:].squeeze(0).transpose(0,2).cpu().numpy())

In [None]:
############################ Putting some other code-maps to zero
with torch.no_grad():
    code_prime = torch.clone(code).cpu()
    i_s = 0
    i_e = 12
    code_prime[:,i_s:i_e,:] = 0
    outputs_prime2 = net.decoder(code_prime.to(device)).sigmoid_()
    plt.imshow(outputs_prime2[0,:,:,:].squeeze(0).transpose(0,2).cpu().numpy())

In [None]:
############################ Putting some code-maps to zero
with torch.no_grad():
    code_prime = torch.clone(code).cpu()
    i_s = 8
    i_e = 20
    code_prime[:,::3,:] = 0
    outputs_prime1 = net.decoder(code_prime.to(device)).sigmoid_()
    plt.imshow(outputs_prime1[0,:,:,:].squeeze(0).transpose(0,2).cpu().numpy())

In [None]:
############################ Ambiguation noise on the complement of support
with torch.no_grad():
    code_tilde = torch.clone(code).cpu()
    code_tilde = T.ambiguate(code_hat, k)
    outputs_tilde = net.decoder(code_tilde.to(device)).sigmoid_()
    plt.imshow(outputs_tilde[0,:,:,:].squeeze(0).transpose(0,2).cpu().numpy())

In [None]:
########################## Randomly selecting k out of k':
with torch.no_grad():
    code_hat = torch.clone(code_tilde).cpu()
    code_hat = T.random_guess(code_hat, k)
    
    outputs_hat = net.decoder(code_hat.to(device)).sigmoid_()
    plt.imshow(outputs_hat[0,:,:,:].squeeze(0).transpose(0,2).cpu().numpy())
    


In [None]:
print(code.shape)
#plt.plot(code[1,0,:].cpu().numpy())
plt.plot(code[10,5,0:200].cpu().numpy())
plt.plot(code_k64[10,5,0:200].cpu().numpy())
#plt.plot(code_tilde[10,5,0:200].cpu().numpy())
#plt.plot(code_hat[10,5,0:200].cpu().numpy())

In [None]:
print(torch.nonzero(code).shape)
print(torch.nonzero(code_tilde).shape)
print(torch.nonzero(code_hat).shape)

In [None]:
print(code.view(-1).nonzero().shape)

In [None]:
plt.hist(code[10,1,code[10,1,:].cpu().nonzero()].reshape(-1),50)[2]


In [None]:
plt.plot(outputs[0,:,:,:].view(-1,1).cpu().numpy()[7000:8000])
plt.plot(outputs_tilde[0,:,:,:].view(-1,1).cpu().numpy()[7000:8000])
plt.plot(outputs_hat[0,:,:,:].view(-1,1).cpu().numpy()[7000:8000])

In [None]:
#### Saving some samples:
ind_i = 30
utils.save_image(inputs[ind_i:ind_i + 8,:,:,:],'samples/inputs.png',nrow=8)
utils.save_image(outputs[ind_i:ind_i + 8,:,:,:],'samples/outputs.png',nrow=8)
utils.save_image(outputs_k64[ind_i:ind_i + 8,:,:,:],'samples/outputs_k64.png',nrow=8)
utils.save_image(outputs_tilde[ind_i:ind_i + 8,:,:,:],'samples/tildes.png',nrow=8)
utils.save_image(outputs_hat[ind_i:ind_i + 8,:,:,:],'samples/hats.png',nrow=8)
utils.save_image(outputs_prime1[ind_i:ind_i + 8,:,:,:],'samples/primes_1.png',nrow=8)
utils.save_image(outputs_prime2[ind_i:ind_i + 8,:,:,:],'samples/primes_2.png',nrow=8)

#### Quality assessment:

* PSNR
* SSIM

In [None]:
print(torch.norm(inputs - outputs).pow(2)/torch.norm(inputs).pow(2))

In [None]:
psnr_outputs = []
ssim_outputs = []

psnr_outputs_k64 = []
ssim_outputs_k64 = []

psnr_tildes = []
ssim_tildes = []

psnr_hats = []
ssim_hats = []

for i in range(inputs.shape[0]):
    psnr_outputs.append(
        compare_psnr(
        inputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(),
        outputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy()) )
    
    ssim_outputs.append(
        compare_ssim(
        inputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(),
        outputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(), multichannel=True) )
    
    
    psnr_outputs_k64.append(
        compare_psnr(
        inputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(),
        outputs_k64[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy()) )
    
    ssim_outputs_k64.append(
        compare_ssim(
        inputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(),
        outputs_k64[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(), multichannel=True) )
    
    
    
    psnr_tildes.append(
        compare_psnr(
        inputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(),
        outputs_tilde[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy()) )
    
    ssim_tildes.append(
        compare_ssim(
        inputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(),
        outputs_tilde[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(), multichannel=True) )
    
    
    psnr_hats.append(
        compare_psnr(
        inputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(),
        outputs_hat[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy()) )
    
    ssim_hats.append(
        compare_ssim(
        inputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(),
        outputs_hat[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy(), multichannel=True) )

In [None]:
print(np.mean(psnr_outputs))
print(np.mean(ssim_outputs))

print(np.mean(psnr_outputs_k64))
print(np.mean(ssim_outputs_k64))


print(np.mean(psnr_tildes))
print(np.mean(ssim_tildes))

print(np.mean(psnr_hats))
print(np.mean(ssim_hats))

In [None]:
psnr_jpg = []
ssim_jpg = []
for i in range(inputs.shape[0]):
    
    img = inputs[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy() 
    imwrite('tmp.jpg', img * 256, [IMWRITE_JPEG_QUALITY, 4])
    image = imread('tmp.jpg' , COLOR_BGR2RGB)
    image = image.astype('float32')/256
    psnr_jpg.append(compare_psnr(img, image))
    ssim_jpg.append(compare_ssim(img, image, multichannel=True))

In [None]:
plt.imshow(image)

In [None]:
print(np.mean(psnr_jpg))
print(np.mean(ssim_jpg))

In [None]:
def calculate_KBytes(m, k, L):
    H = -(k/m) * np.log2((k/m)) - (1 - (k/m)) * np.log2(1 - (k/m))
    return H * m * L  /(8 * 1024)
    
def calculate_psnr(m, k, L, im_size):
    H = -(k/m) * np.log2((k/m)) - (1 - (k/m)) * np.log2(1 - (k/m))
    return H * m * L / np.prod(im_size)

print(calculate_KBytes(512, 128, 20))