In [None]:
import numpy as np, pandas as pd
import requests, io, os, datetime, re, json
import tensorflow as tf
import keras
import matplotlib.pyplot as plt
from tensorflow.keras import datasets, layers, models, losses, Model
from keras.applications import resnet50
from keras.preprocessing import image
import dask
import time
print(tf.__version__)

In [None]:
import wandb
from wandb.keras import WandbCallback

wandb.login()

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
        print(e)

In [None]:
def train_multigpu(n_epochs, classes, base_lr, batchsize,wbargs, scale_batch = False, scale_lr = False):
    
    wbargs={**wbargs, "scale_batch":scale_batch, "scale_lr":scale_lr}

    # --------- Start wandb --------- #
    wandb.init(entity=[YOURUSERNAMEHERE], project=[YOURPROJECTNAMEHERE], config=wbargs)    

    strategy = tf.distribute.MirroredStrategy()
    print('Number of devices: %d' % strategy.num_replicas_in_sync)

    if scale_batch:
        batchsize = batchsize*strategy.num_replicas_in_sync
        
    if scale_lr:
        base_lr = base_lr*strategy.num_replicas_in_sync
    
    with strategy.scope():
        model = tf.keras.applications.ResNet50(
            include_top=True,
            weights=None,
            classes=classes)

        optimizer = keras.optimizers.Adam(lr=base_lr)
        model.compile(loss='sparse_categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])

    # Data
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        'datasets/birds/train',
        image_size=(224,224),
        batch_size=batchsize
    ).prefetch(2).cache().shuffle(1000)
        
    valid_ds = tf.keras.preprocessing.image_dataset_from_directory(
        'datasets/birds/valid',
        image_size=(224,224),
        batch_size=batchsize
    ).prefetch(2)
    
    start = time.time()

    model.fit(
        train_ds, 
        epochs=n_epochs, 
        validation_data=valid_ds,
        callbacks=[WandbCallback()]
    )

    end = time.time()-start
    print("model training time", end)
    wandb.log({"training_time":end})

    # Close your wandb run 
    wandb.run.finish()
    
    tf.keras.models.save_model(model, 'model/keras_multi/')


In [None]:
model_params = {'n_epochs': 50, 
                'base_lr': .02,
               'batchsize': 64,
                   'classes':285,
               'scale_batch': True}

wbargs = {**model_params,
    'Notes':"tf_v100_8x",
    'Tags': ['multi', 'gpu', 'tensorflow'],
    'dataset':"Birds",
    'architecture':"ResNet50"}


In [None]:
tester_plain = train_multigpu(wbargs=wbargs, **model_params)
