In [None]:
!pip install transformers
!pip install datasets
!pip install evaluate
!pip install seqeval
!pip install sentencepiece

In [None]:
import pickle
from functools import reduce
import re
from tqdm import tqdm
from dataclasses import dataclass
import os

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report
import nltk
import matplotlib.pyplot as plt
import pandas as pd
from tensorflow.keras.utils import pad_sequences

from transformers import TrainingArguments, Trainer, AutoConfig, AutoTokenizer, AutoModelForTokenClassification, T5Tokenizer, T5EncoderModel, T5ForConditionalGeneration, DataCollator, T5TokenizerFast
from transformers import DataCollatorForTokenClassification, DataCollatorForSeq2Seq
import torch

from datasets import load_dataset, load_from_disk
import evaluate

In [None]:
path = '/content/drive/MyDrive/tr_project/'

In [None]:
nltk.download('punkt')

In [None]:
all_sentences = pickle.load(open(f'{path}all_sentences_processed.pkl', 'rb'))
all_labels = pickle.load(open(f'{path}all_labels_processed.pkl', 'rb'))

In [None]:
N = len(all_sentences)
n = 200_000

In [None]:
indices = np.random.choice(N, n, replace=False)

In [None]:
all_sentences, all_labels = [all_sentences[i] for i in indices], [all_labels[i] for i in indices]

In [None]:
len(all_sentences)

In [None]:
def apos_pos_analysis(texts):
  i_s = []
  max_i = 0
  for text in texts:
    words = nltk.word_tokenize(text)
    for word in words:
      try:
        i = word.index("'")
        i = len(word) - i
        if i > max_i:
          max_i = i
          print(f'{word}: {i}')
        i_s.append(i)
      except:
        pass

  print(max(i_s))

In [None]:
def index_safe(s, c):
  try:
    return s.index(c)
  except:
    return -1

def count(s, c):
  x = 0
  for i in range(len(s)):
    if s[i] == c:
      x += 1
  return  x

In [None]:
def label_seq(sent, labels):
  new_seq = []
  label_i = 0
  char_c = 0
  reg_tok = nltk.RegexpTokenizer('[\w\'"]+|[.,;:?!\-\(\)]')
  seq = reg_tok.tokenize(sent)
  label_seq = [0] * len(seq)

  for i, c in enumerate(seq):
    try:
      word_len = len(c)
      j = index_safe(c, "'")
      if j != -1:
        j_ = len(c) - j
        if j_ <= 6:
          label_seq[i] = j_

        label_i += 1
      
      j = count(c, '"')
      if j > 0:
        label_i += 1

      if j == 1:
        j = index_safe(c, '"')
        if j == 0:
          label_seq[i] = 7
        else:
          label_seq[i] = 8
      elif j > 1:
        label_seq[i] = 9
      new_seq.append(c.replace('"', '').replace("'", ""))
      
      if label_i == len(labels):
        new_seq[i+1:] = seq[i+1:]
        break

    except:
      new_seq.append(c)

    char_c += word_len + 1

  
  return new_seq, label_seq


index = 200
res = label_seq(all_sentences[index], all_labels[index])

all_sentences[index], res[0], str(res[1]), len(res[1])

In [None]:
tokenizer = T5TokenizerFast.from_pretrained("t5-base")

In [None]:
seq_labels = [label_seq(s, l) for s, l in zip(all_sentences, all_labels)]
sentences = [s for s, l in seq_labels]
labels = [l for s, l in seq_labels]

In [None]:
padded_labels = pad_sequences(labels, maxlen=500)
np.unique(padded_labels, return_counts=True)

In [None]:
train_seq_labels, test_seq_labels = train_test_split(seq_labels, test_size=0.1, stratify=[max(l) for l in labels])

In [None]:
df = pd.DataFrame(train_seq_labels, columns=['text', 'label'])
df.to_json("train_data.json", orient="records", lines=True)

In [None]:
df = pd.DataFrame(test_seq_labels, columns=['text', 'label'])
df.to_json("test_data.json", orient="records", lines=True)

In [None]:
dataset = load_dataset("json", data_files={"train": "train_data.json", "test": "test_data.json"})

In [None]:
dataset

In [None]:
def tokenize_function(examples):
    return tokenizer(examples["text"], padding='max_length', truncation=True)

In [None]:
def align_labels_with_tokens(labels, word_ids):
    new_labels = []
    current_word = None
    for word_id in word_ids:
        if word_id != current_word:
            current_word = word_id
            label = -100 if word_id is None else labels[word_id]
            if label > 0:
              label = 2*label - 1
            new_labels.append(label)
        elif word_id is None:
            new_labels.append(-100)
        else:
            label = labels[word_id]
            if label > 0:
              label *= 2
            
            new_labels.append(label)

    return new_labels

In [None]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer.batch_encode_plus(
        examples["text"], pad_to_max_length=True, max_length=512, is_split_into_words=True, return_tensors="pt"
    )

    labels = []
    for i, l in enumerate(examples["label"]):
      word_ids = tokenized_inputs.word_ids(i)
      new_l = align_labels_with_tokens(l, word_ids)
      labels.append([str(w) for w in new_l])
    
    tokenized_labels = tokenizer.batch_encode_plus(labels, pad_to_max_length=True, max_length=512, is_split_into_words=True, return_tensors="pt")
    tokenized_inputs["decoder_attention_mask"] = tokenized_labels["attention_mask"]
    tokenized_inputs["decoder_input_ids"] = tokenized_labels["input_ids"]
    labels = tokenized_labels["input_ids"].clone().detach()
    labels[labels[:, :] == tokenizer.pad_token_id] = -100
    tokenized_inputs["labels"] = labels
    
    return tokenized_inputs

In [None]:
if os.path.exists('tokenized_dataset_t5'):
  tokenized_dataset = load_from_disk('tokenized_dataset_t5')
else:
  tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True, load_from_cache_file=False)
  tokenized_dataset = tokenized_dataset.remove_columns(['label', 'text'])
  tokenized_dataset.save_to_disk('tokenized_dataset_t5')

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=False)

In [None]:
seqeval = evaluate.load("seqeval")

In [None]:
base_classes = np.arange(0, 10)
classes = np.concatenate([base_classes * 2, base_classes[1:] * 2 - 1])
classes = np.sort(classes)
classes = list(map(int, classes))
classes

In [None]:
id2label = {
    0: 'O'
}
for i in range(1, len(classes), 2):
  if i < 13:
    pre = 'AP'
  else:
    pre = 'QU'
  id2label[classes[i]] = f'B-{pre}-{base_classes[i//2] + 1}'
  id2label[classes[i+1]] = f'I-{pre}-{base_classes[i//2] + 1}'


label2id = { v:k for k,v in id2label.items() }

In [None]:
id2label

In [None]:
label2id

In [None]:
model = T5ForConditionalGeneration.from_pretrained('t5-small')
model

In [None]:
#num_labels=len(id2label), id2label=id2label, label2id=label2id

In [None]:
def preprocess_logits_for_metrics(logits, labels):
    pred_ids = torch.argmax(logits[0], dim=-1)
    return pred_ids, labels


In [None]:
def to_int_safe(s):
  try:
    return int(s)
  except:
    return 0

In [None]:
label_list = list(label2id.keys())

def compute_metrics(p):
  predictions, labels = p
  predictions = predictions[0]

  true_predictions = [
      [p for (p, l) in zip(prediction, label) if l != -100]
      for prediction, label in zip(predictions, labels)
  ]
  true_labels = [
      [l for (p, l) in zip(prediction, label) if l != -100]
      for prediction, label in zip(predictions, labels)
  ]

  decoded_preds = tokenizer.batch_decode(true_predictions, skip_special_tokens=True)
  decoded_labels = tokenizer.batch_decode(true_labels, skip_special_tokens=True)

  decoded_preds = [
      [label_list[to_int_safe(t)] if 0 <= to_int_safe(t) < len(label_list) else 'O' for t in seq.split()]
      for seq in decoded_preds
  ]
  decoded_labels = [
      [label_list[to_int_safe(t)] if 0 <= to_int_safe(t) < len(label_list) else 'O' for t in seq.split()]
      for seq in decoded_labels
  ]
  decoded_preds = [
      [t for t in decoded_preds[i]] + ['O'] * (len(decoded_labels[i]) - len(decoded_preds[i]))
      for i in range(len(decoded_preds))
  ]

  results = seqeval.compute(predictions=decoded_preds, references=decoded_labels)

  res = {}

  for c in base_classes[1:]:
    pre = 'AP' if c < 7 else 'QU'
    label_type = f"{pre}-{c}"
    try:
      res[f"{label_type}_f1"] = results[label_type]['f1']
      res[f"{label_type}_number"] = results[label_type]['number']
    except:
      print(f"{label_type} does not exist.")
    res["overall_f1"] = results["overall_f1"]

  return res

In [None]:
train_path = '/content/drive/MyDrive/tr_project/t5_train_all/'

In [None]:
training_args = TrainingArguments(
    output_dir=train_path,
    learning_rate=1e-5,
    per_device_train_batch_size=12,
    per_device_eval_batch_size=2,
    num_train_epochs=10,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)

In [None]:
!nvidia-smi

In [None]:
trainer.train()

In [None]:
!rm -r tokenized_dataset_t5