<a href="https://colab.research.google.com/github/sutummala/AutismNet/blob/main/Autism_ConvMixer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install tensorflow_addons



In [None]:
from tensorflow.keras import layers
from tensorflow import keras
from sklearn.model_selection import RepeatedStratifiedKFold, cross_val_score
import matplotlib.pyplot as plt
import tensorflow_addons as tfa
import tensorflow as tf
import numpy as np

In [None]:
def normalize(input):
  norm_input = []
  print(f'shape of input is {np.shape(input[0])}')
  for i in range(np.shape(input)[0]):
    norm_in = (input[i]-np.min(input[i]))/(np.max(input[i])-np.min(input[i]))
    norm_input.append(norm_in)
  return np.array(norm_input)

In [None]:
healthy_data = np.load('/content/drive/My Drive/Datasets/Autism/healthy_patches.npy')
healthy_data = normalize(healthy_data)
#healthy_data_2 = np.load('/content/drive/My Drive/Datasets/Autism/healthy_patches_2.npy')
#healthy_data = np.concatenate((healthy_data_1, healthy_data_2))
healthy_labels = np.zeros(np.shape(healthy_data)[0])

autism_data = np.load('/content/drive/My Drive/Datasets/Autism/autism_patches.npy')
autism_data = normalize(autism_data)
#autism_data_2 = np.load('/content/drive/My Drive/Datasets/Autism/autism_patches_2.npy')
#autism_data = np.concatenate((autism_data_1, autism_data_2))
autism_labels = np.ones(np.shape(autism_data)[0])

index = int(0.9 * healthy_data.shape[0])

X = np.concatenate((healthy_data[:index], autism_data[:index])) 
X_train = np.expand_dims(X, axis = -1)
healthy_test = np.expand_dims(healthy_data[index:], axis = -1) 
autism_test = np.expand_dims(autism_data[index:], axis = -1)

y_train = np.concatenate((healthy_labels[:index], autism_labels[:index]))
healthy_test_labels = healthy_labels[index:]
autism_test_labels = autism_labels[index:]

X_test = np.concatenate((healthy_test, autism_test))
y_test = np.concatenate((healthy_test_labels, autism_test_labels))

print(f'no.of healthy patches for testing are {np.shape(healthy_test_labels)[0]}')
print(f'no.of diseased patches for testing are {np.shape(autism_test_labels)[0]}')

shape of input is (32, 32, 32)
shape of input is (32, 32, 32)
no.of healthy patches for testing are 368
no.of diseased patches for testing are 728


In [None]:
folds = RepeatedStratifiedKFold(n_splits = 5, n_repeats = 1)

#X_train = np.expand_dims(X_train, axis = -1) * 255 # expanding dimensions

for train_index, test_index in folds.split(X_train, y_train):
    input_cv, input_test, targets_cv, targets_test = X_train[train_index], X_train[test_index], y_train[train_index], y_train[test_index]

print(f'shape of input for CV is {input_cv.shape}')
print(f'input size for cross-validation is {len(targets_cv)}')
print(f'no.of healthy in CV are {np.shape(np.nonzero(targets_cv))[1]}')

#left_input_test = left_input[cv_index:]
#right_input_test = right_input[cv_index:]
#targets_test = targets[cv_index:]

print(f'shape of input for testing is {input_test.shape}')
print(f'input size for testing is {len(targets_test)}')
print(f'no.of healthy in test are {np.shape(np.nonzero(targets_test))[1]}')

shape of input for CV is (5287, 32, 32, 32, 1)
input size for cross-validation is 5287
no.of healthy in CV are 2643
shape of input for testing is (1321, 32, 32, 32, 1)
input size for testing is 1321
no.of healthy in test are 661


In [None]:
learning_rate = 0.001
weight_decay = 0.0001

batch_size = 128
num_epochs = 100

In [None]:
def activation_block(x):
    x = layers.Activation("gelu")(x)
    return layers.BatchNormalization()(x)


def conv_stem(x, filters: int, patch_size: int):
    x = layers.Conv3D(filters, kernel_size=patch_size, strides=patch_size)(x)
    return activation_block(x)


def conv_mixer_block(x, filters: int, kernel_size: int):
    # Depthwise convolution.
    x0 = x
    x = layers.Conv3D(1, kernel_size=kernel_size, padding="same")(x)
    x = layers.Add()([activation_block(x), x0])  # Residual.

    # Pointwise convolution.
    x = layers.Conv3D(filters, kernel_size=1)(x)
    x = activation_block(x)

    return x


def get_conv_mixer_256_8(
    image_size=32, filters=256, depth=16, kernel_size=5, patch_size=4, num_classes=1
):
    """ConvMixer-256/8: https://openreview.net/pdf?id=TVHS5Y4dNvM.
    The hyperparameter values are taken from the paper.
    """
    inputs = keras.Input((image_size, image_size, image_size, 1))
    #x = layers.Rescaling(scale=1.0 / 255)(inputs)
    x = inputs
    # Extract patch embeddings.
    x = conv_stem(x, filters, patch_size)

    # ConvMixer blocks.
    for _ in range(depth):
        x = conv_mixer_block(x, filters, kernel_size)

    # Classification block.
    x = layers.GlobalAvgPool3D()(x)
    outputs = layers.Dense(num_classes, activation="sigmoid")(x)

    return keras.Model(inputs, outputs)

In [None]:
def run_experiment(model):
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    )

    model.compile(
        optimizer=optimizer,
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )

    checkpoint_filepath = "/tmp/checkpoint"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    history = model.fit(
        input_cv, targets_cv,
        validation_data=(input_test, targets_test),
        epochs=num_epochs,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy = model.evaluate(X_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")

    return history, model

In [None]:
conv_mixer_model = get_conv_mixer_256_8()
conv_mixer_model.summary()
history, conv_mixer_model = run_experiment(conv_mixer_model)

Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            [(None, 32, 32, 32,  0                                            
__________________________________________________________________________________________________
conv3d (Conv3D)                 (None, 16, 16, 16, 2 2304        input_1[0][0]                    
__________________________________________________________________________________________________
activation (Activation)         (None, 16, 16, 16, 2 0           conv3d[0][0]                     
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 16, 16, 16, 2 1024        activation[0][0]                 
______________________________________________________________________________________________