In [3]:
# import relevant packages
import numpy as np 
import torch
import matplotlib.pyplot as plt
from imageio import imread, imwrite
from torch import nn
import random
import argparse
from PIL import Image
from skimage.metrics import peak_signal_noise_ratio
from skimage.metrics import structural_similarity
from steganogan import SteganoGAN

import torch
from torch.optim import LBFGS
import torch.nn.functional as F

# set seed
seed = 11111
np.random.seed(seed)

ModuleNotFoundError: No module named 'DiffJPEG'

In [None]:
# set paramaters
# The mode can be random, pretrained-de or pretrained-d. Refer to the paper for details
mode = "pretrained-d"
steps = 2000
max_iter = 10
alpha = 0.1
eps = 0.3
num_bits = 1

# some pre-trained steganoGAN models can be found here: https://drive.google.com/drive/folders/1-U2NDKUfqqI-Xd5IqT1nkymRQszAlubu?usp=sharing
model_path = "/home/vk352/FaceDetection/SteganoGAN/research/models/celeba_basic_1_1_mse10.steg"


In [None]:
steganogan = SteganoGAN.load(path=model_path, cuda=True, verbose=True)
input_im = "/home/vk352/FaceDetection/datasets/div2k/val/512/0801.jpg"
output_im = "steganographic.png"

In [None]:
inp_image = imread(input_im, pilmode='RGB')

# you can add a custom target message here 
target = torch.bernoulli(torch.empty(1, num_bits, inp_image.shape[1], inp_image.shape[0]).uniform_(0, 1)).to('cuda')

steganogan.encode(input_im, output_im, target)
output = steganogan.decode(output_im)

if mode == "pretrained-de":
    image = output_im
else:
    image = input_im

image = imread(image, pilmode='RGB') / 255.0
image = torch.FloatTensor(image).permute(2, 1, 0).unsqueeze(0)
image = image.to('cuda')

In [None]:
#initial statistics:

im1 = np.array(imread(input_im, pilmode='RGB')).astype(float)
im2 = np.array(imread(output_im, pilmode='RGB')).astype(float)
print("PSNR:", peak_signal_noise_ratio(im1, im2, data_range=255))
print("SSIM:",structural_similarity(im1, im2, data_range=255, multichannel=True))
err = ((target !=output.float()).sum().item()+0.0)/target.numel()
print("Iniitial error:", err)

In [None]:
# FNNS Optimization
model = steganogan.decoder 
criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')


out = model(image)
target = target.to(out.device)

count = 0

adv_image = image.clone().detach()

for i in range(steps // max_iter):
    adv_image.requires_grad = True
    optimizer = LBFGS([adv_image], lr=alpha, max_iter=max_iter)

    def closure():
        outputs = model(adv_image)
        loss = criterion(outputs, target)


        optimizer.zero_grad()
        loss.backward()
        return loss

    optimizer.step(closure)
    delta = torch.clamp(adv_image - image, min=-eps, max=eps)
    adv_image = torch.clamp(image + delta, min=0, max=1).detach()

    err = len(torch.nonzero((model(adv_image)>0).float().view(-1) != target.view(-1))) / target.numel()
    print("Error:", err)
    if err < 0.00001: eps = 0.7
    if err==0: count+=1; eps = 0.3
    if count==10: break

In [None]:
# print final statistics

print("PSNR:", peak_signal_noise_ratio(np.array(imread(input_im, pilmode='RGB')).astype(float), (adv_image.squeeze().permute(2,1,0)*255).detach().cpu().numpy(), data_range=255))
print("SSIM:", structural_similarity(np.array(imread(input_im, pilmode='RGB')).astype(float), (adv_image.squeeze().permute(2,1,0)*255).detach().cpu().numpy(), data_range=255, multichannel=True))
print("Error:", err)
lbfgsimg = (adv_image.cpu().squeeze().permute(2,1,0).numpy()*255).astype(np.uint8)

Image.fromarray(lbfgsimg).save(output_im)
image_read = imread(output_im, pilmode='RGB') / 255.0
image_read = torch.FloatTensor(image_read).permute(2, 1, 0).unsqueeze(0).to('cuda')

print("\nAfter writing to file and reading from file")
im1 = np.array(imread(input_im, pilmode='RGB')).astype(float)
im2 = np.array(imread(output_im, pilmode='RGB')).astype(float)
print("PSNR:", peak_signal_noise_ratio(im1, im2, data_range=255))
print("SSIM:", structural_similarity(im1, im2, data_range=255, multichannel=True))
print("Error:", len(torch.nonzero((model(image_read)>0).float().view(-1) != target.view(-1))) / target.numel())