-
Notifications
You must be signed in to change notification settings - Fork 20
/
ours.py
140 lines (114 loc) · 4.24 KB
/
ours.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import warnings
from .priors import JointAutoregressiveHierarchicalPriors
from .our_utils import *
from compressai.layers import *
from .waseda import Cheng2020Anchor
class InvCompress(Cheng2020Anchor):
def __init__(self, N=192, **kwargs):
super().__init__(N=N)
self.g_a = None
self.g_s = None
self.enh = EnhModule(64)
self.inv = InvComp(M=N)
self.attention = AttModule(N)
def g_a_func(self, x):
x = self.enh(x)
x = self.inv(x)
x = self.attention(x)
return x
def g_s_func(self, x):
x = self.attention(x, rev = True)
x = self.inv(x, rev=True)
x = self.enh(x, rev=True)
return x
def forward(self, x):
y = self.g_a_func(x)
z = self.h_a(y)
z_hat, z_likelihoods = self.entropy_bottleneck(z)
params = self.h_s(z_hat)
y_hat = self.gaussian_conditional.quantize(
y, "noise" if self.training else "dequantize"
)
ctx_params = self.context_prediction(y_hat)
gaussian_params = self.entropy_parameters(
torch.cat((params, ctx_params), dim=1)
)
scales_hat, means_hat = gaussian_params.chunk(2, 1)
_, y_likelihoods = self.gaussian_conditional(y, scales_hat, means=means_hat)
x_hat = self.g_s_func(y_hat)
return {
"x_hat": x_hat,
"likelihoods": {"y": y_likelihoods, "z": z_likelihoods}
}
@classmethod
def from_state_dict(cls, state_dict):
"""Return a new model instance from `state_dict`."""
N = state_dict["h_a.0.weight"].size(0)
net = cls(N)
net.load_state_dict(state_dict)
return net
def compress(self, x):
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
y = self.g_a_func(x)
z = self.h_a(y)
z_strings = self.entropy_bottleneck.compress(z)
z_hat = self.entropy_bottleneck.decompress(z_strings, z.size()[-2:])
params = self.h_s(z_hat)
s = 4 # scaling factor between z and y
kernel_size = 5 # context prediction kernel size
padding = (kernel_size - 1) // 2
y_height = z_hat.size(2) * s
y_width = z_hat.size(3) * s
y_hat = F.pad(y, (padding, padding, padding, padding))
y_strings = []
for i in range(y.size(0)):
string = self._compress_ar(
y_hat[i : i + 1],
params[i : i + 1],
y_height,
y_width,
kernel_size,
padding,
)
y_strings.append(string)
return {"strings": [y_strings, z_strings], "shape": z.size()[-2:], "y": y}
def decompress(self, strings, shape):
assert isinstance(strings, list) and len(strings) == 2
if next(self.parameters()).device != torch.device("cpu"):
warnings.warn(
"Inference on GPU is not recommended for the autoregressive "
"models (the entropy coder is run sequentially on CPU)."
)
z_hat = self.entropy_bottleneck.decompress(strings[1], shape)
params = self.h_s(z_hat)
s = 4 # scaling factor between z and y
kernel_size = 5 # context prediction kernel size
padding = (kernel_size - 1) // 2
y_height = z_hat.size(2) * s
y_width = z_hat.size(3) * s
y_hat = torch.zeros(
(z_hat.size(0), self.M, y_height + 2 * padding, y_width + 2 * padding),
device=z_hat.device,
)
for i, y_string in enumerate(strings[0]):
self._decompress_ar(
y_string,
y_hat[i : i + 1],
params[i : i + 1],
y_height,
y_width,
kernel_size,
padding,
)
y_hat = F.pad(y_hat, (-padding, -padding, -padding, -padding))
x_hat = self.g_s_func(y_hat).clamp_(0, 1)
return {"x_hat": x_hat}