diff --git a/Core/CBaseDataSampler.py b/Core/CBaseDataSampler.py
new file mode 100644
index 0000000..dc02b37
--- /dev/null
+++ b/Core/CBaseDataSampler.py
@@ -0,0 +1,148 @@
+import numpy as np
+import random
+from math import ceil
+from functools import lru_cache
+
+class CBaseDataSampler:
+    def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
+        '''
+        Base class for data sampling.
+
+        Parameters:
+        - storage: The storage object containing the samples.
+        - batch_size: The number of samples per batch.
+        - minFrames: The minimum number of frames required in a trajectory.
+        - defaults: Default parameters for sampling.
+        - maxT: Maximum time window for sampling frames.
+        - cumulative_time: If True, time is cumulative; otherwise, it's time deltas.
+        '''
+        self._storage = storage
+        self._defaults = defaults
+        self._batchSize = batch_size
+        self._maxT = maxT
+        self._minFrames = minFrames
+        self._samples = []
+        self._currentSample = None
+        self._cumulative_time = cumulative_time
+        return
+
+    def reset(self):
+        random.shuffle(self._samples)
+        self._currentSample = 0
+        return
+
+    def __len__(self):
+        return ceil(len(self._samples) / self._batchSize)
+
+    def _storeSample(self, idx):
+        # Store sample if it has enough frames
+        minInd = self._getTrajectoryBefore(idx)
+        if self._minFrames <= (idx - minInd):
+            self._samples.append(idx)
+        return
+
+    def add(self, sample):
+        idx = self._storage.add(sample)
+        self._storeSample(idx)
+        return idx
+
+    def addBlock(self, samples):
+        indexes = self._storage.addBlock(samples)
+        for idx in indexes:
+            self._storeSample(idx)
+            continue
+        return
+
+    def _getTrajectoryBefore(self, mainInd):
+        mainT = self._storage[mainInd]['time']
+        minT = mainT - self._maxT
+
+        minInd = mainInd
+        for ind in range(mainInd - 1, -1, -1):
+            if self._storage[ind]['time'] < minT: break
+            minInd = ind
+            continue
+        return minInd
+
+    @lru_cache(None)
+    def _trajectoryRange(self, mainInd):
+        '''
+        Returns indexes of samples that are within maxT from mainInd.
+        Returns (minInd, maxInd) where minInd <= mainInd <= maxInd
+        '''
+        mainT = self._storage[mainInd]['time']
+        maxT = mainT + self._maxT
+        maxInd = mainInd
+        for ind in range(mainInd, len(self._storage)):
+            if maxT < self._storage[ind]['time']: break
+            maxInd = ind
+            continue
+
+        minInd = self._getTrajectoryBefore(mainInd)
+        return minInd, maxInd
+
+    def _trajectory(self, mainInd):
+        minInd, maxInd = self._trajectoryRange(mainInd)
+        return list(range(minInd, mainInd + 1)), list(range(mainInd + 1, maxInd + 1))
+
+    def _prepareT(self, res):
+        T = np.array([self._storage[ind]['time'] for ind in res])
+        T -= T[0]
+        diff = np.diff(T, 1)
+        idx = np.nonzero(diff)[0]
+        if len(idx) < 1: return None  # All frames have the same time
+        if len(diff) == len(idx):
+            T = diff
+        else:
+            return None # Time is not consistent
+        T = np.insert(T, 0, 0.0)
+        assert len(res) == len(T)
+        # Convert to cumulative time if required
+        if self._cumulative_time:
+            T = np.cumsum(T)
+        return T
+
+    def _reshapeSteps(self, values, steps):
+        if steps is None:
+            return values
+
+        res = []
+        for x in values:
+            B, *s = x.shape
+            newShape = (B // steps, steps, *s)
+            res.append(x.reshape(newShape))
+            continue
+        return tuple(res)
+
+    @property
+    def totalSamples(self):
+        return len(self._storage)
+
+    def validSamples(self):
+        return list(sorted(self._samples))
+    
+    def _framesFor(self, mainInd, samples, steps, stepsSampling):
+        if 'uniform' == stepsSampling:
+            samples = random.sample(samples, steps - 1)
+        elif 'last' == stepsSampling:
+            samples = samples[-(steps - 1):]
+        elif isinstance(stepsSampling, dict):
+            candidates = list(samples)
+            maxFrames = stepsSampling['max frames']
+            candidates = candidates[::-1]
+            samples = []
+            left = steps - 1
+            for _ in range(left):
+                avl = min((maxFrames, 1 + len(candidates) - left))
+                ind = random.randint(0, avl - 1)
+                samples.append(candidates[ind])
+                candidates = candidates[ind+1:]
+                left -= 1
+                continue
+            pass
+        else:
+            raise ValueError('Unknown sampling method: ' + str(stepsSampling))
+
+        res = list(sorted(samples + [mainInd]))
+        assert len(res) == steps
+        return res
\ No newline at end of file
diff --git a/Core/CBaseModel.py b/Core/CBaseModel.py
new file mode 100644
index 0000000..7be84d0
--- /dev/null
+++ b/Core/CBaseModel.py
@@ -0,0 +1,59 @@
+import os
+import numpy as np
+from tensorflow.keras import layers as L
+
+class CBaseModel:
+  def __init__(self, model, embeddings, submodels):
+    self._model = model
+    self._embeddings = {
+      'userId': L.Embedding(embeddings['userId'], embeddings['size']),
+      'placeId': L.Embedding(embeddings['placeId'], embeddings['size']),
+      'screenId': L.Embedding(embeddings['screenId'], embeddings['size']),
+    }
+    self._submodels = submodels
+    return  
+
+  def replaceByEmbeddings(self, data):
+    data = dict(**data) # copy
+    for name, emb in self._embeddings.items():
+      data[name] = emb(data[name][..., 0])
+      continue
+    return data
+
+  def _modelFilename(self, folder, postfix=''):
+    postfix = '-' + postfix if postfix else ''
+    return os.path.join(folder, '%s%s.h5' % (self._model, postfix))
+  
+  def save(self, folder=None, postfix=''):
+    path = self._modelFilename(folder, postfix)
+    if 1 < len(self._submodels):
+      for i, model in enumerate(self._submodels):
+        model.save_weights(path.replace('.h5', '-%d.h5' % i))
+    else:
+      self._submodels[0].save_weights(path)
+
+    embeddings = {}
+    for nm in self._embeddings.keys():
+      weights = self._embeddings[nm].get_weights()[0]
+      embeddings[nm] = weights
+    
+    np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
+    
+  def load(self, folder=None, postfix='', embeddings=False):
+    path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
+    if 1 < len(self._submodels):
+      for i, model in enumerate(self._submodels):
+        model.load_weights(path.replace('.h5', '-%d.h5' % i))
+    else:
+      self._submodels[0].load_weights(path)
+      
+    if embeddings:
+      embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
+      for nm, emb in self._embeddings.items():
+        w = embeddings[nm]
+        if not emb.built: emb.build((None, w.shape[0]))
+        emb.set_weights([w]) # replace embeddings
+    
+  def trainable_variables(self):
+    parts = list(self._embeddings.values()) + self._submodels
+    return sum([p.trainable_variables for p in parts], [])
diff --git a/Core/CDataSampler.py b/Core/CDataSampler.py
index ce81c72..aae5d40 100644
--- a/Core/CDataSampler.py
+++ b/Core/CDataSampler.py
@@ -1,306 +1,152 @@
+from .CBaseDataSampler import CBaseDataSampler
+import Core.CDataSampler_utils as DSUtils
+
 import numpy as np
-import random
-from math import ceil
-import Core.Utils as Utils
 from functools import lru_cache
-import Core.CDataSampler_utils as DSUtils
 
-class CDataSampler:
-  def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
-    '''
-    If cumulative_time is True, then time is a cumulative time from the start of the trajectory i.e. [0, 0.1, 0.2, 0.3, ...]
-    If cumulative_time is False, then time is a time delta between frames i.e. [0, 0.1, 0.1, 0.1, ...]
-    '''
-    self._storage = storage
-    self._defaults = defaults
-    self._batchSize = batch_size
-    self._maxT = maxT
-    self._minFrames = minFrames
-    self._samples = []
-    self._currentSample = None
-    self._cumulative_time = cumulative_time
-    return
-  
-  def reset(self):
-    random.shuffle(self._samples)
-    self._currentSample = 0
-    return
+'''
+This sampler are sample N frames from the dataset, where N is the number of timesteps.
+It returns the tuple (X, Y), where X is the input data and Y is the target data.
+To X could be applied some augmentations.
+X contains the following data:
+  - The points of the face.
+  - The left eye.
+  - The right eye.
+  - The time (cumulative or delta).
+  - The user ID, place ID, and screen ID.
+Y contains the target data.
+  - The target point.
+'''
+class CDataSampler(CBaseDataSampler):
+    def __init__(self, storage, batch_size, minFrames, defaults={}, maxT=1.0, cumulative_time=True):
+        super().__init__(storage, batch_size, minFrames, defaults, maxT, cumulative_time)
 
-  def __len__(self):
-    return ceil(len(self._samples) / self._batchSize)
+    def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
+        if (steps is None) or (1 == steps): return [(mainInd, 0.0)]
+        if mainInd < steps: return False
 
-  def _storeSample(self, idx):
-    # store sample if it has enough frames
-    minInd = self._getTrajectoryBefore(idx)
-    if self._minFrames <= (idx - minInd):
-      self._samples.append(idx)
-    return
-  
-  def add(self, sample):
-    idx = self._storage.add(sample)
-    self._storeSample(idx)
-    return idx
-  
-  def addBlock(self, samples):
-    indexes = self._storage.addBlock(samples)
-    for idx in indexes:
-      self._storeSample(idx)
-      continue
-    return
+        samples, _ = self._trajectory(mainInd)
+        if len(samples) < (steps - 1): return False
+        # Try to sample valid frames
+        for _ in range(10):
+            res = self._framesFor(mainInd, samples, steps, stepsSampling)
+            T = self._prepareT(res)
+            if T is not None:
+                assert len(res) == len(T)
+                return [tuple(x) for x in zip(res, T)]
+            continue
+        return False
 
-  def _getTrajectoryBefore(self, mainInd):
-    mainT = self._storage[mainInd]['time']
-    minT = mainT - self._maxT
-    
-    minInd = mainInd
-    for ind in range(mainInd - 1, -1, -1):
-      if self._storage[ind]['time'] < minT: break
-      minInd = ind
-      continue
-    return minInd
-  
-  @lru_cache(None)
-  def _trajectoryRange(self, mainInd):
-    '''
-      Returns indexes of samples that are in the range of maxT from the mainInd
-      Returns (minInd, maxInd) where minInd <= mainInd <= maxInd
-    '''
-    mainT = self._storage[mainInd]['time']
-    maxT = mainT + self._maxT
-    maxInd = mainInd
-    for ind in range(mainInd, len(self._storage)):
-      if maxT < self._storage[ind]['time']: break
-      maxInd = ind
-      continue
-    
-    minInd = self._getTrajectoryBefore(mainInd)
-    return minInd, maxInd
+    def sample(self, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        N = kwargs.get('N', self._batchSize)
+        indexes = []
+        for _ in range(N):
+            added = False
+            while not added:
+                idx = self._samples[self._currentSample]
+                self._currentSample = (self._currentSample + 1) % len(self._samples)
 
-  def _trajectory(self, mainInd):
-    minInd, maxInd = self._trajectoryRange(mainInd)
-    return list(range(minInd, mainInd)), list(range(mainInd + 1, maxInd + 1))
-  
-  def _trajectory2keypoints(self, before, mainInd, after, N):
-    mainPt = self._storage[mainInd]['goal']
-    if 1 < N:
-      trajectory = []
-      trajectory.extend(before)
-      trajectory.append(mainInd)
-      trajectory.extend(after)
-      trajectory = np.array([self._storage[ind]['goal'] for ind in trajectory])
-      chunksN = max((1, len(trajectory) // N))
-      keypoints = [(mainPt, 0.0)]
-      for i in range(0, len(trajectory), chunksN):
-        x = trajectory[i:i+chunksN]
-        if 0 < len(x):
-          pt = np.mean(x, axis=0)
-          d = np.linalg.norm(pt - mainPt)
-          keypoints.append((pt, d))
-        continue
-      while len(keypoints) < N: keypoints.append(keypoints[0])
-      keypoints = sorted(keypoints, key=lambda x: x[1])
-      keypoints = [pt for pt, _ in keypoints]
-      keypoints = np.array(keypoints[:N])
-    else:
-      keypoints = np.array([mainPt])
-    return keypoints
+                sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+                if sampledSteps:
+                    # TODO: remove from samples?
+                    indexes.extend(sampledSteps)
+                    added = True
+                continue
 
-  def _prepareT(self, res):
-    T = np.array([self._storage[ind]['time'] for ind in res])
-    T -= T[0]
-    diff = np.diff(T, 1)
-    idx = np.nonzero(diff)[0]
-    if len(idx) < 1: return None # all frames have the same time
-    if len(diff) == len(idx):
-      T = diff
-    else:
-      # avg non-zero diff
-      dT = np.min(diff[idx])
-      T = np.append(T, T[-1] + dT)
-      idx = [0, *(1 + idx), len(T) - 1]
-      T = np.interp(np.arange(len(T) - 1), idx, T[idx])
-      T = np.diff(T, 1)
-      pass
-    T = np.insert(T, 0, 0.0)
-    assert len(res) == len(T)
-    # T is an array of time deltas like [0, 0.1, 0.1, 0.1, ...], convert it to cumulative time
-    if self._cumulative_time:
-      T = np.cumsum(T)
-    return T
-  
-  def _framesFor(self, mainInd, samples, steps, stepsSampling):
-    if 'uniform' == stepsSampling:
-      samples = random.sample(samples, steps - 1)
-    if 'last' == stepsSampling:
-      samples = samples[-(steps - 1):]
-      
-    if isinstance(stepsSampling, dict):
-      candidates = list(samples)
-      maxFrames = stepsSampling['max frames']
-      candidates = candidates[::-1]
-      samples = []
-      left = steps - 1
-      for _ in range(left):
-        avl = min((maxFrames, 1 + len(candidates) - left))
-        ind = random.randint(0, avl - 1)
-        samples.append(candidates[ind])
-        candidates = candidates[ind+1:]
-        left -= 1
-        continue
-      pass
-      
-    res = list(sorted(samples + [mainInd]))
-    assert len(res) == steps
-    return res
-  
-  def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
-    if (steps is None) or (1 == steps): return [(mainInd, 0.0)]
-    if mainInd < steps: return False
-    
-    samples, _ = self._trajectory(mainInd)
-    if len(samples) < (steps - 1): return False
-    # Try to sample valid frames
-    for _ in range(10):
-      res = self._framesFor(mainInd, samples, steps, stepsSampling)
-      T = self._prepareT(res)
-      if T is not None:
-        assert len(res) == len(T)
-        return [tuple(x) for x in zip(res, T)]
-      continue
-    return False
-  
-  def sample(self, **kwargs):
-    kwargs = {**self._defaults, **kwargs}
-    timesteps = kwargs.get('timesteps', None)
-    N = kwargs.get('N', self._batchSize)
-    indexes = []
-    for _ in range(N):
-      added = False
-      while not added:
-        idx = self._samples[self._currentSample]
-        self._currentSample = (self._currentSample + 1) % len(self._samples)
+        return self._indexes2XY(indexes, kwargs)
 
+    def sampleById(self, idx, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
         sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
-        if sampledSteps:
-          # TODO: remove from samples?
-          indexes.extend(sampledSteps)
-          added = True
-      continue
+        if not sampledSteps: return None
+        return self._indexes2XY([*sampledSteps], kwargs)
+
+    def checkById(self, idx, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+        if not sampledSteps: return False
+        return True
+
+    def sampleByIds(self, ids, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = []
+        rejected = []
+        accepted = []
+        for idx in ids:
+            sample = self._stepsFor(idx, steps=timesteps, **kwargs)
+            if sample:
+                accepted.append(idx)
+                sampledSteps.extend(sample)
+            else:
+                rejected.append(idx)
+                pass
+            continue
+
+        res = None
+        if 0 < len(sampledSteps):
+            res = self._indexes2XY(sampledSteps, kwargs)
+        return res, rejected, accepted
 
-    return self._indexes2XY(indexes, kwargs)
+    @lru_cache(None)
+    def _targetFor(self, ind):
+        mainPt = self._storage[ind]['goal']
+        keypoints = np.array(mainPt, np.float32)
+        return keypoints
 
-  def sampleById(self, idx, **kwargs):
-    kwargs = {**self._defaults, **kwargs}
-    timesteps = kwargs.get('timesteps', None)
-    sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
-    if not sampledSteps: return None
-    return self._indexes2XY([*sampledSteps], kwargs)
-  
-  def checkById(self, idx, **kwargs):
-    kwargs = {**self._defaults, **kwargs}
-    timesteps = kwargs.get('timesteps', None)
-    sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
-    if not sampledSteps: return False
-    return True
-    
-  def sampleByIds(self, ids, **kwargs):
-    kwargs = {**self._defaults, **kwargs}
-    timesteps = kwargs.get('timesteps', None)
-    sampledSteps = []
-    rejected = []
-    accepted = []
-    for idx in ids:
-      sample = self._stepsFor(idx, steps=timesteps, **kwargs)
-      if sample:
-        accepted.append(idx)
-        sampledSteps.extend(sample)
-      else:
-        rejected.append(idx)
-        pass
-      continue
-    
-    res = None
-    if 0 < len(sampledSteps):
-      res = self._indexes2XY(sampledSteps, kwargs)
-    return res, rejected, accepted
-  
-  def _reshapeSteps(self, values, steps):
-    if steps is None: return values
-    
-    res = []
-    for x in values:
-      B, *s = x.shape
-      newShape = (B // steps, steps, *s)
-      res.append(x.reshape(newShape))
-      continue
-    return tuple(res)
-  
-  @lru_cache(None)
-  def _targetFor(self, ind, keypoints=1, past=True, future=True, **_):
-    before, after = self._trajectory(ind)
-    if not past: before = []
-    if not future: after = []
-    return self._trajectory2keypoints(before, ind, after, N=keypoints)
+    def _indexes2XY(self, indexesAndTime, kwargs):
+        timesteps = kwargs.get('timesteps', None)
+        samples = [self._storage[i] for i, _ in indexesAndTime]
 
-  def _indexes2XY(self, indexesAndTime, kwargs):
-    timesteps = kwargs.get('timesteps', None)
-    samples = [self._storage[i] for i, _ in indexesAndTime]
+        Y = ( np.array([ self._targetFor(i)  for i, _ in indexesAndTime], np.float32), )
+        Y = self._reshapeSteps(Y, timesteps)
+        ##############
+        userIds = np.unique([x['userId'] for x in samples])
+        assert 1 == len(userIds), 'Only one user is supported. Found: ' + str(userIds)
+        placeIds = np.unique([x['placeId'] for x in samples])
+        assert 1 == len(placeIds), 'Only one place is supported. Found: ' + str(placeIds)
+        screenIds = np.unique([x['screenId'] for x in samples])
+        assert 1 == len(screenIds), 'Only one screen is supported. Found: ' + str(screenIds)
 
-    forecast = kwargs.get('forecast', {})
-    Y = ( np.array([
-      self._targetFor(i, **forecast) 
-      for i, _ in indexesAndTime
-    ], np.float32), )
-    Y = self._reshapeSteps(Y, timesteps)
-    ##############
-    userIds = np.unique([x['userId'] for x in samples])
-    assert 1 == len(userIds), 'Only one user is supported. Found: ' + str(userIds)
-    placeIds = np.unique([x['placeId'] for x in samples])
-    assert 1 == len(placeIds), 'Only one place is supported. Found: ' + str(placeIds)
-    screenIds = np.unique([x['screenId'] for x in samples])
-    assert 1 == len(screenIds), 'Only one screen is supported. Found: ' + str(screenIds)
+        X = DSUtils.toTensor(
+            (
+                np.array([x['points'] for x in samples], np.float32),
+                np.array([x['left eye'] for x in samples]),
+                np.array([x['right eye'] for x in samples]),
+                np.array([T for _, T in indexesAndTime], np.float32).reshape((-1, 1)),
+            ),
+            (
+                kwargs.get('pointsNoise', 0.0),
+                kwargs.get('pointsDropout', 0.0),
 
-    X = DSUtils.toTensor(
-      (
-        np.array([x['points'] for x in samples], np.float32),
-        np.array([x['left eye'] for x in samples]),
-        np.array([x['right eye'] for x in samples]),
-        np.array([T for _, T in indexesAndTime], np.float32).reshape((-1, 1)),
-      ),
-      (
-        kwargs.get('pointsNoise', 0.0),
-        kwargs.get('pointsDropout', 0.0),
-      
-        kwargs.get('eyesAdditiveNoise', 0.0),
-        kwargs.get('eyesDropout', 0.0),
-        kwargs.get('brightnessFactor', 0.0),
-        kwargs.get('lightBlobFactor', 0.0),
+                kwargs.get('eyesAdditiveNoise', 0.0),
+                kwargs.get('eyesDropout', 0.0),
+                kwargs.get('brightnessFactor', 0.0),
+                kwargs.get('lightBlobFactor', 0.0),
 
-        timesteps
-      ),
-      userIds[0], placeIds[0], screenIds[0]
-    )
-    ###############
-    (Y, ) = Y
-    return(X, (Y.astype(np.float32), ))
+                timesteps
+            ),
+            userIds[0], placeIds[0], screenIds[0]
+        )
+        ###############
+        (Y, ) = Y
+        return (X, (Y.astype(np.float32), ))
 
-  @property
-  def totalSamples(self):
-    return len(self._storage)
-  
-  def validSamples(self):
-    return list(sorted(self._samples))
-##############
-if __name__ == '__main__':
-  import tensorflow as tf
-  gpus = tf.config.experimental.list_physical_devices('GPU')
-  tf.config.experimental.set_virtual_device_configuration(
-    gpus[0], [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024//2)]
-  )
-  import os
-  from Core.CSamplesStorage import CSamplesStorage
-  folder = os.path.dirname(os.path.dirname(__file__))
-  ds = CDataSampler( CSamplesStorage(), balancingMethod=dict(context='all') )
-  dsBlock = Utils.datasetFrom(os.path.join(folder, 'Data', 'Dataset'))
-  ds.addBlock(dsBlock)
-  exit(0)
\ No newline at end of file
+    def merge(self, samples, expected_batch_size):
+        Y = np.concatenate([Y for _, (Y, ) in samples], axis=0)
+        assert len(Y) == expected_batch_size, 'Invalid batch size: %d != %d' % (len(Y), expected_batch_size)
+        # X contains the clean and augmented data
+        # each dictionary contains the subkeys: points, left eye, right eye, time, userId, placeId, screenId
+        X = {}
+        for key in ['clean', 'augmented']:
+            for subkey in ['points', 'left eye', 'right eye', 'time', 'userId', 'placeId', 'screenId']:
+                data = [x[key][subkey] for x, _ in samples]
+                X[key][subkey] = np.concatenate(data, axis=0)
+                assert X[key][subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (X[key][subkey].shape[0], expected_batch_size)
+                continue
+            continue
+        return (X, (Y, ))
\ No newline at end of file
diff --git a/Core/CDataSamplerInpainting.py b/Core/CDataSamplerInpainting.py
new file mode 100644
index 0000000..58afd15
--- /dev/null
+++ b/Core/CDataSamplerInpainting.py
@@ -0,0 +1,236 @@
+from .CBaseDataSampler import CBaseDataSampler
+import Core.CDataSampler_utils as DSUtils
+from Core.Utils import FACE_MESH_POINTS
+
+import numpy as np
+import tensorflow as tf
+
+'''
+This sampler are sample N frames from the dataset, where N is the number of timesteps.
+Within the range of the N sampled frames, its samples K frames to be inpainted/reconstructed.
+It returns the tuple (X, Y), where X is the input data and Y is the target data.
+To X could be applied some augmentations.
+X contains the following data:
+  - The points of the face.
+  - The left eye.
+  - The right eye.
+  - The time (cumulative or delta).
+  - The target point.
+  - The user ID, place ID, and screen ID.
+Y contains the target data, K frames to be inpainted/reconstructed.
+  - The points of the face.
+  - The left eye.
+  - The right eye.
+  - The normalized time.
+  - The target point.
+'''
+class CDataSamplerInpainting(CBaseDataSampler):
+    def __init__(self, storage, batch_size, minFrames, keys, defaults={}, maxT=1.0, cumulative_time=True):
+        super().__init__(storage, batch_size, minFrames, defaults, maxT, cumulative_time)
+        self._keys = keys
+
+    def _stepsFor(self, mainInd, steps, stepsSampling='uniform', **_):
+        if (steps is None) or (1 == steps): return [(mainInd, 0.0)]
+        if mainInd < steps: return False
+
+        samples, _ = self._trajectory(mainInd)
+        if len(samples) < (steps - 1): return False
+        # Try to sample valid frames
+        for _ in range(10):
+            res = self._framesFor(mainInd, samples, steps, stepsSampling)
+            T = self._prepareT(res)
+            if T is not None:
+                assert len(res) == len(T)
+                return [tuple(x) for x in zip(res, T)]
+            continue
+        return False
+
+    def sample(self, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        N = kwargs.get('N', self._batchSize) // len(self._keys)
+        indexes = []
+        added = False
+        for _ in range(N):
+            added = False
+            while not added:
+                idx = self._samples[self._currentSample]
+                self._currentSample = (self._currentSample + 1) % len(self._samples)
+
+                sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+                if sampledSteps:
+                    # TODO: remove from samples?
+                    indexes.extend(sampledSteps)
+                    added = True
+                continue
+        if not added: return None, 0
+        return self._indexes2XY(indexes, kwargs)
+
+    def sampleById(self, idx, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+        if not sampledSteps: return None
+        return self._indexes2XY([*sampledSteps], kwargs)
+
+    def checkById(self, idx, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = self._stepsFor(idx, steps=timesteps, **kwargs)
+        if not sampledSteps: return False
+        return True
+
+    def sampleByIds(self, ids, **kwargs):
+        kwargs = {**self._defaults, **kwargs}
+        timesteps = kwargs.get('timesteps', None)
+        sampledSteps = []
+        rejected = []
+        accepted = []
+        for idx in ids:
+            sample = self._stepsFor(idx, steps=timesteps, **kwargs)
+            if sample:
+                accepted.append(idx)
+                sampledSteps.extend(sample)
+            else:
+                rejected.append(idx)
+                pass
+            continue
+
+        res = None
+        if 0 < len(sampledSteps):
+            res, _ = self._indexes2XY(sampledSteps, kwargs)
+        return res, rejected, accepted
+
+    def _indexes2XY(self, indexesAndTime, kwargs):
+        timesteps = kwargs.get('timesteps', None)
+        assert timesteps is not None, 'The number of timesteps must be defined.'
+        B = len(indexesAndTime) // timesteps
+        samples = [self._storage[i] for i, _ in indexesAndTime]
+        ##############
+        userIds = np.unique([x['userId'] for x in samples])
+        assert 1 == len(userIds), 'Only one user is supported. Found: ' + str(userIds)
+        placeIds = np.unique([x['placeId'] for x in samples])
+        assert 1 == len(placeIds), 'Only one place is supported. Found: ' + str(placeIds)
+        screenIds = np.unique([x['screenId'] for x in samples])
+        assert 1 == len(screenIds), 'Only one screen is supported. Found: ' + str(screenIds)
+
+        X = DSUtils.toTensor(
+            (
+                np.array([x['points'] for x in samples], np.float32),
+                np.array([x['left eye'] for x in samples]),
+                np.array([x['right eye'] for x in samples]),
+                np.array([T for _, T in indexesAndTime], np.float32).reshape((-1, 1)),
+            ),
+            (
+                kwargs.get('pointsNoise', 0.0),
+                kwargs.get('pointsDropout', 0.0),
+
+                kwargs.get('eyesAdditiveNoise', 0.0),
+                kwargs.get('eyesDropout', 0.0),
+                kwargs.get('brightnessFactor', 0.0),
+                kwargs.get('lightBlobFactor', 0.0),
+
+                timesteps
+            ),
+            userIds[0], placeIds[0], screenIds[0]
+        )
+        for k in X.keys():
+            # add the target point to the X
+            targets = np.array([x['goal'] for x in samples], np.float32).reshape((B, timesteps, 2))
+            X[k]['target'] = tf.constant(targets, dtype=tf.float32)
+
+        if 1 == len(self._keys):
+            X = X[self._keys[0]]
+        else:
+            res = {}
+            k = self._keys[0]
+            subkeys = list(X[k].keys())
+            for k in subkeys:
+                values = [X[key][k] for key in self._keys]
+                res[k] = tf.concat(values, axis=0)
+                continue
+            X = res
+            indexesAndTime = indexesAndTime * len(self._keys)
+            B = len(self._keys) * B
+
+        ###############
+        # generate the target data
+        targets = kwargs.get('targets', {'keypoints': timesteps, 'total': timesteps})
+        K = targets.get('keypoints', timesteps)
+        assert K <= timesteps, 'The number of keypoints to be inpainted/reconstructed must be less or equal to the number of timesteps.'
+        T = targets.get('total', timesteps)
+        assert K <= T, 'The total number of frames to be inpainted/reconstructed must be less or equal to the total number of timesteps.'
+
+        samples_indexes = np.array([ i  for i, _ in indexesAndTime], np.int32)
+        samples_indexes = samples_indexes.reshape((-1, timesteps))
+        assert samples_indexes.shape[0] == B, 'Invalid number of samples: %d != %d' % (samples_indexes.shape[0], B)
+        targetsIdx = np.zeros((B, T), np.int32)
+        for i in range(B):
+            # sample K frames from the X
+            sampled = np.random.choice(samples_indexes[i], K, replace=False)
+            # repeat the sampled frames to fill the K frames
+            targetsIdx[i, :] = np.repeat(sampled, 1 + (T // K))[:T]
+            # sample the remaining frames
+            if K < T: # if need to sample more frames
+                startFrameIdx = samples_indexes[i, 0]
+                endFrameIdx = samples_indexes[i, -1]
+                allFrames = np.arange(startFrameIdx, endFrameIdx + 1)
+                # exclude the frames that are already sampled
+                allFrames = np.array([x for x in allFrames if x not in sampled], np.int32)
+                # sample the remaining frames
+                targetsIdx[i, K:] = np.random.choice(allFrames, T - K, replace=True)
+            continue
+        # targetsIdx contains the indexes of the frames to be inpainted/reconstructed
+        # we need to collect the data for the Y
+        # required data: points, left eye, right eye, time, target point
+        Y = {
+            'points': np.zeros((B, T, FACE_MESH_POINTS, 2), np.float32),
+            'left eye': np.zeros((B, T, 32, 32), np.float32),
+            'right eye': np.zeros((B, T, 32, 32), np.float32),
+            'time': np.zeros((B, T, 1), np.float32),
+            'target': np.zeros((B, T, 2), np.float32)
+        }
+        for i in range(B):
+            idxForSample = samples_indexes[i]
+            # stricly increasing indexes
+            assert np.all(0 < np.diff(idxForSample)), 'Invalid indexes: ' + str(idxForSample)
+            startT = self._storage[idxForSample[0]]['time']
+            endT = self._storage[idxForSample[-1]]['time']
+            duration = endT - startT
+            assert 0 < duration, 'Invalid duration: ' + str(duration)
+            targets_idx = np.sort(targetsIdx[i])
+            for j, idx in enumerate(targets_idx):
+                data = self._storage[idx]
+                Y['points'][i, j] = data['points']
+                # eyes should be cropped to 32x32, so we use the central crop
+                p = (data['left eye'].shape[0] - 32) // 2
+                Y['left eye'][i, j] = data['left eye'][p:p+32, p:p+32]
+                Y['right eye'][i, j] = data['right eye'][p:p+32, p:p+32]
+                Y['time'][i, j] = (data['time'] - startT) / duration
+                Y['target'][i, j] = data['goal']
+        # eyes in 0..255, so we need to normalize them
+        Y['left eye'] /= 255.0
+        Y['right eye'] /= 255.0
+        # check that time is between 0 and 1
+        assert np.all((0 <= Y['time']) & (Y['time'] <= 1)), 'Invalid time: ' + str(Y['time'])
+        for k, v in X.items():
+            assert B == v.shape[0], f'Invalid batch size for X[{k}]: {v.shape[0]} != {B} ({v.shape})'
+        for k, v in Y.items():
+            assert B == v.shape[0], f'Invalid batch size for Y[{k}]: {v.shape[0]} != {B} ({v.shape})'
+        return (X, Y), B
+    
+    def merge(self, samples, expected_batch_size):
+        X = {}
+        for subkey in ['points', 'left eye', 'right eye', 'time', 'userId', 'placeId', 'screenId', 'target']:
+            data = [x[subkey] for x, _ in samples]
+            X[subkey] = np.concatenate(data, axis=0)
+            assert X[subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (X[subkey].shape[0], expected_batch_size)
+            continue
+        # 
+        Y = {}
+        for subkey in ['points', 'left eye', 'right eye', 'time', 'target']:
+            data = [y[subkey] for _, y in samples]
+            Y[subkey] = np.concatenate(data, axis=0)
+            assert Y[subkey].shape[0] == expected_batch_size, 'Invalid batch size: %d != %d' % (Y[subkey].shape[0], expected_batch_size)
+            continue
+        return (X, Y)
\ No newline at end of file
diff --git a/Core/CDataSampler_utils.py b/Core/CDataSampler_utils.py
index ef85cb2..90fa132 100644
--- a/Core/CDataSampler_utils.py
+++ b/Core/CDataSampler_utils.py
@@ -109,7 +109,7 @@ def toTensor(data, params, userId, placeId, screenId):
   ##########################
   # random crop 32x32 eyes
   fraction = 32.0 / 48.0
-  pos = tf.random.uniform((N, 2), minval=0.0, maxval=1.0 - fraction)
+  pos = tf.random.uniform((N, 2), minval=0.0, maxval=2.0 * fraction)
   boxes = tf.concat([pos, pos + fraction], axis=-1)
   tf.assert_equal(tf.shape(boxes), (N, 4))
   imgA = tf.image.crop_and_resize(
diff --git a/Core/CDatasetLoader.py b/Core/CDatasetLoader.py
index a526503..d0c3733 100644
--- a/Core/CDatasetLoader.py
+++ b/Core/CDatasetLoader.py
@@ -1,9 +1,7 @@
 import Core.Utils as Utils
-import os, glob
+import os
 from Core.CSamplesStorage import CSamplesStorage
-from Core.CDataSampler import CDataSampler
 import numpy as np
-import tensorflow as tf
 from enum import Enum
 
 class ESampling(Enum):
@@ -11,38 +9,39 @@ class ESampling(Enum):
   UNIFORM = 'uniform'
   
 class CDatasetLoader:
-  def __init__(self, folder, samplerArgs, sampling, stats):
-    # recursively find all 'train.npz' files
-    trainFiles = glob.glob(os.path.join(folder, '**', 'train.npz'), recursive=True)
-    if 0 == len(trainFiles):
-      raise Exception('No training dataset found in "%s"' % (folder, ))
-      exit(1)
-    
-    print('Found %d training datasets' % (len(trainFiles), ))
-
+  def __init__(self, folder, samplerArgs, sampling, stats, sampler_class, test_folders):
     self._datasets = []
-    for trainFile in trainFiles:
-      print('Loading %s' % (trainFile, ))
-      # extract the placeId, userId, and screenId
-      parts = os.path.split(trainFile)[0].split(os.path.sep)
-      placeId, userId, screenId = parts[-3], parts[-2], parts[-1]
-      ds = CDataSampler(
-        CSamplesStorage(
-          placeId=stats['placeId'].index(placeId),
-          userId=stats['userId'].index(userId),
-          screenId=stats['screenId'].index('%s/%s' % (placeId, screenId))
-        ),
-        **samplerArgs
-      )
-      ds.addBlock(Utils.datasetFrom(trainFile))
-      self._datasets.append(ds)
-      continue
+    for datasetFolder, ID in Utils.dataset_from_stats(stats, folder):
+      (place_id_index, user_id_index, screen_id_index) = ID
+      for test_folder in test_folders:
+        dataset = os.path.join(datasetFolder, test_folder)
+        if not os.path.exists(dataset):
+          continue
+        print('Loading %s' % (dataset, ))
+        print(f'ID: {ID}. Index: {1 + len(self._datasets)}')
+        ds = sampler_class(
+          CSamplesStorage(
+            placeId=place_id_index,
+            userId=user_id_index,
+            screenId=screen_id_index,
+          ),
+          **samplerArgs
+        )
+        ds.addBlock(Utils.datasetFrom(dataset))
+        self._datasets.append(ds)
+
+    if 0 == len(self._datasets):
+      raise Exception('No training dataset found in "%s"' % (folder, ))
     
-    print('Loaded %d datasets' % (len(self._datasets), ))
     validSamples = {
       i: len(ds.validSamples())
       for i, ds in enumerate(self._datasets)
     }
+    # ignore datasets with no valid samples
+    validSamples = {k: v for k, v in validSamples.items() if 0 < v}
+
+    print('Loaded %d datasets with %d valid samples' % (len(self._datasets), sum(validSamples.values())))
+
     dtype = np.uint8 if len(self._datasets) < 256 else np.uint32
     # create an array of dataset indices to sample from
     sampling = ESampling(sampling)
@@ -98,45 +97,18 @@ def _getBatchStats(self, batchSize):
 
   def sample(self, **kwargs):
     batchSize = kwargs.get('batch_size', self._batchSize)
-    resX = []
-    resY = []
+    samples = []
     totalSamples = 0
     # find the datasets ids and the number of samples to take from each dataset
-    datasetIds, counts = self._getBatchStats(batchSize)
-    for datasetId, N in zip(datasetIds, counts):
-      dataset = self._datasets[datasetId]
-      x, (y, ) = dataset.sample(N=N, **kwargs)
-      assert len(y) == N, 'Invalid number of samples: %d != %d' % (len(y), N)
-      resX.append(x)
-      resY.append(y)
-      totalSamples += len(y)
-      continue
+    while totalSamples < batchSize:
+      datasetIds, counts = self._getBatchStats(batchSize - totalSamples)
+      for datasetId, N in zip(datasetIds, counts):
+        dataset = self._datasets[datasetId]
+        sampled, N = dataset.sample(N=N, **kwargs)
+        if 0 < N:
+          samples.append(sampled)
+        totalSamples += N
+        continue
     
-    resY = np.concatenate(resY, axis=0)
-    assert resY.shape[0] == batchSize, 'Invalid shape: %s' % (resY.shape, )
-    assert resY.shape[-1] == 2, 'Invalid shape: %s' % (resY.shape, )
-    assert len(resY.shape) == 4, 'Invalid shape: %s' % (resY.shape, )
-
-    # sampled data has 'clean' and 'augmented' keys
-    output = {}
-    for nm in ['clean', 'augmented']:
-      keys = resX[0][nm].keys()
-      output[nm] = {k: tf.concat([x[nm][k] for x in resX], axis=0) for k in keys}
-      continue
-    return output, (resY,)
-  
-if __name__ == '__main__':
-  import cv2
-  folder = os.path.dirname(__file__)
-  ds = CDatasetLoader(
-    os.path.join(folder, 'Dataset'), batch_size=16, 
-    batchPerEpoch=1, steps=5
-  )
-  print(len(ds))
-  batchX, batchY = ds[0]
-  print(batchY[0].shape)
-  print(batchX[1].shape)
-  img = batchX[1][0, 0]
-  cv2.imshow('L', cv2.resize(img, (256, 256)))
-  cv2.waitKey()
-  pass
\ No newline at end of file
+    first_dataset = self._datasets[0]
+    return first_dataset.merge(samples, batchSize)
diff --git a/Core/CInpaintingTrainer.py b/Core/CInpaintingTrainer.py
new file mode 100644
index 0000000..2d457ae
--- /dev/null
+++ b/Core/CInpaintingTrainer.py
@@ -0,0 +1,106 @@
+import tensorflow as tf
+import time
+import NN.Utils as NNU
+import NN.networks as networks
+from Core.CBaseModel import CBaseModel
+
+class CInpaintingTrainer:
+  def __init__(self, timesteps, model='simple', KP=5, **kwargs):
+    stats = kwargs.get('stats', None)
+    embeddingsSize = kwargs.get('embeddingsSize', 64)
+    latentSize = kwargs.get('latentSize', 64)
+    embeddings = {
+      'userId': len(stats['userId']),
+      'placeId': len(stats['placeId']),
+      'screenId': len(stats['screenId']),
+      'size': embeddingsSize,
+    }
+
+    self._encoder = networks.InpaintingEncoderModel(
+      steps=timesteps, latentSize=latentSize,
+      embeddingsSize=embeddingsSize,
+      KP=KP,
+    )
+    self._decoder = networks.InpaintingDecoderModel(
+      latentSize=latentSize,
+      embeddingsSize=embeddingsSize,
+      KP=KP,
+    )
+    self._model = CBaseModel(
+       model=model, embeddings=embeddings, submodels=[self._encoder, self._decoder]
+    )
+    self.compile()
+    # add signatures to help tensorflow optimize the graph
+    specification = networks.InpaintingInputSpec()
+    self._trainStep = tf.function(
+      self._trainStep,
+      input_signature=[specification]
+    )
+    self._eval = tf.function(
+      self._eval,
+      input_signature=[specification]
+    )
+
+    if 'weights' in kwargs:
+      self.load(**kwargs['weights'])
+    return
+  
+  def compile(self):
+    self._optimizer = NNU.createOptimizer()
+
+  def _calcLoss(self, x, y, training):
+    losses = {}
+    x = self._model.replaceByEmbeddings(x)
+    latents = self._encoder(x, training=training)['latent']
+    decoderArgs = {
+      'keyPoints': latents,
+      'time': y['time'],
+      'userId': x['userId'],
+      'placeId': x['placeId'],
+      'screenId': x['screenId'],
+    }
+    predictions = self._decoder(decoderArgs, training=training)
+    losses = {}
+    for k in predictions.keys():
+      pred = predictions[k]
+      gt = y[k]
+      tf.assert_equal(tf.shape(pred), tf.shape(gt))
+      loss = tf.losses.mse(gt, pred)
+      losses[f"loss-{k}"] = tf.reduce_mean(loss)
+      
+    # calculate total loss and final loss
+    losses['loss'] = sum(losses.values())
+    return losses, losses['loss']
+  
+  def _trainStep(self, Data):
+    print('Instantiate _trainStep')
+    ###############
+    x, y = Data
+    with tf.GradientTape() as tape:
+      losses, loss = self._calcLoss(x, y, training=True)
+  
+    self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
+    ###############
+    return losses
+
+  def fit(self, data):
+    t = time.time()
+    losses = self._trainStep(data)
+    losses = {k: v.numpy() for k, v in losses.items()}
+    return {'time': int((time.time() - t) * 1000), 'losses': losses}
+  
+  def _eval(self, xy):
+    print('Instantiate _eval')
+    x, y = xy
+    losses, loss = self._calcLoss(x, y, training=False)
+    return loss
+
+  def eval(self, data):
+    loss = self._eval(data)
+    return loss.numpy()
+    
+  def save(self, folder=None, postfix=''):
+    self._model.save(folder=folder, postfix=postfix)
+
+  def load(self, folder=None, postfix='', embeddings=False):
+    self._model.load(folder=folder, postfix=postfix, embeddings=embeddings)
\ No newline at end of file
diff --git a/Core/CModelDiffusion.py b/Core/CModelDiffusion.py
deleted file mode 100644
index 20a9259..0000000
--- a/Core/CModelDiffusion.py
+++ /dev/null
@@ -1,272 +0,0 @@
-import os
-import numpy as np
-import NN.networks as networks
-import tensorflow as tf
-import tensorflow_probability as tfp
-import NN.Utils as NNU
-import time
-from tensorflow.keras import layers as L
-
-# TODO: Implement the standard diffusion process (with the prediction of the noise, proper sampling, etc)
-class CModelDiffusion:
-  '''
-  Wrapper for the diffusion model to predict the gaze point
-  Diffusion T is equal to the stddev of the gaussian noise
-  '''
-  def __init__(self, timesteps, model='simple', user=None, stats=None, use_encoders=False, **kwargs):
-    if user is None:
-      user = {
-        'userId': 0,
-        'placeId': 0,
-        'screenId': 0,
-      }
-    else:
-      user = {
-        'userId': stats['userId'].index(user['userId']),
-        'placeId': stats['placeId'].index(user['placeId']),
-        'screenId': stats['screenId'].index(user['screenId']),
-      }
-    self._user = user
-
-    self._modelID = model
-    self._timesteps = timesteps
-    embeddings = {
-      'userId': len(stats['userId']),
-      'placeId': len(stats['placeId']),
-      'screenId': len(stats['screenId']),
-      'size': 64,
-    }
-    self._modelRaw = networks.Face2LatentModel(
-      steps=timesteps, latentSize=64, embeddings=embeddings,
-      diffusion=True
-    )
-    self._model = self._modelRaw['main']
-    self._embeddings = {
-      'userId': L.Embedding(len(stats['userId']), embeddings['size']),
-      'placeId': L.Embedding(len(stats['placeId']), embeddings['size']),
-      'screenId': L.Embedding(len(stats['screenId']), embeddings['size']),
-    }
-    self._intermediateEncoders = {}
-    if use_encoders:
-      shapes = self._modelRaw['intermediate shapes']
-      for name, shape in shapes.items():
-        enc = networks.IntermediatePredictor(name='%s-encoder' % name)
-        enc.build(shape)
-        self._intermediateEncoders[name] = enc
-        continue
-   
-    self._maxDiffusionT = 100.0
-    if 'weights' in kwargs:
-      self.load(**kwargs['weights'])
-    self.compile()
-    # add signatures to help tensorflow optimize the graph
-    specification = self._modelRaw['inputs specification']
-    self._trainStep = tf.function(
-      self._trainStep,
-      input_signature=[
-        (
-          { 'clean': specification, 'augmented': specification, },
-          ( tf.TensorSpec(shape=(None, None, None, 2), dtype=tf.float32), )
-        )
-      ]
-    )
-    self._eval = tf.function(
-      self._eval,
-      input_signature=[(
-        specification,
-        ( tf.TensorSpec(shape=(None, None, None, 2), dtype=tf.float32), )
-      )]
-    )
-
-    return
-  
-  def _step2mean(self, step):
-    step = tf.cast(step, tf.float32) / self._maxDiffusionT
-    step = tf.cast(step, tf.float32) + 1e-6
-    # step = tf.pow(step, 2.0) # make it decrease faster
-    return tf.clip_by_value(step, 1e-3, 1.0)
-  
-  def _replaceByEmbeddings(self, data):
-    data = dict(**data) # copy
-    for name, emb in self._embeddings.items():
-      data[name] = emb(data[name][..., 0])
-      continue
-    return data
-  
-  def _makeGaussian(self, mean, stddev):
-    stddev = tf.concat([stddev, stddev], axis=-1)
-    return tfp.distributions.MultivariateNormalDiag(mean, stddev)
-  
-  @tf.function
-  def _infer(self, data, training=False):
-    print('Instantiate _infer')
-    data = self._replaceByEmbeddings(data)
-    shp = tf.shape(data['userId'])
-    B, N = shp[0], self.timesteps
-    result = tf.zeros((B, N, 2), dtype=tf.float32)
-    for step in tf.range(self._maxDiffusionT, -1, -5):
-      mean = self._step2mean(
-        tf.fill((B, N, 1), step)
-      )
-      stepData = dict(**data)
-      stepData['diffusionT'] = mean
-      stepData['diffusionPoints'] = tf.random.normal((B, N, 2), mean=result, stddev=mean)
-      result = self._model(stepData, training=training)['result']
-    return result
-    
-  def predict(self, data, **kwargs):
-    B = self._timesteps
-    userId = kwargs.get('userId', self._user['userId'])
-    placeId = kwargs.get('placeId', self._user['placeId'])
-    screenId = kwargs.get('screenId', self._user['screenId'])
-    # put them as (1, B, ?)
-    data['userId'] = np.full((1, B, 1), userId, dtype=np.int32)
-    data['placeId'] = np.full((1, B, 1), placeId, dtype=np.int32)
-    data['screenId'] = np.full((1, B, 1), screenId, dtype=np.int32)
-
-    data = self._replaceByEmbeddings(data) # replace embeddings
-    
-    result = self._infer(data)
-    return result.numpy()
-  
-  def __call__(self, data, startPos=None):
-    predictions = self.predict(data)
-    return {
-      'coords': predictions[0, -1, :],
-    }
-    
-  def compile(self):
-    self._optimizer = NNU.createOptimizer()
-    return
-
-  def _modelFilename(self, folder, postfix=''):
-    postfix = '-' + postfix if postfix else ''
-    return os.path.join(folder, '%s-%s%s.h5' % (self._modelID, 'model', postfix))
-  
-  def save(self, folder=None, postfix=''):
-    path = self._modelFilename(folder, postfix)
-    self._model.save_weights(path)
-    embeddings = {}
-    for nm in self._embeddings.keys():
-      weights = self._embeddings[nm].get_weights()[0]
-      embeddings[nm] = weights
-      continue
-    np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
-    # save intermediate encoders
-    if self._intermediateEncoders:
-      encoders = {}
-      for nm, encoder in self._intermediateEncoders.items():
-        # save each variable separately
-        for ww in encoder.trainable_variables:
-          encoders['%s-%s' % (nm, ww.name)] = ww.numpy()
-        continue
-      np.savez_compressed(path.replace('.h5', '-intermediate-encoders.npz'), **encoders)
-    return
-    
-  def load(self, folder=None, postfix='', embeddings=False):
-    path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
-    self._model.load_weights(path)
-    if embeddings:
-      embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
-      for nm, emb in self._embeddings.items():
-        w = embeddings[nm]
-        if not emb.built: emb.build((None, w.shape[0]))
-        emb.set_weights([w]) # replace embeddings
-        continue
-    
-    if self._intermediateEncoders:
-      encodersName = path.replace('.h5', '-intermediate-encoders.npz')
-      if os.path.isfile(encodersName):
-        encoders = np.load(encodersName)
-        for nm, encoder in self._intermediateEncoders.items():
-          for ww in encoder.trainable_variables:
-            w = encoders['%s-%s' % (nm, ww.name)]
-            ww.assign(w)
-          continue
-    return
-  
-  def lock(self, isLocked):
-    self._model.trainable = not isLocked
-    return
-  
-  @property
-  def timesteps(self):
-    return self._timesteps
-  
-  def trainable_variables(self):
-    parts = list(self._embeddings.values()) + [self._model] + list(self._intermediateEncoders.values())
-    return sum([p.trainable_variables for p in parts], [])  
-  
-  def _pointLoss(self, ytrue, ypred):
-    # pseudo huber loss
-    delta = 0.01
-    tf.assert_equal(tf.shape(ytrue), tf.shape(ypred))
-    diff = tf.square(ytrue - ypred)
-    loss = tf.sqrt(diff + delta ** 2) - delta
-    tf.assert_equal(tf.shape(loss), tf.shape(ytrue))
-    return tf.reduce_mean(loss, axis=-1)
-
-  def _trainStep(self, Data):
-    print('Instantiate _trainStep')
-    ###############
-    x, (y, ) = Data
-    y = y[..., 0, :]
-    losses = {}
-    with tf.GradientTape() as tape:
-      data = x['augmented']
-      data = self._replaceByEmbeddings(data)
-      # add sampled T
-      B = tf.shape(y)[0]
-      N = self.timesteps
-      maxT = 100
-      diffusionT = tf.random.uniform((B, 1), minval=0, maxval=maxT, dtype=tf.int32)
-      # (B, 1) -> (B, N, 1)
-      diffusionT = tf.tile(diffusionT, (1, N))[..., None]
-      diffusionT = self._step2mean(diffusionT)
-      tf.assert_equal(tf.shape(diffusionT), (B, N, 1))
-      
-      # store the diffusion parameters
-      data['diffusionT'] = diffusionT
-      # sample the points
-      data['diffusionPoints'] = tf.random.normal((B, N, 2), mean=y, stddev=diffusionT)
-      predictions = self._model(data, training=True)
-    #   intermediate = predictions['intermediate']
-    #   assert len(intermediate) == 0, 'Intermediate predictions are not supported'
-      
-      predictedMean = predictions['result']
-      gaussian = self._makeGaussian(predictedMean, diffusionT)
-      losses['log_prob'] = tf.reduce_mean(
-        -gaussian.log_prob(y)
-      )
-      losses['points'] = self._pointLoss(y, predictedMean)
-      loss = sum(losses.values())
-      losses['loss'] = loss
-  
-    self._optimizer.minimize(loss, tape.watched_variables(), tape=tape)
-    ###############
-    return losses
-
-  def fit(self, data):
-    t = time.time()
-    losses = self._trainStep(data)
-    losses = {k: v.numpy() for k, v in losses.items()}
-    return {'time': int((time.time() - t) * 1000), 'losses': losses}
-  
-  def _eval(self, xy):
-    print('Instantiate _eval')
-    x, (y,) = xy
-    y = y[:, :, 0]
-    B, N = tf.shape(y)[0], tf.shape(y)[1]
-    
-    predictions = self._infer(x)
-    
-    mean = self._step2mean(tf.fill((B, N, 1), 0))
-    gaussian = self._makeGaussian(predictions, mean)
-    loss = tf.nn.sigmoid( -gaussian.log_prob(y) )
-    points = predictions
-    _, dist = NNU.normVec(y - predictions)
-    return loss, points, dist
-
-  def eval(self, data):
-    loss, sampled, dist = self._eval(data)
-    return loss.numpy(), sampled.numpy(), dist.numpy()
\ No newline at end of file
diff --git a/Core/CModelWrapper.py b/Core/CModelWrapper.py
index 2adea8e..e63da2d 100644
--- a/Core/CModelWrapper.py
+++ b/Core/CModelWrapper.py
@@ -1,63 +1,34 @@
-import os
 import numpy as np
 import NN.networks as networks
-import tensorflow as tf
-from tensorflow.keras import layers as L
-
+from Core.CBaseModel import CBaseModel
+  
 class CModelWrapper:
-  def __init__(self, timesteps, model='simple', user=None, stats=None, use_encoders=True, **kwargs):
-    if user is None:
-      user = {
-        'userId': 0,
-        'placeId': 0,
-        'screenId': 0,
-      }
-    else:
+  def __init__(self, timesteps, model='simple', user=None, stats=None, **kwargs):
+    if user is not None:
       user = {
         'userId': stats['userId'].index(user['userId']),
         'placeId': stats['placeId'].index(user['placeId']),
         'screenId': stats['screenId'].index(user['screenId']),
       }
     self._user = user
-
-    self._modelID = model
     self._timesteps = timesteps
     embeddings = {
       'userId': len(stats['userId']),
       'placeId': len(stats['placeId']),
       'screenId': len(stats['screenId']),
-      'size': 64,
+      'size': kwargs.get('embeddingSize', 64),
     }
     self._modelRaw = networks.Face2LatentModel(
-      steps=timesteps, latentSize=64, embeddings=embeddings
+      steps=timesteps, latentSize=kwargs.get('latentSize', 64), embeddings=embeddings
     )
-    self._model = self._modelRaw['main']
-    self._embeddings = {
-      'userId': L.Embedding(len(stats['userId']), embeddings['size']),
-      'placeId': L.Embedding(len(stats['placeId']), embeddings['size']),
-      'screenId': L.Embedding(len(stats['screenId']), embeddings['size']),
-    }
-    self._intermediateEncoders = {}
-    if use_encoders:
-      shapes = self._modelRaw['intermediate shapes']
-      for name, shape in shapes.items():
-        enc = networks.IntermediatePredictor(name='%s-encoder' % name)
-        enc.build(shape)
-        self._intermediateEncoders[name] = enc
-        continue
-   
+    NN =  self._network = self._modelRaw['main']
+    self._model = CBaseModel(model=model, embeddings=embeddings, submodels=[NN])
     if 'weights' in kwargs:
       self.load(**kwargs['weights'])
     return
   
-  def _replaceByEmbeddings(self, data):
-    data = dict(**data) # copy
-    for name, emb in self._embeddings.items():
-      data[name] = emb(data[name][..., 0])
-      continue
-    return data
-  
   def predict(self, data, **kwargs):
+    assert self._user is not None, 'User is not set'
     B = self._timesteps
     userId = kwargs.get('userId', self._user['userId'])
     placeId = kwargs.get('placeId', self._user['placeId'])
@@ -68,68 +39,21 @@ def predict(self, data, **kwargs):
     data['screenId'] = np.full((1, B, 1), screenId, dtype=np.int32)
 
     data = self._replaceByEmbeddings(data) # replace embeddings
-    return self._model(data, training=False)['result'].numpy()
+    return self._network(data, training=False)['result'].numpy()
   
   def __call__(self, data, startPos=None):
     predictions = self.predict(data)
-    return {
-      'coords': predictions[0, -1, :],
-    }
-
-  def _modelFilename(self, folder, postfix=''):
-    postfix = '-' + postfix if postfix else ''
-    return os.path.join(folder, '%s-%s%s.h5' % (self._modelID, 'model', postfix))
-  
-  def save(self, folder=None, postfix=''):
-    path = self._modelFilename(folder, postfix)
-    self._model.save_weights(path)
-    embeddings = {}
-    for nm in self._embeddings.keys():
-      weights = self._embeddings[nm].get_weights()[0]
-      embeddings[nm] = weights
-      continue
-    np.savez_compressed(path.replace('.h5', '-embeddings.npz'), **embeddings)
-    # save intermediate encoders
-    if self._intermediateEncoders:
-      encoders = {}
-      for nm, encoder in self._intermediateEncoders.items():
-        # save each variable separately
-        for ww in encoder.trainable_variables:
-          encoders['%s-%s' % (nm, ww.name)] = ww.numpy()
-        continue
-      np.savez_compressed(path.replace('.h5', '-intermediate-encoders.npz'), **encoders)
-    return
-    
-  def load(self, folder=None, postfix='', embeddings=False):
-    path = self._modelFilename(folder, postfix) if not os.path.isfile(folder) else folder
-    self._model.load_weights(path)
-    if embeddings:
-      embeddings = np.load(path.replace('.h5', '-embeddings.npz'))
-      for nm, emb in self._embeddings.items():
-        w = embeddings[nm]
-        if not emb.built: emb.build((None, w.shape[0]))
-        emb.set_weights([w]) # replace embeddings
-        continue
-    
-    if self._intermediateEncoders:
-      encodersName = path.replace('.h5', '-intermediate-encoders.npz')
-      if os.path.isfile(encodersName):
-        encoders = np.load(encodersName)
-        for nm, encoder in self._intermediateEncoders.items():
-          for ww in encoder.trainable_variables:
-            w = encoders['%s-%s' % (nm, ww.name)]
-            ww.assign(w)
-          continue
-    return
-  
-  def lock(self, isLocked):
-    self._model.trainable = not isLocked
-    return
+    return { 'coords': predictions[0, -1, :], }
   
   @property
   def timesteps(self):
     return self._timesteps
   
+  def save(self, folder=None, postfix=''):
+    self._model.save(folder=folder, postfix=postfix)
+
+  def load(self, folder=None, postfix='', embeddings=False):
+    self._model.load(folder=folder, postfix=postfix, embeddings=embeddings)
+
   def trainable_variables(self):
-    parts = list(self._embeddings.values()) + [self._model] + list(self._intermediateEncoders.values())
-    return sum([p.trainable_variables for p in parts], [])
\ No newline at end of file
+    return self._model.trainable_variables()
\ No newline at end of file
diff --git a/Core/CTestInpaintingLoader.py b/Core/CTestInpaintingLoader.py
new file mode 100644
index 0000000..f67a86b
--- /dev/null
+++ b/Core/CTestInpaintingLoader.py
@@ -0,0 +1,34 @@
+import tensorflow as tf
+import numpy as np
+import os, glob
+from functools import lru_cache
+
+class CTestInpaintingLoader(tf.keras.utils.Sequence):
+  def __init__(self, testFolder):
+    self._batchesNpz = [
+      f for f in glob.glob(os.path.join(testFolder, 'test-*.npz'))
+    ]
+    self.on_epoch_end()
+    return
+  
+  @lru_cache(maxsize=1)
+  def parametersIDs(self):
+    batch, _ = self[0]
+    userId = batch['userId'][0, 0, 0]
+    placeId = batch['placeId'][0, 0, 0]
+    screenId = batch['screenId'][0, 0, 0]
+    return placeId, userId, screenId
+    
+  def on_epoch_end(self):
+    return
+
+  def __len__(self):
+    return len(self._batchesNpz)
+  
+  def __getitem__(self, idx):
+    with np.load(self._batchesNpz[idx]) as res:
+      res = {k: v for k, v in res.items()}
+      
+    X = {k.replace('X_', ''): v for k, v in res.items() if 'X_' in k}
+    Y = {k.replace('Y_', ''): v for k, v in res.items() if 'Y_' in k}
+    return(X, Y)
\ No newline at end of file
diff --git a/Core/CTestLoader.py b/Core/CTestLoader.py
index f3bc23d..be11247 100644
--- a/Core/CTestLoader.py
+++ b/Core/CTestLoader.py
@@ -5,25 +5,20 @@
 
 class CTestLoader(tf.keras.utils.Sequence):
   def __init__(self, testFolder):
-    self._folder = testFolder
     self._batchesNpz = [
       f for f in glob.glob(os.path.join(testFolder, 'test-*.npz'))
     ]
     self.on_epoch_end()
     return
-  
-  @property
-  def folder(self):
-    return self._folder
-  
+
   @lru_cache(maxsize=1)
   def parametersIDs(self):
     batch, _ = self[0]
     userId = batch['userId'][0, 0, 0]
     placeId = batch['placeId'][0, 0, 0]
     screenId = batch['screenId'][0, 0, 0]
-    return userId, placeId, screenId
-    
+    return placeId, userId, screenId
+      
   def on_epoch_end(self):
     return
 
@@ -35,15 +30,4 @@ def __getitem__(self, idx):
       res = {k: v for k, v in res.items()}
       
     Y = res.pop('y')
-    return(res, (Y, ))
-
-if __name__ == '__main__':
-  folder = os.path.dirname(__file__)
-  ds = CTestLoader(os.path.join(folder, 'test'))
-  print(len(ds))
-  batch, (y,) = ds[0]
-  for k, v in batch.items():
-    print(k, v.shape)
-  print()
-  print(y.shape)
-  pass
\ No newline at end of file
+    return(res, (Y, ))
\ No newline at end of file
diff --git a/Core/Utils.py b/Core/Utils.py
index e7d5c54..3173fcf 100644
--- a/Core/Utils.py
+++ b/Core/Utils.py
@@ -191,8 +191,9 @@ def setupGPU():
       PART_TO_INDECES[k].add(i)
 ###################################
 FACE_MESH_INVALID_VALUE = -10.0
+FACE_MESH_POINTS = 478
 def decodeLandmarks(landmarks, VISIBILITY_THRESHOLD, PRESENCE_THRESHOLD):
-  points = np.full((478, 2), fill_value=FACE_MESH_INVALID_VALUE, dtype=np.float32)
+  points = np.full((FACE_MESH_POINTS, 2), fill_value=FACE_MESH_INVALID_VALUE, dtype=np.float32)
   for idx, mark in enumerate(landmarks.landmark):
     if (
       (mark.HasField('visibility') and (mark.visibility < VISIBILITY_THRESHOLD)) 
@@ -271,6 +272,7 @@ def extractSessions(dataset, TDelta):
   res = []
   T = 0
   prevSession = 0
+  N = len(dataset['time'])
   for i, t in enumerate(dataset['time']):
     if TDelta < (t - T):
       if 1 < (i - prevSession):
@@ -280,10 +282,13 @@ def extractSessions(dataset, TDelta):
     T = t
     continue
   # if last session is not empty, then append it
-  if prevSession < len(dataset['time']):
-    res.append((prevSession, len(dataset['time'])))
+  if prevSession < N:
+    res.append((prevSession, N))
     pass
 
+  # remove sessions with less than 2 samples
+  res = [x for x in res if 1 < (x[1] - x[0])]
+
   # check that end of one session is equal or less than start of the next
   for i in range(1, len(res)):
     assert res[i-1][1] <= res[i][0]
@@ -296,4 +301,25 @@ def countSamplesIn(folder):
     with np.load(fn) as data:
       res += len(data['time'])
     continue
-  return res
\ No newline at end of file
+  return res
+
+def dataset_from_stats(stats, folder):
+  userId = stats['userId']
+  placeId = stats['placeId']
+  screenId = stats['screenId']
+  # screenId is a concatenation of placeId and screenId, to make it unique pair
+  PlaceAndScreenId = [x.split('/') for x in screenId]
+
+  blackList = stats.get('blacklist', [])
+  known = set([tuple(x) for x in blackList])
+  for screen_id_index, (place_id, screen_id) in enumerate(PlaceAndScreenId):
+    place_id_index = placeId.index(place_id)
+    # find user_id among all
+    for user_id_index, user_id in enumerate(userId):
+      datasetFolder = os.path.join(folder, place_id, user_id, screen_id)
+      if not os.path.exists(datasetFolder): continue
+      ID = (place_id_index, user_id_index, screen_id_index)
+      if ID in known: continue
+      known.add(ID)
+
+      yield (datasetFolder, ID)
\ No newline at end of file
diff --git a/NN/LagrangianInterpolation.py b/NN/LagrangianInterpolation.py
new file mode 100644
index 0000000..88b2942
--- /dev/null
+++ b/NN/LagrangianInterpolation.py
@@ -0,0 +1,63 @@
+import tensorflow as tf
+
+def lagrange_interpolation(x_values, y_values, x_targets):
+    """
+    Perform Lagrange Polynomial Interpolation using TensorFlow with batch support
+    and multidimensional y_values.
+
+    Parameters:
+    - x_values: Tensor of shape (batch_size, n), original x-values for each batch.
+    - y_values: Tensor of shape (batch_size, n, d), original y-values for each batch.
+    - x_targets: Tensor of shape (batch_size, m), x-values where interpolation is computed.
+
+    Returns:
+    - interpolated_values: Tensor of shape (batch_size, m, d), interpolated y-values for each batch.
+    """
+    # add extra points at -1 and 2, to smooth the interpolation at the edges
+    ones = tf.ones_like(x_values[:, :1])
+    x_values = tf.concat([ones * -1, x_values, ones * 2], axis=1)
+    y_values = tf.concat([y_values[:, :1], y_values, y_values[:, -1:]], axis=1)
+
+    batch_size = tf.shape(x_values)[0]
+    n = tf.shape(x_values)[1]
+    m = tf.shape(x_targets)[-1]
+    d = tf.shape(y_values)[2]
+
+    tf.assert_equal(tf.shape(x_values), (batch_size, n))
+    tf.assert_equal(tf.shape(y_values), (batch_size, n, d))
+    tf.assert_equal(tf.shape(x_targets), (batch_size, m))
+    # Reshape tensors for broadcasting
+    x_values_i = tf.reshape(x_values, (batch_size, n, 1, 1))        # Shape: (batch_size, n, 1, 1)
+    x_values_j = tf.reshape(x_values, (batch_size, 1, n, 1))        # Shape: (batch_size, 1, n, 1)
+
+    x_targets_k = tf.reshape(x_targets, (batch_size, 1, 1, m))  # Shape: (batch_size, 1, 1, m)
+
+    # Compute the denominators (x_i - x_j)
+    denominators = x_values_i - x_values_j                          # Shape: (batch_size, n, n, 1)
+    # Replace zeros on the diagonal with ones to avoid division by zero
+    denominators = tf.where(tf.equal(denominators, 0.0), tf.ones_like(denominators), denominators)
+
+    # Compute the numerators (x_k - x_j)
+    numerators = x_targets_k - x_values_j                           # Shape: (batch_size, 1, n, m)
+
+    # Compute the terms (x_k - x_j) / (x_i - x_j)
+    terms = numerators / denominators                               # Shape: (batch_size, n, n, m)
+
+    # Exclude the terms where i == j by setting them to 1
+    identity_matrix = tf.eye(n, batch_shape=[batch_size], dtype=tf.float64)  # Shape: (batch_size, n, n)
+    identity_matrix = tf.reshape(identity_matrix, (batch_size, n, n, 1))     # Shape: (batch_size, n, n, 1)
+    terms = tf.where(tf.equal(identity_matrix, 1.0), tf.ones_like(terms), terms)
+
+    # Compute the product over j for each i and x_k
+    basis_polynomials = tf.reduce_prod(terms, axis=2)               # Shape: (batch_size, n, m)
+
+    # Multiply each basis polynomial by the corresponding y_i
+    # Adjust shapes for broadcasting
+    basis_polynomials_expanded = tf.expand_dims(basis_polynomials, axis=-1)  # Shape: (batch_size, n, m, 1)
+    y_values_expanded = tf.expand_dims(y_values, axis=2)                     # Shape: (batch_size, n, 1, d)
+    products = basis_polynomials_expanded * y_values_expanded                # Shape: (batch_size, n, m, d)
+
+    # Sum over i to get the interpolated values
+    interpolated_values = tf.reduce_sum(products, axis=1)                    # Shape: (batch_size, m, d)
+
+    return interpolated_values
\ No newline at end of file
diff --git a/NN/networks.py b/NN/networks.py
index 9e0f947..30afcea 100644
--- a/NN/networks.py
+++ b/NN/networks.py
@@ -1,4 +1,4 @@
-from Core.Utils import setupGPU
+from Core.Utils import setupGPU, FACE_MESH_POINTS
 setupGPU() # dirty hack to setup GPU memory limit on startup
 
 import tensorflow as tf
@@ -7,7 +7,7 @@
 from NN.Utils import *
 from NN.EyeEncoder import eyeEncoder
 from NN.FaceMeshEncoder import FaceMeshEncoder
-import numpy as np
+from NN.LagrangianInterpolation import lagrange_interpolation
 
 class CTimeEncoderLayer(tf.keras.layers.Layer):
   def __init__(self, **kwargs):
@@ -136,7 +136,7 @@ def Step2LatentModel(latentSize, embeddingsSize):
 
 def _InputSpec():
   return {
-    'points': tf.TensorSpec(shape=(None, None, 478, 2), dtype=tf.float32),
+    'points': tf.TensorSpec(shape=(None, None, FACE_MESH_POINTS, 2), dtype=tf.float32),
     'left eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
     'right eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
     'time': tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32),
@@ -146,7 +146,7 @@ def _InputSpec():
   }
 
 def Face2LatentModel(
-  pointsN=478, eyeSize=32, steps=None, latentSize=64,
+  pointsN=FACE_MESH_POINTS, eyeSize=32, steps=None, latentSize=64,
   embeddings=None,
   diffusion=False # whether to use diffusion model
 ):
@@ -219,14 +219,179 @@ def Face2LatentModel(
     'inputs specification': _InputSpec()
   }
   
-if __name__ == '__main__':
-  X = Face2LatentModel(steps=5, latentSize=64,
-    embeddings={
-      'userId': 1, 'placeId': 1, 'screenId': 1, 'size': 64
+##########################
+
+def InpaintingInputSpec():
+  XSpec = {
+    'points': tf.TensorSpec(shape=(None, None, FACE_MESH_POINTS, 2), dtype=tf.float32),
+    'left eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
+    'right eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
+    'time': tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32),
+    'userId': tf.TensorSpec(shape=(None, None, 1), dtype=tf.int32),
+    'placeId': tf.TensorSpec(shape=(None, None, 1), dtype=tf.int32),
+    'screenId': tf.TensorSpec(shape=(None, None, 1), dtype=tf.int32),
+    'target': tf.TensorSpec(shape=(None, None, 2), dtype=tf.float32),
+  }
+  YSpec = {
+    'points': tf.TensorSpec(shape=(None, None, FACE_MESH_POINTS, 2), dtype=tf.float32),
+    'left eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
+    'right eye': tf.TensorSpec(shape=(None, None, 32, 32), dtype=tf.float32),
+    'time': tf.TensorSpec(shape=(None, None, 1), dtype=tf.float32),
+    'target': tf.TensorSpec(shape=(None, None, 2), dtype=tf.float32),
+  }
+  return (XSpec, YSpec)
+
+
+def InpaintingEncoderModel(latentSize, embeddingsSize, steps=5, pointsN=FACE_MESH_POINTS, eyeSize=32, KP=5):
+  points = L.Input((steps, pointsN, 2))
+  eyeL = L.Input((steps, eyeSize, eyeSize, 1))
+  eyeR = L.Input((steps, eyeSize, eyeSize, 1))
+  T = L.Input((steps, 1)) # accumulative time
+  target = L.Input((steps, 2))
+  userIdEmb = L.Input((steps, embeddingsSize))
+  placeIdEmb = L.Input((steps, embeddingsSize))
+  screenIdEmb = L.Input((steps, embeddingsSize))
+
+  emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])
+
+  Face2Step = Face2StepModel(pointsN, eyeSize, latentSize, embeddingsSize=emb.shape[-1])
+  stepsData = Face2Step({
+    'embeddings': emb,
+    'points': points,
+    'left eye': eyeL,
+    'right eye': eyeR,
+  })
+
+  diffT = T[:, 1:] - T[:, :-1]
+  diffT = L.Concatenate(-2)([tf.zeros_like(diffT[:, :1]), diffT])
+  combinedT = L.Concatenate(-1)([T, diffT])
+  encodedT = CRolloutTimesteps(CCoordsEncodingLayer(32), name='Time')(combinedT[..., None, :])[..., 0, :]
+
+  latent = stepsData['latent']
+  # add time encoding and target position
+  targetEncoded = CRolloutTimesteps(CCoordsEncodingLayer(32), name='Target')(target[..., None, :])[..., 0, :]
+  latent = L.Concatenate(-1)([latent, encodedT, targetEncoded])
+  # flatten the latent
+  latent = L.Reshape((-1,))(latent)
+
+  # compress the latent
+  latent_N = latent.shape[-1]
+  sizes = []
+  for i in range(1, 4):
+    for _ in range(i):
+       sizes.append(max(latent_N // i, latentSize))
+       
+  sizes.append(latentSize)
+  latent = sMLP(sizes=sizes, activation='relu', name='Compress')(latent)
+  keyT = tf.linspace(0.0, 1.0, KP)[None, :]
+
+  # keyT shape: (B, KP, 1)
+  def transformKeyT(x):
+    t, x = x
+    B = tf.shape(x)[0]
+    return tf.tile(t, (B, 1))[..., None]
+  keyT = L.Lambda(transformKeyT)([keyT, latent])
+  # keyT shape: (B, KP, 1)
+  maxT = T[:, -1, None]
+  keyT = L.Concatenate(-1)([keyT, maxT * keyT]) # fractional time and absolute time
+  encodedKeyT = CRolloutTimesteps(CCoordsEncodingLayer(32), name='KeyTime')(keyT[..., None, :])[..., 0, :]
+
+  def combineKeys(x):
+    latent, keyT = x
+    latent = tf.tile(latent[..., None, :], (1, KP, 1))
+    return L.Concatenate(-1)([latent, keyT])
+  latent = L.Lambda(combineKeys)([latent, encodedKeyT])
+
+  latent = sMLP(sizes=[latentSize] * 3, activation='relu', name='CombineKeys')(latent)
+
+  main = tf.keras.Model(
+    inputs={
+      'points': points,
+      'left eye': eyeL,
+      'right eye': eyeR,
+      'time': T,
+      'target': target,
+      'userId': userIdEmb,
+      'placeId': placeIdEmb,
+      'screenId': screenIdEmb,
+    },
+    outputs={
+      'latent': latent,
     }
   )
-  X['main'].summary(expand_nested=True)
-  X['Face2Step'].summary(expand_nested=False)
-  X['Step2Latent'].summary(expand_nested=False)
-  print(X['main'].outputs)
-  pass
\ No newline at end of file
+  return main
+ 
+def InpaintingDecoderModel(latentSize, embeddingsSize, pointsN=FACE_MESH_POINTS, eyeSize=32, KP=5):
+  latentKeyPoints = L.Input((KP, latentSize))
+  T = L.Input((None, 1))
+  userIdEmb = L.Input((None, embeddingsSize))
+  placeIdEmb = L.Input((None, embeddingsSize))
+  screenIdEmb = L.Input((None, embeddingsSize))
+
+  emb = L.Concatenate(-1)([userIdEmb, placeIdEmb, screenIdEmb])[:, :1, :]
+  # emb shape: (B, 1, 3 * embSize) 
+  def interpolateKeys(x):
+    latents, T = x
+    B = tf.shape(latents)[0]
+    keyT = tf.linspace(0.0, 1.0, KP)[None, :]
+    keyT = tf.tile(keyT, (B, 1))
+    return lagrange_interpolation(x_values=keyT, y_values=latents, x_targets=T[..., 0])
+  latents = L.Lambda(interpolateKeys, name='InterpolateKeys')([latentKeyPoints, T])
+  # latents shape: (B, N, latentSize)
+  def transformLatents(x):
+    latents, emb = x
+    N = tf.shape(latents)[1]
+    emb = tf.tile(emb, (1, N, 1))
+    return L.Concatenate(-1)([latents, emb])
+  latents = L.Lambda(transformLatents, name='CombineEmb')([latents, emb])
+  # process the latents
+  latents = sMLP(sizes=[latentSize] * 3, activation='relu', name='CombineEmb/MLP')(latents)
+  # decode the latents to the face points (FACE_MESH_POINTS, 2), two eyes (32, 32, 2) and the target (2) 
+  target = IntermediatePredictor(shift=0.5)(latents)
+  # two eyes
+  eyesN = eyeSize * eyeSize
+  eyes = sMLP(sizes=[eyesN] * 2, activation='relu')(latents)
+  eyes = L.Dense(eyesN * 2, 'sigmoid')(eyes)
+  eyes = L.Reshape((-1, eyeSize, eyeSize, 2))(eyes)
+  # face points
+  face = sMLP(sizes=[pointsN] * 2, activation='relu')(latents)
+  face = L.Dense(pointsN * 2)(face)
+  face = L.Reshape((-1, pointsN, 2))(face)
+
+  model = tf.keras.Model(
+    inputs={
+      'keyPoints': latentKeyPoints,
+      'time': T,
+      'userId': userIdEmb,
+      'placeId': placeIdEmb,
+      'screenId': screenIdEmb,
+    },
+    outputs={
+      'target': target,
+      'left eye': eyes[..., 0],
+      'right eye': eyes[..., 1],
+      'points': face,
+    }
+  )
+  return model
+
+
+if __name__ == '__main__':
+  # X = InpaintingEncoderModel(latentSize=256, embeddings={
+  #   'size': 64
+  # })
+  X = InpaintingDecoderModel(latentSize=256, embeddings={
+    'size': 64
+  })
+  X.summary(expand_nested=False)
+
+  # X = Face2LatentModel(steps=5, latentSize=64,
+  #   embeddings={
+  #     'userId': 1, 'placeId': 1, 'screenId': 1, 'size': 64
+  #   }
+  # )
+  # X['main'].summary(expand_nested=True)
+  # X['Face2Step'].summary(expand_nested=False)
+  # X['Step2Latent'].summary(expand_nested=False)
+  # print(X['main'].outputs)
+  # pass
\ No newline at end of file
diff --git a/scripts/check-dataset.py b/scripts/check-dataset.py
new file mode 100644
index 0000000..56adacf
--- /dev/null
+++ b/scripts/check-dataset.py
@@ -0,0 +1,92 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-.
+'''
+This script is load one by one the datasets and check how many unique samples are there
+'''
+import argparse, os, sys
+# add the root folder of the project to the path
+ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
+sys.path.append(ROOT_FOLDER)
+
+from Core.CDataSamplerInpainting import CDataSamplerInpainting
+from Core.CDataSampler import CDataSampler
+import Core.Utils as Utils
+import json
+from Core.CSamplesStorage import CSamplesStorage
+
+def samplesStream(params, filename, ID, batch_size, is_inpainting):
+  placeId, userId, screenId = ID
+  storage = CSamplesStorage(placeId=placeId, userId=userId, screenId=screenId)
+  if is_inpainting:
+    ds = CDataSamplerInpainting(
+      storage,
+      defaults=params, 
+      batch_size=batch_size, minFrames=params['timesteps'],
+      keys=['clean']
+    )
+  else:
+    ds = CDataSampler(
+      storage,
+      defaults=params, 
+      batch_size=batch_size, minFrames=params['timesteps'],
+    )
+  ds.addBlock(Utils.datasetFrom(filename))
+  
+  N = ds.totalSamples
+  for i in range(0, N, batch_size):
+    indices = list(range(i, min(i + batch_size, N)))
+    batch, rejected, accepted = ds.sampleByIds(indices)
+    if batch is None: continue
+
+    # main batch
+    x, y = batch
+    if not is_inpainting:
+      x = x['clean']
+    for idx in range(len(x['points'])):
+      yield idx
+  return
+
+def main(args):
+  params = dict(
+    timesteps=args.steps,
+    stepsSampling='uniform',
+    # no augmentations by default
+    pointsNoise=0.0, pointsDropout=0.0,
+    eyesDropout=0.0, eyesAdditiveNoise=0.0, brightnessFactor=1.0, lightBlobFactor=1.0,
+    targets=dict(keypoints=3, total=10),
+  )
+  folder = os.path.join(args.folder, 'Data', 'remote')
+
+  stats = None
+  with open(os.path.join(folder, 'stats.json'), 'r') as f:
+    stats = json.load(f)
+
+  # enable all disabled datasets
+  stats['blacklist'] = []
+  for datasetFolder, ID in Utils.dataset_from_stats(stats, folder):
+    trainFile = os.path.join(datasetFolder, 'train.npz')
+    if not os.path.exists(trainFile):
+      continue
+    print('Processing', trainFile)
+
+    stream = samplesStream(params, trainFile, ID=ID, batch_size=64, is_inpainting=args.inpainting)
+    samplesN = 0
+    for _ in stream:
+      samplesN += 1
+      continue
+    print(f'Dataset has {samplesN} valid samples')
+    if samplesN <= args.min_samples:
+      print(f'Warning: dataset has less or equal to {args.min_samples} samples and will be disabled')
+      stats['blacklist'].append(ID)
+
+  with open(os.path.join(folder, 'stats.json'), 'w') as f:
+    json.dump(stats, f, indent=2, sort_keys=True, default=str)
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  parser.add_argument('--steps', type=int, default=5)
+  parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
+  parser.add_argument('--min-samples', type=int, default=0)
+  parser.add_argument('--inpainting', action='store_true', default=False)
+  main(parser.parse_args())
+  pass
\ No newline at end of file
diff --git a/scripts/create-test-dataset-inpainting.py b/scripts/create-test-dataset-inpainting.py
new file mode 100644
index 0000000..2fef601
--- /dev/null
+++ b/scripts/create-test-dataset-inpainting.py
@@ -0,0 +1,142 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-.
+import argparse, os, sys
+# add the root folder of the project to the path
+ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
+sys.path.append(ROOT_FOLDER)
+
+import numpy as np
+import Core.Utils as Utils
+from Core.CSamplesStorage import CSamplesStorage
+from Core.CDataSamplerInpainting import CDataSamplerInpainting
+from collections import defaultdict
+import json
+import tensorflow as tf
+
+def samplesStream(params, take, filename, ID, batch_size):
+  if not isinstance(take, list): take = [take]
+  placeId, userId, screenId = ID
+  # use the stats to get the numeric values of the placeId, userId, and screenId  
+  ds = CDataSamplerInpainting(
+    CSamplesStorage(
+      placeId=placeId,
+      userId=userId,
+      screenId=screenId,
+    ),
+    defaults=params, 
+    batch_size=batch_size, minFrames=params['timesteps'],
+    keys=take
+  )
+  ds.addBlock(Utils.datasetFrom(filename))
+  
+  N = ds.totalSamples
+  for i in range(0, N, batch_size):
+    indices = list(range(i, min(i + batch_size, N)))
+    batch, rejected, accepted = ds.sampleByIds(indices)
+    if batch is None: continue
+
+    # main batch
+    x, y = batch
+    for idx in range(len(x['points'])):
+      resX = {}
+      for k, v in x.items():
+        item = v[idx, None]
+        if tf.is_tensor(item): item = item.numpy()
+        resX[f'X_{k}'] = item
+        continue
+
+      resY = {}
+      for k, v in y.items():
+        item = v[idx, None]
+        if tf.is_tensor(item): item = item.numpy()
+        resY[f'Y_{k}'] = item
+        continue
+        
+      yield dict(**resX, **resY)
+      continue
+    continue
+  return
+
+def batches(stream, batch_size):
+  data = defaultdict(list)
+  for sample in stream:
+    for k, v in sample.items():
+      data[k].append(v)
+      continue
+
+    if batch_size <= len(data['X_points']):
+      yield data
+      data = defaultdict(list)
+    continue
+
+  if 0 < len(data['X_points']):
+    # copy data to match batch size
+    for k, v in data.items():
+      while len(v) < batch_size: v.extend(v)
+      data[k] = v[:batch_size]
+      continue
+    yield data
+  return
+############################################
+def generateTestDataset(outputFolder, stream):
+  # generate test dataset
+  ONE_MB = 1024 * 1024
+  totalSize = 0
+  if not os.path.exists(outputFolder):
+    os.makedirs(outputFolder, exist_ok=True)
+  for bIndex, batch in enumerate(stream):
+    fname = os.path.join(outputFolder, 'test-%d.npz' % bIndex)
+    # concatenate all arrays
+    batch = {k: np.concatenate(v, axis=0) for k, v in batch.items()}
+    np.savez_compressed(fname, **batch)
+    # get fname size
+    size = os.path.getsize(fname)
+    totalSize += size
+    print('%d | Size: %.1f MB | Total: %.1f MB' % (bIndex + 1, size / ONE_MB, totalSize / ONE_MB))
+    continue
+  print('Done')
+  return
+
+def main(args):
+  PARAMS = [
+    dict(      
+      timesteps=args.steps,
+      stepsSampling='uniform',
+      # no augmentations by default
+      pointsNoise=0.01, pointsDropout=0.0,
+      eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
+      targets=dict(keypoints=3, total=10),
+    ),
+  ]
+  folder = os.path.join(ROOT_FOLDER, 'Data', 'remote')
+
+  stats = None
+  with open(os.path.join(folder, 'stats.json'), 'r') as f:
+    stats = json.load(f)
+
+  for datasetFolder, (place_id_index, user_id_index, screen_id_index) in Utils.dataset_from_stats(stats, folder):
+    trainFile = os.path.join(datasetFolder, 'test.npz')
+    if not os.path.exists(trainFile):
+      continue
+    print('Processing', trainFile)
+
+    for i, params in enumerate(PARAMS):
+      output = args.output
+      if 0 < i: output += '-%d' % i
+      targetFolder = os.path.join(datasetFolder, output)
+      ID = (place_id_index, user_id_index, screen_id_index)
+      stream = samplesStream(params, ['clean'], trainFile, ID=ID, batch_size=args.batch_size)
+      stream = batches(stream, batch_size=args.batch_size)
+      generateTestDataset(outputFolder=targetFolder, stream=stream)
+
+if __name__ == '__main__':
+  parser = argparse.ArgumentParser()
+  parser.add_argument('--steps', type=int, default=5, help='Number of timesteps')
+  parser.add_argument('--batch-size', type=int, default=512, help='Batch size of the test dataset')
+  parser.add_argument(
+    '--output', type=str, help='Output folder name',
+    default='test-inpainting'
+  )
+  args = parser.parse_args()
+  main(args)
+  pass
\ No newline at end of file
diff --git a/scripts/download-remote.py b/scripts/download-remote.py
index 64d8c6a..3656c99 100644
--- a/scripts/download-remote.py
+++ b/scripts/download-remote.py
@@ -8,7 +8,7 @@
     PlaceId
       UserId
         ScreenId
-          start_time.npz
+          *.npz
 '''
 import argparse, os, sys
 # add the root folder of the project to the path
@@ -20,6 +20,7 @@
 import shutil
 import numpy as np
 import requests
+from Core.Utils import FACE_MESH_POINTS
 
 folder = os.path.join(ROOT_FOLDER, 'Data')
 
@@ -65,10 +66,15 @@ def deserialize(buffer):
     offset += EYE_COUNT
     
     # Read points (float32)
-    sample['points'] = np.frombuffer(buffer, dtype='>f4', count=2*478, offset=offset) \
-      .reshape(478, 2)
-    assert np.all(-2 <= sample['points']) and np.all(sample['points'] <= 2), 'Invalid points'
-    offset += 4 * 2 * 478
+    sample['points'] = np.frombuffer(buffer, dtype='>f4', count=2*FACE_MESH_POINTS, offset=offset) \
+      .reshape(FACE_MESH_POINTS, 2)
+    # change the range [-d, 1.0 + d]
+    d = 0.5
+    points = sample['points']
+    is_valid = (-d <= points) & (points <= 1.0 + d)
+    is_valid = is_valid.all(axis=-1)
+    assert np.all(is_valid), 'Invalid points: %s' % points[~is_valid]
+    offset += 4 * 2 * FACE_MESH_POINTS
     
     # Read goal (float32)
     sample['goal'] = goal = np.frombuffer(buffer, dtype='>f4', count=2, offset=offset)
diff --git a/scripts/make-blacklist.py b/scripts/make-blacklist.py
deleted file mode 100644
index ed13c0f..0000000
--- a/scripts/make-blacklist.py
+++ /dev/null
@@ -1,103 +0,0 @@
-#!/usr/bin/env python
-# -*- coding: utf-8 -*-.
-'''
-This script performs the following steps:
-- Load the best model from the Data folder
-- Load the test datasets from the Data/test-main folder
-- Evaluate the model on the test datasets
-- Add each test dataset to blacklists if the model mean loss is greater than the threshold
-'''
-# TODO: add the W&B integration
-import argparse, os, sys
-# add the root folder of the project to the path
-ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
-sys.path.append(ROOT_FOLDER)
-
-import numpy as np
-from Core.CDatasetLoader import CDatasetLoader
-from Core.CTestLoader import CTestLoader
-from collections import defaultdict
-import time
-from Core.CModelTrainer import CModelTrainer
-import tqdm
-import json
-import glob
-
-def _eval(dataset, model):
-  T = time.time()
-  # evaluate the model on the val dataset
-  losses = []
-  predDist = []
-  for batchId in range(len(dataset)):
-    batch = dataset[batchId]
-    loss, _, dist = model.eval(batch)
-    predDist.append(dist)
-    losses.append(loss)
-    continue
-  
-  loss = np.mean(losses)
-  dist = np.mean(predDist)
-  T = time.time() - T
-  return loss, dist, T
-
-def evaluate(dataset, model):
-  loss, dist, T = _eval(dataset, model)
-  print('Test | %.2f sec | Loss: %.5f. Distance: %.5f' % (
-    T, loss, dist,
-  ))
-  return loss, dist
-
-def main(args):
-  timesteps = args.steps
-  folder = args.folder
-  stats = None
-  with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
-    stats = json.load(f)
-    
-  oldBadDatasets = [] # list of tuples (userId, placeId, screenId) strings
-  if os.path.exists(os.path.join(folder, 'blacklist.json')):
-    with open(os.path.join(folder, 'blacklist.json'), 'r') as f:
-      oldBadDatasets = json.load(f)
-    pass
-
-  model = dict(timesteps=timesteps, stats=stats, use_encoders=False)
-  assert args.model is not None, 'The model should be specified'
-  if args.model is not None:
-    model['weights'] = dict(folder=folder, postfix=args.model, embeddings=True)
-
-  model = CModelTrainer(**model)
-  # find folders with the name "/test-*/"
-  badDatasets = []
-  for nm in glob.glob(os.path.join(folder, 'test-main', 'test-*/')):
-    evalDataset = CTestLoader(nm)
-    loss, dist = evaluate(evalDataset, model)
-    if args.threshold < dist:
-      badDatasets.append(evalDataset.parametersIDs())
-    continue
-  # convert indices to the strings
-  res = []
-  for userId, placeId, screenId in badDatasets:
-    userId = stats['userId'][userId]
-    placeId = stats['placeId'][placeId]
-    screenId = stats['screenId'][screenId]
-    res.append((userId, placeId, screenId))
-    continue
-  res = oldBadDatasets + res # add the old blacklisted datasets
-  print('Blacklisted datasets:')
-  print(json.dumps(res, indent=2))
-  # save the blacklisted datasets
-  with open(os.path.join(folder, 'blacklist.json'), 'w') as f:
-    json.dump(res, f, indent=2)
-  return
-
-if __name__ == '__main__':
-  parser = argparse.ArgumentParser()
-  parser.add_argument('--steps', type=int, default=5)
-  parser.add_argument('--model', type=str, default='best')
-  parser.add_argument('--folder', type=str, default=os.path.join(ROOT_FOLDER, 'Data'))
-  parser.add_argument(
-    '--threshold', type=float, required=True,
-  )
-
-  main(parser.parse_args())
-  pass
\ No newline at end of file
diff --git a/scripts/preprocess-remote.py b/scripts/preprocess-remote.py
index 167dd29..294c0e4 100644
--- a/scripts/preprocess-remote.py
+++ b/scripts/preprocess-remote.py
@@ -179,7 +179,7 @@ def processFolder(folder, timeDelta, testRatio, framesPerChunk, testPadding, ski
       start, end, np.min(delta), np.max(delta), np.mean(delta), len(session_time), duration
     ))
     stats['deltas'].append(delta)
-    stats['durations'].append(duration)
+    stats['durations'].append([duration])
     continue
   ######################################################
   # split each session into training and testing sets
@@ -194,7 +194,7 @@ def processFolder(folder, timeDelta, testRatio, framesPerChunk, testPadding, ski
 
   if (0 == len(training)) or (0 == len(testing)):
     print('No training or testing sets found!')
-    return 0, 0, True
+    return 0, 0, True, None
 
   def saveSubset(filename, idx):
     print('%s: %d frames' % (filename, len(idx)))
@@ -215,17 +215,6 @@ def saveSubset(filename, idx):
   return len(testing), len(training), False, stats
 
 def main(args):
-  # blacklisted datasets
-  blacklisted = []
-  if args.blacklist is not None:
-    with open(args.blacklist, 'r') as f:
-      blacklisted = json.load(f)
-    pass
-  blacklisted = set([
-    '/'.join(item)
-    for item in blacklisted
-  ])
-  print(blacklisted)
   stats = {
     'placeId': [],
     'userId': [],
diff --git a/scripts/train-reconstruction.py b/scripts/train-reconstruction.py
new file mode 100644
index 0000000..1b584d5
--- /dev/null
+++ b/scripts/train-reconstruction.py
@@ -0,0 +1,194 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+import argparse, os, sys
+# Add the root folder of the project to the path
+ROOT_FOLDER = os.path.abspath(os.path.dirname(__file__) + '/../')
+sys.path.append(ROOT_FOLDER)
+
+import numpy as np
+from Core.CDataSamplerInpainting import CDataSamplerInpainting
+from Core.CDatasetLoader import CDatasetLoader
+from Core.CTestInpaintingLoader import CTestInpaintingLoader
+import Core.Utils as Utils
+from collections import defaultdict
+import time
+from Core.CInpaintingTrainer import CInpaintingTrainer
+import tqdm
+import json
+import wandb
+
+def _eval(dataset, model):
+    T = time.time()
+    # Evaluate the model on the validation dataset
+    loss = []
+    for batchId in range(len(dataset)):
+        batch = dataset[batchId]
+        loss_value = model.eval(batch)
+        loss.append(loss_value)
+    loss = np.mean(loss)
+    T = time.time() - T
+    return loss, T
+
+def evaluator(datasets, model, folder, args):
+    losses = [np.inf] * len(datasets)  # Initialize with infinity
+    def evaluate(onlyImproved=False, step=None):
+        totalLoss = []
+        eval_metrics = {}
+        for i, dataset in enumerate(datasets):
+            loss, T = _eval(dataset, model)
+            dataset_id = ', '.join([str(x) for x in dataset.parametersIDs()])
+            isImproved = loss < losses[i]
+            if (not onlyImproved) or isImproved:
+                print('Test %d / %d (%s) | %.2f sec | Loss: %.5f (%.5f).' % (
+                    i + 1, len(datasets), dataset_id, T, loss, losses[i],
+                ))
+            if isImproved:
+                print('Test %d / %d | Improved %.5f => %.5f,' % (
+                    i + 1, len(datasets), losses[i], loss,
+                ))
+                modelFolder = os.path.join(folder, f"model-{dataset_id}")
+                os.makedirs(modelFolder, exist_ok=True)
+                # keep only the best model across all runs
+                # name format: {model id}-{loss:.5f}-*.*
+                all_files = os.listdir(modelFolder)
+                all_losses = [f.split('-')[1] for f in all_files]
+                all_losses = list(set(all_losses))
+                print(f"Found losses: {all_losses}")
+                for loss_file in all_losses:
+                    if float(loss_file) > loss:
+                        # remove all files with this loss in folder
+                        to_remove = [os.path.join(modelFolder, f) for f in all_files if loss_file in f]
+                        for f in to_remove:
+                            os.remove(f)
+
+                model.save(modelFolder, postfix='%.5f' % loss)
+                losses[i] = loss
+            totalLoss.append(loss)
+            eval_metrics['eval_loss_(%s)' % dataset_id] = loss
+        mean_loss = np.mean(totalLoss)
+        if not onlyImproved:
+            print('Mean loss: %.5f' % mean_loss)
+        # Log evaluation metrics to wandb
+        if step is not None:
+            wandb.log(eval_metrics, step=step)
+        return mean_loss
+    return evaluate
+
+def _modelTrainingLoop(model, dataset):
+    def F(desc):
+        history = defaultdict(list)
+        # Use the tqdm progress bar
+        with tqdm.tqdm(total=len(dataset), desc=desc) as pbar:
+            dataset.on_epoch_start()
+            for step in range(len(dataset)):
+                sampled = dataset.sample()
+                stats = model.fit(sampled)
+                history['time'].append(stats['time'])
+                for k in stats['losses'].keys():
+                    history[k].append(stats['losses'][k])
+                # Add stats to the progress bar (mean of each history)
+                pbar.set_postfix({k: '%.5f' % np.mean(v) for k, v in history.items()})
+                pbar.update(1)
+            dataset.on_epoch_end()
+        return {k: np.mean(v) for k, v in history.items()}
+    return F
+
+def _trainer_from(args):
+    if args.trainer == 'default': return CInpaintingTrainer
+    raise Exception('Unknown trainer: %s' % (args.trainer, ))
+
+def main(args):
+    wandb.init(project=args.wandb_project, config=vars(args))  # Initialize wandb
+    timesteps = args.steps
+    folder = os.path.join(args.folder, 'Data')
+
+    stats = None
+    with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
+        stats = json.load(f)
+
+    trainer = _trainer_from(args)
+    trainDataset = CDatasetLoader(
+        os.path.join(folder, 'remote'),
+        stats=stats,
+        sampling=args.sampling,
+        samplerArgs=dict(
+            batch_size=args.batch_size,
+            minFrames=timesteps,
+            maxT=1.0,
+            defaults=dict(
+                timesteps=timesteps,
+                stepsSampling={'max frames': 10},
+                # No augmentations by default
+                pointsNoise=0.01, pointsDropout=0.01,
+                eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
+                targets=dict(keypoints=3, total=10),
+            ),
+            keys=['clean'],
+        ),
+        sampler_class=CDataSamplerInpainting,
+        test_folders=['train.npz'],
+    )
+    model = dict(timesteps=timesteps, stats=stats)
+    if args.model is not None:
+        model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
+    if args.modelId is not None:
+        model['model'] = args.modelId
+
+    model = trainer(**model)
+
+    evalDatasets = [
+        CTestInpaintingLoader(os.path.join(folderName, 'test-inpainting'))
+        for folderName, _ in Utils.dataset_from_stats(stats, os.path.join(folder, 'remote'))
+        if os.path.exists(os.path.join(folderName, 'test-inpainting'))
+    ]
+    eval_fn = evaluator(evalDatasets, model, folder, args)
+    bestLoss = eval_fn()  # Evaluate loaded model
+    bestEpoch = 0
+
+    def evalWrapper(eval_fn):
+        def f(epoch, onlyImproved=False, step=None):
+            nonlocal bestLoss, bestEpoch
+            newLoss = eval_fn(onlyImproved=onlyImproved, step=step)
+            if newLoss < bestLoss:
+                print('Improved %.5f => %.5f' % (bestLoss, newLoss))
+                bestLoss = newLoss
+                bestEpoch = epoch
+                model.save(folder, postfix='%.5f' % newLoss)
+            return
+        return f
+
+    eval_fn = evalWrapper(eval_fn)
+    trainStep = _modelTrainingLoop(model, trainDataset)
+    for epoch in range(args.epochs):
+        metrics = trainStep(
+            desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
+        )
+        wandb.log(metrics, step=epoch + 1)
+        model.save(folder, postfix='latest')
+        eval_fn(epoch, step=epoch + 1)
+        print('Passed %d epochs since the last improvement (best: %.5f)' % (epoch - bestEpoch, bestLoss))
+        if args.patience <= (epoch - bestEpoch):
+            print('Early stopping')
+            break
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    parser.add_argument('--epochs', type=int, default=1000)
+    parser.add_argument('--batch-size', type=int, default=64)
+    parser.add_argument('--patience', type=int, default=5)
+    parser.add_argument('--steps', type=int, default=5)
+    parser.add_argument('--model', type=str)
+    parser.add_argument('--embeddings', default=False, action='store_true')
+    parser.add_argument('--folder', type=str, default=ROOT_FOLDER)
+    parser.add_argument('--modelId', type=str)
+    parser.add_argument(
+        '--trainer', type=str, default='default',
+        choices=['default']
+    )
+    parser.add_argument(
+        '--sampling', type=str, default='uniform',
+        choices=['uniform', 'as_is'],
+    )
+    parser.add_argument('--wandb-project', type=str, default='alternative-input-reconstruction')
+
+    main(parser.parse_args())
diff --git a/scripts/train.py b/scripts/train.py
index 2afab81..a580517 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -8,11 +8,11 @@
 
 import numpy as np
 from Core.CDatasetLoader import CDatasetLoader
+from Core.CDataSampler import CDataSampler
 from Core.CTestLoader import CTestLoader
 from collections import defaultdict
 import time
 from Core.CModelTrainer import CModelTrainer
-from Core.CModelDiffusion import CModelDiffusion
 import tqdm
 import json
 import glob
@@ -107,13 +107,13 @@ def evaluate(onlyImproved=False):
   return evaluate
 
 def _modelTrainingLoop(model, dataset):
-  def F(desc, sampleParams):
+  def F(desc):
     history = defaultdict(list)
     # use the tqdm progress bar
     with tqdm.tqdm(total=len(dataset), desc=desc) as pbar:
       dataset.on_epoch_start()
       for _ in range(len(dataset)):
-        sampled = dataset.sample(**sampleParams)
+        sampled = dataset.sample()
         assert 2 == len(sampled), 'The dataset should return a tuple with the input and the output'
         X, Y = sampled
         assert 'clean' in X, 'The input should contain the clean data'
@@ -140,54 +140,8 @@ def F(desc, sampleParams):
     return
   return F
 
-def _defaultSchedule(args):
-  return lambda epoch: dict()
-
-def _schedule_from_json(args):
-  with open(args.schedule, 'r') as f:
-    schedule = json.load(f)
-
-  # schedule is a dictionary of dictionaries, where the keys are the epochs
-  # transform it into a sorted list of tuples (epoch, params)
-  for k, v in schedule.items():
-    v = [(int(epoch), p) for epoch, p in v.items()]
-    schedule[k] = sorted(v, key=lambda x: x[0])
-    continue
-
-  def F(epoch):
-    res = {}
-    for k, v in schedule.items():
-      assert isinstance(v, list), 'The schedule should be a list of parameters'
-      # find the first epoch that is less or equal to the current one
-      smallest = [i for i, (e, _) in enumerate(v) if e <= epoch]
-      if 0 == len(smallest): continue
-      smallest = smallest[-1]
-
-      startEpoch, p = v[smallest]
-      value = p
-      # p could be a dictionary or value
-      if isinstance(p, list) and (2 == len(p)):
-        assert smallest + 1 < len(v), 'The last epoch should be the last one'
-        minV, maxV = [float(x) for x in p]
-        nextEpoch, _ = v[smallest + 1]
-        # linearly interpolate between the values
-        value = minV + (maxV - minV) * (epoch - startEpoch) / (nextEpoch - startEpoch)
-        pass
-      
-      res[k] = float(value)
-      continue
-
-    if args.debug and res:
-      print('Parameters for epoch %d:' % (epoch, ))
-      for k, v in res.items():
-        print('  %s: %.5f' % (k, v))
-        continue
-    return res
-  return F
-
 def _trainer_from(args):
   if args.trainer == 'default': return CModelTrainer
-  if args.trainer == 'diffusion': return CModelDiffusion
   raise Exception('Unknown trainer: %s' % (args.trainer, ))
 
 def averageModels(folder, model, noiseStd=0.0):
@@ -215,11 +169,6 @@ def averageModels(folder, model, noiseStd=0.0):
 def main(args):
   timesteps = args.steps
   folder = os.path.join(args.folder, 'Data')
-  if args.schedule is None:
-    getSampleParams = _defaultSchedule(args)
-  else:
-    getSampleParams = _schedule_from_json(args)
-
   stats = None
   with open(os.path.join(folder, 'remote', 'stats.json'), 'r') as f:
     stats = json.load(f)
@@ -240,9 +189,10 @@ def main(args):
         pointsNoise=0.01, pointsDropout=0.0,
         eyesDropout=0.1, eyesAdditiveNoise=0.01, brightnessFactor=1.5, lightBlobFactor=1.5,
       ),
-    )
+    ),
+    sampler_class=CDataSampler
   )
-  model = dict(timesteps=timesteps, stats=stats, use_encoders=args.with_enconders)
+  model = dict(timesteps=timesteps, stats=stats)
   if args.model is not None:
     model['weights'] = dict(folder=folder, postfix=args.model, embeddings=args.embeddings)
   if args.modelId is not None:
@@ -298,7 +248,6 @@ def performRandomSearch(epoch=0):
   for epoch in range(args.epochs):
     trainStep(
       desc='Epoch %.*d / %d' % (len(str(args.epochs)), epoch, args.epochs),
-      sampleParams=getSampleParams(epoch)
     )
     model.save(folder, postfix='latest')
     eval(epoch)
@@ -332,7 +281,7 @@ def performRandomSearch(epoch=0):
   parser.add_argument('--modelId', type=str)
   parser.add_argument(
     '--trainer', type=str, default='default',
-    choices=['default', 'diffusion']
+    choices=['default']
   )
   parser.add_argument(
     '--schedule', type=str, default=None,
@@ -344,9 +293,6 @@ def performRandomSearch(epoch=0):
     '--restarts', type=int, default=1,
     help='Number of times to restart the model reinitializing the weights'
   )
-  parser.add_argument(
-    '--with-enconders', default=False, action='store_true',
-  )
   parser.add_argument(
     '--sampling', type=str, default='uniform',
     choices=['uniform', 'as_is'],