In [None]:
# !pip install indic_nlp_library
!pip install sentencepiece
!pip install transformers

In [None]:
# from indicnlp.tokenize.indic_tokenize import trivial_tokenize_indic
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModel
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.nn import CrossEntropyLoss, MSELoss
from torch.optim import Adam

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
EXP_DIR = 'en-de'
SENT_TRAIN_FILE_PATH = 'train.ende.df.short.tsv'
SENT_VAL_FILE_PATH = 'dev.ende.df.short.tsv'
SENT_TEST_FILE_PATH = 'test20.ende.df.short.tsv'

WORD_TRAIN_SRC_FILE_PATH = 'train.src'
WORD_TRAIN_TGT_FILE_PATH = 'train.mt'
WORD_TRAIN_SRC_TAGS_FILE_PATH = 'train.source_tags'
WORD_TRAIN_TGT_TAGS_FILE_PATH = 'train.tags'

WORD_VAL_SRC_FILE_PATH = 'dev.src'
WORD_VAL_TGT_FILE_PATH = 'dev.mt'
WORD_VAL_SRC_TAGS_FILE_PATH = 'dev.source_tags'
WORD_VAL_TGT_TAGS_FILE_PATH = 'dev.tags'

WORD_TEST_SRC_FILE_PATH = 'test20.src'
WORD_TEST_TGT_FILE_PATH = 'test20.mt'
WORD_TEST_SRC_TAGS_FILE_PATH = 'test20.source_tags'
WORD_TEST_TGT_TAGS_FILE_PATH = 'test20.tags'

MAX_LEN = 256
SEP_TOKEN = '</s>'
LABEL_ALL_TOKENS = False

MODEL_TYPE = 'xlm-roberta-base'
NUM_EPOCHS = 5
NUM_ACC_STEPS = 1
BATCH_SIZE = 16
LR_RATE = 1e-5
PATIENCE = 10
MIN_DELTA = 5.

if not os.path.exists(os.path.join(os.getcwd(), 'Outputs')):
  os.mkdir('Outputs')
BEST_MODEL_PATH = os.path.join(os.getcwd(), 'Outputs', EXP_DIR + '_' + 'best_model.pt')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_TYPE)

In [None]:
def load_data(sent_file_path, word_src_path, word_tgt_path, word_src_tags_path, word_tgt_tags_path, sep_token):
  lines = open(sent_file_path, 'r').readlines()[1:]
  sent_labels = [float(line.strip().split('\t')[-2]) for line in lines]

  with open(word_src_path) as f1, open(word_tgt_path) as f2:
    src_sents = [str(sent).strip() for sent in f1.readlines()]
    tgt_sents = [str(sent).strip() for sent in f2.readlines()]

  with open(word_src_tags_path) as f1, open(word_tgt_tags_path) as f2:
    src_tags = [str(sent).strip() for sent in f1.readlines()]
    tgt_tags = [str(sent).strip() for sent in f2.readlines()]

  input_sents = []
  for s, t in zip(src_sents, tgt_sents):
    new_t = "<gap> " + " <gap> ".join(t.split(" "))
    input_sent = s + " " + sep_token + " " + sep_token + " " + new_t
    input_sents.append(input_sent)

  word_labels = []
  for s, t in zip(src_tags, tgt_tags):
    tag_seq = s.split(" ") + ["OK"] + ["OK"] + t.split(" ")[:-1]
    word_labels.append(tag_seq)

  return input_sents, word_labels, sent_labels

In [None]:
train_inputs, train_word_labels, train_sent_labels = load_data(SENT_TRAIN_FILE_PATH, WORD_TRAIN_SRC_FILE_PATH, WORD_TRAIN_TGT_FILE_PATH, WORD_TRAIN_SRC_TAGS_FILE_PATH, WORD_TRAIN_TGT_TAGS_FILE_PATH, SEP_TOKEN)
val_inputs, val_word_labels, val_sent_labels = load_data(SENT_VAL_FILE_PATH, WORD_VAL_SRC_FILE_PATH, WORD_VAL_TGT_FILE_PATH, WORD_VAL_SRC_TAGS_FILE_PATH, WORD_VAL_TGT_TAGS_FILE_PATH, SEP_TOKEN)
test_inputs, test_word_labels, test_sent_labels = load_data(SENT_TEST_FILE_PATH, WORD_TEST_SRC_FILE_PATH, WORD_TEST_TGT_FILE_PATH, WORD_TEST_SRC_TAGS_FILE_PATH, WORD_TEST_TGT_TAGS_FILE_PATH, SEP_TOKEN)

train_df = pd.DataFrame({"text": train_inputs, "word_labels": train_word_labels, "sent_labels": train_sent_labels})
val_df = pd.DataFrame({"text": val_inputs, "word_labels": val_word_labels, "sent_labels": val_sent_labels})
test_df = pd.DataFrame({"text": test_inputs, "word_labels": test_word_labels, "sent_labels": test_sent_labels})

In [None]:
labels_to_ids = {"OK": 0, "BAD": 1}
ids_to_labels = {0: "OK", 1: "BAD"}

In [None]:
def align_label(text, labels, label_all_tokens):
  tokenized_input = tokenizer(text, padding='max_length', max_length=MAX_LEN, truncation=True, return_tensors="pt")

  word_ids = tokenized_input.word_ids()

  previous_word_idx = None
  label_ids = []
  for word_idx in word_ids:

      if word_idx is None:
          label_ids.append(-100)

      elif word_idx != previous_word_idx:
          try:
              label_ids.append(labels_to_ids[labels[word_idx]])
          except:
              label_ids.append(-100)
      else:
          try:
              label_ids.append(labels_to_ids[labels[word_idx]] if label_all_tokens else -100)
          except:
              label_ids.append(-100)
      previous_word_idx = word_idx

  return label_ids

In [None]:
class DataSequence(torch.utils.data.Dataset):

    def __init__(self, df):

        word_lb = df['word_labels'].values.tolist()
        sent_lb = df['sent_labels'].values.tolist()
        txts = df['text'].values.tolist()
        self.texts = [tokenizer(str(txt),
                               padding='max_length', max_length = MAX_LEN, truncation=True, return_tensors="pt") for txt in txts]
        self.word_labels = [align_label(t, l, LABEL_ALL_TOKENS) for t, l in zip(txts, word_lb)]
        self.sent_labels = sent_lb

    def __len__(self):

        return len(self.sent_labels)

    def get_batch_data(self, idx):

        return self.texts[idx]

    def get_batch_word_labels(self, idx):

        return torch.LongTensor(self.word_labels[idx])

    def get_batch_sent_labels(self, idx):

        return torch.tensor(self.sent_labels[idx], dtype=torch.float)

    def __getitem__(self, idx):

        batch_data = self.get_batch_data(idx)
        batch_word_labels = self.get_batch_word_labels(idx)
        batch_sent_labels = self.get_batch_sent_labels(idx)

        return batch_data, batch_word_labels, batch_sent_labels

In [None]:
train_dataset = DataSequence(train_df)
val_dataset = DataSequence(val_df)
test_dataset = DataSequence(test_df)

train_dataloader = DataLoader(train_dataset, num_workers=2, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, num_workers=2, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, num_workers=2, batch_size=BATCH_SIZE, shuffle=True)

In [None]:
class MTLModel(torch.nn.Module):
  def __init__(self, config):
    super().__init__()

    self.num_labels = config.num_labels

    self.base_model = AutoModel.from_pretrained(MODEL_TYPE, config=config, add_pooling_layer=False)
    classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
    # word layers
    self.dropout = nn.Dropout(classifier_dropout)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    # sentence layers
    self.dense = nn.Linear(config.hidden_size, config.hidden_size)
    self.out_proj = nn.Linear(config.hidden_size, 1)

  def forward(self, input_ids, attn_mask, word_labels=None, sent_labels=None):
    outputs = self.base_model(input_ids, attn_mask)
    sequence_output = outputs[0]

    # token_classification
    sequence_op_dropout = self.dropout(sequence_output)
    logits = self.classifier(sequence_op_dropout)

    # regression
    x = sequence_output[:, 0, :]
    x = self.dropout(x)
    x = self.dense(x)
    x = torch.tanh(x)
    x = self.dropout(x)
    x = self.out_proj(x)

    if word_labels is not None and sent_labels is not None:
      word_loss_fn = CrossEntropyLoss()
      word_loss = word_loss_fn(logits.view(-1, self.num_labels), word_labels.view(-1))

      sent_loss_fn = MSELoss()
      sent_loss = sent_loss_fn(x.squeeze(), sent_labels.squeeze())

      return logits, x, sent_loss, word_loss

    else:
      return logits, x

In [None]:
config = AutoConfig.from_pretrained(MODEL_TYPE)
config.num_labels = len(labels_to_ids)

model = MTLModel(config)
model.to(device)

optimizer = Adam(model.parameters(), lr=LR_RATE)

In [None]:
# sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())

278045955

In [None]:
def validate(model, dataloader):

  model.eval()

  total_val_loss, total_val_mse, total_val_acc = np.inf, np.inf, 0.
  for data, word_labels, sent_labels in dataloader:

    input_ids = data['input_ids'].squeeze(1).to(device)
    attn_masks = data['attention_mask'].squeeze(1).to(device)
    word_labels = word_labels.to(device)
    sent_labels = sent_labels.to(device)

    logits, x, sent_loss, word_loss = model(input_ids, attn_masks, word_labels, sent_labels)
    loss = torch.add(sent_loss, word_loss)

    # word predictions
    for i in range(logits.shape[0]):
      word_logits_clean = logits[i][word_labels[i] != -100]
      word_label_clean = word_labels[i][word_labels[i] != -100]
      predictions = word_logits_clean.argmax(dim=1)
      acc = (predictions == word_label_clean).float().mean()
      total_val_acc += acc

    total_val_mse += sent_loss.item()
    total_val_loss += loss.item()

  return total_val_loss, total_val_mse, total_val_acc

In [None]:
class EarlyStopping():
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stopping(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False

early_stopping = EarlyStopping(patience=3, min_delta=10)

In [None]:
best_val_loss = np.inf
train_losses, val_losses= [], []
train_accs, val_accs = [], []
train_mses, val_mses = [], []

for epoch in range(NUM_EPOCHS):

  model.train()

  total_train_acc, total_train_mse = 0., np.inf
  total_train_loss, total_dev_loss = np.inf, np.inf
  for data, word_labels, sent_labels in tqdm(train_dataloader):

    input_ids = data['input_ids'].squeeze(1).to(device)
    attn_masks = data['attention_mask'].squeeze(1).to(device)
    word_labels = word_labels.to(device)
    sent_labels = sent_labels.to(device)

    optimizer.zero_grad()

    logits, x, sent_loss, word_loss = model(input_ids, attn_masks, word_labels, sent_labels)

    loss = torch.add(sent_loss, word_loss)

    with torch.no_grad():
      # word predictions
      for i in range(logits.shape[0]):
        word_logits_clean = logits[i][word_labels[i] != -100]
        word_label_clean = word_labels[i][word_labels[i] != -100]
        predictions = word_logits_clean.argmax(dim=1)
        acc = (predictions == word_label_clean).float().mean()

        total_train_acc += acc

      total_train_mse += sent_loss.item()
      total_train_loss += loss.item()

    loss.backward()
    optimizer.step()

  total_val_loss, total_val_mse, total_val_acc = validate(model, val_dataloader)

  if best_val_loss >= total_val_loss:
    best_val_loss = total_val_loss
    torch.save(model.state_dict(), BEST_MODEL_PATH)

  if early_stopping.early_stopping(total_val_loss):
    break

  train_accs.append(total_train_acc)
  train_mses.append(total_train_mse)
  train_losses.append(total_train_loss)
  val_accs.append(total_val_acc)
  val_mses.append(total_val_mse)
  val_losses.append(total_val_loss)

  print(f'Epochs: {epoch + 1} | Train_Loss: {total_train_loss / len(train_df): .3f} | Train_Accuracy: {total_train_acc / len(train_df): .3f} | Train_MSE: {total_train_mse: .3} | Val_Loss: {total_val_loss / len(val_df): .3f} | Val_Accuracy: {total_val_acc/ len(val_df): .3f} | Val_MSE: {total_val_mse: .3f}')

In [None]:
# model.eval()

# for data, word_labels, sent_labels in tqdm(test_dataloader):

#     input_ids = data['input_ids'].squeeze(1).to(device)
#     attn_masks = data['attention_mask'].squeeze(1).to(device)
#     word_labels = word_labels.to(device)
#     sent_labels = sent_labels.to(device)

#     logits, x = model(input_ids, attn_masks)

#     for i in range(logits.shape[0]):
#         word_logits_clean = logits[i][word_labels[i] != -100]
#         word_label_clean = word_labels[i][word_labels[i] != -100]
#         predictions = word_logits_clean.argmax(dim=1)
#         acc = (predictions == word_label_clean).float().mean()