#### Dependencies installing and environment initialization:
Run following code in console, not in jupyter server; then launch server using fresh-created virtual environment.

In [None]:
!pip3 install --user pipenv
!python3 -m pipenv install --skip-lock
!python3 -m pipenv shell
!nbstripout --install

#### General environment initialization:

In [None]:
%matplotlib inline

import tensorflow as tf

from tensorflow.python.framework.config import list_physical_devices
from tensorflow.python.platform.test import is_built_with_cuda


print(f"TensorFlow version: {tf.__version__}, built {'with' if is_built_with_cuda() else 'without'} CUDA")
gpu_available = len(list_physical_devices('GPU')) != 0
print(f"GPU is{' ' if gpu_available else ' not '}available on the device")
if not gpu_available:
    print("Consider following the guide https://www.tensorflow.org/install/gpu for model training")

IMAGES = 7100
IMG_SIZE = (299, 299)
CLASS_NAMES = ['hogweed', 'cetera', 'other']

BATCH_SIZE = 8
INITIAL_EPOCHS = 16
FINE_TUNING_EPOCHS = 16
EARLY_STOPPING = 0.0001

#### Run following cell to set up in Google Colab environment:

In [None]:
from google.colab import drive


drive.mount("./drive")
IMAGE_DIR = "./drive/MyDrive/Hogweb"
MODEL_DIR = "./drive/MyDrive/Hogweed"

#### Run following cell to set up in local environment:

In [None]:
IMAGE_DIR = "./data"
MODEL_DIR = "./models"

#### Support functions and classes initialization:

In [None]:
from matplotlib.pyplot import imread, subplots, close
from IPython.display import display, update_display
from tensorflow.keras.callbacks import Callback
from matplotlib.axes import Axes
from tqdm.notebook import tqdm
from random import sample
from os import listdir
from PIL import Image


class LoggingCallback(Callback):
    def __init__(self):
        def configure_graph(graph, measure: str, activity: str, param: str):
            graph.set_title(f"{measure} {activity} {param}")
            graph.set_xlabel(measure)
            graph.set_ylabel(param)
            graph.grid(visible=True)

        super().__init__()

        self.batch_train_accuracy = []
        self.batch_train_loss = []
        self.epoch_train_accuracy = []
        self.epoch_train_loss = []
        self.epoch_validation_accuracy = []
        self.epoch_validation_loss = []

        self.figure, (self.batch_train_graph, self.epoch_train_graph, self.epoch_validation_graph) = subplots(3, 1, figsize=(15, 15))

        self.figure.suptitle("Accuracy/Loss graphs")
        configure_graph(self.batch_train_graph, "Batch", "training", "Accuracy and Loss")
        configure_graph(self.epoch_train_graph, "Epoch", "training", "Accuracy and Loss")
        configure_graph(self.epoch_validation_graph, "Epoch", "validation", "Accuracy and Loss")

        self.figure_display_id = "figure"
        close(self.figure)

        self.epochs_bar = None
        self.batches_bar = None

    def redraw_graph(self, graph, accuracy_arr, loss_arr):
        title = graph.get_title()
        x_label = graph.get_xlabel()
        y_label = graph.get_ylabel()

        graph.cla()
        graph.plot(accuracy_arr, color='red', label="accuracy")
        graph.scatter(len(accuracy_arr) - 1, accuracy_arr[-1], c='red')
        graph.plot(loss_arr, color='blue', label="loss")
        graph.scatter(len(loss_arr) - 1, loss_arr[-1], c='blue')

        graph.set_ylim(bottom=0)
        graph.set_title(title)
        graph.set_xlabel(x_label)
        graph.set_ylabel(y_label)
        graph.legend()
        graph.grid(visible=True)

        self.figure.canvas.draw()
        update_display(self.figure, display_id=self.figure_display_id)

    def on_train_begin(self, logs=None):
        display(self.figure, display_id=self.figure_display_id)

        self.epochs_bar = tqdm(desc="Training", unit="epoch", total=INITIAL_EPOCHS)

    def on_epoch_begin(self, epoch, logs=None):
        self.batches_bar = tqdm(desc="Training", unit="batch", total=train_dataset.cardinality().numpy())

    def on_epoch_end(self, epoch, logs=None):
        self.epoch_train_accuracy.append(logs['accuracy'])
        self.epoch_train_loss.append(logs['loss'])
        self.redraw_graph(self.epoch_train_graph, self.epoch_train_accuracy, self.epoch_train_loss)

        self.epoch_validation_accuracy.append(logs['val_accuracy'])
        self.epoch_validation_loss.append(logs['val_loss'])
        self.redraw_graph(self.epoch_validation_graph, self.epoch_validation_accuracy, self.epoch_validation_loss)

        self.batches_bar.set_description(f"Loss: {round(logs['loss'], 3)}; Accuracy: {round(logs['accuracy'] * 100, 3)}")
        self.batches_bar.close()
        self.epochs_bar.update()

    def on_train_batch_end(self, batch, logs=None):
        self.batch_train_accuracy.append(logs['accuracy'])
        self.batch_train_loss.append(logs['loss'])
        self.redraw_graph(self.batch_train_graph,self.batch_train_accuracy, self.batch_train_loss)

        self.batches_bar.set_description(f"Loss: {round(logs['loss'], 3)}; Accuracy: {round(logs['accuracy'] * 100, 3)}")
        self.batches_bar.update()

    def on_train_end(self, logs=None):
        self.epochs_bar.close()


def remove_axis(plt: Axes):
    plt.axes.xaxis.set_visible(False)
    plt.axes.yaxis.set_visible(False)

def print_images_line(images: [Image]):
    fig, axes = subplots(1, len(images), figsize=(15, 15))
    for ind, ax in enumerate(axes):
        remove_axis(ax)
        ax.imshow(images[ind])
    display(fig)
    close(fig)

def print_random_images(images: [Image], num: int = 5):
    print_images_line([image for image in sample(images, num)])

def print_images_from_dir(path: str, num: int = 5):
    print_images_line([imread(f"{path}/{image}") for image in sample(listdir(path), num)])

def naturalize_urls(urls: list[str]) -> list[str]:
    return [url.replace("medium", "large") for url in urls]

#### Use following links to access image sets:
Random images available [here](https://storage.googleapis.com/openimages/2018_04/validation/validation-images-with-rotation.csv).
It is a validation set of [OpenImages](https://storage.googleapis.com/openimages/web/index.html) dataset, but it contains enough images for hogweed classification.

Plant images available [here](https://www.inaturalist.org/observations/export) after signing in iNaturalist.
* Use this request `has[]=photos&quality_grade=any&identifications=any&iconic_taxa[]=Plantae&projects[]=leningrad-oblast-flora` to retrieve different plant images.
* Use this request `has[]=photos&quality_grade=any&identifications=any&iconic_taxa[]=Plantae&taxon_id=499936` to retrieve hogweed images.

Please make sure `.csv` files received using given requests and links are placed in `./datasets` directory before running the following cell, making sure:
* The random image dataset file name should be `other.csv`.
* The hogweed image dataset file name should be `hogweed.csv`.
* The other plants image dataset file name should be `cetera.csv`.

In [None]:
from pandas import read_csv, concat
from imagehash import average_hash
from os import makedirs, system
from tqdm.notebook import tqdm
from shutil import rmtree
from requests import get
from io import BytesIO
from PIL import Image


def clear_path(path: str) -> str:
    rmtree(path, ignore_errors=True)
    makedirs(path, exist_ok=True)
    return path

def download_and_save(img_source: str, img_type: str, count: int, source: [str]):
    print(f"Downloading {img_type} images from {img_source}, {len(source)} images provided")
    img_hashes = set()
    img_code = 0

    with tqdm(total=count, desc="Downloading", unit="img") as bar:
        while len(img_hashes) < count:
            if img_code >= len(source):
                raise RuntimeError(f"The source set contains only {len(source)} elements, {img_code}th was requested! ({img_code + 1 - count} collisions)")
            else:
                try:
                    img_data = Image.open(BytesIO(get(source[img_code]).content))
                    img_hash = average_hash(img_data, hash_size=32)
                    if img_hash not in img_hashes and len(img_data.getdata()) > 0:
                        img_data.save(f"{IMAGE_DIR}/{img_type}/{img_type}{img_code}.jpg")
                        img_hashes.add(img_hash)
                        bar.update()
                finally:
                    img_code += 1
                    continue

    print(f"Downloaded {img_code + 1} images finished, {img_code + 1 - count} collisions or wrong images found")
    print(f"{img_source} dataset {img_type} image samples:")
    print_images_from_dir(f"{IMAGE_DIR}/{img_type}")


print(f"Train set should contain hogweed, cetera and other images, {IMAGES} of each kind")

print("Preparing directories")
clear_path(f"{IMAGE_DIR}/hogweed")
clear_path(f"{IMAGE_DIR}/cetera")
clear_path(f"{IMAGE_DIR}/other")

frame = read_csv("./datasets/hogweed.csv")
open_frame = frame[frame["license"] == "CC0"]["image_url"]
frame = concat([frame[frame["license"] != "CC0"]["image_url"], open_frame.head(len(open_frame) - 50)])
download_and_save("iNaturalist", "hogweed", IMAGES, naturalize_urls(frame.to_list()))

frame = read_csv("./datasets/cetera.csv")
open_frame = frame[(frame["license"] == "CC0") & (frame["scientific_name"] != "Heracleum sosnowskyi")]["image_url"]
frame = concat([frame[(frame["license"] != "CC0") & (frame["scientific_name"] != "Heracleum sosnowskyi")]["image_url"], open_frame.head(len(open_frame) - 50)])
download_and_save("iNaturalist", "cetera", IMAGES, naturalize_urls(frame.to_list()))

frame = read_csv("./datasets/other.csv")
download_and_save("OpenImages", "other", IMAGES, frame["OriginalURL"])

print("Removing ipynb caches")
system("rm -rf `find -type d -name .ipynb_checkpoints`")

### Use this cell to regenerate `./datasets/test.csv` file
It will contain 50 last open source from each plant dataset (cetera, hogweed) that were not in train nor in validation datasets + 50 open source images from [unsplash](https://unsplash.com/).
It is not guaranteed that unsplash images will be unique, but it's the only clearly open source image source was found.

The generated file may be used with result testing script (`./test.py`) as follows: `python test.py -n ./models/[YOUR MODEL FILE NAME].tflite -s ./datasets/test.csv`.

In [None]:
from pandas import read_csv, DataFrame


frame = read_csv("./datasets/hogweed.csv")
hogweed_frame = frame[frame["license"] == "CC0"].tail(50)["image_url"]
frame = read_csv("./datasets/cetera.csv")
cetera_frame = frame[(frame["scientific_name"] != "Heracleum sosnowskyi") & (frame["license"] == "CC0")].tail(50)["image_url"]

URLs = naturalize_urls(hogweed_frame.to_list()) + naturalize_urls(cetera_frame.to_list()) + [f"https://source.unsplash.com/random?sig={num}" for num in range(50)]
classes = sum([[cls] * 50 for cls in CLASS_NAMES], [])
DataFrame(list(zip(classes, URLs)), columns=['class', 'url']).to_csv(path_or_buf="./datasets/test.csv", index=False)

### Image dataset creation, splitting in train and validation datasets
**NB!** in the first line `from tensorflow.python.keras.layers import RandomFlip, RandomRotation` causes [error](https://github.com/keras-team/keras/issues/15699) in model saving in current TensorFlow latest version (2.7) - it is an unexpected behavior, should be fixed in nearest future.

In [None]:
from tensorflow.python.keras.layers import RandomFlip, RandomRotation
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.keras.backend import expand_dims
from tensorflow.python.data import AUTOTUNE
from tensorflow.keras import Sequential


dataset = image_dataset_from_directory(IMAGE_DIR, class_names=CLASS_NAMES, label_mode='categorical', batch_size=BATCH_SIZE, image_size=IMG_SIZE)
batches_num = dataset.cardinality()

train_power = batches_num * 9 // 10
train_dataset = dataset.take(train_power).prefetch(buffer_size=AUTOTUNE)
validation_dataset = dataset.skip(train_power).take(train_power // 9).prefetch(buffer_size=AUTOTUNE)

data_augmentation = Sequential([
    RandomFlip('horizontal'),
    RandomRotation(0.1)
])

print(f"Train set has {train_dataset.cardinality()} batches, validation set has {validation_dataset.cardinality()} batches; train set will be augmented:")
for i in range(0, 5):
    imgs = []
    for img, _ in train_dataset.skip(i).take(1):
        for j in range(0, 5):
            imgs += [data_augmentation(expand_dims(img[0], 0))[0] / 255]
    print_images_line(imgs)

### Model initialization and compilation

In [None]:
from tensorflow.keras.applications.xception import preprocess_input, Xception
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import Input, Model


INPUT_SHAPE = IMG_SIZE + (3,)

base_model = Xception(input_shape=INPUT_SHAPE, pooling='avg', include_top=False)
base_model.trainable = False

inputs = Input(shape=INPUT_SHAPE, name="re_shaper")
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = Dropout(0.2, name="dropout")(x)
outputs = Dense(3, name="predictor")(x)
model = Model(inputs, outputs, name="hogweed_detector")

base_learning_rate = 0.0001
model.compile(optimizer=Adam(learning_rate=base_learning_rate), loss=CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

model.summary()

### Training top layers of the model (its 'head')
Early stopping is used to prevent overfitting, usually it stops after 12-14 epochs.

In [None]:
from tensorflow.keras.callbacks import EarlyStopping


print(f"Training model with {len(model.trainable_variables)} variables enabled.")
callbacks = [LoggingCallback(), EarlyStopping(patience=1, restore_best_weights=True, monitor='val_accuracy', min_delta=EARLY_STOPPING)]
history = model.fit(train_dataset, epochs=INITIAL_EPOCHS, validation_data=validation_dataset, verbose=0, callbacks=callbacks)

epochs = len(history.history['val_accuracy'])
if epochs < INITIAL_EPOCHS:
    print(f"Training stopped early with validation accuracy {round(history.history['val_accuracy'][-1], 3)} after {epochs} epochs.")

### Model preparing for fine-tuning
Only top 80% of layers get fine-tuned not to mess up with base image detection weights.

In [None]:
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.optimizers import RMSprop


fine_tune_from = len(base_model.layers) // 5
base_model.trainable = True
for index, layer in enumerate(base_model.layers):
    if index < fine_tune_from or type(layer) == BatchNormalization:
        layer.trainable = False

further_learning_rate = 0.00001
model.compile(optimizer=RMSprop(learning_rate=further_learning_rate), loss=CategoricalCrossentropy(from_logits=True), metrics=['accuracy'])

model.summary()

### Training middle and top layers of the model (fine-tuning)
Early stopping is used to prevent overfitting, usually it stops after 1-2 epochs.

In [None]:
from tensorflow.keras.callbacks import EarlyStopping


print(f"Training model with {len(model.trainable_variables)} variables enabled.")
start_epoch = history.epoch[-1]
callbacks = [LoggingCallback(), EarlyStopping(patience=1, restore_best_weights=True, monitor='val_accuracy', min_delta=EARLY_STOPPING)]
history = model.fit(train_dataset, epochs=INITIAL_EPOCHS + FINE_TUNING_EPOCHS, validation_data=validation_dataset, initial_epoch=start_epoch, verbose=0, callbacks=callbacks)

epochs = len(history.history['val_accuracy'])
if epochs < start_epoch + FINE_TUNING_EPOCHS:
    print(f"Training stopped early with validation accuracy {round(history.history['val_accuracy'][-1], 3)} after {epochs} epochs.")

### Saving model
Model is saved in both `.h5` and `.tflite` formats for mobile use.

In [None]:
from tensorflow import lite
from os import makedirs


model_name = f"hogweed_detector-{round(history.history['accuracy'][-1], 5)}"

makedirs(MODEL_DIR, exist_ok=True)
model.save(f"{MODEL_DIR}/{model_name}.h5")
with open(f"{MODEL_DIR}/{model_name}.tflite", 'wb') as file:
    lite_model = lite.TFLiteConverter.from_keras_model(model).convert()
    file.write(lite_model)