In [1]:
import os
import torch
import shutil
import logging
from tqdm import trange, tqdm

from torch.optim.adam import Adam
from torch.utils.data import RandomSampler, SequentialSampler, DataLoader

from ML_model import BiLSTM_CNN_CRF
from ML_dataloader import load_word_matrix
from ML_utils import get_labels, get_test_texts, set_seed

  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


In [2]:
logger = logging.getLogger(__name__)

In [3]:
class Trainer(object):
    def __init__(self, args, train_dataset=None, dev_dataset=None, test_dataset=None):
        self.args = args
        self.train_dataset = train_dataset
        self.dev_dataset = dev_dataset
        self.test_dataset = test_dataset

        self.label_lst = get_labels(args)
        self.num_labels = len(self.label_lst)

        # Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
        self.pad_token_label_id = args.ignore_index

        self.pretrained_word_matrix = None
        if not args.no_w2v:
            self.pretrained_word_matrix = load_word_matrix(args, self.word_vocab)

        self.model = BiLSTM_CNN_CRF(args, self.pretrained_word_matrix)

        # GPU or CPU
        self.device = "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu"
        self.model.to(self.device)

        self.test_texts = None
        if args.write_pred:
            self.test_texts = get_test_texts(args)
            # Empty the origianl prediction files
            if os.path.exists(args.pred_dir):
                shutil.rmtree(args.pred_dir)


    def train(self):
        train_sampler = RandomSampler(self.train_dataset)
        train_dataloader = DataLoader(self.train_dataset, sampler=train_sampler, batch_size = self.args.train_batch_size)

        # Optimizer and schedule (linear warmup and decay)
        optimizer = Adam(self.model.parameters(), lr = self.args.learning_rate)

        # Train!
        logger.info("***** Running Training *****")
        logger.info(f"   Num examples = {len(self.train_dataset)}")
        logger.info(f"   Num Epochs = {self.args.num_train_epochs}")
        logger.info(f"   Batch size = {self.args.train_batch_size}")

        global_step = 0
        tr_loss = 0.0
        self.model.zero_grad()

        train_iterator = trange(int(self.args.num_train_epochs), desc = "Epoch")
        set_seed(self.args)

        for _ in train_iterator:
            epoch_iterator = tqdm(train_dataloader, desc = "Iteration")
            for step, batch in enumerate(epoch_iterator):
                self.model.train()
                batch = tuple(t.to(self.device) for t in batch)         # GPU or CPU

                inputs = {'word_ids' : batch[0],
                          'char_ids' : batch[1],
                          'mask' : batch[2],
                          'label_ids' : batch[3]}
                outputs = self.model(**inputs)
                loss = outputs[0]

                loss.backward()

                tr_loss += loss.item()

                optimizer.step()
                self.model.zero_grad()
                global_step += 1

                if self.args.logging_steps > 0 and global_step % self.args.logging_steps == 0:
                    self.evaluate("test", global_step)

                if self.args.save_steps > 0 and global_step % self.args.save_steps == 0:
                    self.save_model()

        return global_step, tr_loss / global_step

    def evaluate(self, mode, step):
        if mode == 'test':
            dataset = self.test_dataset
        elif mode == 'dev':
            dataset = self.dev_dataset
        elif mode == 'train':
            dataset = self.train_dataset
        else:
            raise Exception("Only train, dev and test dataset available")

        eval_sampler = SequentialSampler(dataset)
        eval_dataloader = DataLoader(dataset, sampler =  eval_sampler, batch_size = self.args.eval_batch_size)
        
        # Eval!
        logger.info("***** Running evaluation on %s dataset *****", mode)
        logger.info("   Num examples = %d", len(dataset))
        logger.inof("   Batch size = %d", self.args.eval_batch_size)
        eval_loss = 0.0
        nb_eval_stpes = 0
        preds = None
        out_label_ids = None

        for batch in tqdm(eval_dataloader, desc = "Evaluating"):
            self.model.eval()
            batch = tuple(t.to(self.device) for t in batch)
            with torch.no_grad():
                inputs = {
                    'word_ids' : batch[0],
                    'char_ids' : batch[1],
                    'mask' : batch[2],
                    'label_ids' : batch[3]
                }
                outputs = self.model(**inputs)
