# Model Training Notebook

This notebook is used to train the CNN model for fashion classification using the dataset loaded from Google Drive.

In [1]:
# Import necessary libraries
import os
import numpy as np
import tensorflow as tf
from src.data.dataset import FashionDataset
from src.models.cnn_model import CNNModel
from src.training.train import train_model
from src.utils.config import Config

# Load configuration settings
config = Config()

# Load dataset
train_dataset = FashionDataset(data_dir=config.data_dir, split='train')
val_dataset = FashionDataset(data_dir=config.data_dir, split='val')

# Initialize the model
model = CNNModel(input_shape=(config.image_height, config.image_width, config.channels),
                 num_classes=config.num_classes)

# Train the model
history = train_model(model, train_dataset, val_dataset, config)

# Save the trained model
model.save(os.path.join(config.model_dir, 'fashion_cnn_model.h5'))

# Display training history
import matplotlib.pyplot as plt

plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['val_accuracy'], label='val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.title('Model Accuracy')
plt.show()