In [None]:
from kaggle_datasets import KaggleDatasets
GCS_DS_PATH = KaggleDatasets().get_gcs_path('tpu-getting-started')
print(GCS_DS_PATH)

In [None]:
import tensorflow as tf
from functools import partial
import matplotlib.pyplot as plt

try:
    tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect()
    print("Device:", tpu.master())
    strategy = tf.distribute.TPUStrategy(tpu)
except:
    strategy = tf.distribute.get_strategy()
print("Number of replicas:", strategy.num_replicas_in_sync)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE
BATCH_SIZE = 64
IMAGE_SIZE = [224, 224]

In [None]:
TRAINING_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH + "/tfrecords-jpeg-224x224/train/*.tfrec")
VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH + "/tfrecords-jpeg-224x224/val/*.tfrec")
TEST_FILENAMES = tf.io.gfile.glob(GCS_DS_PATH + "/tfrecords-jpeg-224x224/test/*.tfrec")
print("Train TFRecord Files:", len(TRAINING_FILENAMES))
print("val TFRecord Files:", len(VALIDATION_FILENAMES))
print("Test TFRecord Files:", len(TEST_FILENAMES))

In [None]:
def decode_image(image):
    image = tf.image.decode_jpeg(image, channels=3)
    image = tf.cast(image, tf.float32)
    image = tf.reshape(image, [*IMAGE_SIZE, 3])
    return image

In [None]:
def read_tfrecord(example, train):
    tfrecord_format = (
        {
            "image": tf.io.FixedLenFeature([], tf.string),
            "class": tf.io.FixedLenFeature([], tf.int64),
        }
        if train
        else {"image": tf.io.FixedLenFeature([], tf.string),}
    )
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example["image"])
    if train:
        label = tf.cast(example["class"], tf.int32)
        return image, label
    return image

In [None]:
def load_dataset(filenames, train=True):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False  # disable order, increase speed
    dataset = tf.data.TFRecordDataset(
        filenames
    )  # automatically interleaves reads from multiple files
    dataset = dataset.with_options(
        ignore_order
    )  # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(
        partial(read_tfrecord, train=train), num_parallel_calls=AUTOTUNE
    )
    # returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False
    return dataset

In [None]:
def get_dataset(filenames, train=True):
    dataset = load_dataset(filenames, train=train)
    dataset = dataset.shuffle(2048)
    dataset = dataset.prefetch(buffer_size=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    return dataset

In [None]:
train_dataset = get_dataset(TRAINING_FILENAMES)
validation_dataset = get_dataset(VALIDATION_FILENAMES)
test_dataset = get_dataset(TEST_FILENAMES,train=False)

image_batch, label_batch = next(iter(train_dataset))


def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10, 10))
    for n in range(25):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[n] / 255.0)
        if label_batch[n]:
            plt.title("MALIGNANT")
        else:
            plt.title("BENIGN")
        plt.axis("off")


show_batch(image_batch.numpy(), label_batch.numpy())

In [None]:
initial_learning_rate = 0.01
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True
)

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    "melanoma_model.h5", save_best_only=True
)

early_stopping_cb = tf.keras.callbacks.EarlyStopping(
    patience=10, restore_best_weights=True
)

In [None]:
def make_model():
    base_model = tf.keras.applications.ResNet50(
        input_shape=(*IMAGE_SIZE, 3), include_top=False
    )

    base_model.trainable = False

    inputs = tf.keras.layers.Input([*IMAGE_SIZE, 3])
    x = tf.keras.applications.resnet.preprocess_input(inputs)
    x = base_model(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(512, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = tf.keras.layers.Dense(104, activation="softmax")(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
        loss="sparse_categorical_crossentropy",
        metrics=['accuracy'],
    )

    return model

In [None]:
with strategy.scope():
    model = make_model()

history = model.fit(
    train_dataset,
    epochs=20,
    validation_data=validation_dataset,
    callbacks=[checkpoint_cb, early_stopping_cb],
)

In [None]:
#model.save(GCS_DS_PATH + "/model_1")

In [None]:
def make_model2():
    base_model = tf.keras.applications.Xception(
        input_shape=(*IMAGE_SIZE, 3), include_top=False
    )

    base_model.trainable = False

    inputs = tf.keras.layers.Input([*IMAGE_SIZE, 3])
    x = tf.keras.applications.xception.preprocess_input(inputs)
    x = base_model(x)
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    x = tf.keras.layers.Dense(512, activation="relu")(x)
    x = tf.keras.layers.Dropout(0.2)(x)
    outputs = tf.keras.layers.Dense(104, activation="softmax")(x)

    model2 = tf.keras.Model(inputs=inputs, outputs=outputs)

    model2.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=lr_schedule),
        loss="sparse_categorical_crossentropy",
        metrics=['accuracy'],
    )

    return model2

In [None]:
with strategy.scope():
    model2 = make_model()

history2 = model2.fit(
    train_dataset,
    epochs=20,
    validation_data=validation_dataset,
    callbacks=[checkpoint_cb, early_stopping_cb],
)

In [None]:
def show_batch_predictions(image_batch):
    plt.figure(figsize=(10, 10))
    for n in range(25):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[n] / 255.0)
        img_array = tf.expand_dims(image_batch[n], axis=0)
        plt.title(np.argmax(model.predict(img_array)[0]))
        plt.axis("off")


image_batch = next(iter(test_dataset))

show_batch_predictions(image_batch)

In [None]:
def show_batch_predictions(image_batch):
    plt.figure(figsize=(10, 10))
    for n in range(25):
        ax = plt.subplot(5, 5, n + 1)
        plt.imshow(image_batch[n] / 255.0)
        img_array = tf.expand_dims(image_batch[n], axis=0)
        plt.title(np.argmax(model2.predict(img_array)[0]))
        plt.axis("off")


image_batch = next(iter(test_dataset))

show_batch_predictions(image_batch)