In [1]:
labels = ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]

In [2]:

import torch

from torch.utils.data import TensorDataset
from transformers import BertTokenizer
from gdpr.data.default_parse import read_examples_from_file, convert_examples_to_features

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")




In [3]:
dataset = "eng.train"
examples = read_examples_from_file(".", mode=f"data/{dataset}")
pad_token_label_id = torch.nn.CrossEntropyLoss().ignore_index

features = convert_examples_to_features(
    examples,
    label_list=labels,
    max_seq_length=128,
    tokenizer=tokenizer,
    cls_token_at_end=False,
    cls_token=tokenizer.cls_token,
    cls_token_segment_id=0,
    sep_token=tokenizer.sep_token,
    sep_token_extra=False,
    pad_on_left=False,
    pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
    pad_token_segment_id=0,
    pad_token_label_id=pad_token_label_id)

all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)

tensordataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)


dataset_test = "eng.testa"
test_examples = read_examples_from_file(".", mode=f"data/{dataset_test}")

test_features = convert_examples_to_features(
    test_examples,
    label_list=labels,
    max_seq_length=128,
    tokenizer=tokenizer,
    cls_token_at_end=False,
    cls_token=tokenizer.cls_token,
    cls_token_segment_id=0,
    sep_token=tokenizer.sep_token,
    sep_token_extra=False,
    pad_on_left=False,
    pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
    pad_token_segment_id=0,
    pad_token_label_id=pad_token_label_id)


all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_ids for f in test_features], dtype=torch.long)

test_tensordataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

KeyboardInterrupt: 

In [None]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

eval_sampler = SequentialSampler(tensordataset)
eval_dataloader = DataLoader(tensordataset, sampler=eval_sampler, batch_size=1)

eval_sampler = SequentialSampler(test_tensordataset)
test_eval_dataloader = DataLoader(test_tensordataset, sampler=eval_sampler, batch_size=16)

In [None]:
from gdpr.models.lstm_model.lstm import LSTM

model = LSTM(vocab_size=len(tokenizer.vocab))

total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params}')

Total parameters: 22959241


In [None]:
from gdpr.train.lstm_train_steps import train

optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=3e-3,
                             betas=(0.9, 0.999),
                             weight_decay=0.3)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

epochs = 1

# print(train_step(model, eval_dataloader, torch.nn.CrossEntropyLoss(), optimizer, device))
results = train(model, eval_dataloader, eval_dataloader, optimizer, torch.nn.CrossEntropyLoss(), epochs, device)

  0%|          | 0/1 [00:00<?, ?it/s]

0.0 batch:  0
0.021179537308569567 batch:  5
0.2607326940758026 batch:  10
0.45929035721374434 batch:  15
0.5564903116658335 batch:  20
0.6053108561257008 batch:  25
0.6345268819844937 batch:  30
0.6680564291389482 batch:  35
0.6783399075830951 batch:  40
0.6957987262111324 batch:  45
0.716565424388379 batch:  50
0.7331025144916968 batch:  55
0.7294256292142286 batch:  60
0.7275981595997439 batch:  65
0.7411922977008126 batch:  70
0.7492383349239212 batch:  75
0.7500371714829067 batch:  80
0.7586105917455285 batch:  85
0.7626290225147868 batch:  90
0.7497606127773268 batch:  95
0.7562358215041832 batch:  100
0.7620084424056549 batch:  105
0.7659595760686259 batch:  110
0.7742618110902983 batch:  115
0.7774204255126844 batch:  120
0.777864810445959 batch:  125
0.7713067625680199 batch:  130
0.7711662864185506 batch:  135
0.7706597587754178 batch:  140
0.7708237312440513 batch:  145
0.7767595504109858 batch:  150
0.7748314015897857 batch:  155
0.7738398333075844 batch:  160
0.77212790966

KeyboardInterrupt: 