Quick showcase of the training loop.

# Part I: Data Preparation with TFRecords

In [None]:
# Assuming data is already parse/graphs are built.
# Information to build each graph (nodes, edges, senders, receivers, ...) and targets (monopoles, ...) are written into .tfrecords
feature_description = {
    'nodes': tf.io.FixedLenFeature([], tf.string),
    'edges': tf.io.FixedLenFeature([], tf.string),
    'coordinates': tf.io.FixedLenFeature([], tf.string),
    'n_node': tf.io.FixedLenFeature([], tf.string),
    'n_edge': tf.io.FixedLenFeature([], tf.string),
    'senders': tf.io.FixedLenFeature([], tf.string),
    'receivers': tf.io.FixedLenFeature([], tf.string),
    'monopoles': tf.io.FixedLenFeature([], tf.string),
    'dipoles': tf.io.FixedLenFeature([], tf.string),
    'quadrupoles': tf.io.FixedLenFeature([], tf.string),
    #'smiles': tf.io.FixedLenFeature([], tf.string),
    #'file_names': tf.io.FixedLenFeature([], tf.string),
}  


writer = tf.io.TFRecordWriter(tf_record_folder.format(shard_index))
for graphs, monopoles, dipoles, quadrupoles in data:
    #graph = build_graph(coords, elements, cutoff=CUTOFF, num_kernels=NUM_KERNELS)   
    batch = {
                'nodes': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(graphs.nodes).numpy()])),
                'edges': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(graphs.edges).numpy()])),
                'coordinates': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(coordinates).numpy()])),
                'n_node': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(graphs.n_node).numpy()])),
                'n_edge': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(graphs.n_edge).numpy()])),
                'senders': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(graphs.senders).numpy()])),
                'receivers': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(graphs.receivers).numpy()])),
                'monopoles': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(monopoles).numpy()])),
                'dipoles': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(dipoles).numpy()])),
                'quadrupoles': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(quadrupoles).numpy()])),
                #'molecular_dipoles': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(molecular_dipoles).numpy()])),
                #'octupoles': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(octupoles).numpy()])),
                #'smiles': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(smiles).numpy()])),
                #'file_names': tf.train.Feature(bytes_list=tf.train.BytesList(value=[tf.io.serialize_tensor(file_names).numpy()])),
            }
    example = tf.train.Example(features=tf.train.Features(feature=batch)).SerializeToString()
    writer.write(example)        

# Part II: Reading TFRecords

In [None]:
# Data is read using the TFRecordDataset functionality. Set Parametesr (number of records, batchsize, etc) accordingly
dtype_record = tf.float32
def load_data(record):
    batch = tf.io.parse_single_example(record, feature_description)  
    nodes = tf.io.parse_tensor(batch['nodes'], out_type=dtype_record)
    edges = tf.io.parse_tensor(batch['edges'], out_type=dtype_record)
    coords = tf.io.parse_tensor(batch['coordinates'], out_type=dtype_record) 
    n_node = tf.io.parse_tensor(batch['n_node'], out_type=tf.int32)
    n_edge = tf.io.parse_tensor(batch['n_edge'], out_type=tf.int32)
    senders = tf.io.parse_tensor(batch['senders'], out_type=tf.int32)
    receivers = tf.io.parse_tensor(batch['receivers'], out_type=tf.int32)    
    monopoles = tf.io.parse_tensor(batch['monopoles'], out_type=dtype_record)
    dipoles = tf.io.parse_tensor(batch['dipoles'], out_type=dtype_record)
    quadrupoles = D_Q(tf.io.parse_tensor(batch['quadrupoles'], out_type=dtype_record))
    graph = gn.graphs.GraphsTuple(nodes, edges, globals=None, receivers=receivers, senders=senders, n_node=n_node, n_edge=n_edge)
    return graph, coords, monopoles, dipoles, quadrupoles

DATASET_FOLDER = FOLDER_PATH + 'training_data/shard_{}.tf_record'
dataset = tf.data.TFRecordDataset([DATASET_FOLDER.format(x) for x in np.random.choice(NUM_FILES, NUM_FILES, replace=False)], num_parallel_reads=2)#.shuffle(NUM_FILES)
dataset = dataset\
            .repeat()\
            .map(load_data, num_parallel_calls=tf.data.AUTOTUNE)\
            .prefetch(tf.data.AUTOTUNE)\
            .apply(tf.data.experimental.ignore_errors())\
            .shuffle(32, reshuffle_each_iteration=True)

# Part III: Functions for Training

In [None]:
# Helper functions for training
def train_step_shared(model, batch):
    with tf.GradientTape(persistent=True) as tape:        
        loss_mono, loss_dipo, loss_quad, monopoles_predicted, dipoles_predicted, quadrupoles_predicted = get_loss(model, batch)
        loss_total = loss_mono + loss_dipo + loss_quad
    gradients = tape.gradient(loss_total, model.trainable_variables)
    gradients, _ = tf.clip_by_global_norm(gradients, 1e0)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return monopoles_predicted, dipoles_predicted, quadrupoles_predicted
    
def train_step(model, batch):
    with tf.GradientTape(persistent=True) as tape:        
        loss_mono, loss_dipo, loss_quad, monopoles_predicted, dipoles_predicted, quadrupoles_predicted = get_loss(model, batch)
    trainable_variables_mono = variables_mono(model)
    trainable_variables_dipo = variables_dipo(model)
    trainable_variables_quad = variables_quad(model)
    gradients_mono = tape.gradient(loss_mono, trainable_variables_mono)
    gradients_dipo = tape.gradient(loss_dipo, trainable_variables_dipo)
    gradients_quad = tape.gradient(loss_quad, trainable_variables_quad)
    gradients_mono, _ = tf.clip_by_global_norm(gradients_mono, 1e0)
    gradients_dipo, _ = tf.clip_by_global_norm(gradients_dipo, 1e0)
    gradients_quad, _ = tf.clip_by_global_norm(gradients_quad, 1e0)
    optimizer_mono.apply_gradients(zip(gradients_mono, trainable_variables_mono))
    optimizer_dipo.apply_gradients(zip(gradients_dipo, trainable_variables_dipo))
    optimizer_quad.apply_gradients(zip(gradients_quad, trainable_variables_quad))
    del tape
    return monopoles_predicted, dipoles_predicted, quadrupoles_predicted

def get_loss(model, batch):
    monopoles_predicted, dipoles_predicted, quadrupoles_predicted = model(batch[0], batch[1])
    loss_mono = get_loss_mono(monopoles_predicted, batch[2])
    loss_dipo = get_loss_dipo(dipoles_predicted, batch[3])
    loss_quad = get_loss_quad(quadrupoles_predicted, batch[4])
    return loss_mono, loss_dipo, loss_quad, monopoles_predicted, dipoles_predicted, quadrupoles_predicted

def get_loss_mono(monopoles_predicted, monopoles_ref):
    return tf.reduce_mean(tf.math.squared_difference(monopoles_predicted, monopoles_ref))

def get_loss_dipo(dipoles_predicted, dipoles_ref):
    return tf.reduce_mean(tf.math.squared_difference(dipoles_predicted, dipoles_ref))

mask = tf.linalg.band_part(tf.ones((3, 3), dtype=np.bool_), 0, -1)
def get_loss_quad(quadrupoles_predicted, quadrupoles_ref):
    return tf.reduce_mean(tf.boolean_mask(tf.math.squared_difference(quadrupoles_predicted, quadrupoles_ref), mask, axis=1))

def variables_mono(model):
    variables = []
    variables.extend(model.embedding_mono.trainable_variables)
    for layer in model.gns_mono:
        variables.extend(layer.trainable_variables)
    #variables.extend(model.gn.trainable_variables)
    variables.extend(model.mono.trainable_variables)
    return variables
    
def variables_dipo(model):
    variables = []
    variables.extend(model.embedding_dipo.trainable_variables)
    for layer in model.gns_dipo:
        variables.extend(layer.trainable_variables)
    #variables.extend(model.gn.trainable_variables)
    variables.extend(model.dipo.trainable_variables)
    return variables
    
def variables_quad(model):
    variables = []
    variables.extend(model.embedding_quad.trainable_variables)
    for layer in model.gns_quad:
        variables.extend(layer.trainable_variables)
    #variables.extend(model.gn.trainable_variables)    
    variables.extend(model.quad.trainable_variables)
    return variables

@tf.function(experimental_relax_shapes=True)
def get_outer_products(vectors):
    vectors = tf.expand_dims(vectors, axis=-1)
    return D_Q(vectors * tf.linalg.matrix_transpose(vectors))
    
@tf.function(experimental_relax_shapes=True)    
def D_Q(quadrupoles):
    return tf.linalg.set_diag(quadrupoles, tf.linalg.diag_part(quadrupoles) - tf.expand_dims((tf.linalg.trace(quadrupoles) / 3), axis=-1))

In [None]:
# Training Loop
for idb, batch in enumerate(dataset):
    monopoles_predicted, dipoles_predicted, quadrupoles_predicted = train_step(model, batch)    