In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)


import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    print(dirname)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

from tensorflow import keras
import tensorflow as tf

from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import LearningRateScheduler,ReduceLROnPlateau

from tensorflow.keras import layers
from tensorflow.keras.applications import MobileNetV2,EfficientNetB0,EfficientNetB4

import warnings
warnings.filterwarnings("ignore")

## 이미지 제너레이터

In [None]:
img_datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale=1.0/255.0,
    validation_split=0.2,
    rotation_range=5,
    shear_range=0.3,
    zoom_range=0.3,
    width_shift_range=0.05,
    height_shift_range=0.05,
    horizontal_flip=True,
    vertical_flip=True,
)

### 이미지 가져오기

In [None]:
train_ds = img_datagen.flow_from_directory('../input/paddy-disease-classification/train_images',
                                          subset="training", class_mode = 'categorical', 
                                          batch_size=16, target_size=(128,128))

val_ds = img_datagen.flow_from_directory('../input/paddy-disease-classification/train_images',
                                          subset="validation", class_mode = 'categorical', 
                                          batch_size=16, target_size=(128,128))

In [None]:
for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

### 이미지 확인

In [None]:
cnt = 0
plt.figure(figsize=(15,15))
for i in range(0, 16, 1):
    plt.subplot(4,4,cnt+1)
    plt.imshow(image_batch[cnt])
    plt.title(labels_batch[cnt])
    cnt += 1
    
plt.show()

## Xception 모델 사용

In [None]:
base_model = tf.keras.applications.Xception(weights='imagenet', 
                                    input_shape=(128,128,3), include_top=False)
base_model.trainable = False

In [None]:
model = tf.keras.Sequential()

model.add(base_model)
model.add(layers.BatchNormalization())

model.add(layers.Conv2D(128, (3,3), activation='relu', 
                        kernel_regularizer=tf.keras.regularizers.L2(0.001)))
model.add(layers.Dropout(0.5))

model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dropout(0.5))

model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dropout(0.5))

model.add(layers.Dense(256, activation='relu', 
                       kernel_regularizer=tf.keras.regularizers.L2(0.001)))
model.add(layers.Dropout(0.5))

model.add(layers.Dense(10, activation='softmax'))
model.summary()

### 모델 알고리즘 설정

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 
              loss='categorical_crossentropy', 
              metrics=['accuracy'])

### 조기종료

In [None]:
lr_scheduler = ReduceLROnPlateau(monitor='val_accuracy',factor=0.8, patience=10, verbose=1)
save_best = tf.keras.callbacks.ModelCheckpoint("Model.h5",
                                    monitor='val_accuracy',
                                    save_best_only=True, verbose=1)

### 시간 측정

In [None]:
%%time 

model = model.fit(train_ds, validation_data=val_ds, epochs=5, callbacks=[save_best])

### 모델

In [None]:
model = tf.keras.models.load_model('./Model.h5')

In [None]:
test_path = '../input/paddy-disease-classification/test_images'
test_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255).flow_from_directory(    
    directory=test_path,
    target_size=(128, 128),
    batch_size=16,
    classes=['.'],
    shuffle=False,
)

In [None]:
predict = model.predict(test_gen, verbose=1)

In [None]:
predicted_class_indices=np.argmax(predict,axis=1)
print(set(predicted_class_indices))

inv_map = {v:k for k,v in train_ds.class_indices.items()}
predictions = [inv_map[k] for k in predicted_class_indices]

## 제출용

In [None]:
filenames=test_gen.filenames

results=pd.DataFrame({"image_id":filenames,
                      "label":predictions})
results.image_id = results.image_id.str.replace('./', '')
results.to_csv("submission.csv",index=False)
results