## Noise Ground:

Having trained some networks, here we are trying to play around with the codes by adding some ambiguization noise and see their decoded outputs.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import  DataLoader
from torchvision import utils
#
import dataTools as D
import tools as T
#
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 
from skimage.measure import compare_psnr,compare_ssim
from cv2 import imwrite, imread, IMWRITE_JPEG_QUALITY, COLOR_BGR2RGB
%precision 4

In [None]:
# Set some global constants:
batch_size = 1000
device = torch.device("cuda:1")

### Select your database:

In [None]:
database_name = 'CelebA'
############################
if database_name == 'CelebA':
    batch_size = 100
    #root = '/path/to/CelebA/128_crop/'
    root = '/path/to/CelebA/128_crop/'
    img_names_list_test = './dataset_splits/CelebA/CelebA_test.txt'
    img_size = (3, 128, 128)
elif database_name == 'CYale':
    root = '/path/to/CYale/'
    img_names_list_test = './dataset_splits/CYale/CYale_test.txt'  
    img_size = (1, 168, 192)
num_channel = img_size[0]   

### Initialize the database class for the test split:

In [None]:
dataset = D.imgRead_fromList(root, img_names_list_test, img_size)
# Initialize the mini-batch dataloader:
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

### Loading the trained model:

In [None]:
model_path = './weights/CelebA_filts40-40-40-40-40-10_scale1-2-1-2-1-2_codes20_dim512_k256_stmp1587121097.082492.pth'
##########################################  Loading a trained model ##############################
num_blocks = 6
num_filts = [40, 40, 40, 40, 40, 10]
scale_factor = [1, 2, 1, 2, 1, 2]
num_codes = 20
neck_dim = 512
k = 128
#######
from models import Autoencoder
net = Autoencoder(img_size, num_blocks, num_filts, scale_factor, num_codes, neck_dim, k)
net.load_state_dict(torch.load(model_path))
net.eval()
net.to(device)

### Run the network on the test set:

In [None]:
# Here we load only one mini-batch:
with torch.no_grad():
    for i, inp in enumerate(dataloader):
        inp = inp['image'].to(device) 
        out, code = net(inp)
        out.sigmoid_()
        break
            

### Some basic visualization and evaluation of the reconstruction performance:

In [None]:
idx = 20 # To show
ind_i = 10  # To save
D.imShow(inp, idx=idx)
D.imShow(out, idx=idx)

print(torch.norm(inp - out).pow(2)/torch.norm(inp).pow(2))
utils.save_image(inp[ind_i:ind_i + 8,:,:,:], 'samples/inputs.png', nrow=8)
utils.save_image(out[ind_i:ind_i + 8,:,:,:], 'samples/reconstructed.png', nrow=8)

### To see the effect of sparsity on the reconstruction performance:

In [None]:
# Decoding with less sparsity:
code_prime = torch.clone(code).cpu().detach()
code_prime = T.KBest(code_prime, 64)
out_prime = net.decoder(code_prime.to(device)).sigmoid_()
D.imShow(out_prime, idx=idx)
utils.save_image(out_prime[ind_i:ind_i + 8,:,:,:], 'samples/lessSparsity.png', nrow=8)

### To see the characteristics of code-maps:

In [None]:
# Putting some of the code-maps to zero:
code_prime = torch.clone(code).cpu().detach()
i_s = 0
i_e = 12
code_prime[:,i_s:i_e,:] = 0
out_prime = net.decoder(code_prime.to(device)).sigmoid_()
D.imShow(out_prime, idx=idx)
utils.save_image(out_prime[ind_i:ind_i + 8,:,:,:], 'samples/zeroed.png', nrow=8)

In [None]:
# Putting some other code-maps to zero:
code_prime = torch.clone(code).cpu().detach()
i_s = 12
i_e = 20
code_prime[:,i_s:i_e,:] = 0
out_prime = net.decoder(code_prime.to(device)).sigmoid_()
D.imShow(out_prime, idx=idx)
utils.save_image(out_prime[ind_i:ind_i + 8,:,:,:], 'samples/zeroed2.png', nrow=8)

### Ambiguation noise:

As the main application of the paper, let's add some ambiguating noise to the zero coefficients of the code-maps. The idea is to add minimal and undistinguishable noise, while maximally destroying the content of the resulting decoded image.

In [None]:
# Ambiguation noise on the complement of support:
# Note that the number of added noise values is k_prime - k.
k_prime = 2*k
code_tilde = torch.clone(code).cpu().detach()
code_tilde = T.ambiguate(code_tilde, k_prime=k_prime)
out_tilde = net.decoder(code_tilde.to(device)).sigmoid_()
D.imShow(out_tilde, idx=idx)
utils.save_image(out_tilde[ind_i:ind_i + 8,:,:,:], 'samples/ambiguated.png', nrow=8)

### Attacks:

As a very basic attack, here we randomly put $k$ out of $k'$ non-zero values to zero. Without any extra knowledge, the adversary has to make $k' \choose k$ such guesses to reconstruct the original image.

Can this system be attacked in other more intricate ways? In future works, we will study different attacks.

In [None]:
# Randomly selecting k out of k':
code_hat = torch.clone(code_tilde).cpu().detach()
code_hat = T.random_guess(code_hat, k)    
out_hat = net.decoder(code_hat.to(device)).sigmoid_()
D.imShow(out_hat, idx=idx)
utils.save_image(out_hat[ind_i:ind_i + 8,:,:,:], 'samples/disambiguated.png', nrow=8)

#### Quantitative quality assessment:

* PSNR
* SSIM

In [None]:
psnr_outputs = []
ssim_outputs = []

psnr_tildes = []
ssim_tildes = []

psnr_hats = []
ssim_hats = []

for i in range(inp.shape[0]):
    psnr_outputs.append(
        compare_psnr(
        inp[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy(),
        out[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy()) )
    
    ssim_outputs.append(
        compare_ssim(
        inp[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy(),
        out[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy(), multichannel=True) )
    
    
    
    psnr_tildes.append(
        compare_psnr(
        inp[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy(),
        out_tilde[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy()) )
    
    ssim_tildes.append(
        compare_ssim(
        inp[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy(),
        out_tilde[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy(), multichannel=True) )
    
    
    psnr_hats.append(
        compare_psnr(
        inp[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy(),
        out_hat[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy()) )
    
    ssim_hats.append(
        compare_ssim(
        inp[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy(),
        out_hat[i,:,:,:].squeeze(0).transpose(0,2).cpu().detach().numpy(), multichannel=True) )
    
    
############################    
print(np.mean(psnr_outputs))
print(np.mean(ssim_outputs))


print(np.mean(psnr_tildes))
print(np.mean(ssim_tildes))

print(np.mean(psnr_hats))
print(np.mean(ssim_hats))    

In [None]:
psnr_jpg = []
ssim_jpg = []
for i in range(inp.shape[0]):
    
    img = inp[i,:,:,:].squeeze(0).transpose(0,2).cpu().numpy() 
    imwrite('tmp.jpg', img * 256, [IMWRITE_JPEG_QUALITY, 4])
    image = imread('tmp.jpg' , COLOR_BGR2RGB)
    image = image.astype('float32')/256
    psnr_jpg.append(compare_psnr(img, image))
    ssim_jpg.append(compare_ssim(img, image, multichannel=True))

    
print(np.mean(psnr_jpg))
print(np.mean(ssim_jpg))    

In [None]:
## How many KBytes per compressed image:
print(T.calculate_KBytes(512, 128, 20))

In [None]:
## The ratio of the compressed image to the key:
print(T.calculate_KBytes(512, 128, 20) / T.calculate_KBytes(k_prime, 128, 20) )