In [1]:
!pip install avro
!pip install tensorflow==2.14.1
!pip install tensorflow-probability
!pip install tensorflow-io
!pip install keras-tuner
!pip install optuna
!pip install optuna-integration
!pip install biopython
!pip install nglview
#!pip install -q tensorflow-cloud

Collecting avro
  Downloading avro-1.11.3.tar.gz (90 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/90.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━[0m [32m81.9/90.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m90.6/90.6 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: avro
  Building wheel for avro (pyproject.toml) ... [?25l[?25hdone
  Created wheel for avro: filename=avro-1.11.3-py2.py3-none-any.whl size=123911 sha256=23b9e21d4bfe69cedec99bdd64c82cae26ffa17e16240cd8d931d7ec1b4967cb
  Stored in directory: /root/.cache/pip/wheels/1d/f6/41/0e0399396af07060e64d4e32c8bd259b48b98a4a114df31294
Successfully built avro
In

In [2]:
#!gcloud config set project shaker-388116
#!gcloud auth application-default login

In [3]:
import collections

from google.colab import auth
from google.colab import data_table

from google.cloud import storage

from avro import datafile
from avro import schema
import io as ior
from avro import io
import math
import numpy as np
import optuna
from sklearn import neighbors
import keras_tuner as kt
import tensorflow as tf
import tensorflow_io as tfio
import tensorflow_probability as tfp
import random
import sys
import time

from Bio import PDB
import nglview

tfd = tfp.distributions



In [4]:
from google.colab import output
output.enable_custom_widget_manager()

In [5]:
try:
  tpu = tf.distribute.cluster_resolver.TPUClusterResolver()  # TPU detection
  print('Running on TPU ', tpu.cluster_spec().as_dict()['worker'])
except ValueError:
  raise BaseException('ERROR: Not connected to a TPU runtime; please see the previous cell in this notebook for instructions!')

tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
tpu_strategy = tf.distribute.TPUStrategy(tpu)

Running on TPU  ['10.27.128.162:8470']


In [6]:
auth.authenticate_user()

In [7]:
def GetStaticVocab(vocab):
  return tf.lookup.StaticVocabularyTable(
      tf.lookup.KeyValueTensorInitializer(
          tf.constant(vocab),
          tf.constant(range(2, len(vocab)+2), dtype=tf.int64),
          key_dtype=tf.string,
          value_dtype=tf.int64), 1)

In [8]:
client = storage.Client('shaker-388116')
bucket = client.bucket('unreplicated-training-data')
blob = bucket.get_blob('training_examples_summary/data-00000-of-00001.avro')
reader = datafile.DataFileReader(blob.open('rb'), io.DatumReader())
residue_names = []
atom_names = []
for x in reader:
  residue_names = x['residue_names']
  atom_names = x['atom_names']
with tf.device('/CPU:0'):
  residue_names_preprocessor = GetStaticVocab(residue_names)
  atom_names_preprocessor = GetStaticVocab(atom_names)



---



---



---



In [9]:
raw_tf_examples = tf.data.TFRecordDataset(
    ['gs://unreplicated-training-data/training_examples/data-{0}-of-00037.tfrecord'.format(str(i).zfill(5))
    for i in range(37)])
feature_spec = {
      'name': tf.io.FixedLenFeature([], tf.string, default_value=''),
      'residue_names': tf.io.FixedLenFeature([], tf.string, default_value=''),
      'atom_names': tf.io.FixedLenFeature([], tf.string, default_value=''),
      'normalized_coordinates': tf.io.FixedLenFeature([], tf.string, default_value='')}


def _ConvertToFeatures(x, feature_spec, residue_names_preprocessor,
                       atom_names_preprocessor):
  example = tf.io.parse_single_example(x, feature_spec)
  residue_names = residue_names_preprocessor.lookup(
      tf.io.parse_tensor(example['residue_names'], tf.string))
  atom_names = atom_names_preprocessor.lookup(
      tf.io.parse_tensor(example['atom_names'], tf.string))
  normalized_coordinates = tf.io.parse_tensor(
      example['normalized_coordinates'], tf.float32)
  return {
      'name': example['name'],
      'residue_names': residue_names,
      'atom_names': atom_names,
      'normalized_coordinates': normalized_coordinates}

ds = raw_tf_examples.map(lambda x: _ConvertToFeatures(
    x, feature_spec, residue_names_preprocessor, atom_names_preprocessor))

In [10]:
residue_lookup_size = len(residue_names)+2
atom_lookup_size = len(atom_names)+2

In [11]:
Z_EMBEDDING_SIZE = 61
COND_EMBEDDING_SIZE = 6
ENCODER_CONVOLVE_SIZE = 10

In [12]:
TIMESTEP_EMBEDDING_DIMS = 10
AMINO_ACID_EMBEDDING_DIMS = 20
tf.debugging.enable_traceback_filtering()

In [13]:
def AttentionLayer(num_heads, inputs):
  return tf.keras.layers.LayerNormalization()(tf.keras.layers.Add()([inputs, tf.keras.layers.MultiHeadAttention(num_heads, 10)(inputs, inputs, inputs)]))

In [14]:
def FeedForwardLayer(num_layers, output_size, inputs):
  t = inputs
  for i in range(num_layers):
    t = tf.keras.layers.Dense(100, 'gelu')(t)
  t = tf.keras.layers.Dense(output_size)(t)
  return tf.keras.layers.LayerNormalization()(
      tf.keras.layers.Add()([inputs, t]))

In [15]:
def TransformerLayer(num_transformers, num_heads, num_dnn_layers, output_size, inputs):
  x = inputs
  for i in range(num_transformers):
    x = AttentionLayer(num_heads, x)
    x = FeedForwardLayer(num_dnn_layers, output_size, x)
  return x

In [16]:
def PositionalEmbedding(z):
  pos_indices = tf.expand_dims(tf.expand_dims(tf.range(tf.shape(z)[1], dtype='float32'), 0) * tf.ones(tf.shape(z)[:-1]), -1)
  half_dim = AMINO_ACID_EMBEDDING_DIMS // 2
  pemb = tf.math.log(20_000.0) / (half_dim - 1)
  pemb = tf.math.exp(tf.range(half_dim, dtype='float') * - pemb)
  pemb = pos_indices * pemb
  pemb = tf.keras.layers.concatenate([tf.math.sin(pemb), tf.math.cos(pemb)])
  return pemb

In [17]:
def DecoderModel():
  z_0_rescaled = tf.keras.Input(shape=(None, Z_EMBEDDING_SIZE),
                                name='z_0_rescaled')
  cond = tf.keras.Input(shape=(None, COND_EMBEDDING_SIZE),
                        name='cond')
  pemb = PositionalEmbedding(z_0_rescaled)

  base_inputs = tf.keras.layers.concatenate(inputs=[
      z_0_rescaled, cond, pemb])
  convolved_inputs = tf.keras.layers.Conv1DTranspose(32, ENCODER_CONVOLVE_SIZE, padding='same')(base_inputs)
  concatenated_inputs = tf.keras.layers.concatenate(inputs=[base_inputs, convolved_inputs])
  transformer_output = TransformerLayer(1, 5, 5, 119, concatenated_inputs)

  scale_diag = tf.Variable(1.0)
  loc = tf.keras.layers.Dense(3)(tf.keras.layers.Dense(100, 'gelu')(tf.keras.layers.Dense(100, 'gelu')(transformer_output)))

  return tf.keras.Model(inputs=[z_0_rescaled, cond],
                        outputs=[loc,
                                 tf.keras.layers.Identity()(
                                    scale_diag*tf.ones_like(loc))])

In [18]:
class DecoderTrain(object):
  def __init__(self, model):
    self._model = model

  # Decodes a latent representation to a probability distribution on the
  # location of each atom.
  #
  # Args:
  #  z_0_rescaled: A tenstor for the latent distribution at step 0.
  #    Rescaled by alpha. z_0_rescaled should have dimensions:
  #      (batch_size, num_atoms, num_channels)
  #  cond: A tensor with context information. Should have dimensions:
  #      (batch_size, num_atoms, num_channels)
  #
  # Returns: A multivariate distibution on R^(batch_size X num_atoms)
  def decode(self, z_0_rescaled, cond, training):
    outputs = self._model({
        'z_0_rescaled': z_0_rescaled,
        'cond': cond
    }, training=training)
    return tfd.MultivariateNormalDiag(
            loc=outputs[0], scale_diag=outputs[1])

  def trainable_weights(self):
    return self._model.trainable_weights

  def save(self, location):
    self._model.save(location, overwrite=True, save_format='tf',
                     options=tf.saved_model.SaveOptions())

In [19]:
def EncoderModel():
  normalized_coordinates = tf.keras.Input(shape=(None, 3),
                                          name='normalized_coordinates')
  cond = tf.keras.Input(shape=(None, COND_EMBEDDING_SIZE), name='cond')

  pemb = PositionalEmbedding(normalized_coordinates)

  encoded_coordinates = tf.keras.layers.concatenate(
      inputs=[normalized_coordinates, cond, pemb])


  convolved_coordinates = tf.keras.layers.SeparableConv1D(
      32, ENCODER_CONVOLVE_SIZE, activation='gelu', padding='same')(encoded_coordinates)
  concatenate_inputs = tf.keras.layers.concatenate(inputs=[
      convolved_coordinates, encoded_coordinates])
  transformer_output = TransformerLayer(1, 5, 5, 61, concatenate_inputs)


  return tf.keras.Model(inputs=[normalized_coordinates, cond],
                        outputs=tf.keras.layers.Identity()(transformer_output))

In [20]:
class EncoderTrain(object):
  def __init__(self, model):
    self._model = model

  # Encodes a latent distribution for each atom.
  #
  # Args:
  #  normalized_coordinates: The atom's coordinates, normalized to have mean 0.
  #  cond: A tensor with context information. Should have dimensions:
  #      (batch_size, num_atoms, num_channels)
  #
  # Returns: A latent distribution with dimensions:
  #   (batch_size, num_atoms, num_channels)
  def encode(self, normalized_coordinates, cond, training):
    to_return = self._model({
        'normalized_coordinates': normalized_coordinates,
        'cond': cond
    }, training=training)
    return to_return

  def trainable_weights(self):
    return self._model.trainable_weights

  def save(self, location):
    self._model.save(location, overwrite=True, save_format='tf', options=tf.saved_model.SaveOptions())

In [21]:
def CondModel():
  residue_names = tf.keras.Input(shape=(None,), name='residue_names')
  atom_names = tf.keras.Input(shape=(None,), name='atom_names')

  residue_embeddings = tf.keras.layers.Embedding(
      input_dim=residue_lookup_size,
      output_dim=3)(residue_names)
  atom_embeddings = tf.keras.layers.Embedding(
      input_dim=atom_lookup_size,
      output_dim=3)(atom_names)

  cond_out = tf.keras.layers.concatenate(
      inputs=[residue_embeddings, atom_embeddings])
  cond_out = tf.keras.layers.Dense(100, 'gelu')(cond_out)
  cond_out = tf.keras.layers.Dense(COND_EMBEDDING_SIZE)(cond_out)

  return tf.keras.Model(inputs=[residue_names, atom_names], outputs=cond_out)

In [22]:
class CondTrain(object):
  def __init__(self, model):
    self._model = model

  # Returns a conditioning for the Inverse Problem.
  # Args:
  #   residue_names: Integer Tensor representing the residue names of each atom.
  #     Should have shape (batch_size, num_atoms)
  #   atom_names: Integer Tensor representing the atom names of each atom.
  #     Should have shape (batch_size, num_atoms)
  # Returns: The conditioning of the inverse problem. Shoud have shape
  #   (batch_size, num_atoms, num_channels).
  def conditioning(self, residue_names, atom_names, training):
    return self._model({'residue_names': residue_names,
                        'atom_names': atom_names}, training=training)

  def trainable_weights(self):
    return self._model.trainable_weights

  def save(self, location):
    self._model.save(location, overwrite=True, save_format='tf', options=tf.saved_model.SaveOptions())

In [23]:
def ScoreModel():
  z = tf.keras.Input(shape=(None, Z_EMBEDDING_SIZE), name='z')
  gamma = tf.keras.Input(shape=[], name='gamma')
  cond = tf.keras.Input(shape=(None, COND_EMBEDDING_SIZE), name='cond')

  # Compute timestep embedding
  t = gamma * 1000
  t = tf.expand_dims(tf.expand_dims(t, -1) * tf.ones(tf.shape(z)[:-1]), -1)
  half_dim = TIMESTEP_EMBEDDING_DIMS // 2
  temb = tf.math.log(10_000.0) / (half_dim - 1)
  temb = tf.math.exp(tf.range(half_dim, dtype='float') * - temb)
  temb = t * temb
  temb = tf.keras.layers.concatenate([tf.math.sin(temb), tf.math.cos(temb)])

  # Compute Amino Acid Positional Embedding
  pos_indices = tf.expand_dims(tf.expand_dims(tf.range(tf.shape(z)[1], dtype='float32'), 0) * tf.ones(tf.shape(z)[:-1]), -1)
  half_dim = AMINO_ACID_EMBEDDING_DIMS // 2
  pemb = tf.math.log(20_000.0) / (half_dim - 1)
  pemb = tf.math.exp(tf.range(half_dim, dtype='float') * - pemb)
  pemb = pos_indices * pemb
  pemb = tf.keras.layers.concatenate([tf.math.sin(pemb), tf.math.cos(pemb)])

  base_features = tf.keras.layers.concatenate(
      inputs=[z, cond, temb, pemb])
  score_convolve_layer = tf.keras.layers.Conv1DTranspose(
      64, ENCODER_CONVOLVE_SIZE, padding='same', activation='gelu')
  concatenated_features = tf.keras.layers.concatenate(
      inputs=[base_features, score_convolve_layer(base_features)])

  transformer_output = TransformerLayer(1, 5, 5, 161, concatenated_features)
  score = tf.keras.layers.Dense(Z_EMBEDDING_SIZE)(transformer_output)

  return tf.keras.Model(inputs=[z, gamma, cond], outputs=score)

In [24]:
def PerfectScoreModel(perfect_knowledge):
  z = tf.keras.Input(shape=(None, Z_EMBEDDING_SIZE), name='z')
  gamma = tf.keras.Input(shape=[], name='gamma')
  cond = tf.keras.Input(shape=(None, COND_EMBEDDING_SIZE), name='cond')

  a = tf.math.sqrt(1 - tf.math.sigmoid(-1*gamma))
  var = tf.math.sigmoid(-1*gamma)
  score = tf.divide((z - a*perfect_knowledge), tf.math.sqrt(var))
  return tf.keras.Model(inputs=[z,gamma,cond], outputs=score)

In [25]:
class ScoreTrain(object):
  def __init__(self, model):
    self._model = model

  # Returns an estimate of the error in z.
  # Args:
  #   z: The latent space embedding with an error introduced.
  #     Should have shape (batch_size, num_atoms, num_channels)
  #   gamma: The value of gamma used in the variance preserving map used to
  #     construct z. Should have shape (batch_size,)
  #   cond: The conditioning passed in to guide the reconstruction.
  #     Should have shape (batch_size, num_atoms, num_channels).
  # Returns: An estimate of the epsilon error introduced. Shoud have shape
  #   (batch_size, num_atoms, num_channels).
  def score(self, z, gamma, cond, training):
    score_val =  self._model({'z': z,
                         'gamma': gamma,
                         'cond': cond}, training=training)
    return score_val

  def trainable_weights(self):
    return self._model.trainable_weights

  def save(self, location):
    self._model.save(location, overwrite=True, save_format='tf', options=tf.saved_model.SaveOptions())

In [37]:
class GammaModule(tf.Module):
  def __init__(self):
    self._l1 = tf.keras.layers.Dense(
        1, kernel_constraint=tf.keras.constraints.NonNeg())
    self._l2 = tf.keras.layers.Dense(
        1024, activation='sigmoid',
        kernel_constraint=tf.keras.constraints.NonNeg())
    self._l3 = tf.keras.layers.Dense(
        1, kernel_constraint=tf.keras.constraints.NonNeg())

  @tf.function
  def GetGamma(self, ts):
    l1_t = self._l1(ts)
    return -1*(l1_t + self._l3(self._l2(self._l1(ts))))

In [126]:
from keras.layers import concatenate
class DiffusionModel:
  def __init__(self, residue_lookup_size, atom_lookup_size,
               gamma_module, decoder, encoder, conditioner, scorer):
    self._timesteps = 10000

    self._gamma_module = gamma_module
    self._decoder = decoder
    self._encoder = encoder
    self._conditioner = conditioner
    self._scorer = scorer

  def gamma(self, ts):
    return self._gamma_module.GetGamma(ts)
    #gamma_max = tf.math.abs(self._gamma_module.gamma_max)
    #gamma_min = -1*tf.math.abs(self._gamma_module.gamma_min)
    #return gamma_max + (gamma_min - gamma_max) * ts

  def sigma2(self, gamma):
    return tf.math.sigmoid(-1*gamma)

  def alpha(self, gamma):
    return tf.math.sqrt(1-self.sigma2(gamma))

  def variance_preserving_map(self, x, gamma, eps):
    a = self.alpha(gamma)
    var = self.sigma2(gamma)
    s1 = tf.expand_dims(tf.expand_dims(a, axis=-1), axis=-1)*x
    s2 =  (
        tf.expand_dims(
            tf.expand_dims(tf.math.sqrt(var), axis=-1), axis=-1) *
        eps)
    return s1 + s2

  def trainable_weights(self):
    return (self._decoder.trainable_weights() +
            self._encoder.trainable_weights() +
            self._conditioner.trainable_weights() +
            self._scorer.trainable_weights() +
            list(self._gamma_module.trainable_variables))

  def decoder_weights(self):
    return self._decoder.trainable_weights()

  def scale_diag_weights(self):
    return [self._scale_diag]

  def loc_decoder_weights(self):
    trainable_weights = []
    for x in self._loc_decoder_layers:
      trainable_weights.extend(x.trainable_weights)
    return trainable_weights

  def recon_loss(self, x, f, f_mask, cond, training):
    g_0 = self.gamma(tf.constant([[0]]))[0][0]
    eps_0 = tf.random.normal(tf.shape(f))
    z_0 = self.variance_preserving_map(f, g_0, eps_0)
    z_0_rescaled = z_0 / self.alpha(g_0)
    prob_dist = self._decoder.decode(z_0_rescaled, cond, training)
    loss_recon = -tf.reduce_sum(
        tf.math.multiply(prob_dist.log_prob(x), f_mask), axis=[-1])
    return loss_recon, tf.reduce_sum(tf.math.multiply(
        tf.math.abs(x-prob_dist.mean()),tf.expand_dims(f_mask, -1)))/tf.reduce_sum(f_mask)

  def latent_loss(self, f, f_mask):
    g_1 = self.gamma(tf.constant([[1.0]]))[0][0]
    var_1 = self.sigma2(g_1)
    mean1_sqr = (1. - var_1) * tf.square(f)
    loss_klz = 0.5 * tf.reduce_sum(
        tf.math.multiply(mean1_sqr + var_1 - tf.math.log(var_1) - 1., tf.expand_dims(f_mask, -1)),
        axis=[-1, -2])
    return loss_klz

  def diffusion_loss(self, t, f, f_mask, cond, training):
    # sample z_t.
    g_t = tf.squeeze(self.gamma(tf.expand_dims(t, -1)), -1)
    eps = tf.random.normal(tf.shape(f))
    z_t = self.variance_preserving_map(f, g_t, eps)
    # compute predicted noise
    eps_hat = self._scorer.score(z_t, g_t, cond, training)
    # MSE of predicted noise
    loss_diff_se = tf.reduce_sum(
        tf.math.multiply(tf.square(eps - eps_hat), tf.expand_dims(f_mask, -1)), axis=[-1, -2])
    loss_diff_mse = tf.reduce_sum(loss_diff_se)/tf.reduce_sum(f_mask)

    # loss for finite depth T, i.e. discrete time
    T = self._timesteps
    s = t - (1.0/T)
    g_s = tf.squeeze(self.gamma(tf.expand_dims(s, -1)), -1)
    loss_diff = 0.5 * T * tf.math.expm1(g_s - g_t) * loss_diff_se
    return loss_diff, loss_diff_mse

  def compute_model_loss(self, training_data, training=True):
    x = training_data['normalized_coordinates']
    cond = self._conditioner.conditioning(training_data['residue_names'],
                                          training_data['atom_names'],
                                          training)

    n_batch = tf.shape(x)[0]

    # 1. RECONSTRUCTION LOSS
    # add noise and reconstruct
    f = self._encoder.encode(x, cond, training)
    x_mask = tf.cast(
        tf.math.reduce_any(
            tf.math.greater(tf.math.abs(x), 1e-6), axis=[-1]), tf.float32)
    loss_recon, recon_diff = self.recon_loss(x, f, x_mask, cond, training)

    # 2. LATENT LOSS
    # KL z1 with N(0,1) prior
    loss_klz = self.latent_loss(f, x_mask)

    # 3. Diffusion Loss.
    # Sample time steps.
    # Use anithetic time sampling.
    t0 = tf.random.uniform(shape=[])
    t = tf.math.floormod(t0 +
                         tf.range(0, 1, 1./tf.cast(n_batch, 'float32'),
                                  dtype='float32'), 1.0)

    # Discretize timesteps.
    T = self._timesteps
    t = tf.math.ceil(t*T) / T

    loss_diff, loss_diff_mse = self.diffusion_loss(t, f, x_mask, cond, training)
    return (loss_diff, loss_klz, loss_recon, loss_diff_mse, recon_diff)

  @tf.function(reduce_retracing=True)
  def sample_step(self, i, T, z_t, cond):
    eps = tf.random.normal(tf.shape(z_t))
    eps = tf.zeros(tf.shape(z_t))
    t =  tf.cast((T - i) / T, 'float32')
    s = tf.cast((T - i - 1) / T, 'float32')

    g_s = self.gamma(s)
    g_t = self.gamma(t)
    sigma2_t = self.sigma2(g_t)
    sigma2_s = self.sigma2(g_s)

    sigma_t = tf.math.sqrt(self.sigma2(g_t))
    sigma_s = tf.math.sqrt(self.sigma2(g_s))

    alpha_t = self.alpha(g_t)
    alpha_s = self.alpha(g_s)

    alpha_t_s = alpha_t/alpha_s
    sigma2_t_s = sigma2_t - tf.math.square(alpha_t_s)*sigma2_s

    eps_hat_cond = self._scorer.score(z_t, g_t, cond, training=False)

    #a = self.sigma2(g_s)
    #b = self.sigma2(g_t)
    #c = -tf.math.expm1(g_t - g_s)
    x = (z_t -sigma_t*eps_hat_cond)/self.alpha(g_t)
    z_s = (alpha_t_s * sigma2_s * z_t / sigma2_t) + (alpha_s * sigma2_t_s * x /sigma2_t)
    #z_s = tf.math.sqrt(a / b) * (z_t - sigma_t * c * eps_hat_cond) + tf.math.sqrt((1. - a) * c) * eps
    return z_s

  def reconstruct(self, t, training_data):
    # Compute x and the conditioning.
    x = training_data['normalized_coordinates']
    cond = self._conditioner.conditioning(training_data['residue_names'],
                                          training_data['atom_names'], training=False)
    # Encode x into the embedding space.
    z_0 = self._encoder.encode(x, cond, training=False)

    # Introduce the error.
    T = self._timesteps
    tn = math.ceil(t * T)
    t = tn / T
    print("t", t)
    g_t = self.gamma(t)
    eps  = tf.random.normal(tf.shape(z_0))
    print('true eps', eps)
    z_with_error = self.variance_preserving_map(z_0, g_t, eps)
    print('z_0', z_0)
    z_t = z_with_error
    print('z_t', z_t)

    # Remove the error.
    for i in range(T-tn, T):
      if i%100==0:
        print(i)
      z_t = self.sample_step(tf.constant(i), T, z_t, cond)

    # Decode from the embedding space.
    g0 = self.gamma(0)
    z_0_rescaled = z_t /  self.alpha(g0)
    print('z_0', z_t)
    print('z_0_rescaled', z_0_rescaled)
    return (self._decoder.decode(z_with_error / self.alpha(g0), cond, training=False),
            self._decoder.decode(z_0_rescaled, cond, training=False),
            z_0, z_with_error, z_t)

  # Computes the diffusion loss at multiple timesteps.
  def MSEAtTimesteps(self, ts, training_data):
    x = training_data['normalized_coordinates']
    cond = self._conditioner.conditioning(training_data['residue_names'],
                                          training_data['atom_names'])
    z_0 = self._encoder.encode(x, cond)
    x_mask = tf.cast(tf.math.reduce_any(tf.math.not_equal(x, 0), axis=[-1]), tf.float32)
    for t in ts:
      print(self.diffusion_loss(t, z_0, x_mask, cond))

  def set_scorer(self, scorer):
    self._scorer = scorer

  def set_gamma_module(self, gamma_module):
    self._gamma_module = gamma_module

  def save(self, location):
    self._decoder.save(location + '/decoder_model')
    self._encoder.save(location + '/encoder_model')
    self._conditioner.save(location + '/conditioner_model')
    self._scorer.save(location + '/scorer_model')
    tf.saved_model.save(self._gamma_module, location + '/gamma_module')

In [127]:
def LoadDiffusionModel(location_prefix):
  return DiffusionModel(
      residue_lookup_size, atom_lookup_size,
      tf.saved_model.load(location_prefix + '/gamma_module'),
      DecoderTrain(tf.keras.models.load_model(location_prefix+'/decoder_model')),
      EncoderTrain(tf.keras.models.load_model(location_prefix+'/encoder_model')),
      CondTrain(tf.keras.models.load_model(location_prefix+'/conditioner_model')),
      ScoreTrain(tf.keras.models.load_model(location_prefix+'/scorer_model')))

In [128]:
diffusion_model = DiffusionModel(
    residue_lookup_size, atom_lookup_size, GammaModule(),
    DecoderTrain(DecoderModel()),
    EncoderTrain(EncoderModel()), CondTrain(CondModel()),
    ScoreTrain(ScoreModel()))

In [129]:
print(diffusion_model._gamma_module.trainable_variables)

()


In [130]:
#diffusion_model = LoadDiffusionModel(
#    'gs://variational_shaker_models/multiple_protein_test_with_more_transformers2/version_7400')

In [131]:
#print(diffusion_model._gamma_module.gamma_min)
#print(diffusion_model._gamma_module.gamma_max)
print(dir(diffusion_model._gamma_module))
diffusion_model._gamma_module=GammaModule()

['GetGamma', '_TF_MODULE_IGNORED_PROPERTIES', '__annotations__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_add_trackable_child', '_add_variable_with_custom_getter', '_checkpoint_dependencies', '_deferred_dependencies', '_delete_tracking', '_deserialization_dependencies', '_deserialize_from_proto', '_export_to_saved_model_graph', '_flatten', '_gather_saveables_for_checkpoint', '_handle_deferred_dependencies', '_l1', '_l2', '_l3', '_lookup_dependency', '_maybe_initialize_trackable', '_name_based_attribute_restore', '_name_based_restores', '_no_dependency', '_object_identifier', '_preload_simple_restoration', '_restore_from_tensors', '_self_name_based_restores', '_self_saveable_object_facto

## Train the Model

In [132]:
optimizer = tf.keras.optimizers.Adam()
#optimizer.build(diffusion_model.trainable_weights())

In [133]:
@tf.function(reduce_retracing=True)
def train_step(training_data):
  with tf.GradientTape() as tape:
    l1, l2, l3, loss_diff_mse, recon_diff = diffusion_model.compute_model_loss(training_data)
    loss = tf.reduce_mean(l1+l2+l3)
  trainable_weights = diffusion_model.trainable_weights()
  grads = tape.gradient(loss, trainable_weights)
  optimizer.apply_gradients(zip(grads, trainable_weights))
  return tf.reduce_mean(l1), tf.reduce_mean(l2), tf.reduce_mean(l3), tf.reduce_mean(loss_diff_mse/Z_EMBEDDING_SIZE), recon_diff

In [134]:
@tf.function(reduce_retracing=True)
def recon_step(training_data):
  with tf.GradientTape() as tape:
    l1, l2, l3, loss_diff_mse = diffusion_model.compute_model_loss(training_data)
    loss = tf.reduce_mean(l3)
  trainable_weights = diffusion_model.decoder_weights()
  grads = tape.gradient(loss, trainable_weights)
  optimizer.apply_gradients(zip(grads, trainable_weights))
  return l1, l2, l3, loss_diff_mse

In [135]:
epochs = 2
batch_size = 32
train_ds = ds.repeat().shuffle(1000).padded_batch(
    batch_size,
    padded_shapes={
        'name': [],
        'residue_names': [None],
        'atom_names': [None],
        'normalized_coordinates': [None, 3],
    }).prefetch(10)
# train_ds = ds.take(1).repeat().batch(batch_size).prefetch(10)

In [None]:
for epoch in range(epochs):
  print("Start of epoch %d" % (epoch,))
  for step, training_data in train_ds.enumerate():
    l1, l2, l3, loss_diff_mse, recon_diff = train_step(training_data)
    if step % 10==0:
      #print("Training Loss (for one batch) at step %d: %.4f"
      #  % (step, float(tf.reduce_mean(l1+l2+l3))))
      print("Training L1 Loss (for one batch) at step %d: %.4f"
        % (step, l1))
      print("Training L2 Loss (for one batch) at step %d: %.4f"
        % (step, l2))
      print("Training L3 Loss (for one batch) at step %d: %.4f"
        % (step, l3))
      print("loss_diff_mse (for one batch) at step %d: %.4f"
        % (step, loss_diff_mse))
      print("recon_diff (for one batch) at step %d: %.4f"
        % (step, float(recon_diff)))
      #print("l1: ", l1)
      #print("l2: ", l2)
      #print("l3: ", l3)
      print("Seen so far: %s samples" % ((step + 1) * batch_size))
    if step % 200==0:
      diffusion_model.save(
          'gs://variational_shaker_models/train_gamma/version_%d' % step)

Start of epoch 0




Training L1 Loss (for one batch) at step 0: 22282.1250
Training L2 Loss (for one batch) at step 0: 30498.4531
Training L3 Loss (for one batch) at step 0: 268968.2812
loss_diff_mse (for one batch) at step 0: 2.4146
recon_diff (for one batch) at step 0: 26.8692
Seen so far: tf.Tensor(32, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 10: 14537.4668
Training L2 Loss (for one batch) at step 10: 0.4629
Training L3 Loss (for one batch) at step 10: 266706.9688
loss_diff_mse (for one batch) at step 10: 1.5531
recon_diff (for one batch) at step 10: 26.7438
Seen so far: tf.Tensor(352, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 20: 8653.0996
Training L2 Loss (for one batch) at step 20: 1.0670
Training L3 Loss (for one batch) at step 20: 247891.2812
loss_diff_mse (for one batch) at step 20: 1.0119
recon_diff (for one batch) at step 20: 26.0558
Seen so far: tf.Tensor(672, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 30: 9647.5010
Training L2 Loss (for one batch) at step 30: 1.8017
Training L3 Loss (for one batch) at step 30: 351183.8750
loss_diff_mse (for one batch) at step 30: 0.9908
recon_diff (for one batch) at step 30: 28.0808
Seen so far: tf.Tensor(992, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4



Training L1 Loss (for one batch) at step 200: 446.3371
Training L2 Loss (for one batch) at step 200: 7054.3179
Training L3 Loss (for one batch) at step 200: 15841.9141
loss_diff_mse (for one batch) at step 200: 0.1101
recon_diff (for one batch) at step 200: 5.7317
Seen so far: tf.Tensor(6432, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 210: 542.4154
Training L2 Loss (for one batch) at step 210: 7018.5918
Training L3 Loss (for one batch) at step 210: 16454.3633
loss_diff_mse (for one batch) at step 210: 0.1230
recon_diff (for one batch) at step 210: 5.5668
Seen so far: tf.Tensor(6752, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 220: 388.6824
Training L2 Loss (for one batch) at step 220: 7367.2095
Training L3 Loss (for one batch) at step 220: 15710.7158
loss_diff_mse (for one batch) at step 220: 0.0893
recon_diff (for one batch) at step 220: 5.4422
Seen so far: tf.Tensor(7072, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 230: 287.5043
Training L2 Loss (for one batch) at step 230: 6788.0283
Training L3 Loss (for one batch) at step 230: 15051.6396
loss_diff_mse (for one batch) at step 230: 0.0676
recon_diff (for one batch) at step 230: 5.4535
Seen so far: tf.Tensor(7392, shape=(), dtype=int64) samples
Training L1 Loss (for one



Training L1 Loss (for one batch) at step 400: 171.4029
Training L2 Loss (for one batch) at step 400: 7529.3086
Training L3 Loss (for one batch) at step 400: 17441.4863
loss_diff_mse (for one batch) at step 400: 0.0282
recon_diff (for one batch) at step 400: 5.3529
Seen so far: tf.Tensor(12832, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 410: 158.8262
Training L2 Loss (for one batch) at step 410: 6384.1338
Training L3 Loss (for one batch) at step 410: 13966.4141
loss_diff_mse (for one batch) at step 410: 0.0279
recon_diff (for one batch) at step 410: 5.0599
Seen so far: tf.Tensor(13152, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 420: 162.1557
Training L2 Loss (for one batch) at step 420: 6648.6519
Training L3 Loss (for one batch) at step 420: 14520.5742
loss_diff_mse (for one batch) at step 420: 0.0276
recon_diff (for one batch) at step 420: 5.1198
Seen so far: tf.Tensor(13472, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 430: 149.6530
Training L2 Loss (for one batch) at step 430: 6780.7817
Training L3 Loss (for one batch) at step 430: 15350.3848
loss_diff_mse (for one batch) at step 430: 0.0236
recon_diff (for one batch) at step 430: 5.1333
Seen so far: tf.Tensor(13792, shape=(), dtype=int64) samples
Training L1 Loss (for 



Training L1 Loss (for one batch) at step 600: 174.6348
Training L2 Loss (for one batch) at step 600: 5992.1143
Training L3 Loss (for one batch) at step 600: 13048.0703
loss_diff_mse (for one batch) at step 600: 0.0229
recon_diff (for one batch) at step 600: 4.8477
Seen so far: tf.Tensor(19232, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 610: 187.0374
Training L2 Loss (for one batch) at step 610: 6414.1362
Training L3 Loss (for one batch) at step 610: 13425.6670
loss_diff_mse (for one batch) at step 610: 0.0236
recon_diff (for one batch) at step 610: 4.8516
Seen so far: tf.Tensor(19552, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 620: 215.1224
Training L2 Loss (for one batch) at step 620: 7176.3770
Training L3 Loss (for one batch) at step 620: 15323.4512
loss_diff_mse (for one batch) at step 620: 0.0237
recon_diff (for one batch) at step 620: 4.8989
Seen so far: tf.Tensor(19872, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 630: 167.5121
Training L2 Loss (for one batch) at step 630: 6377.5078
Training L3 Loss (for one batch) at step 630: 14131.3838
loss_diff_mse (for one batch) at step 630: 0.0198
recon_diff (for one batch) at step 630: 4.9410
Seen so far: tf.Tensor(20192, shape=(), dtype=int64) samples
Training L1 Loss (for 



Training L1 Loss (for one batch) at step 800: 249.3883
Training L2 Loss (for one batch) at step 800: 5504.9707
Training L3 Loss (for one batch) at step 800: 12195.2568
loss_diff_mse (for one batch) at step 800: 0.0246
recon_diff (for one batch) at step 800: 4.6885
Seen so far: tf.Tensor(25632, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 810: 245.5468
Training L2 Loss (for one batch) at step 810: 5956.6328
Training L3 Loss (for one batch) at step 810: 12774.6973
loss_diff_mse (for one batch) at step 810: 0.0224
recon_diff (for one batch) at step 810: 4.6348
Seen so far: tf.Tensor(25952, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 820: 194.8010
Training L2 Loss (for one batch) at step 820: 5279.8672
Training L3 Loss (for one batch) at step 820: 11580.9541
loss_diff_mse (for one batch) at step 820: 0.0195
recon_diff (for one batch) at step 820: 4.6320
Seen so far: tf.Tensor(26272, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 830: 197.1440
Training L2 Loss (for one batch) at step 830: 5741.1533
Training L3 Loss (for one batch) at step 830: 14581.8086
loss_diff_mse (for one batch) at step 830: 0.0177
recon_diff (for one batch) at step 830: 5.1332
Seen so far: tf.Tensor(26592, shape=(), dtype=int64) samples
Training L1 Loss (for 



Training L1 Loss (for one batch) at step 1000: 253.3547
Training L2 Loss (for one batch) at step 1000: 5965.6572
Training L3 Loss (for one batch) at step 1000: 13228.7793
loss_diff_mse (for one batch) at step 1000: 0.0173
recon_diff (for one batch) at step 1000: 4.5900
Seen so far: tf.Tensor(32032, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 1010: 279.0957
Training L2 Loss (for one batch) at step 1010: 6675.6030
Training L3 Loss (for one batch) at step 1010: 16011.9082
loss_diff_mse (for one batch) at step 1010: 0.0168
recon_diff (for one batch) at step 1010: 4.8362
Seen so far: tf.Tensor(32352, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1020: 253.6082
Training L2 Loss (for one batch) at step 1020: 5656.8418
Training L3 Loss (for one batch) at step 1020: 12461.6533
loss_diff_mse (for one batch) at step 1020: 0.0180
recon_diff (for one batch) at step 1020: 4.5877
Seen so far: tf.Tensor(32672, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1030: 523.8007
Training L2 Loss (for one batch) at step 1030: 6264.0430
Training L3 Loss (for one batch) at step 1030: 14120.2041
loss_diff_mse (for one batch) at step 1030: 0.0331
recon_diff (for one batch) at step 1030: 4.6532
Seen so far: tf.Tensor(32992, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 1200: 354.2581
Training L2 Loss (for one batch) at step 1200: 5268.2681
Training L3 Loss (for one batch) at step 1200: 12422.0000
loss_diff_mse (for one batch) at step 1200: 0.0190
recon_diff (for one batch) at step 1200: 4.4395
Seen so far: tf.Tensor(38432, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 1210: 511.1753
Training L2 Loss (for one batch) at step 1210: 2615.6877
Training L3 Loss (for one batch) at step 1210: 16391.8379
loss_diff_mse (for one batch) at step 1210: 0.0282
recon_diff (for one batch) at step 1210: 5.6314
Seen so far: tf.Tensor(38752, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1220: 437.5104
Training L2 Loss (for one batch) at step 1220: 5455.9463
Training L3 Loss (for one batch) at step 1220: 11460.5879
loss_diff_mse (for one batch) at step 1220: 0.0250
recon_diff (for one batch) at step 1220: 4.4096
Seen so far: tf.Tensor(39072, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1230: 422.8927
Training L2 Loss (for one batch) at step 1230: 7597.2744
Training L3 Loss (for one batch) at step 1230: 11096.2012
loss_diff_mse (for one batch) at step 1230: 0.0220
recon_diff (for one batch) at step 1230: 4.0315
Seen so far: tf.Tensor(39392, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 1400: 516.6407
Training L2 Loss (for one batch) at step 1400: 4640.4893
Training L3 Loss (for one batch) at step 1400: 11297.9707
loss_diff_mse (for one batch) at step 1400: 0.0215
recon_diff (for one batch) at step 1400: 4.2714
Seen so far: tf.Tensor(44832, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 1410: 441.0288
Training L2 Loss (for one batch) at step 1410: 4673.2256
Training L3 Loss (for one batch) at step 1410: 12166.4277
loss_diff_mse (for one batch) at step 1410: 0.0173
recon_diff (for one batch) at step 1410: 4.4035
Seen so far: tf.Tensor(45152, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1420: 490.2620
Training L2 Loss (for one batch) at step 1420: 5472.7051
Training L3 Loss (for one batch) at step 1420: 15628.4785
loss_diff_mse (for one batch) at step 1420: 0.0167
recon_diff (for one batch) at step 1420: 4.7837
Seen so far: tf.Tensor(45472, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1430: 467.4472
Training L2 Loss (for one batch) at step 1430: 5191.3975
Training L3 Loss (for one batch) at step 1430: 14633.4238
loss_diff_mse (for one batch) at step 1430: 0.0158
recon_diff (for one batch) at step 1430: 4.5304
Seen so far: tf.Tensor(45792, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 1600: 660.3947
Training L2 Loss (for one batch) at step 1600: 3249.6909
Training L3 Loss (for one batch) at step 1600: 11799.6914
loss_diff_mse (for one batch) at step 1600: 0.0177
recon_diff (for one batch) at step 1600: 4.2867
Seen so far: tf.Tensor(51232, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 1610: 753.6392
Training L2 Loss (for one batch) at step 1610: 3695.5828
Training L3 Loss (for one batch) at step 1610: 13356.9395
loss_diff_mse (for one batch) at step 1610: 0.0195
recon_diff (for one batch) at step 1610: 4.4893
Seen so far: tf.Tensor(51552, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1620: 696.1642
Training L2 Loss (for one batch) at step 1620: 3650.4873
Training L3 Loss (for one batch) at step 1620: 12441.4199
loss_diff_mse (for one batch) at step 1620: 0.0186
recon_diff (for one batch) at step 1620: 4.6244
Seen so far: tf.Tensor(51872, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1630: 620.7281
Training L2 Loss (for one batch) at step 1630: 3231.0176
Training L3 Loss (for one batch) at step 1630: 10048.8945
loss_diff_mse (for one batch) at step 1630: 0.0171
recon_diff (for one batch) at step 1630: 4.0784
Seen so far: tf.Tensor(52192, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 1800: 1042.3284
Training L2 Loss (for one batch) at step 1800: 2652.1650
Training L3 Loss (for one batch) at step 1800: 10119.6270
loss_diff_mse (for one batch) at step 1800: 0.0205
recon_diff (for one batch) at step 1800: 3.8739
Seen so far: tf.Tensor(57632, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 1810: 1171.8296
Training L2 Loss (for one batch) at step 1810: 2948.7139
Training L3 Loss (for one batch) at step 1810: 12102.0234
loss_diff_mse (for one batch) at step 1810: 0.0214
recon_diff (for one batch) at step 1810: 4.1677
Seen so far: tf.Tensor(57952, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1820: 1041.3417
Training L2 Loss (for one batch) at step 1820: 2884.1970
Training L3 Loss (for one batch) at step 1820: 9772.0635
loss_diff_mse (for one batch) at step 1820: 0.0201
recon_diff (for one batch) at step 1820: 3.8121
Seen so far: tf.Tensor(58272, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 1830: 1010.7114
Training L2 Loss (for one batch) at step 1830: 2796.9668
Training L3 Loss (for one batch) at step 1830: 9383.2383
loss_diff_mse (for one batch) at step 1830: 0.0185
recon_diff (for one batch) at step 1830: 3.6050
Seen so far: tf.Tensor(58592, shape=(), dtype=int64) samples
Traini



Training L1 Loss (for one batch) at step 2000: 1240.2562
Training L2 Loss (for one batch) at step 2000: 1624.8923
Training L3 Loss (for one batch) at step 2000: 9670.1309
loss_diff_mse (for one batch) at step 2000: 0.0172
recon_diff (for one batch) at step 2000: 3.8149
Seen so far: tf.Tensor(64032, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 2010: 1180.0657
Training L2 Loss (for one batch) at step 2010: 1956.1213
Training L3 Loss (for one batch) at step 2010: 9838.0625
loss_diff_mse (for one batch) at step 2010: 0.0158
recon_diff (for one batch) at step 2010: 3.7499
Seen so far: tf.Tensor(64352, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2020: 1236.9504
Training L2 Loss (for one batch) at step 2020: 1895.9197
Training L3 Loss (for one batch) at step 2020: 8721.5684
loss_diff_mse (for one batch) at step 2020: 0.0173
recon_diff (for one batch) at step 2020: 3.5342
Seen so far: tf.Tensor(64672, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2030: 1288.5400
Training L2 Loss (for one batch) at step 2030: 1974.8945
Training L3 Loss (for one batch) at step 2030: 9269.4453
loss_diff_mse (for one batch) at step 2030: 0.0166
recon_diff (for one batch) at step 2030: 3.5035
Seen so far: tf.Tensor(64992, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 2200: 1540.1877
Training L2 Loss (for one batch) at step 2200: 1626.3549
Training L3 Loss (for one batch) at step 2200: 7764.5322
loss_diff_mse (for one batch) at step 2200: 0.0182
recon_diff (for one batch) at step 2200: 3.0838
Seen so far: tf.Tensor(70432, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 2210: 1766.0934
Training L2 Loss (for one batch) at step 2210: 1819.9814
Training L3 Loss (for one batch) at step 2210: 9655.2090
loss_diff_mse (for one batch) at step 2210: 0.0175
recon_diff (for one batch) at step 2210: 3.3040
Seen so far: tf.Tensor(70752, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2220: 2027.6409
Training L2 Loss (for one batch) at step 2220: 1102.1907
Training L3 Loss (for one batch) at step 2220: 13981.3809
loss_diff_mse (for one batch) at step 2220: 0.0183
recon_diff (for one batch) at step 2220: 4.2190
Seen so far: tf.Tensor(71072, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2230: 2003.7568
Training L2 Loss (for one batch) at step 2230: 2155.5996
Training L3 Loss (for one batch) at step 2230: 11753.5186
loss_diff_mse (for one batch) at step 2230: 0.0201
recon_diff (for one batch) at step 2230: 3.6947
Seen so far: tf.Tensor(71392, shape=(), dtype=int64) samples
Train



Training L1 Loss (for one batch) at step 2400: 1493.8928
Training L2 Loss (for one batch) at step 2400: 836.4334
Training L3 Loss (for one batch) at step 2400: 7127.4380
loss_diff_mse (for one batch) at step 2400: 0.0146
recon_diff (for one batch) at step 2400: 3.1120
Seen so far: tf.Tensor(76832, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 2410: 1640.3486
Training L2 Loss (for one batch) at step 2410: 1009.8364
Training L3 Loss (for one batch) at step 2410: 8335.5918
loss_diff_mse (for one batch) at step 2410: 0.0138
recon_diff (for one batch) at step 2410: 3.1328
Seen so far: tf.Tensor(77152, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2420: 1559.5298
Training L2 Loss (for one batch) at step 2420: 466.2689
Training L3 Loss (for one batch) at step 2420: 10447.2578
loss_diff_mse (for one batch) at step 2420: 0.0132
recon_diff (for one batch) at step 2420: 3.9925
Seen so far: tf.Tensor(77472, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2430: 1622.7246
Training L2 Loss (for one batch) at step 2430: 902.7407
Training L3 Loss (for one batch) at step 2430: 8100.9492
loss_diff_mse (for one batch) at step 2430: 0.0150
recon_diff (for one batch) at step 2430: 3.3682
Seen so far: tf.Tensor(77792, shape=(), dtype=int64) samples
Training



Training L1 Loss (for one batch) at step 2600: 1626.4446
Training L2 Loss (for one batch) at step 2600: 1130.3893
Training L3 Loss (for one batch) at step 2600: 7031.7051
loss_diff_mse (for one batch) at step 2600: 0.0145
recon_diff (for one batch) at step 2600: 2.8246
Seen so far: tf.Tensor(83232, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 2610: 1594.7263
Training L2 Loss (for one batch) at step 2610: 1128.1980
Training L3 Loss (for one batch) at step 2610: 7529.1504
loss_diff_mse (for one batch) at step 2610: 0.0142
recon_diff (for one batch) at step 2610: 3.0180
Seen so far: tf.Tensor(83552, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2620: 1682.7400
Training L2 Loss (for one batch) at step 2620: 1160.6362
Training L3 Loss (for one batch) at step 2620: 8221.5215
loss_diff_mse (for one batch) at step 2620: 0.0134
recon_diff (for one batch) at step 2620: 3.0577
Seen so far: tf.Tensor(83872, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2630: 1533.0469
Training L2 Loss (for one batch) at step 2630: 932.4302
Training L3 Loss (for one batch) at step 2630: 7291.1938
loss_diff_mse (for one batch) at step 2630: 0.0138
recon_diff (for one batch) at step 2630: 3.0888
Seen so far: tf.Tensor(84192, shape=(), dtype=int64) samples
Training



Training L1 Loss (for one batch) at step 2800: 1900.2917
Training L2 Loss (for one batch) at step 2800: 1059.1509
Training L3 Loss (for one batch) at step 2800: 8564.0273
loss_diff_mse (for one batch) at step 2800: 0.0133
recon_diff (for one batch) at step 2800: 3.0055
Seen so far: tf.Tensor(89632, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 2810: 1888.3101
Training L2 Loss (for one batch) at step 2810: 1057.1665
Training L3 Loss (for one batch) at step 2810: 8259.3828
loss_diff_mse (for one batch) at step 2810: 0.0130
recon_diff (for one batch) at step 2810: 2.8267
Seen so far: tf.Tensor(89952, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2820: 2038.8024
Training L2 Loss (for one batch) at step 2820: 1112.8822
Training L3 Loss (for one batch) at step 2820: 9026.3691
loss_diff_mse (for one batch) at step 2820: 0.0131
recon_diff (for one batch) at step 2820: 2.9633
Seen so far: tf.Tensor(90272, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 2830: 1728.6257
Training L2 Loss (for one batch) at step 2830: 897.8396
Training L3 Loss (for one batch) at step 2830: 9057.5889
loss_diff_mse (for one batch) at step 2830: 0.0127
recon_diff (for one batch) at step 2830: 3.2187
Seen so far: tf.Tensor(90592, shape=(), dtype=int64) samples
Training



Training L1 Loss (for one batch) at step 3000: 1521.2040
Training L2 Loss (for one batch) at step 3000: 531.5060
Training L3 Loss (for one batch) at step 3000: 6020.2393
loss_diff_mse (for one batch) at step 3000: 0.0119
recon_diff (for one batch) at step 3000: 2.6517
Seen so far: tf.Tensor(96032, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 3010: 2178.4434
Training L2 Loss (for one batch) at step 3010: 764.5062
Training L3 Loss (for one batch) at step 3010: 8081.6523
loss_diff_mse (for one batch) at step 3010: 0.0129
recon_diff (for one batch) at step 3010: 2.7239
Seen so far: tf.Tensor(96352, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3020: 1908.9268
Training L2 Loss (for one batch) at step 3020: 646.4077
Training L3 Loss (for one batch) at step 3020: 6592.7334
loss_diff_mse (for one batch) at step 3020: 0.0140
recon_diff (for one batch) at step 3020: 2.6999
Seen so far: tf.Tensor(96672, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3030: 1978.2717
Training L2 Loss (for one batch) at step 3030: 711.5209
Training L3 Loss (for one batch) at step 3030: 7161.0557
loss_diff_mse (for one batch) at step 3030: 0.0134
recon_diff (for one batch) at step 3030: 2.6792
Seen so far: tf.Tensor(96992, shape=(), dtype=int64) samples
Training L



Training L1 Loss (for one batch) at step 3200: 1831.7810
Training L2 Loss (for one batch) at step 3200: 768.2295
Training L3 Loss (for one batch) at step 3200: 7724.6719
loss_diff_mse (for one batch) at step 3200: 0.0117
recon_diff (for one batch) at step 3200: 2.7835
Seen so far: tf.Tensor(102432, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 3210: 1647.8274
Training L2 Loss (for one batch) at step 3210: 641.1381
Training L3 Loss (for one batch) at step 3210: 6182.7217
loss_diff_mse (for one batch) at step 3210: 0.0120
recon_diff (for one batch) at step 3210: 2.5520
Seen so far: tf.Tensor(102752, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3220: 2875.7944
Training L2 Loss (for one batch) at step 3220: 806.3392
Training L3 Loss (for one batch) at step 3220: 7665.1289
loss_diff_mse (for one batch) at step 3220: 0.0184
recon_diff (for one batch) at step 3220: 2.7863
Seen so far: tf.Tensor(103072, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3230: 2064.6624
Training L2 Loss (for one batch) at step 3230: 765.9614
Training L3 Loss (for one batch) at step 3230: 7587.0435
loss_diff_mse (for one batch) at step 3230: 0.0131
recon_diff (for one batch) at step 3230: 2.7451
Seen so far: tf.Tensor(103392, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 3400: 1930.4717
Training L2 Loss (for one batch) at step 3400: 536.6824
Training L3 Loss (for one batch) at step 3400: 7072.4746
loss_diff_mse (for one batch) at step 3400: 0.0116
recon_diff (for one batch) at step 3400: 2.5768
Seen so far: tf.Tensor(108832, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 3410: 2102.3320
Training L2 Loss (for one batch) at step 3410: 656.4194
Training L3 Loss (for one batch) at step 3410: 8067.8271
loss_diff_mse (for one batch) at step 3410: 0.0111
recon_diff (for one batch) at step 3410: 2.6219
Seen so far: tf.Tensor(109152, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3420: 1693.7909
Training L2 Loss (for one batch) at step 3420: 461.9528
Training L3 Loss (for one batch) at step 3420: 5816.2676
loss_diff_mse (for one batch) at step 3420: 0.0113
recon_diff (for one batch) at step 3420: 2.3737
Seen so far: tf.Tensor(109472, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3430: 2170.0015
Training L2 Loss (for one batch) at step 3430: 538.2037
Training L3 Loss (for one batch) at step 3430: 8600.7656
loss_diff_mse (for one batch) at step 3430: 0.0112
recon_diff (for one batch) at step 3430: 2.8323
Seen so far: tf.Tensor(109792, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 3600: 2185.0156
Training L2 Loss (for one batch) at step 3600: 638.8967
Training L3 Loss (for one batch) at step 3600: 7528.4731
loss_diff_mse (for one batch) at step 3600: 0.0110
recon_diff (for one batch) at step 3600: 2.4367
Seen so far: tf.Tensor(115232, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 3610: 2001.4406
Training L2 Loss (for one batch) at step 3610: 544.3393
Training L3 Loss (for one batch) at step 3610: 6609.2246
loss_diff_mse (for one batch) at step 3610: 0.0113
recon_diff (for one batch) at step 3610: 2.4256
Seen so far: tf.Tensor(115552, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3620: 1756.9668
Training L2 Loss (for one batch) at step 3620: 347.2194
Training L3 Loss (for one batch) at step 3620: 6561.6396
loss_diff_mse (for one batch) at step 3620: 0.0103
recon_diff (for one batch) at step 3620: 2.6353
Seen so far: tf.Tensor(115872, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3630: 2137.6206
Training L2 Loss (for one batch) at step 3630: 542.9495
Training L3 Loss (for one batch) at step 3630: 6322.6318
loss_diff_mse (for one batch) at step 3630: 0.0125
recon_diff (for one batch) at step 3630: 2.3295
Seen so far: tf.Tensor(116192, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 3800: 2155.4863
Training L2 Loss (for one batch) at step 3800: 485.4724
Training L3 Loss (for one batch) at step 3800: 6659.2500
loss_diff_mse (for one batch) at step 3800: 0.0116
recon_diff (for one batch) at step 3800: 2.3668
Seen so far: tf.Tensor(121632, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 3810: 2084.3438
Training L2 Loss (for one batch) at step 3810: 517.4825
Training L3 Loss (for one batch) at step 3810: 6980.1787
loss_diff_mse (for one batch) at step 3810: 0.0110
recon_diff (for one batch) at step 3810: 2.5005
Seen so far: tf.Tensor(121952, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3820: 2174.0884
Training L2 Loss (for one batch) at step 3820: 542.2327
Training L3 Loss (for one batch) at step 3820: 6941.7734
loss_diff_mse (for one batch) at step 3820: 0.0112
recon_diff (for one batch) at step 3820: 2.3641
Seen so far: tf.Tensor(122272, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 3830: 2529.0269
Training L2 Loss (for one batch) at step 3830: 661.4733
Training L3 Loss (for one batch) at step 3830: 9730.2158
loss_diff_mse (for one batch) at step 3830: 0.0112
recon_diff (for one batch) at step 3830: 2.8673
Seen so far: tf.Tensor(122592, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 4000: 1824.6067
Training L2 Loss (for one batch) at step 4000: 451.8321
Training L3 Loss (for one batch) at step 4000: 6479.1118
loss_diff_mse (for one batch) at step 4000: 0.0096
recon_diff (for one batch) at step 4000: 2.3468
Seen so far: tf.Tensor(128032, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 4010: 2268.8599
Training L2 Loss (for one batch) at step 4010: 549.2101
Training L3 Loss (for one batch) at step 4010: 7625.3643
loss_diff_mse (for one batch) at step 4010: 0.0105
recon_diff (for one batch) at step 4010: 2.4442
Seen so far: tf.Tensor(128352, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4020: 1898.2255
Training L2 Loss (for one batch) at step 4020: 444.5411
Training L3 Loss (for one batch) at step 4020: 6249.7617
loss_diff_mse (for one batch) at step 4020: 0.0105
recon_diff (for one batch) at step 4020: 2.3137
Seen so far: tf.Tensor(128672, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4030: 1853.1147
Training L2 Loss (for one batch) at step 4030: 459.1529
Training L3 Loss (for one batch) at step 4030: 6571.5078
loss_diff_mse (for one batch) at step 4030: 0.0100
recon_diff (for one batch) at step 4030: 2.3781
Seen so far: tf.Tensor(128992, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 4200: 2225.0872
Training L2 Loss (for one batch) at step 4200: 378.5668
Training L3 Loss (for one batch) at step 4200: 6957.4302
loss_diff_mse (for one batch) at step 4200: 0.0109
recon_diff (for one batch) at step 4200: 2.4583
Seen so far: tf.Tensor(134432, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 4210: 2274.8059
Training L2 Loss (for one batch) at step 4210: 419.4561
Training L3 Loss (for one batch) at step 4210: 7297.4624
loss_diff_mse (for one batch) at step 4210: 0.0100
recon_diff (for one batch) at step 4210: 2.3568
Seen so far: tf.Tensor(134752, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4220: 1994.3290
Training L2 Loss (for one batch) at step 4220: 422.6694
Training L3 Loss (for one batch) at step 4220: 6399.4717
loss_diff_mse (for one batch) at step 4220: 0.0104
recon_diff (for one batch) at step 4220: 2.3481
Seen so far: tf.Tensor(135072, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4230: 1796.9205
Training L2 Loss (for one batch) at step 4230: 361.9186
Training L3 Loss (for one batch) at step 4230: 5385.1689
loss_diff_mse (for one batch) at step 4230: 0.0105
recon_diff (for one batch) at step 4230: 2.1111
Seen so far: tf.Tensor(135392, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 4400: 1780.6255
Training L2 Loss (for one batch) at step 4400: 308.5016
Training L3 Loss (for one batch) at step 4400: 5678.9531
loss_diff_mse (for one batch) at step 4400: 0.0100
recon_diff (for one batch) at step 4400: 2.1944
Seen so far: tf.Tensor(140832, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 4410: 1869.0417
Training L2 Loss (for one batch) at step 4410: 342.3677
Training L3 Loss (for one batch) at step 4410: 6270.1768
loss_diff_mse (for one batch) at step 4410: 0.0092
recon_diff (for one batch) at step 4410: 2.1998
Seen so far: tf.Tensor(141152, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4420: 1774.4305
Training L2 Loss (for one batch) at step 4420: 269.9074
Training L3 Loss (for one batch) at step 4420: 6917.7358
loss_diff_mse (for one batch) at step 4420: 0.0086
recon_diff (for one batch) at step 4420: 2.4738
Seen so far: tf.Tensor(141472, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4430: 2104.5249
Training L2 Loss (for one batch) at step 4430: 378.8951
Training L3 Loss (for one batch) at step 4430: 6724.1846
loss_diff_mse (for one batch) at step 4430: 0.0100
recon_diff (for one batch) at step 4430: 2.2421
Seen so far: tf.Tensor(141792, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 4600: 1970.6548
Training L2 Loss (for one batch) at step 4600: 191.1458
Training L3 Loss (for one batch) at step 4600: 6114.5674
loss_diff_mse (for one batch) at step 4600: 0.0089
recon_diff (for one batch) at step 4600: 2.1702
Seen so far: tf.Tensor(147232, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 4610: 2115.2173
Training L2 Loss (for one batch) at step 4610: 229.6938
Training L3 Loss (for one batch) at step 4610: 6804.7129
loss_diff_mse (for one batch) at step 4610: 0.0096
recon_diff (for one batch) at step 4610: 2.4329
Seen so far: tf.Tensor(147552, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4620: 1948.9631
Training L2 Loss (for one batch) at step 4620: 247.1150
Training L3 Loss (for one batch) at step 4620: 6258.3599
loss_diff_mse (for one batch) at step 4620: 0.0089
recon_diff (for one batch) at step 4620: 2.1963
Seen so far: tf.Tensor(147872, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4630: 1910.6577
Training L2 Loss (for one batch) at step 4630: 249.0180
Training L3 Loss (for one batch) at step 4630: 6240.3389
loss_diff_mse (for one batch) at step 4630: 0.0092
recon_diff (for one batch) at step 4630: 2.3573
Seen so far: tf.Tensor(148192, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 4800: 2258.3838
Training L2 Loss (for one batch) at step 4800: 254.7568
Training L3 Loss (for one batch) at step 4800: 6506.1709
loss_diff_mse (for one batch) at step 4800: 0.0100
recon_diff (for one batch) at step 4800: 2.2249
Seen so far: tf.Tensor(153632, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 4810: 2359.9805
Training L2 Loss (for one batch) at step 4810: 272.1558
Training L3 Loss (for one batch) at step 4810: 8121.8809
loss_diff_mse (for one batch) at step 4810: 0.0089
recon_diff (for one batch) at step 4810: 2.4663
Seen so far: tf.Tensor(153952, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4820: 1992.9500
Training L2 Loss (for one batch) at step 4820: 247.4921
Training L3 Loss (for one batch) at step 4820: 6537.4385
loss_diff_mse (for one batch) at step 4820: 0.0090
recon_diff (for one batch) at step 4820: 2.3285
Seen so far: tf.Tensor(154272, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 4830: 1965.4113
Training L2 Loss (for one batch) at step 4830: 188.7900
Training L3 Loss (for one batch) at step 4830: 7051.9131
loss_diff_mse (for one batch) at step 4830: 0.0083
recon_diff (for one batch) at step 4830: 2.3863
Seen so far: tf.Tensor(154592, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 5000: 1813.7310
Training L2 Loss (for one batch) at step 5000: 209.9601
Training L3 Loss (for one batch) at step 5000: 5920.1758
loss_diff_mse (for one batch) at step 5000: 0.0090
recon_diff (for one batch) at step 5000: 2.2708
Seen so far: tf.Tensor(160032, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 5010: 1717.4641
Training L2 Loss (for one batch) at step 5010: 162.2313
Training L3 Loss (for one batch) at step 5010: 6660.3857
loss_diff_mse (for one batch) at step 5010: 0.0081
recon_diff (for one batch) at step 5010: 2.5414
Seen so far: tf.Tensor(160352, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5020: 1424.9670
Training L2 Loss (for one batch) at step 5020: 120.7762
Training L3 Loss (for one batch) at step 5020: 5044.9180
loss_diff_mse (for one batch) at step 5020: 0.0076
recon_diff (for one batch) at step 5020: 2.1637
Seen so far: tf.Tensor(160672, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5030: 2255.9180
Training L2 Loss (for one batch) at step 5030: 263.4375
Training L3 Loss (for one batch) at step 5030: 7342.8916
loss_diff_mse (for one batch) at step 5030: 0.0100
recon_diff (for one batch) at step 5030: 2.6306
Seen so far: tf.Tensor(160992, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 5200: 1988.1350
Training L2 Loss (for one batch) at step 5200: 340.0449
Training L3 Loss (for one batch) at step 5200: 6290.9180
loss_diff_mse (for one batch) at step 5200: 0.0094
recon_diff (for one batch) at step 5200: 2.2150
Seen so far: tf.Tensor(166432, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 5210: 2043.2917
Training L2 Loss (for one batch) at step 5210: 406.8706
Training L3 Loss (for one batch) at step 5210: 6394.2686
loss_diff_mse (for one batch) at step 5210: 0.0100
recon_diff (for one batch) at step 5210: 2.3005
Seen so far: tf.Tensor(166752, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5220: 1845.5273
Training L2 Loss (for one batch) at step 5220: 434.1484
Training L3 Loss (for one batch) at step 5220: 6189.6240
loss_diff_mse (for one batch) at step 5220: 0.0091
recon_diff (for one batch) at step 5220: 2.0454
Seen so far: tf.Tensor(167072, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5230: 1746.6218
Training L2 Loss (for one batch) at step 5230: 436.3430
Training L3 Loss (for one batch) at step 5230: 5847.4429
loss_diff_mse (for one batch) at step 5230: 0.0097
recon_diff (for one batch) at step 5230: 2.2244
Seen so far: tf.Tensor(167392, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 5400: 1730.5803
Training L2 Loss (for one batch) at step 5400: 211.7805
Training L3 Loss (for one batch) at step 5400: 5696.8613
loss_diff_mse (for one batch) at step 5400: 0.0081
recon_diff (for one batch) at step 5400: 1.9857
Seen so far: tf.Tensor(172832, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 5410: 2004.7444
Training L2 Loss (for one batch) at step 5410: 217.3983
Training L3 Loss (for one batch) at step 5410: 7521.3896
loss_diff_mse (for one batch) at step 5410: 0.0078
recon_diff (for one batch) at step 5410: 2.3619
Seen so far: tf.Tensor(173152, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5420: 1914.0593
Training L2 Loss (for one batch) at step 5420: 263.9742
Training L3 Loss (for one batch) at step 5420: 7691.6240
loss_diff_mse (for one batch) at step 5420: 0.0084
recon_diff (for one batch) at step 5420: 2.7512
Seen so far: tf.Tensor(173472, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5430: 1895.8901
Training L2 Loss (for one batch) at step 5430: 289.9676
Training L3 Loss (for one batch) at step 5430: 6243.3193
loss_diff_mse (for one batch) at step 5430: 0.0085
recon_diff (for one batch) at step 5430: 2.1296
Seen so far: tf.Tensor(173792, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 5600: 1945.5905
Training L2 Loss (for one batch) at step 5600: 283.8875
Training L3 Loss (for one batch) at step 5600: 6189.3594
loss_diff_mse (for one batch) at step 5600: 0.0083
recon_diff (for one batch) at step 5600: 2.0071
Seen so far: tf.Tensor(179232, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 5610: 1836.7085
Training L2 Loss (for one batch) at step 5610: 243.1093
Training L3 Loss (for one batch) at step 5610: 5648.9917
loss_diff_mse (for one batch) at step 5610: 0.0085
recon_diff (for one batch) at step 5610: 1.9213
Seen so far: tf.Tensor(179552, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5620: 1943.4138
Training L2 Loss (for one batch) at step 5620: 345.7940
Training L3 Loss (for one batch) at step 5620: 6197.1475
loss_diff_mse (for one batch) at step 5620: 0.0089
recon_diff (for one batch) at step 5620: 2.0101
Seen so far: tf.Tensor(179872, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5630: 1575.5812
Training L2 Loss (for one batch) at step 5630: 271.6025
Training L3 Loss (for one batch) at step 5630: 5267.8008
loss_diff_mse (for one batch) at step 5630: 0.0086
recon_diff (for one batch) at step 5630: 2.0913
Seen so far: tf.Tensor(180192, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 5800: 1466.1311
Training L2 Loss (for one batch) at step 5800: 304.9022
Training L3 Loss (for one batch) at step 5800: 5278.1450
loss_diff_mse (for one batch) at step 5800: 0.0081
recon_diff (for one batch) at step 5800: 2.0494
Seen so far: tf.Tensor(185632, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 5810: 1769.6431
Training L2 Loss (for one batch) at step 5810: 413.4027
Training L3 Loss (for one batch) at step 5810: 6254.4165
loss_diff_mse (for one batch) at step 5810: 0.0081
recon_diff (for one batch) at step 5810: 2.0245
Seen so far: tf.Tensor(185952, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5820: 1718.1156
Training L2 Loss (for one batch) at step 5820: 261.3358
Training L3 Loss (for one batch) at step 5820: 6375.1157
loss_diff_mse (for one batch) at step 5820: 0.0077
recon_diff (for one batch) at step 5820: 2.1687
Seen so far: tf.Tensor(186272, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 5830: 1417.7817
Training L2 Loss (for one batch) at step 5830: 226.1984
Training L3 Loss (for one batch) at step 5830: 5060.8428
loss_diff_mse (for one batch) at step 5830: 0.0076
recon_diff (for one batch) at step 5830: 1.9285
Seen so far: tf.Tensor(186592, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 6000: 1847.5237
Training L2 Loss (for one batch) at step 6000: 263.1575
Training L3 Loss (for one batch) at step 6000: 5485.1777
loss_diff_mse (for one batch) at step 6000: 0.0089
recon_diff (for one batch) at step 6000: 1.9873
Seen so far: tf.Tensor(192032, shape=(), dtype=int64) samples




Training L1 Loss (for one batch) at step 6010: 1931.0552
Training L2 Loss (for one batch) at step 6010: 339.4039
Training L3 Loss (for one batch) at step 6010: 5804.7305
loss_diff_mse (for one batch) at step 6010: 0.0088
recon_diff (for one batch) at step 6010: 1.8819
Seen so far: tf.Tensor(192352, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 6020: 1603.6149
Training L2 Loss (for one batch) at step 6020: 121.6573
Training L3 Loss (for one batch) at step 6020: 6866.4336
loss_diff_mse (for one batch) at step 6020: 0.0067
recon_diff (for one batch) at step 6020: 2.4166
Seen so far: tf.Tensor(192672, shape=(), dtype=int64) samples
Training L1 Loss (for one batch) at step 6030: 1920.8838
Training L2 Loss (for one batch) at step 6030: 162.2104
Training L3 Loss (for one batch) at step 6030: 6808.8623
loss_diff_mse (for one batch) at step 6030: 0.0075
recon_diff (for one batch) at step 6030: 2.1932
Seen so far: tf.Tensor(192992, shape=(), dtype=int64) samples
Trainin



Training L1 Loss (for one batch) at step 6200: 1664.1899
Training L2 Loss (for one batch) at step 6200: 239.5022
Training L3 Loss (for one batch) at step 6200: 5323.0972
loss_diff_mse (for one batch) at step 6200: 0.0080
recon_diff (for one batch) at step 6200: 1.9652
Seen so far: tf.Tensor(198432, shape=(), dtype=int64) samples




In [None]:
test_optimizer = tf.keras.optimizers.Adam()

In [None]:
@tf.function(reduce_retracing=True)
def test_step6(some_data, training_data):
  with tf.GradientTape() as tape:
    l1, l2, l3, loss_diff_mse, recon_diff = diffusion_model.compute_model_loss(some_data, training=True)
    loss = tf.reduce_mean(l1+l2+l3)
  trainable_weights = diffusion_model.trainable_weights()
  grads = tape.gradient(loss, trainable_weights)
  test_optimizer.apply_gradients(zip(grads, trainable_weights))

  l11, l21, l31, loss_diff_mse1, recon_diff1 = diffusion_model.compute_model_loss(training_data)
  return tf.reduce_mean(loss_diff_mse1/Z_EMBEDDING_SIZE), recon_diff1

In [None]:
@tf.function(reduce_retracing=True)
def test_step8(training_data):
  l11, l21, l31, loss_diff_mse1, recon_diff1 = diffusion_model.compute_model_loss(training_data, training=False)
  return tf.reduce_mean(loss_diff_mse1/Z_EMBEDDING_SIZE), recon_diff1

In [None]:
some_data = None
for step, t in train_ds.enumerate():
  some_data = t
  break

for step, training_data in train_ds.enumerate():
  print(test_step8(training_data))

In [None]:
print(batch_size)

# Test the model on a single protein

In [None]:
from google.colab import output
output.enable_custom_widget_manager()

Support for third party widgets will remain active for the duration of the session. To disable support:

In [None]:
# Load the test protein
blob = client.bucket("public-datasets-deepmind-alphafold-v4").blob(
    'AF-A0A5C2FU82-F1'+'-model_v4.cif')
parser = PDB.FastMMCIFParser()
structure = parser.get_structure('AF-A0A5C2FU82-F1', blob.open())


In [None]:
def PreProcessPDBStructure(pdb_structure):
    residue_names = []
    atom_names = []
    coords = []
    for r in pdb_structure.get_residues():
        for a in r.get_atoms():
            residue_names.append(r.get_resname())
            atom_names.append(a.get_name())
            coords.append(a.get_coord())
    residue_names = np.array(residue_names)
    atom_names = np.array(atom_names)
    normalized_coordinates = np.array(coords)
    normalized_coordinates -= np.mean(coords, 0)


    return {
        'name': pdb_structure.get_id(),
        'residue_names': residue_names,
        'atom_names': atom_names,
        'normalized_coordinates': normalized_coordinates,
    }

In [None]:
def _FeaturesFromPreprocessedStructure(
    preprocessed_structure, residue_names_preprocessor,
    atom_names_preprocessor):
  residue_names = residue_names_preprocessor.lookup(
      tf.constant(preprocessed_structure['residue_names']))
  atom_names = atom_names_preprocessor.lookup(
      tf.constant(preprocessed_structure['atom_names']))
  normalized_coordinates = tf.constant(preprocessed_structure['normalized_coordinates'])
  return {
      'residue_names': tf.expand_dims(residue_names, 0),
      'atom_names': tf.expand_dims(atom_names, 0),
      'normalized_coordinates': tf.expand_dims(normalized_coordinates, 0)}

In [None]:
def UpdateStructure(structure, new_coordinates):
  loc = 0
  for atom in structure.get_atoms():
    atom.set_coord(new_coordinates[loc])
    loc+=1
  structure.atom_to_internal_coordinates(True)

## Original Structure

In [None]:
preprocessed_structure = PreProcessPDBStructure(structure)
original_data = _FeaturesFromPreprocessedStructure(
    preprocessed_structure, residue_names_preprocessor, atom_names_preprocessor)

In [None]:
#conditioning = diffusion_model._conditioner.conditioning(
#    original_data['residue_names'], original_data['atom_names'])
#encoding = diffusion_model._encoder.encode(
#    original_data['normalized_coordinates'], conditioning)
#diffusion_model.set_scorer(
#    ScoreTrain(PerfectScoreModel(encoding)))

#gamma_module = tf.Module()
#gamma_module.gamma_min = -6.0
#gamma_module.gamma_max = 10.0
#diffusion_model.set_gamma_module(gamma_module)

In [None]:
nglview.show_biopython(structure)

In [None]:
UpdateStructure(structure, preprocessed_structure['normalized_coordinates'])
nglview.show_biopython(structure)

In [None]:
print(preprocessed_structure['normalized_coordinates'])

In [None]:
(error_dist, true_dist, z_0, z_t, new_z_0) = diffusion_model.reconstruct(1, original_data)

In [None]:
print(error_dist.mean())

In [None]:
print(true_dist.mean())

In [None]:
print(preprocessed_structure['normalized_coordinates'])

In [None]:
def NetImprovement(true_solution, error_solution, actual_solution):
  return  (tf.reduce_mean(tf.math.abs(true_solution - actual_solution))/
           tf.reduce_mean(tf.math.abs(true_solution - error_solution)))

In [None]:
print(tf.reduce_mean(tf.math.abs(preprocessed_structure['normalized_coordinates'] - true_dist.mean()[0])))

In [None]:
print(tf.reduce_mean(tf.math.abs(preprocessed_structure['normalized_coordinates'] - error_dist.mean()[0])))

In [None]:
print(
  NetImprovement(preprocessed_structure['normalized_coordinates'],
                 error_dist.mean()[0], true_dist.mean()[0]))

In [None]:
print(NetImprovement(z_0, z_t, new_z_0))

In [None]:
print(new_z_0)

In [None]:
import scipy
import numpy as np
import matplotlib.pyplot as plt

In [None]:
UpdateStructure(structure, true_dist.mean()[0])
nglview.show_biopython(structure)

In [None]:
def PlotToDist(timesteps, gamma_min, gamma_max, x_0):
  ts = np.arange(timesteps+1)/timesteps
  gammas = gamma_max + (gamma_min - gamma_max) *ts
  def sigmoid(g):
    return 1/(1 + np.exp(-g))
  sigma2s = sigmoid(gammas)
  x_norm = tf.norm(x_0).numpy()
  e_norm = tf.math.sqrt(tf.math.reduce_sum(tf.ones_like(x_0))).numpy()

  plt.plot(ts, (1-np.sqrt(1-sigma2s)) * x_norm + np.sqrt(sigma2s) * e_norm)
  print((1-np.sqrt(1-sigma2s)) * x_norm + np.sqrt(sigma2s) * e_norm)
  plt.axvline(x=0.9)
  plt.axvline(x=1)
  plt.ylim([0,25])

In [None]:
PlotToDist(10000, -6, 6, z_0)

In [None]:
print(tf.norm(z_0 - new_z_0))
print(tf.norm(z_0 - z_t))

In [None]:
print(true_dist.mean())
print(tf.norm(true_dist.mean()[0] - preprocessed_structure['normalized_coordinates'])/811)

In [None]:
print(true_dist.mean()[0] - preprocessed_structure['normalized_coordinates'])
print(tf.norm(true_dist.mean()[0] - preprocessed_structure['normalized_coordinates'], ord=1)/811)

In [None]:
def DecoderPerformance(gammas_to_test):
  cond = diffusion_model._conditioner.conditioning(
      original_data['residue_names'], original_data['residue_names'], training=False)
  emb = diffusion_model._encoder.encode(
      original_data['normalized_coordinates'], cond, training=False)
  eps = tf.random.normal(tf.shape(emb))
  errors = []
  for g in gammas_to_test:
    emb_with_error = diffusion_model.variance_preserving_map(emb, g, eps) / diffusion_model.alpha(g)
    solution = diffusion_model._decoder.decode(emb_with_error, cond, training=False)
    errors.append(tf.norm(solution.mean()[0] - preprocessed_structure['normalized_coordinates'], ord=1)/811)
  plt.plot(gammas_to_test, errors)

In [None]:
print(tf.norm(perfect_solution.mean()[0] - preprocessed_structure['normalized_coordinates'], ord=1)/811)
print(perfect_solution.mean()[0] - preprocessed_structure['normalized_coordinates'])
DecoderPerformance([float(x) for x in np.arange(10, 20, 0.5, dtype=np.float)])

# Debug Model Issues

In [None]:
diffusion_model.compute_model_loss(original_data)