In [1]:
labels = ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]

In [2]:
dataset = "eng.train"

import torch

from torch.utils.data import TensorDataset
from transformers import BertTokenizer
from gdpr.data.default_parse import read_examples_from_file, convert_examples_to_features

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")
examples = read_examples_from_file(".", mode=f"data/{dataset}")
pad_token_label_id = torch.nn.CrossEntropyLoss().ignore_index

features = convert_examples_to_features(
    examples,
    label_list=labels,
    max_seq_length=128,
    tokenizer=tokenizer,
    cls_token_at_end=False,
    cls_token=tokenizer.cls_token,
    cls_token_segment_id=0,
    sep_token=tokenizer.sep_token,
    sep_token_extra=False,
    pad_on_left=False,
    pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
    pad_token_segment_id=0,
    pad_token_label_id=pad_token_label_id)

all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)

tensordataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)


dataset_test = "eng.testa"
test_examples = read_examples_from_file(".", mode=f"data/{dataset_test}")

test_features = convert_examples_to_features(
    test_examples,
    label_list=labels,
    max_seq_length=128,
    tokenizer=tokenizer,
    cls_token_at_end=False,
    cls_token=tokenizer.cls_token,
    cls_token_segment_id=0,
    sep_token=tokenizer.sep_token,
    sep_token_extra=False,
    pad_on_left=False,
    pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
    pad_token_segment_id=0,
    pad_token_label_id=pad_token_label_id)


all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_ids for f in test_features], dtype=torch.long)

test_tensordataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)


In [3]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

eval_sampler = SequentialSampler(tensordataset)
eval_dataloader = DataLoader(tensordataset, sampler=eval_sampler, batch_size=1)

eval_sampler = SequentialSampler(test_tensordataset)
test_eval_dataloader = DataLoader(test_tensordataset, sampler=eval_sampler, batch_size=1)

In [4]:
from gdpr.models.bert_model.bert import BERT

model = BERT(vocab_size=len(tokenizer.vocab), patch_size=128)

total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params}')

Total parameters: 3427689


In [5]:
from gdpr.train.bert_train_steps import train

optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=3e-3,
                             betas=(0.9, 0.999),
                             weight_decay=0.3)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

epochs = 1

# print(train_step(model, eval_dataloader, torch.nn.CrossEntropyLoss(), optimizer, device))
results = train(model, eval_dataloader, test_eval_dataloader, optimizer, torch.nn.CrossEntropyLoss(), epochs, device)

  0%|          | 0/1 [00:00<?, ?it/s]

0.0 batch:  0
0.44403714565004887 batch:  5
0.6279302101736118 batch:  10
0.7117386495309882 batch:  15
0.7488318677170669 batch:  20
0.7606636513978509 batch:  25
0.7648227747933937 batch:  30
0.7802556701688343 batch:  35
0.7762755942481286 batch:  40
0.7830892295430102 batch:  45
0.7952980352367394 batch:  50
0.8048054279428821 batch:  55
0.7952512546776118 batch:  60
0.780115140342531 batch:  65
0.7866853909108527 batch:  70
0.7917384614754059 batch:  75
0.7899138334324479 batch:  80
0.796168843116608 batch:  85
0.7981236337006422 batch:  90
0.7834065462972521 batch:  95
0.7882161147508451 batch:  100
0.7924802312538893 batch:  105
0.7950587618155883 batch:  110
0.802106721589547 batch:  115
0.8041147198756006 batch:  120
0.8034998074135212 batch:  125
0.7959633245368203 batch:  130
0.7943972284343847 batch:  135
0.7930669085211869 batch:  140
0.792463512847842 batch:  145
0.7976827829550482 batch:  150
0.7950840177061539 batch:  155
0.793463486190525 batch:  160
0.7911604886623154

KeyboardInterrupt: 