From ecfe432023bdc8a447a841571ad03b813dc30e62 Mon Sep 17 00:00:00 2001 From: Tai-Long He <33795640+tailonghe@users.noreply.github.com> Date: Tue, 14 Jun 2022 14:18:10 -0400 Subject: [PATCH] Update to the demo notebook --- example.ipynb | 381 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 381 insertions(+) create mode 100644 example.ipynb diff --git a/example.ipynb b/example.ipynb new file mode 100644 index 0000000..6510eb9 --- /dev/null +++ b/example.ipynb @@ -0,0 +1,381 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "e0c3d03e", + "metadata": {}, + "outputs": [], + "source": [ + "from tensorflow.keras.optimizers import Adam\n", + "from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint\n", + "from model.core import RUnet_model\n", + "from utils.functions import r_score, mse_nonzero, data_split\n", + "import numpy as np" + ] + }, + { + "cell_type": "markdown", + "id": "cdac820e", + "metadata": {}, + "source": [ + "# Build and compile model" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "33d0d397", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"model\"\n", + "__________________________________________________________________________________________________\n", + " Layer (type) Output Shape Param # Connected to \n", + "==================================================================================================\n", + " model_input (InputLayer) [(None, 48, 376, 10 0 [] \n", + " )] \n", + " \n", + " lambda (Lambda) (None, 48, 376, 10) 0 ['model_input[0][0]'] \n", + " \n", + " Block1_Conv1 (Conv2D) (None, 48, 376, 128 11648 ['lambda[0][0]'] \n", + " ) \n", + " \n", + " Block1_Conv2 (Conv2D) (None, 48, 376, 128 147584 ['Block1_Conv1[0][0]'] \n", + " ) \n", + " \n", + " Block1_MaxPool (MaxPooling2D) (None, 24, 188, 128 0 ['Block1_Conv2[0][0]'] \n", + " ) \n", + " \n", + " Block2_Conv1 (Conv2D) (None, 24, 188, 256 295168 ['Block1_MaxPool[0][0]'] \n", + " ) \n", + " \n", + " Block2_Conv2 (Conv2D) (None, 24, 188, 256 590080 ['Block2_Conv1[0][0]'] \n", + " ) \n", + " \n", + " Block2_MaxPool (MaxPooling2D) (None, 12, 94, 256) 0 ['Block2_Conv2[0][0]'] \n", + " \n", + " Block3_Conv1 (Conv2D) (None, 12, 94, 512) 1180160 ['Block2_MaxPool[0][0]'] \n", + " \n", + " Block3_Conv2 (Conv2D) (None, 12, 94, 512) 2359808 ['Block3_Conv1[0][0]'] \n", + " \n", + " Block3_MaxPool (MaxPooling2D) (None, 6, 47, 512) 0 ['Block3_Conv2[0][0]'] \n", + " \n", + " Block4_Permute1 (Permute) (None, 512, 6, 47) 0 ['Block3_MaxPool[0][0]'] \n", + " \n", + " Block4_Reshape (Reshape) (None, 512, 282) 0 ['Block4_Permute1[0][0]'] \n", + " \n", + " Block4_Permute2 (Permute) (None, 282, 512) 0 ['Block4_Reshape[0][0]'] \n", + " \n", + " LSTM1 (LSTM) (None, 282, 1024) 6295552 ['Block4_Permute2[0][0]'] \n", + " \n", + " Block5_Reshape (Reshape) (None, 6, 47, 1024) 0 ['LSTM1[0][0]'] \n", + " \n", + " Block5_UpConv (Conv2DTranspose (None, 12, 94, 512) 2097664 ['Block5_Reshape[0][0]'] \n", + " ) \n", + " \n", + " concatenate (Concatenate) (None, 12, 94, 1024 0 ['Block5_UpConv[0][0]', \n", + " ) 'Block3_Conv2[0][0]'] \n", + " \n", + " Block5_Conv1 (Conv2D) (None, 12, 94, 512) 4719104 ['concatenate[0][0]'] \n", + " \n", + " Block5_Conv2 (Conv2D) (None, 12, 94, 512) 2359808 ['Block5_Conv1[0][0]'] \n", + " \n", + " Block6_UpConv (Conv2DTranspose (None, 24, 188, 256 524544 ['Block5_Conv2[0][0]'] \n", + " ) ) \n", + " \n", + " concatenate_1 (Concatenate) (None, 24, 188, 512 0 ['Block6_UpConv[0][0]', \n", + " ) 'Block2_Conv2[0][0]'] \n", + " \n", + " Block6_Conv1 (Conv2D) (None, 24, 188, 256 1179904 ['concatenate_1[0][0]'] \n", + " ) \n", + " \n", + " Block6_Conv2 (Conv2D) (None, 24, 188, 256 590080 ['Block6_Conv1[0][0]'] \n", + " ) \n", + " \n", + " Block7_UpConv (Conv2DTranspose (None, 48, 376, 128 131200 ['Block6_Conv2[0][0]'] \n", + " ) ) \n", + " \n", + " concatenate_2 (Concatenate) (None, 48, 376, 256 0 ['Block7_UpConv[0][0]', \n", + " ) 'Block1_Conv2[0][0]'] \n", + " \n", + " Block7_Conv1 (Conv2D) (None, 48, 376, 128 295040 ['concatenate_2[0][0]'] \n", + " ) \n", + " \n", + " Block7_Conv2 (Conv2D) (None, 48, 376, 128 147584 ['Block7_Conv1[0][0]'] \n", + " ) \n", + " \n", + " model_output_2 (Conv2D) (None, 48, 376, 2) 258 ['Block7_Conv2[0][0]'] \n", + " \n", + "==================================================================================================\n", + "Total params: 22,925,186\n", + "Trainable params: 22,925,186\n", + "Non-trainable params: 0\n", + "__________________________________________________________________________________________________\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "D:\\anaconda\\envs\\tf3\\lib\\site-packages\\keras\\optimizer_v2\\adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n", + " super(Adam, self).__init__(name, **kwargs)\n" + ] + } + ], + "source": [ + "# build a DL model for first two levels\n", + "level1 = 1\n", + "level2 = 2\n", + "\n", + "dl_model = RUnet_model(level1, level2)\n", + "opt = Adam(lr=1e-5) \n", + "\n", + "dl_model.compile(optimizer=opt, loss=mse_nonzero, metrics=[r_score, mse_nonzero])\n", + "dl_model.info()" + ] + }, + { + "cell_type": "markdown", + "id": "2bad7378", + "metadata": {}, + "source": [ + "# Load data sets using data generator" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a64f192b", + "metadata": {}, + "outputs": [], + "source": [ + "from utils.data_generator import data_generator\n", + "import glob" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c02ff7f3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['example_data\\\\phase1_example_data\\\\X\\\\X_000.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_001.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_002.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_003.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_004.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_1030.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_1031.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_1032.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_1033.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_1034.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_000.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_001.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_002.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_003.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_004.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1030.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1031.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1032.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1033.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1034.npz']\n" + ] + } + ], + "source": [ + "# phase-1 training using preprocessed B-SOSE data sets\n", + "xfiles = np.array(sorted(glob.glob('example_data\\phase1_example_data\\X/X*.npz')))\n", + "yfiles = np.array(sorted(glob.glob('example_data\\phase1_example_data\\Y/Y*.npz')))\n", + "\n", + "print(xfiles, yfiles)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "2286ec1d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[4 5 2 6 1 9 8]\n", + "['example_data\\\\phase1_example_data\\\\X\\\\X_004.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_1030.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_002.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_1031.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_001.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_1034.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_1033.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_004.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1030.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_002.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1031.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_001.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1034.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1033.npz']\n", + "['example_data\\\\phase1_example_data\\\\X\\\\X_000.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_000.npz']\n", + "['example_data\\\\phase1_example_data\\\\X\\\\X_1032.npz'\n", + " 'example_data\\\\phase1_example_data\\\\X\\\\X_003.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_1032.npz'\n", + " 'example_data\\\\phase1_example_data\\\\Y\\\\Y_003.npz']\n" + ] + } + ], + "source": [ + "# split data set into training, valid, testing sets\n", + "xtrain, ytrain, xvalid, yvalid, xtest, ytest = data_split(xfiles, yfiles, 0.7225, 0.85, maskname='sample_train_valid_test_mask.npz')\n", + "print(xtrain, ytrain)\n", + "print(xvalid, yvalid)\n", + "print(xtest, ytest)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "8b8ff51e", + "metadata": {}, + "outputs": [], + "source": [ + "train_generator = data_generator(xtrain, ytrain, level1, level2, batch_size=5)\n", + "valid_generator = data_generator(xvalid, yvalid, level1, level2, batch_size=5)" + ] + }, + { + "cell_type": "markdown", + "id": "f86f71d5", + "metadata": {}, + "source": [ + "# Set up the training" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "396592b5", + "metadata": {}, + "outputs": [], + "source": [ + "from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "83ea3d0b", + "metadata": {}, + "outputs": [], + "source": [ + "csv_logger = CSVLogger( 'sample_log.csv' , append=True, separator=';')\n", + "earlystopper = EarlyStopping(patience=20, verbose=1)\n", + "checkpointer = ModelCheckpoint('checkpt.h5', verbose=1, save_best_only=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1407b218", + "metadata": {}, + "outputs": [], + "source": [ + "### For phase-2 training, need to load model weights from phase-1:\n", + "# dl_model.load_weights('path/to/phase1/model/weights')\n", + "\n", + "\n", + "dl_model.train(train_generator,\n", + " validation_data=valid_generator, epochs=250,\n", + " callbacks=[earlystopper, checkpointer, csv_logger])" + ] + }, + { + "cell_type": "markdown", + "id": "c28897d3", + "metadata": {}, + "source": [ + "# Save the model weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b8528c5d", + "metadata": {}, + "outputs": [], + "source": [ + "dl_model.save_model('sample_model.h5')" + ] + }, + { + "cell_type": "markdown", + "id": "657fdec4", + "metadata": {}, + "source": [ + "# Load model weights and predict DIC" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "ea1f74db", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sample_data/phase1_sample_data/X\\X_001.npz\n", + "X\\X_001.npz\n", + "sample_data/phase1_sample_data/X\\X_003.npz\n", + "X\\X_003.npz\n", + "sample_data/phase1_sample_data/X\\X_007.npz\n", + "X\\X_007.npz\n" + ] + } + ], + "source": [ + "from utils.functions import process_x, add_buffer\n", + "\n", + "dl_model.load_weights('path/to/pretrained/model/weights')\n", + "\n", + "for x in xtest:\n", + " xdata = np.load(x)['x']\n", + " xdata = xdata[np.newaxis, :, :, :] # fill the batch dimension, 1, 10, 56, 360\n", + " xdata = np.moveaxis(xdata, 1, -1) # rotate the dimension of the input vector to match the model: 1, 56, 360, 10\n", + " xdata = process_x(xdata) # apply scaling factors\n", + " xdata = add_buffer(xdata) # add buffer domain\n", + " pred = dl_model.predict(xdata) \n", + " pred = add_buffer(pred, direction=-1)[0] # remove buffer domain\n", + " filename = 'path/to/output/prediction_' + x.split('_')[-1]\n", + " print(\" >>>>>>>>> Saving: \", filename)\n", + " np.save(filename, pred)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}