In [None]:
import ast
import tempfile
import pandas as pd
import tensorflow as tf
import tensorflow_hub as hub
import matplotlib.pyplot as plt
import matplotlib.style as style
import seaborn as sns
from PIL import Image, ImageFile
from datetime import datetime
from tensorflow import keras
from tensorflow.keras.models import Sequential
from hentai import Utils, Hentai, Option
from pathlib import Path

## Package flags

In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True


## Data Collection

We will be using the doujin dataset obtained from nhentai.

In [None]:
N_SAMPLE = 0 # Number of sample to refetch from hentai.

## Download dataset

In [None]:
samples_df = pd.DataFrame([Utils.get_random_hentai() for _ in range(0, N_SAMPLE)])
samples_df = samples_df.apply(lambda x: x.dictionary(Option.all()))
data_path = Path("data")
metadata_path = data_path / "metadata.csv"
if not metadata_path.is_file():
   samples_df.to_csv(metadata_path, index=False, header="column_names")
else:
   samples_df.to_csv(metadata_path, index=False, mode="a", header=False)
print("Number of resampled samples: ", len(samples_df))


## Read dataset file

In [None]:
converters = {
    column_name: ast.literal_eval
    for column_name in ["tag", "group", "parody", "character", "artist", "category", "image_urls"]
}
hentais_df = pd.read_csv(metadata_path, converters=converters)
hentais_df

## Download images

In [None]:
for _, hentai in hentais_df.iterrows():
    hentai_path = data_path / str(hentai.id)
    if not hentai_path.is_dir():
        hentai = Hentai(hentai.id)
        hentai.download(hentai_path, progressbar=True)

## Data preparation

In [None]:
label_freq = hentais_df["tag"].explode().value_counts().sort_values(ascending=False).head(50)

style.use("fivethirtyeight")
plt.figure(figsize=(12, 20))
sns.barplot(y=label_freq.index.values, x=label_freq, order=label_freq.index)
plt.title("Label frequency", fontsize=14)
plt.xlabel("")
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt.show()

## Data sparsity

In [None]:
nobjs = 2 # Maximum number of images to display
ncols = 2 # Number of columns in display
nrows = nobjs // ncols # Number of rows in display
plt.figure(figsize=(14, 4 * nrows))
hentais_df["num_favorites"].plot.hist(ax=plt.subplot(nrows, ncols, 1), bins=100, title="Favorites")
hentais_df["num_pages"].plot.hist(ax=plt.subplot(nrows, ncols, 2), bins=100, title="Pages")
plt.show()

We need to complete the full path to locate training and test images from the current working directory.

In [None]:
hentais_df = hentais_df.explode("image_urls").reset_index()
filenames_df = hentais_df.apply(lambda x: str(Path(str(x["id"])) / Path(x["image_urls"]).name), axis=1).rename("filename")
labels_df = hentais_df["tag"].rename("labels")
hentais_df = pd.concat([filenames_df, labels_df], axis=1)
hentais_df

## Image examples

In [None]:

nobjs = 8 # Maximum number of images to display
ncols = 4 # Number of columns in display
nrows = nobjs // ncols # Number of rows in display
samples = hentais_df["filename"].explode().apply(lambda x: str(data_path / x)).sample(nrows * ncols)
plt.figure(figsize=(14, 4 * nrows))
for i, img in enumerate(samples):
    ax = plt.subplot(nrows, ncols, i+1)
    ax.imshow(Image.open(img).convert("RGB"))

## Tensorflow DataSet

In [None]:
hentais_gen = keras.preprocessing.image.ImageDataGenerator(
    rescale=1./255,
    validation_split=0.2
)

In [None]:
BATCH_SIZE = 25 # Big enough to measure an F1-score
IMG_SIZE = 224 # Specify height and width of image to match the input format of the model

In [None]:
train_ds = hentais_gen.flow_from_dataframe(
    dataframe=hentais_df,
    directory="data",
    x_col="filename",
    y_col="labels",
    class_mode="categorical",
    batch_size=BATCH_SIZE,
    target_size=(IMG_SIZE, IMG_SIZE),
    shuffle=True,
    seed=44,
    subset="training"
)
val_ds = hentais_gen.flow_from_dataframe(
    dataframe=hentais_df,
    directory="data",
    x_col="filename",
    y_col="labels",
    class_mode="categorical",
    batch_size=BATCH_SIZE,
    target_size=(IMG_SIZE, IMG_SIZE),
    subset="validation"
)

In [None]:
nlabels = len(train_ds.class_indices)
print("Number of hentais labels: ", nlabels)

## Transfert learning feature extractor

In [None]:
model = Sequential([
    hub.KerasLayer(
        "https://tfhub.dev/google/imagenet/mobilenet_v3_large_075_224/feature_vector/5",
        input_shape=(IMG_SIZE, IMG_SIZE, 3),
        trainable=False
    ),
    keras.layers.Dense(64, activation="relu"),
    keras.layers.Dense(nlabels, activation="sigmoid")
])
model.summary()

## Train the model
Specify the learning rate and the number of training epochs (number of loops over the whole dataset).

In [None]:
LR = 1e-5 # Keep it small when transfer learning
EPOCHS = 30

In [None]:
model.compile(loss="binary_crossentropy", optimizer=keras.optimizers.Adam(learning_rate=LR), metrics=["accuracy"])

In [None]:
output_dir = Path("job")
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
savedmodel_dir = output_dir / "export/savedmodel"
model_export_path = savedmodel_dir / timestamp
checkpoint_path = output_dir / "checkpoints"
tensorboard_path = output_dir / "tensorboard"

In [None]:
callbacks = [
    keras.callbacks.ReduceLROnPlateau(),
    keras.callbacks.EarlyStopping(patience=2),
    keras.callbacks.TensorBoard(str(tensorboard_path)),
    keras.callbacks.ModelCheckpoint(str(checkpoint_path)),
]

history = model.fit(train_ds, epochs=EPOCHS, validation_data=val_ds, callbacks=callbacks)

In [None]:
plt.figure(figsize=(14, 4))

ax = plt.subplot(1, 2, 1)
ax.plot(history.history["accuracy"])
ax.plot(history.history["val_accuracy"])
ax.title("model accuracy")
ax.ylabel("accuracy")
ax.xlabel("epoch")
ax.legend(["train", "validation"])

ax = plt.subplot(1, 2, 2)
ax.plot(history.history["loss"])
ax.plot(history.history["val_loss"])
ax.title("model loss")
ax.ylabel("loss")
ax.xlabel("epoch")
ax.legend(["train", "validation"])
plt.show()

In [None]:
tf.saved_model.save(model, str(model_export_path))