# Relation extraction with BERT

---

The goal of this repo is to show how to use [BERT](https://arxiv.org/abs/1810.04805)
to [extract relation](https://en.wikipedia.org/wiki/Relationship_extraction) from text.

Used libraries:
- [Transformers](https://huggingface.co/transformers/index.html)
- [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/)

Used datasets:
- SemEval 2010 Task 8 - [paper](https://arxiv.org/pdf/1911.10422.pdf) - [download](https://github.com/sahitya0000/Relation-Classification/blob/master/corpus/SemEval2010_task8_all_data.zip?raw=true)
-  Google IISc Distant Supervision (GIDS) - [paper](https://arxiv.org/pdf/1804.06987.pdf) - [download](https://drive.google.com/open?id=1gTNAbv8My2QDmP-OHLFtJFlzPDoCG4aI)

## Install dependencies

This project uses [Python 3.7+](https://www.python.org/downloads/release/python-378/)

In [2]:
!rm -r ./checkpoint
!rm -r ./data/processed

In [3]:
!pip install requests numpy pandas \
    scikit-learn pytorch-lightning torch \
    transformers sklearn==0.0 tqdm




In [4]:
!pip install sentencepiece



## Import needed modules

In [5]:
import json
import multiprocessing
import os
import pickle
import shutil
import zipfile
from abc import ABC, abstractmethod
from typing import Tuple
from urllib.parse import urlparse

import requests
import torch
from pytorch_lightning import LightningModule
from pytorch_lightning import Trainer as LightningTrainer
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch import Tensor, nn
from torch.nn import functional as F
from torch.optim import AdamW
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
#from transformers import *

## Define constants

Change the following constant to `False` if you are not running on Kaggle environment:

In [6]:
KAGGLE = False

Other constants:

In [7]:
from transformers import BertModel, BertTokenizer, DistilBertModel, DistilBertTokenizer, RobertaModel, RobertaTokenizer
# --- Directory ---
ROOT_DIR = os.path.abspath('.')
RAW_DATA_DIR = os.path.join(ROOT_DIR, 'input') if KAGGLE else os.path.join(ROOT_DIR, 'data/raw')
PROCESSED_DATA_DIR = os.path.join(ROOT_DIR, 'data/processed') 
CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoint')

# --- Datasets ---
DATASET_MAPPING = {
    'SemEval2010Task8': {
        'dir': os.path.join(RAW_DATA_DIR,'SemEval2010_task8_all_data'),
        'url': 'https://github.com/sahitya0000/Relation-Classification/'
               'blob/master/corpus/SemEval2010_task8_all_data.zip?raw=true',
        'num_classes': 10,
    },
    'GIDS': {
        'dir': os.path.join(RAW_DATA_DIR,'gids_data'),
        'url': 'https://drive.google.com/uc?id=1gTNAbv8My2QDmP-OHLFtJFlzPDoCG4aI&export=download',
        'num_classes': 5
    },
    'Karst': {
        'dir': os.path.join(RAW_DATA_DIR,'Karst'),
        'url':'',
        'num_classes':14
    }
}
#DATASET_NAME = 'GIDS'
DATASET_NAME = 'Karst'
#DATASET_NAME = 'SemEval2010Task8'
# --- BERT ---
SUB_START_CHAR = '{'
SUB_END_CHAR = '}'
OBJ_START_CHAR = '['
OBJ_END_CHAR = ']'

# --- BERT Model ---
# See https://huggingface.co/transformers/pretrained_models.html for the full list

BERT_VARIANT_MAPPING = {
    'bert': {
        'model': BertModel,
        'tokenizer': BertTokenizer,
        'pretrain_weight': 'bert-base-uncased',
        'available_pretrain_weights': ['bert-base-uncased', 'bert-base-cased']
    },
    'distilbert': {
        'model': DistilBertModel,
        'tokenizer': DistilBertTokenizer,
        'pretrain_weight': 'distilbert-base-uncased',
        'available_pretrain_weights': ['distilbert-base-uncased', 'distilbert-base-cased']
    },
    'roberta': {
        'model': RobertaModel,
        'tokenizer': RobertaTokenizer,
        'pretrain_weight': 'roberta-base',
        'available_pretrain_weights': ['roberta-base', 'distilroberta-base']
    },
}
#BERT_VARIANT = 'distilbert'
#BERT_VARIANT = 'bert'
BERT_VARIANT = 'roberta'

## Download data

This part **CAN BE SKIPPED** if this notebook is running on Kaggle environment since the dataset has already been included.

First, we install `gdown` to download files from Google Drive

In [None]:
!pip install gdown
import gdown



Some download util functions:

In [None]:
def download_from_url(url: str, save_path: str, chunk_size: int = 2048) -> None:
    with open(save_path, "wb") as f:
        print(f"Downloading...\nFrom: {url}\nTo: {save_path}")
        response = requests.get(url, stream=True)
        for data in tqdm(response.iter_content(chunk_size=chunk_size)):
            f.write(data)

def download_from_google_drive(url: str, save_path: str) -> None:
    gdown.download(url, save_path, use_cookies=False)

def extract_zip(zip_file_path: str, extract_dir: str, remove_zip_file=True):
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        print("Extracting to " + extract_dir)
        zip_ref.extractall(extract_dir)

    if remove_zip_file:
        print("Removing zip file")
        os.unlink(zip_file_path)

The download function itself:

In [None]:
def download(dataset_name, dataset_url, dataset_dir, force_redownload: bool):
    print(f"\n---> Downloading dataset {dataset_name} <---")
    
    # create raw data dir
    if not os.path.exists(RAW_DATA_DIR):
        print("Creating raw data directory " + RAW_DATA_DIR)
        os.makedirs(RAW_DATA_DIR)
    
    # check data has been downloaded
    if os.path.exists(dataset_dir):
        if force_redownload:
            print(f"Removing old raw data {dataset_dir}")
            shutil.rmtree(dataset_dir)
        else:
            print(f"Directory {dataset_dir} exists, skip downloading.")
            return


    # download
    tmp_file_path = os.path.join(RAW_DATA_DIR, dataset_name + '.zip')
    if urlparse(dataset_url).netloc == 'drive.google.com':
        download_from_google_drive(dataset_url, tmp_file_path)
    else:
        download_from_url(dataset_url, tmp_file_path)

    # unzip
    extract_zip(tmp_file_path, RAW_DATA_DIR)

Download all datasets:

In [None]:
for dataset_name, dataset_info in DATASET_MAPPING.items():
    download(
        dataset_name,
        dataset_url=dataset_info['url'],
        dataset_dir=dataset_info['dir'],
        force_redownload=False
    )


---> Downloading dataset SemEval2010Task8 <---
Downloading...
From: https://github.com/sahitya0000/Relation-Classification/blob/master/corpus/SemEval2010_task8_all_data.zip?raw=true
To: /content/data/raw/SemEval2010Task8.zip


0it [00:00, ?it/s]

Extracting to /content/data/raw
Removing zip file

---> Downloading dataset GIDS <---


Downloading...
From: https://drive.google.com/uc?id=1gTNAbv8My2QDmP-OHLFtJFlzPDoCG4aI&export=download
To: /content/data/raw/GIDS.zip

  0%|          | 0.00/76.4M [00:00<?, ?B/s][A
  6%|▌         | 4.72M/76.4M [00:00<00:06, 11.1MB/s][A
 27%|██▋       | 20.4M/76.4M [00:00<00:01, 47.5MB/s][A
 45%|████▍     | 34.1M/76.4M [00:00<00:01, 42.2MB/s][A
 61%|██████    | 46.7M/76.4M [00:00<00:00, 57.1MB/s][A
 82%|████████▏ | 62.9M/76.4M [00:01<00:00, 78.8MB/s][A
100%|██████████| 76.4M/76.4M [00:01<00:00, 57.3MB/s]


Extracting to /content/data/raw
Removing zip file

---> Downloading dataset Karst <---
Directory /content/data/raw/Karst exists, skip downloading.


## Preprocess

The abstract preprocessor

In [8]:
from transformers import PreTrainedTokenizer
class AbstractPreprocessor(ABC):
    DATASET_NAME = ''

    def __init__(self, tokenizer: PreTrainedTokenizer):
        self.tokenizer = tokenizer

    def preprocess_data(self, reprocess: bool):
        print(f"\n---> Preprocessing {self.DATASET_NAME} dataset <---")
        
        # create processed data dir
        if not os.path.exists(PROCESSED_DATA_DIR):
            print("Creating processed data directory " + PROCESSED_DATA_DIR)
            os.makedirs(PROCESSED_DATA_DIR)

        # stop preprocessing if file existed
        pickled_file_names = [self.get_pickle_file_name(k) for k in ('train', 'val', 'test')]
        existed_files = [fn for fn in pickled_file_names if os.path.exists(fn)]
        if existed_files:
            file_text = "- " + "\n- ".join(existed_files)
            if not reprocess:
                print("The following files already exist:")
                print(file_text)
                print("Preprocessing is skipped. See option --reprocess.")
                return
            else:
                print("The following files will be overwritten:")
                print(file_text)

        self._preprocess_data()

    @abstractmethod
    def _preprocess_data(self):
        pass

    def _pickle_data(self, data, file_name):
        print(f"Saving to pickle file {file_name}")
        with open(file_name, 'wb') as f:
            pickle.dump(data, f)

    @classmethod
    def get_pickle_file_name(cls, key: str):
        return os.path.join(PROCESSED_DATA_DIR, f'{cls.DATASET_NAME.lower()}_{key}.pkl')

For each dataset, define a preprocessor:

In [9]:
LABEL_ENCODER = None

class KarstPreprocessor(AbstractPreprocessor):
    DATASET_NAME = 'Karst'
    RAW_TRAIN_FILE_NAME = os.path.join(DATASET_MAPPING['Karst']['dir'],
                                       'SEM_EVAL_FILE_KARST_TRAIN.txt')
    RAW_TEST_FILE_NAME = os.path.join(DATASET_MAPPING['Karst']['dir'],
                                      'SEM_EVAL_FILE_KARST_TEST.txt')
    RAW_TRAIN_DATA_SIZE = 2084
    RAW_TEST_DATA_SIZE = 233
    RANDOM_SEED = 2020
    VAL_DATA_PROPORTION = 0.1

    def _preprocess_data(self):
        print("Processing training data")
        train_data = self._get_data_from_file(
            self.RAW_TRAIN_FILE_NAME,
            self.RAW_TRAIN_DATA_SIZE
        )

        print("Processing test data")
        test_data = self._get_data_from_file(
            self.RAW_TEST_FILE_NAME,
            self.RAW_TEST_DATA_SIZE
        )

        print("Encoding labels to integers")
        le = LabelEncoder()
        le.fit(train_data['labels'])
        global LABEL_ENCODER
        LABEL_ENCODER = le
        #train_data['labels_string'] = train_data['labels']
        train_data['labels'] = le.transform(train_data['labels']).tolist()
        #test_data['labels_string'] = test_data['labels']
        test_data['labels'] = le.transform(test_data['labels']).tolist()

        print("Splitting train & validate data")
        train_data, val_data = self._train_val_split(train_data)

        self._pickle_data(train_data, self.get_pickle_file_name('train'))
        self._pickle_data(val_data, self.get_pickle_file_name('val'))
        self._pickle_data(test_data, self.get_pickle_file_name('test'))

    def _train_val_split(self, original_data):
        k = list(original_data.keys())[0]
        indies = list(range(len(original_data[k])))
        train_indies, val_indies = train_test_split(
            indies,
            test_size=self.VAL_DATA_PROPORTION,
            random_state=self.RANDOM_SEED
        )
        train_data = {k: self._get_sample(v, train_indies) for k, v in original_data.items()}
        val_data = {k: self._get_sample(v, val_indies) for k, v in original_data.items()}

        return train_data, val_data

    def _get_sample(self, data, indies):
        return [data[i] for i in indies]

    def _get_data_from_file(self, file_name: str, dataset_size: int):
        raw_sentences = []
        labels = []
        with open(file_name) as f:
            for _ in tqdm(range(dataset_size)):
                raw_sentences.append(self._process_sentence(f.readline()))
                labels.append(self._process_label(f.readline()))
                f.readline()
                f.readline()
        data = self.tokenizer(raw_sentences, truncation=True, padding=True)
        data['labels'] = labels
        return data

    def _process_sentence(self, sentence: str):
        # TODO distinguish e1 e2 sub obj
        return sentence.split("\t")[1][1:-2] \
            .replace("<e1>", SUB_START_CHAR) \
            .replace("</e1>", SUB_END_CHAR) \
            .replace("<e2>", OBJ_START_CHAR) \
            .replace("</e2>", OBJ_END_CHAR)

    def _process_label(self, label: str):
        return label[:]

class SemEval2010Task8Preprocessor(AbstractPreprocessor):
    DATASET_NAME = 'SemEval2010Task8'
    RAW_TRAIN_FILE_NAME = os.path.join(DATASET_MAPPING['SemEval2010Task8']['dir'],
                                       'SemEval2010_task8_training/TRAIN_FILE.TXT')
    RAW_TEST_FILE_NAME = os.path.join(DATASET_MAPPING['SemEval2010Task8']['dir'],
                                      'SemEval2010_task8_testing_keys/TEST_FILE_FULL.TXT')
    RAW_TRAIN_DATA_SIZE = 8000
    RAW_TEST_DATA_SIZE = 2717
    RANDOM_SEED = 2020
    VAL_DATA_PROPORTION = 0.2

    def _preprocess_data(self):
        print("Processing training data")
        train_data = self._get_data_from_file(
            self.RAW_TRAIN_FILE_NAME,
            self.RAW_TRAIN_DATA_SIZE
        )

        print("Processing test data")
        test_data = self._get_data_from_file(
            self.RAW_TEST_FILE_NAME,
            self.RAW_TEST_DATA_SIZE
        )

        print("Encoding labels to integers")
        le = LabelEncoder()
        le.fit(train_data['labels'])
        train_data['labels'] = le.transform(train_data['labels']).tolist()
        test_data['labels'] = le.transform(test_data['labels']).tolist()

        print("Splitting train & validate data")
        train_data, val_data = self._train_val_split(train_data)

        self._pickle_data(train_data, self.get_pickle_file_name('train'))
        self._pickle_data(val_data, self.get_pickle_file_name('val'))
        self._pickle_data(test_data, self.get_pickle_file_name('test'))

    def _train_val_split(self, original_data):
        k = list(original_data.keys())[0]
        indies = list(range(len(original_data[k])))
        train_indies, val_indies = train_test_split(
            indies,
            test_size=self.VAL_DATA_PROPORTION,
            random_state=self.RANDOM_SEED
        )
        train_data = {k: self._get_sample(v, train_indies) for k, v in original_data.items()}
        val_data = {k: self._get_sample(v, val_indies) for k, v in original_data.items()}

        return train_data, val_data

    def _get_sample(self, data, indies):
        return [data[i] for i in indies]

    def _get_data_from_file(self, file_name: str, dataset_size: int):
        raw_sentences = []
        labels = []
        with open(file_name) as f:
            for _ in tqdm(range(dataset_size)):
                raw_sentences.append(self._process_sentence(f.readline()))
                labels.append(self._process_label(f.readline()))
                f.readline()
                f.readline()
        data = self.tokenizer(raw_sentences, truncation=True, padding=True)
        data['labels'] = labels
        return data

    def _process_sentence(self, sentence: str):
        # TODO distinguish e1 e2 sub obj
        return sentence.split("\t")[1][1:-2] \
            .replace("<e1>", SUB_START_CHAR) \
            .replace("</e1>", SUB_END_CHAR) \
            .replace("<e2>", OBJ_START_CHAR) \
            .replace("</e2>", OBJ_END_CHAR)

    def _process_label(self, label: str):
        return label[:-8]


class GIDSPreprocessor(AbstractPreprocessor):
    DATASET_NAME = 'GIDS'
    RAW_TRAIN_FILE_NAME = os.path.join(DATASET_MAPPING['GIDS']['dir'], 'gids_train.json')
    RAW_VAL_FILE_NAME = os.path.join(DATASET_MAPPING['GIDS']['dir'], 'gids_dev.json')
    RAW_TEST_FILE_NAME = os.path.join(DATASET_MAPPING['GIDS']['dir'], 'gids_test.json')

    def _preprocess_data(self):
        print("Processing validate data")
        val_data = self._get_data_from_file(self.RAW_VAL_FILE_NAME)
        le = LabelEncoder()
        le.fit(val_data['labels'])
        val_data['labels'] = le.transform(val_data['labels']).tolist()
        self._pickle_data(val_data, self.get_pickle_file_name('val'))
        del val_data

        print("Processing train data")
        train_data = self._get_data_from_file(self.RAW_TRAIN_FILE_NAME)
        train_data['labels_string'] = train_data['labels']
        train_data['labels'] = le.transform(train_data['labels']).tolist()
        self._pickle_data(train_data, self.get_pickle_file_name('train'))
        del train_data
        
        print("Processing test data")
        test_data = self._get_data_from_file(self.RAW_TEST_FILE_NAME)
        test_data['labels_string'] = test_data['labels']
        test_data['labels'] = le.transform(test_data['labels']).tolist()
        self._pickle_data(test_data, self.get_pickle_file_name('test'))
        del test_data

    def _get_data_from_file(self, file_name: str):
        raw_sentences = []
        labels = []
        with open(file_name) as f:
            for line in tqdm(f.readlines()):
                dt = json.loads(line)
                sentence = " ".join(dt['sent'])

                # add subject markup
                new_sub = SUB_START_CHAR + dt['sub'].replace('_', '') + SUB_END_CHAR # TODO keep _ or not?
                new_obj = OBJ_START_CHAR + dt['obj'].replace('_', '') + OBJ_END_CHAR
                sentence = sentence.replace(dt['sub'], new_sub).replace(dt['obj'], new_obj)
                raw_sentences.append(sentence)
                labels.append(dt['rel'])
        data = self.tokenizer(raw_sentences, truncation=True, padding=True)
        data['labels'] = labels
        return data

Factory function to get preprocessor:

In [10]:
def get_preprocessor_class(dataset_name: str):
    return globals()[f'{dataset_name}Preprocessor']

def get_preprocessor(dataset_name: str)-> AbstractPreprocessor:
    bert_model_info = BERT_VARIANT_MAPPING[BERT_VARIANT]
    bert_pretrain_weight = bert_model_info['pretrain_weight']
    tokenizer = bert_model_info['tokenizer'].from_pretrained(bert_pretrain_weight)
    preprocessors_class = get_preprocessor_class(dataset_name)
    return preprocessors_class(tokenizer)

Preprocess data:

In [11]:
preprocessor = get_preprocessor(DATASET_NAME)
preprocessor.preprocess_data(reprocess=False)


---> Preprocessing Karst dataset <---
Creating processed data directory /content/data/processed
Processing training data


  0%|          | 0/2084 [00:00<?, ?it/s]

Processing test data


  0%|          | 0/233 [00:00<?, ?it/s]

Encoding labels to integers
Splitting train & validate data
Saving to pickle file /content/data/processed/karst_train.pkl
Saving to pickle file /content/data/processed/karst_val.pkl
Saving to pickle file /content/data/processed/karst_test.pkl


## Model

### Dataset

In [12]:
class GenericDataset(Dataset):

    def __init__(self, dataset_name: str, subset: str):
        preprocessor_class = get_preprocessor_class(dataset_name)
        if subset not in ['train', 'val', 'test']:
            raise ValueError('subset must be train, val or test')
        with open(preprocessor_class.get_pickle_file_name(subset), 'rb') as f:
            self.data = pickle.load(f)

    def __getitem__(self, index: int) -> Tuple[Tensor, Tensor, Tensor]:
        return (torch.tensor(self.data['input_ids'][index]),
                torch.tensor(self.data['attention_mask'][index]),
                torch.tensor(self.data['labels'][index]))

    def __len__(self) -> int:
        return len(self.data['labels'])

### Torch Lightning Module

In [13]:
class BERTModule(LightningModule):

    def __init__(self, bert_variant, dataset_name, batch_size, learning_rate):
        super().__init__()
        self.save_hyperparameters()

        bert_info = BERT_VARIANT_MAPPING[bert_variant]
        bert_model_class = bert_info['model']
        bert_pretrain_weight = bert_info['pretrain_weight']
        self.bert = bert_model_class.from_pretrained(bert_pretrain_weight, output_attentions=True)

        dataset_info = DATASET_MAPPING[dataset_name]
        self.num_classes = dataset_info['num_classes']
        self.linear = nn.Linear(self.bert.config.hidden_size, self.num_classes)

    def train_dataloader(self) -> DataLoader:
        return self.__get_dataloader('train')

    def val_dataloader(self) -> DataLoader:
        return self.__get_dataloader('val')

    def test_dataloader(self) -> DataLoader:
        return self.__get_dataloader('test')

    def __get_dataloader(self, subset: str) -> DataLoader:
        print(f"Loading {subset} data")
        return DataLoader(
            GenericDataset(self.hparams.dataset_name, subset),
            batch_size=self.hparams.batch_size,
            shuffle=(subset == 'train'),
            num_workers=multiprocessing.cpu_count() + 1
        )

    def configure_optimizers(self) -> Optimizer:
        return AdamW(
            [p for p in self.parameters() if p.requires_grad],
            lr=self.hparams.learning_rate,
            eps=1e-08
        )

    def forward(self, input_ids, attention_mask) -> Tensor:
        #bert_output, cc 
        aa = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        bert_output = aa["last_hidden_state"]
        #print(bert_output)
        #breakpoint()
        bert_cls = bert_output[:, 0]
        #print(bert_cls)
        logits = self.linear(bert_cls)
        return logits

    def training_step(self, batch, batch_nb) -> dict:
        input_ids, attention_mask, label = batch

        y_hat = self(input_ids, attention_mask)

        loss = F.cross_entropy(y_hat, label)
        tensorboard_logs = {'train_loss': loss}

        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb) -> dict:
        input_ids, attention_mask, label = batch

        y_hat = self(input_ids, attention_mask)

        loss = F.cross_entropy(y_hat, label)

        a, y_hat = torch.max(y_hat, dim=1)
        y_hat = y_hat.cpu()
        label = label.cpu()
        


        return {
            'val_loss': loss,
            'val_pre': torch.tensor(precision_score(label, y_hat, average='weighted')),
            'val_rec': torch.tensor(recall_score(label, y_hat, average='weighted')),
            'val_acc': torch.tensor(accuracy_score(label, y_hat)),
            'val_f1': torch.tensor(f1_score(label, y_hat, average='weighted'))
        }

    def validation_epoch_end(self, outputs) -> dict:
        avg_val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        avg_val_pre = torch.stack([x['val_pre'] for x in outputs]).mean()
        avg_val_rec = torch.stack([x['val_rec'] for x in outputs]).mean()
        avg_val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()
        avg_val_f1 = torch.stack([x['val_f1'] for x in outputs]).mean()

        tensorboard_logs = {
            'val_loss': avg_val_loss,
            'avg_val_pre': avg_val_pre,
            'avg_val_rec': avg_val_rec,
            'avg_val_acc': avg_val_acc,
            'avg_val_f1': avg_val_f1,
        }
        print(tensorboard_logs)
        return {'val_loss': avg_val_loss, 'progress_bar': tensorboard_logs}

    def test_step(self, batch, batch_nb) -> dict:
        input_ids, attention_mask, label = batch

        y_hat = self(input_ids, attention_mask)
        a, y_hat = torch.max(y_hat, dim=1)
        
        y_hat = y_hat.cpu()
        label = label.cpu()
        
        test_pre = precision_score(label, y_hat, average='weighted')
        test_rec = recall_score(label, y_hat, average='weighted')
        test_acc = accuracy_score(label, y_hat)
        test_f1 = f1_score(label, y_hat, average='weighted')
        
        #label_string = label_string.cpu()
        dict_res = {}
        for i in range(self.num_classes):
          #label_str = label_string[label==i][0]
          label_curr = label[label==i]
          label_curr2 = torch.clone(label_curr)
          #breakpoint()
          label_curr2[label_curr==i]=1
          label_curr2[label_curr!=i]=0
          #breakpoint()
          #print(label_curr, i)
          y_hat_curr = y_hat[label==i]
          y_hat_curr2 = torch.clone(y_hat_curr)
          y_hat_curr2[y_hat_curr==i] = 1
          y_hat_curr2[y_hat_curr!=i] = 0
          p = precision_score(label_curr2, y_hat_curr2)
          r = recall_score(label_curr2, y_hat_curr2)
          a = accuracy_score(label_curr2, y_hat_curr2)
          f = f1_score(label_curr2, y_hat_curr2)
          dict_res[i] = (p,r,a,f, len(y_hat_curr2))

        return {
            'test_pre': torch.tensor(test_pre),
            'test_rec': torch.tensor(test_rec),
            'test_acc': torch.tensor(test_acc),
            'test_f1': torch.tensor(test_f1),
            'dict': dict_res
        }

    def test_epoch_end(self, outputs) -> dict:
        avg_test_pre = torch.stack([x['test_pre'] for x in outputs]).mean()
        avg_test_rec = torch.stack([x['test_rec'] for x in outputs]).mean()
        avg_test_acc = torch.stack([x['test_acc'] for x in outputs]).mean()
        avg_test_f1 = torch.stack([x['test_f1'] for x in outputs]).mean()

        all_dicts = [x['dict'] for x in outputs]
        all_res = {}
        for d in all_dicts:
          for k in d:
            p = d[k][0]
            r = d[k][1]
            a = d[k][2]
            f = d[k][3]
            l = d[k][4]
            if k in all_res:
              all_res[k] = (all_res[k][0]+p*l,all_res[k][1]+r*l,all_res[k][2]+a*l,all_res[k][3]+f*l, all_res[k][4]+l)
            else:
              all_res[k] = (p*l,r*l,a*l,f*l, l)
        for d in all_res:
          print(str(LABEL_ENCODER.inverse_transform([d])), "precision", str(all_res[d][0]/all_res[d][4]), "recall", str(all_res[d][1]/all_res[d][4]), "acc",str(all_res[d][2]/all_res[d][4]),"f1", str(all_res[d][3]/all_res[d][4]), "num", str(all_res[d][4]))
        #  pass
        tensorboard_logs = {
            'avg_test_pre': avg_test_pre,
            'avg_test_rec': avg_test_rec,
            'avg_test_acc': avg_test_acc,
            'avg_test_f1': avg_test_f1,
        }
        print(tensorboard_logs)
        return {'progress_bar': tensorboard_logs}



## Trainer

In [14]:
GPUS = 1
MIN_EPOCHS = 3
MAX_EPOCHS = 5

trainer = LightningTrainer(
    gpus=GPUS,
    min_epochs=MIN_EPOCHS,
    max_epochs=MAX_EPOCHS,
    default_root_dir=CHECKPOINT_DIR,
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


## Training

In [15]:
BATCH_SIZE = 16
LEARNING_RATE = 2e-05

model = BERTModule(
    bert_variant=BERT_VARIANT,
    dataset_name=DATASET_NAME,
    batch_size=BATCH_SIZE,
    learning_rate=LEARNING_RATE
)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
trainer.fit(model)

Missing logger folder: /content/checkpoint/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name   | Type         | Params
----------------------------------------
0 | bert   | RobertaModel | 124 M 
1 | linear | Linear       | 10.8 K
----------------------------------------
124 M     Trainable params
0         Non-trainable params
124 M     Total params
498.626   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Loading val data


  cpuset_checked))
  _warn_prf(average, modifier, msg_start, len(result))


{'val_loss': tensor(2.6360, device='cuda:0'), 'avg_val_pre': tensor(0.0098, dtype=torch.float64), 'avg_val_rec': tensor(0.0938, dtype=torch.float64), 'avg_val_acc': tensor(0.0938, dtype=torch.float64), 'avg_val_f1': tensor(0.0176, dtype=torch.float64)}
Loading train data


  _warn_prf(average, modifier, msg_start, len(result))
  cpuset_checked))


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'val_loss': tensor(1.0547, device='cuda:0'), 'avg_val_pre': tensor(0.6627, dtype=torch.float64), 'avg_val_rec': tensor(0.7098, dtype=torch.float64), 'avg_val_acc': tensor(0.7098, dtype=torch.float64), 'avg_val_f1': tensor(0.6586, dtype=torch.float64)}


  cpuset_checked))


Validation: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'val_loss': tensor(0.6146, device='cuda:0'), 'avg_val_pre': tensor(0.8421, dtype=torch.float64), 'avg_val_rec': tensor(0.8259, dtype=torch.float64), 'avg_val_acc': tensor(0.8259, dtype=torch.float64), 'avg_val_f1': tensor(0.8259, dtype=torch.float64)}


  cpuset_checked))


Validation: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'val_loss': tensor(0.6952, device='cuda:0'), 'avg_val_pre': tensor(0.8090, dtype=torch.float64), 'avg_val_rec': tensor(0.7812, dtype=torch.float64), 'avg_val_acc': tensor(0.7812, dtype=torch.float64), 'avg_val_f1': tensor(0.7791, dtype=torch.float64)}


  cpuset_checked))


Validation: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'val_loss': tensor(0.6014, device='cuda:0'), 'avg_val_pre': tensor(0.8675, dtype=torch.float64), 'avg_val_rec': tensor(0.8393, dtype=torch.float64), 'avg_val_acc': tensor(0.8393, dtype=torch.float64), 'avg_val_f1': tensor(0.8406, dtype=torch.float64)}


  cpuset_checked))


Validation: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


{'val_loss': tensor(0.5915, device='cuda:0'), 'avg_val_pre': tensor(0.8702, dtype=torch.float64), 'avg_val_rec': tensor(0.8482, dtype=torch.float64), 'avg_val_acc': tensor(0.8482, dtype=torch.float64), 'avg_val_f1': tensor(0.8515, dtype=torch.float64)}


## Testing

In [17]:
trainer.test(model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Loading test data


  cpuset_checked))


Testing: 0it [00:00, ?it/s]

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, "true nor predicted", "F-score is", len(true_sum))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  avg = a.mean(axis)
  ret = ret.dtype.type(ret / rcount)
  _warn_prf(average, "true nor predicted", "F-score is", 

['Composition_medium(e1,e2)\n'] precision 0.7619047619047619 recall 0.5714285714285714 acc nan f1 0.634920634920635 num 21
['Composition_medium(e2,e1)\n'] precision nan recall nan acc nan f1 nan num 0
['Genus(e1,e2)\n'] precision 1.0 recall 0.9910714285714286 acc 0.9910714285714286 f1 0.9953007518796992 num 112
['Genus(e2,e1)\n'] precision nan recall nan acc nan f1 nan num 0
['Has_cause(e1,e2)\n'] precision 0.9333333333333333 recall 0.7333333333333333 acc 0.7333333333333333 f1 0.8072222222222222 num 30
['Has_cause(e2,e1)\n'] precision nan recall nan acc nan f1 nan num 0
['Has_form(e1,e2)\n'] precision 0.8823529411764706 recall 0.7647058823529411 acc nan f1 0.8095238095238094 num 17
['Has_form(e2,e1)\n'] precision nan recall nan acc nan f1 nan num 0
['Has_function(e1,e2)\n'] precision 0.2857142857142857 recall 0.21428571428571427 acc nan f1 0.23809523809523808 num 14
['Has_function(e2,e1)\n'] precision nan recall nan acc nan f1 nan num 0
['Has_location(e1,e2)\n'] precision 0.92307692307

[{}]

In [18]:
LABEL_ENCODER.inverse_transform([0, 1, 2])

array(['Composition_medium(e1,e2)\n', 'Composition_medium(e2,e1)\n',
       'Genus(e1,e2)\n'], dtype='<U26')