Skip to content

Commit

Permalink
使用 Colaboratory 创建
Browse files Browse the repository at this point in the history
  • Loading branch information
tjwei committed Sep 2, 2019
1 parent 4fb3f2d commit 272e974
Showing 1 changed file with 225 additions and 0 deletions.
225 changes: 225 additions & 0 deletions LSTM-imdb.ipynb
@@ -0,0 +1,225 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.3"
},
"colab": {
"name": "04-LSTM-imdb.ipynb",
"version": "0.3.2",
"provenance": []
}
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "tecE8g82xhGm",
"colab_type": "text"
},
"source": [
"* https://github.com/fchollet/keras/blob/master/examples/imdb_lstm.py"
]
},
{
"cell_type": "code",
"metadata": {
"id": "cK8bGPjnxhGu",
"colab_type": "code",
"colab": {},
"outputId": "d127f15a-51af-4dae-a375-ce45a338430d"
},
"source": [
"import numpy as np\n",
"from keras.preprocessing import sequence\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Embedding\n",
"from keras.layers import LSTM\n",
"from keras.datasets import imdb"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
],
"name": "stderr"
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "PD1ijEhhxhG7",
"colab_type": "code",
"colab": {}
},
"source": [
"vocabulary_size = 15000\n",
"(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=vocabulary_size)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "tPSBkgBkxhHE",
"colab_type": "code",
"colab": {}
},
"source": [
"maxlen = 80\n",
"x_train = sequence.pad_sequences(x_train, maxlen=maxlen)\n",
"x_test = sequence.pad_sequences(x_test, maxlen=maxlen)"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "C9nqoeMxxhHO",
"colab_type": "code",
"colab": {}
},
"source": [
"model = Sequential()\n",
"model.add(Embedding(vocabulary_size, 128))\n",
"model.add(LSTM(64, dropout=0.2, recurrent_dropout=0.2))\n",
"model.add(Dense(1, activation='sigmoid'))\n",
"model.compile(loss='binary_crossentropy',\n",
" optimizer='adam',\n",
" metrics=['accuracy'])"
],
"execution_count": 0,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Y5S4ZQDPxhHY",
"colab_type": "code",
"colab": {},
"outputId": "47ae875c-f484-4011-d610-4b7eaf70ce82"
},
"source": [
"from IPython.display import SVG, display\n",
"from keras.utils.vis_utils import model_to_dot\n",
"\n",
"SVG(model_to_dot(model, show_shapes=True).create(prog='dot', format='svg'))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "execute_result",
"data": {
"image/svg+xml": "<svg height=\"304pt\" viewBox=\"0.00 0.00 348.00 304.00\" width=\"348pt\" xmlns=\"http://www.w3.org/2000/svg\" xmlns:xlink=\"http://www.w3.org/1999/xlink\">\n<g class=\"graph\" id=\"graph0\" transform=\"scale(1 1) rotate(0) translate(4 300)\">\n<title>G</title>\n<polygon fill=\"white\" points=\"-4,4 -4,-300 344,-300 344,4 -4,4\" stroke=\"none\"/>\n<!-- 140521308467664 -->\n<g class=\"node\" id=\"node1\"><title>140521308467664</title>\n<polygon fill=\"none\" points=\"0,-249.5 0,-295.5 340,-295.5 340,-249.5 0,-249.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"96.5\" y=\"-268.8\">embedding_1_input: InputLayer</text>\n<polyline fill=\"none\" points=\"193,-249.5 193,-295.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"220.5\" y=\"-280.3\">input:</text>\n<polyline fill=\"none\" points=\"193,-272.5 248,-272.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"220.5\" y=\"-257.3\">output:</text>\n<polyline fill=\"none\" points=\"248,-249.5 248,-295.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"294\" y=\"-280.3\">(None, None)</text>\n<polyline fill=\"none\" points=\"248,-272.5 340,-272.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"294\" y=\"-257.3\">(None, None)</text>\n</g>\n<!-- 140521873195304 -->\n<g class=\"node\" id=\"node2\"><title>140521873195304</title>\n<polygon fill=\"none\" points=\"2.5,-166.5 2.5,-212.5 337.5,-212.5 337.5,-166.5 2.5,-166.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"83\" y=\"-185.8\">embedding_1: Embedding</text>\n<polyline fill=\"none\" points=\"163.5,-166.5 163.5,-212.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-197.3\">input:</text>\n<polyline fill=\"none\" points=\"163.5,-189.5 218.5,-189.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"191\" y=\"-174.3\">output:</text>\n<polyline fill=\"none\" points=\"218.5,-166.5 218.5,-212.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"278\" y=\"-197.3\">(None, None)</text>\n<polyline fill=\"none\" points=\"218.5,-189.5 337.5,-189.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"278\" y=\"-174.3\">(None, None, 128)</text>\n</g>\n<!-- 140521308467664&#45;&gt;140521873195304 -->\n<g class=\"edge\" id=\"edge1\"><title>140521308467664-&gt;140521873195304</title>\n<path d=\"M170,-249.366C170,-241.152 170,-231.658 170,-222.725\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"173.5,-222.607 170,-212.607 166.5,-222.607 173.5,-222.607\" stroke=\"black\"/>\n</g>\n<!-- 140521308467888 -->\n<g class=\"node\" id=\"node3\"><title>140521308467888</title>\n<polygon fill=\"none\" points=\"34,-83.5 34,-129.5 306,-129.5 306,-83.5 34,-83.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"83\" y=\"-102.8\">lstm_1: LSTM</text>\n<polyline fill=\"none\" points=\"132,-83.5 132,-129.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"159.5\" y=\"-114.3\">input:</text>\n<polyline fill=\"none\" points=\"132,-106.5 187,-106.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"159.5\" y=\"-91.3\">output:</text>\n<polyline fill=\"none\" points=\"187,-83.5 187,-129.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"246.5\" y=\"-114.3\">(None, None, 128)</text>\n<polyline fill=\"none\" points=\"187,-106.5 306,-106.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"246.5\" y=\"-91.3\">(None, 64)</text>\n</g>\n<!-- 140521873195304&#45;&gt;140521308467888 -->\n<g class=\"edge\" id=\"edge2\"><title>140521873195304-&gt;140521308467888</title>\n<path d=\"M170,-166.366C170,-158.152 170,-148.658 170,-139.725\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"173.5,-139.607 170,-129.607 166.5,-139.607 173.5,-139.607\" stroke=\"black\"/>\n</g>\n<!-- 140521063067544 -->\n<g class=\"node\" id=\"node4\"><title>140521063067544</title>\n<polygon fill=\"none\" points=\"53.5,-0.5 53.5,-46.5 286.5,-46.5 286.5,-0.5 53.5,-0.5\" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"104.5\" y=\"-19.8\">dense_1: Dense</text>\n<polyline fill=\"none\" points=\"155.5,-0.5 155.5,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"183\" y=\"-31.3\">input:</text>\n<polyline fill=\"none\" points=\"155.5,-23.5 210.5,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"183\" y=\"-8.3\">output:</text>\n<polyline fill=\"none\" points=\"210.5,-0.5 210.5,-46.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.5\" y=\"-31.3\">(None, 64)</text>\n<polyline fill=\"none\" points=\"210.5,-23.5 286.5,-23.5 \" stroke=\"black\"/>\n<text font-family=\"Times,serif\" font-size=\"14.00\" text-anchor=\"middle\" x=\"248.5\" y=\"-8.3\">(None, 1)</text>\n</g>\n<!-- 140521308467888&#45;&gt;140521063067544 -->\n<g class=\"edge\" id=\"edge3\"><title>140521308467888-&gt;140521063067544</title>\n<path d=\"M170,-83.3664C170,-75.1516 170,-65.6579 170,-56.7252\" fill=\"none\" stroke=\"black\"/>\n<polygon fill=\"black\" points=\"173.5,-56.6068 170,-46.6068 166.5,-56.6069 173.5,-56.6068\" stroke=\"black\"/>\n</g>\n</g>\n</svg>",
"text/plain": [
"<IPython.core.display.SVG object>"
]
},
"metadata": {
"tags": []
},
"execution_count": 5
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "_Lk8halnxhHh",
"colab_type": "code",
"colab": {},
"outputId": "eaf9595f-0864-427e-d077-a00af93c3123"
},
"source": [
"model.fit(x_train, y_train,\n",
" batch_size=32,\n",
" epochs=8,\n",
" validation_data=(x_test, y_test))"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"Train on 25000 samples, validate on 25000 samples\n",
"Epoch 1/8\n",
"25000/25000 [==============================] - 66s - loss: 0.4564 - acc: 0.7864 - val_loss: 0.3689 - val_acc: 0.8323\n",
"Epoch 2/8\n",
"25000/25000 [==============================] - 63s - loss: 0.3023 - acc: 0.8754 - val_loss: 0.3948 - val_acc: 0.8256\n",
"Epoch 3/8\n",
"25000/25000 [==============================] - 68s - loss: 0.2303 - acc: 0.9099 - val_loss: 0.4206 - val_acc: 0.8339\n",
"Epoch 4/8\n",
"25000/25000 [==============================] - 66s - loss: 0.1733 - acc: 0.9318 - val_loss: 0.4517 - val_acc: 0.8327\n",
"Epoch 5/8\n",
"25000/25000 [==============================] - 64s - loss: 0.1282 - acc: 0.9525 - val_loss: 0.5009 - val_acc: 0.8264\n",
"Epoch 6/8\n",
"25000/25000 [==============================] - 65s - loss: 0.0940 - acc: 0.9651 - val_loss: 0.6640 - val_acc: 0.8196\n",
"Epoch 7/8\n",
"25000/25000 [==============================] - 66s - loss: 0.0724 - acc: 0.9745 - val_loss: 0.7094 - val_acc: 0.8218\n",
"Epoch 8/8\n",
"25000/25000 [==============================] - 65s - loss: 0.0643 - acc: 0.9775 - val_loss: 0.6907 - val_acc: 0.8200\n"
],
"name": "stdout"
},
{
"output_type": "execute_result",
"data": {
"text/plain": [
"<keras.callbacks.History at 0x7fcdcc55c1d0>"
]
},
"metadata": {
"tags": []
},
"execution_count": 6
}
]
},
{
"cell_type": "code",
"metadata": {
"id": "WldAOPrJxhHp",
"colab_type": "code",
"colab": {},
"outputId": "5f4a7eb7-65ab-48d1-c227-456a1442aad5"
},
"source": [
"score, acc = model.evaluate(x_test, y_test, batch_size=32)\n",
"print('Test score:', score)\n",
"print('Test accuracy:', acc)"
],
"execution_count": 0,
"outputs": [
{
"output_type": "stream",
"text": [
"25000/25000 [==============================] - 12s \n",
"Test score: 0.690729214754\n",
"Test accuracy: 0.82004\n"
],
"name": "stdout"
}
]
}
]
}

0 comments on commit 272e974

Please sign in to comment.