Skip to content

Commit 5133abc

Browse files
wip
1 parent bfc2416 commit 5133abc

File tree

19 files changed

+736
-364
lines changed

19 files changed

+736
-364
lines changed

NN/Nerf2D.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,22 +88,31 @@ def _withExtraLatents(self, latents, src, points):
8888
)
8989
return tf.concat([latents, extraData], axis=-1)
9090

91+
def _extractR(self, YData):
92+
R = None
93+
if 'blur R' in YData:
94+
R = YData['blur R']
95+
tf.assert_equal(R, R[:, 0:1], "R must be the same for all points")
96+
R = R[:, 0]
97+
return R
98+
9199
def train_step(self, data):
92100
(src, YData) = data
93101
src = ensure4d(src)
94102
x0 = YData['sampled']
95103
positions = YData['positions']
104+
B, N = tf.shape(positions)[0], tf.shape(positions)[1]
96105
# remove this keys from the dictionary
97106
YData = {k: v for k, v in YData.items() if k not in ['sampled', 'positions']}
98107

99108
with tf.GradientTape() as tape:
100-
encodedSrc = self._encoder(src=src, training=True)
109+
encodedSrc = self._encoder(src=src, training=True, R=self._extractR(YData))
101110
latents = self._extractLatents(encodedSrc=encodedSrc, positions=positions, training=True)
102111
# train the restorator
103112
residual = self._withResidual(src, points=positions)
104113
latents = self._withExtraLatents(latents, src=src, points=positions)
105114
# flatten values
106-
BN = tf.shape(positions)[0] * tf.shape(positions)[1]
115+
BN = B * N
107116
latents = tf.reshape(latents, (BN, tf.shape(latents)[-1]))
108117
positions = tf.reshape(positions, (BN, 2))
109118
x0 = tf.reshape(x0, (BN, tf.shape(x0)[-1]))
@@ -117,6 +126,7 @@ def train_step(self, data):
117126
positions=positions,
118127
params={**self._lossParams, **params},
119128
)['loss']
129+
tf.assert_equal(tf.shape(loss), (BN, ))
120130

121131
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
122132
self._loss.update_state(loss)
@@ -153,7 +163,17 @@ def _inference(
153163
tf.assert_equal(tf.shape(initialValues)[:1], (B, ))
154164
initialValues = tf.reshape(initialValues, (B, N, C))
155165

156-
encoded = self._encoder(src, training=False, params=encoderParams)
166+
R = None
167+
if 'blurRadius' in reverseArgs:
168+
R = reverseArgs['blurRadius']
169+
if not tf.is_tensor(R):
170+
R = tf.convert_to_tensor(R, dtype=tf.float32)
171+
R = tf.reshape(R, [1, 1])
172+
R = tf.tile(R, [B, 1])
173+
else:
174+
R = tf.zeros((B, 1), dtype=tf.float32)
175+
pass
176+
encoded = self._encoder(src, training=False, params=encoderParams, R=R)
157177
def getChunk(ind, sz):
158178
posC = pos[ind:ind+sz]
159179
sz = tf.shape(posC)[0]

NN/RestorationModel/CRestorationModel.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,11 @@ def _addRadius(self, latents, R=None, fakeR=0.0, training=False):
6060
B = tf.shape(latents)[0]
6161
if self._blurRadiusEncoder is not None:
6262
if R is None:
63-
R = tf.constant([[fakeR]], dtype=tf.float32)
64-
encodedR = self._blurRadiusEncoder(R, training=training)
63+
if not tf.is_tensor(fakeR):
64+
fakeR = tf.convert_to_tensor(fakeR, dtype=tf.float32)
65+
66+
fakeR = tf.reshape(fakeR, [1, 1])
67+
encodedR = self._blurRadiusEncoder(fakeR, training=training)
6568
encodedR = tf.tile(encodedR, [B, 1])
6669
else:
6770
encodedR = self._blurRadiusEncoder(R, training=training)
@@ -81,7 +84,8 @@ def call(self, latents, pos, T, V, residual, R=None, training=False):
8184
def reverse(self, latents, pos, reverseArgs, training, value, residual, index):
8285
EPos = self._encodePos(pos, training=training, args=reverseArgs.get('decoder', {}))
8386
latents = self._addResiduals(latents, residual)
84-
latents = self._addRadius(latents, R=None, training=training)
87+
fakeR = reverseArgs.get('blurRadius', 0.0)
88+
latents = self._addRadius(latents, R=None, fakeR=fakeR, training=training)
8589

8690
def denoiser(x, t, mask=None):
8791
args = dict(condition=latents, coords=EPos, timestep=t, V=x)

NN/encoders/CEncoderHead.py

Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# a bit hacky way to make encoder head obtain input shape dynamically
2+
import tensorflow as tf
3+
import tensorflow.keras.layers as L
4+
from NN.utils import sMLP
5+
from NN.encoding import CCoordsGridLayer, CCoordsEncodingLayer
6+
from NN.layers import MixerConvLayer, Patches, TransformerBlock
7+
from Utils.utils import dumb_deepcopy
8+
9+
def block_params_from_config(config):
10+
layers = config.get('layers', None)
11+
if not(layers is None): return layers
12+
13+
defaultConvParams = {
14+
'kernel size': config.get('kernel size', 3),
15+
'activation': config.get('activation', 'relu'),
16+
'name': config.get('name', 'Conv2D'),
17+
}
18+
convBefore = config['conv before']
19+
# if convBefore is an integer, then it's the same for all layers
20+
if isinstance(convBefore, int):
21+
convParams = { 'channels': config['channels'], **defaultConvParams }
22+
convBefore = [convParams] * convBefore # repeat the same parameters
23+
pass
24+
assert isinstance(convBefore, list), 'convBefore must be a list'
25+
# if convBefore is a list of integers, then each integer is the number of channels
26+
if (0 < len(convBefore)) and isinstance(convBefore[0], int):
27+
convBefore = [ {'channels': sz, **defaultConvParams} for sz in convBefore ]
28+
pass
29+
30+
# add separately last layer
31+
lastConvParams = {
32+
'channels': config.get('channels last', config['channels']),
33+
'kernel size': config.get('kernel size last', defaultConvParams['kernel size']),
34+
'activation': config.get('final activation', defaultConvParams['activation']),
35+
'name': config.get('last name', 'Conv2D'),
36+
}
37+
return convBefore + [lastConvParams]
38+
39+
def conv_block_from_config(data, config, defaults, name='CB'):
40+
config = {**defaults, **config} # merge defaults and config
41+
convParams = block_params_from_config(config)
42+
# apply convolutions to the data
43+
for i, parameters in enumerate(convParams):
44+
parameters = dumb_deepcopy(parameters)
45+
Name = parameters.get('name', 'Conv2D')
46+
if 'Conv2D' == Name:
47+
data = L.Conv2D(
48+
filters=parameters['channels'],
49+
padding='same',
50+
kernel_size=parameters['kernel size'],
51+
activation=parameters['activation'],
52+
name='%s/conv-%d' % (name, i)
53+
)(data)
54+
continue
55+
56+
if 'MLP Mixer' == Name:
57+
data = MixerConvLayer(
58+
token_mixing=parameters.get('token mixing', 512),
59+
channel_mixing=parameters.get('channel mixing', 512),
60+
name='%s/conv-mixer-%d' % (name, i)
61+
)(data)
62+
continue
63+
64+
if 'Patches' == Name:
65+
data = Patches(
66+
patch_size=parameters['patch size'],
67+
name='%s/patches-%d' % (name, i)
68+
)(data)
69+
continue
70+
71+
if 'CoordsGrid' == Name:
72+
parameters = {k: v for k, v in parameters.items() if k not in ['name']}
73+
parameters['name'] = '%s/coordsGrid-%d' % (name, i)
74+
data = CCoordsGridLayer(
75+
CCoordsEncodingLayer(
76+
N=parameters.get('N', 32),
77+
**parameters
78+
),
79+
name='%s/coordsGrid-%d' % (name, i)
80+
)(data)
81+
continue
82+
83+
if 'Transformer' == Name:
84+
parameters = {k: v for k, v in parameters.items()}
85+
parameters['name'] = '%s/transformer-%d' % (name, i)
86+
parameters['intermediate_dim'] = parameters.pop('intermediate dim', 512)
87+
parameters['num_heads'] = parameters.pop('num heads', 8)
88+
data = TransformerBlock(**parameters)(data)
89+
continue
90+
91+
if 'Reshape' == Name:
92+
shape = list(parameters['shape'])
93+
for j, sz in enumerate(shape):
94+
if sz <= -2:
95+
sz = data.shape[sz + 1]
96+
shape[j] = sz
97+
continue
98+
data = L.Reshape(
99+
shape,
100+
name='%s/reshape-%d' % (name, i)
101+
)(data)
102+
continue
103+
104+
if 'MLP' == Name:
105+
parameters['name'] = '%s/mlp-%d' % (name, i)
106+
data = sMLP(**parameters)(data)
107+
continue
108+
109+
raise NotImplementedError('Unknown layer: {}'.format(Name))
110+
return data
111+
112+
def _createGCMv2(dataShape, config, latentDim, name):
113+
data = L.Input(shape=dataShape)
114+
115+
res = data
116+
for i, blockConfig in enumerate(config['downsample steps']):
117+
# downsample
118+
res = L.Conv2D(
119+
filters=blockConfig['channels'],
120+
kernel_size=blockConfig['kernel size'],
121+
strides=2,
122+
padding='same',
123+
activation='relu',
124+
name=name + '/downsample-%d' % (i + 1,)
125+
)(res)
126+
# convolutions
127+
for layerId in range(blockConfig['layers']):
128+
res = L.Conv2D(
129+
filters=blockConfig['channels'],
130+
kernel_size=blockConfig['kernel size'],
131+
padding='same',
132+
activation='relu',
133+
name=name + '/downsample-%d/layer-%d' % (i + 1, layerId + 1)
134+
)(res)
135+
continue
136+
continue
137+
138+
return tf.keras.Model(inputs=[data], outputs=res, name=name)
139+
140+
def _createGlobalContextModel(X, config, latentDim, name):
141+
model = config.get('name', 'v1')
142+
if 'v1' == model: # simple convolutional model
143+
res = conv_block_from_config(
144+
data=X, config=config, defaults={
145+
'conv before': 0, # by default, no convolutions before the last layer
146+
}
147+
)
148+
# calculate global context
149+
latent = L.Flatten()(res)
150+
context = sMLP(sizes=config['mlp'], activation='relu', name=name + '/globalMixer')(latent)
151+
context = L.Dense(latentDim, activation=config['final activation'], name=name + '/dense-latent')(context)
152+
return context # end of 'v1' model
153+
154+
if 'v2' == model:
155+
res = data = L.Input(shape=X.shape[1:])
156+
res = _createGCMv2(res.shape[1:], config, latentDim, name)(res)
157+
# calculate global context
158+
latent = L.Flatten()(res)
159+
context = sMLP(sizes=config['mlp'], activation='relu', name=name + '/globalMixer')(latent)
160+
context = L.Dense(latentDim, activation=config['final activation'], name=name + '/dense-latent')(context)
161+
model = tf.keras.Model(inputs=[data], outputs=context, name=name)
162+
return model(X) # end of 'v2' model
163+
164+
raise NotImplementedError('Unknown global context model: {}'.format(model))
165+
166+
def _withPositionConfig(config, name):
167+
if config is None:
168+
print('[Encoder] Positions: No')
169+
return lambda x, _: x
170+
171+
print('[Encoder] Positions: Yes')
172+
173+
if isinstance(config, bool): config = { 'N': 32 }
174+
assert isinstance(config, dict), 'config must be a dictionary'
175+
176+
def withPosition(x, i):
177+
if not config.get('stage-%d' % i, True): return x
178+
179+
encoding = config.get('encoding', {})
180+
encoding = dict(**encoding)
181+
encoding['N'] = config.get('stage-%d N' % i, config.get('N', 32))
182+
return CCoordsGridLayer(
183+
CCoordsEncodingLayer(**encoding, name='%s/coordsGrid-%d/encoding' % (name, i)),
184+
name='%s/coordsGrid-%d' % (name, i)
185+
)(x)
186+
return withPosition
187+
188+
##################
189+
def createEncoderHead_full(
190+
imgWidth,
191+
config,
192+
channels, downsampleSteps, latentDim,
193+
ConvBeforeStage, ConvAfterStage,
194+
localContext, globalContext,
195+
positionsConfigs,
196+
name
197+
):
198+
assert config is not None, 'config must be a dictionary'
199+
assert isinstance(downsampleSteps, list) and (0 < len(downsampleSteps)), 'downsampleSteps must be a list of integers'
200+
data = L.Input(shape=(imgWidth, imgWidth, channels))
201+
202+
withPosition = _withPositionConfig(positionsConfigs, name)
203+
res = data
204+
intermediate = []
205+
for i, sz in enumerate(downsampleSteps):
206+
if config.get('use downsampling', True):
207+
res = L.Conv2D(sz, 3, strides=2, padding='same', activation='relu')(res)
208+
res = withPosition(res, i) # add position encoding if needed
209+
for _ in range(ConvBeforeStage):
210+
res = L.Conv2D(sz, 3, padding='same', activation='relu')(res)
211+
212+
# local context
213+
if not(localContext is None):
214+
intermediate.append(
215+
conv_block_from_config(
216+
data=res, config=localContext, defaults={
217+
'channels': sz,
218+
'channels last': latentDim, # last layer should have latentDim channels
219+
},
220+
name='%s/intermediate-%d' % (name, i)
221+
)
222+
)
223+
################################
224+
for _ in range(ConvAfterStage):
225+
res = L.Conv2D(sz, 3, padding='same', activation='relu')(res)
226+
continue
227+
228+
if not(globalContext is None): # global context
229+
res = withPosition(res, len(downsampleSteps))
230+
context = _createGlobalContextModel(res, globalContext, latentDim, name + '/globalContext')
231+
else: # no global context
232+
# return dummy context to keep the interface consistent
233+
context = L.Lambda(
234+
lambda x: tf.zeros((tf.shape(x)[0], 1), dtype=res.dtype)
235+
)(res)
236+
237+
return tf.keras.Model(
238+
inputs=[data],
239+
outputs={
240+
'intermediate': intermediate, # intermediate representations
241+
'context': context, # global context
242+
},
243+
name=name
244+
)
245+
246+
class CEncoderHead(tf.keras.Model):
247+
def __init__(self,
248+
config,
249+
downsampleSteps, latentDim,
250+
ConvBeforeStage, ConvAfterStage,
251+
localContext, globalContext,
252+
positionsConfigs,
253+
**kwargs
254+
):
255+
super().__init__(**kwargs)
256+
self._config = config
257+
self._downsampleSteps = downsampleSteps
258+
self._latentDim = latentDim
259+
self._ConvBeforeStage = ConvBeforeStage
260+
self._ConvAfterStage = ConvAfterStage
261+
self._localContext = localContext
262+
self._globalContext = globalContext
263+
self._positionsConfigs = positionsConfigs
264+
return
265+
266+
def build(self, inputShape):
267+
H, W, C = inputShape[1:]
268+
self._encoderHead = createEncoderHead_full(
269+
imgWidth=H, config=self._config,
270+
channels=C, downsampleSteps=self._downsampleSteps, latentDim=self._latentDim,
271+
ConvBeforeStage=self._ConvBeforeStage, ConvAfterStage=self._ConvAfterStage,
272+
localContext=self._localContext, globalContext=self._globalContext,
273+
positionsConfigs=self._positionsConfigs,
274+
name=self.name + '/EncoderHead'
275+
)
276+
self._encoderHead.build(inputShape)
277+
return super().build(inputShape)
278+
279+
def call(self, src, training=None):
280+
return self._encoderHead(src, training=training)
281+
'''
282+
Simple encoder that takes image as input and returns corresponding latent vector with intermediate representations
283+
'''
284+
def createEncoderHead(
285+
config,
286+
downsampleSteps, latentDim,
287+
ConvBeforeStage, ConvAfterStage,
288+
localContext, globalContext,
289+
positionsConfigs,
290+
name
291+
):
292+
return CEncoderHead(
293+
config=config,
294+
downsampleSteps=downsampleSteps,
295+
latentDim=latentDim,
296+
ConvBeforeStage=ConvBeforeStage,
297+
ConvAfterStage=ConvAfterStage,
298+
localContext=localContext,
299+
globalContext=globalContext,
300+
positionsConfigs=positionsConfigs,
301+
name=name
302+
)

0 commit comments

Comments
 (0)