In [17]:
!pip install pillow numpy torch opencv-python matplotlib seaborn torchvision tqdm scikit-learn

Collecting scikit-learn
  Downloading scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting joblib>=1.2.0 (from scikit-learn)
  Downloading joblib-1.4.0-py3-none-any.whl.metadata (5.4 kB)
Collecting threadpoolctl>=2.0.0 (from scikit-learn)
  Downloading threadpoolctl-3.5.0-py3-none-any.whl.metadata (13 kB)
Downloading scikit_learn-1.4.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m12.2/12.2 MB[0m [31m91.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m:01[0m
[?25hDownloading joblib-1.4.0-py3-none-any.whl (301 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m301.2/301.2 kB[0m [31m28.0 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading threadpoolctl-3.5.0-py3-none-any.whl (18 kB)
Installing collected packages: threadpoolctl, joblib, scikit-learn
Successfully installed joblib-1.4.0 scikit-learn-1.4.2 threadpoolctl-3.5.0


In [24]:
import cv2
import os
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
from tqdm import tqdm
import numpy as np

# Define the folders containing the images
folders = {
    0: "../data/Dataset/Non_Demented/",
    1: "../data/Dataset/Very_Mild_Demented/",
    2: "../data/Dataset/Mild_Demented/",
    3: "../data/Dataset/Moderate_Demented/",
}

# Initialize lists to store the images and labels
X = []
y = []

corrupted = 0

# Iterate over the folders and load each image
for label, path in folders.items():
    for filename in tqdm(os.listdir(path)):
        # Load the image
        img_path = os.path.join(path, filename)
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            img = np.zeros((128,128))
            corrupted += 1
        img_tensor = torch.from_numpy(img).float() 
        X.append(img_tensor)
        y.append(label)

X = torch.stack(X)
# Normalize the images
X = X / 255.0
y = torch.tensor(y)

print(X.shape, y.shape)

# Calculate the count of each class
class_counts = y.bincount()
num_samples = y.size(0)
class_weights = 1. / class_counts

# Create a list of weights for each sample
sample_weights = class_weights[y]

# Set up the sampler
sampler = WeightedRandomSampler(weights=sample_weights, num_samples=num_samples, replacement=True)

# Create a dataset and a dataloader
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=32, sampler=sampler)
print(f'Corrupted images: {corrupted}')

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 16 * 16, 128)
        self.fc2 = nn.Linear(128, 4)  # Assuming 4 classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 64 * 16 * 16)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

num_epochs = 10  # You can adjust this based on your specific needs

from sklearn.model_selection import train_test_split

# Split the data into training and validation sets
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# Create DataLoaders for both training and validation sets
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


def calculate_accuracy(outputs, labels):
    _, predicted = torch.max(outputs.data, 1)
    total = labels.size(0)
    correct = (predicted == labels).sum().item()
    return 100 * correct / total

# Starting the training and validation process
for epoch in range(num_epochs):
    model.train()  # Set the model to training mode
    running_loss = 0.0
    for images, labels in tqdm(train_loader):  # Use train_loader for training
        images = images.unsqueeze(1).to(device)
        labels = labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    print(f'Epoch {epoch + 1}, Training Loss: {running_loss / len(train_loader)}')
    
    # Validation phase
    model.eval()  # Set the model to evaluation mode
    val_running_loss = 0.0
    val_accuracy = 0.0
    with torch.no_grad():  # No gradient updates during validation
        for images, labels in tqdm(val_loader):  # Use val_loader for validation
            images = images.unsqueeze(1).to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item()
            val_accuracy += calculate_accuracy(outputs, labels)
    
    avg_val_loss = val_running_loss / len(val_loader)
    avg_val_accuracy = val_accuracy / len(val_loader)
    print(f'Epoch {epoch + 1}, Validation Loss: {avg_val_loss}, Validation Accuracy: {avg_val_accuracy}% \n')

100%|██████████| 3201/3201 [01:00<00:00, 52.85it/s]
100%|██████████| 2241/2241 [00:42<00:00, 53.22it/s]
100%|██████████| 897/897 [00:16<00:00, 53.65it/s]
100%|██████████| 65/65 [00:01<00:00, 54.00it/s]


torch.Size([6404, 128, 128]) torch.Size([6404])
Corrupted images: 4


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
100%|██████████| 161/161 [00:00<00:00, 164.83it/s]


Epoch 1, Training Loss: 1.042286263489575


100%|██████████| 41/41 [00:00<00:00, 296.89it/s]


Epoch 1, Validation Loss: 1.0475603647348357, Validation Accuracy: 53.734756097560975%


100%|██████████| 161/161 [00:00<00:00, 162.83it/s]


Epoch 2, Training Loss: 0.9360838706448952


100%|██████████| 41/41 [00:00<00:00, 381.21it/s]


Epoch 2, Validation Loss: 0.899555199029969, Validation Accuracy: 57.08841463414634%


100%|██████████| 161/161 [00:00<00:00, 163.64it/s]


Epoch 3, Training Loss: 0.8841765771001022


100%|██████████| 41/41 [00:00<00:00, 392.12it/s]


Epoch 3, Validation Loss: 0.8898717761039734, Validation Accuracy: 58.53658536585366%


100%|██████████| 161/161 [00:00<00:00, 164.77it/s]


Epoch 4, Training Loss: 0.796242917731682


100%|██████████| 41/41 [00:00<00:00, 395.02it/s]


Epoch 4, Validation Loss: 0.7493277584634176, Validation Accuracy: 66.23475609756098%


100%|██████████| 161/161 [00:00<00:00, 167.13it/s]


Epoch 5, Training Loss: 0.6802735380504442


100%|██████████| 41/41 [00:00<00:00, 410.48it/s]


Epoch 5, Validation Loss: 0.6225895838039678, Validation Accuracy: 75.3048780487805%


100%|██████████| 161/161 [00:00<00:00, 165.53it/s]


Epoch 6, Training Loss: 0.5544711965951861


100%|██████████| 41/41 [00:00<00:00, 398.82it/s]


Epoch 6, Validation Loss: 0.7564518887822221, Validation Accuracy: 63.8719512195122%


100%|██████████| 161/161 [00:00<00:00, 167.08it/s]


Epoch 7, Training Loss: 0.418086557469753


100%|██████████| 41/41 [00:00<00:00, 392.24it/s]


Epoch 7, Validation Loss: 0.4042680508843282, Validation Accuracy: 83.61280487804878%


100%|██████████| 161/161 [00:00<00:00, 166.42it/s]


Epoch 8, Training Loss: 0.28489164003859396


100%|██████████| 41/41 [00:00<00:00, 384.25it/s]


Epoch 8, Validation Loss: 0.34969379291756125, Validation Accuracy: 85.59451219512195%


100%|██████████| 161/161 [00:00<00:00, 164.11it/s]


Epoch 9, Training Loss: 0.22482332392879154


100%|██████████| 41/41 [00:00<00:00, 399.86it/s]


Epoch 9, Validation Loss: 0.31154799938383626, Validation Accuracy: 87.04268292682927%


100%|██████████| 161/161 [00:00<00:00, 167.48it/s]


Epoch 10, Training Loss: 0.1502990144426408


100%|██████████| 41/41 [00:00<00:00, 396.11it/s]

Epoch 10, Validation Loss: 0.16749838984957555, Validation Accuracy: 93.4451219512195%



