In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pip install torch
!pip install transformers
!pip install datasets

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Collectin

In [4]:
from datasets import load_dataset

data = load_dataset("web_nlg", 'release_v1')

Downloading data:   0%|          | 0.00/2.30M [00:00<?, ?B/s]

Generating full split:   0%|          | 0/14237 [00:00<?, ? examples/s]

In [None]:
# dataset -> WN18RR, FB15-237K

# Embedding

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [None]:
class TransE(Model):
	def __init__(self, ent_tot, rel_tot, dim = 100, p_norm = 1, norm_flag = True, margin = None, epsilon = None):
		super(TransE, self).__init__(ent_tot, rel_tot)
		
		self.dim = dim
		self.margin = margin
		self.epsilon = epsilon
		self.norm_flag = norm_flag
		self.p_norm = p_norm

		self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim)
		self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim)

		if margin == None or epsilon == None:
			nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
			nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
		else:
			self.embedding_range = nn.Parameter(
				torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False
				)

			nn.init.uniform_(
				tensor = self.ent_embeddings.weight.data, 
				a = -self.embedding_range.item(), 
				b = self.embedding_range.item()
				)

			nn.init.uniform_(
				tensor = self.rel_embeddings.weight.data, 
				a= -self.embedding_range.item(), 
				b= self.embedding_range.item()
				)

		if margin != None:
			self.margin = nn.Parameter(torch.Tensor([margin]))
			self.margin.requires_grad = False
			self.margin_flag = True
		else:
			self.margin_flag = False


	def _calc(self, h, t, r, mode):
		if self.norm_flag:
			h = F.normalize(h, 2, -1)
			r = F.normalize(r, 2, -1)
			t = F.normalize(t, 2, -1)
		if mode != 'normal':
			h = h.view(-1, r.shape[0], h.shape[-1])
			t = t.view(-1, r.shape[0], t.shape[-1])
			r = r.view(-1, r.shape[0], r.shape[-1])
		if mode == 'head_batch':
			score = h + (r - t)
		else:
			score = (h + r) - t
		score = torch.norm(score, self.p_norm, -1).flatten()
		
		return score


	def forward(self, data):
		batch_h = data['batch_h']
		batch_t = data['batch_t']
		batch_r = data['batch_r']
		mode = data['mode']
		h = self.ent_embeddings(batch_h)
		t = self.ent_embeddings(batch_t) #bs x embedding dim
		r = self.rel_embeddings(batch_r)
		score = self._calc(h ,t, r, mode)
		
    #margin ranking loss
		if self.margin_flag:
			return self.margin - score
		else:
			return score


	def regularization(self, data):
		batch_h = data['batch_h']
		batch_t = data['batch_t']
		batch_r = data['batch_r']
		h = self.ent_embeddings(batch_h)
		t = self.ent_embeddings(batch_t)
		r = self.rel_embeddings(batch_r)
		regul = (torch.mean(h ** 2) + 
				 torch.mean(t ** 2) + 
				 torch.mean(r ** 2)) / 3
		
		return regul


	def predict(self, data):
		score = self.forward(data)
		if self.margin_flag:
			score = self.margin - score
			return score.cpu().data.numpy()
		else:
			return score.cpu().data.numpy()

In [None]:
# embedding lookup -> triple score calc (model by model) -> return, 이과정에서 vector 들이 잘 학습됌.

In [None]:
import torch
import torch.nn as nn
from .Model import Model # ?

In [None]:
class DistMult(Model):
	def __init__(self, ent_tot, rel_tot, dim = 100, margin = None, epsilon = None):
		super(DistMult, self).__init__(ent_tot, rel_tot)

		self.dim = dim
		self.margin = margin
		self.epsilon = epsilon
		self.ent_embeddings = nn.Embedding(self.ent_tot, self.dim)
		self.rel_embeddings = nn.Embedding(self.rel_tot, self.dim)

		if margin == None or epsilon == None:
			nn.init.xavier_uniform_(self.ent_embeddings.weight.data)
			nn.init.xavier_uniform_(self.rel_embeddings.weight.data)
		else:
			self.embedding_range = nn.Parameter(
				torch.Tensor([(self.margin + self.epsilon) / self.dim]), requires_grad=False
			)
			nn.init.uniform_(
				tensor = self.ent_embeddings.weight.data, 
				a = -self.embedding_range.item(), 
				b = self.embedding_range.item()
			)
			nn.init.uniform_(
				tensor = self.rel_embeddings.weight.data, 
				a= -self.embedding_range.item(), 
				b= self.embedding_range.item()
			)


	def _calc(self, h, t, r, mode):
		if mode != 'normal':
			h = h.view(-1, r.shape[0], h.shape[-1])
			t = t.view(-1, r.shape[0], t.shape[-1])
			r = r.view(-1, r.shape[0], r.shape[-1])
		if mode == 'head_batch':
			score = h * (r * t)
		else:
			score = (h * r) * t
		score = torch.sum(score, -1).flatten()

		return score


	def forward(self, data):
		batch_h = data['batch_h']
		batch_t = data['batch_t']
		batch_r = data['batch_r']
		mode = data['mode']
		h = self.ent_embeddings(batch_h)
		t = self.ent_embeddings(batch_t)
		r = self.rel_embeddings(batch_r)
		score = self._calc(h ,t, r, mode)

		return score


	def regularization(self, data):
		batch_h = data['batch_h']
		batch_t = data['batch_t']
		batch_r = data['batch_r']

		h = self.ent_embeddings(batch_h)
		t = self.ent_embeddings(batch_t)
		r = self.rel_embeddings(batch_r)
		regul = (torch.mean(h ** 2) + torch.mean(t ** 2) + torch.mean(r ** 2)) / 3
		
		return regul


	def l3_regularization(self):
		return (self.ent_embeddings.weight.norm(p = 3)**3 + self.rel_embeddings.weight.norm(p = 3)**3)


	def predict(self, data):
		score = -self.forward(data)

		return score.cpu().data.numpy()

In [None]:
import torch
import torch.nn as nn
from .Model import Model

In [None]:
# complex space -> real dim, imaginary dim
class ComplEx(Model):
    def __init__(self, ent_tot, rel_tot, dim = 100):
        super(ComplEx, self).__init__(ent_tot, rel_tot)

        self.dim = dim
        self.ent_re_embeddings = nn.Embedding(self.ent_tot, self.dim)
        self.ent_im_embeddings = nn.Embedding(self.ent_tot, self.dim)
        self.rel_re_embeddings = nn.Embedding(self.rel_tot, self.dim)
        self.rel_im_embeddings = nn.Embedding(self.rel_tot, self.dim)

        nn.init.xavier_uniform_(self.ent_re_embeddings.weight.data)
        nn.init.xavier_uniform_(self.ent_im_embeddings.weight.data)
        nn.init.xavier_uniform_(self.rel_re_embeddings.weight.data)
        nn.init.xavier_uniform_(self.rel_im_embeddings.weight.data)

    # (a+bi)(c+di)
    def _calc(self, h_re, h_im, t_re, t_im, r_re, r_im):
        return torch.sum(
            h_re * t_re * r_re
            + h_im * t_im * r_re
            + h_re * t_im * r_im
            - h_im * t_re * r_im,
            -1
        )


    def forward(self, h, r, t, n):
        h_re = self.ent_re_embeddings(h)
        h_im = self.ent_im_embeddings(h)
        t_re = self.ent_re_embeddings(t)
        t_im = self.ent_im_embeddings(t)
        r_re = self.rel_re_embeddings(r)
        r_im = self.rel_im_embeddings(r)
        n_re = self.ent_re_embeddings(n)
        n_im = self.ent_im_embeddings(n)
        pos_score = self._calc(h_re, h_im, t_re, t_im, r_re, r_im)
        neg_score = self._calc(n_re, n_im, t_re, t_im, r_re, r_im)
        neg_score_tail = self._calc(h_re, h_im, t_re, t_im, n_re, n_im)

        return pos_scorem, neg_score, neg_score_tail


    def regularization(self, data):
        batch_h = data['batch_h']
        batch_t = data['batch_t']
        batch_r = data['batch_r']
        h_re = self.ent_re_embeddings(batch_h)
        h_im = self.ent_im_embeddings(batch_h)
        t_re = self.ent_re_embeddings(batch_t)
        t_im = self.ent_im_embeddings(batch_t)
        r_re = self.rel_re_embeddings(batch_r)
        r_im = self.rel_im_embeddings(batch_r)
        regul = (torch.mean(h_re ** 2) + 
                 torch.mean(h_im ** 2) + 
                 torch.mean(t_re ** 2) +
                 torch.mean(t_im ** 2) +
                 torch.mean(r_re ** 2) +
                 torch.mean(r_im ** 2)) / 6
        
        return regul


    def predict(self, data):
        score = -self.forward(data)
        return score.cpu().data.numpy()

In [None]:
import torch, os
import json


ent2id = dict()
id2ent = set()
rel2id = dict()
id2rel = set()
with open('train.txt', 'r') as f:
    for line in f:
        line = line.strip()
        line = line.split('\t')
        id2ent.add(line[0])
        id2rel.add(line[1])
        id2ent.add(line[2])

with open('valid.txt', 'r') as f:
    for line in f:
        line = line.strip()
        line = line.split('\t')
        id2ent.add(line[0])
        id2rel.add(line[1])
        id2ent.add(line[2])

with open('test.txt', 'r') as f:
    for line in f:
        line = line.strip()
        line = line.split('\t')
        id2ent.add(line[0])
        id2rel.add(line[1])
        id2ent.add(line[2])

id2ent = sorted(list(id2ent))
id2rel = sorted(list(id2rel))

for i,meta in enumerate(id2ent):
    ent2id[meta] = i

for i,meta in enumerate(id2rel):
    rel2id[meta] = i

In [None]:
from torch.utils.data import Dataset
import torch
import json
import numpy as np

class DataSet(Dataset):
    def __init__(self, file_path):
        self.len = 0
        self.head = []
        self.rel = []
        self.tail = []
        self.triple = []
        self.negative = []
        self.ent2id = torch.load('ent2id.pt')
        self.id2ent = torch.load('id2ent.pt')
        self.rel2id = torch.load('rel2id.pt')
        self.id2rel = torch.load('id2rel.pt')
        self.ent_tot = len(self.id2ent)
        self.rel_tot = len(self.id2rel)
        with open(file_path) as f:
            for line in f:
                line = line.strip()
                line = line.split('\t')
                self.len += 1
                self.head.append(int(line[0]))
                self.rel.append(int(line[1]))
                self.tail.append(int(line[2]))
                self.negative.append(np.random.randint(0, len(self.id2ent)))

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        return self.head[idx], self.rel[idx], self.tail[idx], self.negative[idx]

In [None]:
# 0 23 17000 positive -> true score high
# 0 23 74000 negative -> true score negative

In [None]:
def main():
    opts = args
    print ("load data ...")
    train_data = DataSet('data/train2id.txt')
    train_loader = DataLoader(train_data, shuffle=True, batch_size=opts.batch_size)
    valid_data = DataSet('data/valid2id.txt')
    valid_loader = DataLoader(train_data, shuffle=True, batch_size=opts.batch_size)

    print("save model...")
    torch.save(model.state_dict(), 'kbgat.pt')
    print("[Saving embeddings of whole entities & relations...]")

    save_embeddings(model, opts, train_data.id2ent, train_data.id2rel)
    print("[Embedding results are saved successfully.]")

    print("load model ...")
    model = TransE(opts, train_data.ent_tot, train_data.rel_tot)
    if opts.optimizer == 'Adam':
        optimizer = optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.weight_decay)
    elif opts.optimizer == 'SGD':
        optimizer = optim.SGD(model.parameters(), lr=opts.lr)
    model.cuda()
    loss = nn.MarginRankingLoss(margin=opts.margin)
    loss.cuda()
    scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer, step_size=500, gamma=0.5, last_epoch=-1)

    print("start training")
    for epoch in range(1, opts.epochs + 1):
        print("epoch : " + str(epoch))
        model.train()
        epoch_start = time.time()
        epoch_loss = []
        tot = 0
        for i, batch_data in enumerate(train_loader):
            optimizer.zero_grad()
            batch_h, batch_r, batch_t, batch_n = batch_data
            batch_h = torch.LongTensor(batch_h).cuda()
            batch_r = torch.LongTensor(batch_r).cuda()
            batch_t = torch.LongTensor(batch_t).cuda()
            batch_n = torch.LongTensor(batch_n).cuda()
            pos_score, neg_score = model(batch_h, batch_r, batch_t, batch_n)
            train_loss = loss(pos_score, neg_score, -torch.ones(pos_score.size(-1)).cuda())
            train_loss.backward()
            optimizer.step()
            batch_loss = train_loss.item()
            epoch_loss.append(batch_loss)
            tot += batch_h.size(0)
            print('\r{:>10} epoch {} progress {} loss: {}\n'.format('', epoch, tot / train_data.__len__(),
                                                                    train_loss.item()), end='')
        scheduler.step()
        end = time.time()
        time_used = end - epoch_start
        print('one epoch time: {} minutes'.format(time_used / 60))
        print('{} epochs'.format(epoch))
        print('epoch {} loss: {}'.format(epoch, sum(epoch_loss) / len(epoch_loss)))

        with open('transe_log.txt', 'a') as f:
            f.write('loss : ' + str(sum(epoch_loss) / len(epoch_loss)) + '\n')

        if epoch % opts.save_step == 0:
            print("save model...")
            torch.save(model.state_dict(), 'transe.pt')

    print("save model...")
    torch.save(model.state_dict(), 'transe.pt')
    print("[Saving embeddings of whole entities & relations...]")


if __name__ == '__main__':
    main()

# Modeling

In [None]:
import math
import pickle
from typing import Any, Dict, List

import jsonlines
import torch
import transformers
from pytorch_lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only

import config
import genie.metrics as CustomMetrics
import os
from genie.constrained_generation import Trie, get_information_extraction_prefix_allowed_tokens_fn_hf
from genie.datamodule.utils import TripletUtils
from genie.models import GenieHF

from .utils import label_smoothed_nll_loss # ?

In [None]:
log = general_utils.get_logger(__name__)


class GeniePL(LightningModule):
    """
    A LightningModule organizes your PyTorch code into 5 sections:
        - Setup for all computations (init).
        - Train loop (training_step)
        - Validation loop (validation_step)
        - Test loop (test_step)
        - Optimizers (configure_optimizers)
    Read the docs:
        https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html
    """


    def __init__(self, hf_config=None, hparams_overrides=None, **kwargs):
        super().__init__()

        # this line ensures params passed to LightningModule will be saved to ckpt
        # it also allows to access params with 'self.hparams' attribute
        self.save_hyperparameters()

        if hparams_overrides is not None:
            # Overriding the hyper-parameters of a checkpoint at an arbitrary depth using a dict structure
            hparams_overrides = self.hparams.pop("hparams_overrides")
            general_utils.update(self.hparams, hparams_overrides)
            log.info("Some values of the original hparams were overridden")
            log.info("Hyper-parameters:")
            log.info(self.hparams)

        if self.hparams.hf_config is not None:
            # Initialization from a local, pre-trained GenIE PL checkpoint

            if self.hparams.get("other_parameters", None) is not None:
                self.hparams.hf_config.update(self.hparams.other_parameters)

            self.model = GenieHF(self.hparams.hf_config)

            assert self.hparams.get("tokenizer", False) or self.hparams.get("tokenizer_path", False), (
                "If you initialize the model from a local checkpoint "
                "you need to either pass the tokenizer or the path to the tokenizer in the "
                "constructor "
            )

            if self.hparams.get("tokenizer", False):
                self.tokenizer = self.hparams["tokenizer"]
            else:
                self.tokenizer = transformers.BartTokenizer.from_pretrained(self.hparams.tokenizer_path)
                
        else:
            # Initialization from a HF model
            self.model, hf_config = GenieHF.from_pretrained(
                self.hparams.model_name_or_path,
                return_dict=True,
                other_parameters=self.hparams.get("other_parameters", None),
            )
            self.tokenizer = transformers.BartTokenizer.from_pretrained(
                "martinjosifoski/genie-rw"
                if self.hparams.model_name_or_path == "random"
                else self.hparams.model_name_or_path
            )
            self.hparams.tokenizer = self.tokenizer  # Save in the checkpoint
            self.hparams.hf_config = hf_config  # Save in the checkpoint

        log.info("HF model config:")
        log.info(self.hparams.hf_config)

        self.ts_precision = CustomMetrics.TSPrecision()
        self.ts_recall = CustomMetrics.TSRecall()
        self.ts_f1 = CustomMetrics.TSF1()

        if not self.hparams.inference["free_generation"]:
            self.entity_trie = Trie.load(self.hparams.inference["entity_trie_path"])
            self.relation_trie = Trie.load(self.hparams.inference["relation_trie_path"])

        self.testing_output_parent_dir = kwargs.get("testing_output_parent_dir", None)


    def forward(self, input_ids, attention_mask, decoder_attention_mask, labels=None, **kwargs):
        output = self.model(
            input_ids,
            attention_mask=attention_mask,
            labels=labels,
            decoder_attention_mask=decoder_attention_mask,
            **kwargs,
        )

        return output


    def process_batch(self, batch):
        if self.hparams.get("bos_as_first_token_generated", True):
            return batch

        # remove the starting bos token from the target
        batch["trg_input_ids"] = batch["trg_input_ids"][:, 1:]
        batch["trg_attention_mask"] = batch["trg_attention_mask"][:, 1:]

        return batch


    def training_step(self, batch, batch_idx=None):
        batch = self.process_batch(batch)

        model_output = self(
            input_ids=batch["src_input_ids"],
            attention_mask=batch["src_attention_mask"],
            labels=batch["trg_input_ids"],
            decoder_attention_mask=batch["trg_attention_mask"],
            use_cache=False,
        )

        # the output from hf contains a loss term that can be used in training (see the function commented out above)
        logits = model_output.logits

        # Note that pad_token_id used in trg_input_ids is 1, and not -100 used by the hugging face loss implementation
        loss, nll_loss = label_smoothed_nll_loss(
            logits.log_softmax(dim=-1),
            batch["trg_input_ids"],
            batch["trg_attention_mask"],
            epsilon=self.hparams.eps,
            ignore_index=self.tokenizer.pad_token_id,
        )

        self.log("train-nll_loss", nll_loss.item(), on_step=True, on_epoch=False, prog_bar=True)

        return {"loss": loss}


    def validation_step(self, batch, batch_idx=None):
        batch = self.process_batch(batch)

        model_output = self(
            input_ids=batch["src_input_ids"],
            attention_mask=batch["src_attention_mask"],
            labels=batch["trg_input_ids"],
            decoder_attention_mask=batch["trg_attention_mask"],
            use_cache=False,
        )

        logits = model_output.logits

        # Note that pad_token_id used in trg_input_ids is 1, and not -100 used by the hugging face loss implementation
        loss, nll_loss = label_smoothed_nll_loss(
            logits.log_softmax(dim=-1),
            batch["trg_input_ids"],
            batch["trg_attention_mask"],
            epsilon=self.hparams.eps,
            ignore_index=self.tokenizer.pad_token_id,
        )

        self.log("val-nll_loss", nll_loss.item(), on_step=False, on_epoch=True, prog_bar=True)

        return {"val-nll_loss": nll_loss}


    def test_step(self, batch, batch_idx):
        raw_input = [sample["src"] for sample in batch["raw"]]
        raw_target = [sample["trg"] for sample in batch["raw"]]
        ids = [sample["id"] for sample in batch["raw"]]

        # ==== Prediction related ===

        # Generate predictions
        if self.hparams.inference["free_generation"]:
            outputs = self.sample(
                batch,
                input_data_is_processed_batch=True,
                return_dict_in_generate=True,
                output_scores=True,
                testing=True,
                **self.hparams.inference["hf_generation_params"],
            )
        else:
            prefix_allowed_tokens_fn = get_information_extraction_prefix_allowed_tokens_fn_hf(
                self,
                raw_input,
                bos_as_first_token_generated=self.hparams.get("bos_as_first_token_generated", True),
                entities_trie=self.entity_trie,
                relations_trie=self.relation_trie,
            )
            outputs = self.sample(
                batch,
                input_data_is_processed_batch=True,
                return_dict_in_generate=True,
                output_scores=True,
                testing=True,
                prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
                **self.hparams.inference["hf_generation_params"],
            )

        preds = []
        for lpreds in outputs:
            # lpreds is a list of <= `num_return_sequences` predictions
            pred = None

            if len(lpreds) > 0:
                score = lpreds[0]["log_prob"]
                if score != -1e9 and score != -math.inf:
                    pred = lpreds[0]["text"]

            preds.append(pred)

        return_object = {"ids": ids, "inputs": raw_input, "targets": raw_target, "predictions": preds}

        if self.hparams.inference["save_testing_data"] and self.hparams.inference["save_full_beams"]:
            return_object["full_predictions"] = outputs

        self._write_testing_output(return_object)

        return return_object


    def test_step_end(self, outputs: List[Any]):
        # Process the data in the format expected by the metrics
        predictions = [
            TripletUtils.convert_text_sequence_to_text_triples(
                text, verbose=self.hparams.inference["verbose_flag_in_convert_to_triple"]
            )
            for text in outputs["predictions"]
        ]
        targets = [
            TripletUtils.convert_text_sequence_to_text_triples(
                text, verbose=self.hparams.inference["verbose_flag_in_convert_to_triple"]
            )
            for text in outputs["targets"]
        ]

        # Update the metrics
        p = self.ts_precision(predictions, targets)
        r = self.ts_recall(predictions, targets)
        f1 = self.ts_f1(predictions, targets)

        # Log the loss
        self.log("test-precision_step", p, on_step=True, on_epoch=False, prog_bar=True)
        self.log("test-recall_step", r, on_step=True, on_epoch=False, prog_bar=True)
        self.log("test-f1_step", f1, on_step=True, on_epoch=False, prog_bar=True)


    def _write_testing_output(self, step_output):
        output_path = f"testing_output_{self.global_rank}.jsonl"

        if self.testing_output_parent_dir is not None:
            output_path = os.path.join(self.testing_output_parent_dir, output_path)

        with jsonlines.open(output_path, "a") as writer:
            items = []

            for i in range(len(step_output["predictions"])):
                item_data = {
                    "id": step_output["ids"][i],
                    "input": step_output["inputs"][i],
                    "target": step_output["targets"][i],
                    "prediction": step_output["predictions"][i],
                }

                if self.hparams.inference["save_testing_data"] and self.hparams.inference["save_full_beams"]:
                    item_data["full_prediction"] = step_output["full_predictions"][i]

                items.append(item_data)

            writer.write_all(items)


    @rank_zero_only
    def _write_testing_outputs(self, outputs):
        output_path = f"testing_output.jsonl"

        if self.testing_output_parent_dir is not None:
            output_path = os.path.join(self.testing_output_parent_dir, output_path)

        with jsonlines.open(output_path, "w") as writer:
            for process_output in outputs:
                for step_output in process_output:
                    items = []

                    for i in range(len(step_output["predictions"])):
                        item_data = {
                            "id": step_output["ids"][i],
                            "input": step_output["inputs"][i],
                            "target": step_output["targets"][i],
                            "prediction": step_output["predictions"][i],
                        }

                        if self.hparams.inference["save_testing_data"] and self.hparams.inference["save_full_beams"]:
                            item_data["full_prediction"] = step_output["full_predictions"][i]

                        items.append(item_data)

                    writer.write_all(items)


    def test_epoch_end(self, outputs):
        """Outputs is a list of either test_step outputs outputs"""
        # Log metrics aggregated across steps and processes (in ddp)
        self.log("test-precision", self.ts_precision.compute())
        self.log("test-recall", self.ts_recall.compute())
        self.log("test-f1", self.ts_f1.compute())

        if self.hparams.inference["save_testing_data"]:
            # TODO: Can achieve the same result by collating the testing_output_{rank}.jsonl files
            if torch.distributed.is_initialized():
                torch.distributed.barrier()
                gather = [None] * torch.distributed.get_world_size()
                torch.distributed.all_gather_object(gather, outputs)
                # Gather is a list of `num_gpu` elements, each being the outputs object passed to the test_epoch_end
                outputs = gather
            else:
                outputs = [outputs]

            self._write_testing_outputs(outputs)

        return {
            "test-acc": self.ts_precision.compute(),
            "test-recall": self.ts_precision.compute(),
            "test-f1": self.ts_precision.compute(),
        }


    def configure_optimizers(self):
        # Apply weight decay to all parameters except for the biases and the weight for Layer Normalization
        no_decay = ["bias", "LayerNorm.weight"]

        # Per-parameter optimization.
        # Each dict defines a parameter group and contains the list of parameters to be optimized in a key `params`
        # Other keys should match keyword arguments accepted by the optimizers and
        # will be used as optimization params for the parameter group
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.hparams.weight_decay,
                # "betas": self.hparams.adam_betas,
                "betas": (0.9, 0.999),
                "eps": self.hparams.adam_eps,
            },
            {
                "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
                # "betas": self.hparams.adam_betas,
                "betas": (0.9, 0.999),
                "eps": self.hparams.adam_eps,
            },
        ]

        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )

        if self.hparams.schedule_name == "linear":
            scheduler = transformers.get_linear_schedule_with_warmup(
                optimizer,
                num_warmup_steps=self.hparams.warmup_updates,
                num_training_steps=self.hparams.total_num_updates,
            )
        elif self.hparams.schedule_name == "polynomial":
            scheduler = transformers.get_polynomial_decay_schedule_with_warmup(
                optimizer,
                num_warmup_steps=self.hparams.warmup_updates,
                num_training_steps=self.hparams.total_num_updates,
                lr_end=self.hparams.lr_end,
            )

        lr_dict = {
            "scheduler": scheduler,  # scheduler instance
            "interval": "step",  # The unit of the scheduler's step size. 'step' or 'epoch
            "frequency": 1,  # corresponds to updating the learning rate after every `frequency` epoch/step
            "name": f"LearningRateScheduler-{self.hparams.schedule_name}",  # Used by a LearningRateMonitor callback
        }

        return [optimizer], [lr_dict]


    @staticmethod
    def _convert_surface_form_triplets_to_ids(triplets, entity_name2id, relation_name2id):
        triplets = [[entity_name2id[s], relation_name2id[r], entity_name2id[o]] for s, r, o in triplets]

        return triplets


    @staticmethod
    def _convert_output_to_triplets(output_obj, entity_name2id, relation_name2id):
        if isinstance(output_obj[0], str):
            output = []
            for text in output_obj:
                triplets = TripletUtils.convert_text_sequence_to_text_triples(text)

                if entity_name2id is not None and relation_name2id is not None:
                    triplets = GeniePL._convert_surface_form_triplets_to_ids(triplets, entity_name2id, relation_name2id)

                output.append(triplets)

            return output

        for sample in output_obj:
            sample["textual_triplets"] = TripletUtils.convert_text_sequence_to_text_triples(sample["text"])
            if entity_name2id is not None and relation_name2id is not None:
                sample["id_triplets"] = GeniePL._convert_surface_form_triplets_to_ids(
                    sample["textual_triplets"], entity_name2id, relation_name2id
                )

        return output_obj


    def sample(
        self,
        input_data,
        input_data_is_processed_batch=False,
        testing=False,
        seed=None,
        prefix_allowed_tokens_fn=None,
        entity_trie=None,
        relation_trie=None,
        convert_to_triplets=False,
        surface_form_mappings={"entity_name2id": None, "relation_name2id": None},
        **kwargs,
    ):
        training = self.training
        if training:
            self.eval()

        """Input data is a list of strings or a processed batch (contains src_input_ids,
        and src_attention_mask as expected in training)"""
        inference_parameters = self.hparams.inference["hf_generation_params"].copy()
        inference_parameters.update(kwargs)

        with torch.no_grad():
            # Get input_ids and attention masks
            if input_data_is_processed_batch:
                input_ids = input_data["src_input_ids"]
                attention_mask = input_data["src_attention_mask"]
                if prefix_allowed_tokens_fn is None and "raw" in input_data:
                    raw_input = [sample["src"] for sample in input_data["raw"]]
                else:
                    raw_input = None
            else:
                tokenizer_output = {
                    k: v.to(self.device)
                    for k, v in self.tokenizer(
                        input_data,
                        return_tensors="pt",
                        padding=True,
                        max_length=self.hparams.max_input_length,
                        truncation=True,
                    ).items()
                }  # input_ids and attention_masks with `num_sentences x max_length` dims
                input_ids = tokenizer_output["input_ids"]
                attention_mask = tokenizer_output["attention_mask"]
                raw_input = input_data

            # If an entity and relation prefix trie were passed, construct the corresponding constraining function
            if entity_trie is not None and relation_trie is not None:
                prefix_allowed_tokens_fn = get_information_extraction_prefix_allowed_tokens_fn_hf(
                    self,
                    raw_input,
                    bos_as_first_token_generated=self.hparams.get("bos_as_first_token_generated", True),
                    entities_trie=entity_trie,
                    relations_trie=relation_trie,
                )

            # Set the seed and generate the predictions
            if testing:
                transformers.trainer_utils.set_seed(self.hparams.inference["seed"])
            elif seed is not None:
                transformers.trainer_utils.set_seed(seed)

            output = self.model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                no_repeat_ngram_size=inference_parameters.pop("no_repeat_ngram_size", 0),
                max_length=inference_parameters.pop("max_length", self.hparams.max_output_length),
                early_stopping=inference_parameters.pop("early_stopping", False),
                prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
                **inference_parameters,
            )

            k = inference_parameters.get("num_return_sequences", 1)
            # Process the output and construct a return object
            if inference_parameters.get("return_dict_in_generate", False):
                output["sequences"] = self.tokenizer.batch_decode(output["sequences"], skip_special_tokens=True)
                output["sequences_scores"] = output["sequences_scores"].tolist()

                assert len(output["sequences"]) == len(output["sequences_scores"])

                batch = [
                    (output["sequences"][i : i + k], output["sequences_scores"][i : i + k])
                    for i in range(0, len(output["sequences"]), k)
                ]

                output = []

                # Constructs the returned object and filters ill-formatted sequences
                for seqs, scores in batch:
                    output_obj = [
                        {"text": seq, "log_prob": score}
                        for seq, score in zip(seqs, scores)
                        # if score != -1e9 and score != -math.inf
                    ]

                    if convert_to_triplets:
                        output_obj = GeniePL._convert_output_to_triplets(output_obj, **surface_form_mappings)
                        # for sample in output_obj:
                        #     sample['triplets'] = TripletUtils.convert_text_sequence_to_text_triples(sample['text'])

                    output_obj = sorted(output_obj, key=lambda x: x["log_prob"], reverse=True)
                    output.append(output_obj)

                # returns a list of `num_sentences` lists
                # Where each inner list has `num_return_sequences` elements`
                # Where each dictionary has keys "text" and "log_prob" corresponding to a single predicted sequence
                # The elements in the list are sorted in descending order with respect to the log_prob
                return output

            # Returns a list of `num_sentences` decoded (textual) sequences
            output = self.tokenizer.batch_decode(output, skip_special_tokens=True)
            if convert_to_triplets:
                output = GeniePL._convert_output_to_triplets(output, **surface_form_mappings)
                # output = [TripletUtils.convert_text_sequence_to_text_triples(text) for text in output]

            output = [output[i: i + k] for i in range(0, len(output), k)]

            if training:
                self.train()

            return output

In [None]:
import os
import jsonlines
from torch.utils.data import Dataset
import config
from genie.datamodule.utils import TripletUtils
from tqdm import tqdm

import genie.utils.general as utils


log = utils.get_logger(__name__)

class Seq2SeqDataset(Dataset):
    def __init__(self, tokenizer, data, **kwargs):
        super().__init__()
        self.tokenizer = tokenizer
        self.data = data
        self.params = kwargs

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return {
            "id": self.data[idx][0],
            "src": self.data[idx][1],
            "trg": self.data[idx][2],
        }

    @classmethod
    def from_src_target_files(cls, tokenizer, data_dir, data_split, **kwargs):
        with open(os.path.join(data_dir, f"{data_split}.source")) as fs, open(
            os.path.join(data_dir, f"{data_split}.target")
        ) as ft:
            data = [(s.strip(), t.strip()) for s, t in zip(fs, ft)]

        return cls(tokenizer, data, kwargs)


    def collate_fn(self, batch):
        """batch is a list of samples retrieved with the above defined get function. We assume that the model generated the decoder_ids and any non-standard token processing on itself."""
        collated_batch = {}

        for attr_name in "src", "trg":
            if attr_name == "src":
                max_length = self.params["max_input_length"]

            elif attr_name == "trg":
                max_length = self.params["max_output_length"]
            else:
                raise Exception(f"Unexpected attribute name `{attr_name}`!")

            tokenizer_output = self.tokenizer(
                [sample[attr_name] for sample in batch],
                return_tensors="pt",  # return PyTorch tensors
                return_attention_mask=True,
                padding=self.params["padding"],
                max_length=max_length,
                truncation=self.params["truncation"],
            )

            for k, v in tokenizer_output.items():
                collated_batch["{}_{}".format(attr_name, k)] = v

        if self.params.get("target_padding_token_id", None) is not None:
            trg_input_ids = collated_batch["trg_input_ids"]
            trg_input_ids.masked_fill_(
                trg_input_ids == self.tokenizer.pad_token_id, self.params["target_padding_token_id"]
            )

        collated_batch["raw"] = batch

        return collated_batch

In [None]:

    # check that lengths match, and that the start and end of tag tokens are the same
    l = []
    s = []
    e = []
    for n, c in full_codes.items():
        l.append(len(c))
        s.append(c[1])
        e.append(c[-2])

    assert np.all(np.array(l) == l[0])
    assert np.all(np.array(s) == s[0])
    assert np.all(np.array(e) == e[0])

    codes = {n: full_codes[n][2] for n in full_codes}
    tag_codes = set(codes[k] for k in codes)

    codes["start_of_tag"] = s[0]
    codes["end_of_tag"] = e[0]

    codes["EOS"] = eos_token_id
    codes["BOS"] = bos_token_id

    status_codes = ["ob", "s", "r", "o"]
    status_next_token_name = ["subject_token", "relation_token", "object_token", "end_of_entity_token"]

    if sentences is not None:
        sent_origs = [[codes["EOS"]] + encode_fn(sent)[1:] for sent in sentences]
    else:
        sent_origs = []


    def get_status(sent):
        """Returns the generation setting – mention generation, entity generation or outside"""
        status = 0

        i = 0
        while i < len(sent) - 2:
            if sent[i] == codes["start_of_tag"] and sent[i + 1] in tag_codes and sent[i + 2] == codes["end_of_tag"]:
                status += 1

            i += 1

        status = status % 4

        return status, status_codes[status]


    def get_last_tag_pointer(sent):
        """Assumes that the last tag is fully generated i.e. <tag_name>"""
        i = len(sent) - 2

        while i >= 0:
            if sent[i] == codes["start_of_tag"] and sent[i + 1] in tag_codes and sent[i + 2] == codes["end_of_tag"]:
                return i, i + 2

            i -= 1

        return None


    def prefix_allowed_tokens_fn(batch_id, sent):
        """Sent is the thus far generated sequence of ids acting as output.
        Batch_id is the idx of the sentence that we are generating the output for."""
        sent = sent.tolist()

        # TODO: Figure out when and why the generation doesn't end after EOS is generated.
        # TODO: If the next two lines are removed, output contains many "EOS EOS EOS..." at the end.
        if len(sent) > 1 and sent[-1] == codes["EOS"]:
            return []

        # Force the generation of BOS as a first token to be generated
        # Necessary if the model is trained with [eos bos ... eos] as target
        if bos_as_first_token_generated and len(sent) == 1:
            return [codes["BOS"]]

        status, status_code = get_status(sent)
        if len(sent_origs) == 0:
            sent_orig = None
        else:
            sent_orig = sent_origs[batch_id]

        # ---- IF inside a tag ----
        # return the next status tag if the start tag was generated last
        if len(sent) > 0 and sent[-1] == codes["start_of_tag"]:
            return [codes[status_next_token_name[status]]]

        # return closing tag if the start tag and the status tag have been generated
        if len(sent) > 1 and sent[-2] == codes["start_of_tag"]:
            if sent[-1] in tag_codes:
                return [codes["end_of_tag"]]
            else:
                return []
        # -------------------------

        # ---- If outside of a tag ----
        # Get allowed tokens
        allowed_tokens = get_allowed_tokens(sent, sent_orig, status_code)
        
        return allowed_tokens


    def get_allowed_tokens(sent, sent_orig, status_code):
        if status_code == "ob":
            allowed_tokens = [codes["start_of_tag"], codes["EOS"]]
        elif status_code == "s":
            allowed_tokens = _get_allowed_tokens(sent, sent_orig, entities_trie)
        elif status_code == "r":
            allowed_tokens = _get_allowed_tokens(sent, sent_orig, relations_trie)
        elif status_code == "o":
            allowed_tokens = _get_allowed_tokens(sent, sent_orig, entities_trie)
        else:
            raise RuntimeError

        return allowed_tokens


    def _get_allowed_tokens(sent, sent_orig, trie):
        pointer_start, pointer_end = get_last_tag_pointer(sent)

        allowed_tokens = trie.get(sent[pointer_end + 1 :])

        if codes["EOS"] in allowed_tokens:
            allowed_tokens.remove(codes["EOS"])
            allowed_tokens.append(codes["start_of_tag"])

        return allowed_tokens


    return prefix_allowed_tokens_fn

In [None]:
from collections import defaultdict
import jsonlines
import pickle
import os


def get_trie_from_strings(
    string_iterable,
    add_leading_space_flag=True,
    remove_leading_bos=True,
    output_folder_path=None,
    trie_name=None,
    tokenizer=None,
):
    assert (output_folder_path is None and trie_name is None) or (
        output_folder_path is not None and trie_name is not None
    )
    from tqdm import tqdm

    if tokenizer is None:
        from transformers import BartTokenizer

        tokenizer = BartTokenizer.from_pretrained("martinjosifoski/genie-rw")

    if add_leading_space_flag:
        leading_space = lambda x: f" {x}"
    else:
        leading_space = lambda x: x

    if remove_leading_bos:
        leading_bos = lambda x: x[1:]
    else:
        leading_bos = lambda x: x

    encode_func = lambda x: leading_bos(tokenizer(leading_space(x))["input_ids"])
    trie = Trie([encode_func(uniq_name) for uniq_name in tqdm(sorted(string_iterable))])

    if output_folder_path is not None:
        trie.dump(output_folder_path=output_folder_path, file_name=trie_name, string_iterable=string_iterable)

    return trie



class Trie(object):
    def __init__(self, sequences):
        """sequences is a list of lists,
        each of which corresponds to a sequence of tokens encoded by the tokenizer"""
        next_sets = defaultdict(list)  # a dict that returns an empty list when the key is not in it
        for seq in sequences:
            if len(seq) > 0:
                next_sets[seq[0]].append(seq[1:])

        self._leaves = {k: Trie(v) for k, v in next_sets.items()}
        # for the leaves of the trie _leaves == {}


    def get(self, indices):  # indices holds the list of vocabulary tokens that constitute the current prefix
        if len(indices) == 0:  # if we haven't generated anything so far: return all possible starting tokens
            return list(self._leaves.keys())
        elif indices[0] not in self._leaves:
            # if the currently leading token (and by extension the prefix) isn't eligible: return an empty list
            return []
        else:
            return self._leaves[indices[0]].get(indices[1:])  # take the trie that corresponds to the


    def dump(self, output_folder_path, file_name, string_iterable=None):
        pickle.dump(self, open(os.path.join(output_folder_path, f"{file_name}.pickle"), "wb"), protocol=4)

        if string_iterable is not None:
            with jsonlines.open(os.path.join(output_folder_path, f"{file_name}_original_strings.jsonl"), "w") as writer:
                writer.write_all(string_iterable)


    @staticmethod
    def load(path):
        with open(path, "rb") as f:
            trie = pickle.load(f)

        return trie