# Train the model
Adapted from the original `main.py`. Intergrated with AWS SageMaker.

## Install dependencies

This model requires torch >= 1.9

In [None]:
!pip install -r requirements.txt

In [None]:
import os
from datetime import datetime
import json

from tqdm.auto import tqdm
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

from data_gen import MyDevDataset, NeatDataset, dev_collate_fn, neat_collate_fn
from model_origin import SubjectModel, ObjectModel
import config
from config import create_parser, predicate2id, id2predicate
from utils import para_eval

Define a tensorboard logger

In [None]:
from torch.utils.tensorboard import SummaryWriter

logname = None
now = datetime.now()
dt_string = now.strftime("%m_%d_%H_%M")
if logname is None:
    log_dir = os.path.join('logs', dt_string)
else:
    log_dir = os.path.join('logs', logname + '_' + dt_string)
writer = SummaryWriter(log_dir=log_dir)
print("Logs are saved at:", log_dir)
print("Run this command at the current folder to launch tensorboard:")
print("tensorboard --logdir=logs")

In [None]:
# os.environ["CUDA_VISIBLE_DEVICES"]="0,1,2,3,4,5,6,7"

Read configs from module `config.py`, and define `device`.

In [None]:
# for macOS compatibility
#os.environ['KMP_DUPLICATE_LIB_OK']='True'

BERT_MODEL_NAME = config.bert_model_name
LEARNING_RATE = config.learning_rate
WORD_EMB_SIZE = config.word_emb_size # default bert embedding size
BATCH_SIZE = config.batch_size
BERT_DICT_LEN = config.bert_dict_len
TRAIN_PATH = config.train_path
DEV_PATH = config.dev_path
NUM_CLASSES = config.num_classes

torch.backends.cudnn.benchmark = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Download training data
Skip the downloading step if you have alreay done it.

In [None]:
#!wget https://dataset-bj.cdn.bcebos.com/qianyan/DuIE_2_0.zip

In [None]:
#!unzip -j DuIE_2_0.zip -d data

Transofm raw data to easier usable format

In [None]:
# !mkdir generated
# !python trans.py

## Load training data

Load train and test data. Define their dataloader.

In [None]:
# adjust batch size if needed
# BATCH_SIZE = 512

In [None]:
id2predicate, predicate2id = config.id2predicate, config.predicate2id

train_data = json.load(open(TRAIN_PATH))
dev_data = json.load(open(DEV_PATH))
train_dataset = NeatDataset(train_data, BERT_MODEL_NAME)
test_dataset = MyDevDataset(dev_data, BERT_MODEL_NAME)
train_loader = DataLoader(
    dataset=train_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=True,               # random shuffle for training
    num_workers=2,
    collate_fn=neat_collate_fn,      # subprocesses for loading data
    multiprocessing_context='spawn',
)
test_loader = DataLoader(
    dataset=test_dataset,      # torch TensorDataset format
    batch_size=BATCH_SIZE,      # mini batch size
    shuffle=True,               # random shuffle for training
    num_workers=2,
    collate_fn=dev_collate_fn,      # subprocesses for loading data
    multiprocessing_context='spawn',
)

### Define models
Data are parallimised  to multiple GPUs

In [None]:
subject_model = SubjectModel(BERT_DICT_LEN, WORD_EMB_SIZE).to(device)
object_model = ObjectModel(WORD_EMB_SIZE, NUM_CLASSES).to(device)
if torch.cuda.device_count() > 1:
    print('Using', torch.cuda.device_count(), "GPUs!")
    subject_model = nn.DataParallel(subject_model)
    object_model = nn.DataParallel(object_model)

### Load model if needed
Uncomment lines below to load pre-trained model

In [None]:
breakpoint_epoch = 195 # 210 is saved in repo
model_dir = 'save'
weight_name = 'att1'
subject_model.load_state_dict(torch.load(f"./{model_dir}/subject_{weight_name}_{breakpoint_epoch}", map_location=device))
object_model.load_state_dict(torch.load(f"./{model_dir}/object_{weight_name}_{breakpoint_epoch}", map_location=device))

### Define loss metrics

**Run this after reloading the model and before training**.

In [None]:
params = subject_model.parameters()
params = list(params) + list(object_model.parameters())
print("Using learning rate", LEARNING_RATE)
optimizer = torch.optim.Adam(params, lr=LEARNING_RATE)

### Define training and evaluate scripts

In [None]:
def train(subject_model, object_model, device, train_loader, optimizer, epoch, writer=None, log_interval=10):
    subject_model.train()
    object_model.train()
    train_tqdm = tqdm(enumerate(train_loader), desc="Train")
    for step, batch in train_tqdm:
        token_ids, attention_masks, subject_ids, subject_labels, object_labels = batch
        token_ids, attention_masks, subject_ids, subject_labels, object_labels = \
            token_ids.to(device), attention_masks.to(device), subject_ids.to(device), \
            subject_labels.to(device), object_labels.to(device)
        # predict
        subject_preds, hidden_states = subject_model(token_ids, attention_mask=attention_masks)
        object_preds = object_model(hidden_states, subject_ids, attention_masks)
        # calc loss
        subject_loss = F.binary_cross_entropy(subject_preds, subject_labels, reduction='none') # (bsz, sent_len)
        attention_masks = attention_masks.unsqueeze(dim=2)
        subject_loss = torch.sum(subject_loss * attention_masks) / torch.sum(attention_masks) # ()
        object_loss = F.binary_cross_entropy(object_preds, object_labels, reduction='none') # (bsz, sent_len, n_classes, 2)
        object_loss = torch.mean(object_loss, dim=2) # (bsz, sent_len, 2)
        object_loss = torch.sum(object_loss * attention_masks) / torch.sum(attention_masks) # ()
        loss_sum = subject_loss + object_loss * 10
        train_tqdm.set_postfix(loss=loss_sum.item())
        #updates
        optimizer.zero_grad()
        loss_sum.backward()
        optimizer.step()

        with torch.no_grad():
            exists_subject = subject_labels.sum().item()
            correct_subject = torch.logical_and(subject_preds > 0.6, subject_labels > 0.6).sum().item()
            exists_object = object_labels.sum().item()
            correct_object = torch.logical_and(object_preds > 0.5, object_labels > 0.5).sum().item()

            if step % log_interval == 0:
                print(f"epoch {epoch}, step: {step}, loss: {loss_sum.item()}, subject_recall: {correct_subject}/{exists_subject}, object_recall: {correct_object}/{exists_object}")
                if writer:
                    writer.add_scalar('train/loss', loss_sum.item(), step + epoch * len(train_loader))
                    writer.add_scalar('train/loss_subject', subject_loss.item(), step + epoch * len(train_loader))
                    writer.add_scalar('train/loss_object', object_loss.item(), step + epoch * len(train_loader))
                    writer.add_scalar('train/recall_subject', correct_subject/exists_subject, step + epoch * len(train_loader))
                    writer.add_scalar('train/recall_object', correct_object/exists_object, step + epoch * len(train_loader))


def evaluate(subject_model, object_model, loader, id2predicate, epoch, writer=None):
    subject_model.eval()
    object_model.eval()
    f1, precision, recall = para_eval(subject_model, object_model, loader, id2predicate, epoch=epoch, writer=writer)
    print(f"Eval epoch {epoch}: f1: {f1}, precision: {precision}, recall: {recall}")
    if writer:
        writer.add_scalar('eval/f1', f1, epoch)
        writer.add_scalar('eval/precision', precision, epoch)
        writer.add_scalar('eval/recall', recall, epoch)
    return f1, precision, recall

## Training

In [None]:
best_f1 = 0
best_epoch = 0

In [None]:
starting_epoch = 0

# If you have loaded model from a break point, this part set the starting epoch from the break point.
try:
    breakpoint_epoch
except NameError:
    print("breakpoint epoch not defined, start training from epoch 0")
else:
    print("continue training from epoch", breakpoint_epoch)
    starting_epoch = breakpoint_epoch + 1

In [None]:
epoch_num = 500

In [None]:
for e in range(starting_epoch, epoch_num):
    train(subject_model, object_model, device, train_loader, optimizer, e, writer=writer, log_interval=10)
    f1, precision, recall = evaluate(subject_model, object_model, test_loader, id2predicate, e, writer)
    
    if e % 5 == 0:
        torch.save(subject_model.state_dict(), f"save/subject_{args.logname}_{e}")
        torch.save(object_model.state_dict(), f"save/object_{args.logname}_{e}")

    if f1 >= best_f1:
        best_f1 = f1
        best_epoch = i

    print('f1: %.4f, precision: %.4f, recall: %.4f, bestf1: %.4f, bestepoch: %d \n ' % (
        f1, precision, recall, best_f1, best_epoch))

## Test the trained model on some texts
Extract plain model from Dataparalell if needed

In [None]:
try:
    subject_model = subject_model.module
    object_model = object_model.module
    print("extracted model from DataParalell wrapper")
except:
    print("models are not wrapped by DataParalell")

Examine the model on some test data:

In [None]:
examine_train_loader = DataLoader(
    dataset=train_dataset,    
    batch_size=2,    
    shuffle=True,
    num_workers=0,
    collate_fn=neat_collate_fn,
)
examine_test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=1,
    collate_fn=dev_collate_fn,
    multiprocessing_context='spawn',
)

In [None]:
from utils import extract_spoes

to_print = 4
for step in range(to_print):
    texts, tokens, spoes, att_masks, offset_mappings = next(iter(examine_test_loader))
    print('Text: ', texts)
    print('Predicted SPOs:', extract_spoes(texts, tokens, offset_mappings, subject_model, object_model, id2predicate, attention_mask=att_masks))
    print('Gold SPOs:', spoes)

In [None]:
# visualize the model
with torch.no_grad():
    token_ids, attention_masks, subject_ids, subject_labels, object_labels = next(iter(examine_train_loader))
    writer.add_graph(subject_model, (token_ids, attention_masks))
    _, hidden_states = subject_model(token_ids, attention_mask=attention_masks)
    writer.add_graph(object_model, (hidden_states, subject_ids, attention_masks))

# Find the best model saved
First iterate trough 0 to 195

In [None]:
subject_model = SubjectModel(BERT_DICT_LEN, WORD_EMB_SIZE).to(device)
object_model = ObjectModel(WORD_EMB_SIZE, NUM_CLASSES).to(device)

In [None]:
breakpoint_epoch = 195 # 210 is saved in repo
model_dir = 'save'
weight_name = 'att1'

In [None]:
best_epoch = 0
best_f1 = 0
for e in range(0, epoch_num, 5):
    subject_model.load_state_dict(torch.load(f"./{model_dir}/subject_{weight_name}_{e}", map_location=device))
    object_model.load_state_dict(torch.load(f"./{model_dir}/object_{weight_name}_{e}", map_location=device))
    f1, precision, recall = evaluate(subject_model, object_model, test_loader, id2predicate, e, writer)
    if f1 > best_f1:
        best_f1 = f1
        best_epoch = e
    print('Epoch %d: f1: %.4f, precision: %.4f, recall: %.4f, bestf1: %.4f, bestepoch: %d \n ' % (
        e, f1, precision, recall, best_f1, best_epoch))

In [None]:
writer.flush()
writer.close()