Skip to content

Commit 05140f0

Browse files
funding
1 parent 8fe566b commit 05140f0

14 files changed

+218
-148
lines changed

.github/FUNDING.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
patreon: GreenWizard
2+
buy_me_a_coffee: greenwizard89

FUNDING.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Support and Funding
2+
3+
If you want to support my quest for sarcastic humor, brilliant software achievements, and witty commentary, here’s how you can do it:
4+
5+
1. **Patreon**: All my amazing content there is free. But if you feel an irresistible urge to support me with monthly donations, head over to [my Patreon page](https://www.patreon.com/GreenWizard). Your support will help me keep delighting you with sarcasm and software revelations.
6+
7+
2. **Buy Me a Coffee**: Prefer one-time acts of generosity? You can buy me a coffee at [Buy Me a Coffee](https://buymeacoffee.com/greenwizard89). Because, honestly, my level of sarcasm and ability to write genius code directly correlate with the amount of caffeine I consume.
8+
9+
By supporting me, you ensure a continuous flow of sarcasm, wit, and cutting-edge software insights. And who wouldn't want more of that in their life?

NN/restorators/CARProcess.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ def __init__(self, predictions, sourceDistribution, sampler):
1616
def forward(self, x0, xT=None):
1717
B = tf.shape(x0)[0]
1818
# source distribution need to know the shape of the input, so we need to ensure it explicitly
19-
x0 = tf.ensure_shape(x0, (None, self.predictions))
19+
# x0 = tf.ensure_shape(x0, (None, self.predictions))
2020
sampled = self._sourceDistribution.sampleFor(x0)
2121
x1 = sampled['xT'] if xT is None else xT
2222

23-
tf.assert_equal(tf.shape(x0), (B, self.predictions))
23+
# tf.assert_equal(tf.shape(x0), (B, self.predictions))
2424
return self._sampler.train(x0=x0, x1=x1, T=sampled['T'], xT=xT)
2525

2626
def calculate_loss(self, x_hat, predicted, **kwargs):
27+
if hasattr(self._sampler, 'calculate_loss'):
28+
return self._sampler.calculate_loss(x_hat, predicted, **kwargs)
2729
lossFn = kwargs.get('lossFn', tf.losses.mae) # default loss function
2830
target = self.withExtraOutputs(x_hat['target'], **kwargs)
2931
tf.assert_equal(tf.shape(target), tf.shape(predicted))
@@ -45,11 +47,25 @@ def reverse(self, value, denoiser, modelT=None, index=0, **kwargs):
4547

4648
denoiser = self._makeDenoiser(denoiser, modelT)
4749
res = self._sampler.sample(value=value, model=denoiser, index=index, **kwargs)
48-
tf.assert_equal(tf.shape(res), tf.shape(value))
50+
# tf.assert_equal(tf.shape(res), tf.shape(value))
4951
return res
5052

5153
def targets(self, x_hat, values):
5254
return self._sampler.targets(x_hat, values[:, :self.predictions])
55+
56+
@property
57+
def channels(self):
58+
sampler = self._sampler
59+
if hasattr(sampler, 'channels'):
60+
return sampler.channels
61+
return super().channels
62+
63+
@property
64+
def predictions(self):
65+
sampler = self._sampler
66+
if hasattr(sampler, 'predictions'):
67+
return sampler.predictions
68+
return super().predictions
5369
# End of CARProcess
5470

5571
def autoregressive_restoration_from_config(config):

NN/restorators/IRestorationProcess.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import tensorflow as tf
22

3-
class IRestorationProcess:
3+
class IRestorationProcess(tf.keras.Model):
44
@staticmethod
55
def getOutputSize(outputs):
66
sizes = {
@@ -15,7 +15,9 @@ def getOutputSize(outputs):
1515
continue
1616
return channels
1717

18-
def __init__(self, predictions):
18+
def __init__(self, predictions, name=None, **kwargs):
19+
if name is None: name = self.__class__.__name__
20+
super().__init__(name=name, **kwargs)
1921
predictions = list(predictions)
2022
if 'rgb' not in predictions: predictions.insert(0, 'rgb')
2123
self._outputs = predictions
@@ -24,6 +26,9 @@ def __init__(self, predictions):
2426
print('[IRestorationProcess] Restorator:', self.__class__.__name__)
2527
return
2628

29+
def call(self, *args, **kwargs):
30+
raise RuntimeError('IRestorationProcess object cannot be called directly')
31+
2732
def forward(self, x0, xT=None):
2833
raise NotImplementedError()
2934

NN/restorators/samplers/CARSampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def solve(self, x_hat, step, value, interpolant, params, **kwargs):
176176
)
177177

178178
def directSolve(self, x_hat, values, interpolant, **kwargs):
179-
solved = interpolant.solve(x_hat=x_hat['x1'], xt=values, t=1.0).x0
179+
solved = interpolant.solve(x_hat=values, xt=x_hat['x1'], t=1.0).x0
180180
return solved
181181
# End of CARSamplingAlgorithm
182182

NN/restorators/samplers/CBasicInterpolantSampler.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,21 @@
22
from .ISamplingAlgorithm import ISamplingAlgorithm
33
from NN.utils import is_namedtuple
44

5-
class CBasicInterpolantSampler:
6-
def __init__(self, interpolant, algorithm):
5+
class IBasicInterpolantSampler(tf.keras.Model):
6+
@property
7+
def interpolant(self):
8+
raise NotImplementedError()
9+
10+
@tf.function
11+
def sample(self, value, model, index=0, **kwargs):
12+
raise NotImplementedError()
13+
14+
def targets(self, x_hat, values):
15+
raise NotImplementedError()
16+
17+
class CBasicInterpolantSampler(IBasicInterpolantSampler):
18+
def __init__(self, interpolant, algorithm, **kwargs):
19+
super().__init__(**kwargs)
720
self._interpolant = interpolant
821
self._algorithm = algorithm
922
return
@@ -17,7 +30,7 @@ def sample(self, value, model, index=0, **kwargs):
1730
kwargs = dict(**kwargs, interpolant=self._interpolant, index=index)
1831
# wrap algorithm with hook, if provided
1932
algorithm = kwargs.get('algorithmInterceptor', lambda x: x)( self._algorithm )
20-
assert isinstance(algorithm, ISamplingAlgorithm), f'Algorithm must be an instance of ISamplingAlgorithm, but got {type(algorithm)}'
33+
assert issubclass(type(algorithm), ISamplingAlgorithm), f'Algorithm must be an instance of ISamplingAlgorithm, but got {type(algorithm)}'
2134

2235
step = algorithm.firstStep(value=value, **kwargs)
2336
# CFakeObject is a namedtuple, so we need to check for it
@@ -39,6 +52,11 @@ def sample(self, value, model, index=0, **kwargs):
3952
# update value
4053
tf.assert_equal(tf.shape(value), tf.shape(solution.value))
4154
value = solution.value
55+
# for debugging, print euclidean distance to GT
56+
if 'GT' in kwargs:
57+
gt = kwargs['GT']
58+
dist = tf.reduce_sum(tf.square(value - gt), axis=-1)
59+
tf.print(f'Iteration {iteration}:', dist, summarize=10)
4260
iteration += 1
4361
continue
4462

NN/restorators/samplers/CDDIMInterpolantSampler.py

Lines changed: 2 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -1,109 +1,6 @@
11
import tensorflow as tf
2-
from Utils.utils import CFakeObject
3-
from NN.utils import normVec
4-
from .CBasicInterpolantSampler import CBasicInterpolantSampler, ISamplingAlgorithm
5-
6-
class CDDIMSamplingAlgorithm(ISamplingAlgorithm):
7-
def __init__(self, stochasticity, noiseProvider, schedule, steps, clipping, projectNoise):
8-
self._stochasticity = stochasticity
9-
self._noiseProvider = noiseProvider
10-
self._schedule = schedule
11-
self._steps = steps
12-
self._clipping = clipping
13-
self._projectNoise = projectNoise
14-
return
15-
16-
def _makeStep(self, current_step, steps, **kwargs):
17-
schedule = kwargs.get('schedule', self._schedule)
18-
eta = kwargs.get('eta', self._stochasticity)
19-
20-
T = steps[0][current_step]
21-
alpha_hat_t = schedule.parametersForT(T).alphaHat
22-
prevStepInd = steps[1][current_step]
23-
alpha_hat_t_prev = schedule.parametersForT(prevStepInd).alphaHat
24-
25-
stepVariance = schedule.varianceBetween(alpha_hat_t, alpha_hat_t_prev)
26-
sigma = tf.sqrt(stepVariance) * eta
27-
28-
return CFakeObject(
29-
steps=steps,
30-
current_step=current_step,
31-
active=(0 <= current_step),
32-
sigma=sigma,
33-
#
34-
T=T,
35-
t=alpha_hat_t,
36-
t_prev=alpha_hat_t_prev,
37-
t_prev_2=1.0 - alpha_hat_t_prev - tf.square(sigma),
38-
)
39-
40-
def firstStep(self, **kwargs):
41-
schedule = kwargs.get('schedule', self._schedule)
42-
assert schedule is not None, 'schedule is None'
43-
assert schedule.is_discrete, 'schedule is not discrete'
44-
steps = schedule.steps_sequence(
45-
startStep=kwargs.get('startStep', None),
46-
endStep=kwargs.get('endStep', None),
47-
config=kwargs.get('stepsConfig', self._steps),
48-
reverse=True, # reverse steps order to make it easier to iterate over them
49-
)
50-
51-
return self._makeStep(
52-
current_step=tf.size(steps[0]) - 1,
53-
steps=steps,
54-
**kwargs
55-
)
56-
57-
def nextStep(self, step, **kwargs):
58-
return self._makeStep(
59-
current_step=step.current_step - 1,
60-
steps=step.steps,
61-
**kwargs
62-
)
63-
64-
def inference(self, model, step, value, **kwargs):
65-
schedule = kwargs.get('schedule', self._schedule)
66-
return model(
67-
x=value,
68-
T=step.T,
69-
t=schedule.to_continuous(step.T),
70-
)
71-
72-
def _withNoise(self, value, sigma, x0, kwargs):
73-
noise_provider = kwargs.get('noiseProvider', self._noiseProvider)
74-
noise = noise_provider(shape=tf.shape(value), sigma=sigma)
75-
if not kwargs.get('projectNoise', self._projectNoise): return value + noise
76-
77-
_, L = normVec(value - x0)
78-
vec, _ = normVec(value + noise - x0)
79-
return x0 + L * vec # project noise back to the spherical manifold
80-
81-
def _withClipping(self, value, kwargs):
82-
clipping = kwargs.get('clipping', self._clipping)
83-
if clipping is None: return value
84-
return tf.clip_by_value(value, clip_value_min=clipping['min'], clip_value_max=clipping['max'])
85-
86-
def solve(self, x_hat, step, value, interpolant, **kwargs):
87-
solved = interpolant.solve(x_hat=x_hat, xt=value, t=step.t)
88-
x_prev = interpolant.interpolate(
89-
x0=solved.x0, x1=solved.x1,
90-
t=step.t_prev, t2=step.t_prev_2
91-
)
92-
x_prev = self._withNoise(x_prev, sigma=step.sigma, x0=solved.x0, kwargs=kwargs)
93-
x_prev = self._withClipping(x_prev, kwargs=kwargs)
94-
# return solution and additional information for debugging
95-
return CFakeObject(
96-
value=x_prev,
97-
x0=solved.x0,
98-
x1=solved.x1,
99-
T=step.T,
100-
current_step=step.current_step,
101-
sigma=step.sigma,
102-
)
103-
104-
def directSolve(self, x_hat, xt, interpolant):
105-
return interpolant.solve(x_hat=xt, xt=x_hat['xT'], t=x_hat['alphaHat']).x0
106-
# End of CDDIMSamplingAlgorithm
2+
from .CBasicInterpolantSampler import CBasicInterpolantSampler
3+
from .CDDIMSamplingAlgorithm import CDDIMSamplingAlgorithm
1074

1085
class CDDIMInterpolantSampler(CBasicInterpolantSampler):
1096
def __init__(

NN/restorators/samplers/CSamplerWatcher.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import tensorflow as tf
22
from NN.utils import is_namedtuple
3-
import NN.utils as NNU
43
from .CSamplingInterceptor import CSamplingInterceptor
54
from .ISamplerWatcher import ISamplerWatcher
65

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,48 @@
11
from .ISamplingAlgorithm import ISamplingAlgorithm
2+
from .ISamplerWatcher import ISamplerWatcher
23

3-
class CSamplingInterceptor(ISamplingAlgorithm):
4+
class CSamplingInterceptor(ISamplingAlgorithm, ISamplerWatcher):
45
def __init__(self, watcher, algorithm):
6+
assert issubclass(type(watcher), ISamplerWatcher), f'Invalid watcher: {watcher}'
57
self._watcher = watcher
68
self._algorithm = algorithm
79
return
810

11+
def interceptor(self):
12+
def F(algorithm):
13+
if isinstance(self._watcher, ISamplerWatcher):
14+
self._watcher = algorithm = self._watcher.interceptor()(algorithm)
15+
16+
if callable(self._watcher): # replace the watcher with the interceptor
17+
self._watcher = algorithm = self._watcher(algorithm)
18+
19+
# if self._algorithm is not None:
20+
# assert isinstance(algorithm, ISamplingAlgorithm), f'algorithm is not an instance of ISamplingAlgorithm: {algorithm}'
21+
return self
22+
return F
23+
924
def firstStep(self, **kwargs):
10-
res = self._algorithm.firstStep(**kwargs)
25+
res = self._algorithm.firstStep(**kwargs) if self._algorithm is not None else None
1126
self._watcher._onStart(value=kwargs['value'], kwargs=kwargs)
1227
return res
1328

29+
def _onStart(self, value, kwargs):
30+
return self._watcher._onStart(value=value, kwargs=kwargs)
31+
1432
def nextStep(self, **kwargs):
1533
self._watcher._onNextStep(iteration=kwargs['iteration'], kwargs=kwargs)
16-
res = self._algorithm.nextStep(**kwargs)
34+
res = self._algorithm.nextStep(**kwargs) if self._algorithm is not None else None
1735
return res
1836

37+
def _onNextStep(self, iteration, kwargs):
38+
return self._watcher._onNextStep(iteration=iteration, kwargs=kwargs)
39+
1940
def inference(self, **kwargs):
20-
return self._algorithm.inference(**kwargs)
41+
return self._algorithm.inference(**kwargs) if self._algorithm is not None else None
2142

2243
def solve(self, **kwargs):
23-
return self._algorithm.solve(**kwargs)
44+
return self._algorithm.solve(**kwargs) if self._algorithm is not None else None
45+
46+
def tracked(self, name):
47+
return self._watcher.tracked(name)
2448
# End of CSamplingInterceptor

0 commit comments

Comments
 (0)