# Single Node Single GPU Training in TensorFlow

In this sample notebook, you'll train Resnet50 architecture to identify different species of birds. We are only running a few epochs, to save time, but once you've got this working you'll have all the information you need to build and run bigger Tensorflow models on Saturn Cloud.

This notebook was run on a T4-4XLarge instance in tests, and we recommend not going much smaller with your instance size.

In [None]:
import tensorflow as tf
import keras
import time

In [None]:
import s3fs

s3 = s3fs.S3FileSystem(anon=True)
_ = s3.get(
    rpath="s3://saturn-public-data/100-bird-species/100-bird-species/*/*/*.jpg",
    lpath="dataset/birds/",
)

The first time you run this job, you'll need to download the training and test data in the code chunk above. It'll take a few moments, but it shouldn't be too long. This small sample only includes 61 classes, but the original has 285. 

In [None]:
def train_model_fit(n_epochs, base_lr, batchsize, classes):

    model = tf.keras.applications.ResNet50(include_top=True, weights=None, classes=classes)

    # Data
    train_ds = (
        tf.keras.preprocessing.image_dataset_from_directory(
            "dataset/birds/train", image_size=(224, 224), batch_size=batchsize
        )
        .prefetch(2)
        .cache()
        .shuffle(1000)
    )

    valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
        "dataset/birds/valid", image_size=(224, 224), batch_size=batchsize
    ).prefetch(2)

    optimizer = keras.optimizers.Adam(lr=base_lr)
    model.compile(loss="sparse_categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"])
    start = time.time()

    model.fit(
        train_ds,
        epochs=n_epochs,
        validation_data=valid_ds,
    )
    end = time.time() - start
    print("model training time", end)

    tf.keras.models.save_model(model, "model/keras_single/")

In [None]:
model_params = {"n_epochs": 3, "base_lr": 0.02, "classes": 61, "batchsize": 64}

The following chunk runs the model training process, and saves your trained model object to the Jupyter instance memory. A folder called `model` will be created and populated for you.

In [None]:
tester = train_model_fit(**model_params)