-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
117 lines (109 loc) · 4.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import argparse
# import utility
from model.mwcnn import Model
from torch.utils.data import DataLoader
import loss
import os
# import h5py
from option import args
from data.data_provider import SingleLoader,SingleLoader_raw
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.nn as nn
import numpy as np
# import model
from torchsummary import summary
from utils.metric import calculate_psnr
from utils.training_util import save_checkpoint,MovingAverage, load_checkpoint
# from collections import OrderedDict
if __name__ == "__main__":
torch.set_num_threads(4)
torch.manual_seed(args.seed)
# checkpoint = utility.checkpoint(args)
if args.data_type =='rgb':
data_set = SingleLoader(noise_dir=args.noise_dir,gt_dir=args.gt_dir,image_size=args.image_size)
elif args.data_type == 'raw':
data_set = SingleLoader_raw(noise_dir=args.noise_dir,gt_dir=args.gt_dir,image_size=args.image_size)
else:
print("Data type not valid")
exit()
data_loader = DataLoader(
data_set,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers
)
# loss_func = loss.Loss(args,None)
loss_func = loss.CharbonnierLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
checkpoint_dir = args.checkpoint
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
model = Model(args).to(device)
optimizer = optim.Adam(
model.parameters(),
lr=args.lr
)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [2, 4, 6, 8, 10, 12, 14, 16], 0.8)
optimizer.zero_grad()
global_step = 0
average_loss = MovingAverage(args.save_every)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.restart:
start_epoch = 0
global_step = 0
best_loss = np.inf
print('=> no checkpoint file to be loaded.')
else:
try:
checkpoint = load_checkpoint(checkpoint_dir, device == 'cuda', 'latest')
start_epoch = checkpoint['epoch']
global_step = checkpoint['global_iter']
best_loss = checkpoint['best_loss']
state_dict = checkpoint['state_dict']
# new_state_dict = OrderedDict()
# for k, v in state_dict.items():
# name = "model."+ k # remove `module.`
# new_state_dict[name] = v
model.model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint['optimizer'])
print('=> loaded checkpoint (epoch {}, global_step {})'.format(start_epoch, global_step))
except:
start_epoch = 0
global_step = 0
best_loss = np.inf
print('=> no checkpoint file to be loaded.')
for epoch in range(start_epoch, args.epochs):
for step, (noise, gt) in enumerate(data_loader):
noise = noise.to(device)
gt = gt.to(device)
pred = model(noise,0)
# print(pred.size())
loss = loss_func(pred,gt)
optimizer.zero_grad()
loss.backward()
optimizer.step()
average_loss.update(loss)
if global_step % args.save_every == 0:
print(len(average_loss._cache))
if average_loss.get_value() < best_loss:
is_best = True
best_loss = average_loss.get_value()
else:
is_best = False
save_dict = {
'epoch': epoch,
'global_iter': global_step,
'state_dict': model.model.state_dict(),
'best_loss': best_loss,
'optimizer': optimizer.state_dict(),
}
save_checkpoint(save_dict, is_best, checkpoint_dir, global_step)
if global_step % args.loss_every == 0:
print(global_step ,"PSNR : ",calculate_psnr(pred,gt))
print(average_loss.get_value())
global_step +=1
scheduler.step()
# print(model)
# print(summary(model,[(3,512,512),[8]]))