"
+ ]
+ },
+ "metadata": {
+ "tags": []
+ },
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Saving train_PlayMusic_full.json to train_PlayMusic_full.json\n",
+ "Saving validate_PlayMusic.json to validate_PlayMusic.json\n"
+ ]
+ }
+ ],
"source": [
- "train_loc = \"Data/snips/train_PlayMusic_full.json\"\n",
- "test_loc = \"Data/snips/validate_PlayMusic.json\"\n",
+ "try:\n",
+ " from google.colab import files\n",
+ "# upload \"train_PlayMusic_full.json\" and \"validate_PlayMusic.json\" here\n",
+ " uploaded = files.upload()\n",
+ " train_loc = \"train_PlayMusic_full.json\"\n",
+ " test_loc = \"validate_PlayMusic.json\"\n",
+ " \n",
+ "except ModuleNotFoundError:\n",
+ " train_loc = \"Data/snips/train_PlayMusic_full.json\"\n",
+ " test_loc = \"Data/snips/validate_PlayMusic.json\"\n",
"\n",
"train_file = json.load(open(train_loc, encoding= \"iso-8859-2\"))\n",
"test_file = json.load(open(test_loc, encoding= \"iso-8859-2\"))"
@@ -62,7 +137,8 @@
"ExecuteTime": {
"end_time": "2020-01-28T19:01:10.791023Z",
"start_time": "2020-01-28T19:01:10.786037Z"
- }
+ },
+ "id": "LwUcnYOwPpU0"
},
"outputs": [],
"source": [
@@ -77,7 +153,8 @@
"ExecuteTime": {
"end_time": "2020-01-28T19:01:10.798005Z",
"start_time": "2020-01-28T19:01:10.792054Z"
- }
+ },
+ "id": "crYoV0k9PpU0"
},
"outputs": [],
"source": [
@@ -109,7 +186,12 @@
"ExecuteTime": {
"end_time": "2020-01-28T19:01:11.033375Z",
"start_time": "2020-01-28T19:01:10.799002Z"
- }
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "2n3GrvhXPpU1",
+ "outputId": "cecb04b9-5666-44de-906e-1f49eb5ac448"
},
"outputs": [
{
@@ -2226,18 +2308,73 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:11.037401Z",
"start_time": "2020-01-28T19:01:11.034387Z"
- }
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "6lkzF_a3PpU1",
+ "outputId": "ae9b6255-109a-44a7-edd3-a9c4af53b061"
},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "--2021-06-09 11:23:24-- http://nlp.stanford.edu/data/glove.6B.zip\n",
+ "Resolving nlp.stanford.edu (nlp.stanford.edu)... 171.64.67.140\n",
+ "Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:80... connected.\n",
+ "HTTP request sent, awaiting response... 302 Found\n",
+ "Location: https://nlp.stanford.edu/data/glove.6B.zip [following]\n",
+ "--2021-06-09 11:23:24-- https://nlp.stanford.edu/data/glove.6B.zip\n",
+ "Connecting to nlp.stanford.edu (nlp.stanford.edu)|171.64.67.140|:443... connected.\n",
+ "HTTP request sent, awaiting response... 301 Moved Permanently\n",
+ "Location: http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip [following]\n",
+ "--2021-06-09 11:23:25-- http://downloads.cs.stanford.edu/nlp/data/glove.6B.zip\n",
+ "Resolving downloads.cs.stanford.edu (downloads.cs.stanford.edu)... 171.64.64.22\n",
+ "Connecting to downloads.cs.stanford.edu (downloads.cs.stanford.edu)|171.64.64.22|:80... connected.\n",
+ "HTTP request sent, awaiting response... 200 OK\n",
+ "Length: 862182613 (822M) [application/zip]\n",
+ "Saving to: ‘glove.6B/glove.6B.zip’\n",
+ "\n",
+ "glove.6B.zip 100%[===================>] 822.24M 5.12MB/s in 2m 42s \n",
+ "\n",
+ "2021-06-09 11:26:07 (5.08 MB/s) - ‘glove.6B/glove.6B.zip’ saved [862182613/862182613]\n",
+ "\n",
+ "Archive: glove.6B/glove.6B.zip\n",
+ " inflating: glove.6B/glove.6B.50d.txt \n",
+ " inflating: glove.6B/glove.6B.100d.txt \n",
+ " inflating: glove.6B/glove.6B.200d.txt \n",
+ " inflating: glove.6B/glove.6B.300d.txt \n"
+ ]
+ }
+ ],
"source": [
- "BASE_DIR = 'Data'\n",
- "GLOVE_DIR = os.path.join(BASE_DIR, 'glove.6B')\n",
+ "try :\n",
+ " from google.colab import files\n",
+ " !wget -P glove.6B http://nlp.stanford.edu/data/glove.6B.zip\n",
+ " !unzip glove.6B/glove.6B.zip -d glove.6B\n",
+ " BASE_DIR='.'\n",
+ "except ModuleNotFoundError :\n",
+ " if not os.path.exists(os.getcwd()+'\\\\Data\\\\glove.6B'):\n",
+ " os.mkdir(os.getcwd()+'\\\\Data\\\\glove.6B')\n",
"\n",
+ " url='http://nlp.stanford.edu/data/glove.6B.zip' \n",
+ " path=os.getcwd()+'\\Data' \n",
+ " wget.download(url,path) \n",
+ "\n",
+ " temp=path+'\\glove.6B.zip' \n",
+ " file = ZipFile(temp) \n",
+ " file.extractall(path+'\\glove.6B') \n",
+ " file.close() \n",
+ " \n",
+ " BASE_DIR = 'Data' \n",
+ " \n",
+ "GLOVE_DIR = os.path.join(BASE_DIR, 'glove.6B')\n",
"MAX_SEQUENCE_LENGTH = 300\n",
"MAX_NUM_WORDS = 20000 \n",
"EMBEDDING_DIM = 100 \n",
@@ -2246,12 +2383,17 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:22.707203Z",
"start_time": "2020-01-28T19:01:11.039359Z"
- }
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "IEGeeSVEPpU2",
+ "outputId": "14ba9578-6b14-43e9-9d2c-6a5213d59a88"
},
"outputs": [
{
@@ -2259,7 +2401,7 @@
"output_type": "stream",
"text": [
"Preparing embedding matrix.\n",
- "Found 400001 word vectors in Glove embeddings.\n"
+ "Found 400000 word vectors in Glove embeddings.\n"
]
}
],
@@ -2289,22 +2431,33 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:22.719134Z",
"start_time": "2020-01-28T19:01:22.709201Z"
- }
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 35
+ },
+ "id": "MSwvgNxaPpU3",
+ "outputId": "163ad86a-d238-4b70-a17c-3ca2a26d833e"
},
"outputs": [
{
"data": {
+ "application/vnd.google.colaboratory.intrinsic+json": {
+ "type": "string"
+ },
"text/plain": [
"'I need to hear the song Aspro Mavro from Bill Szymczyk on Youtube'"
]
},
- "execution_count": 9,
- "metadata": {},
+ "execution_count": 8,
+ "metadata": {
+ "tags": []
+ },
"output_type": "execute_result"
}
],
@@ -2317,12 +2470,17 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:22.773986Z",
"start_time": "2020-01-28T19:01:22.720185Z"
- }
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "jNhOyHPbPpU3",
+ "outputId": "1c559d69-f92c-4933-a4c8-7d960f509fc2"
},
"outputs": [
{
@@ -2344,12 +2502,13 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:22.785954Z",
"start_time": "2020-01-28T19:01:22.775018Z"
- }
+ },
+ "id": "R32CcFGaPpU3"
},
"outputs": [],
"source": [
@@ -2422,12 +2581,13 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:22.792969Z",
"start_time": "2020-01-28T19:01:22.786951Z"
- }
+ },
+ "id": "UWvHwHgjPpU4"
},
"outputs": [],
"source": [
@@ -2443,12 +2603,13 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:22.799982Z",
"start_time": "2020-01-28T19:01:22.793978Z"
- }
+ },
+ "id": "u30ziBm0PpU5"
},
"outputs": [],
"source": [
@@ -2469,12 +2630,13 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:22.807948Z",
"start_time": "2020-01-28T19:01:22.800914Z"
- }
+ },
+ "id": "Kn-9DabIPpU5"
},
"outputs": [],
"source": [
@@ -2502,12 +2664,13 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 14,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:22.813930Z",
"start_time": "2020-01-28T19:01:22.808893Z"
- }
+ },
+ "id": "zkjAIvckPpU5"
},
"outputs": [],
"source": [
@@ -2523,19 +2686,43 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 15,
"metadata": {
"ExecuteTime": {
"end_time": "2020-01-28T19:01:34.092725Z",
"start_time": "2020-01-28T19:01:22.814878Z"
- }
+ },
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "duWlQXq8PpU6",
+ "outputId": "76ab9975-37fc-4573-db35-fcda9b7fd642",
+ "scrolled": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Training a Sequence classification model with CRF\n",
+ "Training a Sequence classification model with CRF\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1515: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n",
+ " average, \"true nor predicted\", 'F-score is', len(true_sum)\n",
+ "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
+ " _warn_prf(average, modifier, msg_start, len(result))\n",
+ "/usr/local/lib/python3.7/dist-packages/sklearn/metrics/_classification.py:1272: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
+ " _warn_prf(average, modifier, msg_start, len(result))\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
"0.8560889758746073\n",
" precision recall f1-score support\n",
"\n",
@@ -2633,20 +2820,6 @@
" music_item-2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 1 3\n",
"Done with sequence model\n"
]
- },
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/home/samyak/OpenSource/practical-nlp/env/lib64/python3.8/site-packages/sklearn/metrics/_classification.py:1464: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no true nor predicted samples. Use `zero_division` parameter to control this behavior.\n",
- " _warn_prf(\n",
- "/home/samyak/OpenSource/practical-nlp/env/lib64/python3.8/site-packages/sklearn/utils/validation.py:68: FutureWarning: Pass labels=['O', 'year-1', 'genre-1', 'genre-2', 'genre-3', 'genre-4', 'genre-5', 'genre-6', 'service-1', 'service-2', 'playlist-1', 'playlist-2', 'playlist-3', 'playlist-4', 'playlist-5', 'playlist-6', 'album-1', 'album-2', 'album-3', 'album-4', 'album-5', 'album-6', 'album-7', 'album-8', 'sort-1', 'sort-2', 'track-1', 'track-2', 'track-3', 'track-4', 'track-5', 'track-6', 'track-7', 'track-8', 'artist-1', 'artist-2', 'artist-3', 'artist-4', 'artist-5', 'artist-6', 'music_item-1', 'music_item-2'] as keyword args. From version 0.25 passing these as positional arguments will result in an error\n",
- " warnings.warn(\"Pass {} as keyword args. From version 0.25 \"\n",
- "/home/samyak/OpenSource/practical-nlp/env/lib64/python3.8/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
- " _warn_prf(average, modifier, msg_start, len(result))\n",
- "/home/samyak/OpenSource/practical-nlp/env/lib64/python3.8/site-packages/sklearn/metrics/_classification.py:1221: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
- " _warn_prf(average, modifier, msg_start, len(result))\n"
- ]
}
],
"source": [
@@ -2656,16 +2829,14 @@
"train_seq(feats, labels, devfeats, devlabels)\n",
"print(\"Done with sequence model\")"
]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
+ "colab": {
+ "collapsed_sections": [],
+ "name": "04_CRF_SNIPS_slots.ipynb",
+ "provenance": []
+ },
"kernelspec": {
"display_name": "Python 3",
"language": "python",
@@ -2681,7 +2852,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.6.12"
+ "version": "3.6.13"
},
"toc": {
"base_numbering": 1,
@@ -2727,5 +2898,5 @@
}
},
"nbformat": 4,
- "nbformat_minor": 2
+ "nbformat_minor": 1
}
From d2ebc54935ea13b38021c07af1af4fe9673b21be Mon Sep 17 00:00:00 2001
From: Kumar Apurva <66004696+KUMAR-APURVA@users.noreply.github.com>
Date: Sun, 13 Jun 2021 19:59:16 +0530
Subject: [PATCH 05/15] [Ch6Nb03] Added missing functions and missing files.
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
1. Added ‘if torch.cuda.is_available()’ to check if we are using gpu.
2. Added a try except block to upload files to colab.
3. Added ‘get_data’ and ‘get_data2’ functions as they were missing.
4. Fixed path for windows and colab.
5. ‘load_atis’ and ‘load_data’ functions were missing, so added it.
6. ‘atis.test.pkl’ and ‘atis.train.pkl’ files were missing, so added these files.
7. Added a if condition for using ‘model.cuda()’ only when we are using gpu.
---
Ch6/03_BERT_ATIS_Binary.ipynb | 983 ++++++++++++++--------------------
1 file changed, 404 insertions(+), 579 deletions(-)
diff --git a/Ch6/03_BERT_ATIS_Binary.ipynb b/Ch6/03_BERT_ATIS_Binary.ipynb
index a870309..25d3d4c 100644
--- a/Ch6/03_BERT_ATIS_Binary.ipynb
+++ b/Ch6/03_BERT_ATIS_Binary.ipynb
@@ -3,7 +3,6 @@
{
"cell_type": "markdown",
"metadata": {
- "colab_type": "text",
"id": "b7We9WSKH7DN"
},
"source": [
@@ -13,92 +12,61 @@
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "id": "vKbMCwXXgqC9"
+ },
"source": [
"## Imports"
]
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": 1,
"metadata": {
"colab": {
- "base_uri": "https://localhost:8080/",
- "height": 546
+ "base_uri": "https://localhost:8080/"
},
- "colab_type": "code",
"id": "Mk2-vK00E2ms",
- "outputId": "beab72e6-87de-4c58-9a56-00e371139a0e"
+ "outputId": "e2dec994-33b8-46fe-a0bd-0f09ed0ee6ec"
},
"outputs": [
- {
- "data": {
- "text/html": [
- "\n",
- "The default version of TensorFlow in Colab will soon switch to TensorFlow 2.x.
\n",
- "We recommend you upgrade now \n",
- "or ensure your notebook will continue to use TensorFlow 1.x via the %tensorflow_version 1.x magic:\n",
- "more info.
\n"
- ],
- "text/plain": [
- ""
- ]
- },
- "metadata": {
- "tags": []
- },
- "output_type": "display_data"
- },
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting pytorch-pretrained-bert\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/d7/e0/c08d5553b89973d9a240605b9c12404bcf8227590de62bae27acbcfe076b/pytorch_pretrained_bert-0.6.2-py3-none-any.whl (123kB)\n",
- "\r",
- "\u001b[K |██▋ | 10kB 22.7MB/s eta 0:00:01\r",
- "\u001b[K |█████▎ | 20kB 2.1MB/s eta 0:00:01\r",
- "\u001b[K |████████ | 30kB 3.1MB/s eta 0:00:01\r",
- "\u001b[K |██████████▋ | 40kB 2.1MB/s eta 0:00:01\r",
- "\u001b[K |█████████████▎ | 51kB 2.5MB/s eta 0:00:01\r",
- "\u001b[K |███████████████▉ | 61kB 3.0MB/s eta 0:00:01\r",
- "\u001b[K |██████████████████▌ | 71kB 3.5MB/s eta 0:00:01\r",
- "\u001b[K |█████████████████████▏ | 81kB 3.9MB/s eta 0:00:01\r",
- "\u001b[K |███████████████████████▉ | 92kB 4.4MB/s eta 0:00:01\r",
- "\u001b[K |██████████████████████████▌ | 102kB 3.4MB/s eta 0:00:01\r",
- "\u001b[K |█████████████████████████████▏ | 112kB 3.4MB/s eta 0:00:01\r",
- "\u001b[K |███████████████████████████████▊| 122kB 3.4MB/s eta 0:00:01\r",
- "\u001b[K |████████████████████████████████| 133kB 3.4MB/s \n",
+ "\u001b[K |████████████████████████████████| 133kB 7.9MB/s \n",
"\u001b[?25hCollecting pytorch-nlp\n",
"\u001b[?25l Downloading https://files.pythonhosted.org/packages/4f/51/f0ee1efb75f7cc2e3065c5da1363d6be2eec79691b2821594f3f2329528c/pytorch_nlp-0.5.0-py3-none-any.whl (90kB)\n",
- "\r",
- "\u001b[K |███▋ | 10kB 26.7MB/s eta 0:00:01\r",
- "\u001b[K |███████▎ | 20kB 32.9MB/s eta 0:00:01\r",
- "\u001b[K |███████████ | 30kB 39.6MB/s eta 0:00:01\r",
- "\u001b[K |██████████████▌ | 40kB 44.0MB/s eta 0:00:01\r",
- "\u001b[K |██████████████████▏ | 51kB 46.7MB/s eta 0:00:01\r",
- "\u001b[K |█████████████████████▉ | 61kB 48.7MB/s eta 0:00:01\r",
- "\u001b[K |█████████████████████████▌ | 71kB 50.1MB/s eta 0:00:01\r",
- "\u001b[K |█████████████████████████████ | 81kB 50.9MB/s eta 0:00:01\r",
- "\u001b[K |████████████████████████████████| 92kB 12.3MB/s \n",
- "\u001b[?25hRequirement already satisfied: boto3 in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (1.10.47)\n",
- "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (1.17.5)\n",
- "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (4.28.1)\n",
- "Requirement already satisfied: torch>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (1.3.1)\n",
- "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (2.21.0)\n",
- "Requirement already satisfied: regex in /usr/local/lib/python3.6/dist-packages (from pytorch-pretrained-bert) (2019.12.20)\n",
- "Requirement already satisfied: s3transfer<0.3.0,>=0.2.0 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-bert) (0.2.1)\n",
- "Requirement already satisfied: botocore<1.14.0,>=1.13.47 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-bert) (1.13.47)\n",
- "Requirement already satisfied: jmespath<1.0.0,>=0.7.1 in /usr/local/lib/python3.6/dist-packages (from boto3->pytorch-pretrained-bert) (0.9.4)\n",
- "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-bert) (2019.11.28)\n",
- "Requirement already satisfied: idna<2.9,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-bert) (2.8)\n",
- "Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-bert) (3.0.4)\n",
- "Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->pytorch-pretrained-bert) (1.24.3)\n",
- "Requirement already satisfied: python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\" in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.47->boto3->pytorch-pretrained-bert) (2.6.1)\n",
- "Requirement already satisfied: docutils<0.16,>=0.10 in /usr/local/lib/python3.6/dist-packages (from botocore<1.14.0,>=1.13.47->boto3->pytorch-pretrained-bert) (0.15.2)\n",
- "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.6/dist-packages (from python-dateutil<3.0.0,>=2.1; python_version >= \"2.7\"->botocore<1.14.0,>=1.13.47->boto3->pytorch-pretrained-bert) (1.12.0)\n",
- "Installing collected packages: pytorch-pretrained-bert, pytorch-nlp\n",
- "Successfully installed pytorch-nlp-0.5.0 pytorch-pretrained-bert-0.6.2\n"
+ "\u001b[K |████████████████████████████████| 92kB 10.6MB/s \n",
+ "\u001b[?25hRequirement already satisfied: torch>=0.4.1 in /usr/local/lib/python3.7/dist-packages (from pytorch-pretrained-bert) (1.8.1+cu101)\n",
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from pytorch-pretrained-bert) (1.19.5)\n",
+ "Requirement already satisfied: requests in /usr/local/lib/python3.7/dist-packages (from pytorch-pretrained-bert) (2.23.0)\n",
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.7/dist-packages (from pytorch-pretrained-bert) (4.41.1)\n",
+ "Collecting boto3\n",
+ "\u001b[?25l Downloading https://files.pythonhosted.org/packages/7c/e1/1b164502f455035def771ec7a31f705351b7f953695d57ce26219aaf21a9/boto3-1.17.90-py2.py3-none-any.whl (131kB)\n",
+ "\u001b[K |████████████████████████████████| 133kB 22.1MB/s \n",
+ "\u001b[?25hRequirement already satisfied: regex in /usr/local/lib/python3.7/dist-packages (from pytorch-pretrained-bert) (2019.12.20)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.7/dist-packages (from torch>=0.4.1->pytorch-pretrained-bert) (3.7.4.3)\n",
+ "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.7/dist-packages (from requests->pytorch-pretrained-bert) (1.24.3)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.7/dist-packages (from requests->pytorch-pretrained-bert) (2020.12.5)\n",
+ "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.7/dist-packages (from requests->pytorch-pretrained-bert) (3.0.4)\n",
+ "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.7/dist-packages (from requests->pytorch-pretrained-bert) (2.10)\n",
+ "Collecting botocore<1.21.0,>=1.20.90\n",
+ "\u001b[?25l Downloading https://files.pythonhosted.org/packages/4a/ac/617d3ac25ea905279deb06edd82d6c19ca272006d6dcf232b837b75c3dde/botocore-1.20.90-py2.py3-none-any.whl (7.6MB)\n",
+ "\u001b[K |████████████████████████████████| 7.6MB 25.6MB/s \n",
+ "\u001b[?25hCollecting s3transfer<0.5.0,>=0.4.0\n",
+ "\u001b[?25l Downloading https://files.pythonhosted.org/packages/63/d0/693477c688348654ddc21dcdce0817653a294aa43f41771084c25e7ff9c7/s3transfer-0.4.2-py2.py3-none-any.whl (79kB)\n",
+ "\u001b[K |████████████████████████████████| 81kB 10.0MB/s \n",
+ "\u001b[?25hCollecting jmespath<1.0.0,>=0.7.1\n",
+ " Downloading https://files.pythonhosted.org/packages/07/cb/5f001272b6faeb23c1c9e0acc04d48eaaf5c862c17709d20e3469c6e0139/jmespath-0.10.0-py2.py3-none-any.whl\n",
+ "Requirement already satisfied: python-dateutil<3.0.0,>=2.1 in /usr/local/lib/python3.7/dist-packages (from botocore<1.21.0,>=1.20.90->boto3->pytorch-pretrained-bert) (2.8.1)\n",
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil<3.0.0,>=2.1->botocore<1.21.0,>=1.20.90->boto3->pytorch-pretrained-bert) (1.15.0)\n",
+ "\u001b[31mERROR: botocore 1.20.90 has requirement urllib3<1.27,>=1.25.4, but you'll have urllib3 1.24.3 which is incompatible.\u001b[0m\n",
+ "Installing collected packages: jmespath, botocore, s3transfer, boto3, pytorch-pretrained-bert, pytorch-nlp\n",
+ "Successfully installed boto3-1.17.90 botocore-1.20.90 jmespath-0.10.0 pytorch-nlp-0.5.0 pytorch-pretrained-bert-0.6.2 s3transfer-0.4.2\n",
+ "TensorFlow 1.x selected.\n"
]
},
{
@@ -107,36 +75,28 @@
"text": [
"Using TensorFlow backend.\n"
]
- },
- {
- "data": {
- "text/plain": [
- "'Tesla T4'"
- ]
- },
- "execution_count": 1,
- "metadata": {
- "tags": []
- },
- "output_type": "execute_result"
}
],
"source": [
+ "import warnings\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "# install\n",
+ "!pip install pytorch-pretrained-bert pytorch-nlp\n",
"\n",
- "#if not using colab, comment below line\n",
- "%tensorflow_version 1.x\n",
+ "try : \n",
+ " from google.colab import files\n",
+ " %tensorflow_version 1.x\n",
+ " \n",
+ "except ModuleNotFoundError :\n",
+ " print(\"Not Using Colab\")\n",
"\n",
- "from torch.nn import Adam\n",
"DATA_DIR=\".\"\n",
"import os\n",
"import numpy as np\n",
"import pickle\n",
"import tensorflow as tf\n",
"\n",
- "\n",
- "# install\n",
- "!pip install pytorch-pretrained-bert pytorch-nlp\n",
- "\n",
"# BERT imports\n",
"import torch\n",
"from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler\n",
@@ -149,31 +109,34 @@
"import io\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
- "% matplotlib inline\n",
+ "%matplotlib inline\n",
"\n",
"# specify GPU device\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
- "n_gpu = torch.cuda.device_count()\n",
- "torch.cuda.get_device_name(0)"
+ "if torch.cuda.is_available():\n",
+ " n_gpu = torch.cuda.device_count()\n",
+ " torch.cuda.get_device_name(0)"
]
},
{
"cell_type": "markdown",
- "metadata": {},
+ "metadata": {
+ "id": "xhRHs0sugqC_"
+ },
"source": [
"## Data Loading"
]
},
{
"cell_type": "code",
- "execution_count": 0,
+ "execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
- "height": 109,
+ "height": 162,
"resources": {
"http://localhost:8080/nbextensions/google.colab/files.js": {
- "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7Ci8vIE1heCBhbW91bnQgb2YgdGltZSB0byBibG9jayB3YWl0aW5nIGZvciB0aGUgdXNlci4KY29uc3QgRklMRV9DSEFOR0VfVElNRU9VVF9NUyA9IDMwICogMTAwMDsKCmZ1bmN0aW9uIF91cGxvYWRGaWxlcyhpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IHN0ZXBzID0gdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKTsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIC8vIENhY2hlIHN0ZXBzIG9uIHRoZSBvdXRwdXRFbGVtZW50IHRvIG1ha2UgaXQgYXZhaWxhYmxlIGZvciB0aGUgbmV4dCBjYWxsCiAgLy8gdG8gdXBsb2FkRmlsZXNDb250aW51ZSBmcm9tIFB5dGhvbi4KICBvdXRwdXRFbGVtZW50LnN0ZXBzID0gc3RlcHM7CgogIHJldHVybiBfdXBsb2FkRmlsZXNDb250aW51ZShvdXRwdXRJZCk7Cn0KCi8vIFRoaXMgaXMgcm91Z2hseSBhbiBhc3luYyBnZW5lcmF0b3IgKG5vdCBzdXBwb3J0ZWQgaW4gdGhlIGJyb3dzZXIgeWV0KSwKLy8gd2hlcmUgdGhlcmUgYXJlIG11bHRpcGxlIGFzeW5jaHJvbm91cyBzdGVwcyBhbmQgdGhlIFB5dGhvbiBzaWRlIGlzIGdvaW5nCi8vIHRvIHBvbGwgZm9yIGNvbXBsZXRpb24gb2YgZWFjaCBzdGVwLgovLyBUaGlzIHVzZXMgYSBQcm9taXNlIHRvIGJsb2NrIHRoZSBweXRob24gc2lkZSBvbiBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcCwKLy8gdGhlbiBwYXNzZXMgdGhlIHJlc3VsdCBvZiB0aGUgcHJldmlvdXMgc3RlcCBhcyB0aGUgaW5wdXQgdG8gdGhlIG5leHQgc3RlcC4KZnVuY3Rpb24gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpIHsKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIGNvbnN0IHN0ZXBzID0gb3V0cHV0RWxlbWVudC5zdGVwczsKCiAgY29uc3QgbmV4dCA9IHN0ZXBzLm5leHQob3V0cHV0RWxlbWVudC5sYXN0UHJvbWlzZVZhbHVlKTsKICByZXR1cm4gUHJvbWlzZS5yZXNvbHZlKG5leHQudmFsdWUucHJvbWlzZSkudGhlbigodmFsdWUpID0+IHsKICAgIC8vIENhY2hlIHRoZSBsYXN0IHByb21pc2UgdmFsdWUgdG8gbWFrZSBpdCBhdmFpbGFibGUgdG8gdGhlIG5leHQKICAgIC8vIHN0ZXAgb2YgdGhlIGdlbmVyYXRvci4KICAgIG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSA9IHZhbHVlOwogICAgcmV0dXJuIG5leHQudmFsdWUucmVzcG9uc2U7CiAgfSk7Cn0KCi8qKgogKiBHZW5lcmF0b3IgZnVuY3Rpb24gd2hpY2ggaXMgY2FsbGVkIGJldHdlZW4gZWFjaCBhc3luYyBzdGVwIG9mIHRoZSB1cGxvYWQKICogcHJvY2Vzcy4KICogQHBhcmFtIHtzdHJpbmd9IGlucHV0SWQgRWxlbWVudCBJRCBvZiB0aGUgaW5wdXQgZmlsZSBwaWNrZXIgZWxlbWVudC4KICogQHBhcmFtIHtzdHJpbmd9IG91dHB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIG91dHB1dCBkaXNwbGF5LgogKiBAcmV0dXJuIHshSXRlcmFibGU8IU9iamVjdD59IEl0ZXJhYmxlIG9mIG5leHQgc3RlcHMuCiAqLwpmdW5jdGlvbiogdXBsb2FkRmlsZXNTdGVwKGlucHV0SWQsIG91dHB1dElkKSB7CiAgY29uc3QgaW5wdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQoaW5wdXRJZCk7CiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gZmFsc2U7CgogIGNvbnN0IG91dHB1dEVsZW1lbnQgPSBkb2N1bWVudC5nZXRFbGVtZW50QnlJZChvdXRwdXRJZCk7CiAgb3V0cHV0RWxlbWVudC5pbm5lckhUTUwgPSAnJzsKCiAgY29uc3QgcGlja2VkUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBpbnB1dEVsZW1lbnQuYWRkRXZlbnRMaXN0ZW5lcignY2hhbmdlJywgKGUpID0+IHsKICAgICAgcmVzb2x2ZShlLnRhcmdldC5maWxlcyk7CiAgICB9KTsKICB9KTsKCiAgY29uc3QgY2FuY2VsID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnYnV0dG9uJyk7CiAgaW5wdXRFbGVtZW50LnBhcmVudEVsZW1lbnQuYXBwZW5kQ2hpbGQoY2FuY2VsKTsKICBjYW5jZWwudGV4dENvbnRlbnQgPSAnQ2FuY2VsIHVwbG9hZCc7CiAgY29uc3QgY2FuY2VsUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICBjYW5jZWwub25jbGljayA9ICgpID0+IHsKICAgICAgcmVzb2x2ZShudWxsKTsKICAgIH07CiAgfSk7CgogIC8vIENhbmNlbCB1cGxvYWQgaWYgdXNlciBoYXNuJ3QgcGlja2VkIGFueXRoaW5nIGluIHRpbWVvdXQuCiAgY29uc3QgdGltZW91dFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgc2V0VGltZW91dCgoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9LCBGSUxFX0NIQU5HRV9USU1FT1VUX01TKTsKICB9KTsKCiAgLy8gV2FpdCBmb3IgdGhlIHVzZXIgdG8gcGljayB0aGUgZmlsZXMuCiAgY29uc3QgZmlsZXMgPSB5aWVsZCB7CiAgICBwcm9taXNlOiBQcm9taXNlLnJhY2UoW3BpY2tlZFByb21pc2UsIHRpbWVvdXRQcm9taXNlLCBjYW5jZWxQcm9taXNlXSksCiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdzdGFydGluZycsCiAgICB9CiAgfTsKCiAgaWYgKCFmaWxlcykgewogICAgcmV0dXJuIHsKICAgICAgcmVzcG9uc2U6IHsKICAgICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICAgIH0KICAgIH07CiAgfQoKICBjYW5jZWwucmVtb3ZlKCk7CgogIC8vIERpc2FibGUgdGhlIGlucHV0IGVsZW1lbnQgc2luY2UgZnVydGhlciBwaWNrcyBhcmUgbm90IGFsbG93ZWQuCiAgaW5wdXRFbGVtZW50LmRpc2FibGVkID0gdHJ1ZTsKCiAgZm9yIChjb25zdCBmaWxlIG9mIGZpbGVzKSB7CiAgICBjb25zdCBsaSA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2xpJyk7CiAgICBsaS5hcHBlbmQoc3BhbihmaWxlLm5hbWUsIHtmb250V2VpZ2h0OiAnYm9sZCd9KSk7CiAgICBsaS5hcHBlbmQoc3BhbigKICAgICAgICBgKCR7ZmlsZS50eXBlIHx8ICduL2EnfSkgLSAke2ZpbGUuc2l6ZX0gYnl0ZXMsIGAgKwogICAgICAgIGBsYXN0IG1vZGlmaWVkOiAkewogICAgICAgICAgICBmaWxlLmxhc3RNb2RpZmllZERhdGUgPyBmaWxlLmxhc3RNb2RpZmllZERhdGUudG9Mb2NhbGVEYXRlU3RyaW5nKCkgOgogICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAnbi9hJ30gLSBgKSk7CiAgICBjb25zdCBwZXJjZW50ID0gc3BhbignMCUgZG9uZScpOwogICAgbGkuYXBwZW5kQ2hpbGQocGVyY2VudCk7CgogICAgb3V0cHV0RWxlbWVudC5hcHBlbmRDaGlsZChsaSk7CgogICAgY29uc3QgZmlsZURhdGFQcm9taXNlID0gbmV3IFByb21pc2UoKHJlc29sdmUpID0+IHsKICAgICAgY29uc3QgcmVhZGVyID0gbmV3IEZpbGVSZWFkZXIoKTsKICAgICAgcmVhZGVyLm9ubG9hZCA9IChlKSA9PiB7CiAgICAgICAgcmVzb2x2ZShlLnRhcmdldC5yZXN1bHQpOwogICAgICB9OwogICAgICByZWFkZXIucmVhZEFzQXJyYXlCdWZmZXIoZmlsZSk7CiAgICB9KTsKICAgIC8vIFdhaXQgZm9yIHRoZSBkYXRhIHRvIGJlIHJlYWR5LgogICAgbGV0IGZpbGVEYXRhID0geWllbGQgewogICAgICBwcm9taXNlOiBmaWxlRGF0YVByb21pc2UsCiAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgYWN0aW9uOiAnY29udGludWUnLAogICAgICB9CiAgICB9OwoKICAgIC8vIFVzZSBhIGNodW5rZWQgc2VuZGluZyB0byBhdm9pZCBtZXNzYWdlIHNpemUgbGltaXRzLiBTZWUgYi82MjExNTY2MC4KICAgIGxldCBwb3NpdGlvbiA9IDA7CiAgICB3aGlsZSAocG9zaXRpb24gPCBmaWxlRGF0YS5ieXRlTGVuZ3RoKSB7CiAgICAgIGNvbnN0IGxlbmd0aCA9IE1hdGgubWluKGZpbGVEYXRhLmJ5dGVMZW5ndGggLSBwb3NpdGlvbiwgTUFYX1BBWUxPQURfU0laRSk7CiAgICAgIGNvbnN0IGNodW5rID0gbmV3IFVpbnQ4QXJyYXkoZmlsZURhdGEsIHBvc2l0aW9uLCBsZW5ndGgpOwogICAgICBwb3NpdGlvbiArPSBsZW5ndGg7CgogICAgICBjb25zdCBiYXNlNjQgPSBidG9hKFN0cmluZy5mcm9tQ2hhckNvZGUuYXBwbHkobnVsbCwgY2h1bmspKTsKICAgICAgeWllbGQgewogICAgICAgIHJlc3BvbnNlOiB7CiAgICAgICAgICBhY3Rpb246ICdhcHBlbmQnLAogICAgICAgICAgZmlsZTogZmlsZS5uYW1lLAogICAgICAgICAgZGF0YTogYmFzZTY0LAogICAgICAgIH0sCiAgICAgIH07CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPQogICAgICAgICAgYCR7TWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCl9JSBkb25lYDsKICAgIH0KICB9CgogIC8vIEFsbCBkb25lLgogIHlpZWxkIHsKICAgIHJlc3BvbnNlOiB7CiAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgIH0KICB9Owp9CgpzY29wZS5nb29nbGUgPSBzY29wZS5nb29nbGUgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYiA9IHNjb3BlLmdvb2dsZS5jb2xhYiB8fCB7fTsKc2NvcGUuZ29vZ2xlLmNvbGFiLl9maWxlcyA9IHsKICBfdXBsb2FkRmlsZXMsCiAgX3VwbG9hZEZpbGVzQ29udGludWUsCn07Cn0pKHNlbGYpOwo=",
+ "data": "Ly8gQ29weXJpZ2h0IDIwMTcgR29vZ2xlIExMQwovLwovLyBMaWNlbnNlZCB1bmRlciB0aGUgQXBhY2hlIExpY2Vuc2UsIFZlcnNpb24gMi4wICh0aGUgIkxpY2Vuc2UiKTsKLy8geW91IG1heSBub3QgdXNlIHRoaXMgZmlsZSBleGNlcHQgaW4gY29tcGxpYW5jZSB3aXRoIHRoZSBMaWNlbnNlLgovLyBZb3UgbWF5IG9idGFpbiBhIGNvcHkgb2YgdGhlIExpY2Vuc2UgYXQKLy8KLy8gICAgICBodHRwOi8vd3d3LmFwYWNoZS5vcmcvbGljZW5zZXMvTElDRU5TRS0yLjAKLy8KLy8gVW5sZXNzIHJlcXVpcmVkIGJ5IGFwcGxpY2FibGUgbGF3IG9yIGFncmVlZCB0byBpbiB3cml0aW5nLCBzb2Z0d2FyZQovLyBkaXN0cmlidXRlZCB1bmRlciB0aGUgTGljZW5zZSBpcyBkaXN0cmlidXRlZCBvbiBhbiAiQVMgSVMiIEJBU0lTLAovLyBXSVRIT1VUIFdBUlJBTlRJRVMgT1IgQ09ORElUSU9OUyBPRiBBTlkgS0lORCwgZWl0aGVyIGV4cHJlc3Mgb3IgaW1wbGllZC4KLy8gU2VlIHRoZSBMaWNlbnNlIGZvciB0aGUgc3BlY2lmaWMgbGFuZ3VhZ2UgZ292ZXJuaW5nIHBlcm1pc3Npb25zIGFuZAovLyBsaW1pdGF0aW9ucyB1bmRlciB0aGUgTGljZW5zZS4KCi8qKgogKiBAZmlsZW92ZXJ2aWV3IEhlbHBlcnMgZm9yIGdvb2dsZS5jb2xhYiBQeXRob24gbW9kdWxlLgogKi8KKGZ1bmN0aW9uKHNjb3BlKSB7CmZ1bmN0aW9uIHNwYW4odGV4dCwgc3R5bGVBdHRyaWJ1dGVzID0ge30pIHsKICBjb25zdCBlbGVtZW50ID0gZG9jdW1lbnQuY3JlYXRlRWxlbWVudCgnc3BhbicpOwogIGVsZW1lbnQudGV4dENvbnRlbnQgPSB0ZXh0OwogIGZvciAoY29uc3Qga2V5IG9mIE9iamVjdC5rZXlzKHN0eWxlQXR0cmlidXRlcykpIHsKICAgIGVsZW1lbnQuc3R5bGVba2V5XSA9IHN0eWxlQXR0cmlidXRlc1trZXldOwogIH0KICByZXR1cm4gZWxlbWVudDsKfQoKLy8gTWF4IG51bWJlciBvZiBieXRlcyB3aGljaCB3aWxsIGJlIHVwbG9hZGVkIGF0IGEgdGltZS4KY29uc3QgTUFYX1BBWUxPQURfU0laRSA9IDEwMCAqIDEwMjQ7CgpmdW5jdGlvbiBfdXBsb2FkRmlsZXMoaW5wdXRJZCwgb3V0cHV0SWQpIHsKICBjb25zdCBzdGVwcyA9IHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCk7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICAvLyBDYWNoZSBzdGVwcyBvbiB0aGUgb3V0cHV0RWxlbWVudCB0byBtYWtlIGl0IGF2YWlsYWJsZSBmb3IgdGhlIG5leHQgY2FsbAogIC8vIHRvIHVwbG9hZEZpbGVzQ29udGludWUgZnJvbSBQeXRob24uCiAgb3V0cHV0RWxlbWVudC5zdGVwcyA9IHN0ZXBzOwoKICByZXR1cm4gX3VwbG9hZEZpbGVzQ29udGludWUob3V0cHV0SWQpOwp9CgovLyBUaGlzIGlzIHJvdWdobHkgYW4gYXN5bmMgZ2VuZXJhdG9yIChub3Qgc3VwcG9ydGVkIGluIHRoZSBicm93c2VyIHlldCksCi8vIHdoZXJlIHRoZXJlIGFyZSBtdWx0aXBsZSBhc3luY2hyb25vdXMgc3RlcHMgYW5kIHRoZSBQeXRob24gc2lkZSBpcyBnb2luZwovLyB0byBwb2xsIGZvciBjb21wbGV0aW9uIG9mIGVhY2ggc3RlcC4KLy8gVGhpcyB1c2VzIGEgUHJvbWlzZSB0byBibG9jayB0aGUgcHl0aG9uIHNpZGUgb24gY29tcGxldGlvbiBvZiBlYWNoIHN0ZXAsCi8vIHRoZW4gcGFzc2VzIHRoZSByZXN1bHQgb2YgdGhlIHByZXZpb3VzIHN0ZXAgYXMgdGhlIGlucHV0IHRvIHRoZSBuZXh0IHN0ZXAuCmZ1bmN0aW9uIF91cGxvYWRGaWxlc0NvbnRpbnVlKG91dHB1dElkKSB7CiAgY29uc3Qgb3V0cHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKG91dHB1dElkKTsKICBjb25zdCBzdGVwcyA9IG91dHB1dEVsZW1lbnQuc3RlcHM7CgogIGNvbnN0IG5leHQgPSBzdGVwcy5uZXh0KG91dHB1dEVsZW1lbnQubGFzdFByb21pc2VWYWx1ZSk7CiAgcmV0dXJuIFByb21pc2UucmVzb2x2ZShuZXh0LnZhbHVlLnByb21pc2UpLnRoZW4oKHZhbHVlKSA9PiB7CiAgICAvLyBDYWNoZSB0aGUgbGFzdCBwcm9taXNlIHZhbHVlIHRvIG1ha2UgaXQgYXZhaWxhYmxlIHRvIHRoZSBuZXh0CiAgICAvLyBzdGVwIG9mIHRoZSBnZW5lcmF0b3IuCiAgICBvdXRwdXRFbGVtZW50Lmxhc3RQcm9taXNlVmFsdWUgPSB2YWx1ZTsKICAgIHJldHVybiBuZXh0LnZhbHVlLnJlc3BvbnNlOwogIH0pOwp9CgovKioKICogR2VuZXJhdG9yIGZ1bmN0aW9uIHdoaWNoIGlzIGNhbGxlZCBiZXR3ZWVuIGVhY2ggYXN5bmMgc3RlcCBvZiB0aGUgdXBsb2FkCiAqIHByb2Nlc3MuCiAqIEBwYXJhbSB7c3RyaW5nfSBpbnB1dElkIEVsZW1lbnQgSUQgb2YgdGhlIGlucHV0IGZpbGUgcGlja2VyIGVsZW1lbnQuCiAqIEBwYXJhbSB7c3RyaW5nfSBvdXRwdXRJZCBFbGVtZW50IElEIG9mIHRoZSBvdXRwdXQgZGlzcGxheS4KICogQHJldHVybiB7IUl0ZXJhYmxlPCFPYmplY3Q+fSBJdGVyYWJsZSBvZiBuZXh0IHN0ZXBzLgogKi8KZnVuY3Rpb24qIHVwbG9hZEZpbGVzU3RlcChpbnB1dElkLCBvdXRwdXRJZCkgewogIGNvbnN0IGlucHV0RWxlbWVudCA9IGRvY3VtZW50LmdldEVsZW1lbnRCeUlkKGlucHV0SWQpOwogIGlucHV0RWxlbWVudC5kaXNhYmxlZCA9IGZhbHNlOwoKICBjb25zdCBvdXRwdXRFbGVtZW50ID0gZG9jdW1lbnQuZ2V0RWxlbWVudEJ5SWQob3V0cHV0SWQpOwogIG91dHB1dEVsZW1lbnQuaW5uZXJIVE1MID0gJyc7CgogIGNvbnN0IHBpY2tlZFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgaW5wdXRFbGVtZW50LmFkZEV2ZW50TGlzdGVuZXIoJ2NoYW5nZScsIChlKSA9PiB7CiAgICAgIHJlc29sdmUoZS50YXJnZXQuZmlsZXMpOwogICAgfSk7CiAgfSk7CgogIGNvbnN0IGNhbmNlbCA9IGRvY3VtZW50LmNyZWF0ZUVsZW1lbnQoJ2J1dHRvbicpOwogIGlucHV0RWxlbWVudC5wYXJlbnRFbGVtZW50LmFwcGVuZENoaWxkKGNhbmNlbCk7CiAgY2FuY2VsLnRleHRDb250ZW50ID0gJ0NhbmNlbCB1cGxvYWQnOwogIGNvbnN0IGNhbmNlbFByb21pc2UgPSBuZXcgUHJvbWlzZSgocmVzb2x2ZSkgPT4gewogICAgY2FuY2VsLm9uY2xpY2sgPSAoKSA9PiB7CiAgICAgIHJlc29sdmUobnVsbCk7CiAgICB9OwogIH0pOwoKICAvLyBXYWl0IGZvciB0aGUgdXNlciB0byBwaWNrIHRoZSBmaWxlcy4KICBjb25zdCBmaWxlcyA9IHlpZWxkIHsKICAgIHByb21pc2U6IFByb21pc2UucmFjZShbcGlja2VkUHJvbWlzZSwgY2FuY2VsUHJvbWlzZV0pLAogICAgcmVzcG9uc2U6IHsKICAgICAgYWN0aW9uOiAnc3RhcnRpbmcnLAogICAgfQogIH07CgogIGNhbmNlbC5yZW1vdmUoKTsKCiAgLy8gRGlzYWJsZSB0aGUgaW5wdXQgZWxlbWVudCBzaW5jZSBmdXJ0aGVyIHBpY2tzIGFyZSBub3QgYWxsb3dlZC4KICBpbnB1dEVsZW1lbnQuZGlzYWJsZWQgPSB0cnVlOwoKICBpZiAoIWZpbGVzKSB7CiAgICByZXR1cm4gewogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbXBsZXRlJywKICAgICAgfQogICAgfTsKICB9CgogIGZvciAoY29uc3QgZmlsZSBvZiBmaWxlcykgewogICAgY29uc3QgbGkgPSBkb2N1bWVudC5jcmVhdGVFbGVtZW50KCdsaScpOwogICAgbGkuYXBwZW5kKHNwYW4oZmlsZS5uYW1lLCB7Zm9udFdlaWdodDogJ2JvbGQnfSkpOwogICAgbGkuYXBwZW5kKHNwYW4oCiAgICAgICAgYCgke2ZpbGUudHlwZSB8fCAnbi9hJ30pIC0gJHtmaWxlLnNpemV9IGJ5dGVzLCBgICsKICAgICAgICBgbGFzdCBtb2RpZmllZDogJHsKICAgICAgICAgICAgZmlsZS5sYXN0TW9kaWZpZWREYXRlID8gZmlsZS5sYXN0TW9kaWZpZWREYXRlLnRvTG9jYWxlRGF0ZVN0cmluZygpIDoKICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgICAgJ24vYSd9IC0gYCkpOwogICAgY29uc3QgcGVyY2VudCA9IHNwYW4oJzAlIGRvbmUnKTsKICAgIGxpLmFwcGVuZENoaWxkKHBlcmNlbnQpOwoKICAgIG91dHB1dEVsZW1lbnQuYXBwZW5kQ2hpbGQobGkpOwoKICAgIGNvbnN0IGZpbGVEYXRhUHJvbWlzZSA9IG5ldyBQcm9taXNlKChyZXNvbHZlKSA9PiB7CiAgICAgIGNvbnN0IHJlYWRlciA9IG5ldyBGaWxlUmVhZGVyKCk7CiAgICAgIHJlYWRlci5vbmxvYWQgPSAoZSkgPT4gewogICAgICAgIHJlc29sdmUoZS50YXJnZXQucmVzdWx0KTsKICAgICAgfTsKICAgICAgcmVhZGVyLnJlYWRBc0FycmF5QnVmZmVyKGZpbGUpOwogICAgfSk7CiAgICAvLyBXYWl0IGZvciB0aGUgZGF0YSB0byBiZSByZWFkeS4KICAgIGxldCBmaWxlRGF0YSA9IHlpZWxkIHsKICAgICAgcHJvbWlzZTogZmlsZURhdGFQcm9taXNlLAogICAgICByZXNwb25zZTogewogICAgICAgIGFjdGlvbjogJ2NvbnRpbnVlJywKICAgICAgfQogICAgfTsKCiAgICAvLyBVc2UgYSBjaHVua2VkIHNlbmRpbmcgdG8gYXZvaWQgbWVzc2FnZSBzaXplIGxpbWl0cy4gU2VlIGIvNjIxMTU2NjAuCiAgICBsZXQgcG9zaXRpb24gPSAwOwogICAgZG8gewogICAgICBjb25zdCBsZW5ndGggPSBNYXRoLm1pbihmaWxlRGF0YS5ieXRlTGVuZ3RoIC0gcG9zaXRpb24sIE1BWF9QQVlMT0FEX1NJWkUpOwogICAgICBjb25zdCBjaHVuayA9IG5ldyBVaW50OEFycmF5KGZpbGVEYXRhLCBwb3NpdGlvbiwgbGVuZ3RoKTsKICAgICAgcG9zaXRpb24gKz0gbGVuZ3RoOwoKICAgICAgY29uc3QgYmFzZTY0ID0gYnRvYShTdHJpbmcuZnJvbUNoYXJDb2RlLmFwcGx5KG51bGwsIGNodW5rKSk7CiAgICAgIHlpZWxkIHsKICAgICAgICByZXNwb25zZTogewogICAgICAgICAgYWN0aW9uOiAnYXBwZW5kJywKICAgICAgICAgIGZpbGU6IGZpbGUubmFtZSwKICAgICAgICAgIGRhdGE6IGJhc2U2NCwKICAgICAgICB9LAogICAgICB9OwoKICAgICAgbGV0IHBlcmNlbnREb25lID0gZmlsZURhdGEuYnl0ZUxlbmd0aCA9PT0gMCA/CiAgICAgICAgICAxMDAgOgogICAgICAgICAgTWF0aC5yb3VuZCgocG9zaXRpb24gLyBmaWxlRGF0YS5ieXRlTGVuZ3RoKSAqIDEwMCk7CiAgICAgIHBlcmNlbnQudGV4dENvbnRlbnQgPSBgJHtwZXJjZW50RG9uZX0lIGRvbmVgOwoKICAgIH0gd2hpbGUgKHBvc2l0aW9uIDwgZmlsZURhdGEuYnl0ZUxlbmd0aCk7CiAgfQoKICAvLyBBbGwgZG9uZS4KICB5aWVsZCB7CiAgICByZXNwb25zZTogewogICAgICBhY3Rpb246ICdjb21wbGV0ZScsCiAgICB9CiAgfTsKfQoKc2NvcGUuZ29vZ2xlID0gc2NvcGUuZ29vZ2xlIHx8IHt9OwpzY29wZS5nb29nbGUuY29sYWIgPSBzY29wZS5nb29nbGUuY29sYWIgfHwge307CnNjb3BlLmdvb2dsZS5jb2xhYi5fZmlsZXMgPSB7CiAgX3VwbG9hZEZpbGVzLAogIF91cGxvYWRGaWxlc0NvbnRpbnVlLAp9Owp9KShzZWxmKTsK",
"headers": [
[
"content-type",
@@ -186,17 +149,17 @@
}
}
},
- "colab_type": "code",
"id": "ntUZndfrE9W8",
- "outputId": "5798e9c4-0b51-430b-cf75-6b139a55cc10"
+ "outputId": "fe735e2b-7702-40e8-dbd2-fc7113e78901"
},
"outputs": [
{
"data": {
"text/html": [
"\n",
- " \n",
- "