In [1]:
import os
from art.params import *

base_dir = LOCAL_DATA_PATH
train_dir = os.path.join(base_dir, "train")
validation_dir = os.path.join(base_dir, "validation")
test_dir = os.path.join(base_dir, "test")

In [None]:
from art.load_data import load_data
from art.clean import clean_data, get_most_common

data = load_data(path="/Users/poloniki/code/melisa/art/label_list/label_list.csv")
cleaned_data = clean_data(data)
df = get_most_common(cleaned_data, "Style", 0.8)

In [3]:
df.Style.unique()

array(['impressionism', 'conceptual_art', 'surrealism',
       'early_renaissance', 'baroque', 'rococo', 'neoclassicism',
       'romanticism', 'ukiyo_e', 'naïve_art_primitivism', 'realism',
       'abstract_art', 'expressionism', 'symbolism', 'academicism',
       'post_impressionism', 'art_nouveau_modern', 'cubism',
       'abstract_expressionism', 'magic_realism',
       'mannerism_late_renaissance', 'northern_renaissance',
       'high_renaissance', 'op_art', 'art_informel', 'minimalism',
       'color_field_painting', 'pop_art'], dtype=object)

In [4]:
import os
from PIL import Image

def is_image_valid(image_path):
    try:
        img = Image.open(image_path)
        img.verify()  # Verify that it is an image
        img.close()   # Close the image to avoid resource leaks

        # Reopen the image to ensure it can be read
        img = Image.open(image_path)
        img.load()    # Ensure the image can be fully loaded
        img.close()
        return True
    except (IOError, SyntaxError, OSError) as e:
        return False

def check_and_remove_invalid_images(base_dir):
    for root, dirs, files in os.walk(base_dir):
        for file in files:
            file_path = os.path.join(root, file)
            if not is_image_valid(file_path):
                print(f"Removing invalid image: {file_path}")
                os.remove(file_path)

# Set the path to the directory containing your image dataset
dataset_dir = LOCAL_DATA_PATH

# Check and remove empty or invalid images
check_and_remove_invalid_images(dataset_dir)


Removing invalid image: /Users/poloniki/.lewagon/art_data/validation.cache
Removing invalid image: /Users/poloniki/.lewagon/art_data/train.cache


KeyboardInterrupt: 

In [10]:
import tensorflow as tf
from tensorflow.keras.preprocessing import image_dataset_from_directory



# Specify the batch size and image dimensions
batch_size = 32
img_height = 224
img_width = 224

# Load the training dataset with augmentation
train_dataset = image_dataset_from_directory(
    train_dir,

    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='categorical',  # For multi-class classification
    labels='inferred'
)

# Load the validation dataset
validation_dataset = image_dataset_from_directory(
    validation_dir,

    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size,
    label_mode='categorical',  # For multi-class classification
    labels='inferred'
)

# Create data augmentation layers
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip('horizontal'),
    tf.keras.layers.RandomRotation(0.2),
    tf.keras.layers.RandomWidth(0.2),
    tf.keras.layers.RandomHeight(0.2)
])

# Prefetch the datasets for better performance
AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)


Found 77293 files belonging to 28 classes.
Found 13828 files belonging to 28 classes.


In [11]:
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, GlobalAveragePooling2D
from tensorflow.keras.optimizers import Adam

# Load the pre-trained VGG16 model
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_height, img_width, 3))

# Freeze the layers of the base model
for layer in base_model.layers:
    layer.trainable = False

# Create a sequential model and add the VGG16 model
model = Sequential([
    data_augmentation,          # Adding data augmentation directly into the model
    base_model,
    GlobalAveragePooling2D(),
    Dense(512, activation='relu'),
    Dropout(0.5),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(28, activation='softmax')  # Number of classes inferred from dataset
])

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])


In [13]:
from tensorflow.keras.callbacks import EarlyStopping

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True)

# Train the model
history = model.fit(
    train_dataset,
    validation_data=validation_dataset,
    epochs=300,
    callbacks=[early_stopping],
)


Epoch 1/300


  83/2416 [>.............................] - ETA: 14:22 - loss: 6.7537 - accuracy: 0.0602

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 439/2416 [====>.........................] - ETA: 9:50 - loss: 4.6755 - accuracy: 0.0829



Epoch 2/300
  84/2416 [>.............................] - ETA: 6:59 - loss: 2.8355 - accuracy: 0.1853

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 439/2416 [====>.........................] - ETA: 5:26 - loss: 2.8136 - accuracy: 0.1880



Epoch 3/300
  84/2416 [>.............................] - ETA: 4:22 - loss: 2.5860 - accuracy: 0.2310

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 439/2416 [====>.........................] - ETA: 3:52 - loss: 2.5892 - accuracy: 0.2363



Epoch 4/300
  85/2416 [>.............................] - ETA: 3:34 - loss: 2.4677 - accuracy: 0.2607

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 439/2416 [====>.........................] - ETA: 3:12 - loss: 2.4679 - accuracy: 0.2607



Epoch 5/300
  87/2416 [>.............................] - ETA: 3:08 - loss: 2.3818 - accuracy: 0.2769

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 439/2416 [====>.........................] - ETA: 2:43 - loss: 2.3846 - accuracy: 0.2817



Epoch 6/300
  85/2416 [>.............................] - ETA: 2:55 - loss: 2.3139 - accuracy: 0.2971

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 439/2416 [====>.........................] - ETA: 2:36 - loss: 2.3265 - accuracy: 0.2929



Epoch 7/300
  87/2416 [>.............................] - ETA: 2:58 - loss: 2.2675 - accuracy: 0.3118

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 441/2416 [====>.........................] - ETA: 2:07 - loss: 2.2877 - accuracy: 0.3035



Epoch 8/300
  85/2416 [>.............................] - ETA: 2:10 - loss: 2.2247 - accuracy: 0.3213

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 441/2416 [====>.........................] - ETA: 1:55 - loss: 2.2401 - accuracy: 0.3172



Epoch 9/300
  86/2416 [>.............................] - ETA: 1:52 - loss: 2.1851 - accuracy: 0.3292

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 441/2416 [====>.........................] - ETA: 1:49 - loss: 2.2012 - accuracy: 0.3210



Epoch 10/300
  86/2416 [>.............................] - ETA: 1:57 - loss: 2.1540 - accuracy: 0.3401

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 438/2416 [====>.........................] - ETA: 1:44 - loss: 2.1666 - accuracy: 0.3325



Epoch 11/300
  87/2416 [>.............................] - ETA: 2:06 - loss: 2.1293 - accuracy: 0.3445

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 441/2416 [====>.........................] - ETA: 1:39 - loss: 2.1535 - accuracy: 0.3360



Epoch 12/300
  83/2416 [>.............................] - ETA: 1:57 - loss: 2.1187 - accuracy: 0.3441

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 440/2416 [====>.........................] - ETA: 1:39 - loss: 2.1313 - accuracy: 0.3397



Epoch 13/300
  87/2416 [>.............................] - ETA: 2:00 - loss: 2.0835 - accuracy: 0.3502

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 441/2416 [====>.........................] - ETA: 1:35 - loss: 2.1009 - accuracy: 0.3508



Epoch 14/300
  86/2416 [>.............................] - ETA: 1:54 - loss: 2.1099 - accuracy: 0.3467

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 439/2416 [====>.........................] - ETA: 1:36 - loss: 2.1018 - accuracy: 0.3482



Epoch 15/300
  86/2416 [>.............................] - ETA: 1:50 - loss: 2.0710 - accuracy: 0.3648

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 442/2416 [====>.........................] - ETA: 1:33 - loss: 2.0887 - accuracy: 0.3585



Epoch 16/300
  87/2416 [>.............................] - ETA: 1:44 - loss: 2.0412 - accuracy: 0.3534

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 441/2416 [====>.........................] - ETA: 1:28 - loss: 2.0685 - accuracy: 0.3566



Epoch 17/300
  86/2416 [>.............................] - ETA: 1:45 - loss: 2.0537 - accuracy: 0.3652

Corrupt JPEG data: 4015 extraneous bytes before marker 0xe2


 442/2416 [====>.........................] - ETA: 1:32 - loss: 2.0332 - accuracy: 0.3681



