In [1]:
import torch
from torch.utils.data import DataLoader
from torch import optim, nn
from transformers import BertTokenizerFast

from data_handling.trc_dataset import TRCDataset
from model.trc_model import TRCModel
from trainer.trainer import Trainer
from trainer.training_utils import get_parameters

if torch.backends.cuda.is_built():
    device_name = 'cuda'

else:
    device_name = 'cpu'

device = torch.device(device_name)
print('device:', device)

BATCH_SIZE = 4
MODEL_CHECKPOINT = 'onlplab/alephbert-base'
TRAINING_LAYERS = 52
LABELS = ['BEFORE', 'AFTER', 'EQUAL', 'VAGUE']

data_paths = {
    'train': 'data_handling/split_data/train.csv',
    'test': 'data_handling/split_data/test.csv'}

tokenizer = BertTokenizerFast.from_pretrained(MODEL_CHECKPOINT)
tokenizer.add_tokens(['[א1]', '[/א1]', '[א2]', '[/א2]'])
E1_start = tokenizer.convert_tokens_to_ids('[א1]')
E2_start = tokenizer.convert_tokens_to_ids('[א2]')

train_set = TRCDataset(data_path=data_paths['train'])
test_set = TRCDataset(data_path=data_paths['test'])

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)

print(f'train: {len(train_set)}\ntest: {len(test_set)}')

model = TRCModel(output_size=len(LABELS), tokenizer=tokenizer, check_point=MODEL_CHECKPOINT, architecture='ESS')

trainer = Trainer(model, tokenizer=tokenizer,
                  optimizer=optim.Adam(get_parameters(model.named_parameters(), TRAINING_LAYERS), lr=1e-5),
                  criterion=nn.CrossEntropyLoss(),
                  entity_markers=(E1_start, E2_start),
                  labels=LABELS,
                  device=device)

trainer.train(train_loader=train_loader,
              valid_loader=test_loader,
              max_epochs=10)


device: cpu
train: 2180
test: 756


Some weights of the model checkpoint at onlplab/alephbert-base were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at onlplab/alephbert-base and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias

TRCModel(
  (lm): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(52004, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropo


100%|██████████| 545/545 [02:14<00:00,  3.99it/s]
  0%|          | 0/545 [00:00<?, ?it/s]          A

Epoch [1/10], Step [545/5450], Train Loss: 1.0917, Valid Loss: 1.0276, precision: 0.50, recall: 0.57, F1: 0.52


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))

  0%|          | 0/545 [00:00<?, ?it/s]          A

Epoch [2/10], Step [1090/5450], Train Loss: 0.8536, Valid Loss: 1.2231, precision: 0.52, recall: 0.51, F1: 0.37


100%|██████████| 545/545 [02:16<00:00,  4.06it/s][E thread_pool.cpp:113] Exception in thread pool task: mutex lock failed: Invalid argument
[E thread_pool.cpp:113] Exception in thread pool task: mutex lock failed: Invalid argument


KeyboardInterrupt: 

In [None]:
trainer.evaluate(test_loader, print_report=True)