In [42]:
from typing import Iterator, List, Dict
import torch
import torch.optim as optim
from allennlp.nn import util as nn_util
import numpy as np
from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField
from allennlp.data.dataset_readers import DatasetReader
from allennlp.common.file_utils import cached_path
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator
from allennlp.training.trainer import Trainer
from allennlp.predictors import SentenceTaggerPredictor
from allennlp.models import crf_tagger

In [2]:
from allennlp.common.params import Params

In [3]:
from allennlp.modules import  ConditionalRandomField

In [4]:
class PosDatasetReader(DatasetReader):
    """
    DatasetReader for PoS tagging data, one sentence per line, like
        The###DET dog###NN ate###V the###DET apple###NN
    """

    def __init__(self, token_indexers: Dict[str, TokenIndexer] = None) -> None:
        super().__init__(lazy=False)
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}

    def text_to_instance(self, tokens: List[Token], tags: List[str] = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"sentence": sentence_field}

        if tags:
            label_field = SequenceLabelField(labels=tags, sequence_field=sentence_field)
            fields["labels"] = label_field

        return Instance(fields)

    def _read(self, file_path: str) -> Iterator[Instance]:
        with open(file_path) as f:
            for line in f:
                pairs = line.strip().split()
                sentence, tags = zip(*(pair.split("###") for pair in pairs))
                yield self.text_to_instance([Token(word) for word in sentence], tags)


In [5]:
class LstmCRFTagger(Model):
    def __init__(self,
                 word_embeddings: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        self.crf = ConditionalRandomField(num_tags=vocab.get_vocab_size('labels'))
        self.accuracy = CategoricalAccuracy()

    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        mask = get_text_field_mask(sentence)
        embeddings = self.word_embeddings(sentence)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}


In [6]:
class LstmTagger(Model):
    def __init__(self,
                 word_embeddings: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 vocab: Vocabulary) -> None:
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
        self.hidden2tag = torch.nn.Linear(in_features=encoder.get_output_dim(),
                                          out_features=vocab.get_vocab_size('labels'))
        self.accuracy = CategoricalAccuracy()

    def forward(self,
                sentence: Dict[str, torch.Tensor],
                labels: torch.Tensor = None) -> torch.Tensor:
        mask = get_text_field_mask(sentence)
        embeddings = self.word_embeddings(sentence)
        encoder_out = self.encoder(embeddings, mask)
        tag_logits = self.hidden2tag(encoder_out)
        output = {"tag_logits": tag_logits}
        if labels is not None:
            self.accuracy(tag_logits, labels, mask)
            output["loss"] = sequence_cross_entropy_with_logits(tag_logits, labels, mask)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}


In [7]:
reader = PosDatasetReader()
train_dataset = reader.read(
    '/home/pding/OneDrive/kph/kph/trainan.txt')
validation_dataset = reader.read(
    '/home/pding/OneDrive/kph/kph/testan.txt')
vocab = Vocabulary.from_instances(train_dataset + validation_dataset)
EMBEDDING_DIM = 200
HIDDEN_DIM = 100


11543it [00:08, 1309.54it/s]
2885it [00:02, 1339.03it/s]
100%|██████████| 14428/14428 [00:03<00:00, 4553.81it/s]


In [12]:
#token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
#                            embedding_dim=EMBEDDING_DIM)

token_embedding = Embedding.from_params(
                            vocab=vocab,
                            params=Params({'pretrained_file':'/home/pding/Documents/glove/glove.840B.300d.txt',
                                           'embedding_dim' : EMBEDDING_DIM})
                            )
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

2196017it [00:13, 161347.29it/s]


In [8]:
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                            embedding_dim=EMBEDDING_DIM)
word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

In [None]:
vocab.save_to_files("/tmp/vocabulary")

In [9]:
lstm = PytorchSeq2SeqWrapper(torch.nn.LSTM(EMBEDDING_DIM, HIDDEN_DIM, batch_first=True))


In [61]:
model = LstmTagger(word_embeddings, lstm, vocab)
model.cuda()
optimizer = optim.SGD(model.parameters(), lr=0.01)
iterator = BucketIterator(batch_size=300, biggest_batch_first=True, sorting_keys=[("sentence", "num_tokens")])
iterator.index_with(vocab)
trainer = Trainer(model=model,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  cuda_device=0,
                  patience=10,
                  num_epochs=10)


In [62]:
with open("model.th", 'rb') as f:
    model.load_state_dict(torch.load(f))

In [11]:
model2 = LstmCRFTagger(word_embeddings, lstm, vocab)
model2.cuda()
optimizer = optim.SGD(model2.parameters(), lr=0.01)
iterator = BucketIterator(batch_size=300, biggest_batch_first=True, sorting_keys=[("sentence", "num_tokens")])
iterator.index_with(vocab)
trainer2 = Trainer(model=model2,
                  optimizer=optimizer,
                  iterator=iterator,
                  train_dataset=train_dataset,
                  validation_dataset=validation_dataset,
                  cuda_device=0,
                  patience=10,
                  num_epochs=10)

In [12]:
with open("model2.th", 'rb') as f:
    model2.load_state_dict(torch.load(f))

In [63]:
trainer.train()

accuracy: 0.9021, loss: 1.0924 ||: 100%|██████████| 39/39 [00:09<00:00,  4.19it/s]
accuracy: 0.9375, loss: 0.8638 ||: 100%|██████████| 10/10 [00:00<00:00, 10.54it/s]
accuracy: 0.9384, loss: 0.7086 ||: 100%|██████████| 39/39 [00:06<00:00,  6.06it/s]
accuracy: 0.9375, loss: 0.5767 ||: 100%|██████████| 10/10 [00:00<00:00, 11.08it/s]
accuracy: 0.9384, loss: 0.4999 ||: 100%|██████████| 39/39 [00:06<00:00,  5.67it/s]
accuracy: 0.9375, loss: 0.4394 ||: 100%|██████████| 10/10 [00:01<00:00,  9.60it/s]
accuracy: 0.9384, loss: 0.4044 ||: 100%|██████████| 39/39 [00:07<00:00,  5.51it/s]
accuracy: 0.9375, loss: 0.3789 ||: 100%|██████████| 10/10 [00:00<00:00, 10.20it/s]
accuracy: 0.9384, loss: 0.3619 ||: 100%|██████████| 39/39 [00:06<00:00,  5.87it/s]
accuracy: 0.9375, loss: 0.3513 ||: 100%|██████████| 10/10 [00:00<00:00, 10.94it/s]
accuracy: 0.9384, loss: 0.3415 ||: 100%|██████████| 39/39 [00:16<00:00,  2.39it/s]
accuracy: 0.9375, loss: 0.3372 ||: 100%|██████████| 10/10 [00:00<00:00, 11.00it/s]
accu

{'best_epoch': 9,
 'peak_cpu_memory_MB': 3502.156,
 'peak_gpu_0_memory_MB': 2837,
 'training_duration': '0:01:28.724381',
 'training_start_epoch': 0,
 'training_epochs': 9,
 'epoch': 9,
 'training_accuracy': 0.9384352123415832,
 'training_loss': 0.31823092775467116,
 'training_cpu_memory_MB': 3502.156,
 'training_gpu_0_memory_MB': 2829,
 'validation_accuracy': 0.9374978140356399,
 'validation_loss': 0.3198258697986603,
 'best_validation_accuracy': 0.9374978140356399,
 'best_validation_loss': 0.3198258697986603}

In [14]:
trainer2.train()

accuracy: 0.9384, loss: 0.3140 ||: 100%|██████████| 39/39 [00:06<00:00,  5.88it/s]
accuracy: 0.9375, loss: 0.3160 ||: 100%|██████████| 10/10 [00:00<00:00, 10.33it/s]
accuracy: 0.9384, loss: 0.3125 ||: 100%|██████████| 39/39 [00:06<00:00,  5.95it/s]
accuracy: 0.9375, loss: 0.3146 ||: 100%|██████████| 10/10 [00:00<00:00, 10.64it/s]
accuracy: 0.9384, loss: 0.3111 ||: 100%|██████████| 39/39 [00:06<00:00,  5.96it/s]
accuracy: 0.9375, loss: 0.3134 ||: 100%|██████████| 10/10 [00:00<00:00, 10.68it/s]
accuracy: 0.9384, loss: 0.3099 ||: 100%|██████████| 39/39 [00:06<00:00,  5.98it/s]
accuracy: 0.9375, loss: 0.3121 ||: 100%|██████████| 10/10 [00:00<00:00, 10.60it/s]
accuracy: 0.9384, loss: 0.3089 ||: 100%|██████████| 39/39 [00:06<00:00,  5.95it/s]
accuracy: 0.9375, loss: 0.3110 ||: 100%|██████████| 10/10 [00:00<00:00, 10.89it/s]
accuracy: 0.9384, loss: 0.3080 ||: 100%|██████████| 39/39 [00:06<00:00,  5.75it/s]
accuracy: 0.9375, loss: 0.3102 ||: 100%|██████████| 10/10 [00:00<00:00, 10.38it/s]
accu

{'best_epoch': 9,
 'best_validation_accuracy': 0.9374978140356399,
 'best_validation_loss': 0.30665835589170454,
 'peak_cpu_memory_MB': 3439.356,
 'peak_gpu_0_memory_MB': 2813,
 'training_duration': '0:01:16.245005',
 'training_start_epoch': 0,
 'training_epochs': 9,
 'epoch': 9,
 'training_accuracy': 0.9384352123415832,
 'training_loss': 0.30433233693624157,
 'training_cpu_memory_MB': 3439.356,
 'training_gpu_0_memory_MB': 2813,
 'validation_accuracy': 0.9374978140356399,
 'validation_loss': 0.30665835589170454}

In [15]:
with open("model2_307.th", 'wb') as f:
    torch.save(model2.state_dict(), f)

In [69]:
predictor = SentenceTaggerPredictor(model, dataset_reader=reader)
tag_logits = predictor.predict("Exercise therapy is a promising nonpharmacological therapy in people with multiple sclerosis (MS). Although exercise training may induce a transient worsening of symptoms in some MS patients, it is generally considered safe and does not increase the risk of relapses. Exercise training can lead to clinically relevant improvements in physical function, but should be considered an adjunct to specific task-based training. Exercise has also shown positive effects on the brain, including improvements in brain volume and cognition. In summary, exercise therapy is a safe and potent nonpharmacological intervention in MS, with beneficial effects on both functional capacity and the brain. ")['tag_logits']
tag_ids = np.argmax(tag_logits, axis=-1)
print([model2.vocab.get_token_from_index(i, 'labels') for i in tag_ids])

['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


In [25]:
 predictor.predict("The dog ate the apple")

{'tag_logits': [[0.9840908646583557,
   -0.1494881808757782,
   -0.3755095601081848,
   -0.5525108575820923],
  [1.8476608991622925,
   -0.34337377548217773,
   -0.6281458735466003,
   -1.069543480873108],
  [2.6106700897216797,
   -0.5082828998565674,
   -0.9541940689086914,
   -1.522050142288208],
  [2.785907745361328,
   -0.544689416885376,
   -0.9694169163703918,
   -1.6415318250656128],
  [2.850036144256592,
   -0.5169174075126648,
   -0.9883534908294678,
   -1.6936215162277222]]}

In [46]:
from allennlp.data.iterators import DataIterator
from tqdm import tqdm
from scipy.special import expit # the sigmoid function

def tonp(tsr): return tsr.detach().cpu().numpy()

class Predictor:
    def __init__(self, model: Model, iterator: DataIterator,
                 cuda_device: int=-1) -> None:
        self.model = model
        self.iterator = iterator
        self.cuda_device = cuda_device
        
    def _extract_data(self, batch) -> np.ndarray:
        out_dict = self.model(**batch)
        return expit(tonp(out_dict["tag_logits"]))
    
    def predict(self, ds) -> np.ndarray:
        pred_generator = self.iterator(ds, num_epochs=1, shuffle=False)
        self.model.eval()
        pred_generator_tqdm = tqdm(pred_generator,
                                   total=self.iterator.get_num_batches(ds))
        preds = []
        with torch.no_grad():
            for batch in pred_generator_tqdm:
                batch = nn_util.move_to_device(batch, self.cuda_device)
                preds.append(self._extract_data(batch))
        return preds

In [32]:
from allennlp.data.iterators import BasicIterator
# iterate over the dataset without changing its order
seq_iterator = BasicIterator(batch_size=64)
seq_iterator.index_with(vocab)

In [64]:
predictor = Predictor(model, seq_iterator, cuda_device=0)

In [65]:
test_preds = predictor.predict(validation_dataset)

100%|██████████| 46/46 [00:01<00:00, 28.16it/s]


In [56]:
a1 = test_preds[0]

In [66]:
test_preds

[array([[[0.689646  , 0.44694638, 0.43020666, 0.41612455],
         [0.82185113, 0.40735516, 0.36402375, 0.341292  ],
         [0.8853623 , 0.35924545, 0.33473325, 0.28613812],
         ...,
         [0.6400122 , 0.4661164 , 0.44365484, 0.42744187],
         [0.6400122 , 0.4661164 , 0.44365484, 0.42744187],
         [0.6400122 , 0.4661164 , 0.44365484, 0.42744187]],
 
        [[0.6842742 , 0.4638188 , 0.42904213, 0.40697294],
         [0.8409406 , 0.3915321 , 0.35967204, 0.3132512 ],
         [0.89774764, 0.34103537, 0.31876275, 0.26019377],
         ...,
         [0.6400122 , 0.4661164 , 0.44365484, 0.42744187],
         [0.6400122 , 0.4661164 , 0.44365484, 0.42744187],
         [0.6400122 , 0.4661164 , 0.44365484, 0.42744187]],
 
        [[0.61869496, 0.4781061 , 0.45288125, 0.42933974],
         [0.8218461 , 0.40247455, 0.35611326, 0.33137822],
         [0.89349586, 0.3583878 , 0.3203543 , 0.26947114],
         ...,
         [0.6400122 , 0.4661164 , 0.44365484, 0.42744187],
        

In [67]:
np.argmax(a1, axis=-1)

array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]])

In [53]:
test_preds[1].shape

(64, 667, 4)

In [76]:
vars(train_dataset[0].fields['labels'])

{'labels': ('B',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'B',
  'I',
  'I',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'B',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  '