In [0]:
pip install allennlp

In [2]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd /content/gdrive/My Drive/Colab Notebooks/NLP Labs/hw 6

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
/content/gdrive/My Drive/Colab Notebooks/NLP Labs/hw 6


In [0]:
import sys
sys.path.insert(0, '/content/gdrive/My Drive/Colab Notebooks/NLP Labs/hw 6')

In [0]:
import torch.optim as optim

from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import StanfordSentimentTreeBankDatasetReader
from allennlp.data.iterators import BucketIterator
from allennlp.data.token_indexers import PretrainedBertIndexer
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import archive_model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import PretrainedBertEmbedder
from allennlp.training.trainer import Trainer

from sst_classifier import LstmClassifier
from transformer_encoder import TransformerSeq2VecEncoder

In [0]:
HIDDEN = 256
EMBEDDED = 512
max_sequence_length = 100

In [0]:
token_indexer = PretrainedBertIndexer(pretrained_model='bert-base-uncased', max_pieces=max_sequence_length, do_lowercase=True)

reader = StanfordSentimentTreeBankDatasetReader(token_indexers={'tokens': token_indexer})

In [7]:
train_dataset = reader.read('trees/train.txt')
dev_dataset = reader.read('trees/dev.txt')

8544it [00:02, 4127.07it/s]
1101it [00:00, 5304.12it/s]


In [8]:
vocab = Vocabulary.from_instances(train_dataset + dev_dataset, min_count={'tokens': 3})

100%|██████████| 9645/9645 [00:00<00:00, 143105.49it/s]


In [0]:
bert_embedder = PretrainedBertEmbedder(pretrained_model='bert-base-uncased', top_layer_only=True)

word_embeddings: TextFieldEmbedder = BasicTextFieldEmbedder({"tokens": bert_embedder}, allow_unmatched_keys=True)

In [0]:
encoder = TransformerSeq2VecEncoder(EMBEDDED, HIDDEN, projection_dim=128, feedforward_hidden_dim=128, num_layers=2, num_attention_heads=2)

In [11]:
model_bert = LstmClassifier(word_embeddings, encoder, vocab)
model_bert.cuda()

LstmClassifier(
  (word_embeddings): BasicTextFieldEmbedder(
    (token_embedder_tokens): PretrainedBertEmbedder(
      (bert_model): BertModel(
        (embeddings): BertEmbeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (token_type_embeddings): Embedding(2, 768)
          (LayerNorm): BertLayerNorm()
          (dropout): Dropout(p=0.1)
        )
        (encoder): BertEncoder(
          (layer): ModuleList(
            (0): BertLayer(
              (attention): BertAttention(
                (self): BertSelfAttention(
                  (query): Linear(in_features=768, out_features=768, bias=True)
                  (key): Linear(in_features=768, out_features=768, bias=True)
                  (value): Linear(in_features=768, out_features=768, bias=True)
                  (dropout): Dropout(p=0.1)
                )
                (output): BertSelfOutput(
                  (dense): Linear(in_featu

In [0]:
optimizer = optim.Adam(model_bert.parameters(), lr=1e-4, weight_decay=1e-5)

In [0]:
iterator = BucketIterator(batch_size=64, sorting_keys=[("tokens", "num_tokens")])

iterator.index_with(vocab)

In [0]:
trainer = Trainer(model=model_bert, optimizer=optimizer, iterator=iterator, train_dataset=train_dataset, validation_dataset=dev_dataset, cuda_device=0, patience=5, num_epochs=20)

In [15]:
history = trainer.train()

accuracy: 0.2960, precision: 0.2437, recall: 0.0266, f1_measure: 0.0479, loss: 1.5637 ||: 100%|██████████| 134/134 [00:20<00:00,  6.09it/s]
accuracy: 0.3942, precision: 0.0000, recall: 0.0000, f1_measure: 0.0000, loss: 1.3816 ||: 100%|██████████| 18/18 [00:02<00:00,  6.90it/s]
accuracy: 0.4158, precision: 0.3731, recall: 0.1145, f1_measure: 0.1752, loss: 1.3124 ||: 100%|██████████| 134/134 [00:20<00:00,  5.52it/s]
accuracy: 0.4251, precision: 0.5385, recall: 0.0504, f1_measure: 0.0921, loss: 1.3084 ||: 100%|██████████| 18/18 [00:02<00:00,  6.99it/s]
accuracy: 0.4526, precision: 0.4064, recall: 0.2225, f1_measure: 0.2876, loss: 1.2345 ||: 100%|██████████| 134/134 [00:19<00:00,  6.74it/s]
accuracy: 0.4296, precision: 0.4118, recall: 0.2518, f1_measure: 0.3125, loss: 1.2711 ||: 100%|██████████| 18/18 [00:02<00:00,  7.30it/s]
accuracy: 0.4772, precision: 0.5016, recall: 0.2793, f1_measure: 0.3588, loss: 1.1877 ||: 100%|██████████| 134/134 [00:19<00:00,  6.95it/s]
accuracy: 0.4242, precisio

In [16]:
history

{'best_epoch': 6,
 'best_validation_accuracy': 0.4623069936421435,
 'best_validation_f1_measure': 0.31192660550454093,
 'best_validation_loss': 1.22759124967787,
 'best_validation_precision': 0.43037974683544306,
 'best_validation_recall': 0.2446043165467626,
 'epoch': 10,
 'peak_cpu_memory_MB': 4218.396,
 'peak_gpu_0_memory_MB': 2215,
 'training_accuracy': 0.5639044943820225,
 'training_cpu_memory_MB': 4218.396,
 'training_duration': '00:04:05',
 'training_epochs': 10,
 'training_f1_measure': 0.4976911236531062,
 'training_gpu_0_memory_MB': 2215,
 'training_loss': 1.0254951897841782,
 'training_precision': 0.5659276546091015,
 'training_recall': 0.4441391941391941,
 'training_start_epoch': 0,
 'validation_accuracy': 0.4604904632152589,
 'validation_f1_measure': 0.31818181818177166,
 'validation_loss': 1.3216620087623596,
 'validation_precision': 0.43209876543209874,
 'validation_recall': 0.2517985611510791}