In [1]:
import json
import random
from glirel.model import GLiREL

import torch
import os

from glirel.model import load_config_as_namespace
from datetime import datetime
import logging

import sys
sys.path.append('..')
from train import train, dirty_split_data_by_relation_type


## 🏎️ Load Model

In [3]:
model = GLiREL.from_pretrained("jackboyla/glirel_beta")



## ⚙️ Config

In [4]:
config_file_path = '../configs/config_finetuning.yaml'
log_dir = '../logs/finetuning'

In [5]:

logger = logging.getLogger(__name__)

logging.basicConfig(level=logging.INFO, 
                    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                    handlers=[logging.StreamHandler()])


# load config
config = load_config_as_namespace(config_file_path)

config.log_dir = log_dir
config.train_data = train_path

# set up logging
if config.log_dir is None:
    current_time = datetime.now().strftime("%Y-%m-%d__%H-%M-%S")
    config.log_dir = f'logs/{config.dataset_name}/{config.dataset_name}-{current_time}'
if not os.path.exists(config.log_dir):
    os.makedirs(config.log_dir)

log_file = "train.log"
log_file_path = os.path.join(config.log_dir, log_file)
if os.path.exists(log_file_path):
    os.remove(log_file_path)
file_handler = logging.FileHandler(log_file_path)
file_handler.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)


## 📓 Prep Data

In [6]:
train_path = "../data/sample.jsonl" # cannot be None
test_path = None  # can be None

assert os.path.exists(train_path), f"Train file not found: {train_path}"

with open(train_path, "r") as f:
    data = [json.loads(line) for line in f]

random.shuffle(data)

TEST_SET_RATIO = 0.1
TRAIN_SET_RATIO = 1 - TEST_SET_RATIO


if test_path is None:
    # if no test set provided, split the training data
    max_test_size = round(len(data)*TEST_SET_RATIO) + 1
    print(f"Splitting training data into training and test sets with max test size: {max_test_size}")
    train_dataset, test_dataset = dirty_split_data_by_relation_type(
        data, 
        num_unseen_rel_types=config.num_unseen_rel_types, 
        max_test_size=max_test_size,
        )
else:
    train_dataset = data
    with open(test_path, "r") as f:
        test_dataset = [json.loads(line) for line in f]

print('Train dataset size:', len(train_dataset))
print('Test dataset size:', len(test_dataset))

2024-07-25 15:59:52,494 - train - INFO - Dirty splitting data...


Splitting training data into training and test sets with max test size: 1
Train dataset size: 4
Test dataset size: 1


## 🚀 Train!

In [7]:
# Get number of parameters (trainable and total)
num_params = sum(p.numel() for p in model.parameters())
num_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
logger.info(f"Number of trainable parameters: {num_trainable_params} / {num_params}")

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = model.to(DEVICE)

lr_encoder = float(config.lr_encoder)
lr_others = float(config.lr_others)

optimizer = torch.optim.AdamW([
    # encoder
    {'params': model.token_rep_layer.parameters(), 'lr': lr_encoder},
    {'params': model.rnn.parameters(), 'lr': lr_others},
    # projection layers
    {'params': model.span_rep_layer.parameters(), 'lr': lr_others},
    {'params': model.prompt_rep_layer.parameters(), 'lr': lr_others},
])


logger.info("🚀 Relation extraction training started")
train(model, 
      optimizer, 
      train_dataset, 
      config, 
      eval_data=test_dataset, 
      num_steps=config.num_steps, 
      eval_every=config.eval_every, 
      log_dir=config.log_dir,
      wandb_log=False, 
      wandb_sweep=False, 
      warmup_ratio=config.warmup_ratio, 
      train_batch_size=config.train_batch_size, 
      device=DEVICE,
      use_amp=True if DEVICE == 'cuda' else False,
)

2024-07-25 15:59:52,516 - __main__ - INFO - Number of trainable parameters: 466576896 / 466576896
2024-07-25 15:59:52,527 - __main__ - INFO - 🚀 Relation extraction training started
  0%|          | 0/21 [00:00<?, ?it/s]2024-07-25 15:59:56,139 - train - INFO - Step 0 | loss: 31.11713218688965 | x['rel_label']: torch.Size([4, 2]) | x['span_idx']: torch.Size([4, 2, 2]) | x['tokens']: [24, 28, 25, 30] | num candidate_classes: 4
step: 0 | epoch: 0 | loss: 31.12:   5%|▍         | 1/21 [00:05<01:58,  5.93s/it]2024-07-25 16:00:02,168 - train - INFO - Step 1 | loss: 31.591588973999023 | x['rel_label']: torch.Size([4, 2]) | x['span_idx']: torch.Size([4, 2, 2]) | x['tokens']: [24, 28, 25, 30] | num candidate_classes: 4
step: 1 | epoch: 1 | loss: 31.59:  10%|▉         | 2/21 [00:11<01:43,  5.46s/it]2024-07-25 16:00:06,928 - train - INFO - Step 2 | loss: 14.576828002929688 | x['rel_label']: torch.Size([4, 2]) | x['span_idx']: torch.Size([4, 2, 2]) | x['tokens']: [24, 28, 25, 30] | num candidate_cla

## 🥳 Load your finetuned model

In [10]:
model = GLiREL.from_pretrained(f"{config.log_dir}/model_{config.eval_every}")

import spacy

nlp = spacy.load('en_core_web_sm')

text = 'Derren Nesbitt had a history of being cast in "Doctor Who", having played villainous warlord Tegana in the 1964 First Doctor serial "Marco Polo".'
doc = nlp(text)
tokens = [token.text for token in doc]

labels = ['country of origin', 'licensed to broadcast to', 'father', 'followed by', 'characters']

ner = [[26, 27, 'PERSON', 'Marco Polo'], [22, 23, 'Q2989412', 'First Doctor']] 

relations = model.predict_relations(tokens, labels, threshold=0.0, ner=ner, top_k=1)

print('Number of relations:', len(relations))

sorted_data_desc = sorted(relations, key=lambda x: x['score'], reverse=True)
print("\nDescending Order by Score:")
for item in sorted_data_desc:
    print(f"{item['head_text']} --> {item['label']} --> {item['tail_text']} | score: {item['score']}")


config.json not found in /home/jackboylan/GLiREL/logs/finetuning/model_20


Number of relations: 2

Descending Order by Score:
['Marco', 'Polo'] --> characters --> ['First', 'Doctor'] | score: 0.7750263810157776
['First', 'Doctor'] --> characters --> ['Marco', 'Polo'] | score: 0.573470413684845
