In [2]:
from __future__ import division
from torchvision import models
from torchvision import transforms

from PIL import Image
import argparse
import torch
import torchvision
import torch.nn as nn
import numpy as np
from torch.utils.tensorboard import SummaryWriter

In [3]:
%load_ext tensorboard

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [5]:
device

device(type='cuda')

In [6]:
class VGGNet(nn.Module):
  def __init__(self):
    super(VGGNet,self).__init__()
    ## choose specific layers in net
    self.select = ['0', '5', '10', '19', '28']
    self.vgg = models.vgg19(pretrained=True).features[:29]

  def forward(self,x):
    ## only get data after specific layers
    features = []

    for i, layer in self.vgg._modules.items():

      x = layer(x)
      # choose the specific layers
      if i in self.select:
        features.append(x)
    return features


In [7]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean = (0.485,0.456,0.406),
        std = (0.229,0.224,0.225)
    )
])

denorm = transforms.Normalize((-2.12,-2.04,-1.80),(4.37,4.46,4.44))

def load_img(img_path, transform=None, max_size=None, shape=None):
  img = Image.open(img_path)



  # resize the img into max size
  if max_size:
    scale = max_size / max(img.size)
    size = np.array(img.size) * scale
    img = img.resize(size.astype(int), Image.ANTIALIAS)

  if shape:
    img = img.resize(shape,Image.LANCZOS)

  # the input of network is (N, C, H, W)
  # so make the single img into (1, C, H, W)
  if transform:
    img = transform(img).unsqueeze(0)
  
  return img.to(device)


In [9]:
# init three img
origin_img = load_img('/content/content.png',
          transform = transform,
          max_size = 400)

style_img = load_img('/content/style.jpg',
          transform = transform,
          shape = [origin_img.shape[3], origin_img.shape[2]])

# genrated = torch.randn((1, 3, origin_img.shape[2], origin_img.shape[3]))
# genrated = genrated.cuda()
genrated = origin_img.clone()
genrated.requires_grad_(True)

# model and optimizer
VGG = VGGNet().to(device).eval()


Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

In [10]:
print(origin_img.shape, style_img.shape)

torch.Size([1, 3, 272, 400]) torch.Size([1, 3, 272, 400])


In [20]:
config={
    'iter_num' : 3000,
    'lr' : 10 ,
    'content_w':0.001,
    'style_w':1
    }

In [21]:
optimizer = torch.optim.Adam([genrated] ,lr=config['lr'], betas=[0.5, 0.999])

# 先使风格急剧变化  再贴近内容

# 各层feature的占比也可不同

In [27]:
writer = SummaryWriter(f"style_transfer3")

In [28]:



for i in range(1,config['iter_num']):
  # 如果放在for循环之外  因为计算loss梯度的时候会用到这两个 所以会报错 
  origin_features = VGG(origin_img)
  style_features = VGG(style_img)
  genrated_features = VGG(genrated)


  content_loss = 0
  style_loss = 0

  for f1, f2, f3 in zip(genrated_features, origin_features, style_features):
  
    
    # total_loss =content_w * content_loss + style_w * style_loss
    content_loss += torch.mean((f2 - f1) ** 2)
    
    _, C, H, W = f1.shape
    G = f1.view(C, H*W)
    
    _, C, H, W = f3.shape
    S = f3.view(C, H*W)
    S_M = S @ S.T
    G_M = G @ G.T
    style_loss += torch.mean((S_M-G_M)**2)/(C*H*W)
    

  
  total_loss = content_loss * config['content_w'] + style_loss * config['style_w']
  
  writer.add_scalar(tag="content_loss", scalar_value=content_loss, global_step=i)
  writer.add_scalar(tag="style_loss", scalar_value=style_loss, global_step=i)
  writer.add_scalar(tag="total_loss", scalar_value=total_loss, global_step=i)


  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()
  if i==1000 :
    config['style_w'] /= 10
  if i % 500 == 0:
    config["lr"] /= 10
    optimizer = torch.optim.Adam([genrated] ,lr=config['lr'], betas=[0.5, 0.999])
  if i%50==0:
    print(f"current iter {i}")
    img=genrated.clone()
    img=img.squeeze()
    img=denorm(img).clamp_(0,1)

    writer.add_image(tag="img", img_tensor=img, global_step=i)


    

current iter 50
current iter 100
current iter 150
current iter 200
current iter 250
current iter 300
current iter 350
current iter 400
current iter 450
current iter 500
current iter 550
current iter 600
current iter 650
current iter 700
current iter 750
current iter 800
current iter 850
current iter 900
current iter 950
current iter 1000
current iter 1050
current iter 1100
current iter 1150
current iter 1200
current iter 1250
current iter 1300
current iter 1350
current iter 1400
current iter 1450
current iter 1500
current iter 1550
current iter 1600
current iter 1650
current iter 1700
current iter 1750
current iter 1800
current iter 1850
current iter 1900
current iter 1950
current iter 2000
current iter 2050
current iter 2100
current iter 2150
current iter 2200
current iter 2250
current iter 2300
current iter 2350
current iter 2400
current iter 2450
current iter 2500
current iter 2550
current iter 2600
current iter 2650
current iter 2700
current iter 2750
current iter 2800
current iter

In [15]:
! tensorboard --logdir=run


NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.7.0 at http://localhost:6006/ (Press CTRL+C to quit)
^C


In [23]:
! rm -rf style_transfer3

In [25]:
! pwd

/content


In [26]:
! ls

content.png  sample_data  style.jpg


In [None]:
! find 