# Alzheimer’s Disease Classification

Alzheimer’s is a brain disease that affects memory and thinking. It’s the most common cause of dementia and gets worse over time. Early detection helps with better care and planning.

In this session, we’ll use MRI images and a simple deep learning model (CNN) to classify brain scans into stages of Alzheimer’s.


#Dataset Overview

This dataset has MRI brain scans showing different stages of Alzheimer’s disease. The images are saved in Parquet format, where each row contains one image stored as encoded bytes.

- MRI Images: Brain scan data.
- Labels: Four classes showing Alzheimer’s stage.
- Train/Test Split: Already divided into training and testing sets.

## Class Labels

Each MRI scan belongs to **one of four categories**:

| Label | Class Name           | Description                                |
|-------|----------------------|--------------------------------------------|
| **0** | Mild Demented        | Early signs of dementia.                   |
| **1** | Moderate Demented    | Noticeable memory loss and confusion.      |
| **2** | Non Demented         | Healthy brain, no Alzheimer’s symptoms.    |
| **3** | Very Mild Demented   | Slight cognitive decline, minimal symptoms.|


# Machine Learning workflow :
(Recall what we did yesterday))

1. Load the MRI image dataset  
2. Preprocess the data  
3. Build a CNN model  
4. Train the model on training data  
5. Evaluate on test data  
6. Use it to make predictions

We’ll go through each of these steps in the notebook.

#Load MRI Image Data

In [None]:
# Importing necessary libraries for machine learning and visualization
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

#suppress warnings
import warnings
warnings.filterwarnings('ignore')

In [None]:
# #Mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')


# train_data_file = "/content/drive/MyDrive/SAU/Workshop BIO-ML/Lab Workbooks/Alzheimer MRI Disease Classification Dataset/Data/train-00000-of-00001-c08a401c53fe5312.parquet"
# test_data_file = "/content/drive/MyDrive/SAU/Workshop BIO-ML/Lab Workbooks/Alzheimer MRI Disease Classification Dataset/Data/test-00000-of-00001-44110b9df98c5585.parquet"

# df_train = pd.read_parquet(train_data_file)
# df_test = pd.read_parquet(test_data_file)

In [None]:
# We could have stored our dataset in google drive and then access it from there
# Below code avoids the steps of authorizing access to google drive for participants
# everytime one access the dataset


# Alternative: google drive access (refer previous cell)
!pip install -q gdown

import gdown

# file ID from the shareable link
file_id_train = '1bk8qGUTmG6rJinUBi9viwmAbeEU9ZkPf'
file_id_test = '1F0n85kMms-UT3OacnpeW_6LA3bl_vc2m'

# Downloadable URL
url_train = f'https://drive.google.com/uc?id={file_id_train}'
url_test = f'https://drive.google.com/uc?id={file_id_test}'

# Download the file to the current working directory
gdown.download(url_train, 'train.parquet', quiet=False)
gdown.download(url_test, 'test.parquet', quiet=False)


df_train = pd.read_parquet('train.parquet')
df_test = pd.read_parquet('test.parquet')

# Notice that the data is stored in 'parquet' file format
# It's commonly used for storing large dataset where the data is stored column-wise
# instead of row-wise. Allows high compression and fast reading.

In [None]:
#Add .head(), .info() or class distribution plot
df_train.head()

#Data Preprocessing

The MRI scans in Parquet format store images as byte strings inside a ditionary . To use them for deep learning, we convert these bytes into grayscale image arrays using np.frombuffer and cv2.imdecode.

In [None]:
# cv2 library is part of OpenCV(Open Source computer vision library) used for
# preprocessing images before feeding them to machine learning models

# MRI scans are typically grayscale; using IMREAD_GRAYSCALE ensures consistent shape and format
import cv2
def dict_to_image(image_dict):
    if isinstance(image_dict, dict) and 'bytes' in image_dict:
        byte_string = image_dict['bytes']
        nparr = np.frombuffer(byte_string, np.uint8)
        img = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
        return img
    else:
        raise TypeError(f"Expected dictionary with 'bytes' key, got {type(image_dict)}")

Apply above function to transform all images in the dataset

In [None]:
df_train['img_arr'] = df_train['image'].apply(dict_to_image)
df_train.drop("image", axis=1, inplace=True)
df_train.head()

In [None]:
# Create a dictionary to map labels to class names
label_mapping = {
    0: "Mild_Demented",
    1: "Moderate_Demented",
    2: "Non_Demented",
    3: "Very_Mild_Demented"
}

df_train['class_name'] = df_train['label'].map(label_mapping)

#df_train.head()

#Explore the data

In [None]:
#Plot some of the MRI images
fig, ax = plt.subplots(2, 3, figsize=(15, 5))
axs = ax.flatten()
for axes in axs:
    rand = np.random.randint(0, len(df_train))
    axes.imshow(df_train.iloc[rand]['img_arr'], cmap="gray")
    axes.set_title([df_train.iloc[rand]['class_name']])
plt.tight_layout()
plt.show()



In [None]:
import seaborn as sns

plt.figure(figsize=(8, 5))

# Countplot to visualize the distribution of classes
sns.countplot(data=df_train, x='class_name', palette="viridis")

# Add labels and title
plt.xlabel("Alzheimer Category")
plt.ylabel("Number of Images")
plt.title("Distribution of Alzheimer MRI Categories")
plt.xticks(rotation=20)


plt.show()


In [None]:
df_train['class_name'].value_counts()


*Imbalanced dataset as non_demented and ver_mild_demented dominate the dataset while moderate_demented is very low*


#Train-Test Split

In [None]:
from sklearn.model_selection import train_test_split
train_df, val_df = train_test_split(df_train, test_size=0.2, stratify=df_train['class_name'], random_state=42)

In [None]:
print(train_df.shape,val_df.shape)

* Normalize images:
When working with image data, pixel values typically range from 0 to 255
(since images are stored as 8-bit integers). Normalizing them by dividing by 255 scales these values between 0 and 1.


In [None]:
#image pixel values (0, 255) -> (0, 1)
train_df['img_arr'] = train_df['img_arr'].apply(lambda x: x / 255.0)
val_df['img_arr'] = val_df['img_arr'].apply(lambda x: x / 255.0)

In [None]:
print("Training Set Class Distribution:\n", train_df['class_name'].value_counts())
print("\nValidation Set Class Distribution:\n", val_df['class_name'].value_counts())


#Model Building/Training

CNN: Convolution, Pooling, Dense/Fully Connected layers

Convolution and Max/Average pooling: https://colab.research.google.com/drive/1EsBZi0Y-pu47BZRNI2JlG3Cpyb9RoSyY?usp=sharing

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score

# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# === Prepare Train and Validation Data ===
X_train = np.stack(train_df['img_arr'].values).reshape(-1, 1, 128, 128).astype(np.float32)
y_train = train_df['label'].values.astype(np.int64)

X_val = np.stack(val_df['img_arr'].values).reshape(-1, 1, 128, 128).astype(np.float32)
y_val = val_df['label'].values.astype(np.int64)

train_dataset = TensorDataset(torch.tensor(X_train), torch.tensor(y_train))
val_dataset = TensorDataset(torch.tensor(X_val), torch.tensor(y_val))

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

# CNN Model
class CNNModel(nn.Module):
    def __init__(self):
        super(CNNModel, self).__init__()
        # Convolutional layer: detects patterns in image patches
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        # Pooling layer: reduces spatial dimensions while retaining important features
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        # Fully connected (dense) layer: combines features for final prediction
        self.fc1 = nn.Linear(64 * 30 * 30, 128)
        self.fc2 = nn.Linear(128, 4)  # 4 classes


    # Forward method defines how input flows through the network
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))   #
        x = self.pool(F.relu(self.conv2(x)))   #

        # Flatten the 2D features into 1D before passing to dense layers
        x = x.view(-1, 64 * 30 * 30) #old

        x = F.relu(self.fc1(x))

        return self.fc2(x)

model = CNNModel().to(device)


## Model Architecture
Uncomment the code in cell below to print the Model Architecture. There will be some additional libraries installed so it may take a while!

In [None]:
    # Model Architecture
    # !pip install torchviz
    # import torch
    # from torchviz import make_dot
    # import matplotlib.pyplot as plt

    # def plot_model(model, input_data, filename="model_architecture.png"):
    #     """Plots the model architecture."""
    #     dot = make_dot(model(input_data), params=dict(model.named_parameters()))
    #     dot.format = 'png'
    #     dot.render(filename)

    # # Example Usage
    # model1 = CNNModel()  # CNN model
    # input_data = torch.randn(1, 1, 128, 128) #  input data
    # plot_model(model1, input_data)
    # plot_model(model1, input_data, filename="model_architecture.png")

In [None]:
# Define loss function to measure prediction error
criterion = nn.CrossEntropyLoss()
# Define optimizer to adjust model weights based on loss
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training with Early Stopping
num_epochs = 20 #10 #50
patience = 5
best_val_loss = np.inf
early_stop_counter = 0

train_losses, val_losses = [], []

# epochs: An epoch refers to one complete pass through the entire training dataset
# by the model. If you have 1000 images and train for 10 epochs, Your model will see
# all 1000 images 10 times (in different orders, usually). Imagine studying for an
# exam. You don't read your notes just once—you review them several times to really
# understand and retain the information. That’s what epochs do for a model!

for epoch in range(num_epochs):
    # Training mode: enable gradient updates
    model.train()
    running_loss = 0.0
    batch_count = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        batch_count += 1
        if batch_count % 16 == 0:
            print(f"Epoch: {epoch+1}, iteration: {batch_count}/128, loss : {loss.item()}")

    train_loss = running_loss / len(train_loader)
    train_losses.append(train_loss)

    # Evaluation mode: disable gradient updates
    model.eval()
    val_loss = 0.0
    # Disable gradient calculation for validation/inference
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

    val_loss /= len(val_loader)
    val_losses.append(val_loss)

    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    # Early Stopping Check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')
        early_stop_counter = 0
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("Early stopping triggered.")
            break



##Summary of Model Parameters

In [None]:
!pip install torchinfo
from torchinfo import summary
summary(model, input_size=(1, 1, 128, 128))

In [None]:
# Load best model
model.load_state_dict(torch.load('best_model.pth'))

# Evaluate on Validation Set
model.eval()
all_preds = []
with torch.no_grad():
    for images, _ in val_loader:
        images = images.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())

print("Validation Classification Report:\n", classification_report(y_val, all_preds))



In [None]:
# Confusion Matrix
plt.figure(figsize=(6,6))
sns.heatmap(confusion_matrix(y_val, all_preds), annot=True, fmt='d', cmap='Blues',
            xticklabels=['Mild', 'Moderate', 'Non-Demented', 'Very Mild'],
            yticklabels=['Mild', 'Moderate', 'Non-Demented', 'Very Mild'])
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Validation Confusion Matrix')
plt.show()



In [None]:
# Plot Training vs. Validation Loss
plt.figure(figsize=(8,5))
plt.plot(range(1, len(train_losses)+1), train_losses, 'bo-', label='Training Loss')
plt.plot(range(1, len(val_losses)+1), val_losses, 'r^-', label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training vs. Validation Loss')
plt.legend()
plt.grid()
plt.show()



#Prediction

In [None]:
# Prediction on Test Data
df_test['img_arr'] = df_test['image'].apply(dict_to_image)
df_test.drop("image", axis=1, inplace=True)

X_test = np.stack(df_test['img_arr'].values).reshape(-1, 1, 128, 128).astype(np.float32) / 255.0
y_test = df_test['label'].values.astype(np.int64)

test_dataset = TensorDataset(torch.tensor(X_test), torch.tensor(y_test))
test_loader = DataLoader(test_dataset, batch_size=32)

model.eval()
all_test_preds = []
with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        all_test_preds.extend(preds.cpu().numpy())

accuracy = accuracy_score(y_test, all_test_preds)
print(f"Test Accuracy: {accuracy:.4f}")
print("Test Classification Report:\n", classification_report(y_test, all_test_preds))

# Grad-CAM Visualization

Grad-CAM produces a heatmap overlay on the original input image, highlighting which parts of the image the CNN focused on when making its prediction.

- Red/Hot regions: The model found these parts most important for its
decision.
- Green = Medium importance. The region contributed to the decision, but not as strongly.
- Cool/Blue regions: These were less relevant.


Question:
- Does the focus area in Grad-CAM overlay make clinical sense?
- Is the model consistent across similar predictions?
- Trust in AI: Would a doctor feel confident with this kind of explanation?
- Note: Try multiple examples to check if the model is focusing on medically relevant regions consistently.

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # Register hooks
        self.target_layer.register_forward_hook(self.save_activation)
        self.target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output.detach()

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()

    def generate(self, input_tensor, class_idx=None):
        output = self.model(input_tensor)

        if class_idx is None:
            class_idx = output.argmax().item()

        self.model.zero_grad()
        loss = output[0, class_idx]
        loss.backward()

        weights = self.gradients.mean(dim=[2, 3], keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)

        cam = F.relu(cam)
        cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False)
        cam = cam.squeeze().cpu().numpy()

        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)  # normalize
        return cam


In [None]:
def show_gradcam_on_image(img_tensor, cam, alpha=0.4):
    img = img_tensor.squeeze().cpu().numpy()
    if img.ndim == 2:
        img = np.stack([img]*3, axis=-1)

    cam_resized = cv2.resize(cam, (img.shape[1], img.shape[0]))
    heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255

    overlay = heatmap * alpha + img / img.max()
    overlay = overlay / overlay.max()

    plt.imshow(overlay)
    plt.axis('off')
    plt.title("Grad-CAM Overlay")
    plt.show()


In [None]:
# Load an image tensor
rand = np.random.randint(0, len(df_train))
sample_img = df_train.iloc[rand]['img_arr']
input_tensor = torch.tensor(sample_img, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device) / 255.0

# Apply Grad-CAM to conv2
gradcam = GradCAM(model, model.conv2)
cam = gradcam.generate(input_tensor)

# Visualize
show_gradcam_on_image(input_tensor.cpu(), cam)


#Quiz

In [None]:
# Quiz: MRI and CNN Understanding

# Question 1
ans = input("1. Why do we use grayscale images instead of color for brain MRI scans?\n(a) Grayscale images are more artistic\n(b) Brain MRIs are naturally in grayscale and color adds noise\n(c) Grayscale saves storage but reduces accuracy\n(d) CNNs cannot work with color images\nYour answer: ")
if ans.lower() == 'b':
    print("Correct! Medical imaging like MRI is naturally in grayscale to emphasize tissue contrast.")
else:
    print("Incorrect. The correct answer is (b).")

# Question 2
ans = input("\n2. What does a Convolutional Neural Network (CNN) learn from an image?\n(a) Exact pixel values\n(b) Patterns and features like edges, shapes, and textures\n(c) Text labels on the image\n(d) The file name of the image\nYour answer: ")
if ans.lower() == 'b':
    print("Correct! CNNs learn spatial features like edges and textures that help in classification.")
else:
    print("Incorrect. The correct answer is (b).")

# Question 3
ans = input("\n3. What does a Conv2D layer do in a CNN model?\n(a) Adds text annotations to images\n(b) Converts images to 1D arrays\n(c) Applies filters to detect features like edges and patterns\n(d) Shrinks the image size\nYour answer: ")
if ans.lower() == 'c':
    print("Correct! Conv2D applies filters to extract features like edges and textures from images.")
else:
    print("Incorrect. The correct answer is (c).")

# Question 4
ans = input("\n4. What is the purpose of a MaxPooling2D layer in CNNs?\n(a) To shuffle the image pixels\n(b) To reduce the size of the feature maps and keep only the most important information\n(c) To make the image larger\n(d) To increase model accuracy directly\nYour answer: ")
if ans.lower() == 'b':
    print("Correct! MaxPooling2D helps reduce the image size while retaining important features.")
else:
    print("Incorrect. The correct answer is (b).")

# Question 5
ans = input("\n5. Why is it important to resize all MRI images to the same shape (e.g., 128x128)?\n(a) So they look nicer\n(b) Because models only accept images of equal shape\n(c) To reduce brightness\n(d) To improve color accuracy\nYour answer: ")
if ans.lower() == 'b':
    print("Correct! Neural networks need consistent input dimensions.")
else:
    print("Incorrect. The correct answer is (b).")

# Question 6
ans = input("\n6. What is the role of the final layer (softmax) in our CNN model?\n(a) To detect edges\n(b) To convert features into class probabilities\n(c) To shuffle the data\n(d) To normalize pixel values\nYour answer: ")
if ans.lower() == 'b':
    print("Correct! Softmax turns outputs into probabilities for each class.")
else:
    print("Incorrect. The correct answer is (b).")

# Question 7
ans = input("\n7. Why is it useful to use MRI scans for detecting Alzheimer's disease?\n(a) MRIs are colorful\n(b) They show brain structure changes that relate to disease\n(c) Doctors prefer images over tests\n(d) They contain the patient’s DNA\nYour answer: ")
if ans.lower() == 'b':
    print("Correct! MRI scans reveal structural changes in the brain linked to neurodegeneration.")
else:
    print("Incorrect. The correct answer is (b).")

# Question 8
ans = input("\n8. What does it mean when we say a model is 'trained'?\n(a) It has learned to make predictions based on patterns in data\n(b) It memorized the patient names\n(c) It runs faster\n(d) It creates new diseases\nYour answer: ")
if ans.lower() == 'a':
    print("Correct! Training is the process of learning from examples to make predictions.")
else:
    print("Incorrect. The correct answer is (a).")

# Question about epochs
ans = input("\n9. What is an 'epoch' in machine learning?\n(a) A single pass through the entire training dataset\n(b) A type of neural network\n(c) A mistake in the data\n(d) A hardware device used for training\nYour answer: ")
if ans.lower() == 'a':
    print("Correct! An epoch is one full pass through all the training data.")
else:
    print("Incorrect. The correct answer is (a). An epoch means the model has seen all training samples once.")


# Question about multiple epochs
ans = input("\n10. Why do we train a model for multiple epochs?\n(a) To help the model learn better by refining its predictions over time\n(b) To waste computer resources\n(c) To make the dataset larger\n(d) To create more labels\nYour answer: ")
if ans.lower() == 'a':
    print("Correct! Training for multiple epochs allows the model to gradually improve and reduce errors.")
else:
    print("Incorrect. The correct answer is (a). Multiple epochs help the model learn from the data more effectively.")

# Question: Why use Early Stopping?
ans = input("\n11. Why do we use 'early stopping' during model training?\n"
            "(a) To save electricity\n"
            "(b) To avoid overfitting by stopping when validation loss stops improving\n"
            "(c) To make training easier to watch\n"
            "(d) To restart the training from scratch\n"
            "Your answer: ")

if ans.lower() == 'b':
    print("Correct! Early stopping helps prevent overfitting by stopping training when the model stops improving on validation data.")
else:
    print("Incorrect. The correct answer is (b). Early stopping monitors performance on validation data and stops training to avoid overfitting.")


# Question: MRI vs fMRI
# ans = input("\n12. What is a key difference between MRI and fMRI scans?\n"
#             "(a) MRI captures brain activity while fMRI captures brain structure\n"
#             "(b) MRI uses colors to show brain functions\n"
#             "(c) fMRI captures changes in brain activity over time\n"
#             "(d) fMRI is only used in animals\n"
#             "Your answer: ")

# if ans.lower() == 'c':
#     print("Correct! fMRI tracks brain activity by measuring changes in blood flow over time.")
# else:
#     print("Incorrect. The correct answer is (c). Unlike MRI, which gives static anatomical images, fMRI captures brain activity over time.")

### Common Alzheimer’s Terminology

| Abbreviation | Full Form                          | Meaning / Use |
|--------------|-------------------------------------|----------------|
| **AD**       | Alzheimer’s Disease                 | The most common type of dementia, causing memory loss and cognitive decline. |
| **MCI**      | Mild Cognitive Impairment           | A stage between normal aging and dementia; noticeable decline but not severely disabling. |
| **CN**       | Cognitively Normal                  | Individuals with normal cognitive functioning; used as control in studies. |
| **MRI**      | Magnetic Resonance Imaging          | Brain scan to detect structural changes in the brain. |
| **fMRI**     | Functional MRI                      | Measures brain activity through blood flow changes. |
| **CSF**      | Cerebrospinal Fluid                 | Fluid around brain/spine; tested for Alzheimer’s biomarkers. |
| **Aβ (A-beta)** | Amyloid Beta                     | Protein that forms brain plaques in Alzheimer’s disease. |
| **Tau**      | Tau Protein                         | Protein that forms tangles inside neurons in Alzheimer’s. |
| **PET**      | Positron Emission Tomography        | Imaging that visualizes amyloid and tau buildup. |
| **NIA-AA**   | National Institute on Aging – Alzheimer’s Association | Sets research and diagnostic guidelines for Alzheimer’s. |



### Difference in AI Model Design: MRI vs fMRI

| **Aspect**             | **MRI (Structural MRI)**                                             | **fMRI (Functional MRI)**                                                             |
|------------------------|----------------------------------------------------------------------|----------------------------------------------------------------------------------------|
| **Type of Data**       | Static 3D anatomical image (e.g., brain structure)                   | 4D time-series data (3D images over time) capturing brain activity                    |
| **Typical Use**        | Detecting structural changes (e.g., atrophy in Alzheimer’s)          | Studying brain function, networks, and responses to tasks/stimuli                    |
| **Input to Model**     | Single 2D or 3D grayscale image                                       | Sequence of 3D images over time (like a video)                                       |
| **Preprocessing**      | Image resizing, normalization, skull-stripping                      | All MRI steps **plus** temporal filtering, motion correction, time-series alignment  |
| **Model Type**         | 2D or 3D CNN                                                         | 3D CNN, RNN (like LSTM), or hybrid (CNN + RNN)                                       |
| **Data Size**          | Smaller, easier to train on modest hardware                          | Larger, requires more memory and often downsampling                                  |
| **Example Architecture**| 2D CNN for image classification                                     | CNN + LSTM for activity pattern recognition over time                                |
| **Applications**       | Disease staging, tumor detection, brain structure classification     | Brain region connectivity analysis, mental task decoding, emotion detection          |

---

### Simplified Analogy

- **MRI** is like a **photo** of the brain.
- **fMRI** is like a **video** showing how the brain is working over time.

---

### Summary

| MRI:  | Focused on **structure** → best with **CNNs (2D or 3D)**.        |
|-------|------------------------------------------------------------------|
| fMRI: | Focused on **brain function** → needs models that handle time, like **CNN + RNN**. |


#Notes
- Add google drive access code to let the user know, how to access the dataset.
and why did I use gdown(to allow simpler access)

- Simplify the code and explanation so that participants can get something out of it and have motivation to explore on their own.

- Some image for pooling apart from convolution. We can discuss kernels/filter convolution and pooling which does shrinking and how the two form the core of CNN for feature engineering. And finally fully connected layers to move towards prediction.

- If possible, convey relevance/importance of Convolution allowing sparse connection and weight sharing which reduces the number of parameter in the network greatly.

- Also discuss how for identifying a shape for example a car, identifying edges is helpful (we don't care about car color etc). Also how in the past sift/hog kernel were used for feature engineering and then ML tech applied but now with CNN we try to learn the parameters of those filters by ourself.

- Instead of handcrafted features in the past, we try to learn the parameters of the filters in CNN.

- And not just one filter, but multiple layer of filters, in addition to learning the weights of the classifier (using backpropagation).

- Multiple learnt convolutions at every layer and multiple such layers.

- Discuss if relevant, why CNN why not Feed Forward. convey how cnn capture and benefit from the structure of the image (interaction between nearby or neighboring pixels is more interesting). This leads to sparse connectivity and reduces the number of parameters in the network.

- Further we try to learn more and more abstract representation at each layer (where at some point, they don't make much sense to naked eye, but contains information for Neural network and predictions/inferences.

- Add more discussion points everywhere, specially in introductory workbook.

- Pytorch CNN model, forward, conv2d etc
- Reference Notebook: https://www.kaggle.com/code/aasthadata/alzheimer-mri-cnn-0-96 (in Tensorflow). Share this link if someone wants to use tensorflow.

- Use GPUs for training, should be faster. Otherwise even going through 1-2 epoch during demo time would be difficult.


In [None]:
#End