In [1]:
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision.models import vgg19
from PIL import Image
import copy
import os

In [2]:
# 存储模型的目录
os.environ['TORCH_HOME'] = './model'  

# 加载图像

In [3]:
def load_image(image_path, max_size=400):
    image = Image.open(image_path).convert('RGB')
    # 如果图片过大则调整大小
    size = min(max_size, max(image.size))
    transform = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])
    # 添加批次维度
    image = transform(image).unsqueeze(0)  
    return image

# 保存图像

In [4]:
def save_image(tensor, path):
    # 取消归一化并返回结果
    unnormalize = transforms.Normalize(
        mean=[-2.118, -2.036, -1.804],
        std=[4.367, 4.464, 4.444]
    )
    tensor = unnormalize(tensor)
    image = tensor.clone().detach()
    # 去掉批次维度
    image = image.squeeze(0)  
    image = transforms.ToPILImage()(image)
    image.save(path)

# 定义特征提取模型

In [5]:
# 定义VGG19模型，只提取特定层的特征
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        # 只使用前21层
        self.features = vgg19(pretrained=True).features[:21].eval()  

    def forward(self, x):
        features = []
        for i, layer in enumerate(self.features):
            x = layer(x)
            # 选择特定层的输出
            if i in {0, 5, 10, 19, 21}:  
                features.append(x)
        return features

# 计算内容损失

In [6]:
class ContentLoss(nn.Module):
    def __init__(self, target):
        super(ContentLoss, self).__init__()
        self.target = target.detach()

    def forward(self, input):
        # input：输入，target：目标
        # 输入和目标的内容误差
        return nn.functional.mse_loss(input, self.target)

# 计算风格损失

In [7]:
def gram_matrix(input):
    batch_size, channels, height, width = input.size()
    features = input.view(batch_size * channels, height * width)
    G = torch.mm(features, features.t())
    return G.div(batch_size * channels * height * width)

In [8]:
class StyleLoss(nn.Module):
    def __init__(self, target):
        super(StyleLoss, self).__init__()
        # 目标的风格
        self.target = gram_matrix(target).detach()

    def forward(self, input):
        # 输入的风格
        G = gram_matrix(input)
        # 输入和目标的风格误差
        return nn.functional.mse_loss(G, self.target)

# 图像风格迁移

In [9]:
def style_transfer(content_img, style_img, num_steps=1000, style_weight=1e9, content_weight=1):
    # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = torch.device("cpu")
    content_img = content_img.to(device)
    style_img = style_img.to(device)
    model = VGG().to(device)
    # 提取风格特征和内容特征
    style_features = model(style_img)
    content_features = model(content_img)
    # 初始化输入图像（使用内容图像作为初始图像）
    input_img = content_img.clone().requires_grad_(True).to(device)
    # 定义优化器
    optimizer = optim.LBFGS([input_img])
    style_losses = []
    content_losses = []
    # 初始化损失模块
    for sf, cf in zip(style_features, content_features):
        content_losses.append(ContentLoss(cf))
        style_losses.append(StyleLoss(sf))
    run = [0]
    while run[0] <= num_steps:
        def closure():
            optimizer.zero_grad()
            input_features = model(input_img)
            content_loss = 0
            style_loss = 0
            # 累加内容损失
            for cl, input_f in zip(content_losses, input_features):
                content_loss += content_weight * cl(input_f)
            # 累加风格损失
            for sl, input_f in zip(style_losses, input_features):
                style_loss += style_weight * sl(input_f)
            loss = content_loss + style_loss
            loss.backward()
            run[0] += 1
            if run[0] % 50 == 0:
                print(f'Step {run[0]}, Content Loss: {content_loss.item():4f}, Style Loss: {style_loss.item():4f}')
            return loss
        optimizer.step(closure)
    return input_img

In [10]:
def main():
    if __name__ == '__main__':
        content_image_path = 'image/content_image.png'
        style_image_path = 'image/style_image.png'
        output_image_path = 'image/output_image.jpg'
    
        content_img = load_image(content_image_path)
        style_img = load_image(style_image_path)
        result = style_transfer(content_img, style_img)
    
        save_image(result, output_image_path)
        print(f"风格迁移完成，图像已保存为 {output_image_path}")

In [11]:
if __name__ == '__main__':
    main()



Step 50, Content Loss: 29.754610, Style Loss: 2030.043701
Step 100, Content Loss: 30.590519, Style Loss: 700.369385
Step 150, Content Loss: 31.174225, Style Loss: 365.433228
Step 200, Content Loss: 31.571880, Style Loss: 250.382965
Step 250, Content Loss: 31.778713, Style Loss: 197.025391
Step 300, Content Loss: 31.942102, Style Loss: 163.526123
Step 350, Content Loss: 32.088745, Style Loss: 142.871002
Step 400, Content Loss: 32.167210, Style Loss: 128.662079
Step 450, Content Loss: 32.240772, Style Loss: 118.532410
Step 500, Content Loss: 32.298084, Style Loss: 110.577751
Step 550, Content Loss: 32.349014, Style Loss: 104.269287
Step 600, Content Loss: 32.393829, Style Loss: 99.337952
Step 650, Content Loss: 32.429863, Style Loss: 95.298645
Step 700, Content Loss: 32.460632, Style Loss: 92.166992
Step 750, Content Loss: 32.481163, Style Loss: 89.454445
Step 800, Content Loss: 32.517361, Style Loss: 87.031052
Step 850, Content Loss: 32.539078, Style Loss: 84.895279
Step 900, Content Lo