<a href="https://colab.research.google.com/github/sai-sreekhar/ML-Learning/blob/main/LUKE/Supervised_relation_extraction_with_LukeForEntityPairClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In this notebook, we are going to fine-tune [`LukeForEntityPairClassification`](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) on a supervised **relation extraction** dataset.

The goal for the model is to predict, given a sentence and the character spans of two entities within the sentence, the relationship between the entities.

The author of LUKE has fine-tuned this model on the [TACRED](https://nlp.stanford.edu/projects/tacred/) dataset, an important supervised relation extraction dataset by Stanford University, and obtains state-of-the-art results with it. 

* Paper: https://arxiv.org/abs/2010.01057
* Original repository: https://github.com/studio-ousia/luke

In [None]:
!pip install -q transformers 

In [None]:
!pip install -q pytorch-lightning wandb

## Read in data

Let's download the data from the web, hosted on Dropbox.

In [None]:
import requests, zipfile, io

def download_data():
    url = "https://www.dropbox.com/s/izi2x4sjohpzoot/relation_extraction_dataset.zip?dl=1"
    r = requests.get(url)
    z = zipfile.ZipFile(io.BytesIO(r.content))
    z.extractall()

download_data()

Each row in the dataframe consists of a news article, and a sentence in which a certain relationship was found (just as "invested_in", or "founded_by"). There were some patterns used to gather the data, so it might contain some noise. 

In [None]:
import pandas as pd

df = pd.read_pickle("relation_extraction_dataset.pkl")
df.reset_index(drop=True, inplace=True)
df.head()
# print(df.iloc[0]['sentence'])

Let's create 2 dictionaries, one that maps each label to a unique integer, and one that does it the other way around.

In [None]:
id2label = dict()
for idx, label in enumerate(df.string_id.value_counts().index):
  id2label[idx] = label

As we can see, there are 7 labels (7 unique relationships):

In [None]:
id2label

In [None]:
label2id = {v:k for k,v in id2label.items()}
label2id

In [None]:
df.shape

## Define the PyTorch dataset and dataloaders

Next, we define regular PyTorch datasets and corresponding dataloaders. In PyTorch, you need to define a `Dataset` class that inherits from `torch.utils.data.Dataset`, and you need to implement 3 methods: the `init` method (for initializing the dataset with data), the `len` method (which returns the number of elements in the dataset) and the `getitem()` method, which returns a single item from the dataset.

In our case, each item of the dataset consists of a sentence, the spans of 2 entities in the sentence, and a label of the relationship. We use `LukeTokenizer` (available in the Transformers library) to turn these into the inputs expected by the model, which are `input_ids`, `entity_ids`, `attention_mask`, `entity_attention_mask` and `entity_position_ids`.

For more information regarding these inputs, refer to the [docs](https://huggingface.co/transformers/model_doc/luke.html#lukeforentitypairclassification) of `LukeForEntityPairClassification`.


In [None]:
from transformers import LukeTokenizer
from torch.utils.data import Dataset, DataLoader
import torch

tokenizer = LukeTokenizer.from_pretrained("studio-ousia/luke-base", task="entity_pair_classification")

class RelationExtractionDataset(Dataset):
    """Relation extraction dataset."""

    def __init__(self, data):
        """
        Args:
            data : Pandas dataframe.
        """
        self.data = data

    def __len__(self):
        return len(self .data)

    def __getitem__(self, idx):
        item = self.data.iloc[idx]

        sentence = item.sentence
        entity_spans = [tuple(x) for x in item.entity_spans]

        encoding = tokenizer(sentence, entity_spans=entity_spans, padding="max_length", truncation=True, return_tensors="pt")

        for k,v in encoding.items():
          encoding[k] = encoding[k].squeeze()

        encoding["label"] = torch.tensor(label2id[item.string_id])

        return encoding

Here we instantiate the class defined above with 3 objects: a training dataset, a validation dataset and a test set.

In [None]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, shuffle=True)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42, shuffle=False)

# define the dataset
train_dataset = RelationExtractionDataset(data=train_df)
valid_dataset = RelationExtractionDataset(data=val_df)
test_dataset = RelationExtractionDataset(data=test_df)

In [None]:
train_dataset[0].keys()

Let's define the corresponding dataloaders (which allow us to iterate over the elements of the dataset):

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=2)
test_dataloader = DataLoader(test_dataset, batch_size=2)

Let's verify an example of a batch:

In [None]:
batch = next(iter(train_dataloader))
tokenizer.decode(batch["input_ids"][1])

In [None]:
id2label[batch["label"][1].item()]

## Define a PyTorch LightningModule

Let's define the model as a PyTorch LightningModule. A `LightningModule` is actually an `nn.Module`, but with some extra functionality.

For more information regarding how to define this, see the [docs](https://pytorch-lightning.readthedocs.io/en/latest/?_ga=2.56317931.1395871250.1622709933-1738348008.1615553774) of PyTorch Lightning.

In [None]:
from transformers import LukeForEntityPairClassification, AdamW
import pytorch_lightning as pl

class LUKE(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.model = LukeForEntityPairClassification.from_pretrained("studio-ousia/luke-base", num_labels=len(label2id))

    def forward(self, input_ids, entity_ids, entity_position_ids, attention_mask, entity_attention_mask):     
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, entity_ids=entity_ids, 
                             entity_attention_mask=entity_attention_mask, entity_position_ids=entity_position_ids)
        return outputs
    
    def common_step(self, batch, batch_idx):
        labels = batch['label']
        del batch['label']
        outputs = self(**batch)
        logits = outputs.logits

        criterion = torch.nn.CrossEntropyLoss() # multi-class classification
        loss = criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct/batch['input_ids'].shape[0]

        return loss, accuracy
      
    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        # logs metrics for each training_step,
        # and the average across the epoch
        self.log("training_loss", loss)
        self.log("training_accuracy", accuracy)

        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     
        self.log("validation_loss", loss, on_epoch=True)
        self.log("validation_accuracy", accuracy, on_epoch=True)

        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)     

        return loss

    def configure_optimizers(self):
        optimizer = AdamW(self.parameters(), lr=5e-5)
        return optimizer

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return valid_dataloader

    def test_dataloader(self):
        return test_dataloader

Let's verify a forward pass on a batch:

In [None]:
batch = next(iter(valid_dataloader))
labels = batch["label"]
batch.keys()

In [None]:
batch["input_ids"].shape

In [None]:
model = LUKE()
del batch["label"]
outputs = model(**batch)

The initial loss should be around -ln(1/number of classes) = -ln(1/7) = 1.95:

In [None]:
criterion = torch.nn.CrossEntropyLoss()

initial_loss = criterion(outputs.logits, labels)
print("Initial loss:", initial_loss)

## Train the model

Let's train the model. We also use early stopping, to avoid overfitting the training dataset. We also log everything to Weights and Biases, which will give us beautiful charts of the loss and accuracy plotted over time.

If you haven't already, you can create an account on the [website](https://wandb.ai/site), then log in in a web browser, and run the cell below: 

In [None]:
import wandb

wandb.login()

In [None]:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import EarlyStopping

wandb_logger = WandbLogger(name='luke-first-run-12000-articles-bis', project='LUKE')
# for early stopping, see https://pytorch-lightning.readthedocs.io/en/1.0.0/early_stopping.html?highlight=early%20stopping
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    patience=2,
    strict=False,
    verbose=False,
    mode='min'
)

trainer = Trainer(logger=wandb_logger, callbacks=[EarlyStopping(monitor='validation_loss')])
trainer.fit(model)

In [None]:
trainer.test()

## Evaluation

Instead of calling `trainer.test()`, we can also manually evaluate the model on the entire test set:

In [None]:
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score

loaded_model.model.eval()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
loaded_model.to(device)

predictions_total = []
labels_total = []
for batch in tqdm(test_dataloader):
    # get the inputs;
    labels = batch["label"]
    del batch["label"]

    # move everything to the GPU
    for k,v in batch.items():
      batch[k] = batch[k].to(device)

    # forward pass
    outputs = loaded_model.model(**batch)
    logits = outputs.logits
    predictions = logits.argmax(-1)
    predictions_total.extend(predictions.tolist())
    labels_total.extend(labels.tolist())

In [None]:
print("Accuracy on test set:", accuracy_score(labels_total, predictions_total))

## Inference

Here we test the trained model on a new, unseen sentence.

In [None]:
loaded_model = LUKE.load_from_checkpoint(checkpoint_path="/content/drive/Shareddrives/Datascouts/epoch=3-step=7699.ckpt")

In [None]:
test_df.iloc[0].sentence

In [None]:
import torch.nn.functional as F

idx = 2
text = test_df.iloc[idx].sentence
entity_spans = test_df.iloc[idx].entity_spans  # character-based entity spans 
entity_spans = [tuple(x) for x in entity_spans]

inputs = tokenizer(text, entity_spans=entity_spans, return_tensors="pt")

outputs = loaded_model.model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print("Sentence:", text)
print("Ground truth label:", test_df.iloc[idx].string_id)
print("Predicted class idx:", id2label[predicted_class_idx])
print("Confidence:", F.softmax(logits, -1).max().item())