In [1]:
from datasets import load_dataset, Dataset, DatasetDict
from transformers import DataCollatorForSeq2Seq
from transformers import T5Tokenizer, T5ForConditionalGeneration
from torch.utils.data import DataLoader
import pickle
import pandas

model_checkpoint = "t5-small"

model = T5ForConditionalGeneration.from_pretrained(model_checkpoint)
tokenizer = T5Tokenizer.from_pretrained(model_checkpoint)
new_words = ['<vertex>', '<r>', '<t>', '<h>', '[[', ']]', '<end>']
tokenizer.add_tokens(new_words)
model.resize_token_embeddings(len(tokenizer))


# get the dataset
with open('DR.pkl', 'rb') as file:
  data_dump = pickle.load(file)

train_datadict = {'document': [], 'summary': []}
validate_datadict = {'document': [], 'summary': []}
for x, data_point in enumerate(data_dump):
  if x < 2453:
    train_datadict['document'].append(data_point['text'])
    train_datadict['summary'].append(data_point['linearized'])
  else:
    validate_datadict['document'].append(data_point['text'])
    validate_datadict['summary'].append(data_point['linearized'])


train_dataset = Dataset.from_dict(train_datadict)
validate_dataset = Dataset.from_dict(validate_datadict)
dataset = DatasetDict({'train': train_dataset, 'validation': validate_dataset})

def preprocess_data(examples):
    model_inputs = tokenizer(examples['document'], max_length = 2048, truncation = True, padding = True)
    with tokenizer.as_target_tokenizer():
        targets = tokenizer(examples['summary'], max_length = 1024, truncation = True, padding = True)
    model_inputs['labels'] = targets['input_ids']
    return model_inputs


tokenized_datasets = dataset.map(preprocess_data, batched = True, remove_columns=['document','summary'])
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model,return_tensors='pt')
batch_size = 8
train_data_loader = DataLoader(tokenized_datasets["train"], shuffle = True, batch_size = batch_size, collate_fn = data_collator)
eval_data_loader = DataLoader(tokenized_datasets["validation"], shuffle = True, batch_size = batch_size, collate_fn = data_collator)

def _delinearize_relations(strings):
    output = []
    for string in strings:
        ht = string.split('<h>')
        relation_type = ht.pop(0).replace(' ', '')
        final_split = ht[0].split('<t>')
        head = int(final_split.pop(0))
        if '<end>' in final_split[0]:
            # final_split[0] = final_split[0].replace('<pad>', '')
            tail = int(final_split.pop()[:-11])
        else:
            tail = int(final_split.pop())
        output.append({'r': relation_type, 'h': head, 't': tail})
    return output

# takes a list of linear strings, and a list with the corresponding article input
# returns a list of dictionary items included in linear string

# this is going to update the inputs dataset by filling in the veritices and relations frield for each data_point
def delinearize(linear_strings, dataset):
    for x, string in enumerate(linear_strings):
        string = string.replace('<pad>', '')
        split = string.split('<r>')
        vertices = split.pop(0)
        relations = _delinearize_relations(split)
        vertices = vertices.split('<vertex>')
        vertices.pop(0)
        vertex_data_form = dict()
        for vertex in vertices:
            split = vertex.split('[[')
            if '</s>' in split[1]:
                split[1] = split[1].replace('</s>', '')
                split[1] = split[1].replace('<end> ', '')
            vertex_data_form[int(split[1][:-3])] = split[0][1:-1]
        dataset[x]['relations'] = relations
        dataset[x]['vertexList'] = vertex_data_form
        dataset[x]['linearized'] = string

# helper function
# return an array of relations with the same head and tail vertices
# if none exist return None type
def _match(relation, comparisons):
    output = []
    print('comparisons: ', comparisons)
    for comparison in comparisons:
        print('comparison: ', comparison)
        print('relation: ', relation)
        if comparison['h'] == relation['h'] and comparison['t'] == relation['t']:
          if comparison['r'] == relation['r']:
            return comparison
          else:
            output.append(comparison)
    if output == []:
        return None
    else:
        return output


# confusion matrix generator
def generate_confusion_matrix(pred_labels, true_labels, possible_labels):
  matrix = [[0 for x in range(len(possible_labels))] for y in range(len(possible_labels))]
  matched_full = []
  print()
  for x, article in enumerate(pred_labels):
    for relation in article:
      print(relation)
      matches = _match(relation, true_labels[x])
      pred_idx = possible_labels.index(relation['r'])
      if matches is None:
        # generated an invalid relation
        matrix[pred_idx][-1] += 1
        matched_full.append(relation)
      elif matches is list:
        # we generated the right pair worng relation ID
        actual_idx = [possible_labels.index(matches[0]['r'])]
        matrix[pred_idx][actual_idx] += 1
        matched_full.append(relation[0])
      else:
        actual_idx = [possible_labels.index(matches['r'])]
        matrix[pred_idx][actual_idx] += 1
        matched_full.append(relation)
    if not matched_full == true_labels[x]:
      for x, label in enumerate(true_labels[x]):
        if label not in matched_full:
          matrix[-1][possible_labels.index(label['r'].replace(' ', ''))] += 1

  return matrix

# get the valid labels
def get_labels(filepath):
  with open(filepath, 'r') as file:
    table = pandas.read_json(file, typ='series')
  keys = [key for key in table.keys()]
  return keys


def compute_accuracy(eval_pred):
  predictions, labels = eval_pred
  token_preds = tokenizer.batch_decode(predictions)
  token_labs = tokenizer.batch_decode(labels)

  #need to delinearize before confusion matrix
  # need dictionaries for preds and labs of each data_point
  pred_dicts = [{'linearized': pred_lin} for pred_lin in token_preds]
  lab_dicts = [{'linearized': lab_lin} for lab_lin in token_labs]

  delinearize(token_preds, pred_dicts)
  delinearize(token_labs, lab_dicts)

  pred_rels = [article['relations'] for article in pred_dicts]
  lab_rels = [article['relations'] for article in lab_dicts]


  label_list = get_labels('rel_info.json')
  confusion_matrix = generate_confusion_matrix(pred_rels, lab_rels, label_list)


  num_correct = 0
  for x in range(len(label_list)):
    num_correct += confusion_matrix[x][x]

  return {'accuracy': num_correct/len(predictions)}

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

args = Seq2SeqTrainingArguments(
    output_dir=f"{model_checkpoint}-DocRED",
    eval_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_accuracy,
)

trainer.train()

  from .autonotebook import tqdm as notebook_tqdm


ImportError: 
T5Tokenizer requires the SentencePiece library but it was not found in your environment. Checkout the instructions on the
installation page of its repo: https://github.com/google/sentencepiece#installation and follow the ones
that match your environment. Please note that you may need to restart your runtime after installation.
