-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
381 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,381 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "e0c3d03e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from tensorflow.keras.optimizers import Adam\n", | ||
"from keras.callbacks import CSVLogger, EarlyStopping, ModelCheckpoint\n", | ||
"from model.core import RUnet_model\n", | ||
"from utils.functions import r_score, mse_nonzero, data_split\n", | ||
"import numpy as np" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cdac820e", | ||
"metadata": {}, | ||
"source": [ | ||
"# Build and compile model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "33d0d397", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Model: \"model\"\n", | ||
"__________________________________________________________________________________________________\n", | ||
" Layer (type) Output Shape Param # Connected to \n", | ||
"==================================================================================================\n", | ||
" model_input (InputLayer) [(None, 48, 376, 10 0 [] \n", | ||
" )] \n", | ||
" \n", | ||
" lambda (Lambda) (None, 48, 376, 10) 0 ['model_input[0][0]'] \n", | ||
" \n", | ||
" Block1_Conv1 (Conv2D) (None, 48, 376, 128 11648 ['lambda[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" Block1_Conv2 (Conv2D) (None, 48, 376, 128 147584 ['Block1_Conv1[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" Block1_MaxPool (MaxPooling2D) (None, 24, 188, 128 0 ['Block1_Conv2[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" Block2_Conv1 (Conv2D) (None, 24, 188, 256 295168 ['Block1_MaxPool[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" Block2_Conv2 (Conv2D) (None, 24, 188, 256 590080 ['Block2_Conv1[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" Block2_MaxPool (MaxPooling2D) (None, 12, 94, 256) 0 ['Block2_Conv2[0][0]'] \n", | ||
" \n", | ||
" Block3_Conv1 (Conv2D) (None, 12, 94, 512) 1180160 ['Block2_MaxPool[0][0]'] \n", | ||
" \n", | ||
" Block3_Conv2 (Conv2D) (None, 12, 94, 512) 2359808 ['Block3_Conv1[0][0]'] \n", | ||
" \n", | ||
" Block3_MaxPool (MaxPooling2D) (None, 6, 47, 512) 0 ['Block3_Conv2[0][0]'] \n", | ||
" \n", | ||
" Block4_Permute1 (Permute) (None, 512, 6, 47) 0 ['Block3_MaxPool[0][0]'] \n", | ||
" \n", | ||
" Block4_Reshape (Reshape) (None, 512, 282) 0 ['Block4_Permute1[0][0]'] \n", | ||
" \n", | ||
" Block4_Permute2 (Permute) (None, 282, 512) 0 ['Block4_Reshape[0][0]'] \n", | ||
" \n", | ||
" LSTM1 (LSTM) (None, 282, 1024) 6295552 ['Block4_Permute2[0][0]'] \n", | ||
" \n", | ||
" Block5_Reshape (Reshape) (None, 6, 47, 1024) 0 ['LSTM1[0][0]'] \n", | ||
" \n", | ||
" Block5_UpConv (Conv2DTranspose (None, 12, 94, 512) 2097664 ['Block5_Reshape[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" concatenate (Concatenate) (None, 12, 94, 1024 0 ['Block5_UpConv[0][0]', \n", | ||
" ) 'Block3_Conv2[0][0]'] \n", | ||
" \n", | ||
" Block5_Conv1 (Conv2D) (None, 12, 94, 512) 4719104 ['concatenate[0][0]'] \n", | ||
" \n", | ||
" Block5_Conv2 (Conv2D) (None, 12, 94, 512) 2359808 ['Block5_Conv1[0][0]'] \n", | ||
" \n", | ||
" Block6_UpConv (Conv2DTranspose (None, 24, 188, 256 524544 ['Block5_Conv2[0][0]'] \n", | ||
" ) ) \n", | ||
" \n", | ||
" concatenate_1 (Concatenate) (None, 24, 188, 512 0 ['Block6_UpConv[0][0]', \n", | ||
" ) 'Block2_Conv2[0][0]'] \n", | ||
" \n", | ||
" Block6_Conv1 (Conv2D) (None, 24, 188, 256 1179904 ['concatenate_1[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" Block6_Conv2 (Conv2D) (None, 24, 188, 256 590080 ['Block6_Conv1[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" Block7_UpConv (Conv2DTranspose (None, 48, 376, 128 131200 ['Block6_Conv2[0][0]'] \n", | ||
" ) ) \n", | ||
" \n", | ||
" concatenate_2 (Concatenate) (None, 48, 376, 256 0 ['Block7_UpConv[0][0]', \n", | ||
" ) 'Block1_Conv2[0][0]'] \n", | ||
" \n", | ||
" Block7_Conv1 (Conv2D) (None, 48, 376, 128 295040 ['concatenate_2[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" Block7_Conv2 (Conv2D) (None, 48, 376, 128 147584 ['Block7_Conv1[0][0]'] \n", | ||
" ) \n", | ||
" \n", | ||
" model_output_2 (Conv2D) (None, 48, 376, 2) 258 ['Block7_Conv2[0][0]'] \n", | ||
" \n", | ||
"==================================================================================================\n", | ||
"Total params: 22,925,186\n", | ||
"Trainable params: 22,925,186\n", | ||
"Non-trainable params: 0\n", | ||
"__________________________________________________________________________________________________\n" | ||
] | ||
}, | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"D:\\anaconda\\envs\\tf3\\lib\\site-packages\\keras\\optimizer_v2\\adam.py:105: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.\n", | ||
" super(Adam, self).__init__(name, **kwargs)\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# build a DL model for first two levels\n", | ||
"level1 = 1\n", | ||
"level2 = 2\n", | ||
"\n", | ||
"dl_model = RUnet_model(level1, level2)\n", | ||
"opt = Adam(lr=1e-5) \n", | ||
"\n", | ||
"dl_model.compile(optimizer=opt, loss=mse_nonzero, metrics=[r_score, mse_nonzero])\n", | ||
"dl_model.info()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "2bad7378", | ||
"metadata": {}, | ||
"source": [ | ||
"# Load data sets using data generator" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "a64f192b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from utils.data_generator import data_generator\n", | ||
"import glob" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"id": "c02ff7f3", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"['example_data\\\\phase1_example_data\\\\X\\\\X_000.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_001.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_002.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_003.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_004.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1030.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1031.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1032.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1033.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1034.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_000.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_001.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_002.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_003.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_004.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1030.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1031.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1032.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1033.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1034.npz']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# phase-1 training using preprocessed B-SOSE data sets\n", | ||
"xfiles = np.array(sorted(glob.glob('example_data\\phase1_example_data\\X/X*.npz')))\n", | ||
"yfiles = np.array(sorted(glob.glob('example_data\\phase1_example_data\\Y/Y*.npz')))\n", | ||
"\n", | ||
"print(xfiles, yfiles)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"id": "2286ec1d", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[4 5 2 6 1 9 8]\n", | ||
"['example_data\\\\phase1_example_data\\\\X\\\\X_004.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1030.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_002.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1031.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_001.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1034.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_1033.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_004.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1030.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_002.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1031.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_001.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1034.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_1033.npz']\n", | ||
"['example_data\\\\phase1_example_data\\\\X\\\\X_000.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_000.npz']\n", | ||
"['example_data\\\\phase1_example_data\\\\X\\\\X_1032.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\X\\\\X_003.npz'] ['example_data\\\\phase1_example_data\\\\Y\\\\Y_1032.npz'\n", | ||
" 'example_data\\\\phase1_example_data\\\\Y\\\\Y_003.npz']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# split data set into training, valid, testing sets\n", | ||
"xtrain, ytrain, xvalid, yvalid, xtest, ytest = data_split(xfiles, yfiles, 0.7225, 0.85, maskname='sample_train_valid_test_mask.npz')\n", | ||
"print(xtrain, ytrain)\n", | ||
"print(xvalid, yvalid)\n", | ||
"print(xtest, ytest)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 17, | ||
"id": "8b8ff51e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"train_generator = data_generator(xtrain, ytrain, level1, level2, batch_size=5)\n", | ||
"valid_generator = data_generator(xvalid, yvalid, level1, level2, batch_size=5)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f86f71d5", | ||
"metadata": {}, | ||
"source": [ | ||
"# Set up the training" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 18, | ||
"id": "396592b5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from keras.callbacks import EarlyStopping, ModelCheckpoint, CSVLogger" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 26, | ||
"id": "83ea3d0b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"csv_logger = CSVLogger( 'sample_log.csv' , append=True, separator=';')\n", | ||
"earlystopper = EarlyStopping(patience=20, verbose=1)\n", | ||
"checkpointer = ModelCheckpoint('checkpt.h5', verbose=1, save_best_only=True)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1407b218", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"### For phase-2 training, need to load model weights from phase-1:\n", | ||
"# dl_model.load_weights('path/to/phase1/model/weights')\n", | ||
"\n", | ||
"\n", | ||
"dl_model.train(train_generator,\n", | ||
" validation_data=valid_generator, epochs=250,\n", | ||
" callbacks=[earlystopper, checkpointer, csv_logger])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c28897d3", | ||
"metadata": {}, | ||
"source": [ | ||
"# Save the model weights" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "b8528c5d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"dl_model.save_model('sample_model.h5')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "657fdec4", | ||
"metadata": {}, | ||
"source": [ | ||
"# Load model weights and predict DIC" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"id": "ea1f74db", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"sample_data/phase1_sample_data/X\\X_001.npz\n", | ||
"X\\X_001.npz\n", | ||
"sample_data/phase1_sample_data/X\\X_003.npz\n", | ||
"X\\X_003.npz\n", | ||
"sample_data/phase1_sample_data/X\\X_007.npz\n", | ||
"X\\X_007.npz\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from utils.functions import process_x, add_buffer\n", | ||
"\n", | ||
"dl_model.load_weights('path/to/pretrained/model/weights')\n", | ||
"\n", | ||
"for x in xtest:\n", | ||
" xdata = np.load(x)['x']\n", | ||
" xdata = xdata[np.newaxis, :, :, :] # fill the batch dimension, 1, 10, 56, 360\n", | ||
" xdata = np.moveaxis(xdata, 1, -1) # rotate the dimension of the input vector to match the model: 1, 56, 360, 10\n", | ||
" xdata = process_x(xdata) # apply scaling factors\n", | ||
" xdata = add_buffer(xdata) # add buffer domain\n", | ||
" pred = dl_model.predict(xdata) \n", | ||
" pred = add_buffer(pred, direction=-1)[0] # remove buffer domain\n", | ||
" filename = 'path/to/output/prediction_' + x.split('_')[-1]\n", | ||
" print(\" >>>>>>>>> Saving: \", filename)\n", | ||
" np.save(filename, pred)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.10" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |