-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_nyu.py
195 lines (179 loc) · 7.73 KB
/
train_nyu.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# -*- coding: utf-8 -*-
import sys
sys.path.append('/home/aistudio/external-libraries')
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np
from utils.loss import *
from utils.print_time import *
from utils.save_log_to_excel import *
from dataloader_nyu import nyu_DataSet
from Res_ED_model import *
import time
import xlwt
from utils.ms_ssim import *
import os
ednet_LR = 0.004 # 学习率
Atnet_LR = 0.0004 # 学习率
EPOCH = 20 # 轮次
BATCH_SIZE = 4 # 批大小
excel_train_line = 1 # train_excel写入的行的下标
excel_val_line = 1 # val_excel写入的行的下标
accumulation_steps = 2 # 梯度积累的次数,类似于batch-size=64
itr_to_lr = 10000 // BATCH_SIZE # 训练10000次后损失下降50%
itr_to_excel = 64 // BATCH_SIZE # 训练64次后保存相关数据到excel
# 由于loss数量过多,建议使用分步训练以降低显存占用。
loss_num = 12 # loss的数量。重建损失2个,At网络5个,去雾损失(正向反向),中间特征约束。
weight_At = [1, 1, 1, 1, 1]
weight_ed = [1, 1, 1, 1, 0.01]
weight_recon = [1, 1]
weight = weight_At + weight_ed + weight_recon
train_haze_path = '/home/aistudio/work/nyu/train/' # 去雾训练集的路径
val_haze_path = '/home/aistudio/work/nyu/val/' # 去雾验证集的路径
gt_path = '/home/aistudio/work/nyu/gth/'
d_path = '/home/aistudio/work/nyu/depth/'
save_path = './result_nyu_' + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + '/'
save_model_ed_name = save_path + 'ed_model.pt' # 保存模型的路径
save_model_At_name = save_path + 'At_model.pt' # 保存模型的路径
excel_save = save_path + 'result.xls' # 保存excel的路径
mid_save_ed_path = './mid_model/ednet_model.pt' # 保存的中间模型,用于下一步ntire数据的训练。
mid_save_At_path = './mid_model/Atnet_model.pt'
# 初始化excel
f, sheet_train, sheet_val = init_excel()
# 加载模型
ednet_path = './pre_model/ednet_model.pt'
Atnet_path = './pre_model/Atnet_model.pt'
ednet = torch.load(ednet_path)
Atnet = torch.load(Atnet_path)
ednet = ednet.cuda()
Atnet = Atnet.cuda()
if not os.path.exists(save_path):
os.makedirs(save_path)
if not os.path.exists('./mid_model/'):
os.makedirs('./mid_model/')
for param in ednet.decoder.parameters():
param.requires_grad = False
# 数据转换模式
transform = transforms.Compose([transforms.ToTensor()])
# 读取训练集数据
train_path_list = [train_haze_path, gt_path, d_path]
train_data = nyu_DataSet(transform, train_path_list)
train_data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
# 读取验证集数据
val_path_list = [val_haze_path, gt_path, d_path]
val_data = nyu_DataSet(transform, val_path_list)
val_data_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
# 定义优化器
ednet_optim = torch.optim.Adam(ednet.parameters(), lr=ednet_LR, weight_decay=1e-5)
Atnet_optim = torch.optim.Adam(Atnet.parameters(), lr=Atnet_LR, weight_decay=1e-5)
min_loss = 999999999
min_epoch = 0
itr = 0 # 记录一共训练了多少个batch
start_time = time.time()
# 开始训练
print("\nstart to train!")
for epoch in range(EPOCH):
index = 0
loss = 0
train_loss = 0
val_loss = 0
loss_excel = [0] * loss_num
ednet.train()
Atnet.train()
for haze, gt, A_haze, A_gt, t_haze, t_gt in train_data_loader:
index += 1
itr += 1
J, A, t = Atnet(haze)
output, gt_scene, I = ednet(gt, A_gt, t_gt)
dehaze, hazy_scene, I = ednet(haze, A, t)
# 分批计算loss,以防现存溢出。
loss_image = [gt, A_haze, t_haze, J, A, t]
loss, temp_loss = loss_At_function(loss_image, weight_At)
train_loss += loss.item()
for i in range(len(temp_loss)):
loss_excel[i] = loss_excel[i] + temp_loss[i]
loss = loss / accumulation_steps
loss.backward(retain_graph=True)
loss_image = [output, gt, dehaze, gt_scene, hazy_scene]
loss, temp_loss = loss_ed_function(loss_image, weight_ed)
train_loss += loss.item()
for i in range(len(temp_loss)):
loss_excel[i + 5] = loss_excel[i + 5] + temp_loss[i]
loss = loss / accumulation_steps
loss.backward(retain_graph=True)
loss_image = [haze, I]
loss, temp_loss = loss_recon_function(loss_image, weight_recon)
train_loss += loss.item()
for i in range(len(temp_loss)):
loss_excel[i + 10] = loss_excel[i + 10] + temp_loss[i]
loss = loss / accumulation_steps
loss.backward()
if ((index + 1) % accumulation_steps) == 0:
ednet_optim.step()
Atnet_optim.step()
ednet_optim.zero_grad()
Atnet_optim.zero_grad()
if np.mod(index, itr_to_excel) == 0:
loss_excel = [loss_excel[i] / itr_to_excel for i in range(len(loss_excel))]
print('epoch %d, %03d/%d, loss=%.5f' % (epoch + 1, index, len(train_data_loader), sum(loss_excel)))
print_time(start_time, index, EPOCH, len(train_data_loader), epoch)
excel_train_line = write_excel(sheet=sheet_train,
data_type='train',
line=excel_train_line,
epoch=epoch,
itr=itr,
loss=loss_excel)
f.save(excel_save)
loss_excel = [0] * loss_num
ednet_optim.step()
Atnet_optim.step()
ednet_optim.zero_grad()
Atnet_optim.zero_grad()
loss_excel = [0] * loss_num
with torch.no_grad():
ednet.eval()
Atnet.eval()
for haze, gt, A_haze, A_gt, t_haze, t_gt in val_data_loader:
J, A, t = Atnet(haze)
output, gt_scene, I = ednet(gt, A_gt, t_gt)
dehaze, hazy_scene, I = ednet(haze, A_haze, t_haze)
loss_image = [gt, A_haze, t_haze, J, A, t]
loss, temp_loss = loss_At_function(loss_image, weight_At)
val_loss += loss.item()
for i in range(len(temp_loss)):
loss_excel[i] = loss_excel[i] + temp_loss[i]
loss_image = [output, gt, dehaze, gt_scene, hazy_scene]
loss, temp_loss = loss_ed_function(loss_image, weight_ed)
val_loss += loss.item()
for i in range(len(temp_loss)):
loss_excel[i + 5] = loss_excel[i + 5] + temp_loss[i]
loss_image = [haze, I]
loss, temp_loss = loss_recon_function(loss_image, weight_recon)
val_loss += loss.item()
for i in range(len(temp_loss)):
loss_excel[i + 10] = loss_excel[i + 10] + temp_loss[i]
loss_excel = [loss_excel[i] / len(val_data_loader) for i in range(len(loss_excel))]
print('val_loss = %.5f' % val_loss)
excel_val_line = write_excel(sheet=sheet_val,
data_type='val',
line=excel_val_line,
epoch=epoch,
itr=False,
loss=train_loss)
excel_val_line = write_excel(sheet=sheet_val,
data_type='val',
line=excel_val_line,
epoch=epoch,
itr=False,
loss=loss_excel)
f.save(excel_save)
if val_loss < min_loss:
min_loss = val_loss
min_epoch = epoch
torch.save(ednet, save_model_ed_name)
torch.save(Atnet, save_model_At_name)
torch.save(ednet, mid_save_ed_path)
torch.save(Atnet, mid_save_At_path)
print('saving the epoch %d model with %.5f' % (epoch + 1, min_loss))
print('Train is Done!')