In [1]:

import pickle
import time
import datetime
from normalizerFunctions import Training_Corpus
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils import data
from pytorch_pretrained_bert import BertTokenizer
from tqdm import tqdm_notebook as tqdm
from Normalization_Dataset import Normalization
from sklearn.model_selection import train_test_split
from transformers import get_linear_schedule_with_warmup


In [14]:
with open('pickled_wus.pkl', 'rb') as file:
    wus_corpus = pickle.load(file)
with open('pickled_archimob.pkl', 'rb') as file:
    archimob_corpus = pickle.load(file)
with open('pickled_train_corpus.pkl', 'rb') as file:
    norm_corpus = pickle.load(file)

In [3]:

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(torch.backends.mps.is_available())

True


In [4]:
# define sequence padding 
def pad(batch):
    '''Pads to the longest sample'''
    f = lambda x: [sample[x] for sample in batch]
    words = f(0)
    is_heads = f(2)
    labels = f(3)
    seqlens = f(-1)
    maxlen = np.array(seqlens).max()
    f = lambda x, seqlen: [sample[x] + [0] * (seqlen - len(sample[x])) for sample in batch] # 0: <pad>
    x = f(1, maxlen)
    y = f(-2, maxlen)
    f = torch.LongTensor
    return words, f(x), is_heads, labels, f(y), seqlens

In [5]:
# define Token_Classifier model class
from pytorch_pretrained_bert import BertModel

class Token_Classifier(nn.Module):
    def __init__(self, vocab_size=None):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-swiss-lm')

        self.fc = nn.Linear(768, vocab_size) # this is just the ouput shape for BertModel
        self.device = device

    def forward(self, x, y):
        x = x.to(device) # (N, L). int64
        y = y.to(device) # (N, L). int64
        with torch.set_grad_enabled(self.training):
            self.bert.train(self.training)
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        logits = self.fc(enc)
        y_hat = logits.argmax(-1)
        return logits, y, y_hat
    
    def normalize(self, x):
        self.bert.eval()
        with torch.no_grad():
            encoded_layers, _ = self.bert(x)
            enc = encoded_layers[-1]
        logits = self.fc(enc)
        prediction = logits.argmax(-1)
        return prediction

In [11]:
def get_Dataloader(corpus, name):
    tokenizer = BertTokenizer.from_pretrained('bert-swiss-lm', do_lower_case=True, never_split=corpus.multigrams)
    labels = ["<pad>"] + list(corpus.labels)
    label2idx = {label:idx for idx, label in enumerate(labels)}
    idx2label = {idx:label for idx, label in enumerate(labels)}
    # create training / validation split and load data into batches 
    train_data, val_data = train_test_split(corpus.word_norm_pairs)
    train_dataset = Normalization(train_data, label2idx, corpus.multigrams)
    val_dataset = Normalization(val_data,label2idx, corpus.multigrams)

    train_iter = data.DataLoader(dataset=train_dataset,
                                batch_size=8,
                                shuffle=True,
                                num_workers=0,
                                collate_fn=pad)
    val_iter = data.DataLoader(dataset=val_dataset,
                                batch_size=8,
                                shuffle=False,
                                num_workers=0,
                                collate_fn=pad)
    with open("idx2label_"+name+".pickle", "wb") as file:
        pickle.dump(idx2label, file)
    with open("label2idx_"+name+".pickle", "wb") as file:
        pickle.dump(label2idx, file)
    print("Data loaded")
    return train_iter, val_iter, label2idx, idx2label

In [8]:
def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [9]:
def get_normalizer(corpus,name):
    train_iter, val_iter, label2idx,idx2label = get_Dataloader(corpus,name)
    model = Token_Classifier(vocab_size=len(label2idx))
    model.to(device)
    if torch.cuda.device_count() > 1:
        print(True)
        model = nn.DataParallel(model)
    criterion = nn.CrossEntropyLoss(ignore_index=0)
    optimizer = optim.Adam(model.parameters(), lr = 0.0001)
    epochs = 3
    # performance and monitoring metrics 
    training_stats = []
    total_t0 = time.time()
    total_steps = len(train_iter) * epochs
    scheduler = get_linear_schedule_with_warmup(optimizer, 
                                                num_warmup_steps = 0, # Default value in run_glue.py
                                                num_training_steps = total_steps)
    # Full training loop
    for epoch in range(0,epochs):
        print("")
        print('======== Epoch {:} / {:} ========'.format(epoch + 1, epochs))
        print('Training...')
        t0 = time.time()
        total_train_loss = 0
        model.train()
        for step, batch in enumerate(train_iter):
            # Progress update every 500 batches.
            if step % 500 == 0 and not step == 0:
                elapsed = format_time(time.time() - t0)
                print('Batch {} of {}.Elapsed: {:}.'.format(step, len(train_iter), elapsed))
            words, x, is_heads, labels, y, seqlens = batch
            _y = y # for monitoring
            optimizer.zero_grad()
            logits, y, _ = model(x, y) # logits: (N, L, VOCAB), y: (N, L)
            logits = logits.view(-1, logits.shape[-1]) # (N*L, VOCAB)
            y = y.view(-1)  # (N*L,)
            loss = criterion(logits, y)
            total_train_loss += loss.item()
            loss.backward()
            optimizer.step()
            scheduler.step()
            if step%1000==0: # monitoring
                print("step: {}, loss: {}".format(step, loss.item()))
        avg_train_loss = total_train_loss / len(train_iter)
        training_time = format_time(time.time() - t0)
        print("")
        print("  Average training loss: {0:.2f}".format(avg_train_loss))
        print("  Training epoch took: {:}".format(training_time))
        # Validation: 
        print("")
        print("Running Validation...")
        t0 = time.time()
        model.eval()
        total_eval_loss = 0
        total = 0
        hits = 0
        words_unnormed = 0
        with open("norm_test_"+name+".txt", 'w') as fout: # generates a results file with the word, the true label, and the prediction
            for batch in val_iter:
                with torch.no_grad():        
                    b_utterances, x, b_is_heads, b_labels, y, seqlens = batch
                    _, _, b_predictions = model(x, y) # logits: (N, L, VOCAB), y: (N, L)
                    b_predictions = b_predictions.detach().cpu().numpy() # pred_ids.cpu().numpy().tolist() alternative?
                    assert len(b_utterances)==len(b_labels)== len(b_predictions)                
                    for utterance, utterance_labels, utterance_preds, is_heads in zip(b_utterances, b_labels, b_predictions, b_is_heads):
                        utterance_preds = [pred for head, pred in zip(is_heads, utterance_preds) if head == 1]
                        for pred in utterance_preds:
                            try:
                                test = idx2label[pred]
                            except KeyError:
                                idx2label[pred] = '<pad>'
                        preds = [idx2label[pred] for pred in utterance_preds]
                        words = utterance.split()
                        labels = utterance_labels.split()
                        assert len(preds)==len(words)== len(labels)
                        for w, l, p in zip(words[1:-1], labels[1:-1], preds[1:-1]):
                            if w == l:
                                words_unnormed += 1
                            if l == p:
                                hits += 1
                            total += 1
                            fout.write("{} {} {}\n".format(w, l, p))
                        fout.write("\n")   
        avg_val_loss = total_eval_loss / len(val_iter)
        validation_time = format_time(time.time() - t0)
        print("  Validation Loss: {0:.2f}".format(avg_val_loss))
        print("  Validation took: {:}".format(validation_time))
        accuracy = 100*hits/total
        unnormed = 100*words_unnormed/total
        print("Epoch {} accuracy: ".format(epoch+1),accuracy)
        print(unnormed)
        Err_Red_rate = (accuracy - unnormed)/(100 - unnormed) # all are percentages
        print(Err_Red_rate)
        print("Epoch {} error reduction rate: ".format(epoch+1),Err_Red_rate)
        training_stats.append(
            {
                'epoch': epoch + 1,
                'Training Loss': avg_train_loss,
                'Valid. Loss': avg_val_loss,
                'Error Reduction': Err_Red_rate,
                'Training Time': training_time,
                'Validation Time': validation_time
            }
        )
        
    print("")
    print("Training complete!")
    print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))
    torch.save(model, 'token_classifier_'+name+'.pt')


In [13]:
get_normalizer(archimob_corpus,'archimob')

Data loaded

Training...
step: 0, loss: 10.34687328338623
Batch 500 of 7729.Elapsed: 0:02:28.
Batch 1000 of 7729.Elapsed: 0:04:52.
step: 1000, loss: 2.212703227996826
Batch 1500 of 7729.Elapsed: 0:07:15.
Batch 2000 of 7729.Elapsed: 0:09:42.
step: 2000, loss: 2.44248366355896
Batch 2500 of 7729.Elapsed: 0:12:04.
Batch 3000 of 7729.Elapsed: 0:14:37.
step: 3000, loss: 2.2548458576202393
Batch 3500 of 7729.Elapsed: 0:17:09.
Batch 4000 of 7729.Elapsed: 0:19:39.
step: 4000, loss: 0.9408369064331055
Batch 4500 of 7729.Elapsed: 0:22:16.
Batch 5000 of 7729.Elapsed: 0:24:47.
step: 5000, loss: 1.481183648109436
Batch 5500 of 7729.Elapsed: 0:27:19.
Batch 6000 of 7729.Elapsed: 0:29:46.
step: 6000, loss: 1.1357345581054688
Batch 6500 of 7729.Elapsed: 0:32:11.
Batch 7000 of 7729.Elapsed: 0:34:37.
step: 7000, loss: 1.7781896591186523
Batch 7500 of 7729.Elapsed: 0:37:02.

  Average training loss: 1.70
  Training epoch took: 0:38:12

Running Validation...
  Validation Loss: 0.00
  Validation took: 0:02:

In [None]:
import Normalization_Dataset 
import importlib
NormalizationDataset = importlib.reload(Normalization_Dataset)
from Normalization_Dataset import Normalization
get_normalizer(wus_corpus,'wus')