In [None]:
import tensorflow as tf
import keras
import time

In [None]:
%pip install wandb

In [None]:
import wandb
from wandb.keras import WandbCallback

wandb.login()

In [None]:
gpus = tf.config.list_physical_devices("GPU")
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices("GPU")
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
def train_model_fit(n_epochs, base_lr, batchsize, classes):

    model = tf.keras.applications.ResNet50(include_top=True, weights=None, classes=classes)

    # --------- Start wandb --------- #
    wandb.init(entity=[YOURUSERNAMEHERE], project=[YOURPROJECTNAMEHERE], config=wbargs)

    # Data
    train_ds = (
        tf.keras.preprocessing.image_dataset_from_directory(
            "datasets/birds/train", image_size=(224, 224), batch_size=batchsize
        )
        .prefetch(2)
        .cache()
        .shuffle(1000)
    )

    valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
        "datasets/birds/valid", image_size=(224, 224), batch_size=batchsize
    ).prefetch(2)

    optimizer = keras.optimizers.Adam(lr=base_lr)
    model.compile(loss="sparse_categorical_crossentropy", optimizer=optimizer, metrics=["accuracy"])
    start = time.time()

    model.fit(train_ds, epochs=n_epochs, validation_data=valid_ds, callbacks=[WandbCallback()])
    end = time.time() - start
    print("model training time", end)
    wandb.log({"training_time": end})

    # Close your wandb run
    wandb.run.finish()

    tf.keras.models.save_model(model, "model/keras_single/")

In [None]:
model_params = {"n_epochs": 2, "base_lr": 0.02, "classes": 285, "batchsize": 64}

wbargs = {
    **model_params,
    "Notes": "tf_v100_2x",
    "Tags": ["single", "gpu", "tensorflow"],
    "dataset": "Birds",
    "architecture": "ResNet50",
}

In [None]:
tester = train_model_fit(**model_params)