-
Notifications
You must be signed in to change notification settings - Fork 5
/
model.py
37 lines (28 loc) · 1.12 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch.optim
import torch.nn as nn
import modules.rrdb_denselayer
from modules.hinet import Hinet
import config as c
class PRIS(nn.Module):
def __init__(self, in_1=3, in_2=3):
super(PRIS, self).__init__()
self.inbs = Hinet(in_1=in_1, in_2=in_2)
self.pre_enhance = modules.rrdb_denselayer.ResidualDenseBlock_out(3, 3)
self.post_enhance = modules.rrdb_denselayer.ResidualDenseBlock_out(3, 3)
def load_hinet(self, path):
state_dicts = torch.load(path)
network_state_dict = {k: v for k, v in state_dicts['net'].items() if 'tmp_var' not in k}
self.inbs.load_state_dict(network_state_dict)
def forward(self, x, rev=False):
if not rev:
out = self.inbs(x)
else:
out = self.inbs(x, rev=True)
return out
def init_model(mod):
for key, param in mod.named_parameters():
split = key.split('.')
if param.requires_grad:
param.data = c.init_scale * torch.randn(param.data.shape).cuda()
if split[-2] == 'conv5':
param.data.fill_(0.)