# Train model

## 1. Setup

In [None]:
import numpy as np
from tqdm import tqdm

import itertools
from sklearn.metrics import f1_score
from seqeval.metrics import f1_score as ner_f1_score
from seqeval.scheme import IOB2

import torch
from datasets import load_dataset
from datasets import Dataset as HFDataset
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import get_scheduler, pipeline

In [None]:
DEVICE = 'cuda:1'
MODEL_NAME = 'microsoft/deberta-v3-base'
MAX_LENGTH = 512
BATCH_SIZE = 16
LR = 2e-5
NUM_EPOCHS = 10
WARMUP_RATIO = 0.01

In [None]:
LABELS = ['O', 'B-Entity', 'I-Entity']
NUM_LABELS = len(LABELS)
LABEL2ID = {l:i for i,l in enumerate(LABELS)}
ID2LABEL = {i:l for i,l in enumerate(LABELS)}

## 2. Data

In [None]:
def find_mention_span(text, mention):
    spans = []
    
    gs = 0
    while True:
        s = text.find(mention)
        if s == -1:
            break
        e = s + len(mention)
        
        spans.append((gs+s, gs+e))
        gs = gs+e
        text = text[e:]

    return spans


def pad_sequences(seqs, pad_val, maxlen):     
    _maxlen = max([len(s) for s in seqs])
    maxlen = min(maxlen, _maxlen) if maxlen else _maxlen 
    
    padded_seqs = []
    for seq in seqs:
        pads = [pad_val] * (maxlen - len(seq))
        seq = seq + pads
        padded_seqs.append(seq)

    padded_seqs = torch.tensor(padded_seqs)
    return padded_seqs


class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['text']
        
        entity = np.random.choice(item['entities'])
        entity_type = entity['entity_type']
        entity_mentions = entity['entity_mentions']

        inputs = self.tokenizer(text, entity_type, truncation='only_first', max_length=self.max_length)
        label = [0 for _ in range(len(inputs.input_ids))]
        for m in entity_mentions:
            spans = find_mention_span(text, m)
            for s, e in spans:
                s = inputs.char_to_token(s)
                e = inputs.char_to_token(e - 1)
                if s is None or e is None: 
                    continue
                
                label[s] = 1 # B-Entity
                for i in range(s+1, e+1):
                    label[i] = 2 # I-Entity
        
        return inputs['input_ids'], inputs['attention_mask'], label


    def collate_fn(self, batch):
        input_ids, attention_mask, labels = zip(*batch)
        input_ids = pad_sequences(input_ids, self.tokenizer.pad_token_id, self.max_length)
        attention_mask = pad_sequences(attention_mask, 0, self.max_length)
        labels = pad_sequences(labels, -100, self.max_length)
        return input_ids, attention_mask, labels


    def get_dataloader(self, batch_size, shuffle):
        return torch.utils.data.DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate_fn)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
data = load_dataset('yongsun-yoon/open-ner-english')
train_data = data['train']
valid_data = data['validation']

In [None]:
train_dataset = Dataset(train_data, tokenizer, MAX_LENGTH)
train_loader = train_dataset.get_dataloader(BATCH_SIZE, shuffle=True)

valid_dataset = Dataset(valid_data, tokenizer, MAX_LENGTH)
valid_loader = valid_dataset.get_dataloader(BATCH_SIZE, shuffle=False)

In [None]:
input_ids, attention_mask, labels = next(iter(train_loader))
input_ids.shape, attention_mask.shape, labels.shape

## 3. Train

In [None]:
def predict(model, loader):
    model.eval()
    device = next(model.parameters()).device

    total_preds, total_labels = [], []
    for input_ids, attention_mask, labels in tqdm(loader):
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
        with torch.no_grad():
            outputs = model(input_ids, attention_mask)
        preds = outputs.logits.argmax(dim=-1).cpu()

        bs = preds.shape[0]
        for i in range(bs):
            pred, label = preds[i], labels[i]
            idxs = torch.where(label != -100)
            total_preds.append(pred[idxs].tolist())
            total_labels.append(label[idxs].tolist())
            
    return total_preds, total_labels


def token_f1_func(total_preds, total_labels):
    y_pred = list(itertools.chain(*total_preds))
    y_true = list(itertools.chain(*total_labels))
    return f1_score(y_true, y_pred, average='macro')


def entity_f1_func(total_preds, total_labels, LABELS):
    y_pred = [[LABELS[p] for p in preds] for preds in total_preds]
    y_true = [[LABELS[l] for l in labels] for labels in total_labels]
    return ner_f1_score(y_true, y_pred, average="macro", mode="strict", scheme=IOB2)

In [None]:
model = AutoModelForTokenClassification.from_pretrained(MODEL_NAME, num_labels=NUM_LABELS, label2id=LABEL2ID, id2label=ID2LABEL)
_ = model.train().to(DEVICE)

num_training_steps = NUM_EPOCHS * len(train_loader)
num_warmup_steps = int(num_training_steps * WARMUP_RATIO)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = get_scheduler('cosine', optimizer, num_training_steps, num_warmup_steps)

In [None]:
best_score = 0.
for ep in range(NUM_EPOCHS):
    pbar = tqdm(train_loader)
    for batch in pbar:
        optimizer.zero_grad()
        
        input_ids, attention_mask, labels = [b.to(DEVICE) for b in batch]
        outputs = model(input_ids, attention_mask, labels=labels)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        scheduler.step()

        log = {'loss': loss.item()}
        pbar.set_postfix(log)

    total_preds, total_labels = predict(model, valid_loader)
    entity_f1 = entity_f1_func(total_preds, total_labels, LABELS)
    token_f1 = token_f1_func(total_preds, total_labels)
    print(f'ep {ep:02d} | entity_f1 {entity_f1:.3f} | token_f1 {token_f1:.3f}')

    if entity_f1 > best_score:
        tokenizer.save_pretrained('ckpt')
        model.save_pretrained('ckpt')
        best_score = entity_f1

## 4. Evaluate

In [None]:
def flatten_data(data):
    flattened_data = []
    for d in data:
        for entity in d['entities']:
            flattened_data.append({
                'text': d['text'],
                'entity_type': entity['entity_type'],
                'entity_mentions': entity['entity_mentions']
            })
    return HFDataset.from_list(flattened_data)


class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, max_length):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        text = item['text']
        
        entity_type = item['entity_type']
        entity_mentions = item['entity_mentions']

        inputs = self.tokenizer(text, entity_type, truncation='only_first', max_length=self.max_length)
        label = [0 for _ in range(len(inputs.input_ids))]
        for m in entity_mentions:
            spans = find_mention_span(text, m)
            for s, e in spans:
                s = inputs.char_to_token(s)
                e = inputs.char_to_token(e - 1)
                if s is None or e is None: 
                    continue
                
                label[s] = 1 # B-Entity
                for i in range(s+1, e+1):
                    label[i] = 2 # I-Entity
        
        return inputs['input_ids'], inputs['attention_mask'], label


    def collate_fn(self, batch):
        input_ids, attention_mask, labels = zip(*batch)
        input_ids = pad_sequences(input_ids, self.tokenizer.pad_token_id, self.max_length)
        attention_mask = pad_sequences(attention_mask, 0, self.max_length)
        labels = pad_sequences(labels, -100, self.max_length)
        return input_ids, attention_mask, labels


    def get_dataloader(self, batch_size, shuffle):
        return torch.utils.data.DataLoader(self, batch_size=batch_size, shuffle=shuffle, collate_fn=self.collate_fn)

In [None]:
tokenizer = AutoTokenizer.from_pretrained('ckpt')
model = AutoModelForTokenClassification.from_pretrained('ckpt')
_ = model.eval().requires_grad_(False).to(DEVICE)

In [None]:
data = load_dataset('yongsun-yoon/open-ner-english')
valid_data = data['validation']
valid_data = flatten_data(valid_data)

In [None]:
valid_dataset = Dataset(valid_data, tokenizer, MAX_LENGTH)
valid_loader = valid_dataset.get_dataloader(BATCH_SIZE, shuffle=False)

In [None]:
# entity_f1 0.560 | token_f1 0.747
total_preds, total_labels = predict(model, valid_loader)
entity_f1 = entity_f1_func(total_preds, total_labels, LABELS)
token_f1 = token_f1_func(total_preds, total_labels)
print(f'entity_f1 {entity_f1:.3f} | token_f1 {token_f1:.3f}')

## 5. Test

In [None]:
def run(nlp, text, entity_type):
    input_text = f'{text}{nlp.tokenizer.sep_token}{entity_type}'
    return nlp(input_text)

In [None]:
nlp = pipeline('token-classification', 'ckpt', aggregation_strategy='simple')

In [None]:
text = 'Heat the olive oil in a frying pan, add the onion and cook for 5 minutes until softened and starting to turn golden. Set aside.'

In [None]:
run(nlp, text, 'ingredient')

In [None]:
run(nlp, text, 'tool')

In [None]:
text = 'Introducing the best 4 Korean BBQ restaurants in Jamsil, a hot place where Lotte Tower, Seokchon Lake and Sonridan-gil are located in.'

In [None]:
run(nlp, text, 'food')

In [None]:
run(nlp, text, 'place')

In [None]:
text = 'Anthony Edwards was the top scorer for the third game with a personal-best 21 points, and Team USA improved to 4-0 in exhibition play with a 108-86 win over Team Greece.'

In [None]:
run(nlp, text, 'person')

In [None]:
run(nlp, text, 'team')

In [None]:
run(nlp, text, 'score')

In [None]:
text = """The depth and frequency of craters across the frontline city of Orikhiv are a blunt example of why Ukraine needs F-16 fighter jets urgently."""

In [None]:
run(nlp, text, 'weapon')

In [None]:
run(nlp, text, 'country')

## 6. Push to Huggingface Hub

In [None]:
tokenizer = AutoTokenizer.from_pretrained('ckpt')
model = AutoModelForTokenClassification.from_pretrained('ckpt')

In [None]:
tokenizer.push_to_hub('yongsun-yoon/deberta-v3-base-open-ner')
model.push_to_hub('yongsun-yoon/deberta-v3-base-open-ner')

In [None]:
nlp = pipeline('token-classification', 'yongsun-yoon/deberta-v3-base-open-ner', aggregation_strategy='simple')
text = 'Heat the olive oil in a frying pan, add the onion and cook for 5 minutes until softened and starting to turn golden. Set aside.'
entity_type = 'ingredient'
input_text = f'{text}{nlp.tokenizer.sep_token}{entity_type}'
nlp(input_text)

In [None]:
text = 'Heat the olive oil in a frying pan, add the onion and cook for 5 minutes until softened and starting to turn golden. Set aside.'
entity_type = 'ingredient'
input_text = f'{text}{nlp.tokenizer.sep_token}{entity_type}'
nlp(input_text)