导入包

In [None]:
import torch
from mainNet import EstNet
from dataloader import ImageLoder
from torch.utils.data import DataLoader, random_split
from lossFunction import Gram_Style_Loss, Content_Loss, TV_Loss
from tqdm import tqdm
import utils
# from torchmetrics import MeanAbsoluteError
# import numpy as np

设置cudnn是否关闭自动优化，训练设备

In [None]:
# torch.backends.cudnn.benchmark = True
device = 'cuda'

加载网络，预载权重

In [None]:
load_model_name = 'EstNet'
load_model_version = 'alpha_0.1'
ESTNet = EstNet().to(device)
# ESTNet = EstNet(load_preTrained_model= f'./pretrained_models/{load_model_name}_{load_model_version}.pth').to(device)

加载数据：训练集、验证集、测试集

In [None]:
# 82783 items of coco2014
dataset_usage = 1.
used_items_num = round(82783 * dataset_usage)
coco2014 = ImageLoder(r"./data\train2014", 
                        datanum= used_items_num, 
                        if_random= False,
                        preload= False, 
                        resize= 224, 
                        normalized= True, 
                        std= 0.5, mean= 0.5, 
                        double_output= True)

print('used_items_num: ', len(coco2014))
random_seed = 2014
train_data_rate = 0.8
train_items_num = round(train_data_rate * used_items_num) 
eval_data_rate = 0.1
eval_items_num = round(eval_data_rate * used_items_num)

test_items_num = used_items_num - train_items_num - eval_items_num
train_dataset, eval_dataset, test_dataset = random_split(dataset= coco2014, lengths= [train_items_num, eval_items_num, test_items_num], 
                                           generator= torch.Generator('cpu').manual_seed(random_seed))

train_dataloader = DataLoader(dataset= train_dataset,  
                            batch_size= 4, 
                            shuffle= True, 
                            num_workers= 4,
                            pin_memory= True, 
                            prefetch_factor= 2,)

eval_dataloader = DataLoader(dataset= eval_dataset, 
                            batch_size= 4, 
                            shuffle= True, 
                            num_workers= 4,
                            pin_memory= True, 
                            prefetch_factor= 2,)

test_dataloader = DataLoader(dataset= test_dataset, 
                            batch_size= 4, 
                            shuffle= True, 
                            num_workers= 4,
                            pin_memory= True, 
                            prefetch_factor= 2,)

加载优化器、训练调整器、损失函数

In [None]:
optimizer = torch.optim.Adam(ESTNet.parameters(), lr= 0.001)

In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode= 'min', 
                                                    factor= 0.8, patience= 10, 
                                                    verbose= True, min_lr= 1e-4)

In [None]:
style_loss = Gram_Style_Loss().to(device)
content_loss = Content_Loss().to(device)
tv_loss = TV_Loss().to(device)

设置总训练轮次，验证间隔轮次

In [None]:
epochs = 2
eval_distance = 30

训练及验证

In [None]:
ESTNet.train()
train_log = []

for e in range(epochs):
    loss_all = 0
    with tqdm(train_dataloader, desc= f'epoch{e}', leave= False, unit= 'batch') as t:
        for c, s in t:
            c = c.to(device)
            s = s.to(device)
            output = ESTNet(c, s)
            
            loss = style_loss(output, s) + content_loss(output, c) + tv_loss(output)
            t.set_postfix_str(f'loss: {loss.item()}')
            loss_all += loss.item()
        
            optimizer.zero_grad()
            # loss.requires_grad_(True)
            loss.backward()
            optimizer.step()
        t.write(f'epoch{e} loss_all: {loss_all}') 
    train_log.append(loss_all)
    scheduler.step(loss_all)
    

训练过程可视化

In [None]:
utils.train_process_visualable(train_log)

模型保存

In [None]:
save_model_name = 'EstNet'
save_model_version = 'beta_0.1'
torch.save(ESTNet.state_dict(), f'./pretrained_models/{save_model_name}_{save_model_version}.pth')

测试

In [None]:
ESTNet.eval()


结果测试

In [None]:
content_img = utils.get_pilimg(r'data/test/StarSky.jpg')
style_img = utils.get_pilimg(r'data/test/Lenna.png')

utils.img_display(content_img, 'content')
utils.img_display(style_img, 'style')

content_tensor = utils.pilimg2tensor(content_img, cuda= True, resize= (224, 224), mean= 0.5, std= 0.5)
style_tensor = utils.pilimg2tensor(style_img, cuda= True, resize= (224, 224), mean= 0.5, std= 0.5)

In [None]:
ESTNet.eval()
result_tensor = ESTNet(content_tensor, style_tensor)
result_img = utils.tensor2img(result_tensor, mean= 0.5, std= 0.5)
utils.img_display(result_img, 'test')