In [None]:
#%matplotlib qt
import tensorflow as tf
import numpy as np
#import matplotlib.pyplot as plt
#from mpl_toolkits.mplot3d import Axes3D  # sometimes needed to register 3D
#from matplotlib import colors
#from matplotlib.widgets import Slider, Button


from medmnist import OrganMNIST3D


train_dataset = OrganMNIST3D(split='train', size=28, download=True)
trainx = []
trainy = []

test_dataset = OrganMNIST3D(split='test', size=28, download=True)
testx = []
testy = []

val_dataset = OrganMNIST3D(split='train', size=28, download=True)
valx = []
valy = []

for i in range(len(train_dataset)):
    trainx.append(train_dataset[i][0])
    trainy.append(train_dataset[i][1])

for i in range(len(test_dataset)):
    testx.append(test_dataset[i][0])
    testy.append(test_dataset[i][1])

for i in range(len(val_dataset)):
    valx.append(val_dataset[i][0])
    valy.append(val_dataset[i][1])


trainx_tensor = tf.convert_to_tensor(trainx, dtype=tf.float16)
trainy_tensor = tf.convert_to_tensor(trainy, dtype=tf.float16)
testx_tensor = tf.convert_to_tensor(testx, dtype=tf.float16)
testy_tensor = tf.convert_to_tensor(testy, dtype=tf.float16)
valx_tensor = tf.convert_to_tensor(valx, dtype=tf.float16)
valy_tensor = tf.convert_to_tensor(valy, dtype=tf.float16)
# float16 doesn't run any faster on the 4090s, but it cuts memory usage in half!


In [None]:

def myNet():

    model = tf.keras.Sequential(layers = [
        # #images only have 1 color scale (greyscale)
        tf.keras.layers.InputLayer(shape=(1, 28, 28, 28)),

        #(3, 3, 3) slides a 3x3x3 cube over the height, depth, width
        # can change bias to be 0.01 if the learning is very slow, helps kickstart it if needed later
        # Using data_format channels first because the input has the channels(axis here) first instead of last
        tf.keras.layers.Conv3D(32, (3,3,3), activation= 'relu', bias_initializer='zeros', data_format='channels_first'),
        tf.keras.layers.Conv3D(64, (2,2,2), activation= 'relu', bias_initializer='zeros', data_format='channels_first'),
        #pool size of (2, 2, 2) changes the input shape of (28, 28, 28, 1) to (14, 14, 14, 1)
        tf.keras.layers.MaxPooling3D(pool_size=(2,2,2), data_format='channels_first'),
        tf.keras.layers.Dropout(0.15),


        tf.keras.layers.Conv3D(128, (3,3,3), activation= 'relu', bias_initializer='zeros', data_format='channels_first'),
        tf.keras.layers.MaxPooling3D(pool_size=(2,2,2),data_format='channels_first'),

        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dropout(0.3),
        
        tf.keras.layers.Dense(11, activation='softmax', dtype='float32', name= 'output') #needs to be 11 because there are 10 classifications
    ])
    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy'])
    return model

In [None]:

class saveNet(tf.keras.callbacks.Callback):
    def __init__(self, n):
        super().__init__()
        self.save_rate = n

    def on_epoch_end(self, epoch, logs=None, save_rate=15):
        if (epoch + 1) % self.save_rate == 0:
            filename = f'checkpoints/epoch_{epoch+1}.weights.h5'
            self.model.save_weights(filename, overwrite=True)

checkpoint = saveNet(10)
model = myNet()

training_history = model.fit(
    trainx_tensor, trainy_tensor,
    batch_size= 16, 
    epochs= 20,
    callbacks = [checkpoint]
    )



Epoch 1/3
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 1s/step - accuracy: 0.2348 - loss: 2.5616
Epoch 2/3
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 1s/step - accuracy: 0.5633 - loss: 1.3059
Epoch 3/3
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m42s[0m 1s/step - accuracy: 0.7559 - loss: 0.7802


In [4]:
test_loss, test_acc = model.evaluate(testx_tensor, testy_tensor)
print(f"Test Accuracy: {round(test_acc * 100, 2)}%")

[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 308ms/step - accuracy: 0.7705 - loss: 0.7439
Test Accuracy: 77.05%
