[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/enter-opy/genre-classification/blob/main/notebooks/spectrograms.ipynb)

In [25]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Dense, Flatten, Dropout, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from sklearn.model_selection import train_test_split

In [26]:
IMG_SIZE = (224, 224)
BATCH_SIZE = 32

In [27]:
datagen = ImageDataGenerator(rescale=1.0 / 255, validation_split=0.2)

generator = datagen.flow_from_directory(
    "../Data/images_original",
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode="categorical",
)

dataset = tf.data.Dataset.from_generator(
    lambda: generator, 
    output_signature=(
        tf.TensorSpec(shape=(None, 224, 224, 3), dtype=tf.float32),
        tf.TensorSpec(shape=(None, 10), dtype=tf.float32)
    )
)

Found 999 images belonging to 10 classes.


In [28]:
train_size = int(0.6 * 999)
validation_size = int(0.2 * 999)
test_size = int(0.2 * 999) 

dataset = dataset.shuffle(buffer_size=1000)

train_dataset = dataset.take(train_size)
validation_dataset = dataset.skip(train_size).take(validation_size)
test_dataset = dataset.skip(train_size + validation_size)

In [29]:
base_model = VGG16(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
base_model.trainable = False 

In [30]:
x = Flatten()(base_model.output)
x = Dense(512, activation="relu")(x)
x = Dropout(0.2)(x)
x = Dense(256, activation="relu")(x)
x = Dropout(0.2)(x)
x = Dense(128, activation="relu")(x)
x = Dropout(0.2)(x)
x = Dense(10, activation="softmax")(x)

In [31]:
model = Model(inputs=base_model.input, outputs=x)

In [32]:
model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

In [33]:
model.summary()

In [None]:
model.fit(train_dataset, validation_data=validation_dataset, epochs=50)

Epoch 1/50


In [14]:
tf.config.list_physical_devices('GPU')

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]