In [2]:
pip install catboost

Collecting catboost
  Obtaining dependency information for catboost from https://files.pythonhosted.org/packages/1c/e1/78e635a1e5f0066bd02a1ecfd658ad09fe30d275c65c2d0dd76fe253e648/catboost-1.2.7-cp311-cp311-win_amd64.whl.metadata
  Using cached catboost-1.2.7-cp311-cp311-win_amd64.whl.metadata (1.2 kB)
Collecting graphviz (from catboost)
  Obtaining dependency information for graphviz from https://files.pythonhosted.org/packages/00/be/d59db2d1d52697c6adc9eacaf50e8965b6345cc143f671e1ed068818d5cf/graphviz-0.20.3-py3-none-any.whl.metadata
  Using cached graphviz-0.20.3-py3-none-any.whl.metadata (12 kB)
Downloading catboost-1.2.7-cp311-cp311-win_amd64.whl (101.7 MB)
   ---------------------------------------- 0.0/101.7 MB ? eta -:--:--
   ---------------------------------------- 0.0/101.7 MB ? eta -:--:--
   ---------------------------------------- 0.0/101.7 MB 320.0 kB/s eta 0:05:18
   ---------------------------------------- 0.0/101.7 MB 325.1 kB/s eta 0:05:13
   ------------------------

In [3]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Activation, Dropout, Flatten, Dense
from catboost import CatBoostClassifier
from sklearn.metrics import accuracy_score
import numpy as np
import joblib

In [4]:
train_dir = r'Alzheimers Dataset\train'  
test_dir = r'Alzheimers Dataset\test'

In [5]:
# Set the input shape
input_shape = (128, 128, 3)
batch_size = 32

In [6]:
# Use ImageDataGenerator for data augmentation
data_generator = ImageDataGenerator(rescale=1./255)

In [7]:
# Load the train images
train_generator = data_generator.flow_from_directory(
    train_dir,
    target_size=(input_shape[0], input_shape[1]),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False)

Found 5121 images belonging to 4 classes.


In [8]:
# Load the test images
test_generator = data_generator.flow_from_directory(
    test_dir,
    target_size=(input_shape[0], input_shape[1]),
    batch_size=batch_size,
    class_mode='categorical',
    shuffle=False)

Found 1279 images belonging to 4 classes.


In [9]:
# Define the model until the feature extraction layer
model = Sequential([
    Conv2D(32, (3, 3), input_shape=input_shape),
    Activation('relu'),
    MaxPooling2D(pool_size=(2, 2)),
    Conv2D(64, (3, 3)),
    Activation('relu'),
    MaxPooling2D(pool_size=(2, 2)),
    Flatten(),
    Dense(64),
    Activation('relu'),
    Dropout(0.5)
])

  super().__init__(activity_regularizer=activity_regularizer, **kwargs)


In [10]:
# Extract features from train images 
train_features = model.predict(train_generator)
train_features = train_features.reshape(train_features.shape[0], -1)  
train_labels = train_generator.classes

  self._warn_if_super_not_called()


[1m161/161[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m51s[0m 319ms/step


In [11]:
# Extract features from test images 
test_features = model.predict(test_generator)
test_features = test_features.reshape(test_features.shape[0], -1)
test_labels = test_generator.classes

[1m40/40[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 344ms/step


In [12]:
# Use CatBoost classifier
cb_classifier = CatBoostClassifier(iterations=1000, learning_rate=0.1, depth=6)
cb_classifier.fit(train_features, train_labels)

0:	learn: 1.3217437	total: 190ms	remaining: 3m 9s
1:	learn: 1.2678871	total: 229ms	remaining: 1m 54s
2:	learn: 1.2242691	total: 255ms	remaining: 1m 24s
3:	learn: 1.1831662	total: 279ms	remaining: 1m 9s
4:	learn: 1.1522876	total: 304ms	remaining: 1m
5:	learn: 1.1229351	total: 327ms	remaining: 54.2s
6:	learn: 1.0981603	total: 354ms	remaining: 50.2s
7:	learn: 1.0754309	total: 379ms	remaining: 47s
8:	learn: 1.0561446	total: 409ms	remaining: 45s
9:	learn: 1.0364044	total: 439ms	remaining: 43.4s
10:	learn: 1.0183532	total: 465ms	remaining: 41.8s
11:	learn: 1.0025408	total: 490ms	remaining: 40.4s
12:	learn: 0.9893529	total: 517ms	remaining: 39.2s
13:	learn: 0.9768079	total: 541ms	remaining: 38.1s
14:	learn: 0.9662080	total: 567ms	remaining: 37.3s
15:	learn: 0.9554846	total: 593ms	remaining: 36.5s
16:	learn: 0.9452679	total: 621ms	remaining: 35.9s
17:	learn: 0.9362306	total: 651ms	remaining: 35.5s
18:	learn: 0.9277166	total: 676ms	remaining: 34.9s
19:	learn: 0.9192935	total: 701ms	remaining: 3

164:	learn: 0.5508235	total: 5.88s	remaining: 29.8s
165:	learn: 0.5491420	total: 5.93s	remaining: 29.8s
166:	learn: 0.5475080	total: 5.97s	remaining: 29.8s
167:	learn: 0.5460004	total: 6.01s	remaining: 29.8s
168:	learn: 0.5432082	total: 6.05s	remaining: 29.8s
169:	learn: 0.5411039	total: 6.09s	remaining: 29.7s
170:	learn: 0.5392916	total: 6.13s	remaining: 29.7s
171:	learn: 0.5374389	total: 6.16s	remaining: 29.7s
172:	learn: 0.5359634	total: 6.2s	remaining: 29.6s
173:	learn: 0.5341927	total: 6.23s	remaining: 29.6s
174:	learn: 0.5318635	total: 6.27s	remaining: 29.5s
175:	learn: 0.5306009	total: 6.3s	remaining: 29.5s
176:	learn: 0.5293966	total: 6.33s	remaining: 29.5s
177:	learn: 0.5282262	total: 6.37s	remaining: 29.4s
178:	learn: 0.5265716	total: 6.41s	remaining: 29.4s
179:	learn: 0.5255260	total: 6.44s	remaining: 29.3s
180:	learn: 0.5231824	total: 6.48s	remaining: 29.3s
181:	learn: 0.5222764	total: 6.51s	remaining: 29.3s
182:	learn: 0.5206621	total: 6.54s	remaining: 29.2s
183:	learn: 0.

325:	learn: 0.3578632	total: 12.3s	remaining: 25.4s
326:	learn: 0.3570301	total: 12.3s	remaining: 25.3s
327:	learn: 0.3565018	total: 12.3s	remaining: 25.3s
328:	learn: 0.3555772	total: 12.4s	remaining: 25.2s
329:	learn: 0.3544822	total: 12.4s	remaining: 25.2s
330:	learn: 0.3535645	total: 12.4s	remaining: 25.2s
331:	learn: 0.3521804	total: 12.5s	remaining: 25.1s
332:	learn: 0.3511265	total: 12.5s	remaining: 25.1s
333:	learn: 0.3503200	total: 12.6s	remaining: 25.1s
334:	learn: 0.3497057	total: 12.6s	remaining: 25s
335:	learn: 0.3490247	total: 12.6s	remaining: 25s
336:	learn: 0.3481774	total: 12.7s	remaining: 24.9s
337:	learn: 0.3468133	total: 12.7s	remaining: 24.9s
338:	learn: 0.3459864	total: 12.8s	remaining: 24.9s
339:	learn: 0.3453091	total: 12.8s	remaining: 24.8s
340:	learn: 0.3443771	total: 12.8s	remaining: 24.8s
341:	learn: 0.3438784	total: 12.9s	remaining: 24.7s
342:	learn: 0.3425631	total: 12.9s	remaining: 24.7s
343:	learn: 0.3414503	total: 12.9s	remaining: 24.7s
344:	learn: 0.34

487:	learn: 0.2518922	total: 18.4s	remaining: 19.3s
488:	learn: 0.2511888	total: 18.4s	remaining: 19.2s
489:	learn: 0.2504472	total: 18.5s	remaining: 19.2s
490:	learn: 0.2494827	total: 18.5s	remaining: 19.2s
491:	learn: 0.2488500	total: 18.6s	remaining: 19.2s
492:	learn: 0.2483112	total: 18.6s	remaining: 19.1s
493:	learn: 0.2478933	total: 18.7s	remaining: 19.1s
494:	learn: 0.2471760	total: 18.7s	remaining: 19.1s
495:	learn: 0.2466958	total: 18.8s	remaining: 19.1s
496:	learn: 0.2460861	total: 18.8s	remaining: 19s
497:	learn: 0.2456488	total: 18.9s	remaining: 19s
498:	learn: 0.2451139	total: 18.9s	remaining: 19s
499:	learn: 0.2448093	total: 19s	remaining: 19s
500:	learn: 0.2442440	total: 19.1s	remaining: 19s
501:	learn: 0.2437046	total: 19.1s	remaining: 19s
502:	learn: 0.2430337	total: 19.2s	remaining: 18.9s
503:	learn: 0.2425021	total: 19.2s	remaining: 18.9s
504:	learn: 0.2417559	total: 19.3s	remaining: 18.9s
505:	learn: 0.2414810	total: 19.3s	remaining: 18.9s
506:	learn: 0.2408759	tota

649:	learn: 0.1822821	total: 27.1s	remaining: 14.6s
650:	learn: 0.1821649	total: 27.2s	remaining: 14.6s
651:	learn: 0.1818987	total: 27.2s	remaining: 14.5s
652:	learn: 0.1816918	total: 27.3s	remaining: 14.5s
653:	learn: 0.1812963	total: 27.4s	remaining: 14.5s
654:	learn: 0.1809133	total: 27.4s	remaining: 14.4s
655:	learn: 0.1805204	total: 27.5s	remaining: 14.4s
656:	learn: 0.1800998	total: 27.5s	remaining: 14.4s
657:	learn: 0.1798784	total: 27.6s	remaining: 14.3s
658:	learn: 0.1794557	total: 27.6s	remaining: 14.3s
659:	learn: 0.1790506	total: 27.7s	remaining: 14.3s
660:	learn: 0.1785712	total: 27.8s	remaining: 14.2s
661:	learn: 0.1781829	total: 27.8s	remaining: 14.2s
662:	learn: 0.1780376	total: 27.9s	remaining: 14.2s
663:	learn: 0.1775395	total: 27.9s	remaining: 14.1s
664:	learn: 0.1773624	total: 28s	remaining: 14.1s
665:	learn: 0.1770478	total: 28.1s	remaining: 14.1s
666:	learn: 0.1765899	total: 28.1s	remaining: 14s
667:	learn: 0.1762339	total: 28.2s	remaining: 14s
668:	learn: 0.1758

808:	learn: 0.1366536	total: 37.1s	remaining: 8.76s
809:	learn: 0.1363964	total: 37.2s	remaining: 8.71s
810:	learn: 0.1361148	total: 37.2s	remaining: 8.68s
811:	learn: 0.1358482	total: 37.3s	remaining: 8.63s
812:	learn: 0.1356873	total: 37.4s	remaining: 8.59s
813:	learn: 0.1354767	total: 37.4s	remaining: 8.55s
814:	learn: 0.1351280	total: 37.5s	remaining: 8.51s
815:	learn: 0.1348465	total: 37.6s	remaining: 8.47s
816:	learn: 0.1346040	total: 37.6s	remaining: 8.43s
817:	learn: 0.1343619	total: 37.7s	remaining: 8.39s
818:	learn: 0.1340038	total: 37.8s	remaining: 8.36s
819:	learn: 0.1335570	total: 37.9s	remaining: 8.32s
820:	learn: 0.1334179	total: 38s	remaining: 8.28s
821:	learn: 0.1330128	total: 38.1s	remaining: 8.24s
822:	learn: 0.1327073	total: 38.1s	remaining: 8.2s
823:	learn: 0.1324491	total: 38.2s	remaining: 8.15s
824:	learn: 0.1322147	total: 38.2s	remaining: 8.11s
825:	learn: 0.1319004	total: 38.3s	remaining: 8.07s
826:	learn: 0.1316953	total: 38.4s	remaining: 8.03s
827:	learn: 0.1

970:	learn: 0.1042377	total: 47.1s	remaining: 1.41s
971:	learn: 0.1042056	total: 47.1s	remaining: 1.36s
972:	learn: 0.1039071	total: 47.2s	remaining: 1.31s
973:	learn: 0.1037907	total: 47.2s	remaining: 1.26s
974:	learn: 0.1036504	total: 47.3s	remaining: 1.21s
975:	learn: 0.1035381	total: 47.3s	remaining: 1.16s
976:	learn: 0.1034426	total: 47.4s	remaining: 1.11s
977:	learn: 0.1032685	total: 47.4s	remaining: 1.07s
978:	learn: 0.1030763	total: 47.5s	remaining: 1.02s
979:	learn: 0.1029812	total: 47.5s	remaining: 970ms
980:	learn: 0.1028983	total: 47.6s	remaining: 922ms
981:	learn: 0.1026483	total: 47.7s	remaining: 873ms
982:	learn: 0.1024993	total: 47.7s	remaining: 825ms
983:	learn: 0.1022417	total: 47.8s	remaining: 777ms
984:	learn: 0.1020632	total: 47.8s	remaining: 728ms
985:	learn: 0.1019116	total: 47.9s	remaining: 680ms
986:	learn: 0.1017885	total: 47.9s	remaining: 631ms
987:	learn: 0.1016640	total: 48s	remaining: 583ms
988:	learn: 0.1014460	total: 48s	remaining: 534ms
989:	learn: 0.10

<catboost.core.CatBoostClassifier at 0x1d089b11290>

In [13]:
# Make predictions on the test data
test_predictions = cb_classifier.predict(test_features)

In [14]:
# Calculate accuracy
test_accuracy = accuracy_score(test_labels, test_predictions)
print(f"Test Accuracy: {test_accuracy * 100:.2f}%")

Test Accuracy: 57.47%


In [15]:
# Save the Keras CNN model as an .h5 file
model.save('cnn_feature_extractor.h5')



In [16]:
# Save the trained CatBoost model as a .pkl file
joblib.dump(cb_classifier, 'catboost_classifier.pkl')

['catboost_classifier.pkl']

In [17]:
from keras.models import load_model
import joblib
import numpy as np
from tensorflow.keras.preprocessing import image
from catboost import CatBoostClassifier

# Load the saved models
cnn_model = load_model('cnn_feature_extractor.h5')
catboost_classifier = joblib.load('catboost_classifier.pkl')  

def classify_image(img_path):
    # Load and preprocess the new image
    img = image.load_img(img_path, target_size=(128, 128))  
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)  
    img_array /= 255.0  

    # Extract features using the CNN model
    features = cnn_model.predict(img_array)
    features = features.reshape(1, -1)  
    
    # Classify the features with the CatBoost model
    prediction = catboost_classifier.predict(features)
    
    # Ensure prediction is an integer
    predicted_class = int(prediction[0])

    # Interpret the prediction
    class_labels = {0: 'MildDemented', 1: 'ModerateDemented', 2: 'NonDemented', 3: 'VeryMildDemented'}  
    result = class_labels[predicted_class]
    return result

# Test the function with a new image
img_path = r'Alzheimers Dataset\test\NonDemented\26 (62).jpg'
result = classify_image(img_path)
print(f"The image is classified as: {result}")




[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 211ms/step
The image is classified as: NonDemented
