In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# 查看挂载目录内容
!ls /content/drive/

加载数据

In [None]:
import os

# 创建目录
os.makedirs('/content/coco', exist_ok=True)

# 下载 COCO 数据集的训练图像
!wget -c http://images.cocodataset.org/zips/train2017.zip -P /content/coco/
!unzip -q /content/coco/train2017.zip -d /content/coco/

# 下载 COCO 数据集的验证图像
!wget -c http://images.cocodataset.org/zips/val2017.zip -P /content/coco/
!unzip -q /content/coco/val2017.zip -d /content/coco/

# 下载 COCO 的标注文件
!wget -c http://images.cocodataset.org/annotations/annotations_trainval2017.zip -P /content/coco/
!unzip -q /content/coco/annotations_trainval2017.zip -d /content/coco/


In [None]:
!ls /content/coco
!rm /content/coco/*.zip

导入必要的包

In [None]:
import torch
from torch import nn
import torchvision
import torch
from torchvision import transforms
from torchvision.datasets import CocoDetection
from PIL import Image
from matplotlib import pyplot as plt

配置参数

In [None]:
content_layers = [22]
style_layers = [1, 6, 11, 20, 29]

content_weight = 2e9
style_weight = 6e3

root_dir = "/content/coco/train2017"
annFile = "/content/coco/annotations/instances_train2017.json"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

style_image_path = "/content/drive/MyDrive/starry_night.jpg"

batch_size = 24
learning_rate = 1e-4
print(device)

## 生成模型定义

降采样

In [None]:
class DownSample(nn.Module):
    """下采样层"""

    def __init__(self, in_channels, out_channels, kernel_size, stride):
        super().__init__()
        self.padding = nn.ReflectionPad2d(kernel_size // 2)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播"""
        return self.conv(self.padding(x))

残差块

In [None]:
class ResidualBlock(nn.Module):
    """残差块"""

    def __init__(self, channels):
        super().__init__()
        self.padding1 = nn.ReflectionPad2d(1)
        self.conv1 = nn.Conv2d(channels, channels, 3, 1)
        self.in1 = nn.InstanceNorm2d(channels, affine=True)
        self.relu = nn.ReLU()
        self.padding2 = nn.ReflectionPad2d(1)
        self.conv2 = nn.Conv2d(channels, channels, 3, 1)
        self.in2 = nn.InstanceNorm2d(channels, affine=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播"""
        residual = x
        out = self.relu(self.in1(self.conv1(self.padding1(x))))
        out = self.in2(self.conv2(self.padding2(out)))
        out = out + residual
        return out

升采样

In [None]:
class UpSample(nn.Module):
    """上采样层"""

    # pylint: disable=too-many-arguments, too-many-positional-arguments
    def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
        super().__init__()
        self.upsample = upsample
        self.padding = nn.ReflectionPad2d(kernel_size // 2)
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size, stride)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播"""
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(
                x_in, mode="nearest", scale_factor=self.upsample
            )
        return self.conv2d(self.padding(x_in))

模型定义

In [None]:
class TransferNet(nn.Module):
    """风格迁移网络"""

    def __init__(self):
        super().__init__()
        self.downsample1 = DownSample(3, 32, 9, 1)
        self.in1 = nn.InstanceNorm2d(32, affine=True)
        self.downsample2 = DownSample(32, 64, 3, 2)
        self.in2 = nn.InstanceNorm2d(64, affine=True)
        self.downsample3 = DownSample(64, 128, 3, 2)
        self.in3 = nn.InstanceNorm2d(128, affine=True)
        self.residual1 = ResidualBlock(128)
        self.residual2 = ResidualBlock(128)
        self.residual3 = ResidualBlock(128)
        self.residual4 = ResidualBlock(128)
        self.residual5 = ResidualBlock(128)
        self.upsample1 = UpSample(128, 64, 3, 1, 2)
        self.in4 = nn.InstanceNorm2d(64, affine=True)
        self.upsample2 = UpSample(64, 32, 3, 1, 2)
        self.in5 = nn.InstanceNorm2d(32, affine=True)
        self.upsample3 = DownSample(32, 3, 9, 1)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """前向传播"""
        out = self.relu(self.in1(self.downsample1(x)))
        out = self.relu(self.in2(self.downsample2(out)))
        out = self.relu(self.in3(self.downsample3(out)))
        out = self.residual1(out)
        out = self.residual2(out)
        out = self.residual3(out)
        out = self.residual4(out)
        out = self.residual5(out)
        out = self.relu(self.in4(self.upsample1(out)))
        out = self.relu(self.in5(self.upsample2(out)))
        out = self.upsample3(out)
        return out

## 损失网络定义

特征提取

In [None]:
# 加载 `VGG19` 网络
vgg19 = torchvision.models.vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT).features.eval()
vgg19 = vgg19[0:36]
for param in vgg19.parameters():
    param.requires_grad = False
vgg19 = vgg19.to(device)

# 计算Gama矩阵
def compute_gama_matrix(x: torch.Tensor):
    """计算Gama矩阵"""
    n, c, h, w = x.shape
    x = x.view(n, c, h * w)
    gama_matrix = torch.bmm(x, x.transpose(1, 2))
    return gama_matrix

# 提取特征
def extract_feature(image_tensor: torch.Tensor):
    """提取特征"""
    content_features, style_features = [], []
    x = image_tensor
    for i, layer in enumerate(vgg19):
        x = layer(x)
        if i in content_layers:
            content_features.append(x)
        if i in style_layers:
            style_features.append(compute_gama_matrix(x)/ x.numel() * image_tensor.numel())
    return content_features, style_features

损失计算

In [None]:
def compute_loss(generated_image: torch.Tensor, content_image: torch.Tensor, style_features) -> torch.Tensor:
    """计算损失"""
    generated_content_features, generated_style_features = extract_feature(generated_image)
    content_features, _ = extract_feature(content_image)

    content_loss = 0
    for i in range(len(content_features)):
        content_loss += torch.nn.functional.mse_loss(generated_content_features[i], content_features[i])

    style_loss = 0
    for i in range(len(style_features)):
        style_loss += torch.nn.functional.mse_loss(generated_style_features[i], style_features[i])

    return content_loss * content_weight + style_loss * style_weight

## 定义数据集

In [None]:
# modified from example code provided by ChatGPT
import torch.utils


class CocoDataset(CocoDetection):
    """Coco数据集加载及处理"""

    def __init__(self, root: str, annFile: str, transform: transforms = None):
        super().__init__(root, annFile)
        self.transform = transform

    def __getitem__(self, index: int) -> torch.Tensor:
        """获取图像"""
        img, _ = super().__getitem__(index)  # 忽略标注
        if self.transform:
            img = self.transform(img)
        return img

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = CocoDataset(root_dir, annFile, transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)


## 定义网络及训练

In [None]:
transformer = TransferNet().to(device)
optimizer = torch.optim.AdamW(transformer.parameters(), lr=1e-3)
style_image = Image.open(style_image_path).convert("RGB")

# 可视化检查
# plt.figure(figsize=(10, 10), dpi=150)
# plt.imshow(style_image)

# 批量
style_image = transform(style_image).unsqueeze(0).repeat(batch_size, 1, 1, 1).to(device)
print(style_image.shape)

_, style_features = extract_feature(style_image)

In [None]:
best_model = None
best_loss = float('inf')

for epoch in range(1):
    for i, data in enumerate(dataloader):
        data = data.to(device)
        optimizer.zero_grad()
        loss = compute_loss(transformer(data), data, style_features)
        loss.backward()
        optimizer.step()
        print(f"Epoch: {epoch}, Batch: {i}, Loss: {loss.item()}")
        if loss < best_loss:
            best_loss = loss
            best_model = transformer.state_dict()
            # save the best model
            torch.save(best_model, "/content/drive/MyDrive/best_model.pth")