# Create Model Classification For Figures Type: 

In [1]:
!pip install tensorflow



In [11]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Flatten, Dense, Dropout
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from PIL import Image

# Define constants
target_size = (224, 224)
batch_size = 32
num_classes = 9  # Change this to the number of classes in your dataset

# Data Augmentation and Preprocessing
train_datagen = ImageDataGenerator(
    rescale=1.0/255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

def load_and_preprocess_image(image_path):
    image = Image.open(image_path)
    if image.mode != "RGB":
        image = image.convert("RGB")
    image = image.resize(target_size)
    image = np.array(image) / 255.0
    return image


def load_train_data(dataset_directory):
    data = []
    labels = []
    
    for category_name in os.listdir(dataset_directory):
        category_path = os.path.join(dataset_directory, category_name)
        if os.path.isdir(category_path):
            for image_name in os.listdir(category_path):
                image_path = os.path.join(category_path, image_name)
                image = load_and_preprocess_image(image_path)
                data.append(image)
                labels.append(category_name)
    
    return np.array(data), np.array(labels)


In [None]:
train_data, train_labels = load_train_data('C:\\Users\\USER\\Documents\\figures\\train_set')

# Convert labels to one-hot encoding
from sklearn.preprocessing import LabelBinarizer
label_binarizer = LabelBinarizer()
train_labels = label_binarizer.fit_transform(train_labels)

# Load pre-trained VGG16 model
base_model = VGG16(include_top=False, weights='imagenet', input_shape=(target_size[0], target_size[1], 3))

# Freeze the layers in the base model
for layer in base_model.layers:
    layer.trainable = False

# Add custom classification layers on top
model = Sequential([
    base_model,
    Flatten(),
    Dense(512, activation='relu'),
    Dropout(0.5),
    Dense(num_classes, activation='softmax')
])

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

In [None]:
from sklearn.model_selection import train_test_split

train_data, val_data, train_labels, val_labels = train_test_split(
    train_data, train_labels, test_size=0.2, random_state=42)

In [4]:
model.fit(train_data, train_labels, epochs=10, batch_size=batch_size, validation_data=(val_data, val_labels))
model.save('image_classification_model_with_vgg16.h5')

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


  saving_api.save_model(


In [5]:
model.save('image_classification_model_with_vgg16.h5')

In [12]:
train_data, train_labels = load_train_data('C:\\Users\\USER\\Documents\\figures\\train_set')


In [None]:
type_num_to_type_name = {0: "bar_graphs", 1: "code", 2: "diagrams", 3: "heatmap", 4: "images", 5: "line_graphs", 6: "others", 7: "pie_graphs", 8: "scatter_plot"}

In [13]:
train_labels

array(['bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_graphs',
       'bar_graphs', 'bar_graphs', 'bar_graphs', 'bar_g