Skip to content
This repository has been archived by the owner. It is now read-only.
Go to file

Latest commit


Git stats


Failed to load latest commit information.
Latest commit message
Commit time

starttf - Simplified Deeplearning for Tensorflow License: MIT

This repo aims to contain everything required to quickly develop a deep neural network with tensorflow. The idea is that if you use write a compatible SimpleSequence for data loading and networks based on the StartTFModel, you will automatically obey best practices and have super fast training speeds.


Properly install tensorflow or tensorflow-gpu please follow the official instructions carefully.

Then, simply pip install from the github repo.

pip install starttf


Extensions SimpleSequences from opendatalake.simple_sequence.SimpleSequence are supported. They work like keras.Sequence however with an augmentation and a preprocessing function.

For details checkout the readme of opendatalake.


Every model returns a dictionary containing output tensors and a dictionary containing debug tensors

  1. Model Base Classes
  2. Common Encoders
  3. Untrained Backbones

Simple to use tensorflow

Simple Training (No Boilerplate)

There are pre-implemented models which can be glued together and trained with just a few lines. However, before training you will have to create tf-records as shown in the section Simple TF Record Creation. This is actually a full main file.

# Import helpers
from starttf.estimators.tf_estimator import easy_train_and_evaluate
from starttf.utils.hyperparams import load_params

# Import a/your model (here one for mnist)
from mymodel import MyStartTFModel

# Import your loss (here an example)
from myloss import create_loss

# Load params (here for mnist)
hyperparams = load_params("hyperparams/experiment1.json")

# Train model
easy_train_and_evaluate(hyperparams, MyStartTFModel, create_loss, continue_training=False)

Quick Model Definition

Simply implement a create_model function. This model is only a feed forward model.

The model function returns a dictionary containing all layers that should be accessible from outside and a dictionary containing debug values that should be availible for loss or plotting in tensorboard.

import tensorflow as tf

from starttf.models.model import StartTFModel
from starttf.models.encoders import Encoder

Conv2D = tf.keras.layers.Conv2D

class ExampleModel(StartTFModel):
    def __init__(self, hyperparams):
        super(ExampleModel, self).__init__(hyperparams)
        num_classes = hyperparams.problem.number_of_categories

        # Create the vgg encoder
        self.encoder = Encoder(hyperparams)

        #Use the generated model 
        self.conv6 = Conv2D(filters=1024, kernel_size=(1, 1), padding="same", activation="relu")
        self.conv7 = Conv2D(filters=1024, kernel_size=(1, 1), padding="same", activation="relu")
        self.conv8 = Conv2D(filters=num_classes, kernel_size=(1, 1), padding="same", activation=None, name="probs")

    def call(self, input_tensor, training=False):
        Run the model.
        encoder, debug = self.encoder(input_tensor, training)
        result = self.conv6(encoder["features"])
        result = self.conv7(result)
        logits = self.conv8(result)
        probs = tf.nn.softmax(logits)
        return {"logits": logits, "probs": probs}, debug

Quick Loss Definition

def create_loss(model, labels, mode, hyper_params):
    metrics = {}
    losses = {}

    # Add loss
    labels = tf.reshape(labels["probs"], [-1, hyper_params.problem.number_of_categories])
    ce = tf.nn.softmax_cross_entropy_with_logits_v2(logits=model["logits"], labels=labels)
    loss_op = tf.reduce_mean(ce)

    # Add losses to dict. "loss" is the primary loss that is optimized.
    losses["loss"] = loss_op
    metrics['accuracy'] = tf.metrics.accuracy(labels=labels,

    return losses, metrics

Simple TF Record Creation

Fast training speed can be achieved by using tf records. However, usually tf records are a hastle to use the write_data method makes it simple.

from starttf.utils.hyperparams import load_params
from import write_data

from my_data import MySimpleSequence

# Load the hyper parameters.
hyperparams = load_params("hyperparams/experiment1.json")

# Get a generator and its parameters
training_data = MySimpleSequence(hyperparams)
validation_data = MySimpleSequence(hyperparams)

# Write the data
write_data(hyperparams, PHASE_TRAIN, training_data, 4)
write_data(hyperparams, PHASE_VALIDATION, validation_data, 2)

Tensorboard Integration

Tensorboard integration is simple.

Every loss in the losses dict is automatically added to tensorboard. If you also want debug images, you can add a tf.summary.image() in your create_loss method.

TF Estimator + Cluster Support

If you use the easy_train_and_evaluate method, a correctly configured TF Estimator is created. The estimator is then trained in a way that supports cluster training if you have a cluster.

You can’t perform that action at this time.