In [None]:
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow_addons as tfa

from kaggle_datasets import KaggleDatasets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from plotly.subplots import make_subplots
import plotly.express as px

AUTO=tf.data.AUTOTUNE


In [None]:
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
tf.config.experimental_connect_to_cluster(resolver)
tf.tpu.experimental.initialize_tpu_system(resolver)
strategy = tf.distribute.TPUStrategy(resolver)


In [None]:
gcsPath=KaggleDatasets().get_gcs_path()


In [None]:
Train=tf.data.TFRecordDataset(tf.io.gfile.glob(gcsPath+'/tfrecords-jpeg-512x512/train/*'),num_parallel_reads=AUTO)
Val=tf.data.TFRecordDataset(tf.io.gfile.glob(gcsPath+'/tfrecords-jpeg-512x512/val/*'),num_parallel_reads=AUTO)
Test=tf.data.TFRecordDataset(tf.io.gfile.glob(gcsPath+'/tfrecords-jpeg-512x512/test/*'),num_parallel_reads=AUTO)


In [None]:
FullTrain=Train.concatenate(Val)


In [None]:
temp=[]
for i,data in enumerate(FullTrain):
    temp.append(i)

for i in FullTrain.take(1):
    example=tf.train.Example()
    example.ParseFromString(i.numpy())
    print(example)


In [None]:
def parseImage(EagerTensor):
    FeatureMap={
    'image':tf.io.FixedLenFeature([],tf.string),
    }
    Features=tf.io.parse_single_example(EagerTensor,FeatureMap)
    return tf.reshape(tf.image.decode_jpeg(Features['image']),(512,512,3))/255

def parseLabel(EagerTensor):
    FeatureMap={
    "class":tf.io.FixedLenFeature([], tf.int64),
    }
    Features=tf.io.parse_single_example(EagerTensor,FeatureMap)
    return tf.one_hot(Features['class'],104)

def parseId(EagerTensor):
    FeatureMap={
    "id":tf.io.FixedLenFeature([], tf.string),
    }
    Features=tf.io.parse_single_example(EagerTensor,FeatureMap)
    return Features['id']


plt.imshow(next(iter(Train.map(parseImage))))

plt.imshow(next(iter(HighResolution.map(parse512Image))))

In [None]:

with strategy.scope():
    Densenet201=tf.keras.applications.DenseNet201(
    weights='imagenet',
    include_top=False,
    input_shape=(512,512,3)
    )
    Densenet201.trainalbe=True
    for layer in Densenet201.layers[:100]:
        layer.trainable=False

    model=tf.keras.Sequential([
        Densenet201,
        layers.Dropout(0.2),
        layers.GlobalAveragePooling2D(),
        layers.Dense(104,activation='softmax')
    ])
    model.compile(loss='categorical_crossentropy',optimizer='adam',metrics=['accuracy',tfa.metrics.F1Score(104)])
    
model.summary()

In [None]:
BATCHSIZE=16*strategy.num_replicas_in_sync

TrainImage=FullTrain.map(parseImage)
TrainLabel=FullTrain.map(parseLabel)

# ValImage=Val.map(parseImage)
# ValLabel=Val.map(parseLabel)

TrainDS=tf.data.Dataset.zip((TrainImage,TrainLabel)).repeat().batch(BATCHSIZE).prefetch(AUTO)
# ValDS=tf.data.Dataset.zip((ValImage,ValLabel)).batch(BATCHSIZE).prefetch(AUTO)

In [None]:
ReduceLR = tf.keras.callbacks.ReduceLROnPlateau(monitor='accuracy', factor=0.2, patience=3,verbose=True)

class MyCallBack(tf.keras.callbacks.Callback):
    def on_epoch_end(self,epoch,log):
        if log['accuracy']>0.9995:
            self.stop_training=True
            
Stop=MyCallBack()

In [None]:
Epochs=20
History=model.fit(TrainDS,epochs=Epochs,steps_per_epoch=temp[-1]//BATCHSIZE,callbacks=[ReduceLR,Stop],verbose=1)


In [None]:
Loss=History.history['loss']
Acc=History.history['accuracy']
F1=History.history['f1_score']

# ValLoss=History.history['val_loss']
# ValAcc=History.history['val_accuracy']
# ValF1=History.history['val_f1_score']


In [None]:
fig=make_subplots(rows=1, cols=3)
fig.add_scatter(x=History.epoch,y=Acc,name='TrainAcc',row=1, col=1)
# fig.add_scatter(x=History.epoch,y=ValAcc,name='ValAcc',row=1, col=1)
fig.add_scatter(x=History.epoch,y=Loss,name='TrainLoss',row=1, col=2)
# fig.add_scatter(x=History.epoch,y=ValLoss,name='ValLoss',row=1, col=2)
fig.add_scatter(x=History.epoch,y=Loss,name='TrainF1',row=1, col=3)
# fig.add_scatter(x=History.epoch,y=ValLoss,name='ValF1',row=1, col=3)
fig.update_xaxes(title_text="Epoch", row=1, col=1)
fig.update_xaxes(title_text="Epoch", row=1, col=2)
fig.update_xaxes(title_text="Epoch", row=1, col=3)
fig.update_yaxes(title_text="Accuracy", row=1, col=1)
fig.update_yaxes(title_text="Loss", row=1, col=2)
fig.update_yaxes(title_text="F1 Score", row=1, col=3)


In [None]:
TestImage=Test.map(parseImage)
TestId=Test.map(parseId)
TestDS=tf.data.Dataset.zip((TestImage)).batch(BATCHSIZE).prefetch(AUTO)
result=model.predict(TestDS)


In [None]:
Id=[]
for i in TestId:
    Id.append(str(i.numpy(),encoding='utf-8'))
    

In [None]:
submission=pd.read_csv(r'../input/tpu-getting-started/sample_submission.csv')
submission.id=Id
submission.label=np.argmax(result,axis=1)
submission.to_csv(r'./submission.csv',index=False)
