In [1]:
from pathlib import Path

import torch
import torch.nn as nn
import torch.utils.data as data
from PIL import Image
from torchvision import transforms
from tqdm import tqdm
import torchvision.models as models
import net

def train_transform():
    transform_list = [
        transforms.Resize(size=(256,256)),
        transforms.ToTensor()
    ]
    return transforms.Compose(transform_list)


class FlatFolderDataset(data.Dataset):
    def __init__(self, root, transform):
        super(FlatFolderDataset, self).__init__()
        self.root = root
        self.paths = list(Path(self.root).glob('*'))
        self.transform = transform

    def __getitem__(self, index):
        path = self.paths[index]
        img = Image.open(str(path)).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)

    def name(self):
        return 'FlatFolderDataset'

decoder = net.decoder
#decoder加载decoder.pth
decoder.load_state_dict(torch.load("decoder.pth"))
vgg=models.vgg16(pretrained=True).features
vgg = nn.Sequential(*list(vgg.children())[:31])
network = net.Net(vgg, decoder)
network.train()
print(network)



Net(
  (enc_1): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (enc_2): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU(inplace=True)
    (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
  )
  (enc_3): Sequential(
    (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (2): ReLU(inplace=True)
    (3): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU(inplace=True)
    (5): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
  )
  (enc_4): Sequen

In [6]:
tf=train_transform()
content_dataset = FlatFolderDataset(f"D:\git\AdaIN\content\images", tf)
style_dataset = FlatFolderDataset(f"D:\git\AdaIN\style\images", tf)

content_iter = iter(data.DataLoader(
    content_dataset, batch_size=2))
style_iter = iter(data.DataLoader(
    style_dataset, batch_size=2))

In [8]:
from torchvision.utils import save_image
network.to('cuda')
optimizer = torch.optim.Adam(network.parameters())
for i in tqdm(range(100000)):
    content_images = next(content_iter).to('cuda')
    style_images = next(style_iter).to('cuda')
    loss_c, loss_s, generate_image = network(content_images, style_images)
    loss_c = 1 * loss_c
    loss_s = 20 * loss_s
    loss = loss_c + loss_s

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if (i+1) % 10 == 0:
        print(f"loss_c: {loss_c}, loss_s: {loss_s}")
    if (i+1) % 100 ==0:
        save_image(generate_image, f"output/generate_image_{i+1}.png")
        save_image(content_images, f"output/content_image_{i+1}.png")
        save_image(style_images, f"output/style_image_{i+1}.png")
    if (i+1) % 10000 == 0:
        torch.save(network.state_dict(), f"output/network_{i+1}.pth")

  0%|          | 12/100000 [00:01<4:20:36,  6.39it/s]

loss_c: 1.0574362622245136e-17, loss_s: 0.0005502101848833263


  0%|          | 20/100000 [00:03<5:45:25,  4.82it/s]

loss_c: 3.294958991999841e-18, loss_s: 0.0002924492582678795


  0%|          | 32/100000 [00:05<4:19:58,  6.41it/s]

loss_c: 1.7618366728480995e-19, loss_s: 7.348533836193383e-05


  0%|          | 42/100000 [00:07<4:20:35,  6.39it/s]

loss_c: 1.9651245801891316e-19, loss_s: 0.0001109450458898209


  0%|          | 50/100000 [00:08<5:48:14,  4.78it/s]

loss_c: 2.9078649351046256e-18, loss_s: 3.876128903357312e-05


  0%|          | 62/100000 [00:10<4:19:33,  6.42it/s]

loss_c: 7.02783236336893e-18, loss_s: 2.816143023665063e-05


  0%|          | 72/100000 [00:12<4:18:13,  6.45it/s]

loss_c: 1.0249263451664691e-19, loss_s: 2.228541052318178e-05


  0%|          | 82/100000 [00:13<4:19:10,  6.43it/s]

loss_c: 1.0588076630566411e-19, loss_s: 3.865035978378728e-05


  0%|          | 92/100000 [00:15<4:19:58,  6.40it/s]

loss_c: 5.836065278388255e-19, loss_s: 1.63541262736544e-05


  0%|          | 100/100000 [00:16<6:40:07,  4.16it/s]

loss_c: 2.3793169964034016e-18, loss_s: 1.9861492546624504e-05


  0%|          | 110/100000 [00:18<6:09:33,  4.50it/s]

loss_c: 1.291725244562808e-18, loss_s: 1.0780080629047006e-05


  0%|          | 120/100000 [00:20<5:49:36,  4.76it/s]

loss_c: 2.117665086197006e-20, loss_s: 6.272150494623929e-05


  0%|          | 130/100000 [00:22<5:45:08,  4.82it/s]

loss_c: 1.5033987780816577e-17, loss_s: 0.000242191570578143


  0%|          | 140/100000 [00:23<5:50:26,  4.75it/s]

loss_c: 2.565663417618735e-18, loss_s: 6.662111991317943e-05


  0%|          | 150/100000 [00:25<5:52:43,  4.72it/s]

loss_c: 5.734413880092226e-18, loss_s: 0.0008478756062686443


  0%|          | 162/100000 [00:27<4:18:56,  6.43it/s]

loss_c: 1.8228155228767137e-18, loss_s: 3.387447941349819e-05


  0%|          | 170/100000 [00:28<5:57:05,  4.66it/s]

loss_c: 1.376428539288238e-18, loss_s: 2.5870936951832846e-05


  0%|          | 180/100000 [00:30<5:50:09,  4.75it/s]

loss_c: 7.157436676104963e-19, loss_s: 3.7176065234234557e-05


  0%|          | 192/100000 [00:32<4:48:59,  5.76it/s]

loss_c: 6.589916329638457e-18, loss_s: 4.281834844732657e-05


  0%|          | 200/100000 [00:34<6:37:46,  4.18it/s]

loss_c: 1.6940658945086007e-21, loss_s: 1.0487450708751567e-05


  0%|          | 212/100000 [00:35<4:15:16,  6.52it/s]

loss_c: 2.206521654778065e-18, loss_s: 0.00014062609989196062


  0%|          | 222/100000 [00:37<4:27:08,  6.22it/s]

loss_c: 1.7364258136774412e-19, loss_s: 4.8819816584000364e-05


  0%|          | 232/100000 [00:39<4:18:06,  6.44it/s]

loss_c: 1.2959604092990795e-18, loss_s: 8.049411007959861e-06


  0%|          | 241/100000 [00:40<5:15:24,  5.27it/s]

loss_c: 2.1252056646610396e-17, loss_s: 5.349251750885742e-06


  0%|          | 252/100000 [00:42<4:22:10,  6.34it/s]

loss_c: 1.4314856808597676e-18, loss_s: 1.0170628002015292e-06


  0%|          | 262/100000 [00:44<4:19:57,  6.39it/s]

loss_c: 6.506060067860281e-18, loss_s: 2.7243660952080972e-05


  0%|          | 272/100000 [00:46<4:35:56,  6.02it/s]

loss_c: 3.90483015864845e-19, loss_s: 2.5895853468682617e-05


  0%|          | 280/100000 [00:47<5:55:46,  4.67it/s]

loss_c: 8.470411544369405e-20, loss_s: 6.996945103310281e-07


  0%|          | 290/100000 [00:49<6:02:31,  4.58it/s]

loss_c: 1.497554974528639e-18, loss_s: 2.729400421230821e-06


  0%|          | 300/100000 [00:51<7:01:54,  3.94it/s]

loss_c: 6.928729508540177e-19, loss_s: 3.2416212434327463e-06


  0%|          | 312/100000 [00:53<4:18:17,  6.43it/s]

loss_c: 1.497554250745603e-18, loss_s: 1.9212158804293722e-05


  0%|          | 320/100000 [00:54<5:57:57,  4.64it/s]

loss_c: 2.244637310223896e-18, loss_s: 2.9639793865499087e-05


  0%|          | 332/100000 [00:56<4:23:20,  6.31it/s]

loss_c: 6.268043809681823e-20, loss_s: 2.0903611584799364e-05


  0%|          | 342/100000 [00:58<4:23:48,  6.30it/s]

loss_c: 2.3801632021700434e-18, loss_s: 2.699739980016602e-06


  0%|          | 348/100000 [00:59<4:44:19,  5.84it/s]


KeyboardInterrupt: 

  5%|▌         | 51/1000 [00:14<03:50,  4.12it/s]

第50次迭代，content损失为：0.000，style损失为：18.842


 10%|█         | 101/1000 [00:26<03:39,  4.09it/s]

第100次迭代，content损失为：0.000，style损失为：5.658


 15%|█▌        | 151/1000 [00:39<03:29,  4.04it/s]

第150次迭代，content损失为：0.000，style损失为：1.878


 20%|██        | 200/1000 [00:51<04:12,  3.17it/s]

第200次迭代，content损失为：0.000，style损失为：0.458


 21%|██        | 208/1000 [00:53<03:23,  3.89it/s]


KeyboardInterrupt: 

Net(
  (enc_1): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      (16): MaxP