In [1]:
import os
import sys
import cv2
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import tensorflow as tf
from typing import Iterable, Tuple, List


# If the following line throws an error, do `pip install -e ./` from the base repo dir
from utils import ImageDataset

# Change to match data filepath on local
# base_fp = r'D:\Documents\MIDS\W281\2023-mids-w81-final-project-dataset\256x256'
base_fp = r'/Users/taehun.kim/mids/rendered_256x256/256x256'

2023-04-08 13:26:15.400326: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
# Load full dataset
full_dataset = ImageDataset.from_directory(base_fp)

Loading dataset...
Loaded 12500 unique image IDs


In [3]:
NUM_CLASSES = full_dataset.num_classes
IMAGE_SIZE = full_dataset.image_size
assert IMAGE_SIZE == full_dataset[0].photos[0].load().shape, 'Photos and sketches are not the same size'

In [42]:
# Train-test split, need to split on ID
train, valid, test = full_dataset.split([0.8, 0.1, 0.1])
assert train.num_classes == valid.num_classes == test.num_classes

# This approach not necessary as tf takes care of it

In [64]:
# Train-valid-test split [0.8, 0.1, 0.1] using tensorflow generator to avoid RAM limitations
no_aug_directory = '/Users/taehun.kim/mids/rendered_256x256/256x256/sketch/tx_000000000000'
SEED = 1 # Must be same for train_ds and val_ds
validation_split = 0.2

train_ds, val_ds = tf.keras.utils.image_dataset_from_directory(
    directory =no_aug_directory,
    image_size = IMAGE_SIZE[:2],
    label_mode='categorical',
    validation_split = validation_split,
    subset = 'both',
    seed = SEED,
    color_mode = 'rgb'
)

val_batch_count = tf.data.experimental.cardinality(val_ds)
test_ds = val_ds.take(val_batch_count // 2)
val_ds = val_ds.skip(val_batch_count // 2)

Found 75481 files belonging to 125 classes.
Using 60385 files for training.
Using 15096 files for validation.


In [60]:
# Build model
pretrained_resnet = tf.keras.applications.ResNet50(
    include_top=False,
    input_shape=IMAGE_SIZE,
    pooling='avg',
    classes=NUM_CLASSES,
)
pretrained_resnet.trainable = False

layer = tf.keras.layers.Dense(1024, activation='relu')(pretrained_resnet.output)
layer = tf.keras.layers.Dense(512, activation='relu')(layer)
outputs = tf.keras.layers.Dense(len(train_ds.class_names), activation='softmax')(layer)
model = tf.keras.Model(pretrained_resnet.input, outputs)

In [61]:
model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

In [66]:
# If below line isn't set, tf function decorator is thrown
tf.config.run_functions_eagerly(True)

epochs=10
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=epochs,
  callbacks=[tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=3,restore_best_weights=True)]
)


Epoch 1/10
   5/1888 [..............................] - ETA: 5:57:34 - loss: 4.8739 - accuracy: 0.0000e+00