From 5562fbe55c8cd42b650e33a6e2bf8d8a974503f2 Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Sat, 12 Oct 2024 13:12:24 +0200
Subject: [PATCH 01/10] basic model for frame inpainting

---
 NN/LagrangianInterpolation.py |  64 +++++++++++++
 NN/networks.py                | 168 ++++++++++++++++++++++++++++++++--
 2 files changed, 222 insertions(+), 10 deletions(-)
 create mode 100644 NN/LagrangianInterpolation.py

diff --git a/NN/LagrangianInterpolation.py b/NN/LagrangianInterpolation.py
new file mode 100644
index 0000000..7c53f9a
--- /dev/null
+++ b/NN/LagrangianInterpolation.py
@@ -0,0 +1,64 @@
+import tensorflow as tf
+
+def lagrange_interpolation(x_values, y_values, x_targets):
+    """
+    Perform Lagrange Polynomial Interpolation using TensorFlow with batch support
+    and multidimensional y_values.
+
+    Parameters:
+    - x_values: Tensor of shape (batch_size, n), original x-values for each batch.
+    - y_values: Tensor of shape (batch_size, n, d), original y-values for each batch.
+    - x_targets: Tensor of shape (batch_size, m), x-values where interpolation is computed.
+
+    Returns:
+    - interpolated_values: Tensor of shape (batch_size, m, d), interpolated y-values for each batch.
+    """
+    minX = tf.reduce_min(x_values, axis=1)
+    maxX = tf.reduce_max(x_values, axis=1)
+    # Check if x_targets in the range of x_values
+    tf.debugging.assert_greater_equal(x_targets, minX, message="x_targets out of range")
+    tf.debugging.assert_less_equal(x_targets, maxX, message="x_targets out of range")
+
+    batch_size = tf.shape(x_values)[0]
+    n = tf.shape(x_values)[1]
+    m = tf.shape(x_targets)[-1]
+    d = tf.shape(y_values)[2]
+
+    tf.assert_equal(tf.shape(x_values), (batch_size, n))
+    tf.assert_equal(tf.shape(y_values), (batch_size, n, d))
+    tf.assert_equal(tf.shape(x_targets), (batch_size, m))
+    # Reshape tensors for broadcasting
+    x_values_i = tf.reshape(x_values, (batch_size, n, 1, 1))        # Shape: (batch_size, n, 1, 1)
+    x_values_j = tf.reshape(x_values, (batch_size, 1, n, 1))        # Shape: (batch_size, 1, n, 1)
+
+    x_targets_k = tf.reshape(x_targets, (batch_size, 1, 1, m))  # Shape: (batch_size, 1, 1, m)
+
+    # Compute the denominators (x_i - x_j)
+    denominators = x_values_i - x_values_j                          # Shape: (batch_size, n, n, 1)
+    # Replace zeros on the diagonal with ones to avoid division by zero
+    denominators = tf.where(tf.equal(denominators, 0.0), tf.ones_like(denominators), denominators)
+
+    # Compute the numerators (x_k - x_j)
+    numerators = x_targets_k - x_values_j                           # Shape: (batch_size, 1, n, m)
+
+    # Compute the terms (x_k - x_j) / (x_i - x_j)
+    terms = numerators / denominators                               # Shape: (batch_size, n, n, m)
+
+    # Exclude the terms where i == j by setting them to 1
+    identity_matrix = tf.eye(n, batch_shape=[batch_size], dtype=tf.float64)  # Shape: (batch_size, n, n)
+    identity_matrix = tf.reshape(identity_matrix, (batch_size, n, n, 1))     # Shape: (batch_size, n, n, 1)
+    terms = tf.where(tf.equal(identity_matrix, 1.0), tf.ones_like(terms), terms)
+
+    # Compute the product over j for each i and x_k
+    basis_polynomials = tf.reduce_prod(terms, axis=2)               # Shape: (batch_size, n, m)
+
+    # Multiply each basis polynomial by the corresponding y_i
+    # Adjust shapes for broadcasting
+    basis_polynomials_expanded = tf.expand_dims(basis_polynomials, axis=-1)  # Shape: (batch_size, n, m, 1)
+    y_values_expanded = tf.expand_dims(y_values, axis=2)                     # Shape: (batch_size, n, 1, d)
+    products = basis_polynomials_expanded * y_values_expanded                # Shape: (batch_size, n, m, d)
+
+    # Sum over i to get the interpolated values
+    interpolated_values = tf.reduce_sum(products, axis=1)                    # Shape: (batch_size, m, d)
+
+    return interpolated_values
\ No newline at end of file
diff --git a/NN/networks.py b/NN/networks.py
index 9e0f947..719b3e5 100644
--- a/NN/networks.py
+++ b/NN/networks.py
@@ -1,3 +1,8 @@
+import argparse, os, sys
+# add the root folder of the project to the path
+ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
+sys.path.append(ROOT_FOLDER)
+
 from Core.Utils import setupGPU
 setupGPU() # dirty hack to setup GPU memory limit on startup
 
@@ -7,7 +12,7 @@
 from NN.Utils import *
 from NN.EyeEncoder import eyeEncoder
 from NN.FaceMeshEncoder import FaceMeshEncoder
-import numpy as np
+from NN.LagrangianInterpolation import lagrange_interpolation
 
 class CTimeEncoderLayer(tf.keras.layers.Layer):
   def __init__(self, **kwargs):
@@ -219,14 +224,157 @@ def Face2LatentModel(
     'inputs specification': _InputSpec()
   }
   
-if __name__ == '__main__':
-  X = Face2LatentModel(steps=5, latentSize=64,
-    embeddings={
-      'userId': 1, 'placeId': 1, 'screenId': 1, 'size': 64
+##########################
+def InpaintingEncoderModel(latentSize, embeddings, steps=5, pointsN=478, eyeSize=32, KP=5):
+  points = L.Input((steps, pointsN, 2))
+  eyeL = L.Input((steps, eyeSize, eyeSize, 1))
+  eyeR = L.Input((steps, eyeSize, eyeSize, 1))
+  T = L.Input((steps, 1)) # accumulative time
+  target = L.Input((steps, 2))
+  userIdEmb = L.Input((steps, embeddings['size']))
+  placeIdEmb = L.Input((steps, embeddings['size']))
+  screenIdEmb = L.Input((steps, embeddings['size']))
+
+  emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
+
+  Face2Step = Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize=emb.shape[-1])
+  stepsData = Face2Step({
+    'embeddings': emb,
+    'points': points,
+    'left eye': eyeL,
+    'right eye': eyeR,
+  })
+
+  diffT = T[:, 1:] - T[:, :-1]
+  diffT = L.Concatenate(-2)([tf.zeros_like(diffT[:, :1]), diffT])
+  combinedT = L.Concatenate(-1)([T, diffT])
+  encodedT = CRolloutTimesteps(CCoordsEncodingLayer(32), name='Time')(combinedT[..., None, :])[..., 0, :]
+
+  latent = stepsData['latent']
+  # add time encoding and target position
+  targetEncoded = CRolloutTimesteps(CCoordsEncodingLayer(32), name='Target')(target[..., None, :])[..., 0, :]
+  latent = L.Concatenate(-1)([latent, encodedT, targetEncoded])
+  # flatten the latent
+  latent = L.Reshape((-1,))(latent)
+
+  # compress the latent
+  latent_N = latent.shape[-1]
+  sizes = []
+  for i in range(1, 4):
+    for _ in range(i):
+       sizes.append(max(latent_N // i, latentSize))
+       
+  sizes.append(latentSize)
+  latent = sMLP(sizes=sizes, activation='relu', name='Compress')(latent)
+  keyT = tf.linspace(0.0, 1.0, KP)[None, :]
+
+  # keyT shape: (B, KP, 1)
+  def transformKeyT(x):
+    t, x = x
+    B = tf.shape(x)[0]
+    return tf.tile(t, (B, 1))[..., None]
+  keyT = L.Lambda(transformKeyT)([keyT, latent])
+  # keyT shape: (B, KP, 1)
+  maxT = T[:, -1, None]
+  keyT = L.Concatenate(-1)([keyT, maxT * keyT]) # fractional time and absolute time
+  encodedKeyT = CRolloutTimesteps(CCoordsEncodingLayer(32), name='KeyTime')(keyT[..., None, :])[..., 0, :]
+
+  def combineKeys(x):
+    latent, keyT = x
+    latent = tf.tile(latent[..., None, :], (1, KP, 1))
+    return L.Concatenate(-1)([latent, keyT])
+  latent = L.Lambda(combineKeys)([latent, encodedKeyT])
+
+  latent = sMLP(sizes=[latentSize] * 3, activation='relu', name='CombineKeys')(latent)
+
+  main = tf.keras.Model(
+    inputs={
+      'points': points,
+      'left eye': eyeL,
+      'right eye': eyeR,
+      'time': T,
+      'target': target,
+      'userId': userIdEmb,
+      'placeId': placeIdEmb,
+      'screenId': screenIdEmb,
+    },
+    outputs={
+      'latent': latent,
+    }
+  )
+  return main
+ 
+def InpaintingDecoderModel(latentSize, embeddings, pointsN=478, eyeSize=32, KP=5):
+  latentKeyPoints = L.Input((KP, latentSize))
+  T = L.Input((None, 1))
+  userIdEmb = L.Input((embeddings['size']))
+  placeIdEmb = L.Input((embeddings['size']))
+  screenIdEmb = L.Input((embeddings['size']))
+
+  emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])[..., None, :]
+  # emb shape: (B, 1, 3 * embSize) 
+  def interpolateKeys(x):
+    latents, T = x
+    B = tf.shape(latents)[0]
+    keyT = tf.linspace(0.0, 1.0, KP)[None, :]
+    keyT = tf.tile(keyT, (B, 1))
+    return lagrange_interpolation(x_values=keyT, y_values=latents, x_targets=T[..., 0])
+  latents = L.Lambda(interpolateKeys, name='InterpolateKeys')([latentKeyPoints, T])
+  # latents shape: (B, N, latentSize)
+  def transformLatents(x):
+    latents, emb = x
+    N = tf.shape(latents)[1]
+    emb = tf.tile(emb, (1, N, 1)) # (B, 1, 3 * embSize) -> (B, N, 3 * embSize)
+    return L.Concatenate(-1)([latents, emb])
+  latents = L.Lambda(transformLatents, name='CombineEmb')([latents, emb])
+  # process the latents
+  latents = sMLP(sizes=[latentSize] * 3, activation='relu', name='CombineEmb/MLP')(latents)
+  # decode the latents to the face points (478, 2), two eyes (32, 32, 2) and the target (2) 
+  target = IntermediatePredictor(shift=0.5)(latents)
+  # two eyes
+  eyesN = eyeSize * eyeSize
+  eyes = sMLP(sizes=[eyesN] * 2, activation='relu')(latents)
+  eyes = L.Dense(eyesN * 2)(eyes)
+  eyes = L.Reshape((-1, eyeSize, eyeSize, 2))(eyes)
+  # face points
+  face = sMLP(sizes=[pointsN] * 2, activation='relu')(latents)
+  face = L.Dense(pointsN * 2)(face)
+  face = L.Reshape((-1, pointsN, 2))(face)
+
+  model = tf.keras.Model(
+    inputs={
+      'keyPoints': latentKeyPoints,
+      'time': T,
+      'userId': userIdEmb,
+      'placeId': placeIdEmb,
+      'screenId': screenIdEmb,
+    },
+    outputs={
+      'target': target,
+      'left eye': eyes[:, :, 0],
+      'right eye': eyes[:, :, 1],
+      'face': face,
     }
   )
-  X['main'].summary(expand_nested=True)
-  X['Face2Step'].summary(expand_nested=False)
-  X['Step2Latent'].summary(expand_nested=False)
-  print(X['main'].outputs)
-  pass
\ No newline at end of file
+  return model
+
+
+if __name__ == '__main__':
+  # X = InpaintingEncoderModel(latentSize=256, embeddings={
+  #   'size': 64
+  # })
+  X = InpaintingDecoderModel(latentSize=256, embeddings={
+    'size': 64
+  })
+  X.summary(expand_nested=False)
+
+  # X = Face2LatentModel(steps=5, latentSize=64,
+  #   embeddings={
+  #     'userId': 1, 'placeId': 1, 'screenId': 1, 'size': 64
+  #   }
+  # )
+  # X['main'].summary(expand_nested=True)
+  # X['Face2Step'].summary(expand_nested=False)
+  # X['Step2Latent'].summary(expand_nested=False)
+  # print(X['main'].outputs)
+  # pass
\ No newline at end of file

From 66ab19d407456bb0367105b7d2d6a78e79a8822a Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Tue, 15 Oct 2024 19:16:46 +0200
Subject: [PATCH 02/10] fix

---
 scripts/preprocess-remote.py | 13 +------------
 1 file changed, 1 insertion(+), 12 deletions(-)

diff --git a/scripts/preprocess-remote.py b/scripts/preprocess-remote.py
index 167dd29..fddf377 100644
--- a/scripts/preprocess-remote.py
+++ b/scripts/preprocess-remote.py
@@ -179,7 +179,7 @@ def processFolder(folder, timeDelta, testRatio, framesPerChunk, testPadding, ski
       start, end, np.min(delta), np.max(delta), np.mean(delta), len(session_time), duration
     ))
     stats['deltas'].append(delta)
-    stats['durations'].append(duration)
+    stats['durations'].append([duration])
     continue
   ######################################################
   # split each session into training and testing sets
@@ -215,17 +215,6 @@ def saveSubset(filename, idx):
   return len(testing), len(training), False, stats
 
 def main(args):
-  # blacklisted datasets
-  blacklisted = []
-  if args.blacklist is not None:
-    with open(args.blacklist, 'r') as f:
-      blacklisted = json.load(f)
-    pass
-  blacklisted = set([
-    '/'.join(item)
-    for item in blacklisted
-  ])
-  print(blacklisted)
   stats = {
     'placeId': [],
     'userId': [],

From 32b33cfb238d64111977932f77a5f18691cd5a84 Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Tue, 15 Oct 2024 20:42:43 +0200
Subject: [PATCH 03/10] wip

---
 Core/CBaseDataSampler.py        | 148 +++++++++++
 Core/CDataSampler.py            | 424 ++++++++++----------------------
 Core/CDataSamplerInpainting.py  | 214 ++++++++++++++++
 Core/CDataSampler_utils.py      |   2 +-
 Core/CDatasetLoader.py          |  45 +---
 Core/Utils.py                   |   3 +-
 NN/networks.py                  |  30 ++-
 scripts/download-remote.py      |   7 +-
 scripts/train-reconstruction.py | 268 ++++++++++++++++++++
 scripts/train.py                |   4 +-
 10 files changed, 802 insertions(+), 343 deletions(-)
 create mode 100644 Core/CBaseDataSampler.py
 create mode 100644 Core/CDataSamplerInpainting.py
 create mode 100644 scripts/train-reconstruction.py

diff --git a/Core/CBaseDataSampler.py b/Core/CBaseDataSampler.py
new file mode 100644
index 0000000..dc02b37
--- /dev/null
+++ b/Core/CBaseDataSampler.py
@@ -0,0 +1,148 @@
+import numpy as np
+import random
+from math import ceil
+from functools import lru_cache
+
+class CBaseDataSampler:
+    def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
+        '''
+        Base class for data sampling.
+
+        Parameters:
+        - storage: The storage object containing the samples.
+        - batch_size: The number of samples per batch.
+        - minFrames: The minimum number of frames required in a trajectory.
+        - defaults: Default parameters for sampling.
+        - maxT: Maximum time window for sampling frames.
+        - cumulative_time: If True, time is cumulative; otherwise, it's time deltas.
+        '''
+        self._storage = storage
+        self._defaults = defaults
+        self._batchSize = batch_size
+        self._maxT = maxT
+        self._minFrames = minFrames
+        self._samples = []
+        self._currentSample = None
+        self._cumulative_time = cumulative_time
+        return
+
+    def reset(self):
+        random.shuffle(self._samples)
+        self._currentSample = 0
+        return
+
+    def __len__(self):
+        return ceil(len(self._samples) / self._batchSize)
+
+    def _storeSample(self, idx):
+        # Store sample if it has enough frames
+        minInd = self._getTrajectoryBefore(idx)
+        if self._minFrames <= (idx - minInd):
+            self._samples.append(idx)
+        return
+
+    def add(self, sample):
+        idx = self._storage.add(sample)
+        self._storeSample(idx)
+        return idx
+
+    def addBlock(self, samples):
+        indexes = self._storage.addBlock(samples)
+        for idx in indexes:
+            self._storeSample(idx)
+            continue
+        return
+
+    def _getTrajectoryBefore(self, mainInd):
+        mainT = self._storage[mainInd]['time']
+        minT = mainT - self._maxT
+
+        minInd = mainInd
+        for ind in range(mainInd - 1, -1, -1):
+            if self._storage[ind]['time'] < minT: break
+            minInd = ind
+            continue
+        return minInd
+
+    @lru_cache(None)
+    def _trajectoryRange(self, mainInd):
+        '''
+        Returns indexes of samples that are within maxT from mainInd.
+        Returns (minInd, maxInd) where minInd <= mainInd <= maxInd
+        '''
+        mainT = self._storage[mainInd]['time']
+        maxT = mainT + self._maxT
+        maxInd = mainInd
+        for ind in range(mainInd, len(self._storage)):
+            if maxT < self._storage[ind]['time']: break
+            maxInd = ind
+            continue
+
+        minInd = self._getTrajectoryBefore(mainInd)
+        return minInd, maxInd
+
+    def _trajectory(self, mainInd):
+        minInd, maxInd = self._trajectoryRange(mainInd)
+        return list(range(minInd, mainInd + 1)), list(range(mainInd + 1, maxInd + 1))
+
+    def _prepareT(self, res):
+        T = np.array([self._storage[ind]['time'] for ind in res])
+        T -= T[0]
+        diff = np.diff(T, 1)
+        idx = np.nonzero(diff)[0]
+        if len(idx) < 1: return None  # All frames have the same time
+        if len(diff) == len(idx):
+            T = diff
+        else:
+            return None # Time is not consistent
+        T = np.insert(T, 0, 0.0)
+        assert len(res) == len(T)
+        # Convert to cumulative time if required
+        if self._cumulative_time:
+            T = np.cumsum(T)
+        return T
+
+    def _reshapeSteps(self, values, steps):
+        if steps is None:
+            return values
+
+        res = []
+        for x in values:
+            B, *s = x.shape
+            newShape = (B // steps, steps, *s)
+            res.append(x.reshape(newShape))
+            continue
+        return tuple(res)
+
+    @property
+    def totalSamples(self):
+        return len(self._storage)
+
+    def validSamples(self):
+        return list(sorted(self._samples))
+    
+    def _framesFor(self, mainInd, samples, steps, stepsSampling):
+        if 'uniform' == stepsSampling:
+            samples = random.sample(samples, steps - 1)
+        elif 'last' == stepsSampling:
+            samples = samples[-(steps - 1):]
+        elif isinstance(stepsSampling, dict):
+            candidates = list(samples)
+            maxFrames = stepsSampling['max frames']
+            candidates = candidates[::-1]
+            samples = []
+            left = steps - 1
+            for _ in range(left):
+                avl = min((maxFrames, 1 + len(candidates) - left))
+                ind = random.randint(0, avl - 1)
+                samples.append(candidates[ind])
+                candidates = candidates[ind+1:]
+                left -= 1
+                continue
+            pass
+        else:
+            raise ValueError('Unknown sampling method: ' + str(stepsSampling))
+
+        res = list(sorted(samples + [mainInd]))
+        assert len(res) == steps
+        return res
\ No newline at end of file
diff --git a/Core/CDataSampler.py b/Core/CDataSampler.py
index ce81c72..aae5d40 100644
--- a/Core/CDataSampler.py
+++ b/Core/CDataSampler.py
@@ -1,306 +1,152 @@
+from .CBaseDataSampler import CBaseDataSampler
+import Core.CDataSampler_utils as DSUtils
+
 import numpy as np
-import random
-from math import ceil
-import Core.Utils as Utils
 from functools import lru_cache
-import Core.CDataSampler_utils as DSUtils
 
-class CDataSampler:
-  def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
-    '''
-    If cumulative_time is True, then time is a cumulative time from the start of the trajectory i.e. [0, 0.1, 0.2, 0.3, ...]
-    If cumulative_time is False, then time is a time delta between frames i.e. [0, 0.1, 0.1, 0.1, ...]
-    '''
-    self._storage = storage
-    self._defaults = defaults
-    self._batchSize = batch_size
-    self._maxT = maxT
-    self._minFrames = minFrames
-    self._samples = []
-    self._currentSample = None
-    self._cumulative_time = cumulative_time
-    return
-  
-  def reset(self):
-    random.shuffle(self._samples)
-    self._currentSample = 0
-    return
+'''
+This sampler are sample N frames from the dataset, where N is the number of timesteps.
+It returns the tuple (X, Y), where X is the input data and Y is the target data.
+To X could be applied some augmentations.
+X contains the following data:
+  - The points of the face.
+  - The left eye.
+  - The right eye.
+  - The time (cumulative or delta).
+  - The user ID, place ID, and screen ID.
+Y contains the target data.
+  - The target point.
+'''
+class CDataSampler(CBaseDataSampler):
+    def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
+        super().__init__(storage, batch_size, minFrames, defaults, maxT, cumulative_time)
 
-  def __len__(self):
-    return ceil(len(self._samples) / self._batchSize)
+    def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
+        if (steps is None) or (1 == steps): return [(mainInd, 0.0)]
+        if mainInd < steps: return False
 
-  def _storeSample(self, idx):
-    # store sample if it has enough frames
-    minInd = self._getTrajectoryBefore(idx)
-    if self._minFrames <= (idx - minInd):
-      self._samples.append(idx)
-    return
-  
-  def add(self, sample):
-    idx = self._storage.add(sample)
-    self._storeSample(idx)
-    return idx
-  
-  def addBlock(self, samples):
-    indexes = self._storage.addBlock(samples)
-    for idx in indexes:
-      self._storeSample(idx)
-      continue
-    return
+        samples, _ = self._trajectory(mainInd)
+        if len(samples) < (steps - 1): return False
+        # Try to sample valid frames
+        for _ in range(10):
+            res = self._framesFor(mainInd, samples, steps, stepsSampling)
+            T = self._prepareT(res)
+            if T is not None:
+                assert len(res) == len(T)
+                return [tuple(x) for x in zip(res, T)]
+            continue
+        return False
 
-  def _getTrajectoryBefore(self, mainInd):
-    mainT = self._storage[mainInd]['time']
-    minT = mainT - self._maxT
-    
-    minInd = mainInd
-    for ind in range(mainInd - 1, -1, -1):
-      if self._storage[ind]['time'] < minT: break
-      minInd = ind
-      continue
-    return minInd
-  
-  @lru_cache(None)
-  def _trajectoryRange(self, mainInd):
-    '''
-      Returns indexes of samples that are in the range of maxT from the mainInd
-      Returns (minInd, maxInd) where minInd <= mainInd <= maxInd
-    '''
-    mainT = self._storage[mainInd]['time']
-    maxT = mainT + self._maxT
-    maxInd = mainInd
-    for ind in range(mainInd, len(self._storage)):
-      if maxT < self._storage[ind]['time']: break
-      maxInd = ind
-      continue
-    
-    minInd = self._getTrajectoryBefore(mainInd)
-    return minInd, maxInd
+    def sample(self, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        N = kwargs.get('N', self._batchSize)
+        indexes = []
+        for _ in range(N):
+            added = False
+            while not added:
+                idx = self._samples[self._currentSample]
+                self._currentSample = (self._currentSample + 1) % len(self._samples)
 
-  def _trajectory(self, mainInd):
-    minInd, maxInd = self._trajectoryRange(mainInd)
-    return list(range(minInd, mainInd)), list(range(mainInd + 1, maxInd + 1))
-  
-  def _trajectory2keypoints(self, before, mainInd, after, N):
-    mainPt = self._storage[mainInd]['goal']
-    if 1 < N:
-      trajectory = []
-      trajectory.extend(before)
-      trajectory.append(mainInd)
-      trajectory.extend(after)
-      trajectory = np.array([self._storage[ind]['goal'] for ind in trajectory])
-      chunksN = max((1, len(trajectory) // N))
-      keypoints = [(mainPt, 0.0)]
-      for i in range(0, len(trajectory), chunksN):
-        x = trajectory[i:i+chunksN]
-        if 0 < len(x):
-          pt = np.mean(x, axis=0)
-          d = np.linalg.norm(pt - mainPt)
-          keypoints.append((pt, d))
-        continue
-      while len(keypoints) < N: keypoints.append(keypoints[0])
-      keypoints = sorted(keypoints, key=lambda x: x[1])
-      keypoints = [pt for pt, _ in keypoints]
-      keypoints = np.array(keypoints[:N])
-    else:
-      keypoints = np.array([mainPt])
-    return keypoints
+                sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+                if sampledSteps:
+                    # TODO: remove from samples?
+                    indexes.extend(sampledSteps)
+                    added = True
+                continue
 
-  def _prepareT(self, res):
-    T = np.array([self._storage[ind]['time'] for ind in res])
-    T -= T[0]
-    diff = np.diff(T, 1)
-    idx = np.nonzero(diff)[0]
-    if len(idx) < 1: return None # all frames have the same time
-    if len(diff) == len(idx):
-      T = diff
-    else:
-      # avg non-zero diff
-      dT = np.min(diff[idx])
-      T = np.append(T, T[-1] + dT)
-      idx = [0, *(1 + idx), len(T) - 1]
-      T = np.interp(np.arange(len(T) - 1), idx, T[idx])
-      T = np.diff(T, 1)
-      pass
-    T = np.insert(T, 0, 0.0)
-    assert len(res) == len(T)
-    # T is an array of time deltas like [0, 0.1, 0.1, 0.1, ...], convert it to cumulative time
-    if self._cumulative_time:
-      T = np.cumsum(T)
-    return T
-  
-  def _framesFor(self, mainInd, samples, steps, stepsSampling):
-    if 'uniform' == stepsSampling:
-      samples = random.sample(samples, steps - 1)
-    if 'last' == stepsSampling:
-      samples = samples[-(steps - 1):]
-      
-    if isinstance(stepsSampling, dict):
-      candidates = list(samples)
-      maxFrames = stepsSampling['max frames']
-      candidates = candidates[::-1]
-      samples = []
-      left = steps - 1
-      for _ in range(left):
-        avl = min((maxFrames, 1 + len(candidates) - left))
-        ind = random.randint(0, avl - 1)
-        samples.append(candidates[ind])
-        candidates = candidates[ind+1:]
-        left -= 1
-        continue
-      pass
-      
-    res = list(sorted(samples + [mainInd]))
-    assert len(res) == steps
-    return res
-  
-  def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
-    if (steps is None) or (1 == steps): return [(mainInd, 0.0)]
-    if mainInd < steps: return False
-    
-    samples, _ = self._trajectory(mainInd)
-    if len(samples) < (steps - 1): return False
-    # Try to sample valid frames
-    for _ in range(10):
-      res = self._framesFor(mainInd, samples, steps, stepsSampling)
-      T = self._prepareT(res)
-      if T is not None:
-        assert len(res) == len(T)
-        return [tuple(x) for x in zip(res, T)]
-      continue
-    return False
-  
-  def sample(self, **kwargs):
-    kwargs = {**self._defaults, **kwargs}
-    timesteps = kwargs.get('timesteps', None)
-    N = kwargs.get('N', self._batchSize)
-    indexes = []
-    for _ in range(N):
-      added = False
-      while not added:
-        idx = self._samples[self._currentSample]
-        self._currentSample = (self._currentSample + 1) % len(self._samples)
+        return self._indexes2XY(indexes, kwargs)
 
+    def sampleById(self, idx, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
         sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
-        if sampledSteps:
-          # TODO: remove from samples?
-          indexes.extend(sampledSteps)
-          added = True
-      continue
+        if not sampledSteps: return None
+        return self._indexes2XY([*sampledSteps], kwargs)
+
+    def checkById(self, idx, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+        if not sampledSteps: return False
+        return True
+
+    def sampleByIds(self, ids, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = []
+        rejected = []
+        accepted = []
+        for idx in ids:
+            sample = self._stepsFor(idx, steps=timesteps, **kwargs)
+            if sample:
+                accepted.append(idx)
+                sampledSteps.extend(sample)
+            else:
+                rejected.append(idx)
+                pass
+            continue
+
+        res = None
+        if 0 < len(sampledSteps):
+            res = self._indexes2XY(sampledSteps, kwargs)
+        return res, rejected, accepted
 
-    return self._indexes2XY(indexes, kwargs)
+    @lru_cache(None)
+    def _targetFor(self, ind):
+        mainPt = self._storage[ind]['goal']
+        keypoints = np.array(mainPt, np.float32)
+        return keypoints
 
-  def sampleById(self, idx, **kwargs):
-    kwargs = {**self._defaults, **kwargs}
-    timesteps = kwargs.get('timesteps', None)
-    sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
-    if not sampledSteps: return None
-    return self._indexes2XY([*sampledSteps], kwargs)
-  
-  def checkById(self, idx, **kwargs):
-    kwargs = {**self._defaults, **kwargs}
-    timesteps = kwargs.get('timesteps', None)
-    sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
-    if not sampledSteps: return False
-    return True
-    
-  def sampleByIds(self, ids, **kwargs):
-    kwargs = {**self._defaults, **kwargs}
-    timesteps = kwargs.get('timesteps', None)
-    sampledSteps = []
-    rejected = []
-    accepted = []
-    for idx in ids:
-      sample = self._stepsFor(idx, steps=timesteps, **kwargs)
-      if sample:
-        accepted.append(idx)
-        sampledSteps.extend(sample)
-      else:
-        rejected.append(idx)
-        pass
-      continue
-    
-    res = None
-    if 0 < len(sampledSteps):
-      res = self._indexes2XY(sampledSteps, kwargs)
-    return res, rejected, accepted
-  
-  def _reshapeSteps(self, values, steps):
-    if steps is None: return values
-    
-    res = []
-    for x in values:
-      B, *s = x.shape
-      newShape = (B // steps, steps, *s)
-      res.append(x.reshape(newShape))
-      continue
-    return tuple(res)
-  
-  @lru_cache(None)
-  def _targetFor(self, ind, keypoints=1, past=True, future=True, **_):
-    before, after = self._trajectory(ind)
-    if not past: before = []
-    if not future: after = []
-    return self._trajectory2keypoints(before, ind, after, N=keypoints)
+    def _indexes2XY(self, indexesAndTime, kwargs):
+        timesteps = kwargs.get('timesteps', None)
+        samples = [self._storage[i] for i, _ in indexesAndTime]
 
-  def _indexes2XY(self, indexesAndTime, kwargs):
-    timesteps = kwargs.get('timesteps', None)
-    samples = [self._storage[i] for i, _ in indexesAndTime]
+        Y = ( np.array([ self._targetFor(i)  for i, _ in indexesAndTime], np.float32), )
+        Y = self._reshapeSteps(Y, timesteps)
+        ##############
+        userIds = np.unique([x['userId'] for x in samples])
+        assert 1 == len(userIds), 'Only one user is supported. Found: ' + str(userIds)
+        placeIds = np.unique([x['placeId'] for x in samples])
+        assert 1 == len(placeIds), 'Only one place is supported. Found: ' + str(placeIds)
+        screenIds = np.unique([x['screenId'] for x in samples])
+        assert 1 == len(screenIds), 'Only one screen is supported. Found: ' + str(screenIds)
 
-    forecast = kwargs.get('forecast', {})
-    Y = ( np.array([
-      self._targetFor(i, **forecast) 
-      for i, _ in indexesAndTime
-    ], np.float32), )
-    Y = self._reshapeSteps(Y, timesteps)
-    ##############
-    userIds = np.unique([x['userId'] for x in samples])
-    assert 1 == len(userIds), 'Only one user is supported. Found: ' + str(userIds)
-    placeIds = np.unique([x['placeId'] for x in samples])
-    assert 1 == len(placeIds), 'Only one place is supported. Found: ' + str(placeIds)
-    screenIds = np.unique([x['screenId'] for x in samples])
-    assert 1 == len(screenIds), 'Only one screen is supported. Found: ' + str(screenIds)
+        X = DSUtils.toTensor(
+            (
+                np.array([x['points'] for x in samples], np.float32),
+                np.array([x['left eye'] for x in samples]),
+                np.array([x['right eye'] for x in samples]),
+                np.array([T for _, T in indexesAndTime], np.float32).reshape((-1, 1)),
+            ),
+            (
+                kwargs.get('pointsNoise', 0.0),
+                kwargs.get('pointsDropout', 0.0),
 
-    X = DSUtils.toTensor(
-      (
-        np.array([x['points'] for x in samples], np.float32),
-        np.array([x['left eye'] for x in samples]),
-        np.array([x['right eye'] for x in samples]),
-        np.array([T for _, T in indexesAndTime], np.float32).reshape((-1, 1)),
-      ),
-      (
-        kwargs.get('pointsNoise', 0.0),
-        kwargs.get('pointsDropout', 0.0),
-      
-        kwargs.get('eyesAdditiveNoise', 0.0),
-        kwargs.get('eyesDropout', 0.0),
-        kwargs.get('brightnessFactor', 0.0),
-        kwargs.get('lightBlobFactor', 0.0),
+                kwargs.get('eyesAdditiveNoise', 0.0),
+                kwargs.get('eyesDropout', 0.0),
+                kwargs.get('brightnessFactor', 0.0),
+                kwargs.get('lightBlobFactor', 0.0),
 
-        timesteps
-      ),
-      userIds[0], placeIds[0], screenIds[0]
-    )
-    ###############
-    (Y, ) = Y
-    return(X, (Y.astype(np.float32), ))
+                timesteps
+            ),
+            userIds[0], placeIds[0], screenIds[0]
+        )
+        ###############
+        (Y, ) = Y
+        return (X, (Y.astype(np.float32), ))
 
-  @property
-  def totalSamples(self):
-    return len(self._storage)
-  
-  def validSamples(self):
-    return list(sorted(self._samples))
-##############
-if __name__ == '__main__':
-  import tensorflow as tf
-  gpus = tf.config.experimental.list_physical_devices('GPU')
-  tf.config.experimental.set_virtual_device_configuration(
-    gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024//2)]
-  )
-  import os
-  from Core.CSamplesStorage import CSamplesStorage
-  folder = os.path.dirname(os.path.dirname(__file__))
-  ds = CDataSampler( CSamplesStorage(), balancingMethod=dict(context='all') )
-  dsBlock = Utils.datasetFrom(os.path.join(folder, 'Data', 'Dataset'))
-  ds.addBlock(dsBlock)
-  exit(0)
\ No newline at end of file
+    def merge(self, samples, expected_batch_size):
+        Y = np.concatenate([Y for _, (Y, ) in samples], axis=0)
+        assert len(Y) == expected_batch_size, 'Invalid batch size: %d != %d' % (len(Y), expected_batch_size)
+        # X contains the clean and augmented data
+        # each dictionary contains the subkeys: points, left eye, right eye, time, userId, placeId, screenId
+        X = {}
+        for key in ['clean', 'augmented']:
+            for subkey in ['points', 'left eye', 'right eye', 'time', 'userId', 'placeId', 'screenId']:
+                data = [x[key][subkey] for x, _ in samples]
+                X[key][subkey] = np.concatenate(data, axis=0)
+                assert X[key][subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (X[key][subkey].shape[0], expected_batch_size)
+                continue
+            continue
+        return (X, (Y, ))
\ No newline at end of file
diff --git a/Core/CDataSamplerInpainting.py b/Core/CDataSamplerInpainting.py
new file mode 100644
index 0000000..f81342f
--- /dev/null
+++ b/Core/CDataSamplerInpainting.py
@@ -0,0 +1,214 @@
+from .CBaseDataSampler import CBaseDataSampler
+import Core.CDataSampler_utils as DSUtils
+from Core.Utils import FACE_MESH_POINTS
+
+import numpy as np
+
+'''
+This sampler are sample N frames from the dataset, where N is the number of timesteps.
+Within the range of the N sampled frames, its samples K frames to be inpainted/reconstructed.
+It returns the tuple (X, Y), where X is the input data and Y is the target data.
+To X could be applied some augmentations.
+X contains the following data:
+  - The points of the face.
+  - The left eye.
+  - The right eye.
+  - The time (cumulative or delta).
+  - The target point.
+  - The user ID, place ID, and screen ID.
+Y contains the target data, K frames to be inpainted/reconstructed.
+  - The points of the face.
+  - The left eye.
+  - The right eye.
+  - The normalized time.
+  - The target point.
+'''
+class CDataSamplerInpainting(CBaseDataSampler):
+    def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
+        super().__init__(storage, batch_size, minFrames, defaults, maxT, cumulative_time)
+
+    def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
+        if (steps is None) or (1 == steps): return [(mainInd, 0.0)]
+        if mainInd < steps: return False
+
+        samples, _ = self._trajectory(mainInd)
+        if len(samples) < (steps - 1): return False
+        # Try to sample valid frames
+        for _ in range(10):
+            res = self._framesFor(mainInd, samples, steps, stepsSampling)
+            T = self._prepareT(res)
+            if T is not None:
+                assert len(res) == len(T)
+                return [tuple(x) for x in zip(res, T)]
+            continue
+        return False
+
+    def sample(self, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        N = kwargs.get('N', self._batchSize)
+        indexes = []
+        for _ in range(N):
+            added = False
+            while not added:
+                idx = self._samples[self._currentSample]
+                self._currentSample = (self._currentSample + 1) % len(self._samples)
+
+                sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+                if sampledSteps:
+                    # TODO: remove from samples?
+                    indexes.extend(sampledSteps)
+                    added = True
+                continue
+
+        return self._indexes2XY(indexes, kwargs)
+
+    def sampleById(self, idx, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+        if not sampledSteps: return None
+        return self._indexes2XY([*sampledSteps], kwargs)
+
+    def checkById(self, idx, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+        if not sampledSteps: return False
+        return True
+
+    def sampleByIds(self, ids, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = []
+        rejected = []
+        accepted = []
+        for idx in ids:
+            sample = self._stepsFor(idx, steps=timesteps, **kwargs)
+            if sample:
+                accepted.append(idx)
+                sampledSteps.extend(sample)
+            else:
+                rejected.append(idx)
+                pass
+            continue
+
+        res = None
+        if 0 < len(sampledSteps):
+            res = self._indexes2XY(sampledSteps, kwargs)
+        return res, rejected, accepted
+
+    def _indexes2XY(self, indexesAndTime, kwargs):
+        timesteps = kwargs.get('timesteps', None)
+        assert timesteps is not None, 'The number of timesteps must be defined.'
+        samples = [self._storage[i] for i, _ in indexesAndTime]
+        ##############
+        userIds = np.unique([x['userId'] for x in samples])
+        assert 1 == len(userIds), 'Only one user is supported. Found: ' + str(userIds)
+        placeIds = np.unique([x['placeId'] for x in samples])
+        assert 1 == len(placeIds), 'Only one place is supported. Found: ' + str(placeIds)
+        screenIds = np.unique([x['screenId'] for x in samples])
+        assert 1 == len(screenIds), 'Only one screen is supported. Found: ' + str(screenIds)
+
+        X = DSUtils.toTensor(
+            (
+                np.array([x['points'] for x in samples], np.float32),
+                np.array([x['left eye'] for x in samples]),
+                np.array([x['right eye'] for x in samples]),
+                np.array([T for _, T in indexesAndTime], np.float32).reshape((-1, 1)),
+            ),
+            (
+                kwargs.get('pointsNoise', 0.0),
+                kwargs.get('pointsDropout', 0.0),
+
+                kwargs.get('eyesAdditiveNoise', 0.0),
+                kwargs.get('eyesDropout', 0.0),
+                kwargs.get('brightnessFactor', 0.0),
+                kwargs.get('lightBlobFactor', 0.0),
+
+                timesteps
+            ),
+            userIds[0], placeIds[0], screenIds[0]
+        )
+        X = X['clean']
+        ###############
+        # generate the target data
+        targets = kwargs.get('targets', {'keypoints': timesteps, 'total': timesteps})
+        K = targets.get('keypoints', timesteps)
+        assert K <= timesteps, 'The number of keypoints to be inpainted/reconstructed must be less or equal to the number of timesteps.'
+        T = targets.get('total', timesteps)
+        assert K <= T, 'The total number of frames to be inpainted/reconstructed must be less or equal to the total number of timesteps.'
+
+        samples_indexes = np.array([ i  for i, _ in indexesAndTime], np.int32)
+        samples_indexes = samples_indexes.reshape((-1, timesteps))
+        B = samples_indexes.shape[0]
+        targetsIdx = np.zeros((B, T), np.int32)
+        for i in range(B):
+            # sample K frames from the X
+            sampled = np.random.choice(samples_indexes[i], K, replace=False)
+            # repeat the sampled frames to fill the K frames
+            targetsIdx[i, :] = np.repeat(sampled, 1 + (T // K))[:T]
+            # sample the remaining frames
+            if K < T: # if need to sample more frames
+                startFrameIdx = samples_indexes[i, 0]
+                endFrameIdx = samples_indexes[i, -1]
+                allFrames = np.arange(startFrameIdx, endFrameIdx + 1)
+                # exclude the frames that are already sampled
+                allFrames = np.array([x for x in allFrames if x not in sampled], np.int32)
+                # sample the remaining frames
+                targetsIdx[i, K:] = np.random.choice(allFrames, T - K, replace=True)
+            continue
+        # targetsIdx contains the indexes of the frames to be inpainted/reconstructed
+        # we need to collect the data for the Y
+        # required data: points, left eye, right eye, time, target point
+        Y = {
+            'points': np.zeros((B, T, FACE_MESH_POINTS, 2), np.float32),
+            'left eye': np.zeros((B, T, 32, 32), np.float32),
+            'right eye': np.zeros((B, T, 32, 32), np.float32),
+            'time': np.zeros((B, T, 1), np.float32),
+            'target': np.zeros((B, T, 2), np.float32)
+        }
+        for i in range(B):
+            idxForSample = samples_indexes[i]
+            # stricly increasing indexes
+            assert np.all(0 < np.diff(idxForSample)), 'Invalid indexes: ' + str(idxForSample)
+            startT = self._storage[idxForSample[0]]['time']
+            endT = self._storage[idxForSample[-1]]['time']
+            duration = endT - startT
+            assert 0 < duration, 'Invalid duration: ' + str(duration)
+            targets_idx = np.sort(targetsIdx[i])
+            for j, idx in enumerate(targets_idx):
+                data = self._storage[idx]
+                Y['points'][i, j] = data['points']
+                # eyes should be cropped to 32x32, so we use the central crop
+                p = (data['left eye'].shape[0] - 32) // 2
+                Y['left eye'][i, j] = data['left eye'][p:p+32, p:p+32]
+                Y['right eye'][i, j] = data['right eye'][p:p+32, p:p+32]
+                Y['time'][i, j] = (data['time'] - startT) / duration
+                Y['target'][i, j] = data['goal']
+                
+        # check that time is between 0 and 1
+        assert np.all((0 <= Y['time']) & (Y['time'] <= 1)), 'Invalid time: ' + str(Y['time'])
+        B = Y['points'].shape[0] 
+        for k, v in X.items():
+            assert B == v.shape[0], f'Invalid batch size for X[{k}]: {v.shape[0]} != {B} ({v.shape})'
+        for k, v in Y.items():
+            assert B == v.shape[0], f'Invalid batch size for Y[{k}]: {v.shape[0]} != {B} ({v.shape})'
+        return (X, Y)
+    
+    def merge(self, samples, expected_batch_size):
+        # each dictionary contains the subkeys: points, left eye, right eye, time, userId, placeId, screenId
+        X = {}
+        for subkey in ['points', 'left eye', 'right eye', 'time', 'userId', 'placeId', 'screenId']:
+            data = [x[subkey] for x, _ in samples]
+            X[subkey] = np.concatenate(data, axis=0)
+            assert X[subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (X[subkey].shape[0], expected_batch_size)
+            continue
+        # 
+        Y = {}
+        for subkey in ['points', 'left eye', 'right eye', 'time', 'target']:
+            data = [y[subkey] for _, y in samples]
+            Y[subkey] = np.concatenate(data, axis=0)
+            assert Y[subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (Y[subkey].shape[0], expected_batch_size)
+            continue
+        return (X, (Y, ))
\ No newline at end of file
diff --git a/Core/CDataSampler_utils.py b/Core/CDataSampler_utils.py
index ef85cb2..90fa132 100644
--- a/Core/CDataSampler_utils.py
+++ b/Core/CDataSampler_utils.py
@@ -109,7 +109,7 @@ def toTensor(data, params, userId, placeId, screenId):
   ##########################
   # random crop 32x32 eyes
   fraction = 32.0 / 48.0
-  pos = tf.random.uniform((N, 2), minval=0.0, maxval=1.0 - fraction)
+  pos = tf.random.uniform((N, 2), minval=0.0, maxval=2.0 * fraction)
   boxes = tf.concat([pos, pos + fraction], axis=-1)
   tf.assert_equal(tf.shape(boxes), (N, 4))
   imgA = tf.image.crop_and_resize(
diff --git a/Core/CDatasetLoader.py b/Core/CDatasetLoader.py
index a526503..59c62f2 100644
--- a/Core/CDatasetLoader.py
+++ b/Core/CDatasetLoader.py
@@ -11,7 +11,7 @@ class ESampling(Enum):
   UNIFORM = 'uniform'
   
 class CDatasetLoader:
-  def __init__(self, folder, samplerArgs, sampling, stats):
+  def __init__(self, folder, samplerArgs, sampling, stats, sampler_class):
     # recursively find all 'train.npz' files
     trainFiles = glob.glob(os.path.join(folder, '**', 'train.npz'), recursive=True)
     if 0 == len(trainFiles):
@@ -26,7 +26,7 @@ def __init__(self, folder, samplerArgs, sampling, stats):
       # extract the placeId, userId, and screenId
       parts = os.path.split(trainFile)[0].split(os.path.sep)
       placeId, userId, screenId = parts[-3], parts[-2], parts[-1]
-      ds = CDataSampler(
+      ds = sampler_class(
         CSamplesStorage(
           placeId=stats['placeId'].index(placeId),
           userId=stats['userId'].index(userId),
@@ -98,45 +98,16 @@ def _getBatchStats(self, batchSize):
 
   def sample(self, **kwargs):
     batchSize = kwargs.get('batch_size', self._batchSize)
-    resX = []
-    resY = []
+    samples = []
     totalSamples = 0
     # find the datasets ids and the number of samples to take from each dataset
     datasetIds, counts = self._getBatchStats(batchSize)
     for datasetId, N in zip(datasetIds, counts):
       dataset = self._datasets[datasetId]
-      x, (y, ) = dataset.sample(N=N, **kwargs)
-      assert len(y) == N, 'Invalid number of samples: %d != %d' % (len(y), N)
-      resX.append(x)
-      resY.append(y)
-      totalSamples += len(y)
+      sampled = dataset.sample(N=N, **kwargs)
+      samples.append(sampled)
+      totalSamples += N
       continue
     
-    resY = np.concatenate(resY, axis=0)
-    assert resY.shape[0] == batchSize, 'Invalid shape: %s' % (resY.shape, )
-    assert resY.shape[-1] == 2, 'Invalid shape: %s' % (resY.shape, )
-    assert len(resY.shape) == 4, 'Invalid shape: %s' % (resY.shape, )
-
-    # sampled data has 'clean' and 'augmented' keys
-    output = {}
-    for nm in ['clean', 'augmented']:
-      keys = resX[0][nm].keys()
-      output[nm] = {k: tf.concat([x[nm][k] for x in resX], axis=0) for k in keys}
-      continue
-    return output, (resY,)
-  
-if __name__ == '__main__':
-  import cv2
-  folder = os.path.dirname(__file__)
-  ds = CDatasetLoader(
-    os.path.join(folder, 'Dataset'), batch_size=16, 
-    batchPerEpoch=1, steps=5
-  )
-  print(len(ds))
-  batchX, batchY = ds[0]
-  print(batchY[0].shape)
-  print(batchX[1].shape)
-  img = batchX[1][0, 0]
-  cv2.imshow('L', cv2.resize(img, (256, 256)))
-  cv2.waitKey()
-  pass
\ No newline at end of file
+    first_dataset = self._datasets[0]
+    return first_dataset.merge(samples, batchSize)
diff --git a/Core/Utils.py b/Core/Utils.py
index e7d5c54..df78942 100644
--- a/Core/Utils.py
+++ b/Core/Utils.py
@@ -191,8 +191,9 @@ def setupGPU():
       PART_TO_INDECES[k].add(i)
 ###################################
 FACE_MESH_INVALID_VALUE = -10.0
+FACE_MESH_POINTS = 478
 def decodeLandmarks(landmarks, VISIBILITY_THRESHOLD, PRESENCE_THRESHOLD):
-  points = np.full((478, 2), fill_value=FACE_MESH_INVALID_VALUE, dtype=np.float32)
+  points = np.full((FACE_MESH_POINTS, 2), fill_value=FACE_MESH_INVALID_VALUE, dtype=np.float32)
   for idx, mark in enumerate(landmarks.landmark):
     if (
       (mark.HasField('visibility') and (mark.visibility < VISIBILITY_THRESHOLD)) 
diff --git a/NN/networks.py b/NN/networks.py
index 719b3e5..d90cd21 100644
--- a/NN/networks.py
+++ b/NN/networks.py
@@ -1,9 +1,4 @@
-import argparse, os, sys
-# add the root folder of the project to the path
-ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
-sys.path.append(ROOT_FOLDER)
-
-from Core.Utils import setupGPU
+from Core.Utils import setupGPU, FACE_MESH_POINTS
 setupGPU() # dirty hack to setup GPU memory limit on startup
 
 import tensorflow as tf
@@ -141,7 +136,7 @@ def Step2LatentModel(latentSize, embeddingsSize):
 
 def _InputSpec():
   return {
-    'points': tf.TensorSpec(shape=(None, None, 478, 2), dtype=tf.float32),
+    'points': tf.TensorSpec(shape=(None, None, FACE_MESH_POINTS, 2), dtype=tf.float32),
     'left eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
     'right eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
     'time': tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32),
@@ -151,7 +146,7 @@ def _InputSpec():
   }
 
 def Face2LatentModel(
-  pointsN=478, eyeSize=32, steps=None, latentSize=64,
+  pointsN=FACE_MESH_POINTS, eyeSize=32, steps=None, latentSize=64,
   embeddings=None,
   diffusion=False # whether to use diffusion model
 ):
@@ -225,7 +220,20 @@ def Face2LatentModel(
   }
   
 ##########################
-def InpaintingEncoderModel(latentSize, embeddings, steps=5, pointsN=478, eyeSize=32, KP=5):
+
+def _InpaintingInputSpec():
+  return {
+    'points': tf.TensorSpec(shape=(None, None, FACE_MESH_POINTS, 2), dtype=tf.float32),
+    'left eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
+    'right eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
+    'time': tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32),
+    'userId': tf.TensorSpec(shape=(None, None, 1), dtype=tf.int32),
+    'placeId': tf.TensorSpec(shape=(None, None, 1), dtype=tf.int32),
+    'screenId': tf.TensorSpec(shape=(None, None, 1), dtype=tf.int32),
+    'target': tf.TensorSpec(shape=(None, None, 2), dtype=tf.float32),
+  }
+
+def InpaintingEncoderModel(latentSize, embeddings, steps=5, pointsN=FACE_MESH_POINTS, eyeSize=32, KP=5):
   points = L.Input((steps, pointsN, 2))
   eyeL = L.Input((steps, eyeSize, eyeSize, 1))
   eyeR = L.Input((steps, eyeSize, eyeSize, 1))
@@ -304,7 +312,7 @@ def combineKeys(x):
   )
   return main
  
-def InpaintingDecoderModel(latentSize, embeddings, pointsN=478, eyeSize=32, KP=5):
+def InpaintingDecoderModel(latentSize, embeddings, pointsN=FACE_MESH_POINTS, eyeSize=32, KP=5):
   latentKeyPoints = L.Input((KP, latentSize))
   T = L.Input((None, 1))
   userIdEmb = L.Input((embeddings['size']))
@@ -329,7 +337,7 @@ def transformLatents(x):
   latents = L.Lambda(transformLatents, name='CombineEmb')([latents, emb])
   # process the latents
   latents = sMLP(sizes=[latentSize] * 3, activation='relu', name='CombineEmb/MLP')(latents)
-  # decode the latents to the face points (478, 2), two eyes (32, 32, 2) and the target (2) 
+  # decode the latents to the face points (FACE_MESH_POINTS, 2), two eyes (32, 32, 2) and the target (2) 
   target = IntermediatePredictor(shift=0.5)(latents)
   # two eyes
   eyesN = eyeSize * eyeSize
diff --git a/scripts/download-remote.py b/scripts/download-remote.py
index 64d8c6a..c698900 100644
--- a/scripts/download-remote.py
+++ b/scripts/download-remote.py
@@ -20,6 +20,7 @@
 import shutil
 import numpy as np
 import requests
+from Core.Utils import FACE_MESH_POINTS
 
 folder = os.path.join(ROOT_FOLDER, 'Data')
 
@@ -65,10 +66,10 @@ def deserialize(buffer):
     offset += EYE_COUNT
     
     # Read points (float32)
-    sample['points'] = np.frombuffer(buffer, dtype='>f4', count=2*478, offset=offset) \
-      .reshape(478, 2)
+    sample['points'] = np.frombuffer(buffer, dtype='>f4', count=2*FACE_MESH_POINTS, offset=offset) \
+      .reshape(FACE_MESH_POINTS, 2)
     assert np.all(-2 <= sample['points']) and np.all(sample['points'] <= 2), 'Invalid points'
-    offset += 4 * 2 * 478
+    offset += 4 * 2 * FACE_MESH_POINTS
     
     # Read goal (float32)
     sample['goal'] = goal = np.frombuffer(buffer, dtype='>f4', count=2, offset=offset)
diff --git a/scripts/train-reconstruction.py b/scripts/train-reconstruction.py
new file mode 100644
index 0000000..ec58252
--- /dev/null
+++ b/scripts/train-reconstruction.py
@@ -0,0 +1,268 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-.
+# TODO: add the W&B integration
+import argparse, os, sys
+# add the root folder of the project to the path
+ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
+sys.path.append(ROOT_FOLDER)
+
+import numpy as np
+from Core.CDataSamplerInpainting import CDataSamplerInpainting
+from Core.CDatasetLoader import CDatasetLoader
+from Core.CTestLoader import CTestLoader
+from collections import defaultdict
+import time
+from Core.CModelTrainer import CModelTrainer
+from Core.CModelDiffusion import CModelDiffusion
+import tqdm
+import json
+import glob
+
+def _eval(dataset, model):
+  T = time.time()
+  # evaluate the model on the val dataset
+  lossPerSample = {'loss': [], 'pos': []}
+  predV = []
+  predDist = []
+  Y = []
+  for batchId in range(len(dataset)):
+    _, (y,) = batch = dataset[batchId]
+    loss, predP, dist = model.eval(batch)
+    predV.append(predP)
+    predDist.append(dist)
+    Y.append(y[:, -1, 0])
+    for l, pos in zip(loss, y[:, -1]):
+      lossPerSample['loss'].append(l)
+      lossPerSample['pos'].append(pos[0])
+      continue
+    continue
+
+  loss = np.mean(lossPerSample['loss'])
+  dist = np.mean(predDist)
+  T = time.time() - T
+  return loss, dist, T
+
+def evaluator(datasets, model, folder, args):
+  losses = [np.inf] * len(datasets) # initialize with infinity
+  dists = [np.inf] * len(datasets) # initialize with infinity
+  def evaluate(onlyImproved=False):
+    totalLoss = []
+    totalDist = []
+    losses_dist = []
+    for i, dataset in enumerate(datasets):
+      loss, dist, T = _eval(dataset, model)
+      losses_dist.append((loss, losses[i], dist, dists[i]))
+      isImproved = loss < losses[i]
+      if (not onlyImproved) or isImproved:
+        print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (
+          i + 1, len(datasets), T, loss, losses[i], dist, dists[i]
+        ))
+      if isImproved:
+        print('Test %d / %d | Improved %.5f => %.5f, Distance: %.5f => %.5f' % (
+          i + 1, len(datasets), losses[i], loss, dists[i], dist
+        ))
+        model.save(folder, postfix='best-%d' % i) # save the model separately
+        losses[i] = loss
+        pass
+
+      dists[i] = min(dist, dists[i]) # track the best distance
+      totalLoss.append(loss)
+      totalDist.append(dist)
+      continue
+    if not onlyImproved:
+      print('Mean loss: %.5f | Mean distance: %.5f' % (
+        np.mean(totalLoss), np.mean(totalDist)
+      ))
+    return np.mean(totalLoss), losses_dist
+  return evaluate
+
+def _modelTrainingLoop(model, dataset):
+  def F(desc):
+    history = defaultdict(list)
+    # use the tqdm progress bar
+    with tqdm.tqdm(total=len(dataset), desc=desc) as pbar:
+      dataset.on_epoch_start()
+      for _ in range(len(dataset)):
+        sampled = dataset.sample()
+        assert 2 == len(sampled), 'The dataset should return a tuple with the input and the output'
+        X, Y = sampled
+        # print shapes of the sampled data
+        for k, v in X.items():
+          print('X', k, v.shape)
+          continue
+        for k, v in Y.items():
+          print('Y', k, v.shape)
+          continue
+        exit(0)
+        assert 'clean' in X, 'The input should contain the clean data'
+        assert 'augmented' in X, 'The input should contain the augmented data'
+        for nm in ['clean', 'augmented']:
+          item = X[nm]
+          assert 'points' in item, 'The input should contain the points'
+          assert 'left eye' in item, 'The input should contain the left eye'
+          assert 'right eye' in item, 'The input should contain the right eye'
+          assert 'time' in item, 'The input should contain the time'
+          assert 'userId' in item, 'The input should contain the userId'
+          assert 'placeId' in item, 'The input should contain the placeId'
+          assert 'screenId' in item, 'The input should contain the screenId'
+          continue
+        stats = model.fit(sampled)
+        history['time'].append(stats['time'])
+        for k in stats['losses'].keys():
+          history[k].append(stats['losses'][k])
+        # add stats to the progress bar (mean of each history)
+        pbar.set_postfix({k: '%.5f' % np.mean(v) for k, v in history.items()})
+        pbar.update(1)
+        continue
+      dataset.on_epoch_end()
+    return
+  return F
+
+def _trainer_from(args):
+  if args.trainer == 'default': return CModelTrainer
+  raise Exception('Unknown trainer: %s' % (args.trainer, ))
+
+def averageModels(folder, model, noiseStd=0.0):
+  TV = [np.zeros_like(x) for x in model.trainable_variables()]
+  N = 0
+  for nm in glob.glob(os.path.join(folder, '*.h5')):
+    if not('best' in nm): continue # only the best models
+    model.load(nm, embeddings=True)
+    # add the weights to the total
+    weights = model.trainable_variables()
+    for i in range(len(TV)):
+      TV[i] += weights[i].numpy()
+      continue
+    N += 1
+    continue
+
+  # average the weights
+  TV = [(x / N) + np.random.normal(0.0, noiseStd, x.shape) for x in TV]
+  for v, new in zip(model.trainable_variables(), TV):
+    v.assign(new)
+    continue
+  model.compile() # recompile the model with the new weights
+  return
+
+def main(args):
+  timesteps = args.steps
+  folder = os.path.join(args.folder, 'Data')
+
+  stats = None
+  with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
+    stats = json.load(f)
+
+  trainer = _trainer_from(args)
+  trainDataset = CDatasetLoader(
+    os.path.join(folder, 'remote'),
+    stats=stats,
+    sampling=args.sampling,
+    samplerArgs=dict(
+      batch_size=args.batch_size,
+      minFrames=timesteps,
+      maxT=1.0,
+      defaults=dict(
+        timesteps=timesteps,
+        stepsSampling='uniform',
+        # no augmentations by default
+        pointsNoise=0.01, pointsDropout=0.0,
+        eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
+      ),
+      targets=dict(
+        keypoints=3,
+        total=10
+      ),
+    ),
+    sampler_class=CDataSamplerInpainting,
+  )
+  model = dict(timesteps=timesteps, stats=stats, use_encoders=args.with_enconders)
+  if args.model is not None:
+    model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
+  if args.modelId is not None:
+    model['model'] = args.modelId
+
+  model = trainer(**model)
+  model._model.summary()
+
+  # find folders with the name "/test-*/"
+  evalDatasets = [
+    # CTestLoader(nm)
+    # for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/'))
+  ]
+  eval = evaluator(evalDatasets, model, folder, args)
+  bestLoss, _ = eval() # evaluate loaded model
+  bestEpoch = 0
+  # wrapper for the evaluation function. It saves the model if it is better
+  def evalWrapper(eval):
+    def f(epoch, onlyImproved=False):
+      nonlocal bestLoss, bestEpoch
+      newLoss, losses = eval(onlyImproved=onlyImproved)
+      if newLoss < bestLoss:
+        print('Improved %.5f => %.5f' % (bestLoss, newLoss))
+        if onlyImproved: #details
+          for i, (loss, bestLoss_, dist, bestDist) in enumerate(losses):
+            print('Test %d | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (i + 1, loss, bestLoss_, dist, bestDist))
+            continue
+          print('-' * 80)
+        bestLoss = newLoss
+        bestEpoch = epoch
+        model.save(folder, postfix='best')
+      return
+    return f
+  
+  eval = evalWrapper(eval)
+  trainStep = _modelTrainingLoop(model, trainDataset)
+  for epoch in range(args.epochs):
+    trainStep(
+      desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
+    )
+    model.save(folder, postfix='latest')
+    eval(epoch)
+
+    print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
+    if args.patience <= (epoch - bestEpoch):
+      if 'stop' == args.on_patience:
+        print('Early stopping')
+        break
+    continue
+  return
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  parser.add_argument('--epochs', type=int, default=1000)
+  parser.add_argument('--batch-size', type=int, default=64)
+  parser.add_argument('--patience', type=int, default=5)
+  parser.add_argument('--on-patience', type=str, default='stop', choices=['stop', 'reset'])
+  parser.add_argument('--steps', type=int, default=5)
+  parser.add_argument('--model', type=str)
+  parser.add_argument('--embeddings', default=False, action='store_true')
+  parser.add_argument(
+    '--average', default=False, action='store_true',
+    help='Load each model from the folder and average them weights'
+  )
+  parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
+  parser.add_argument('--modelId', type=str)
+  parser.add_argument(
+    '--trainer', type=str, default='default',
+    choices=['default', 'diffusion']
+  )
+  parser.add_argument(
+    '--schedule', type=str, default=None,
+    help='JSON file with the scheduler parameters for sampling the training dataset'
+  )
+  parser.add_argument('--debug', action='store_true')
+  parser.add_argument('--noise', type=float, default=1e-4)
+  parser.add_argument(
+    '--restarts', type=int, default=1,
+    help='Number of times to restart the model reinitializing the weights'
+  )
+  parser.add_argument(
+    '--with-enconders', default=False, action='store_true',
+  )
+  parser.add_argument(
+    '--sampling', type=str, default='uniform',
+    choices=['uniform', 'as_is'],
+  )
+
+  main(parser.parse_args())
+  pass
\ No newline at end of file
diff --git a/scripts/train.py b/scripts/train.py
index 2afab81..3b7a424 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -8,6 +8,7 @@
 
 import numpy as np
 from Core.CDatasetLoader import CDatasetLoader
+from Core.CDataSampler import CDataSampler
 from Core.CTestLoader import CTestLoader
 from collections import defaultdict
 import time
@@ -240,7 +241,8 @@ def main(args):
         pointsNoise=0.01, pointsDropout=0.0,
         eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
       ),
-    )
+    ),
+    sampler_class=CDataSampler
   )
   model = dict(timesteps=timesteps, stats=stats, use_encoders=args.with_enconders)
   if args.model is not None:

From 581824146331c9ea10bb92951f160f52921ad317 Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Wed, 16 Oct 2024 19:59:24 +0200
Subject: [PATCH 04/10] Remove diffussion model

---
 Core/CModelDiffusion.py         | 272 --------------------------------
 scripts/train-reconstruction.py |   1 -
 scripts/train.py                |   4 +-
 3 files changed, 1 insertion(+), 276 deletions(-)
 delete mode 100644 Core/CModelDiffusion.py

diff --git a/Core/CModelDiffusion.py b/Core/CModelDiffusion.py
deleted file mode 100644
index 20a9259..0000000
--- a/Core/CModelDiffusion.py
+++ /dev/null
@@ -1,272 +0,0 @@
-import os
-import numpy as np
-import NN.networks as networks
-import tensorflow as tf
-import tensorflow_probability as tfp
-import NN.Utils as NNU
-import time
-from tensorflow.keras import layers as L
-
-# TODO: Implement the standard diffusion process (with the prediction of the noise, proper sampling, etc)
-class CModelDiffusion:
-  '''
-  Wrapper for the diffusion model to predict the gaze point
-  Diffusion T is equal to the stddev of the gaussian noise
-  '''
-  def __init__(self, timesteps, model='simple', user=None, stats=None, use_encoders=False, **kwargs):
-    if user is None:
-      user = {
-        'userId': 0,
-        'placeId': 0,
-        'screenId': 0,
-      }
-    else:
-      user = {
-        'userId': stats['userId'].index(user['userId']),
-        'placeId': stats['placeId'].index(user['placeId']),
-        'screenId': stats['screenId'].index(user['screenId']),
-      }
-    self._user = user
-
-    self._modelID = model
-    self._timesteps = timesteps
-    embeddings = {
-      'userId': len(stats['userId']),
-      'placeId': len(stats['placeId']),
-      'screenId': len(stats['screenId']),
-      'size': 64,
-    }
-    self._modelRaw = networks.Face2LatentModel(
-      steps=timesteps, latentSize=64, embeddings=embeddings,
-      diffusion=True
-    )
-    self._model = self._modelRaw['main']
-    self._embeddings = {
-      'userId': L.Embedding(len(stats['userId']), embeddings['size']),
-      'placeId': L.Embedding(len(stats['placeId']), embeddings['size']),
-      'screenId': L.Embedding(len(stats['screenId']), embeddings['size']),
-    }
-    self._intermediateEncoders = {}
-    if use_encoders:
-      shapes = self._modelRaw['intermediate shapes']
-      for name, shape in shapes.items():
-        enc = networks.IntermediatePredictor(name='%s-encoder' % name)
-        enc.build(shape)
-        self._intermediateEncoders[name] = enc
-        continue
-   
-    self._maxDiffusionT = 100.0
-    if 'weights' in kwargs:
-      self.load(**kwargs['weights'])
-    self.compile()
-    # add signatures to help tensorflow optimize the graph
-    specification = self._modelRaw['inputs specification']
-    self._trainStep = tf.function(
-      self._trainStep,
-      input_signature=[
-        (
-          { 'clean': specification, 'augmented': specification, },
-          ( tf.TensorSpec(shape=(None, None, None, 2), dtype=tf.float32), )
-        )
-      ]
-    )
-    self._eval = tf.function(
-      self._eval,
-      input_signature=[(
-        specification,
-        ( tf.TensorSpec(shape=(None, None, None, 2), dtype=tf.float32), )
-      )]
-    )
-
-    return
-  
-  def _step2mean(self, step):
-    step = tf.cast(step, tf.float32) / self._maxDiffusionT
-    step = tf.cast(step, tf.float32) + 1e-6
-    # step = tf.pow(step, 2.0) # make it decrease faster
-    return tf.clip_by_value(step, 1e-3, 1.0)
-  
-  def _replaceByEmbeddings(self, data):
-    data = dict(**data) # copy
-    for name, emb in self._embeddings.items():
-      data[name] = emb(data[name][..., 0])
-      continue
-    return data
-  
-  def _makeGaussian(self, mean, stddev):
-    stddev = tf.concat([stddev, stddev], axis=-1)
-    return tfp.distributions.MultivariateNormalDiag(mean, stddev)
-  
-  @tf.function
-  def _infer(self, data, training=False):
-    print('Instantiate _infer')
-    data = self._replaceByEmbeddings(data)
-    shp = tf.shape(data['userId'])
-    B, N = shp[0], self.timesteps
-    result = tf.zeros((B, N, 2), dtype=tf.float32)
-    for step in tf.range(self._maxDiffusionT, -1, -5):
-      mean = self._step2mean(
-        tf.fill((B, N, 1), step)
-      )
-      stepData = dict(**data)
-      stepData['diffusionT'] = mean
-      stepData['diffusionPoints'] = tf.random.normal((B, N, 2), mean=result, stddev=mean)
-      result = self._model(stepData, training=training)['result']
-    return result
-    
-  def predict(self, data, **kwargs):
-    B = self._timesteps
-    userId = kwargs.get('userId', self._user['userId'])
-    placeId = kwargs.get('placeId', self._user['placeId'])
-    screenId = kwargs.get('screenId', self._user['screenId'])
-    # put them as (1, B, ?)
-    data['userId'] = np.full((1, B, 1), userId, dtype=np.int32)
-    data['placeId'] = np.full((1, B, 1), placeId, dtype=np.int32)
-    data['screenId'] = np.full((1, B, 1), screenId, dtype=np.int32)
-
-    data = self._replaceByEmbeddings(data) # replace embeddings
-    
-    result = self._infer(data)
-    return result.numpy()
-  
-  def __call__(self, data, startPos=None):
-    predictions = self.predict(data)
-    return {
-      'coords': predictions[0, -1, :],
-    }
-    
-  def compile(self):
-    self._optimizer = NNU.createOptimizer()
-    return
-
-  def _modelFilename(self, folder, postfix=''):
-    postfix = '-' + postfix if postfix else ''
-    return os.path.join(folder, '%s-%s%s.h5' % (self._modelID, 'model', postfix))
-  
-  def save(self, folder=None, postfix=''):
-    path = self._modelFilename(folder, postfix)
-    self._model.save_weights(path)
-    embeddings = {}
-    for nm in self._embeddings.keys():
-      weights = self._embeddings[nm].get_weights()[0]
-      embeddings[nm] = weights
-      continue
-    np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
-    # save intermediate encoders
-    if self._intermediateEncoders:
-      encoders = {}
-      for nm, encoder in self._intermediateEncoders.items():
-        # save each variable separately
-        for ww in encoder.trainable_variables:
-          encoders['%s-%s' % (nm, ww.name)] = ww.numpy()
-        continue
-      np.savez_compressed(path.replace('.h5', '-intermediate-encoders.npz'), **encoders)
-    return
-    
-  def load(self, folder=None, postfix='', embeddings=False):
-    path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
-    self._model.load_weights(path)
-    if embeddings:
-      embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
-      for nm, emb in self._embeddings.items():
-        w = embeddings[nm]
-        if not emb.built: emb.build((None, w.shape[0]))
-        emb.set_weights([w]) # replace embeddings
-        continue
-    
-    if self._intermediateEncoders:
-      encodersName = path.replace('.h5', '-intermediate-encoders.npz')
-      if os.path.isfile(encodersName):
-        encoders = np.load(encodersName)
-        for nm, encoder in self._intermediateEncoders.items():
-          for ww in encoder.trainable_variables:
-            w = encoders['%s-%s' % (nm, ww.name)]
-            ww.assign(w)
-          continue
-    return
-  
-  def lock(self, isLocked):
-    self._model.trainable = not isLocked
-    return
-  
-  @property
-  def timesteps(self):
-    return self._timesteps
-  
-  def trainable_variables(self):
-    parts = list(self._embeddings.values()) + [self._model] + list(self._intermediateEncoders.values())
-    return sum([p.trainable_variables for p in parts], [])  
-  
-  def _pointLoss(self, ytrue, ypred):
-    # pseudo huber loss
-    delta = 0.01
-    tf.assert_equal(tf.shape(ytrue), tf.shape(ypred))
-    diff = tf.square(ytrue - ypred)
-    loss = tf.sqrt(diff + delta ** 2) - delta
-    tf.assert_equal(tf.shape(loss), tf.shape(ytrue))
-    return tf.reduce_mean(loss, axis=-1)
-
-  def _trainStep(self, Data):
-    print('Instantiate _trainStep')
-    ###############
-    x, (y, ) = Data
-    y = y[..., 0, :]
-    losses = {}
-    with tf.GradientTape() as tape:
-      data = x['augmented']
-      data = self._replaceByEmbeddings(data)
-      # add sampled T
-      B = tf.shape(y)[0]
-      N = self.timesteps
-      maxT = 100
-      diffusionT = tf.random.uniform((B, 1), minval=0, maxval=maxT, dtype=tf.int32)
-      # (B, 1) -> (B, N, 1)
-      diffusionT = tf.tile(diffusionT, (1, N))[..., None]
-      diffusionT = self._step2mean(diffusionT)
-      tf.assert_equal(tf.shape(diffusionT), (B, N, 1))
-      
-      # store the diffusion parameters
-      data['diffusionT'] = diffusionT
-      # sample the points
-      data['diffusionPoints'] = tf.random.normal((B, N, 2), mean=y, stddev=diffusionT)
-      predictions = self._model(data, training=True)
-    #   intermediate = predictions['intermediate']
-    #   assert len(intermediate) == 0, 'Intermediate predictions are not supported'
-      
-      predictedMean = predictions['result']
-      gaussian = self._makeGaussian(predictedMean, diffusionT)
-      losses['log_prob'] = tf.reduce_mean(
-        -gaussian.log_prob(y)
-      )
-      losses['points'] = self._pointLoss(y, predictedMean)
-      loss = sum(losses.values())
-      losses['loss'] = loss
-  
-    self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
-    ###############
-    return losses
-
-  def fit(self, data):
-    t = time.time()
-    losses = self._trainStep(data)
-    losses = {k: v.numpy() for k, v in losses.items()}
-    return {'time': int((time.time() - t) * 1000), 'losses': losses}
-  
-  def _eval(self, xy):
-    print('Instantiate _eval')
-    x, (y,) = xy
-    y = y[:, :, 0]
-    B, N = tf.shape(y)[0], tf.shape(y)[1]
-    
-    predictions = self._infer(x)
-    
-    mean = self._step2mean(tf.fill((B, N, 1), 0))
-    gaussian = self._makeGaussian(predictions, mean)
-    loss = tf.nn.sigmoid( -gaussian.log_prob(y) )
-    points = predictions
-    _, dist = NNU.normVec(y - predictions)
-    return loss, points, dist
-
-  def eval(self, data):
-    loss, sampled, dist = self._eval(data)
-    return loss.numpy(), sampled.numpy(), dist.numpy()
\ No newline at end of file
diff --git a/scripts/train-reconstruction.py b/scripts/train-reconstruction.py
index ec58252..2293ab1 100644
--- a/scripts/train-reconstruction.py
+++ b/scripts/train-reconstruction.py
@@ -13,7 +13,6 @@
 from collections import defaultdict
 import time
 from Core.CModelTrainer import CModelTrainer
-from Core.CModelDiffusion import CModelDiffusion
 import tqdm
 import json
 import glob
diff --git a/scripts/train.py b/scripts/train.py
index 3b7a424..569bae7 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -13,7 +13,6 @@
 from collections import defaultdict
 import time
 from Core.CModelTrainer import CModelTrainer
-from Core.CModelDiffusion import CModelDiffusion
 import tqdm
 import json
 import glob
@@ -188,7 +187,6 @@ def F(epoch):
 
 def _trainer_from(args):
   if args.trainer == 'default': return CModelTrainer
-  if args.trainer == 'diffusion': return CModelDiffusion
   raise Exception('Unknown trainer: %s' % (args.trainer, ))
 
 def averageModels(folder, model, noiseStd=0.0):
@@ -334,7 +332,7 @@ def performRandomSearch(epoch=0):
   parser.add_argument('--modelId', type=str)
   parser.add_argument(
     '--trainer', type=str, default='default',
-    choices=['default', 'diffusion']
+    choices=['default']
   )
   parser.add_argument(
     '--schedule', type=str, default=None,

From 79004872b6b4e77b3373ad5b552fc66354f78e68 Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Thu, 17 Oct 2024 15:50:22 +0200
Subject: [PATCH 05/10] wip

---
 Core/CBaseModel.py              |  59 +++++++++++++++++
 Core/CDataSamplerInpainting.py  |  39 ++++++++---
 Core/CInpaintingTrainer.py      | 101 ++++++++++++++++++++++++++++
 Core/CModelWrapper.py           | 112 +++++---------------------------
 NN/LagrangianInterpolation.py   |   9 ++-
 NN/networks.py                  |  39 ++++++-----
 scripts/download-remote.py      |   9 ++-
 scripts/train-reconstruction.py |  48 +++-----------
 scripts/train.py                |  60 +----------------
 9 files changed, 255 insertions(+), 221 deletions(-)
 create mode 100644 Core/CBaseModel.py
 create mode 100644 Core/CInpaintingTrainer.py

diff --git a/Core/CBaseModel.py b/Core/CBaseModel.py
new file mode 100644
index 0000000..4bb0204
--- /dev/null
+++ b/Core/CBaseModel.py
@@ -0,0 +1,59 @@
+import os
+import numpy as np
+from tensorflow.keras import layers as L
+
+class CBaseModel:
+  def __init__(self, model, embeddings, submodels):
+    self._model = model
+    self._embeddings = {
+      'userId': L.Embedding(embeddings['userId'], embeddings['size']),
+      'placeId': L.Embedding(embeddings['placeId'], embeddings['size']),
+      'screenId': L.Embedding(embeddings['screenId'], embeddings['size']),
+    }
+    self._submodels = submodels
+    return  
+
+  def replaceByEmbeddings(self, data):
+    data = dict(**data) # copy
+    for name, emb in self._embeddings.items():
+      data[name] = emb(data[name][..., 0])
+      continue
+    return data
+
+  def _modelFilename(self, folder, postfix=''):
+    postfix = '-' + postfix if postfix else ''
+    return os.path.join(folder, '%s%s.h5' % (self._modelID, postfix))
+  
+  def save(self, folder=None, postfix=''):
+    path = self._modelFilename(folder, postfix)
+    if 1 < len(self._submodels):
+      for i, model in enumerate(self._submodels):
+        model.save_weights(path.replace('.h5', '-%d.h5' % i))
+    else:
+      self._submodels[0].save_weights(path)
+
+    embeddings = {}
+    for nm in self._embeddings.keys():
+      weights = self._embeddings[nm].get_weights()[0]
+      embeddings[nm] = weights
+    
+    np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
+    
+  def load(self, folder=None, postfix='', embeddings=False):
+    path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
+    if 1 < len(self._submodels):
+      for i, model in enumerate(self._submodels):
+        model.load_weights(path.replace('.h5', '-%d.h5' % i))
+    else:
+      self._submodels[0].load_weights(path)
+      
+    if embeddings:
+      embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
+      for nm, emb in self._embeddings.items():
+        w = embeddings[nm]
+        if not emb.built: emb.build((None, w.shape[0]))
+        emb.set_weights([w]) # replace embeddings
+    
+  def trainable_variables(self):
+    parts = list(self._embeddings.values()) + self._submodels
+    return sum([p.trainable_variables for p in parts], [])
diff --git a/Core/CDataSamplerInpainting.py b/Core/CDataSamplerInpainting.py
index f81342f..ac8207c 100644
--- a/Core/CDataSamplerInpainting.py
+++ b/Core/CDataSamplerInpainting.py
@@ -3,6 +3,7 @@
 from Core.Utils import FACE_MESH_POINTS
 
 import numpy as np
+import tensorflow as tf
 
 '''
 This sampler are sample N frames from the dataset, where N is the number of timesteps.
@@ -24,8 +25,9 @@
   - The target point.
 '''
 class CDataSamplerInpainting(CBaseDataSampler):
-    def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
+    def __init__(self, storage, batch_size, minFrames, keys, defaults={}, maxT=1.0, cumulative_time=True):
         super().__init__(storage, batch_size, minFrames, defaults, maxT, cumulative_time)
+        self._keys = keys
 
     def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
         if (steps is None) or (1 == steps): return [(mainInd, 0.0)]
@@ -46,7 +48,7 @@ def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
     def sample(self, **kwargs):
         kwargs = {**self._defaults, **kwargs}
         timesteps = kwargs.get('timesteps', None)
-        N = kwargs.get('N', self._batchSize)
+        N = kwargs.get('N', self._batchSize) // len(self._keys)
         indexes = []
         for _ in range(N):
             added = False
@@ -101,6 +103,7 @@ def sampleByIds(self, ids, **kwargs):
     def _indexes2XY(self, indexesAndTime, kwargs):
         timesteps = kwargs.get('timesteps', None)
         assert timesteps is not None, 'The number of timesteps must be defined.'
+        B = len(indexesAndTime) // timesteps
         samples = [self._storage[i] for i, _ in indexesAndTime]
         ##############
         userIds = np.unique([x['userId'] for x in samples])
@@ -130,7 +133,25 @@ def _indexes2XY(self, indexesAndTime, kwargs):
             ),
             userIds[0], placeIds[0], screenIds[0]
         )
-        X = X['clean']
+        for k in X.keys():
+            # add the target point to the X
+            targets = np.array([x['goal'] for x in samples], np.float32).reshape((B, timesteps, 2))
+            X[k]['target'] = tf.constant(targets, dtype=tf.float32)
+
+        if 1 == len(self._keys):
+            X = X[self._keys[0]]
+        else:
+            res = {}
+            k = self._keys[0]
+            subkeys = list(X[k].keys())
+            for k in subkeys:
+                values = [X[key][k] for key in self._keys]
+                res[k] = tf.concat(values, axis=0)
+                continue
+            X = res
+            indexesAndTime = indexesAndTime * len(self._keys)
+            B = len(self._keys) * B
+
         ###############
         # generate the target data
         targets = kwargs.get('targets', {'keypoints': timesteps, 'total': timesteps})
@@ -141,7 +162,7 @@ def _indexes2XY(self, indexesAndTime, kwargs):
 
         samples_indexes = np.array([ i  for i, _ in indexesAndTime], np.int32)
         samples_indexes = samples_indexes.reshape((-1, timesteps))
-        B = samples_indexes.shape[0]
+        assert samples_indexes.shape[0] == B, 'Invalid number of samples: %d != %d' % (samples_indexes.shape[0], B)
         targetsIdx = np.zeros((B, T), np.int32)
         for i in range(B):
             # sample K frames from the X
@@ -186,10 +207,11 @@ def _indexes2XY(self, indexesAndTime, kwargs):
                 Y['right eye'][i, j] = data['right eye'][p:p+32, p:p+32]
                 Y['time'][i, j] = (data['time'] - startT) / duration
                 Y['target'][i, j] = data['goal']
-                
+        # eyes in 0..255, so we need to normalize them
+        Y['left eye'] /= 255.0
+        Y['right eye'] /= 255.0
         # check that time is between 0 and 1
         assert np.all((0 <= Y['time']) & (Y['time'] <= 1)), 'Invalid time: ' + str(Y['time'])
-        B = Y['points'].shape[0] 
         for k, v in X.items():
             assert B == v.shape[0], f'Invalid batch size for X[{k}]: {v.shape[0]} != {B} ({v.shape})'
         for k, v in Y.items():
@@ -197,9 +219,8 @@ def _indexes2XY(self, indexesAndTime, kwargs):
         return (X, Y)
     
     def merge(self, samples, expected_batch_size):
-        # each dictionary contains the subkeys: points, left eye, right eye, time, userId, placeId, screenId
         X = {}
-        for subkey in ['points', 'left eye', 'right eye', 'time', 'userId', 'placeId', 'screenId']:
+        for subkey in ['points', 'left eye', 'right eye', 'time', 'userId', 'placeId', 'screenId', 'target']:
             data = [x[subkey] for x, _ in samples]
             X[subkey] = np.concatenate(data, axis=0)
             assert X[subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (X[subkey].shape[0], expected_batch_size)
@@ -211,4 +232,4 @@ def merge(self, samples, expected_batch_size):
             Y[subkey] = np.concatenate(data, axis=0)
             assert Y[subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (Y[subkey].shape[0], expected_batch_size)
             continue
-        return (X, (Y, ))
\ No newline at end of file
+        return (X, Y)
\ No newline at end of file
diff --git a/Core/CInpaintingTrainer.py b/Core/CInpaintingTrainer.py
new file mode 100644
index 0000000..0e54533
--- /dev/null
+++ b/Core/CInpaintingTrainer.py
@@ -0,0 +1,101 @@
+import tensorflow as tf
+import time
+import NN.Utils as NNU
+import NN.networks as networks
+from Core.CBaseModel import CBaseModel
+
+class CInpaintingTrainer:
+  def __init__(self, timesteps, model='simple', KP=5, **kwargs):
+    stats = kwargs.get('stats', None)
+    embeddingsSize = kwargs.get('embeddingsSize', 64)
+    latentSize = kwargs.get('latentSize', 64)
+    embeddings = {
+      'userId': len(stats['userId']),
+      'placeId': len(stats['placeId']),
+      'screenId': len(stats['screenId']),
+      'size': embeddingsSize,
+    }
+
+    self._encoder = networks.InpaintingEncoderModel(
+      steps=timesteps, latentSize=latentSize,
+      embeddingsSize=embeddingsSize,
+      KP=KP,
+    )
+    self._decoder = networks.InpaintingDecoderModel(
+      latentSize=latentSize,
+      embeddingsSize=embeddingsSize,
+      KP=KP,
+    )
+    self._model = CBaseModel(
+       model=model, embeddings=embeddings, submodels=[self._encoder, self._decoder]
+    )
+    self.compile()
+    # add signatures to help tensorflow optimize the graph
+    specification = networks.InpaintingInputSpec()
+    self._trainStep = tf.function(
+      self._trainStep,
+      input_signature=[specification]
+    )
+    self._eval = tf.function(
+      self._eval,
+      input_signature=[specification]
+    )
+    return
+  
+  def compile(self):
+    self._optimizer = NNU.createOptimizer()
+
+  def _trainStep(self, Data):
+    print('Instantiate _trainStep')
+    ###############
+    x, y = Data
+    losses = {}
+    with tf.GradientTape() as tape:
+      x = self._model.replaceByEmbeddings(x)
+      latents = self._encoder(x, training=True)['latent']
+      decoderArgs = {
+        'keyPoints': latents,
+        'time': y['time'],
+        'userId': x['userId'],
+        'placeId': x['placeId'],
+        'screenId': x['screenId'],
+      }
+      predictions = self._decoder(decoderArgs, training=True)
+      losses = {}
+      for k in predictions.keys():
+        pred = predictions[k]
+        gt = y[k]
+        tf.assert_equal(tf.shape(pred), tf.shape(gt))
+        loss = tf.losses.mse(gt, pred)
+        losses[f"loss-{k}"] = tf.reduce_mean(loss)
+        
+      # calculate total loss and final loss
+      losses['loss'] = loss = sum(losses.values())
+  
+    self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
+    ###############
+    return losses
+
+  def fit(self, data):
+    t = time.time()
+    losses = self._trainStep(data)
+    losses = {k: v.numpy() for k, v in losses.items()}
+    return {'time': int((time.time() - t) * 1000), 'losses': losses}
+  
+  def _eval(self, xy):
+    print('Instantiate _eval')
+    x, (y,) = xy
+    x = self._replaceByEmbeddings(x)
+    y = y[:, :, 0]
+    predictions = self._model(x, training=False)
+    points = predictions['result'][:, :, :]
+    tf.assert_equal(tf.shape(points), tf.shape(y))
+
+    loss = self._pointLoss(y, points)
+    tf.assert_equal(tf.shape(loss), tf.shape(y)[:2])
+    _, dist = NNU.normVec(points - y)
+    return loss, points, dist
+
+  def eval(self, data):
+    loss, sampled, dist = self._eval(data)
+    return loss.numpy(), sampled.numpy(), dist.numpy()
\ No newline at end of file
diff --git a/Core/CModelWrapper.py b/Core/CModelWrapper.py
index 2adea8e..e63da2d 100644
--- a/Core/CModelWrapper.py
+++ b/Core/CModelWrapper.py
@@ -1,63 +1,34 @@
-import os
 import numpy as np
 import NN.networks as networks
-import tensorflow as tf
-from tensorflow.keras import layers as L
-
+from Core.CBaseModel import CBaseModel
+  
 class CModelWrapper:
-  def __init__(self, timesteps, model='simple', user=None, stats=None, use_encoders=True, **kwargs):
-    if user is None:
-      user = {
-        'userId': 0,
-        'placeId': 0,
-        'screenId': 0,
-      }
-    else:
+  def __init__(self, timesteps, model='simple', user=None, stats=None, **kwargs):
+    if user is not None:
       user = {
         'userId': stats['userId'].index(user['userId']),
         'placeId': stats['placeId'].index(user['placeId']),
         'screenId': stats['screenId'].index(user['screenId']),
       }
     self._user = user
-
-    self._modelID = model
     self._timesteps = timesteps
     embeddings = {
       'userId': len(stats['userId']),
       'placeId': len(stats['placeId']),
       'screenId': len(stats['screenId']),
-      'size': 64,
+      'size': kwargs.get('embeddingSize', 64),
     }
     self._modelRaw = networks.Face2LatentModel(
-      steps=timesteps, latentSize=64, embeddings=embeddings
+      steps=timesteps, latentSize=kwargs.get('latentSize', 64), embeddings=embeddings
     )
-    self._model = self._modelRaw['main']
-    self._embeddings = {
-      'userId': L.Embedding(len(stats['userId']), embeddings['size']),
-      'placeId': L.Embedding(len(stats['placeId']), embeddings['size']),
-      'screenId': L.Embedding(len(stats['screenId']), embeddings['size']),
-    }
-    self._intermediateEncoders = {}
-    if use_encoders:
-      shapes = self._modelRaw['intermediate shapes']
-      for name, shape in shapes.items():
-        enc = networks.IntermediatePredictor(name='%s-encoder' % name)
-        enc.build(shape)
-        self._intermediateEncoders[name] = enc
-        continue
-   
+    NN =  self._network = self._modelRaw['main']
+    self._model = CBaseModel(model=model, embeddings=embeddings, submodels=[NN])
     if 'weights' in kwargs:
       self.load(**kwargs['weights'])
     return
   
-  def _replaceByEmbeddings(self, data):
-    data = dict(**data) # copy
-    for name, emb in self._embeddings.items():
-      data[name] = emb(data[name][..., 0])
-      continue
-    return data
-  
   def predict(self, data, **kwargs):
+    assert self._user is not None, 'User is not set'
     B = self._timesteps
     userId = kwargs.get('userId', self._user['userId'])
     placeId = kwargs.get('placeId', self._user['placeId'])
@@ -68,68 +39,21 @@ def predict(self, data, **kwargs):
     data['screenId'] = np.full((1, B, 1), screenId, dtype=np.int32)
 
     data = self._replaceByEmbeddings(data) # replace embeddings
-    return self._model(data, training=False)['result'].numpy()
+    return self._network(data, training=False)['result'].numpy()
   
   def __call__(self, data, startPos=None):
     predictions = self.predict(data)
-    return {
-      'coords': predictions[0, -1, :],
-    }
-
-  def _modelFilename(self, folder, postfix=''):
-    postfix = '-' + postfix if postfix else ''
-    return os.path.join(folder, '%s-%s%s.h5' % (self._modelID, 'model', postfix))
-  
-  def save(self, folder=None, postfix=''):
-    path = self._modelFilename(folder, postfix)
-    self._model.save_weights(path)
-    embeddings = {}
-    for nm in self._embeddings.keys():
-      weights = self._embeddings[nm].get_weights()[0]
-      embeddings[nm] = weights
-      continue
-    np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
-    # save intermediate encoders
-    if self._intermediateEncoders:
-      encoders = {}
-      for nm, encoder in self._intermediateEncoders.items():
-        # save each variable separately
-        for ww in encoder.trainable_variables:
-          encoders['%s-%s' % (nm, ww.name)] = ww.numpy()
-        continue
-      np.savez_compressed(path.replace('.h5', '-intermediate-encoders.npz'), **encoders)
-    return
-    
-  def load(self, folder=None, postfix='', embeddings=False):
-    path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
-    self._model.load_weights(path)
-    if embeddings:
-      embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
-      for nm, emb in self._embeddings.items():
-        w = embeddings[nm]
-        if not emb.built: emb.build((None, w.shape[0]))
-        emb.set_weights([w]) # replace embeddings
-        continue
-    
-    if self._intermediateEncoders:
-      encodersName = path.replace('.h5', '-intermediate-encoders.npz')
-      if os.path.isfile(encodersName):
-        encoders = np.load(encodersName)
-        for nm, encoder in self._intermediateEncoders.items():
-          for ww in encoder.trainable_variables:
-            w = encoders['%s-%s' % (nm, ww.name)]
-            ww.assign(w)
-          continue
-    return
-  
-  def lock(self, isLocked):
-    self._model.trainable = not isLocked
-    return
+    return { 'coords': predictions[0, -1, :], }
   
   @property
   def timesteps(self):
     return self._timesteps
   
+  def save(self, folder=None, postfix=''):
+    self._model.save(folder=folder, postfix=postfix)
+
+  def load(self, folder=None, postfix='', embeddings=False):
+    self._model.load(folder=folder, postfix=postfix, embeddings=embeddings)
+
   def trainable_variables(self):
-    parts = list(self._embeddings.values()) + [self._model] + list(self._intermediateEncoders.values())
-    return sum([p.trainable_variables for p in parts], [])
\ No newline at end of file
+    return self._model.trainable_variables()
\ No newline at end of file
diff --git a/NN/LagrangianInterpolation.py b/NN/LagrangianInterpolation.py
index 7c53f9a..88b2942 100644
--- a/NN/LagrangianInterpolation.py
+++ b/NN/LagrangianInterpolation.py
@@ -13,11 +13,10 @@ def lagrange_interpolation(x_values, y_values, x_targets):
     Returns:
     - interpolated_values: Tensor of shape (batch_size, m, d), interpolated y-values for each batch.
     """
-    minX = tf.reduce_min(x_values, axis=1)
-    maxX = tf.reduce_max(x_values, axis=1)
-    # Check if x_targets in the range of x_values
-    tf.debugging.assert_greater_equal(x_targets, minX, message="x_targets out of range")
-    tf.debugging.assert_less_equal(x_targets, maxX, message="x_targets out of range")
+    # add extra points at -1 and 2, to smooth the interpolation at the edges
+    ones = tf.ones_like(x_values[:, :1])
+    x_values = tf.concat([ones * -1, x_values, ones * 2], axis=1)
+    y_values = tf.concat([y_values[:, :1], y_values, y_values[:, -1:]], axis=1)
 
     batch_size = tf.shape(x_values)[0]
     n = tf.shape(x_values)[1]
diff --git a/NN/networks.py b/NN/networks.py
index d90cd21..9bfed75 100644
--- a/NN/networks.py
+++ b/NN/networks.py
@@ -221,8 +221,8 @@ def Face2LatentModel(
   
 ##########################
 
-def _InpaintingInputSpec():
-  return {
+def InpaintingInputSpec():
+  XSpec = {
     'points': tf.TensorSpec(shape=(None, None, FACE_MESH_POINTS, 2), dtype=tf.float32),
     'left eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
     'right eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
@@ -232,16 +232,25 @@ def _InpaintingInputSpec():
     'screenId': tf.TensorSpec(shape=(None, None, 1), dtype=tf.int32),
     'target': tf.TensorSpec(shape=(None, None, 2), dtype=tf.float32),
   }
+  YSpec = {
+    'points': tf.TensorSpec(shape=(None, None, FACE_MESH_POINTS, 2), dtype=tf.float32),
+    'left eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
+    'right eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
+    'time': tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32),
+    'target': tf.TensorSpec(shape=(None, None, 2), dtype=tf.float32),
+  }
+  return (XSpec, YSpec)
+
 
-def InpaintingEncoderModel(latentSize, embeddings, steps=5, pointsN=FACE_MESH_POINTS, eyeSize=32, KP=5):
+def InpaintingEncoderModel(latentSize, embeddingsSize, steps=5, pointsN=FACE_MESH_POINTS, eyeSize=32, KP=5):
   points = L.Input((steps, pointsN, 2))
   eyeL = L.Input((steps, eyeSize, eyeSize, 1))
   eyeR = L.Input((steps, eyeSize, eyeSize, 1))
   T = L.Input((steps, 1)) # accumulative time
   target = L.Input((steps, 2))
-  userIdEmb = L.Input((steps, embeddings['size']))
-  placeIdEmb = L.Input((steps, embeddings['size']))
-  screenIdEmb = L.Input((steps, embeddings['size']))
+  userIdEmb = L.Input((steps, embeddingsSize))
+  placeIdEmb = L.Input((steps, embeddingsSize))
+  screenIdEmb = L.Input((steps, embeddingsSize))
 
   emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
 
@@ -312,14 +321,14 @@ def combineKeys(x):
   )
   return main
  
-def InpaintingDecoderModel(latentSize, embeddings, pointsN=FACE_MESH_POINTS, eyeSize=32, KP=5):
+def InpaintingDecoderModel(latentSize, embeddingsSize, pointsN=FACE_MESH_POINTS, eyeSize=32, KP=5):
   latentKeyPoints = L.Input((KP, latentSize))
   T = L.Input((None, 1))
-  userIdEmb = L.Input((embeddings['size']))
-  placeIdEmb = L.Input((embeddings['size']))
-  screenIdEmb = L.Input((embeddings['size']))
+  userIdEmb = L.Input((None, embeddingsSize))
+  placeIdEmb = L.Input((None, embeddingsSize))
+  screenIdEmb = L.Input((None, embeddingsSize))
 
-  emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])[..., None, :]
+  emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])[:, :1, :]
   # emb shape: (B, 1, 3 * embSize) 
   def interpolateKeys(x):
     latents, T = x
@@ -332,7 +341,7 @@ def interpolateKeys(x):
   def transformLatents(x):
     latents, emb = x
     N = tf.shape(latents)[1]
-    emb = tf.tile(emb, (1, N, 1)) # (B, 1, 3 * embSize) -> (B, N, 3 * embSize)
+    emb = tf.tile(emb, (1, N, 1))
     return L.Concatenate(-1)([latents, emb])
   latents = L.Lambda(transformLatents, name='CombineEmb')([latents, emb])
   # process the latents
@@ -359,9 +368,9 @@ def transformLatents(x):
     },
     outputs={
       'target': target,
-      'left eye': eyes[:, :, 0],
-      'right eye': eyes[:, :, 1],
-      'face': face,
+      'left eye': eyes[..., 0],
+      'right eye': eyes[..., 1],
+      'points': face,
     }
   )
   return model
diff --git a/scripts/download-remote.py b/scripts/download-remote.py
index c698900..3656c99 100644
--- a/scripts/download-remote.py
+++ b/scripts/download-remote.py
@@ -8,7 +8,7 @@
     PlaceId
       UserId
         ScreenId
-          start_time.npz
+          *.npz
 '''
 import argparse, os, sys
 # add the root folder of the project to the path
@@ -68,7 +68,12 @@ def deserialize(buffer):
     # Read points (float32)
     sample['points'] = np.frombuffer(buffer, dtype='>f4', count=2*FACE_MESH_POINTS, offset=offset) \
       .reshape(FACE_MESH_POINTS, 2)
-    assert np.all(-2 <= sample['points']) and np.all(sample['points'] <= 2), 'Invalid points'
+    # change the range [-d, 1.0 + d]
+    d = 0.5
+    points = sample['points']
+    is_valid = (-d <= points) & (points <= 1.0 + d)
+    is_valid = is_valid.all(axis=-1)
+    assert np.all(is_valid), 'Invalid points: %s' % points[~is_valid]
     offset += 4 * 2 * FACE_MESH_POINTS
     
     # Read goal (float32)
diff --git a/scripts/train-reconstruction.py b/scripts/train-reconstruction.py
index 2293ab1..893bcf3 100644
--- a/scripts/train-reconstruction.py
+++ b/scripts/train-reconstruction.py
@@ -12,7 +12,7 @@
 from Core.CTestLoader import CTestLoader
 from collections import defaultdict
 import time
-from Core.CModelTrainer import CModelTrainer
+from Core.CInpaintingTrainer import CInpaintingTrainer
 import tqdm
 import json
 import glob
@@ -83,28 +83,6 @@ def F(desc):
       dataset.on_epoch_start()
       for _ in range(len(dataset)):
         sampled = dataset.sample()
-        assert 2 == len(sampled), 'The dataset should return a tuple with the input and the output'
-        X, Y = sampled
-        # print shapes of the sampled data
-        for k, v in X.items():
-          print('X', k, v.shape)
-          continue
-        for k, v in Y.items():
-          print('Y', k, v.shape)
-          continue
-        exit(0)
-        assert 'clean' in X, 'The input should contain the clean data'
-        assert 'augmented' in X, 'The input should contain the augmented data'
-        for nm in ['clean', 'augmented']:
-          item = X[nm]
-          assert 'points' in item, 'The input should contain the points'
-          assert 'left eye' in item, 'The input should contain the left eye'
-          assert 'right eye' in item, 'The input should contain the right eye'
-          assert 'time' in item, 'The input should contain the time'
-          assert 'userId' in item, 'The input should contain the userId'
-          assert 'placeId' in item, 'The input should contain the placeId'
-          assert 'screenId' in item, 'The input should contain the screenId'
-          continue
         stats = model.fit(sampled)
         history['time'].append(stats['time'])
         for k in stats['losses'].keys():
@@ -118,7 +96,7 @@ def F(desc):
   return F
 
 def _trainer_from(args):
-  if args.trainer == 'default': return CModelTrainer
+  if args.trainer == 'default': return CInpaintingTrainer
   raise Exception('Unknown trainer: %s' % (args.trainer, ))
 
 def averageModels(folder, model, noiseStd=0.0):
@@ -166,22 +144,20 @@ def main(args):
         # no augmentations by default
         pointsNoise=0.01, pointsDropout=0.0,
         eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
+        targets=dict(keypoints=3, total=10),
       ),
-      targets=dict(
-        keypoints=3,
-        total=10
-      ),
+      keys=['clean', 'augmented'],
     ),
     sampler_class=CDataSamplerInpainting,
   )
-  model = dict(timesteps=timesteps, stats=stats, use_encoders=args.with_enconders)
+  model = dict(timesteps=timesteps, stats=stats)
   if args.model is not None:
     model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
   if args.modelId is not None:
     model['model'] = args.modelId
 
   model = trainer(**model)
-  model._model.summary()
+#   model._model.summary()
 
   # find folders with the name "/test-*/"
   evalDatasets = [
@@ -194,6 +170,7 @@ def main(args):
   # wrapper for the evaluation function. It saves the model if it is better
   def evalWrapper(eval):
     def f(epoch, onlyImproved=False):
+      return
       nonlocal bestLoss, bestEpoch
       newLoss, losses = eval(onlyImproved=onlyImproved)
       if newLoss < bestLoss:
@@ -215,7 +192,7 @@ def f(epoch, onlyImproved=False):
     trainStep(
       desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
     )
-    model.save(folder, postfix='latest')
+    # model.save(folder, postfix='latest')
     eval(epoch)
 
     print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
@@ -243,11 +220,7 @@ def f(epoch, onlyImproved=False):
   parser.add_argument('--modelId', type=str)
   parser.add_argument(
     '--trainer', type=str, default='default',
-    choices=['default', 'diffusion']
-  )
-  parser.add_argument(
-    '--schedule', type=str, default=None,
-    help='JSON file with the scheduler parameters for sampling the training dataset'
+    choices=['default']
   )
   parser.add_argument('--debug', action='store_true')
   parser.add_argument('--noise', type=float, default=1e-4)
@@ -255,9 +228,6 @@ def f(epoch, onlyImproved=False):
     '--restarts', type=int, default=1,
     help='Number of times to restart the model reinitializing the weights'
   )
-  parser.add_argument(
-    '--with-enconders', default=False, action='store_true',
-  )
   parser.add_argument(
     '--sampling', type=str, default='uniform',
     choices=['uniform', 'as_is'],
diff --git a/scripts/train.py b/scripts/train.py
index 569bae7..a580517 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -107,13 +107,13 @@ def evaluate(onlyImproved=False):
   return evaluate
 
 def _modelTrainingLoop(model, dataset):
-  def F(desc, sampleParams):
+  def F(desc):
     history = defaultdict(list)
     # use the tqdm progress bar
     with tqdm.tqdm(total=len(dataset), desc=desc) as pbar:
       dataset.on_epoch_start()
       for _ in range(len(dataset)):
-        sampled = dataset.sample(**sampleParams)
+        sampled = dataset.sample()
         assert 2 == len(sampled), 'The dataset should return a tuple with the input and the output'
         X, Y = sampled
         assert 'clean' in X, 'The input should contain the clean data'
@@ -140,51 +140,6 @@ def F(desc, sampleParams):
     return
   return F
 
-def _defaultSchedule(args):
-  return lambda epoch: dict()
-
-def _schedule_from_json(args):
-  with open(args.schedule, 'r') as f:
-    schedule = json.load(f)
-
-  # schedule is a dictionary of dictionaries, where the keys are the epochs
-  # transform it into a sorted list of tuples (epoch, params)
-  for k, v in schedule.items():
-    v = [(int(epoch), p) for epoch, p in v.items()]
-    schedule[k] = sorted(v, key=lambda x: x[0])
-    continue
-
-  def F(epoch):
-    res = {}
-    for k, v in schedule.items():
-      assert isinstance(v, list), 'The schedule should be a list of parameters'
-      # find the first epoch that is less or equal to the current one
-      smallest = [i for i, (e, _) in enumerate(v) if e <= epoch]
-      if 0 == len(smallest): continue
-      smallest = smallest[-1]
-
-      startEpoch, p = v[smallest]
-      value = p
-      # p could be a dictionary or value
-      if isinstance(p, list) and (2 == len(p)):
-        assert smallest + 1 < len(v), 'The last epoch should be the last one'
-        minV, maxV = [float(x) for x in p]
-        nextEpoch, _ = v[smallest + 1]
-        # linearly interpolate between the values
-        value = minV + (maxV - minV) * (epoch - startEpoch) / (nextEpoch - startEpoch)
-        pass
-      
-      res[k] = float(value)
-      continue
-
-    if args.debug and res:
-      print('Parameters for epoch %d:' % (epoch, ))
-      for k, v in res.items():
-        print('  %s: %.5f' % (k, v))
-        continue
-    return res
-  return F
-
 def _trainer_from(args):
   if args.trainer == 'default': return CModelTrainer
   raise Exception('Unknown trainer: %s' % (args.trainer, ))
@@ -214,11 +169,6 @@ def averageModels(folder, model, noiseStd=0.0):
 def main(args):
   timesteps = args.steps
   folder = os.path.join(args.folder, 'Data')
-  if args.schedule is None:
-    getSampleParams = _defaultSchedule(args)
-  else:
-    getSampleParams = _schedule_from_json(args)
-
   stats = None
   with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
     stats = json.load(f)
@@ -242,7 +192,7 @@ def main(args):
     ),
     sampler_class=CDataSampler
   )
-  model = dict(timesteps=timesteps, stats=stats, use_encoders=args.with_enconders)
+  model = dict(timesteps=timesteps, stats=stats)
   if args.model is not None:
     model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
   if args.modelId is not None:
@@ -298,7 +248,6 @@ def performRandomSearch(epoch=0):
   for epoch in range(args.epochs):
     trainStep(
       desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
-      sampleParams=getSampleParams(epoch)
     )
     model.save(folder, postfix='latest')
     eval(epoch)
@@ -344,9 +293,6 @@ def performRandomSearch(epoch=0):
     '--restarts', type=int, default=1,
     help='Number of times to restart the model reinitializing the weights'
   )
-  parser.add_argument(
-    '--with-enconders', default=False, action='store_true',
-  )
   parser.add_argument(
     '--sampling', type=str, default='uniform',
     choices=['uniform', 'as_is'],

From e99e991001a25158498d922a59853d942b098d9e Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Fri, 18 Oct 2024 07:35:27 +0200
Subject: [PATCH 06/10] Can evaluate the inpainting model

---
 Core/CDataSamplerInpainting.py            |   7 +-
 Core/CDatasetLoader.py                    |  18 ++-
 Core/CInpaintingTrainer.py                |  64 +++++-----
 Core/CTestInpaintingLoader.py             |  26 ++++
 Core/CTestLoader.py                       |  26 +---
 scripts/create-test-dataset-inpainting.py | 148 ++++++++++++++++++++++
 scripts/make-blacklist.py                 | 103 ---------------
 scripts/train-reconstruction.py           |  58 +++------
 8 files changed, 241 insertions(+), 209 deletions(-)
 create mode 100644 Core/CTestInpaintingLoader.py
 create mode 100644 scripts/create-test-dataset-inpainting.py
 delete mode 100644 scripts/make-blacklist.py

diff --git a/Core/CDataSamplerInpainting.py b/Core/CDataSamplerInpainting.py
index ac8207c..58afd15 100644
--- a/Core/CDataSamplerInpainting.py
+++ b/Core/CDataSamplerInpainting.py
@@ -50,6 +50,7 @@ def sample(self, **kwargs):
         timesteps = kwargs.get('timesteps', None)
         N = kwargs.get('N', self._batchSize) // len(self._keys)
         indexes = []
+        added = False
         for _ in range(N):
             added = False
             while not added:
@@ -62,7 +63,7 @@ def sample(self, **kwargs):
                     indexes.extend(sampledSteps)
                     added = True
                 continue
-
+        if not added: return None, 0
         return self._indexes2XY(indexes, kwargs)
 
     def sampleById(self, idx, **kwargs):
@@ -97,7 +98,7 @@ def sampleByIds(self, ids, **kwargs):
 
         res = None
         if 0 < len(sampledSteps):
-            res = self._indexes2XY(sampledSteps, kwargs)
+            res, _ = self._indexes2XY(sampledSteps, kwargs)
         return res, rejected, accepted
 
     def _indexes2XY(self, indexesAndTime, kwargs):
@@ -216,7 +217,7 @@ def _indexes2XY(self, indexesAndTime, kwargs):
             assert B == v.shape[0], f'Invalid batch size for X[{k}]: {v.shape[0]} != {B} ({v.shape})'
         for k, v in Y.items():
             assert B == v.shape[0], f'Invalid batch size for Y[{k}]: {v.shape[0]} != {B} ({v.shape})'
-        return (X, Y)
+        return (X, Y), B
     
     def merge(self, samples, expected_batch_size):
         X = {}
diff --git a/Core/CDatasetLoader.py b/Core/CDatasetLoader.py
index 59c62f2..f827bf7 100644
--- a/Core/CDatasetLoader.py
+++ b/Core/CDatasetLoader.py
@@ -43,6 +43,8 @@ def __init__(self, folder, samplerArgs, sampling, stats, sampler_class):
       i: len(ds.validSamples())
       for i, ds in enumerate(self._datasets)
     }
+    # ignore datasets with no valid samples
+    validSamples = {k: v for k, v in validSamples.items() if 0 < v}
     dtype = np.uint8 if len(self._datasets) < 256 else np.uint32
     # create an array of dataset indices to sample from
     sampling = ESampling(sampling)
@@ -101,13 +103,15 @@ def sample(self, **kwargs):
     samples = []
     totalSamples = 0
     # find the datasets ids and the number of samples to take from each dataset
-    datasetIds, counts = self._getBatchStats(batchSize)
-    for datasetId, N in zip(datasetIds, counts):
-      dataset = self._datasets[datasetId]
-      sampled = dataset.sample(N=N, **kwargs)
-      samples.append(sampled)
-      totalSamples += N
-      continue
+    while totalSamples < batchSize:
+      datasetIds, counts = self._getBatchStats(batchSize - totalSamples)
+      for datasetId, N in zip(datasetIds, counts):
+        dataset = self._datasets[datasetId]
+        sampled, N = dataset.sample(N=N, **kwargs)
+        if 0 < N:
+          samples.append(sampled)
+        totalSamples += N
+        continue
     
     first_dataset = self._datasets[0]
     return first_dataset.merge(samples, batchSize)
diff --git a/Core/CInpaintingTrainer.py b/Core/CInpaintingTrainer.py
index 0e54533..4e25776 100644
--- a/Core/CInpaintingTrainer.py
+++ b/Core/CInpaintingTrainer.py
@@ -45,32 +45,36 @@ def __init__(self, timesteps, model='simple', KP=5, **kwargs):
   def compile(self):
     self._optimizer = NNU.createOptimizer()
 
+  def _calcLoss(self, x, y, training):
+    losses = {}
+    x = self._model.replaceByEmbeddings(x)
+    latents = self._encoder(x, training=training)['latent']
+    decoderArgs = {
+      'keyPoints': latents,
+      'time': y['time'],
+      'userId': x['userId'],
+      'placeId': x['placeId'],
+      'screenId': x['screenId'],
+    }
+    predictions = self._decoder(decoderArgs, training=training)
+    losses = {}
+    for k in predictions.keys():
+      pred = predictions[k]
+      gt = y[k]
+      tf.assert_equal(tf.shape(pred), tf.shape(gt))
+      loss = tf.losses.mse(gt, pred)
+      losses[f"loss-{k}"] = tf.reduce_mean(loss)
+      
+    # calculate total loss and final loss
+    losses['loss'] = sum(losses.values())
+    return losses, losses['loss']
+  
   def _trainStep(self, Data):
     print('Instantiate _trainStep')
     ###############
     x, y = Data
-    losses = {}
     with tf.GradientTape() as tape:
-      x = self._model.replaceByEmbeddings(x)
-      latents = self._encoder(x, training=True)['latent']
-      decoderArgs = {
-        'keyPoints': latents,
-        'time': y['time'],
-        'userId': x['userId'],
-        'placeId': x['placeId'],
-        'screenId': x['screenId'],
-      }
-      predictions = self._decoder(decoderArgs, training=True)
-      losses = {}
-      for k in predictions.keys():
-        pred = predictions[k]
-        gt = y[k]
-        tf.assert_equal(tf.shape(pred), tf.shape(gt))
-        loss = tf.losses.mse(gt, pred)
-        losses[f"loss-{k}"] = tf.reduce_mean(loss)
-        
-      # calculate total loss and final loss
-      losses['loss'] = loss = sum(losses.values())
+      losses, loss = self._calcLoss(x, y, training=True)
   
     self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
     ###############
@@ -84,18 +88,10 @@ def fit(self, data):
   
   def _eval(self, xy):
     print('Instantiate _eval')
-    x, (y,) = xy
-    x = self._replaceByEmbeddings(x)
-    y = y[:, :, 0]
-    predictions = self._model(x, training=False)
-    points = predictions['result'][:, :, :]
-    tf.assert_equal(tf.shape(points), tf.shape(y))
-
-    loss = self._pointLoss(y, points)
-    tf.assert_equal(tf.shape(loss), tf.shape(y)[:2])
-    _, dist = NNU.normVec(points - y)
-    return loss, points, dist
+    x, y = xy
+    losses, loss = self._calcLoss(x, y, training=False)
+    return loss
 
   def eval(self, data):
-    loss, sampled, dist = self._eval(data)
-    return loss.numpy(), sampled.numpy(), dist.numpy()
\ No newline at end of file
+    loss = self._eval(data)
+    return loss.numpy()
\ No newline at end of file
diff --git a/Core/CTestInpaintingLoader.py b/Core/CTestInpaintingLoader.py
new file mode 100644
index 0000000..e99291e
--- /dev/null
+++ b/Core/CTestInpaintingLoader.py
@@ -0,0 +1,26 @@
+import tensorflow as tf
+import numpy as np
+import os, glob
+from functools import lru_cache
+
+class CTestInpaintingLoader(tf.keras.utils.Sequence):
+  def __init__(self, testFolder):
+    self._batchesNpz = [
+      f for f in glob.glob(os.path.join(testFolder, 'test-*.npz'))
+    ]
+    self.on_epoch_end()
+    return
+  
+  def on_epoch_end(self):
+    return
+
+  def __len__(self):
+    return len(self._batchesNpz)
+  
+  def __getitem__(self, idx):
+    with np.load(self._batchesNpz[idx]) as res:
+      res = {k: v for k, v in res.items()}
+      
+    X = {k.replace('X_', ''): v for k, v in res.items() if 'X_' in k}
+    Y = {k.replace('Y_', ''): v for k, v in res.items() if 'Y_' in k}
+    return(X, Y)
\ No newline at end of file
diff --git a/Core/CTestLoader.py b/Core/CTestLoader.py
index f3bc23d..afecbca 100644
--- a/Core/CTestLoader.py
+++ b/Core/CTestLoader.py
@@ -5,25 +5,12 @@
 
 class CTestLoader(tf.keras.utils.Sequence):
   def __init__(self, testFolder):
-    self._folder = testFolder
     self._batchesNpz = [
       f for f in glob.glob(os.path.join(testFolder, 'test-*.npz'))
     ]
     self.on_epoch_end()
     return
   
-  @property
-  def folder(self):
-    return self._folder
-  
-  @lru_cache(maxsize=1)
-  def parametersIDs(self):
-    batch, _ = self[0]
-    userId = batch['userId'][0, 0, 0]
-    placeId = batch['placeId'][0, 0, 0]
-    screenId = batch['screenId'][0, 0, 0]
-    return userId, placeId, screenId
-    
   def on_epoch_end(self):
     return
 
@@ -35,15 +22,4 @@ def __getitem__(self, idx):
       res = {k: v for k, v in res.items()}
       
     Y = res.pop('y')
-    return(res, (Y, ))
-
-if __name__ == '__main__':
-  folder = os.path.dirname(__file__)
-  ds = CTestLoader(os.path.join(folder, 'test'))
-  print(len(ds))
-  batch, (y,) = ds[0]
-  for k, v in batch.items():
-    print(k, v.shape)
-  print()
-  print(y.shape)
-  pass
\ No newline at end of file
+    return(res, (Y, ))
\ No newline at end of file
diff --git a/scripts/create-test-dataset-inpainting.py b/scripts/create-test-dataset-inpainting.py
new file mode 100644
index 0000000..ffcf02c
--- /dev/null
+++ b/scripts/create-test-dataset-inpainting.py
@@ -0,0 +1,148 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-.
+import argparse, os, sys
+# add the root folder of the project to the path
+ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
+sys.path.append(ROOT_FOLDER)
+
+import numpy as np
+import Core.Utils as Utils
+from Core.CSamplesStorage import CSamplesStorage
+from Core.CDataSamplerInpainting import CDataSamplerInpainting
+from collections import defaultdict
+import glob
+import json
+import shutil
+import tensorflow as tf
+
+BATCH_SIZE = 128 * 4
+
+def samplesStream(params, take, filename, stats):
+  if not isinstance(take, list): take = [take]
+  # filename is "{placeId}/{userId}/{screenId}/train.npz"
+  # extract the placeId, userId, and screenId
+  parts = os.path.split(filename)[0].split(os.path.sep)
+  placeId, userId, screenId = parts[-3], parts[-2], parts[-1]
+  # use the stats to get the numeric values of the placeId, userId, and screenId  
+  ds = CDataSamplerInpainting(
+    CSamplesStorage(
+      placeId=stats['placeId'].index(placeId),
+      userId=stats['userId'].index(userId),
+      screenId=stats['screenId'].index('%s/%s' % (placeId, screenId))
+    ),
+    defaults=params, 
+    batch_size=BATCH_SIZE, minFrames=params['timesteps'],
+    keys=take
+  )
+  ds.addBlock(Utils.datasetFrom(filename))
+  
+  N = ds.totalSamples
+  for i in range(0, N, BATCH_SIZE):
+    indices = list(range(i, min(i + BATCH_SIZE, N)))
+    batch, rejected, accepted = ds.sampleByIds(indices)
+    if batch is None: continue
+
+    # main batch
+    x, y = batch
+    for idx in range(len(x['points'])):
+      resX = {}
+      for k, v in x.items():
+        item = v[idx, None]
+        if tf.is_tensor(item): item = item.numpy()
+        resX[f'X_{k}'] = item
+        continue
+
+      resY = {}
+      for k, v in y.items():
+        item = v[idx, None]
+        if tf.is_tensor(item): item = item.numpy()
+        resY[f'Y_{k}'] = item
+        continue
+        
+      yield dict(**resX, **resY)
+      continue
+    continue
+  return
+
+def batches(*params):
+  data = defaultdict(list)
+  for sample in samplesStream(*params):
+    for k, v in sample.items():
+      data[k].append(v)
+      continue
+
+    if BATCH_SIZE <= len(data['X_points']):
+      yield data
+      data = defaultdict(list)
+    continue
+
+  if 0 < len(data['X_points']):
+    # copy data to match batch size
+    for k, v in data.items():
+      while len(v) < BATCH_SIZE: v.extend(v)
+      data[k] = v[:BATCH_SIZE]
+      continue
+    yield data
+  return
+############################################
+def generateTestDataset(params, filename, stats, outputFolder):
+  # generate test dataset
+  ONE_MB = 1024 * 1024
+  totalSize = 0
+  if not os.path.exists(outputFolder):
+    os.makedirs(outputFolder, exist_ok=True)
+  for bIndex, batch in enumerate(batches(params, ['clean'], filename, stats)):
+    fname = os.path.join(outputFolder, 'test-%d.npz' % bIndex)
+    # concatenate all arrays
+    batch = {k: np.concatenate(v, axis=0) for k, v in batch.items()}
+    np.savez_compressed(fname, **batch)
+    # get fname size
+    size = os.path.getsize(fname)
+    totalSize += size
+    print('%d | Size: %.1f MB | Total: %.1f MB' % (bIndex + 1, size / ONE_MB, totalSize / ONE_MB))
+    continue
+  print('Done')
+  return
+
+def main(args):
+  PARAMS = [
+    dict(      
+      timesteps=args.steps,
+      stepsSampling='uniform',
+      # no augmentations by default
+      pointsNoise=0.01, pointsDropout=0.0,
+      eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
+      targets=dict(keypoints=3, total=10),
+    ),
+  ]
+  folder = os.path.join(ROOT_FOLDER, 'Data', 'remote')
+
+  stats = None
+  with open(os.path.join(folder, 'stats.json'), 'r') as f:
+    stats = json.load(f)
+
+  # remove all content from the output folder
+  shutil.rmtree(args.output, ignore_errors=True)
+  # recursively find the train file
+  trainFilename = glob.glob(os.path.join(folder, '**', 'test.npz'), recursive=True)
+  print('Found test files:', len(trainFilename))
+  for idx, filename in enumerate(trainFilename):
+    print('Processing', filename)
+    for params in PARAMS:
+      targetFolder = os.path.join(args.output, 'test-%d' % idx)
+      generateTestDataset(params, filename, stats, outputFolder=targetFolder)
+      continue
+  return
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  parser.add_argument('--steps', type=int, default=5, help='Number of timesteps')
+  parser.add_argument('--batch-size', type=int, default=512, help='Batch size of the test dataset')
+  parser.add_argument(
+    '--output', type=str, help='Output folder',
+    default=os.path.join(ROOT_FOLDER, 'Data', 'test-inpainting')
+  )
+  args = parser.parse_args()
+  BATCH_SIZE = args.batch_size # TODO: fix this hack
+  main(args)
+  pass
\ No newline at end of file
diff --git a/scripts/make-blacklist.py b/scripts/make-blacklist.py
deleted file mode 100644
index ed13c0f..0000000
--- a/scripts/make-blacklist.py
+++ /dev/null
@@ -1,103 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-.
-'''
-This script performs the following steps:
-- Load the best model from the Data folder
-- Load the test datasets from the Data/test-main folder
-- Evaluate the model on the test datasets
-- Add each test dataset to blacklists if the model mean loss is greater than the threshold
-'''
-# TODO: add the W&B integration
-import argparse, os, sys
-# add the root folder of the project to the path
-ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
-sys.path.append(ROOT_FOLDER)
-
-import numpy as np
-from Core.CDatasetLoader import CDatasetLoader
-from Core.CTestLoader import CTestLoader
-from collections import defaultdict
-import time
-from Core.CModelTrainer import CModelTrainer
-import tqdm
-import json
-import glob
-
-def _eval(dataset, model):
-  T = time.time()
-  # evaluate the model on the val dataset
-  losses = []
-  predDist = []
-  for batchId in range(len(dataset)):
-    batch = dataset[batchId]
-    loss, _, dist = model.eval(batch)
-    predDist.append(dist)
-    losses.append(loss)
-    continue
-  
-  loss = np.mean(losses)
-  dist = np.mean(predDist)
-  T = time.time() - T
-  return loss, dist, T
-
-def evaluate(dataset, model):
-  loss, dist, T = _eval(dataset, model)
-  print('Test | %.2f sec | Loss: %.5f. Distance: %.5f' % (
-    T, loss, dist,
-  ))
-  return loss, dist
-
-def main(args):
-  timesteps = args.steps
-  folder = args.folder
-  stats = None
-  with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
-    stats = json.load(f)
-    
-  oldBadDatasets = [] # list of tuples (userId, placeId, screenId) strings
-  if os.path.exists(os.path.join(folder, 'blacklist.json')):
-    with open(os.path.join(folder, 'blacklist.json'), 'r') as f:
-      oldBadDatasets = json.load(f)
-    pass
-
-  model = dict(timesteps=timesteps, stats=stats, use_encoders=False)
-  assert args.model is not None, 'The model should be specified'
-  if args.model is not None:
-    model['weights'] = dict(folder=folder, postfix=args.model, embeddings=True)
-
-  model = CModelTrainer(**model)
-  # find folders with the name "/test-*/"
-  badDatasets = []
-  for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/')):
-    evalDataset = CTestLoader(nm)
-    loss, dist = evaluate(evalDataset, model)
-    if args.threshold < dist:
-      badDatasets.append(evalDataset.parametersIDs())
-    continue
-  # convert indices to the strings
-  res = []
-  for userId, placeId, screenId in badDatasets:
-    userId = stats['userId'][userId]
-    placeId = stats['placeId'][placeId]
-    screenId = stats['screenId'][screenId]
-    res.append((userId, placeId, screenId))
-    continue
-  res = oldBadDatasets + res # add the old blacklisted datasets
-  print('Blacklisted datasets:')
-  print(json.dumps(res, indent=2))
-  # save the blacklisted datasets
-  with open(os.path.join(folder, 'blacklist.json'), 'w') as f:
-    json.dump(res, f, indent=2)
-  return
-
-if __name__ == '__main__':
-  parser = argparse.ArgumentParser()
-  parser.add_argument('--steps', type=int, default=5)
-  parser.add_argument('--model', type=str, default='best')
-  parser.add_argument('--folder', type=str, default=os.path.join(ROOT_FOLDER, 'Data'))
-  parser.add_argument(
-    '--threshold', type=float, required=True,
-  )
-
-  main(parser.parse_args())
-  pass
\ No newline at end of file
diff --git a/scripts/train-reconstruction.py b/scripts/train-reconstruction.py
index 893bcf3..f91b4f7 100644
--- a/scripts/train-reconstruction.py
+++ b/scripts/train-reconstruction.py
@@ -9,7 +9,7 @@
 import numpy as np
 from Core.CDataSamplerInpainting import CDataSamplerInpainting
 from Core.CDatasetLoader import CDatasetLoader
-from Core.CTestLoader import CTestLoader
+from Core.CTestInpaintingLoader import CTestInpaintingLoader
 from collections import defaultdict
 import time
 from Core.CInpaintingTrainer import CInpaintingTrainer
@@ -20,59 +20,44 @@
 def _eval(dataset, model):
   T = time.time()
   # evaluate the model on the val dataset
-  lossPerSample = {'loss': [], 'pos': []}
-  predV = []
-  predDist = []
-  Y = []
+  loss = []
   for batchId in range(len(dataset)):
-    _, (y,) = batch = dataset[batchId]
-    loss, predP, dist = model.eval(batch)
-    predV.append(predP)
-    predDist.append(dist)
-    Y.append(y[:, -1, 0])
-    for l, pos in zip(loss, y[:, -1]):
-      lossPerSample['loss'].append(l)
-      lossPerSample['pos'].append(pos[0])
-      continue
+    batch = dataset[batchId]
+    loss_value = model.eval(batch)
+    loss.append(loss_value)
     continue
 
-  loss = np.mean(lossPerSample['loss'])
-  dist = np.mean(predDist)
+  loss = np.mean(loss)
   T = time.time() - T
-  return loss, dist, T
+  return loss, T
 
 def evaluator(datasets, model, folder, args):
   losses = [np.inf] * len(datasets) # initialize with infinity
-  dists = [np.inf] * len(datasets) # initialize with infinity
   def evaluate(onlyImproved=False):
     totalLoss = []
-    totalDist = []
     losses_dist = []
     for i, dataset in enumerate(datasets):
-      loss, dist, T = _eval(dataset, model)
-      losses_dist.append((loss, losses[i], dist, dists[i]))
+      loss, T = _eval(dataset, model)
       isImproved = loss < losses[i]
       if (not onlyImproved) or isImproved:
-        print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (
-          i + 1, len(datasets), T, loss, losses[i], dist, dists[i]
+        print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f).' % (
+          i + 1, len(datasets), T, loss, losses[i],
         ))
       if isImproved:
-        print('Test %d / %d | Improved %.5f => %.5f, Distance: %.5f => %.5f' % (
-          i + 1, len(datasets), losses[i], loss, dists[i], dist
+        print('Test %d / %d | Improved %.5f => %.5f,' % (
+          i + 1, len(datasets), losses[i], loss,
         ))
-        model.save(folder, postfix='best-%d' % i) # save the model separately
+        # model.save(folder, postfix='best-%d' % i) # save the model separately
         losses[i] = loss
         pass
 
-      dists[i] = min(dist, dists[i]) # track the best distance
       totalLoss.append(loss)
-      totalDist.append(dist)
       continue
     if not onlyImproved:
-      print('Mean loss: %.5f | Mean distance: %.5f' % (
-        np.mean(totalLoss), np.mean(totalDist)
+      print('Mean loss: %.5f' % (
+        np.mean(totalLoss)
       ))
-    return np.mean(totalLoss), losses_dist
+    return np.mean(totalLoss)
   return evaluate
 
 def _modelTrainingLoop(model, dataset):
@@ -161,18 +146,17 @@ def main(args):
 
   # find folders with the name "/test-*/"
   evalDatasets = [
-    # CTestLoader(nm)
-    # for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/'))
+    CTestInpaintingLoader(nm)
+    for nm in glob.glob(os.path.join(folder, 'test-inpainting', 'test-*/'))
   ]
   eval = evaluator(evalDatasets, model, folder, args)
-  bestLoss, _ = eval() # evaluate loaded model
+  bestLoss = eval() # evaluate loaded model
   bestEpoch = 0
   # wrapper for the evaluation function. It saves the model if it is better
   def evalWrapper(eval):
     def f(epoch, onlyImproved=False):
-      return
       nonlocal bestLoss, bestEpoch
-      newLoss, losses = eval(onlyImproved=onlyImproved)
+      newLoss = eval(onlyImproved=onlyImproved)
       if newLoss < bestLoss:
         print('Improved %.5f => %.5f' % (bestLoss, newLoss))
         if onlyImproved: #details
@@ -182,7 +166,7 @@ def f(epoch, onlyImproved=False):
           print('-' * 80)
         bestLoss = newLoss
         bestEpoch = epoch
-        model.save(folder, postfix='best')
+        # model.save(folder, postfix='best')
       return
     return f
   

From 2bfcadaf373b8c178b5117e20ecac2f16a92fd65 Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Sat, 19 Oct 2024 21:37:48 +0200
Subject: [PATCH 07/10] wip

---
 Core/CBaseModel.py                        |  2 +-
 Core/CDatasetLoader.py                    | 55 +++++++-------
 Core/CInpaintingTrainer.py                | 11 ++-
 Core/CTestInpaintingLoader.py             |  8 ++
 Core/CTestLoader.py                       | 10 ++-
 Core/Utils.py                             | 23 +++++-
 NN/networks.py                            |  2 +-
 scripts/check-dataset.py                  | 92 +++++++++++++++++++++++
 scripts/create-test-dataset-inpainting.py | 66 ++++++++--------
 scripts/train-reconstruction.py           | 71 +++++------------
 10 files changed, 217 insertions(+), 123 deletions(-)
 create mode 100644 scripts/check-dataset.py

diff --git a/Core/CBaseModel.py b/Core/CBaseModel.py
index 4bb0204..7be84d0 100644
--- a/Core/CBaseModel.py
+++ b/Core/CBaseModel.py
@@ -22,7 +22,7 @@ def replaceByEmbeddings(self, data):
 
   def _modelFilename(self, folder, postfix=''):
     postfix = '-' + postfix if postfix else ''
-    return os.path.join(folder, '%s%s.h5' % (self._modelID, postfix))
+    return os.path.join(folder, '%s%s.h5' % (self._model, postfix))
   
   def save(self, folder=None, postfix=''):
     path = self._modelFilename(folder, postfix)
diff --git a/Core/CDatasetLoader.py b/Core/CDatasetLoader.py
index f827bf7..d0c3733 100644
--- a/Core/CDatasetLoader.py
+++ b/Core/CDatasetLoader.py
@@ -1,9 +1,7 @@
 import Core.Utils as Utils
-import os, glob
+import os
 from Core.CSamplesStorage import CSamplesStorage
-from Core.CDataSampler import CDataSampler
 import numpy as np
-import tensorflow as tf
 from enum import Enum
 
 class ESampling(Enum):
@@ -11,40 +9,39 @@ class ESampling(Enum):
   UNIFORM = 'uniform'
   
 class CDatasetLoader:
-  def __init__(self, folder, samplerArgs, sampling, stats, sampler_class):
-    # recursively find all 'train.npz' files
-    trainFiles = glob.glob(os.path.join(folder, '**', 'train.npz'), recursive=True)
-    if 0 == len(trainFiles):
-      raise Exception('No training dataset found in "%s"' % (folder, ))
-      exit(1)
-    
-    print('Found %d training datasets' % (len(trainFiles), ))
-
+  def __init__(self, folder, samplerArgs, sampling, stats, sampler_class, test_folders):
     self._datasets = []
-    for trainFile in trainFiles:
-      print('Loading %s' % (trainFile, ))
-      # extract the placeId, userId, and screenId
-      parts = os.path.split(trainFile)[0].split(os.path.sep)
-      placeId, userId, screenId = parts[-3], parts[-2], parts[-1]
-      ds = sampler_class(
-        CSamplesStorage(
-          placeId=stats['placeId'].index(placeId),
-          userId=stats['userId'].index(userId),
-          screenId=stats['screenId'].index('%s/%s' % (placeId, screenId))
-        ),
-        **samplerArgs
-      )
-      ds.addBlock(Utils.datasetFrom(trainFile))
-      self._datasets.append(ds)
-      continue
+    for datasetFolder, ID in Utils.dataset_from_stats(stats, folder):
+      (place_id_index, user_id_index, screen_id_index) = ID
+      for test_folder in test_folders:
+        dataset = os.path.join(datasetFolder, test_folder)
+        if not os.path.exists(dataset):
+          continue
+        print('Loading %s' % (dataset, ))
+        print(f'ID: {ID}. Index: {1 + len(self._datasets)}')
+        ds = sampler_class(
+          CSamplesStorage(
+            placeId=place_id_index,
+            userId=user_id_index,
+            screenId=screen_id_index,
+          ),
+          **samplerArgs
+        )
+        ds.addBlock(Utils.datasetFrom(dataset))
+        self._datasets.append(ds)
+
+    if 0 == len(self._datasets):
+      raise Exception('No training dataset found in "%s"' % (folder, ))
     
-    print('Loaded %d datasets' % (len(self._datasets), ))
     validSamples = {
       i: len(ds.validSamples())
       for i, ds in enumerate(self._datasets)
     }
     # ignore datasets with no valid samples
     validSamples = {k: v for k, v in validSamples.items() if 0 < v}
+
+    print('Loaded %d datasets with %d valid samples' % (len(self._datasets), sum(validSamples.values())))
+
     dtype = np.uint8 if len(self._datasets) < 256 else np.uint32
     # create an array of dataset indices to sample from
     sampling = ESampling(sampling)
diff --git a/Core/CInpaintingTrainer.py b/Core/CInpaintingTrainer.py
index 4e25776..2d457ae 100644
--- a/Core/CInpaintingTrainer.py
+++ b/Core/CInpaintingTrainer.py
@@ -40,6 +40,9 @@ def __init__(self, timesteps, model='simple', KP=5, **kwargs):
       self._eval,
       input_signature=[specification]
     )
+
+    if 'weights' in kwargs:
+      self.load(**kwargs['weights'])
     return
   
   def compile(self):
@@ -94,4 +97,10 @@ def _eval(self, xy):
 
   def eval(self, data):
     loss = self._eval(data)
-    return loss.numpy()
\ No newline at end of file
+    return loss.numpy()
+    
+  def save(self, folder=None, postfix=''):
+    self._model.save(folder=folder, postfix=postfix)
+
+  def load(self, folder=None, postfix='', embeddings=False):
+    self._model.load(folder=folder, postfix=postfix, embeddings=embeddings)
\ No newline at end of file
diff --git a/Core/CTestInpaintingLoader.py b/Core/CTestInpaintingLoader.py
index e99291e..f67a86b 100644
--- a/Core/CTestInpaintingLoader.py
+++ b/Core/CTestInpaintingLoader.py
@@ -11,6 +11,14 @@ def __init__(self, testFolder):
     self.on_epoch_end()
     return
   
+  @lru_cache(maxsize=1)
+  def parametersIDs(self):
+    batch, _ = self[0]
+    userId = batch['userId'][0, 0, 0]
+    placeId = batch['placeId'][0, 0, 0]
+    screenId = batch['screenId'][0, 0, 0]
+    return placeId, userId, screenId
+    
   def on_epoch_end(self):
     return
 
diff --git a/Core/CTestLoader.py b/Core/CTestLoader.py
index afecbca..be11247 100644
--- a/Core/CTestLoader.py
+++ b/Core/CTestLoader.py
@@ -10,7 +10,15 @@ def __init__(self, testFolder):
     ]
     self.on_epoch_end()
     return
-  
+
+  @lru_cache(maxsize=1)
+  def parametersIDs(self):
+    batch, _ = self[0]
+    userId = batch['userId'][0, 0, 0]
+    placeId = batch['placeId'][0, 0, 0]
+    screenId = batch['screenId'][0, 0, 0]
+    return placeId, userId, screenId
+      
   def on_epoch_end(self):
     return
 
diff --git a/Core/Utils.py b/Core/Utils.py
index df78942..2446528 100644
--- a/Core/Utils.py
+++ b/Core/Utils.py
@@ -297,4 +297,25 @@ def countSamplesIn(folder):
     with np.load(fn) as data:
       res += len(data['time'])
     continue
-  return res
\ No newline at end of file
+  return res
+
+def dataset_from_stats(stats, folder):
+  userId = stats['userId']
+  placeId = stats['placeId']
+  screenId = stats['screenId']
+  # screenId is a concatenation of placeId and screenId, to make it unique pair
+  PlaceAndScreenId = [x.split('/') for x in screenId]
+
+  blackList = set(stats.get('blacklist', []))
+  known = set([tuple(x) for x in blackList])
+  for screen_id_index, (place_id, screen_id) in enumerate(PlaceAndScreenId):
+    place_id_index = placeId.index(place_id)
+    # find user_id among all
+    for user_id_index, user_id in enumerate(userId):
+      datasetFolder = os.path.join(folder, place_id, user_id, screen_id)
+      if not os.path.exists(datasetFolder): continue
+      ID = (place_id_index, user_id_index, screen_id_index)
+      if ID in known: continue
+      known.add(ID)
+
+      yield (datasetFolder, ID)
\ No newline at end of file
diff --git a/NN/networks.py b/NN/networks.py
index 9bfed75..30afcea 100644
--- a/NN/networks.py
+++ b/NN/networks.py
@@ -351,7 +351,7 @@ def transformLatents(x):
   # two eyes
   eyesN = eyeSize * eyeSize
   eyes = sMLP(sizes=[eyesN] * 2, activation='relu')(latents)
-  eyes = L.Dense(eyesN * 2)(eyes)
+  eyes = L.Dense(eyesN * 2, 'sigmoid')(eyes)
   eyes = L.Reshape((-1, eyeSize, eyeSize, 2))(eyes)
   # face points
   face = sMLP(sizes=[pointsN] * 2, activation='relu')(latents)
diff --git a/scripts/check-dataset.py b/scripts/check-dataset.py
new file mode 100644
index 0000000..56adacf
--- /dev/null
+++ b/scripts/check-dataset.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-.
+'''
+This script is load one by one the datasets and check how many unique samples are there
+'''
+import argparse, os, sys
+# add the root folder of the project to the path
+ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
+sys.path.append(ROOT_FOLDER)
+
+from Core.CDataSamplerInpainting import CDataSamplerInpainting
+from Core.CDataSampler import CDataSampler
+import Core.Utils as Utils
+import json
+from Core.CSamplesStorage import CSamplesStorage
+
+def samplesStream(params, filename, ID, batch_size, is_inpainting):
+  placeId, userId, screenId = ID
+  storage = CSamplesStorage(placeId=placeId, userId=userId, screenId=screenId)
+  if is_inpainting:
+    ds = CDataSamplerInpainting(
+      storage,
+      defaults=params, 
+      batch_size=batch_size, minFrames=params['timesteps'],
+      keys=['clean']
+    )
+  else:
+    ds = CDataSampler(
+      storage,
+      defaults=params, 
+      batch_size=batch_size, minFrames=params['timesteps'],
+    )
+  ds.addBlock(Utils.datasetFrom(filename))
+  
+  N = ds.totalSamples
+  for i in range(0, N, batch_size):
+    indices = list(range(i, min(i + batch_size, N)))
+    batch, rejected, accepted = ds.sampleByIds(indices)
+    if batch is None: continue
+
+    # main batch
+    x, y = batch
+    if not is_inpainting:
+      x = x['clean']
+    for idx in range(len(x['points'])):
+      yield idx
+  return
+
+def main(args):
+  params = dict(
+    timesteps=args.steps,
+    stepsSampling='uniform',
+    # no augmentations by default
+    pointsNoise=0.0, pointsDropout=0.0,
+    eyesDropout=0.0, eyesAdditiveNoise=0.0, brightnessFactor=1.0, lightBlobFactor=1.0,
+    targets=dict(keypoints=3, total=10),
+  )
+  folder = os.path.join(args.folder, 'Data', 'remote')
+
+  stats = None
+  with open(os.path.join(folder, 'stats.json'), 'r') as f:
+    stats = json.load(f)
+
+  # enable all disabled datasets
+  stats['blacklist'] = []
+  for datasetFolder, ID in Utils.dataset_from_stats(stats, folder):
+    trainFile = os.path.join(datasetFolder, 'train.npz')
+    if not os.path.exists(trainFile):
+      continue
+    print('Processing', trainFile)
+
+    stream = samplesStream(params, trainFile, ID=ID, batch_size=64, is_inpainting=args.inpainting)
+    samplesN = 0
+    for _ in stream:
+      samplesN += 1
+      continue
+    print(f'Dataset has {samplesN} valid samples')
+    if samplesN <= args.min_samples:
+      print(f'Warning: dataset has less or equal to {args.min_samples} samples and will be disabled')
+      stats['blacklist'].append(ID)
+
+  with open(os.path.join(folder, 'stats.json'), 'w') as f:
+    json.dump(stats, f, indent=2, sort_keys=True, default=str)
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  parser.add_argument('--steps', type=int, default=5)
+  parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
+  parser.add_argument('--min-samples', type=int, default=0)
+  parser.add_argument('--inpainting', action='store_true', default=False)
+  main(parser.parse_args())
+  pass
\ No newline at end of file
diff --git a/scripts/create-test-dataset-inpainting.py b/scripts/create-test-dataset-inpainting.py
index ffcf02c..2fef601 100644
--- a/scripts/create-test-dataset-inpainting.py
+++ b/scripts/create-test-dataset-inpainting.py
@@ -10,35 +10,28 @@
 from Core.CSamplesStorage import CSamplesStorage
 from Core.CDataSamplerInpainting import CDataSamplerInpainting
 from collections import defaultdict
-import glob
 import json
-import shutil
 import tensorflow as tf
 
-BATCH_SIZE = 128 * 4
-
-def samplesStream(params, take, filename, stats):
+def samplesStream(params, take, filename, ID, batch_size):
   if not isinstance(take, list): take = [take]
-  # filename is "{placeId}/{userId}/{screenId}/train.npz"
-  # extract the placeId, userId, and screenId
-  parts = os.path.split(filename)[0].split(os.path.sep)
-  placeId, userId, screenId = parts[-3], parts[-2], parts[-1]
+  placeId, userId, screenId = ID
   # use the stats to get the numeric values of the placeId, userId, and screenId  
   ds = CDataSamplerInpainting(
     CSamplesStorage(
-      placeId=stats['placeId'].index(placeId),
-      userId=stats['userId'].index(userId),
-      screenId=stats['screenId'].index('%s/%s' % (placeId, screenId))
+      placeId=placeId,
+      userId=userId,
+      screenId=screenId,
     ),
     defaults=params, 
-    batch_size=BATCH_SIZE, minFrames=params['timesteps'],
+    batch_size=batch_size, minFrames=params['timesteps'],
     keys=take
   )
   ds.addBlock(Utils.datasetFrom(filename))
   
   N = ds.totalSamples
-  for i in range(0, N, BATCH_SIZE):
-    indices = list(range(i, min(i + BATCH_SIZE, N)))
+  for i in range(0, N, batch_size):
+    indices = list(range(i, min(i + batch_size, N)))
     batch, rejected, accepted = ds.sampleByIds(indices)
     if batch is None: continue
 
@@ -64,14 +57,14 @@ def samplesStream(params, take, filename, stats):
     continue
   return
 
-def batches(*params):
+def batches(stream, batch_size):
   data = defaultdict(list)
-  for sample in samplesStream(*params):
+  for sample in stream:
     for k, v in sample.items():
       data[k].append(v)
       continue
 
-    if BATCH_SIZE <= len(data['X_points']):
+    if batch_size <= len(data['X_points']):
       yield data
       data = defaultdict(list)
     continue
@@ -79,19 +72,19 @@ def batches(*params):
   if 0 < len(data['X_points']):
     # copy data to match batch size
     for k, v in data.items():
-      while len(v) < BATCH_SIZE: v.extend(v)
-      data[k] = v[:BATCH_SIZE]
+      while len(v) < batch_size: v.extend(v)
+      data[k] = v[:batch_size]
       continue
     yield data
   return
 ############################################
-def generateTestDataset(params, filename, stats, outputFolder):
+def generateTestDataset(outputFolder, stream):
   # generate test dataset
   ONE_MB = 1024 * 1024
   totalSize = 0
   if not os.path.exists(outputFolder):
     os.makedirs(outputFolder, exist_ok=True)
-  for bIndex, batch in enumerate(batches(params, ['clean'], filename, stats)):
+  for bIndex, batch in enumerate(stream):
     fname = os.path.join(outputFolder, 'test-%d.npz' % bIndex)
     # concatenate all arrays
     batch = {k: np.concatenate(v, axis=0) for k, v in batch.items()}
@@ -121,28 +114,29 @@ def main(args):
   with open(os.path.join(folder, 'stats.json'), 'r') as f:
     stats = json.load(f)
 
-  # remove all content from the output folder
-  shutil.rmtree(args.output, ignore_errors=True)
-  # recursively find the train file
-  trainFilename = glob.glob(os.path.join(folder, '**', 'test.npz'), recursive=True)
-  print('Found test files:', len(trainFilename))
-  for idx, filename in enumerate(trainFilename):
-    print('Processing', filename)
-    for params in PARAMS:
-      targetFolder = os.path.join(args.output, 'test-%d' % idx)
-      generateTestDataset(params, filename, stats, outputFolder=targetFolder)
+  for datasetFolder, (place_id_index, user_id_index, screen_id_index) in Utils.dataset_from_stats(stats, folder):
+    trainFile = os.path.join(datasetFolder, 'test.npz')
+    if not os.path.exists(trainFile):
       continue
-  return
+    print('Processing', trainFile)
+
+    for i, params in enumerate(PARAMS):
+      output = args.output
+      if 0 < i: output += '-%d' % i
+      targetFolder = os.path.join(datasetFolder, output)
+      ID = (place_id_index, user_id_index, screen_id_index)
+      stream = samplesStream(params, ['clean'], trainFile, ID=ID, batch_size=args.batch_size)
+      stream = batches(stream, batch_size=args.batch_size)
+      generateTestDataset(outputFolder=targetFolder, stream=stream)
 
 if __name__ == '__main__':
   parser = argparse.ArgumentParser()
   parser.add_argument('--steps', type=int, default=5, help='Number of timesteps')
   parser.add_argument('--batch-size', type=int, default=512, help='Batch size of the test dataset')
   parser.add_argument(
-    '--output', type=str, help='Output folder',
-    default=os.path.join(ROOT_FOLDER, 'Data', 'test-inpainting')
+    '--output', type=str, help='Output folder name',
+    default='test-inpainting'
   )
   args = parser.parse_args()
-  BATCH_SIZE = args.batch_size # TODO: fix this hack
   main(args)
   pass
\ No newline at end of file
diff --git a/scripts/train-reconstruction.py b/scripts/train-reconstruction.py
index f91b4f7..6a6d6d1 100644
--- a/scripts/train-reconstruction.py
+++ b/scripts/train-reconstruction.py
@@ -10,12 +10,12 @@
 from Core.CDataSamplerInpainting import CDataSamplerInpainting
 from Core.CDatasetLoader import CDatasetLoader
 from Core.CTestInpaintingLoader import CTestInpaintingLoader
+import Core.Utils as Utils
 from collections import defaultdict
 import time
 from Core.CInpaintingTrainer import CInpaintingTrainer
 import tqdm
 import json
-import glob
 
 def _eval(dataset, model):
   T = time.time()
@@ -35,28 +35,26 @@ def evaluator(datasets, model, folder, args):
   losses = [np.inf] * len(datasets) # initialize with infinity
   def evaluate(onlyImproved=False):
     totalLoss = []
-    losses_dist = []
     for i, dataset in enumerate(datasets):
       loss, T = _eval(dataset, model)
       isImproved = loss < losses[i]
       if (not onlyImproved) or isImproved:
-        print('Test %d / %d | %.2f sec | Loss: %.5f (%.5f).' % (
-          i + 1, len(datasets), T, loss, losses[i],
+        dataset_id = ', '.join([str(x) for x in dataset.parametersIDs()])
+        print('Test %d / %d (%s) | %.2f sec | Loss: %.5f (%.5f).' % (
+          i + 1, len(datasets), dataset_id, T, loss, losses[i],
         ))
       if isImproved:
         print('Test %d / %d | Improved %.5f => %.5f,' % (
           i + 1, len(datasets), losses[i], loss,
         ))
-        # model.save(folder, postfix='best-%d' % i) # save the model separately
+        model.save(folder, postfix='best-%d' % i) # save the model separately
         losses[i] = loss
         pass
 
       totalLoss.append(loss)
       continue
     if not onlyImproved:
-      print('Mean loss: %.5f' % (
-        np.mean(totalLoss)
-      ))
+      print('Mean loss: %.5f' % (np.mean(totalLoss), ))
     return np.mean(totalLoss)
   return evaluate
 
@@ -84,28 +82,6 @@ def _trainer_from(args):
   if args.trainer == 'default': return CInpaintingTrainer
   raise Exception('Unknown trainer: %s' % (args.trainer, ))
 
-def averageModels(folder, model, noiseStd=0.0):
-  TV = [np.zeros_like(x) for x in model.trainable_variables()]
-  N = 0
-  for nm in glob.glob(os.path.join(folder, '*.h5')):
-    if not('best' in nm): continue # only the best models
-    model.load(nm, embeddings=True)
-    # add the weights to the total
-    weights = model.trainable_variables()
-    for i in range(len(TV)):
-      TV[i] += weights[i].numpy()
-      continue
-    N += 1
-    continue
-
-  # average the weights
-  TV = [(x / N) + np.random.normal(0.0, noiseStd, x.shape) for x in TV]
-  for v, new in zip(model.trainable_variables(), TV):
-    v.assign(new)
-    continue
-  model.compile() # recompile the model with the new weights
-  return
-
 def main(args):
   timesteps = args.steps
   folder = os.path.join(args.folder, 'Data')
@@ -125,15 +101,16 @@ def main(args):
       maxT=1.0,
       defaults=dict(
         timesteps=timesteps,
-        stepsSampling='uniform',
+        stepsSampling={'max frames': 10},
         # no augmentations by default
-        pointsNoise=0.01, pointsDropout=0.0,
+        pointsNoise=0.01, pointsDropout=0.01,
         eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
         targets=dict(keypoints=3, total=10),
       ),
-      keys=['clean', 'augmented'],
+      keys=['clean'],
     ),
     sampler_class=CDataSamplerInpainting,
+    test_folders=['train.npz'],
   )
   model = dict(timesteps=timesteps, stats=stats)
   if args.model is not None:
@@ -143,11 +120,11 @@ def main(args):
 
   model = trainer(**model)
 #   model._model.summary()
-
-  # find folders with the name "/test-*/"
+  
   evalDatasets = [
-    CTestInpaintingLoader(nm)
-    for nm in glob.glob(os.path.join(folder, 'test-inpainting', 'test-*/'))
+    CTestInpaintingLoader(os.path.join(folderName, 'test-inpainting'))
+    for folderName, _ in Utils.dataset_from_stats(stats, os.path.join(folder, 'remote'))
+    if os.path.exists(os.path.join(folderName, 'test-inpainting'))
   ]
   eval = evaluator(evalDatasets, model, folder, args)
   bestLoss = eval() # evaluate loaded model
@@ -166,7 +143,7 @@ def f(epoch, onlyImproved=False):
           print('-' * 80)
         bestLoss = newLoss
         bestEpoch = epoch
-        # model.save(folder, postfix='best')
+        model.save(folder, postfix='best')
       return
     return f
   
@@ -176,14 +153,13 @@ def f(epoch, onlyImproved=False):
     trainStep(
       desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
     )
-    # model.save(folder, postfix='latest')
+    model.save(folder, postfix='latest')
     eval(epoch)
 
     print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
     if args.patience <= (epoch - bestEpoch):
-      if 'stop' == args.on_patience:
-        print('Early stopping')
-        break
+      print('Early stopping')
+      break
     continue
   return
 
@@ -192,26 +168,15 @@ def f(epoch, onlyImproved=False):
   parser.add_argument('--epochs', type=int, default=1000)
   parser.add_argument('--batch-size', type=int, default=64)
   parser.add_argument('--patience', type=int, default=5)
-  parser.add_argument('--on-patience', type=str, default='stop', choices=['stop', 'reset'])
   parser.add_argument('--steps', type=int, default=5)
   parser.add_argument('--model', type=str)
   parser.add_argument('--embeddings', default=False, action='store_true')
-  parser.add_argument(
-    '--average', default=False, action='store_true',
-    help='Load each model from the folder and average them weights'
-  )
   parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
   parser.add_argument('--modelId', type=str)
   parser.add_argument(
     '--trainer', type=str, default='default',
     choices=['default']
   )
-  parser.add_argument('--debug', action='store_true')
-  parser.add_argument('--noise', type=float, default=1e-4)
-  parser.add_argument(
-    '--restarts', type=int, default=1,
-    help='Number of times to restart the model reinitializing the weights'
-  )
   parser.add_argument(
     '--sampling', type=str, default='uniform',
     choices=['uniform', 'as_is'],

From 1b4e38b6201931e3bc90800f4e18edb813170023 Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Sun, 20 Oct 2024 14:44:43 +0200
Subject: [PATCH 08/10] fix

---
 Core/Utils.py                | 8 ++++++--
 scripts/preprocess-remote.py | 2 +-
 2 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/Core/Utils.py b/Core/Utils.py
index 2446528..adab5fc 100644
--- a/Core/Utils.py
+++ b/Core/Utils.py
@@ -272,6 +272,7 @@ def extractSessions(dataset, TDelta):
   res = []
   T = 0
   prevSession = 0
+  N = len(dataset['time'])
   for i, t in enumerate(dataset['time']):
     if TDelta < (t - T):
       if 1 < (i - prevSession):
@@ -281,10 +282,13 @@ def extractSessions(dataset, TDelta):
     T = t
     continue
   # if last session is not empty, then append it
-  if prevSession < len(dataset['time']):
-    res.append((prevSession, len(dataset['time'])))
+  if prevSession < N:
+    res.append((prevSession, N))
     pass
 
+  # remove sessions with less than 2 samples
+  res = [x for x in res if 1 < (x[1] - x[0])]
+
   # check that end of one session is equal or less than start of the next
   for i in range(1, len(res)):
     assert res[i-1][1] <= res[i][0]
diff --git a/scripts/preprocess-remote.py b/scripts/preprocess-remote.py
index fddf377..294c0e4 100644
--- a/scripts/preprocess-remote.py
+++ b/scripts/preprocess-remote.py
@@ -194,7 +194,7 @@ def processFolder(folder, timeDelta, testRatio, framesPerChunk, testPadding, ski
 
   if (0 == len(training)) or (0 == len(testing)):
     print('No training or testing sets found!')
-    return 0, 0, True
+    return 0, 0, True, None
 
   def saveSubset(filename, idx):
     print('%s: %d frames' % (filename, len(idx)))

From 07a01d37257f5ce7454d88ce96561ca9211e53f2 Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Sun, 20 Oct 2024 14:53:39 +0200
Subject: [PATCH 09/10] fix

---
 Core/Utils.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Core/Utils.py b/Core/Utils.py
index adab5fc..3173fcf 100644
--- a/Core/Utils.py
+++ b/Core/Utils.py
@@ -310,7 +310,7 @@ def dataset_from_stats(stats, folder):
   # screenId is a concatenation of placeId and screenId, to make it unique pair
   PlaceAndScreenId = [x.split('/') for x in screenId]
 
-  blackList = set(stats.get('blacklist', []))
+  blackList = stats.get('blacklist', [])
   known = set([tuple(x) for x in blackList])
   for screen_id_index, (place_id, screen_id) in enumerate(PlaceAndScreenId):
     place_id_index = placeId.index(place_id)

From 4c88b60ffcfc6eef1463adc9f5346974e61fc9c1 Mon Sep 17 00:00:00 2001
From: Green Wizard <forwork.anton@gmail.com>
Date: Sat, 26 Oct 2024 10:52:25 +0200
Subject: [PATCH 10/10] integrate w&b and other changes

---
 scripts/train-reconstruction.py | 314 ++++++++++++++++----------------
 1 file changed, 161 insertions(+), 153 deletions(-)

diff --git a/scripts/train-reconstruction.py b/scripts/train-reconstruction.py
index 6a6d6d1..1b584d5 100644
--- a/scripts/train-reconstruction.py
+++ b/scripts/train-reconstruction.py
@@ -1,8 +1,7 @@
 #!/usr/bin/env python
-# -*- coding: utf-8 -*-.
-# TODO: add the W&B integration
+# -*- coding: utf-8 -*-
 import argparse, os, sys
-# add the root folder of the project to the path
+# Add the root folder of the project to the path
 ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
 sys.path.append(ROOT_FOLDER)
 
@@ -16,171 +15,180 @@
 from Core.CInpaintingTrainer import CInpaintingTrainer
 import tqdm
 import json
+import wandb
 
 def _eval(dataset, model):
-  T = time.time()
-  # evaluate the model on the val dataset
-  loss = []
-  for batchId in range(len(dataset)):
-    batch = dataset[batchId]
-    loss_value = model.eval(batch)
-    loss.append(loss_value)
-    continue
-
-  loss = np.mean(loss)
-  T = time.time() - T
-  return loss, T
+    T = time.time()
+    # Evaluate the model on the validation dataset
+    loss = []
+    for batchId in range(len(dataset)):
+        batch = dataset[batchId]
+        loss_value = model.eval(batch)
+        loss.append(loss_value)
+    loss = np.mean(loss)
+    T = time.time() - T
+    return loss, T
 
 def evaluator(datasets, model, folder, args):
-  losses = [np.inf] * len(datasets) # initialize with infinity
-  def evaluate(onlyImproved=False):
-    totalLoss = []
-    for i, dataset in enumerate(datasets):
-      loss, T = _eval(dataset, model)
-      isImproved = loss < losses[i]
-      if (not onlyImproved) or isImproved:
-        dataset_id = ', '.join([str(x) for x in dataset.parametersIDs()])
-        print('Test %d / %d (%s) | %.2f sec | Loss: %.5f (%.5f).' % (
-          i + 1, len(datasets), dataset_id, T, loss, losses[i],
-        ))
-      if isImproved:
-        print('Test %d / %d | Improved %.5f => %.5f,' % (
-          i + 1, len(datasets), losses[i], loss,
-        ))
-        model.save(folder, postfix='best-%d' % i) # save the model separately
-        losses[i] = loss
-        pass
+    losses = [np.inf] * len(datasets)  # Initialize with infinity
+    def evaluate(onlyImproved=False, step=None):
+        totalLoss = []
+        eval_metrics = {}
+        for i, dataset in enumerate(datasets):
+            loss, T = _eval(dataset, model)
+            dataset_id = ', '.join([str(x) for x in dataset.parametersIDs()])
+            isImproved = loss < losses[i]
+            if (not onlyImproved) or isImproved:
+                print('Test %d / %d (%s) | %.2f sec | Loss: %.5f (%.5f).' % (
+                    i + 1, len(datasets), dataset_id, T, loss, losses[i],
+                ))
+            if isImproved:
+                print('Test %d / %d | Improved %.5f => %.5f,' % (
+                    i + 1, len(datasets), losses[i], loss,
+                ))
+                modelFolder = os.path.join(folder, f"model-{dataset_id}")
+                os.makedirs(modelFolder, exist_ok=True)
+                # keep only the best model across all runs
+                # name format: {model id}-{loss:.5f}-*.*
+                all_files = os.listdir(modelFolder)
+                all_losses = [f.split('-')[1] for f in all_files]
+                all_losses = list(set(all_losses))
+                print(f"Found losses: {all_losses}")
+                for loss_file in all_losses:
+                    if float(loss_file) > loss:
+                        # remove all files with this loss in folder
+                        to_remove = [os.path.join(modelFolder, f) for f in all_files if loss_file in f]
+                        for f in to_remove:
+                            os.remove(f)
 
-      totalLoss.append(loss)
-      continue
-    if not onlyImproved:
-      print('Mean loss: %.5f' % (np.mean(totalLoss), ))
-    return np.mean(totalLoss)
-  return evaluate
+                model.save(modelFolder, postfix='%.5f' % loss)
+                losses[i] = loss
+            totalLoss.append(loss)
+            eval_metrics['eval_loss_(%s)' % dataset_id] = loss
+        mean_loss = np.mean(totalLoss)
+        if not onlyImproved:
+            print('Mean loss: %.5f' % mean_loss)
+        # Log evaluation metrics to wandb
+        if step is not None:
+            wandb.log(eval_metrics, step=step)
+        return mean_loss
+    return evaluate
 
 def _modelTrainingLoop(model, dataset):
-  def F(desc):
-    history = defaultdict(list)
-    # use the tqdm progress bar
-    with tqdm.tqdm(total=len(dataset), desc=desc) as pbar:
-      dataset.on_epoch_start()
-      for _ in range(len(dataset)):
-        sampled = dataset.sample()
-        stats = model.fit(sampled)
-        history['time'].append(stats['time'])
-        for k in stats['losses'].keys():
-          history[k].append(stats['losses'][k])
-        # add stats to the progress bar (mean of each history)
-        pbar.set_postfix({k: '%.5f' % np.mean(v) for k, v in history.items()})
-        pbar.update(1)
-        continue
-      dataset.on_epoch_end()
-    return
-  return F
+    def F(desc):
+        history = defaultdict(list)
+        # Use the tqdm progress bar
+        with tqdm.tqdm(total=len(dataset), desc=desc) as pbar:
+            dataset.on_epoch_start()
+            for step in range(len(dataset)):
+                sampled = dataset.sample()
+                stats = model.fit(sampled)
+                history['time'].append(stats['time'])
+                for k in stats['losses'].keys():
+                    history[k].append(stats['losses'][k])
+                # Add stats to the progress bar (mean of each history)
+                pbar.set_postfix({k: '%.5f' % np.mean(v) for k, v in history.items()})
+                pbar.update(1)
+            dataset.on_epoch_end()
+        return {k: np.mean(v) for k, v in history.items()}
+    return F
 
 def _trainer_from(args):
-  if args.trainer == 'default': return CInpaintingTrainer
-  raise Exception('Unknown trainer: %s' % (args.trainer, ))
+    if args.trainer == 'default': return CInpaintingTrainer
+    raise Exception('Unknown trainer: %s' % (args.trainer, ))
 
 def main(args):
-  timesteps = args.steps
-  folder = os.path.join(args.folder, 'Data')
-
-  stats = None
-  with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
-    stats = json.load(f)
+    wandb.init(project=args.wandb_project, config=vars(args))  # Initialize wandb
+    timesteps = args.steps
+    folder = os.path.join(args.folder, 'Data')
 
-  trainer = _trainer_from(args)
-  trainDataset = CDatasetLoader(
-    os.path.join(folder, 'remote'),
-    stats=stats,
-    sampling=args.sampling,
-    samplerArgs=dict(
-      batch_size=args.batch_size,
-      minFrames=timesteps,
-      maxT=1.0,
-      defaults=dict(
-        timesteps=timesteps,
-        stepsSampling={'max frames': 10},
-        # no augmentations by default
-        pointsNoise=0.01, pointsDropout=0.01,
-        eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
-        targets=dict(keypoints=3, total=10),
-      ),
-      keys=['clean'],
-    ),
-    sampler_class=CDataSamplerInpainting,
-    test_folders=['train.npz'],
-  )
-  model = dict(timesteps=timesteps, stats=stats)
-  if args.model is not None:
-    model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
-  if args.modelId is not None:
-    model['model'] = args.modelId
+    stats = None
+    with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
+        stats = json.load(f)
 
-  model = trainer(**model)
-#   model._model.summary()
-  
-  evalDatasets = [
-    CTestInpaintingLoader(os.path.join(folderName, 'test-inpainting'))
-    for folderName, _ in Utils.dataset_from_stats(stats, os.path.join(folder, 'remote'))
-    if os.path.exists(os.path.join(folderName, 'test-inpainting'))
-  ]
-  eval = evaluator(evalDatasets, model, folder, args)
-  bestLoss = eval() # evaluate loaded model
-  bestEpoch = 0
-  # wrapper for the evaluation function. It saves the model if it is better
-  def evalWrapper(eval):
-    def f(epoch, onlyImproved=False):
-      nonlocal bestLoss, bestEpoch
-      newLoss = eval(onlyImproved=onlyImproved)
-      if newLoss < bestLoss:
-        print('Improved %.5f => %.5f' % (bestLoss, newLoss))
-        if onlyImproved: #details
-          for i, (loss, bestLoss_, dist, bestDist) in enumerate(losses):
-            print('Test %d | Loss: %.5f (%.5f). Distance: %.5f (%.5f)' % (i + 1, loss, bestLoss_, dist, bestDist))
-            continue
-          print('-' * 80)
-        bestLoss = newLoss
-        bestEpoch = epoch
-        model.save(folder, postfix='best')
-      return
-    return f
-  
-  eval = evalWrapper(eval)
-  trainStep = _modelTrainingLoop(model, trainDataset)
-  for epoch in range(args.epochs):
-    trainStep(
-      desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
+    trainer = _trainer_from(args)
+    trainDataset = CDatasetLoader(
+        os.path.join(folder, 'remote'),
+        stats=stats,
+        sampling=args.sampling,
+        samplerArgs=dict(
+            batch_size=args.batch_size,
+            minFrames=timesteps,
+            maxT=1.0,
+            defaults=dict(
+                timesteps=timesteps,
+                stepsSampling={'max frames': 10},
+                # No augmentations by default
+                pointsNoise=0.01, pointsDropout=0.01,
+                eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
+                targets=dict(keypoints=3, total=10),
+            ),
+            keys=['clean'],
+        ),
+        sampler_class=CDataSamplerInpainting,
+        test_folders=['train.npz'],
     )
-    model.save(folder, postfix='latest')
-    eval(epoch)
+    model = dict(timesteps=timesteps, stats=stats)
+    if args.model is not None:
+        model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
+    if args.modelId is not None:
+        model['model'] = args.modelId
+
+    model = trainer(**model)
+
+    evalDatasets = [
+        CTestInpaintingLoader(os.path.join(folderName, 'test-inpainting'))
+        for folderName, _ in Utils.dataset_from_stats(stats, os.path.join(folder, 'remote'))
+        if os.path.exists(os.path.join(folderName, 'test-inpainting'))
+    ]
+    eval_fn = evaluator(evalDatasets, model, folder, args)
+    bestLoss = eval_fn()  # Evaluate loaded model
+    bestEpoch = 0
 
-    print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
-    if args.patience <= (epoch - bestEpoch):
-      print('Early stopping')
-      break
-    continue
-  return
+    def evalWrapper(eval_fn):
+        def f(epoch, onlyImproved=False, step=None):
+            nonlocal bestLoss, bestEpoch
+            newLoss = eval_fn(onlyImproved=onlyImproved, step=step)
+            if newLoss < bestLoss:
+                print('Improved %.5f => %.5f' % (bestLoss, newLoss))
+                bestLoss = newLoss
+                bestEpoch = epoch
+                model.save(folder, postfix='%.5f' % newLoss)
+            return
+        return f
+
+    eval_fn = evalWrapper(eval_fn)
+    trainStep = _modelTrainingLoop(model, trainDataset)
+    for epoch in range(args.epochs):
+        metrics = trainStep(
+            desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
+        )
+        wandb.log(metrics, step=epoch + 1)
+        model.save(folder, postfix='latest')
+        eval_fn(epoch, step=epoch + 1)
+        print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
+        if args.patience <= (epoch - bestEpoch):
+            print('Early stopping')
+            break
 
 if __name__ == '__main__':
-  parser = argparse.ArgumentParser()
-  parser.add_argument('--epochs', type=int, default=1000)
-  parser.add_argument('--batch-size', type=int, default=64)
-  parser.add_argument('--patience', type=int, default=5)
-  parser.add_argument('--steps', type=int, default=5)
-  parser.add_argument('--model', type=str)
-  parser.add_argument('--embeddings', default=False, action='store_true')
-  parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
-  parser.add_argument('--modelId', type=str)
-  parser.add_argument(
-    '--trainer', type=str, default='default',
-    choices=['default']
-  )
-  parser.add_argument(
-    '--sampling', type=str, default='uniform',
-    choices=['uniform', 'as_is'],
-  )
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--epochs', type=int, default=1000)
+    parser.add_argument('--batch-size', type=int, default=64)
+    parser.add_argument('--patience', type=int, default=5)
+    parser.add_argument('--steps', type=int, default=5)
+    parser.add_argument('--model', type=str)
+    parser.add_argument('--embeddings', default=False, action='store_true')
+    parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
+    parser.add_argument('--modelId', type=str)
+    parser.add_argument(
+        '--trainer', type=str, default='default',
+        choices=['default']
+    )
+    parser.add_argument(
+        '--sampling', type=str, default='uniform',
+        choices=['uniform', 'as_is'],
+    )
+    parser.add_argument('--wandb-project', type=str, default='alternative-input-reconstruction')
 
-  main(parser.parse_args())
-  pass
\ No newline at end of file
+    main(parser.parse_args())