In [None]:
import cv2

In [None]:
from functools import partial
import matplotlib.pyplot as plt

# from readTFRecords import *

# import tensorflow_hub as hub
# from tensorflow.keras.applications import MobileNet
from tensorflow.keras.applications import EfficientNetB0


import tensorflow as tf
from tensorflow import keras
import re
import numpy as np
import pandas as pd

from functools import partial

IMAGE_SIZE = (512, 512)
AUTOTUNE = tf.data.experimental.AUTOTUNE
CLASSES = ['0', '1', '2', '3', '4']
NUM_CLASSES = len(CLASSES)

def read_tfrecord(example, labeled=True):
    features = {
        "image": tf.io.FixedLenFeature([], tf.string),
        "image_name": tf.io.FixedLenFeature([], tf.string),
        "target": tf.io.FixedLenFeature([], tf.int64)
    }

    # decode the TFRecord
    example = tf.io.parse_single_example(example, features)
    image = decode_image(example["image"])
    
    if labeled:
        label = tf.cast(example["target"], tf.int32)
        image_name = tf.cast(example["image_name"], tf.string)
        return image, label, image_name
    return image

def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32)# / 255.0
    return image
    
    
    
def load_dataset(filenames, labeled=True):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False  # disable order, increase speed
    dataset = tf.data.TFRecordDataset(
        filenames
    )  # automatically interleaves reads from multiple files
    dataset = dataset.with_options(
        ignore_order
    )  # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(partial(read_tfrecord, labeled=True), num_parallel_calls=AUTOTUNE)
    # returns a dataset of (image, label) pairs if labeled=True or just images if labeled=False
    return dataset

# One-hot / categorical encoding
# Resize

def input_preprocess(image, label,image_name):
    image = tf.image.resize(image, size=IMAGE_SIZE)
    image = tf.clip_by_value(image, clip_value_min=0, clip_value_max=255)
    image = tf.image.rgb_to_yuv(image)
    label = tf.one_hot(label, NUM_CLASSES)
    return image, label,image_name


def get_training_dataset(FILENAMES, BATCH_SIZE=12):
    dataset = load_dataset(FILENAMES, labeled=True)  
    dataset = dataset.map(input_preprocess, num_parallel_calls=AUTOTUNE)
    dataset = dataset.repeat()
    dataset = dataset.shuffle(2048)
    dataset = dataset.batch(BATCH_SIZE)
#     dataset = dataset.map(cutmix)
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset

def get_validation_dataset(FILENAMES, BATCH_SIZE=12):
    dataset = load_dataset(FILENAMES, labeled=True)
    dataset = dataset.map(input_preprocess, num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.cache()
    dataset = dataset.prefetch(AUTOTUNE)
    return dataset



from sklearn.model_selection import train_test_split

# AUTOTUNE = tf.data.experimental.AUTOTUNE
GCS_PATH = "../input/train_tfrecords"

FILENAMES = tf.io.gfile.glob(GCS_PATH + "/*tfrec")
split_ind = int(0.9 * len(FILENAMES))
# TRAINING_FILENAMES, VALID_FILENAMES = FILENAMES[:split_ind], FILENAMES[split_ind:]

# TRAINING_FILENAMES, VALID_FILENAMES = train_test_split(FILENAMES, test_size=0.2, random_state=420)

# TRAINING_FILENAMES = ['gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train04-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train05-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train06-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train07-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train08-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train09-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train10-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train11-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train12-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train13-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train14-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train15-1327.tfrec']
# VALID_FILENAMES = ['gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train00-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train01-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train02-1338.tfrec', 'gs://kds-3a224514a454fd9aa3d169e4b992b270639f50cb2562afc9a7f30028/train_tfrecords/ld_train03-1338.tfrec']

TRAINING_FILENAMES = ['../input/train_tfrecords/ld_train04-1338.tfrec', '../input/train_tfrecords/ld_train05-1338.tfrec', '../input/train_tfrecords/ld_train06-1338.tfrec', '../input/train_tfrecords/ld_train07-1338.tfrec', '../input/train_tfrecords/ld_train08-1338.tfrec', '../input/train_tfrecords/ld_train09-1338.tfrec', '../input/train_tfrecords/ld_train10-1338.tfrec', '../input/train_tfrecords/ld_train11-1338.tfrec', '../input/train_tfrecords/ld_train12-1338.tfrec', '../input/train_tfrecords/ld_train13-1338.tfrec', '../input/train_tfrecords/ld_train14-1338.tfrec', '../input/train_tfrecords/ld_train15-1327.tfrec']
VALID_FILENAMES = ['../input/train_tfrecords/ld_train00-1338.tfrec', '../input/train_tfrecords/ld_train01-1338.tfrec', '../input/train_tfrecords/ld_train02-1338.tfrec', '../input/train_tfrecords/ld_train03-1338.tfrec']

TEST_FILENAMES = tf.io.gfile.glob("../input/test_tfrecords/*tfrec")
print("Train TFRecord Files:", len(TRAINING_FILENAMES))
print("Validation TFRecord Files:", len(VALID_FILENAMES))
print("Test TFRecord Files:", len(TEST_FILENAMES))

def count_data_items(filenames):
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES)
NUM_VALIDATION_IMAGES = count_data_items(VALID_FILENAMES)
NUM_TEST_IMAGES = count_data_items(TEST_FILENAMES)

print('Dataset: {} training images, {} validation images, {} (unlabeled) test images'.format(
    NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES, NUM_TEST_IMAGES))

In [None]:
TRAINING_FILENAMES

In [None]:
BATCH_SIZE = 12
EPOCHS = 30

In [None]:
from tensorflow.keras.experimental import CosineDecay
import efficientnet.keras as eff

decay_steps = int(round(NUM_TRAINING_IMAGES/BATCH_SIZE))*EPOCHS
cosine_decay = CosineDecay(initial_learning_rate=1e-4, decay_steps=decay_steps, alpha=0.3)

from keras.backend import sigmoid

class SwishActivation(tf.keras.layers.Activation):
    
    def __init__(self, activation, **kwargs):
        super(SwishActivation, self).__init__(activation, **kwargs)
        self.__name__ = 'swish_act'

def swish_act(x, beta = 1):
    return (x * sigmoid(beta * x))

from keras.utils.generic_utils import get_custom_objects
# from tf.keras.layers import Activation
get_custom_objects().update({'swish_act': SwishActivation(swish_act)})

inputs = tf.keras.layers.Input(shape=(IMAGE_SIZE[0], IMAGE_SIZE[1], 3))
#     x = data_augmentation_layers(inputs)
#     model = tf.keras.applications.EfficientNetB0(include_top=False, input_tensor=x, weights="imagenet")
model = eff.EfficientNetB5(include_top=False, input_tensor=inputs, weights=None)


# Freeze the pretrained weights
# model.trainable = False

# Rebuild top
x = tf.keras.layers.GlobalAveragePooling2D(name="avg_pool")(model.output)
x = tf.keras.layers.BatchNormalization()(x)

top_dropout_rate = 0.1
#     x = tf.keras.layers.Dropout(top_dropout_rate, name="top_dropout")(x)

x = tf.keras.layers.BatchNormalization()(x)
#     x = tf.keras.layers.Dropout(0.1)(x)

x = tf.keras.layers.Dense(512)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation(swish_act)(x)
#     x = tf.keras.layers.Dropout(0.1)(x)

x = tf.keras.layers.Dense(256)(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.Activation(swish_act)(x)

outputs = tf.keras.layers.Dense(NUM_CLASSES, activation="softmax", name="pred")(x)

# Compile
model = tf.keras.Model(inputs, outputs, name="EfficientNet")

loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.4)
#     loss = tf.keras.losses.CategoricalCrossentropy()

model.compile(loss=loss, optimizer=tf.keras.optimizers.Adam(cosine_decay), 
          metrics=["accuracy"])

In [None]:
model = tf.keras.models.load_model('/home/usmanr/Downloads/EfficientNetB5_yuv_smooting.h5')

In [None]:
train_dataset = get_validation_dataset(VALID_FILENAMES, BATCH_SIZE=120)

In [None]:
for image, label, filename in train_dataset.take(1):
    pass

In [None]:
filename.numpy().tolist()

In [None]:
np.min(image[0:,:,:,2])

In [None]:
plt.imshow(image[0,:,:,0]/255.0)

In [None]:
# for _ in range(4279//20):
labels = []
preds = []
filenames = []

i = 1
for image, label, filename in train_dataset.take(45):
    labels+=label.numpy().tolist()
    pred = model.predict(image)
    preds+=pred.tolist()
    filenames+=filename.numpy().tolist()
    print(i)
    i+=1

In [None]:
print(len(filenames))
print(len(set(filenames)))

In [None]:
from sklearn.metrics import accuracy_score

In [None]:
y_label=[np.argmax(y) for y in labels]
y_pred=[np.argmax(y) for y in preds]

In [None]:
accuracy_score(y_label, y_pred)

In [None]:
breaking_point = 0
mixed_precision = []

for i in range(len(preds)):
    if np.max(preds[i]) < 0.6:
        print(filenames[i])
        mixed_precision.append(filenames[i])
        
        breaking_point+=1
#     if breaking_point > 10:
#         break

In [None]:
filenames[0].decode("utf-8") 

In [None]:
image = cv2.imread('../input/train_images/' + mixed_precision[1].decode("utf-8") )
plt.imshow(image)

In [None]:
for i in range(len(mixed_precision)):
    print(mixed_precision[i].decode("utf-8")  + ',', end='')