<a href="https://colab.research.google.com/github/sharon-kurant/archive_explorer/blob/main/training_and_inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


## Load imports and requirements

In [2]:
import os
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Dense, Flatten, Dropout
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import pathlib

AUTO = tf.data.AUTOTUNE

## Path constants

In [None]:
#### Set a path to a tagged parent folder containing subfolders with images ####
TRAIN_PATH = "/path/to/train/dataset/"

############# Set a path to an untagged  folder containing images ##############
INFERENCE_PATH = "/path/to/inference/dataset"

########################### Set a path to a saved model ########################
SAVE_MODEL_PATH = "/path/to/save/model"

## Hyperparameters constants

tune to match your needs

In [4]:
EPOCHS = 3
BATCH_SIZE = 4
IMAGE_SIZE = (512,512)
LEARNING_RATE = 0.001
TRAIN_VAL_SPLIT = 0.2
METRICS=['accuracy']
LOSS_FUNCTION = 'sparse_categorical_crossentropy'

## Train process

#### Load training and validation dataset

In [5]:
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
  TRAIN_PATH,
  validation_split=TRAIN_VAL_SPLIT,
  subset="training",
  seed=123,
  shuffle=True,
  labels="inferred",
  image_size=IMAGE_SIZE,
  batch_size=BATCH_SIZE,
  crop_to_aspect_ratio=True)

Found 19 files belonging to 3 classes.
Using 16 files for training.


In [6]:
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
  TRAIN_PATH,
  validation_split=TRAIN_VAL_SPLIT,
  subset="validation",
  seed=123,
  shuffle=True,
  labels="inferred",
  image_size=IMAGE_SIZE,
  batch_size=BATCH_SIZE,
  crop_to_aspect_ratio=True)

Found 19 files belonging to 3 classes.
Using 3 files for validation.


#### Load pretrained ResNet50 model and add a custom classification head

In [7]:
model = Sequential()

pretrained_model = tf.keras.applications.ResNet50(
                  include_top=False,
                  weights='imagenet',
                  input_shape=(*IMAGE_SIZE,3),
                  pooling='avg',
                  classes=len(train_ds.class_names))

# Freeze ResNet50 weights
for layer in pretrained_model.layers:
  layer.trainable=False

model.add(pretrained_model)

# Add new classification head fulle connected network with dropout
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.3))
model.add(Dense(128, activation='relu'))
model.add(Dense(len(train_ds.class_names), activation='softmax'))

#### Train the model

In [8]:
model.compile(optimizer=Adam(learning_rate=LEARNING_RATE),loss=LOSS_FUNCTION, metrics=METRICS)
history = model.fit(x=train_ds,validation_data=val_ds, epochs=EPOCHS)

Epoch 1/3
Epoch 2/3
Epoch 3/3


## Inference

#### Load inference dataset

In [9]:
inference_ds = tf.keras.preprocessing.image_dataset_from_directory(
  INFERENCE_PATH,
  shuffle=False,
  labels=None,
  image_size=IMAGE_SIZE,
  batch_size=BATCH_SIZE,
  crop_to_aspect_ratio = True)

Found 3 files belonging to 1 classes.


#### Apply trained model on the inference set

In [10]:
confidence = model.predict(inference_ds)
class_idx = np.argmax(confidence, axis=1)
class_name = [train_ds.class_names[p] for p in class_idx]



#### Save model results to prediction.csv

In [11]:
predictions_df = pd.DataFrame(data={'path':inference_ds.file_paths, 'class_idx': class_idx, 'class_name':class_name})
predictions_df.to_csv("predictions.csv", index=False)

## Save model (optional)

In [12]:
model.save(SAVE_MODEL_PATH)

## Load saved model and predict (optional)

In [13]:
loaded_model = keras.saving.load_model(SAVE_MODEL_PATH)

In [14]:
loaded_model.predict(inference_ds)



array([[0.18967162, 0.16748136, 0.642847  ],
       [0.45626593, 0.01939861, 0.5243355 ],
       [0.1826696 , 0.3698817 , 0.44744876]], dtype=float32)