In [1]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import net

def train_transform():
    transform_list = [
        transforms.Resize(size=(256,256)),
        transforms.ToTensor()
    ]
    return transforms.Compose(transform_list)


class MyDataset(data.Dataset):
    def __init__(self, path, transform):
        super(MyDataset, self).__init__()
        self.path = path
        self.paths = list(Path(self.path).glob('*.jpg'))
        self.transform = transform

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(str(path)).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)

    def name(self):
        return 'MyDataset'
network = net.Net()
# 加载预训练参数
network.load_state_dict(torch.load("network.pth"))
print(network)


Net(
  (enc_1): Sequential(
    (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (1): ReflectionPad2d((1, 1, 1, 1))
    (2): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
  )
  (enc_2): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (2): ReLU()
    (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
    (4): ReflectionPad2d((1, 1, 1, 1))
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU()
  )
  (enc_3): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (2): ReLU()
    (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
    (4): ReflectionPad2d((1, 1, 1, 1))
    (5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU()
  )
  (enc_4): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(256, 256,

In [2]:
print(network)

Net(
  (enc_1): Sequential(
    (0): Conv2d(3, 3, kernel_size=(1, 1), stride=(1, 1))
    (1): ReflectionPad2d((1, 1, 1, 1))
    (2): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1))
    (3): ReLU()
  )
  (enc_2): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (2): ReLU()
    (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
    (4): ReflectionPad2d((1, 1, 1, 1))
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU()
  )
  (enc_3): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (2): ReLU()
    (3): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=(0, 0), dilation=1, ceil_mode=True)
    (4): ReflectionPad2d((1, 1, 1, 1))
    (5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1))
    (6): ReLU()
  )
  (enc_4): Sequential(
    (0): ReflectionPad2d((1, 1, 1, 1))
    (1): Conv2d(256, 256,

In [3]:
tf=train_transform()
content_dataset = MyDataset(f"content", tf)
style_dataset = MyDataset(f"style", tf)

content_iter = iter(data.DataLoader(
    content_dataset, batch_size=1,shuffle=True))
style_iter = iter(data.DataLoader(
    style_dataset, batch_size=1,shuffle=True))

In [4]:
from torchvision.utils import save_image
network.to('cuda')
optimizer = torch.optim.Adam(network.parameters())
# 目前为评估模式，如果需要训练则将eval修改为train，同时使用优化器优化网络即可
# 计时判断转换的时间花费，效果可在output文件夹查看
network.eval()
for i in tqdm(range(10)):
    content_images = next(content_iter).to('cuda')
    style_images = next(style_iter).to('cuda')
    loss_c, loss_s, generate_image = network(content_images, style_images,0.5)# rate代表风格迁移的程度,修改rate可以看到loss_c明显下降和loss_s明显增大
    loss_c = 1 * loss_c
    loss_s = 10 * loss_s
    loss = loss_c + loss_s
    print(f"loss_c: {loss_c}, loss_s: {loss_s}，loss: {loss}")
    save_image(generate_image, f"output/{i+1}_generate_image.png")
    save_image(content_images, f"output/{i+1}_content_image.png")
    save_image(style_images, f"output/{i+1}_style_image.png")

 20%|██        | 2/10 [00:01<00:05,  1.44it/s]

loss_c: 0.626849889755249, loss_s: 22.883394241333008，loss: 23.510244369506836
loss_c: 0.9384181499481201, loss_s: 11.691937446594238，loss: 12.630355834960938


 40%|████      | 4/10 [00:01<00:01,  3.25it/s]

loss_c: 0.5885189771652222, loss_s: 17.896041870117188，loss: 18.484560012817383
loss_c: 0.824416995048523, loss_s: 30.870208740234375，loss: 31.694625854492188


 60%|██████    | 6/10 [00:02<00:00,  5.10it/s]

loss_c: 1.3897268772125244, loss_s: 27.496374130249023，loss: 28.88610076904297
loss_c: 1.3628208637237549, loss_s: 53.09851837158203，loss: 54.46133804321289


 80%|████████  | 8/10 [00:02<00:00,  7.04it/s]

loss_c: 1.4095178842544556, loss_s: 24.254638671875，loss: 25.664155960083008
loss_c: 1.1904929876327515, loss_s: 30.432701110839844，loss: 31.623193740844727
loss_c: 0.8729572296142578, loss_s: 14.61633586883545，loss: 15.489293098449707


100%|██████████| 10/10 [00:02<00:00,  3.95it/s]

loss_c: 0.692363977432251, loss_s: 9.698209762573242，loss: 10.390573501586914



