|
11 | 11 | "In this notebook we build a binarry classifier for the ATIS Dataset using [BERT](https://arxiv.org/abs/1810.04805), a pre-Trained NLP model open soucred by google in late 2018 that can be used for [Transfer Learning](https://towardsdatascience.com/transfer-learning-in-nlp-fecc59f546e4) on text data. This notebook has been adapted from this [Article](https://towardsdatascience.com/bert-for-dummies-step-by-step-tutorial-fb90890ffe03). The link for the dataset can be found [here](https://www.kaggle.com/siddhadev/ms-cntk-atis/data#).<br> This notebook requires a GPU to get setup. We suggest you to run this on your local machine only if you have a GPU setup or else you can use google colab."
|
12 | 12 | ]
|
13 | 13 | },
|
| 14 | + { |
| 15 | + "cell_type": "markdown", |
| 16 | + "metadata": {}, |
| 17 | + "source": [ |
| 18 | + "## Imports" |
| 19 | + ] |
| 20 | + }, |
14 | 21 | {
|
15 | 22 | "cell_type": "code",
|
16 | 23 | "execution_count": 0,
|
|
115 | 122 | }
|
116 | 123 | ],
|
117 | 124 | "source": [
|
118 |
| - "#importing a few necessary packages and setting the DATA directory\n", |
119 | 125 | "\n",
|
| 126 | + "#if not using colab, comment below line\n", |
120 | 127 | "%tensorflow_version 1.x\n",
|
121 | 128 | "\n",
|
122 | 129 | "from torch.nn import Adam\n",
|
|
150 | 157 | "torch.cuda.get_device_name(0)"
|
151 | 158 | ]
|
152 | 159 | },
|
| 160 | + { |
| 161 | + "cell_type": "markdown", |
| 162 | + "metadata": {}, |
| 163 | + "source": [ |
| 164 | + "## Data Loading" |
| 165 | + ] |
| 166 | + }, |
153 | 167 | {
|
154 | 168 | "cell_type": "code",
|
155 | 169 | "execution_count": 0,
|
|
345 | 359 | "query_data_test, intent_data_test, intent_data_label_test, slot_data_test = load_atis('atis.test.pkl')\n"
|
346 | 360 | ]
|
347 | 361 | },
|
| 362 | + { |
| 363 | + "cell_type": "markdown", |
| 364 | + "metadata": {}, |
| 365 | + "source": [ |
| 366 | + "Let's look at a few training queries." |
| 367 | + ] |
| 368 | + }, |
348 | 369 | {
|
349 | 370 | "cell_type": "code",
|
350 | 371 | "execution_count": 0,
|
|
381 | 402 | "query_data_train"
|
382 | 403 | ]
|
383 | 404 | },
|
| 405 | + { |
| 406 | + "cell_type": "markdown", |
| 407 | + "metadata": {}, |
| 408 | + "source": [ |
| 409 | + "## Data Pre-processing\n", |
| 410 | + "We need to convert the sentences to tensors." |
| 411 | + ] |
| 412 | + }, |
384 | 413 | {
|
385 | 414 | "cell_type": "code",
|
386 | 415 | "execution_count": 0,
|
|
431 | 460 | ]
|
432 | 461 | },
|
433 | 462 | {
|
434 |
| - "cell_type": "code", |
435 |
| - "execution_count": 0, |
436 |
| - "metadata": { |
437 |
| - "colab": {}, |
438 |
| - "colab_type": "code", |
439 |
| - "id": "S9SMEwslo-ve" |
440 |
| - }, |
441 |
| - "outputs": [], |
442 |
| - "source": [] |
| 463 | + "cell_type": "markdown", |
| 464 | + "metadata": {}, |
| 465 | + "source": [ |
| 466 | + "BERT expects data to be in a specific format, i.e, [CLS] token1,token2,....[SEP]" |
| 467 | + ] |
443 | 468 | },
|
444 | 469 | {
|
445 | 470 | "cell_type": "code",
|
|
508 | 533 | "input_ids = pad_sequences(input_ids, maxlen=MAX_LEN, dtype=\"long\", truncating=\"post\", padding=\"post\")"
|
509 | 534 | ]
|
510 | 535 | },
|
| 536 | + { |
| 537 | + "cell_type": "markdown", |
| 538 | + "metadata": {}, |
| 539 | + "source": [ |
| 540 | + "Creating the BERT attention masks" |
| 541 | + ] |
| 542 | + }, |
511 | 543 | {
|
512 | 544 | "cell_type": "code",
|
513 | 545 | "execution_count": 0,
|
|
579 | 611 | "validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)\n"
|
580 | 612 | ]
|
581 | 613 | },
|
| 614 | + { |
| 615 | + "cell_type": "markdown", |
| 616 | + "metadata": {}, |
| 617 | + "source": [ |
| 618 | + "## Training" |
| 619 | + ] |
| 620 | + }, |
582 | 621 | {
|
583 | 622 | "cell_type": "code",
|
584 | 623 | "execution_count": 0,
|
|
913 | 952 | "model.cuda()"
|
914 | 953 | ]
|
915 | 954 | },
|
| 955 | + { |
| 956 | + "cell_type": "markdown", |
| 957 | + "metadata": {}, |
| 958 | + "source": [ |
| 959 | + "## Fine-Tuning BERT" |
| 960 | + ] |
| 961 | + }, |
916 | 962 | {
|
917 | 963 | "cell_type": "code",
|
918 | 964 | "execution_count": 0,
|
|
1149 | 1195 | "name": "python",
|
1150 | 1196 | "nbconvert_exporter": "python",
|
1151 | 1197 | "pygments_lexer": "ipython3",
|
1152 |
| - "version": "3.6.10" |
| 1198 | + "version": "3.6.12" |
1153 | 1199 | }
|
1154 | 1200 | },
|
1155 | 1201 | "nbformat": 4,
|
|
0 commit comments