Skip to content

Commit

Permalink
Update to the demo notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
tailonghe committed Jun 14, 2022
1 parent 0d128c7 commit ecfe432
Showing 1 changed file with 381 additions and 0 deletions.
381 changes: 381 additions & 0 deletions example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,381 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "e0c3d03e",
"metadata": {},
"outputs": [],
"source": [
"from tensorflow.keras.optimizers import Adam\n",
"from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint\n",
"from model.core import RUnet_model\n",
"from utils.functions import r_score, mse_nonzero, data_split\n",
"import numpy as np"
]
},
{
"cell_type": "markdown",
"id": "cdac820e",
"metadata": {},
"source": [
"# Build and compile model"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "33d0d397",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Model: \"model\"\n",
"__________________________________________________________________________________________________\n",
" Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
" model_input (InputLayer) [(None, 48, 376, 10 0 [] \n",
" )] \n",
" \n",
" lambda (Lambda) (None, 48, 376, 10) 0 ['model_input[0][0]'] \n",
" \n",
" Block1_Conv1 (Conv2D) (None, 48, 376, 128 11648 ['lambda[0][0]'] \n",
" ) \n",
" \n",
" Block1_Conv2 (Conv2D) (None, 48, 376, 128 147584 ['Block1_Conv1[0][0]'] \n",
" ) \n",
" \n",
" Block1_MaxPool (MaxPooling2D) (None, 24, 188, 128 0 ['Block1_Conv2[0][0]'] \n",
" ) \n",
" \n",
" Block2_Conv1 (Conv2D) (None, 24, 188, 256 295168 ['Block1_MaxPool[0][0]'] \n",
" ) \n",
" \n",
" Block2_Conv2 (Conv2D) (None, 24, 188, 256 590080 ['Block2_Conv1[0][0]'] \n",
" ) \n",
" \n",
" Block2_MaxPool (MaxPooling2D) (None, 12, 94, 256) 0 ['Block2_Conv2[0][0]'] \n",
" \n",
" Block3_Conv1 (Conv2D) (None, 12, 94, 512) 1180160 ['Block2_MaxPool[0][0]'] \n",
" \n",
" Block3_Conv2 (Conv2D) (None, 12, 94, 512) 2359808 ['Block3_Conv1[0][0]'] \n",
" \n",
" Block3_MaxPool (MaxPooling2D) (None, 6, 47, 512) 0 ['Block3_Conv2[0][0]'] \n",
" \n",
" Block4_Permute1 (Permute) (None, 512, 6, 47) 0 ['Block3_MaxPool[0][0]'] \n",
" \n",
" Block4_Reshape (Reshape) (None, 512, 282) 0 ['Block4_Permute1[0][0]'] \n",
" \n",
" Block4_Permute2 (Permute) (None, 282, 512) 0 ['Block4_Reshape[0][0]'] \n",
" \n",
" LSTM1 (LSTM) (None, 282, 1024) 6295552 ['Block4_Permute2[0][0]'] \n",
" \n",
" Block5_Reshape (Reshape) (None, 6, 47, 1024) 0 ['LSTM1[0][0]'] \n",
" \n",
" Block5_UpConv (Conv2DTranspose (None, 12, 94, 512) 2097664 ['Block5_Reshape[0][0]'] \n",
" ) \n",
" \n",
" concatenate (Concatenate) (None, 12, 94, 1024 0 ['Block5_UpConv[0][0]', \n",
" ) 'Block3_Conv2[0][0]'] \n",
" \n",
" Block5_Conv1 (Conv2D) (None, 12, 94, 512) 4719104 ['concatenate[0][0]'] \n",
" \n",
" Block5_Conv2 (Conv2D) (None, 12, 94, 512) 2359808 ['Block5_Conv1[0][0]'] \n",
" \n",
" Block6_UpConv (Conv2DTranspose (None, 24, 188, 256 524544 ['Block5_Conv2[0][0]'] \n",
" ) ) \n",
" \n",
" concatenate_1 (Concatenate) (None, 24, 188, 512 0 ['Block6_UpConv[0][0]', \n",
" ) 'Block2_Conv2[0][0]'] \n",
" \n",
" Block6_Conv1 (Conv2D) (None, 24, 188, 256 1179904 ['concatenate_1[0][0]'] \n",
" ) \n",
" \n",
" Block6_Conv2 (Conv2D) (None, 24, 188, 256 590080 ['Block6_Conv1[0][0]'] \n",
" ) \n",
" \n",
" Block7_UpConv (Conv2DTranspose (None, 48, 376, 128 131200 ['Block6_Conv2[0][0]'] \n",
" ) ) \n",
" \n",
" concatenate_2 (Concatenate) (None, 48, 376, 256 0 ['Block7_UpConv[0][0]', \n",
" ) 'Block1_Conv2[0][0]'] \n",
" \n",
" Block7_Conv1 (Conv2D) (None, 48, 376, 128 295040 ['concatenate_2[0][0]'] \n",
" ) \n",
" \n",
" Block7_Conv2 (Conv2D) (None, 48, 376, 128 147584 ['Block7_Conv1[0][0]'] \n",
" ) \n",
" \n",
" model_output_2 (Conv2D) (None, 48, 376, 2) 258 ['Block7_Conv2[0][0]'] \n",
" \n",
"==================================================================================================\n",
"Total params: 22,925,186\n",
"Trainable params: 22,925,186\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"D:\\anaconda\\envs\\tf3\\lib\\site-packages\\keras\\optimizer_v2\\adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n",
" super(Adam, self).__init__(name, **kwargs)\n"
]
}
],
"source": [
"# build a DL model for first two levels\n",
"level1 = 1\n",
"level2 = 2\n",
"\n",
"dl_model = RUnet_model(level1, level2)\n",
"opt = Adam(lr=1e-5) \n",
"\n",
"dl_model.compile(optimizer=opt, loss=mse_nonzero, metrics=[r_score, mse_nonzero])\n",
"dl_model.info()"
]
},
{
"cell_type": "markdown",
"id": "2bad7378",
"metadata": {},
"source": [
"# Load data sets using data generator"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "a64f192b",
"metadata": {},
"outputs": [],
"source": [
"from utils.data_generator import data_generator\n",
"import glob"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "c02ff7f3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['example_data\\\\phase1_example_data\\\\X\\\\X_000.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_001.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_002.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_003.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_004.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1030.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1031.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1032.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1033.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1034.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_000.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_001.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_002.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_003.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_004.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1030.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1031.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1032.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1033.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1034.npz']\n"
]
}
],
"source": [
"# phase-1 training using preprocessed B-SOSE data sets\n",
"xfiles = np.array(sorted(glob.glob('example_data\\phase1_example_data\\X/X*.npz')))\n",
"yfiles = np.array(sorted(glob.glob('example_data\\phase1_example_data\\Y/Y*.npz')))\n",
"\n",
"print(xfiles, yfiles)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "2286ec1d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[4 5 2 6 1 9 8]\n",
"['example_data\\\\phase1_example_data\\\\X\\\\X_004.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1030.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_002.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1031.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_001.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1034.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1033.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_004.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1030.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_002.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1031.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_001.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1034.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1033.npz']\n",
"['example_data\\\\phase1_example_data\\\\X\\\\X_000.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_000.npz']\n",
"['example_data\\\\phase1_example_data\\\\X\\\\X_1032.npz'\n",
" 'example_data\\\\phase1_example_data\\\\X\\\\X_003.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_1032.npz'\n",
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_003.npz']\n"
]
}
],
"source": [
"# split data set into training, valid, testing sets\n",
"xtrain, ytrain, xvalid, yvalid, xtest, ytest = data_split(xfiles, yfiles, 0.7225, 0.85, maskname='sample_train_valid_test_mask.npz')\n",
"print(xtrain, ytrain)\n",
"print(xvalid, yvalid)\n",
"print(xtest, ytest)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "8b8ff51e",
"metadata": {},
"outputs": [],
"source": [
"train_generator = data_generator(xtrain, ytrain, level1, level2, batch_size=5)\n",
"valid_generator = data_generator(xvalid, yvalid, level1, level2, batch_size=5)"
]
},
{
"cell_type": "markdown",
"id": "f86f71d5",
"metadata": {},
"source": [
"# Set up the training"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "396592b5",
"metadata": {},
"outputs": [],
"source": [
"from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "83ea3d0b",
"metadata": {},
"outputs": [],
"source": [
"csv_logger = CSVLogger( 'sample_log.csv' , append=True, separator=';')\n",
"earlystopper = EarlyStopping(patience=20, verbose=1)\n",
"checkpointer = ModelCheckpoint('checkpt.h5', verbose=1, save_best_only=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1407b218",
"metadata": {},
"outputs": [],
"source": [
"### For phase-2 training, need to load model weights from phase-1:\n",
"# dl_model.load_weights('path/to/phase1/model/weights')\n",
"\n",
"\n",
"dl_model.train(train_generator,\n",
" validation_data=valid_generator, epochs=250,\n",
" callbacks=[earlystopper, checkpointer, csv_logger])"
]
},
{
"cell_type": "markdown",
"id": "c28897d3",
"metadata": {},
"source": [
"# Save the model weights"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b8528c5d",
"metadata": {},
"outputs": [],
"source": [
"dl_model.save_model('sample_model.h5')"
]
},
{
"cell_type": "markdown",
"id": "657fdec4",
"metadata": {},
"source": [
"# Load model weights and predict DIC"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "ea1f74db",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"sample_data/phase1_sample_data/X\\X_001.npz\n",
"X\\X_001.npz\n",
"sample_data/phase1_sample_data/X\\X_003.npz\n",
"X\\X_003.npz\n",
"sample_data/phase1_sample_data/X\\X_007.npz\n",
"X\\X_007.npz\n"
]
}
],
"source": [
"from utils.functions import process_x, add_buffer\n",
"\n",
"dl_model.load_weights('path/to/pretrained/model/weights')\n",
"\n",
"for x in xtest:\n",
" xdata = np.load(x)['x']\n",
" xdata = xdata[np.newaxis, :, :, :] # fill the batch dimension, 1, 10, 56, 360\n",
" xdata = np.moveaxis(xdata, 1, -1) # rotate the dimension of the input vector to match the model: 1, 56, 360, 10\n",
" xdata = process_x(xdata) # apply scaling factors\n",
" xdata = add_buffer(xdata) # add buffer domain\n",
" pred = dl_model.predict(xdata) \n",
" pred = add_buffer(pred, direction=-1)[0] # remove buffer domain\n",
" filename = 'path/to/output/prediction_' + x.split('_')[-1]\n",
" print(\" >>>>>>>>> Saving: \", filename)\n",
" np.save(filename, pred)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit ecfe432

Please sign in to comment.