# Image Classification using CIFAR dataset

## 1. Install necessary libraries

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.utils import to_categorical
from sklearn.metrics import classification_report, accuracy_score

## 2. Load and preprocess the dataset

In [None]:
# Load CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize the images to [0, 1] range
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# One-hot encode the labels
num_classes = 10
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

# Class names for CIFAR-10 dataset
class_names = ['Airplane', 'Automobile', 'Bird', 'Cat', 'Deer', 'Dog', 'Frog', 'Horse', 'Ship', 'Truck']

## 3. Model Training

In [None]:
# Define the CNN model
def create_cnn_model():
    model = Sequential()
    model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(64, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Conv2D(128, (3, 3), activation='relu'))
    model.add(MaxPooling2D(pool_size=(2, 2)))

    model.add(Flatten())
    model.add(Dense(128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(num_classes, activation='softmax'))  # Softmax for multi-class classification

    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    return model

# Create and summarize the model
cnn_model = create_cnn_model()
cnn_model.summary()

# Train the model
cnn_model.fit(x_train, y_train, epochs=10, batch_size=64, validation_split=0.2)

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 6ms/step - accuracy: 0.2584 - loss: 1.9719 - val_accuracy: 0.4735 - val_loss: 1.4437
Epoch 2/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.4854 - loss: 1.4209 - val_accuracy: 0.5339 - val_loss: 1.2885
Epoch 3/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 5ms/step - accuracy: 0.5611 - loss: 1.2286 - val_accuracy: 0.6054 - val_loss: 1.1067
Epoch 4/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 5ms/step - accuracy: 0.6032 - loss: 1.1209 - val_accuracy: 0.6415 - val_loss: 1.0145
Epoch 5/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 4ms/step - accuracy: 0.6382 - loss: 1.0255 - val_accuracy: 0.6686 - val_loss: 0.9510
Epoch 6/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 4ms/step - accuracy: 0.6609 - loss: 0.9529 - val_accuracy: 0.6604 - val_loss: 0.9671
Epoch 7/10
[1m625/625[0m 

<keras.src.callbacks.history.History at 0x7d0149112980>

## 4. Model Evaluation

In [None]:
# Evaluate the model on test data
test_loss, test_accuracy = cnn_model.evaluate(x_test, y_test, verbose=0)
print(f'Test accuracy: {test_accuracy:.4f}')

# Predictions on the test set
y_pred_prob = cnn_model.predict(x_test)
y_pred = np.argmax(y_pred_prob, axis=1)  # Get the predicted class index
y_true = np.argmax(y_test, axis=1)  # Get the true class index

# Print classification report
print("Classification Report:")
print(classification_report(y_true, y_pred))


Test accuracy: 0.7069
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 4ms/step
Classification Report:
              precision    recall  f1-score   support

           0       0.70      0.81      0.75      1000
           1       0.85      0.83      0.84      1000
           2       0.67      0.55      0.60      1000
           3       0.49      0.54      0.52      1000
           4       0.62      0.70      0.66      1000
           5       0.61      0.63      0.62      1000
           6       0.72      0.83      0.78      1000
           7       0.75      0.73      0.74      1000
           8       0.87      0.73      0.80      1000
           9       0.88      0.71      0.78      1000

    accuracy                           0.71     10000
   macro avg       0.72      0.71      0.71     10000
weighted avg       0.72      0.71      0.71     10000



## 5. Visualise the predictions

In [None]:
# Plot some sample predictions with class names
def plot_samples(x, y_true, y_pred, num_samples=10):
    plt.figure(figsize=(15, 6))
    for i in range(num_samples):
        plt.subplot(2, 5, i + 1)
        plt.imshow(x[i])
        plt.title(f'True: {class_names[y_true[i]]}\nPred: {class_names[y_pred[i]]}')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

# Visualize some predictions
plot_samples(x_test, y_true, y_pred, num_samples=10)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz
[1m170498071/170498071[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 0us/step


  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


Epoch 1/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m67s[0m 101ms/step - accuracy: 0.2681 - loss: 1.9593 - val_accuracy: 0.4922 - val_loss: 1.4063
Epoch 2/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 101ms/step - accuracy: 0.4810 - loss: 1.4307 - val_accuracy: 0.5556 - val_loss: 1.2332
Epoch 3/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 100ms/step - accuracy: 0.5417 - loss: 1.2787 - val_accuracy: 0.6045 - val_loss: 1.1165
Epoch 4/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m85s[0m 106ms/step - accuracy: 0.6014 - loss: 1.1333 - val_accuracy: 0.6269 - val_loss: 1.0495
Epoch 5/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m63s[0m 101ms/step - accuracy: 0.6257 - loss: 1.0573 - val_accuracy: 0.6615 - val_loss: 0.9681
Epoch 6/10
[1m625/625[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m82s[0m 100ms/step - accuracy: 0.6613 - loss: 0.9725 - val_accuracy: 0.6692 - val_loss: 0.9419
Epoch 7/10

KeyboardInterrupt: 