#### Dependencies installing and environment initialization:
Run following code in console, not in jupyter server; then launch server using fresh-created virtual environment.

In [None]:
!pip install --user pipenv
!mkdir venv && python -m pipenv install --skip-lock venv
!python -m pipenv shell
!nbstripout --install

#### General environment initialization:

In [None]:
%matplotlib inline

import tensorflow as tf

from tensorflow.python.framework.config import list_physical_devices
from tensorflow.python.platform.test import is_built_with_cuda


print(f"TensorFlow version: {tf.__version__}, built {'with' if is_built_with_cuda() else 'without'} CUDA")
gpu_available = len(list_physical_devices('GPU')) != 0
print(f"GPU is{' ' if gpu_available else ' not '}available on the device")
if not gpu_available:
    print("Consider following the guide https://www.tensorflow.org/install/gpu for model training")

IMAGES = 7400
API_KEY = "24844585-de839c0e13ca422a989916f16"

BATCH_SIZE = 32
IMG_SIZE = (299, 299)

INITIAL_EPOCHS = 16

#### Run following cell to set up in Google Colab environment:

In [None]:
from google.colab import drive


drive.mount("./drive")
IMAGE_DIR = "./drive/MyDrive/Hogweb"
MODEL_DIR = "./drive/MyDrive/Hogweed"

#### Run following cell to set up in local environment:

In [None]:
IMAGE_DIR = "./data"
MODEL_DIR = "./model"

#### Support function initialization:

In [None]:
from matplotlib.pyplot import imread, subplots, close
from IPython.display import display
from matplotlib.axes import Axes
from random import sample
from os import listdir
from PIL import Image


def remove_axis(plt: Axes):
    plt.axes.xaxis.set_visible(False)
    plt.axes.yaxis.set_visible(False)

def print_images_line(images: [Image]):
    fig, axes = subplots(1, len(images), figsize=(15, 15))
    for ind, ax in enumerate(axes):
        remove_axis(ax)
        ax.imshow(images[ind])
    display(fig)
    close(fig)

def print_random_images(images: [Image], num: int = 5):
    print_images_line([image for image in sample(images, num)])

def print_images_from_dir(path: str, num: int = 5):
    print_images_line([imread(f"{path}/{image}") for image in sample(listdir(path), num)])

def naturalize_urls(urls: list[str]) -> list[str]:
    return [url.replace("medium", "large") for url in urls]

#### Use following links to access image sets:
Random images available [here](https://storage.googleapis.com/openimages/2018_04/validation/validation-images-with-rotation.csv).
It is a validation set of [OpenImages](https://storage.googleapis.com/openimages/web/index.html) dataset, but it contains enough images for hogweed classification.

Plant images available [here](https://www.inaturalist.org/observations/export) after signing in iNaturalist.
* Use this request `has[]=photos&quality_grade=any&identifications=any&iconic_taxa[]=Plantae&projects[]=leningrad-oblast-flora` to retrieve different plant images.
* Use this request `has[]=photos&quality_grade=any&identifications=any&iconic_taxa[]=Plantae&taxon_id=499936` to retrieve hogweed images.

Please make sure `.csv` files received using given requests and links are placed in `./datasets` directory before running the following cell, making sure:
* The random image dataset file name should be `other.csv`.
* The hogweed image dataset file name should be `hogweed.csv`.
* The other plants image dataset file name should be `cetera.csv`.

In [None]:
from imagehash import average_hash
from os import makedirs, system
from tqdm.notebook import tqdm
from pandas import read_csv
from shutil import rmtree
from requests import get
from io import BytesIO


def clear_path(path: str) -> str:
    rmtree(path, ignore_errors=True)
    makedirs(path)
    return path

def download_and_save(img_source: str, img_type: str, count: int, source: [str]):
    print(f"Downloading {img_type} images from {img_source}")
    img_hashes = set()
    img_code = 0

    with tqdm(total=count, desc="Downloading", unit="img") as bar:
        while len(img_hashes) < count:
            if img_code >= len(source):
                raise RuntimeError(f"The source set contains only {len(source)} elements, {img_code}th was requested!")
            else:
                try:
                    img_data = Image.open(BytesIO(get(source[img_code]).content))
                    img_hash = average_hash(img_data)
                    if img_hash not in img_hashes:
                        img_hashes.add(img_hash)
                        img_data.save(f"{IMAGE_DIR}/{img_type}/{img_type}{img_code}.jpg")
                        bar.update()
                finally:
                    img_code += 1
                    continue

    print(f"Downloaded {img_code + 1} images finished, {img_code + 1 - count} collisions or wrong images found")
    print(f"{img_source} dataset {img_type} image samples:")
    print_images_from_dir(f"{IMAGE_DIR}/{img_type}")


print(f"Train set should contain hogweed, cetera and other images, {IMAGES} of each kind")

print("Preparing directories")
clear_path(f"{IMAGE_DIR}/hogweed")
clear_path(f"{IMAGE_DIR}/cetera")
clear_path(f"{IMAGE_DIR}/other")

frame = read_csv("./datasets/hogweed.csv")
frame = frame[frame["license"] != "CC0"]["image_url"]
download_and_save("iNaturalist", "hogweed", IMAGES, naturalize_urls(frame.to_list()))

frame = read_csv("./datasets/cetera.csv")
frame = frame[(frame["license"] != "CC0") & (frame["scientific_name"] != "Heracleum sosnowskyi")]["image_url"]
download_and_save("iNaturalist", "cetera", IMAGES, naturalize_urls(frame.to_list()))

frame = read_csv("./datasets/other.csv")
download_and_save("UnSplash", "other", IMAGES, frame["OriginalURL"])

print("Removing ipynb caches")
system("rm -rf `find -type d -name .ipynb_checkpoints`")

In [None]:
from pandas import read_csv, DataFrame


frame = read_csv("./datasets/hogweed.csv")
hogweed_frame = frame[frame["license"] == "CC0"].sample(50)["image_url"]
frame = read_csv("./datasets/cetera.csv")
cetera_frame = frame[(frame["scientific_name"] != "Heracleum sosnowskyi") & (frame["license"] == "CC0")].sample(50)["image_url"]

URLs = naturalize_urls(hogweed_frame.to_list()) + naturalize_urls(cetera_frame.to_list()) + [f"https://source.unsplash.com/random?sig={num}" for num in range(50)]
classes = (['hogweed'] * 50) + (['cetera'] * 50) + (['other'] * 50)
DataFrame(list(zip(classes, URLs)), columns=['class', 'url']).to_csv(path_or_buf="./datasets/validation.csv", index=False)

In [None]:
from tensorflow.python.keras.layers import RandomFlip, RandomRotation
from tensorflow.keras.utils import image_dataset_from_directory
from tensorflow.python.data.experimental import cardinality
from tensorflow.python.data import AUTOTUNE
from tensorflow.keras import Sequential
from tensorflow import expand_dims


dataset = image_dataset_from_directory(IMAGE_DIR, shuffle=True, label_mode='categorical', batch_size=BATCH_SIZE, image_size=IMG_SIZE)
batches_num = cardinality(dataset)

train_power = batches_num * 9 // 10
train_dataset = dataset.take(train_power).prefetch(buffer_size=AUTOTUNE)
test_dataset = dataset.skip(train_power).take(train_power // 9).prefetch(buffer_size=AUTOTUNE)

data_augmentation = Sequential([
    RandomFlip('horizontal'),
    RandomRotation(0.1),
    # RandomContrast(0.01)
])

print(f"Train set has {cardinality(train_dataset)} batches, test set has {cardinality(test_dataset)} batches; train set will be augmented:")
for i in range(0, 5):
    imgs = []
    for img, _ in train_dataset.skip(i).take(1):
        for j in range(0, 5):
            imgs += [data_augmentation(expand_dims(img[0], 0))[0] / 255]
    print_images_line(imgs)

In [None]:
from tensorflow.keras.applications.xception import preprocess_input, Xception
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import Input, Model


INPUT_SHAPE = IMG_SIZE + (3,)

base_model = Xception(input_shape=INPUT_SHAPE, include_top=False)
base_model.trainable = False

inputs = Input(shape=INPUT_SHAPE, name="re_shaper")
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = GlobalAveragePooling2D(name="pooling")(x)
x = Dropout(0.2, name="dropout")(x)
outputs = Dense(3, name="predictor")(x)
model = Model(inputs, outputs, name="hogweed_detector")

# TODO: check learning rate!
# TODO: check optimizer!
base_learning_rate = 0.001
model.compile(optimizer=Adam(learning_rate=base_learning_rate), loss=CategoricalCrossentropy(), metrics=['accuracy'])

model.summary()

In [None]:
from tensorflow.python.keras.callbacks import Callback, EarlyStopping
from tensorflow.python.data.experimental import cardinality
from IPython.display import display, update_display
from matplotlib.pyplot import subplots, close
from tqdm.notebook import tqdm


class CustomCallback(Callback):
    def __init__(self):
        super().__init__()
        self.batch_train_accuracy = []
        self.batch_train_losses = []

        self.epoch_train_accuracy = []
        self.epoch_train_losses = []
        self.epoch_test_accuracy = []
        self.epoch_test_losses = []

        self.batch_figure, self.batch_train_graph = subplots(1, 1, figsize=(15, 5))
        self.batch_figure.suptitle("Accuracy/Loss graph")
        self.batch_train_graph.set_title("Training")
        self.batch_train_graph.set_xlabel("Batches")
        self.batch_train_graph.set_ylabel("Accuracy/Losses")
        self.batch_train_graph.grid(visible=True)
        self.batch_display_id = "batch_graph"
        close(self.batch_figure)

        self.epoch_figure, (self.epoch_train_graph, self.epoch_test_graph) = subplots(1, 2, figsize=(15, 5))
        self.epoch_figure.suptitle("Accuracy/Loss graph")
        self.epoch_train_graph.set_title("Training")
        self.epoch_test_graph.set_title("Testing")
        for graph in (self.epoch_train_graph, self.epoch_test_graph):
            graph.set_xlabel("Epochs")
            graph.set_ylabel("Accuracy/Losses")
            graph.grid(visible=True)
        self.epoch_display_id = "epoch_graph"
        close(self.epoch_figure)

        self.epochs_bar = None
        self.batches_bar = None

    def redraw_graph(self, figure, graph, display_id, losses_arr, accuracy_arr):
        title = graph.get_title()
        x_label = graph.get_xlabel()
        y_label = graph.get_ylabel()

        graph.cla()
        graph.plot(losses_arr, color='blue', label="losses")
        graph.scatter(len(losses_arr) - 1, losses_arr[-1], c='blue')
        graph.plot(accuracy_arr, color='red', label="accuracy")
        graph.scatter(len(accuracy_arr) - 1, accuracy_arr[-1], c='red')

        graph.set_ylim(bottom=0)
        graph.set_title(title)
        graph.set_xlabel(x_label)
        graph.set_ylabel(y_label)
        graph.legend()
        graph.grid(visible=True)

        figure.canvas.draw()
        update_display(figure, display_id=display_id)

    def on_train_begin(self, logs=None):
        display(self.batch_figure, display_id=self.batch_display_id)
        display(self.epoch_figure, display_id=self.epoch_display_id)

        self.epochs_bar = tqdm(desc="Training", unit="epoch", total=INITIAL_EPOCHS)

    def on_epoch_begin(self, epoch, logs=None):
        self.batches_bar = tqdm(desc="Training", unit="batch", total=cardinality(train_dataset).numpy())

    def on_epoch_end(self, epoch, logs=None):
        self.epoch_train_losses.append(logs['loss'])
        self.epoch_train_accuracy.append(logs['accuracy'])
        self.redraw_graph(self.epoch_figure, self.epoch_train_graph, self.epoch_display_id, self.epoch_train_losses, self.epoch_train_accuracy)

        self.epoch_test_losses.append(logs['val_loss'])
        self.epoch_test_accuracy.append(logs['val_accuracy'])
        self.redraw_graph(self.epoch_figure, self.epoch_test_graph, self.epoch_display_id, self.epoch_test_losses, self.epoch_test_accuracy)

        self.batches_bar.set_description(f"Loss: {round(logs['loss'], 3)}; Accuracy: {round(logs['accuracy'] * 100, 3)}")
        self.batches_bar.close()
        self.epochs_bar.update()

    def on_train_batch_end(self, batch, logs=None):
        self.batch_train_losses.append(logs['loss'])
        self.batch_train_accuracy.append(logs['accuracy'])
        self.redraw_graph(self.batch_figure, self.batch_train_graph, self.batch_display_id, self.batch_train_losses, self.batch_train_accuracy)

        self.batches_bar.set_description(f"Loss: {round(logs['loss'], 3)}; Accuracy: {round(logs['accuracy'] * 100, 3)}")
        self.batches_bar.update()

    def on_train_end(self, logs=None):
        self.epochs_bar.close()



history = model.fit(train_dataset, epochs=INITIAL_EPOCHS, validation_data=test_dataset, verbose=0, callbacks=[CustomCallback(), EarlyStopping(patience=3, restore_best_weights=True)])

In [None]:
from tensorflow.lite.python.lite import TFLiteConverter


with open(f"{MODEL_DIR}/detector-{history['accuracy'][-1]}.tflite", 'wb') as file:
    lite_model = tf.lite.TFLiteConverter.from_keras_model(model).convert()
    file.write(lite_model)