In [1]:
import numpy as np 
import torch
import matplotlib.pyplot as plt
from imageio import imread, imwrite
from torch import nn
import random

In [2]:
import sys
sys.path.append("..")
# from steganogan.decoders import DenseDecoderNLayers
from steganogan.decoders import BasicDecoder, DenseDecoder, DenseDecoderNLayers
from steganogan import SteganoGAN

In [3]:
import torch
import torchvision
from torch.optim import LBFGS
import torch.nn.functional as F

In [4]:
from tqdm import tqdm, trange

In [5]:
def shuffle_params(m):
    if type(m)==nn.Conv2d or type(m)==nn.BatchNorm2d:
        param = m.weight
        m.weight.data = nn.Parameter(torch.tensor(np.random.normal(0, 1, param.shape)).float())
        
        param = m.bias
        m.bias.data = nn.Parameter(torch.zeros(len(param.view(-1))).float().reshape(param.shape))
    if type(m)==nn.BatchNorm2d:
        if "track_running_stats" in m.__dict__:
            m.track_running_stats=False
#         

In [6]:
class normLayer(nn.Module):
    def __init__(self):
        super(normLayer, self).__init__()
    def forward(self, x):
        b,c,h,w = x.shape
        assert b == 1
        mean = x.view(c, -1).mean(-1)
        std = x.view(c, -1).std(-1)
        x = x - mean.reshape([1, c, 1, 1])
        x = x / (std + 1e-7).reshape([1,c,1,1])
        return x

In [7]:
class BasicDecoder(nn.Module):
    """
    The BasicDecoder module takes an steganographic image and attempts to decode
    the embedded data tensor.

    Input: (N, 3, H, W)
    Output: (N, D, H, W)
    """

    def _conv2d(self, in_channels, out_channels):
        return nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            padding=1
        )

    def _build_models(self):
        modules = []

        modules.append(self._conv2d(3, self.hidden_size))
        modules.append(nn.LeakyReLU(inplace=True))
        modules.append(normLayer() if self.yan_norm else nn.BatchNorm2d(self.hidden_size))

        for i in range(self.layers-1):
            modules.append(self._conv2d(self.hidden_size, self.hidden_size))
            modules.append(nn.LeakyReLU(inplace=True))
            modules.append(normLayer() if self.yan_norm else nn.BatchNorm2d(self.hidden_size))

        modules.append(self._conv2d(self.hidden_size, self.data_depth))

        self.layers = nn.Sequential(*modules)

        return [self.layers]    

    def __init__(self, data_depth, hidden_size, layers = 3, yan_norm=False):
        super().__init__()
        self.version = '1'
        self.data_depth = data_depth
        self.hidden_size = hidden_size
        self.yan_norm = yan_norm
        self.layers = layers

        self._models = self._build_models()

    def forward(self, x):
        x = self._models[0](x)

        if len(self._models) > 1:
            x_list = [x]
            for layer in self._models[1:]:
                x = layer(torch.cat(x_list, dim=1))
                x_list.append(x)

        return x

In [8]:
num_bits = 3
yan_norm = False
# models

model = BasicDecoder(num_bits, hidden_size=128, layers = 3, yan_norm=yan_norm)
model.apply(shuffle_params)
model.to('cuda')

BasicDecoder(
  (layers): Sequential(
    (0): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.01, inplace)
    (2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.01, inplace)
    (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (6): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): LeakyReLU(negative_slope=0.01, inplace)
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
    (9): Conv2d(128, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [9]:
model.layers[2].weight.requires_grad

True

In [10]:
# load image
image = "/home/vk352/FaceDetection/datasets/sample/obama2.jpg"
image = imread(image, pilmode='RGB') 
image.shape

(612, 450, 3)

In [11]:
# extract a bit vector
# image = "/home/vk352/FaceDetection/datasets/sample/obama2.jpg"
image = "/home/vk352/FaceDetection/datasets/data512x512/00001.jpg"
image = imread(image, pilmode='RGB') / 255.0
image = torch.FloatTensor(image).permute(2, 1, 0).unsqueeze(0)
image = image.to('cuda')
out = model(image)
# image = self.decoder(image).view(-1) > 0

In [12]:
target = torch.bernoulli(torch.empty(out.shape).uniform_(0, 1)).to(out.device)
# target = torch.empty(out.shape).random_(256).to(out.device)
target.shape

torch.Size([1, 3, 512, 512])

In [13]:
criterion = torch.nn.BCEWithLogitsLoss(reduction='sum')
criterion1 = torch.nn.L1Loss(reduction='sum')
criterion2 = torch.nn.MSELoss(reduction='sum')

def get_loss(outputs, target, loss_mode):
    if loss_mode == "BCE":
        loss = criterion(outputs, target)
    elif loss_mode == "log":
        loss = -(target * 2 - 1) * outputs
        loss = torch.nn.functional.softplus(loss)  # log(1+exp(x))
        loss = torch.sum(loss)
    elif loss_mode == "hingelog":
        loss = -(target * 2 - 1) * outputs
        loss = torch.nn.functional.softplus(loss)  # log(1+exp(x))
        loss = torch.max(loss-hinge, torch.zeros(target.shape).to(target.device))
        loss = torch.sum(loss)
    elif loss_mode == "L1":
        outputs = F.sigmoid(outputs) * 255
        loss = criterion1(outputs, target)
    elif loss_mode == "L2":
        outputs = F.sigmoid(outputs) * 255
        loss = criterion2(outputs, target)
    return loss

In [14]:
loss_mode = "log"
hinge = 0.3

In [15]:
# lbfgs

final_err = 0

steps = 1000
eps = 0.2
adv_image = image.clone().detach()
max_iter = 20
alpha = 0.5


adv_image = image.clone().detach()
print("alpha:", alpha)
error = []

for i in trange(steps // max_iter):
    adv_image.requires_grad = True
    optimizer = LBFGS([adv_image], lr=alpha, max_iter=max_iter)

    def closure():
        outputs = model(adv_image)
        loss = get_loss(outputs, target, loss_mode)

        optimizer.zero_grad()
        loss.backward()
        return loss

    optimizer.step(closure)
    delta = torch.clamp(adv_image - image, min=-eps, max=eps)
    adv_image = torch.clamp(image + delta, min=0, max=1)
    adv_image = torch.clamp(adv_image*255, 0, 255).int().float()/255.
    adv_image = adv_image.detach()

    if loss_mode in ["L1", "L2"]:
        err = len(torch.nonzero(torch.abs(F.sigmoid(model(adv_image)).float().view(-1)*255-target.view(-1)) > 128)) / target.numel()
    else:
        err = len(torch.nonzero((model(adv_image)>0).float().view(-1) != target.view(-1))) / target.numel()
    print("error", err)
    error.append(err)

final_err = error[-1]


  0%|          | 0/50 [00:00<?, ?it/s]

alpha: 0.5


  2%|▏         | 1/50 [00:01<01:09,  1.41s/it]

error 0.2351099650065104


  4%|▍         | 2/50 [00:02<01:06,  1.39s/it]

error 0.14502588907877603


  6%|▌         | 3/50 [00:04<01:05,  1.39s/it]

error 0.09476470947265625


  8%|▊         | 4/50 [00:05<01:03,  1.39s/it]

error 0.06740824381510417


 10%|█         | 5/50 [00:06<01:02,  1.39s/it]

error 0.049496968587239586


 12%|█▏        | 6/50 [00:08<01:00,  1.39s/it]

error 0.039553324381510414


 14%|█▍        | 7/50 [00:09<00:59,  1.39s/it]

error 0.03298314412434896


 16%|█▌        | 8/50 [00:11<00:58,  1.39s/it]

error 0.028635660807291668


 18%|█▊        | 9/50 [00:12<00:56,  1.39s/it]

error 0.02581024169921875


 20%|██        | 10/50 [00:13<00:55,  1.39s/it]

error 0.023907979329427082


 22%|██▏       | 11/50 [00:15<00:54,  1.39s/it]

error 0.022412618001302082


 24%|██▍       | 12/50 [00:16<00:52,  1.39s/it]

error 0.021086374918619793


 26%|██▌       | 13/50 [00:18<00:51,  1.39s/it]

error 0.020894368489583332


 28%|██▊       | 14/50 [00:19<00:50,  1.39s/it]

error 0.019803365071614582


 30%|███       | 15/50 [00:20<00:48,  1.39s/it]

error 0.019182840983072918


 32%|███▏      | 16/50 [00:22<00:47,  1.39s/it]

error 0.019276936848958332


 34%|███▍      | 17/50 [00:23<00:46,  1.40s/it]

error 0.0188751220703125


 36%|███▌      | 18/50 [00:25<00:44,  1.40s/it]

error 0.018503824869791668


 38%|███▊      | 19/50 [00:26<00:43,  1.40s/it]

error 0.018362681070963543


 40%|████      | 20/50 [00:27<00:41,  1.40s/it]

error 0.018404642740885418


 42%|████▏     | 21/50 [00:29<00:40,  1.40s/it]

error 0.018187204996744793


 44%|████▍     | 22/50 [00:30<00:39,  1.40s/it]

error 0.018213907877604168


 46%|████▌     | 23/50 [00:32<00:37,  1.40s/it]

error 0.018128712972005207


 48%|████▊     | 24/50 [00:33<00:36,  1.40s/it]

error 0.01802825927734375


 50%|█████     | 25/50 [00:34<00:34,  1.40s/it]

error 0.018062591552734375


 52%|█████▏    | 26/50 [00:36<00:33,  1.40s/it]

error 0.017923990885416668


 54%|█████▍    | 27/50 [00:37<00:32,  1.40s/it]

error 0.018230438232421875


 56%|█████▌    | 28/50 [00:39<00:30,  1.40s/it]

error 0.018352508544921875


 58%|█████▊    | 29/50 [00:40<00:29,  1.40s/it]

error 0.018376668294270832


 60%|██████    | 30/50 [00:41<00:27,  1.40s/it]

error 0.01871490478515625


 62%|██████▏   | 31/50 [00:43<00:26,  1.40s/it]

error 0.018583933512369793


 64%|██████▍   | 32/50 [00:44<00:25,  1.40s/it]

error 0.01902008056640625


 66%|██████▌   | 33/50 [00:46<00:23,  1.40s/it]

error 0.019003550211588543


 68%|██████▊   | 34/50 [00:47<00:22,  1.40s/it]

error 0.019502003987630207


 70%|███████   | 35/50 [00:48<00:20,  1.40s/it]

error 0.019505818684895832


 72%|███████▏  | 36/50 [00:50<00:19,  1.40s/it]

error 0.019772847493489582


 74%|███████▍  | 37/50 [00:51<00:18,  1.40s/it]

error 0.0202178955078125


 76%|███████▌  | 38/50 [00:53<00:16,  1.40s/it]

error 0.020182291666666668


 78%|███████▊  | 39/50 [00:54<00:15,  1.40s/it]

error 0.020758310953776043


 80%|████████  | 40/50 [00:55<00:14,  1.40s/it]

error 0.020758310953776043


 82%|████████▏ | 41/50 [00:57<00:12,  1.40s/it]

error 0.021291097005208332


 84%|████████▍ | 42/50 [00:58<00:11,  1.40s/it]

error 0.021068572998046875


 86%|████████▌ | 43/50 [01:00<00:09,  1.40s/it]

error 0.0215911865234375


 88%|████████▊ | 44/50 [01:01<00:08,  1.40s/it]

error 0.021814982096354168


 90%|█████████ | 45/50 [01:02<00:07,  1.40s/it]

error 0.022150675455729168


 92%|█████████▏| 46/50 [01:04<00:05,  1.40s/it]

error 0.022486368815104168


 94%|█████████▍| 47/50 [01:05<00:04,  1.40s/it]

error 0.023042043050130207


 96%|█████████▌| 48/50 [01:07<00:02,  1.40s/it]

error 0.023382822672526043


 98%|█████████▊| 49/50 [01:08<00:01,  1.41s/it]

error 0.023919423421223957


100%|██████████| 50/50 [01:09<00:00,  1.40s/it]

error 0.024180094401041668





In [17]:
# plt.figure(figsize=(16, 10))
# # plt.plot(range(1000), np.array(bceacc) * 100, label="SGD")
# # for lr, err in final_err:
# plt.plot(np.arange(1, steps // max_iter + 1) * max_iter, np.array(err[:steps // max_iter]) * 100, label=f"LBFGS lr {lr}")
# plt.legend()
# plt.ylabel("Error Rate (%)")
# plt.xlabel("iterations")
# plt.title(f"{loss_mode} loss, {num_bits} bits, yan_norm {yan_norm}")
# # plt.ylim(0, 0.01)
# plt.show()