In [1]:
import os
import argparse
import numpy as np
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset,DataLoader

import torchvision.utils as utils
import torchvision.transforms as transforms
import torchvision.models as models


import copy
import glob
import torch.utils.data as udata
import h5py
import cv2
import random
import pywt
import matplotlib.pyplot as plt
from PIL import Image
import math
from skimage.measure.simple_metrics import compare_psnr
import scipy.io as sio
import random
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
plt.rcParams['font.sans-serif'] = [u'SimHei']
plt.rcParams['axes.unicode_minus'] = False
%matplotlib inline

In [2]:
#本代码，直接将风格合成结果输入残差网络，直接以原图和西电图之间的均方差作为损失，得到西电图
opts = argparse.Namespace()
opts.image_channel = 3 #为什么通道是6
opts.batch_size = 4
opts.lr = 1e-3
opts.target_path = 'l2.png' #目标图 
use_cuda = torch.cuda.is_available()
if use_cuda:
    device = "cuda:0" 
    dtype = torch.cuda.FloatTensor 
    imsize = 128 
else:
    device = "cpu"
    dtype = torch.FloatTensor
    # desired size of the output image
    imsize = 128  # use small size if no gpu
print(use_cuda)
print(device)

True
cuda:0


In [3]:

loader = transforms.Compose([
    transforms.Resize([imsize, imsize]),  # scale imported image
    transforms.ToTensor()])  # transform it into a torch tensor


def image_loader(image_name):
    image = Image.open(image_name)
    image = Variable(loader(image))
    # fake batch dimension required to fit network's input dimensions
    image = image.unsqueeze(0)
    return image

unloader = transforms.ToPILImage()  # reconvert into PIL image
def imshow(tensor, title=None):
    image = tensor.clone().cpu()  # we clone the tensor to not do changes on it
    image = image.view(3, imsize, imsize)  # remove the fake batch dimension
    image = unloader(image)
    plt.imshow(image)
    if title is not None:
        plt.title(title)
    plt.pause(0.001) # pause a bit so that plots are updated
    
def image_unloader(tensor):
    image = tensor.clone().cpu()  # we clone the tensor to not do changes on it
    image = image.view(3, imsize, imsize)  # remove the fake batch dimension
    image = unloader(image)
    return image

def pil2tensor(pil_img):
    image = Variable(loader(pil_img))
    # fake batch dimension required to fit network's input dimensions
    image = image.unsqueeze(0)
    return image

In [4]:
mat = sio.loadmat('ver') #加载数据库
x = mat['img'] #图片
x_name = mat['img_name'].squeeze() #图片名称

In [5]:
# 3x3 convolution  网络部分 终于要看了
#残差网络
def conv3x3(in_channels, out_channels, stride=1):
    return nn.Conv2d(in_channels, out_channels, kernel_size=3, 
                     stride=stride, padding=1, bias=False)

# Residual block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        
    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# ResNet
class DenoiseNet2(nn.Module):
    def __init__(self, block):
        super(DenoiseNet2, self).__init__()
        self.e1 = nn.Sequential(
            # param [input_c, output_c, kernel_size, stride, padding]
            nn.Conv2d(3, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )# 64,32,32
        self.e2 = nn.Sequential(
            # param [input_c, output_c, kernel_size, stride, padding]
            nn.Conv2d(64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )# 128,32,32
        self.l1 = self.make_block(block, 128, 128)
        self.l2 = nn.Sequential(
            # param [input_c, output_c, kernel_size, stride, padding]
            nn.Conv2d(128, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )# 64,32,32
        self.l3 = nn.Sequential(
            # param [input_c, output_c, kernel_size, stride, padding]
            nn.Conv2d(128 + 64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )# 64,32,32
        self.l4 = nn.Sequential(
            # param [input_c, output_c, kernel_size, stride, padding]
            nn.Conv2d(128 + 64 + 64, 64, 3, 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )# 64,32,32
        self.l5 = nn.Sequential(
            # param [input_c, output_c, kernel_size, stride, padding]
            nn.Conv2d(128 + 64 + 64 + 64, 64, 1, 1, 0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )# 1, 128, 128
        self.l6 = self.make_block(block, 128 + 64, 128 + 64)
        self.l7 = self.make_block(block, 128 + 64, 128 + 64)
        self.l8 = nn.Sequential(
            # param [input_c, output_c, kernel_size, stride, padding]
            nn.Conv2d(128 + 64, 128, 3, 1, 1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
        )# 64,32,32
        self.l9 = nn.Sequential(
            nn.ConvTranspose2d(128 + 128, 64, 3, 2 , 1, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
        )# 64, 64, 64
        self.l10 = nn.Sequential(
            # param [input_c, output_c, kernel_size, stride, padding]
            nn.Conv2d(64, 3, 1, 1, 0),
            nn.Sigmoid()
        )# 1, 64, 64
    
    def make_block(self, block, in_c, out_c):
        layers = []
        layers.append(block(in_c, out_c, 1, False))
        return nn.Sequential(*layers)

    
    def forward(self, x):
        e1 = self.e1(x)
        e2 = self.e2(e1)
        l1 = self.l1(e2)
        l2 = self.l2(l1)
        in_l3 = torch.cat((l1, l2), dim = 1)
        l3 = self.l3(in_l3)
        in_l4 = torch.cat((l1, l2, l3), dim = 1)
        l4 = self.l4(in_l4)
        in_l5 = torch.cat((l1, l2, l3, l4), dim = 1)
        l5 = self.l5(in_l5)
        in_l6 = torch.cat((l1, l5), dim = 1)
        l6 = self.l6(in_l6)
        l7 = self.l7(l6)
        l8 = self.l8(l7)
        in_l9 = torch.cat((e2, l8), dim = 1)
        l9 = self.l9(in_l9)
        l10 = self.l10(l9)
        return l10

In [6]:

net = DenoiseNet2(ResidualBlock).to(device)

In [7]:
net.load_state_dict(torch.load('extract_l2_1.pth'))

<All keys matched successfully>

In [8]:
source_path = 'style' #资源，保存的是全部风格图
save_path = './verify/zero-watermark' #训练输入图片 攻击后合成图
save_res_path = './verify/res' #输出结果  全部提取结果
save_source_path = './verify/use_source' #训练合成资源
dataset = x
dataset_name = x_name
for i in range(dataset.shape[0]):
    msave_path = os.path.join(save_path, dataset_name[i].strip())
    msave_res_path = os.path.join(save_res_path, dataset_name[i].strip())
    msave_source_path = os.path.join(save_source_path, dataset_name[i].strip())
    x = torch.tensor(dataset[i]).clone().unsqueeze(0)
    x = x.to(device)
    pred = net(x)
    x = image_unloader(x)
    x.save(msave_path) #攻击后合成图
    pred = image_unloader(pred)
    pred.save(msave_res_path) #提取结果
    source = image_loader(os.path.join(source_path, dataset_name[i].strip()))
    source = image_unloader(source)
    source.save(msave_source_path) #我觉得有点不对劲

In [9]:
# IMAGE_PATH = './st_test'
# # SAVE_PATH = './res'
# # source_path = 'res1' #资源，保存的是全部风格图
# # save_path = './test1' #训练输入图片 攻击后合成图
# # save_res_path = './test_res1' #输出结果  全部提取结果
# # save_source_path = './test_source1' #训练合成资源

# image_pathes = os.listdir(IMAGE_PATH)
# images = []
# for img_path in image_pathes:
#     if os.path.isdir(img_path):
#         continue
#     tmp = image_loader(os.path.join(IMAGE_PATH,img_path)).detach().to(device)
#     pred = net(tmp)
#     print('_____________________________________')
#     imshow(tmp,"input")
#     imshow(pred,"output")