In [7]:
labels = ["O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC"]

In [8]:

import torch

from torch.utils.data import TensorDataset
from transformers import BertTokenizer
from gdpr.data.default_parse import read_examples_from_file, convert_examples_to_features

tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

dataset = "eng.train"
examples = read_examples_from_file(".", mode=f"data/{dataset}")
pad_token_label_id = torch.nn.CrossEntropyLoss().ignore_index

features = convert_examples_to_features(
    examples,
    label_list=labels,
    max_seq_length=128,
    tokenizer=tokenizer,
    cls_token_at_end=False,
    cls_token=tokenizer.cls_token,
    cls_token_segment_id=0,
    sep_token=tokenizer.sep_token,
    sep_token_extra=False,
    pad_on_left=False,
    pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
    pad_token_segment_id=0,
    pad_token_label_id=pad_token_label_id)

all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)

tensordataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)


dataset_test = "eng.testa"
test_examples = read_examples_from_file(".", mode=f"data/{dataset_test}")

test_features = convert_examples_to_features(
    test_examples,
    label_list=labels,
    max_seq_length=128,
    tokenizer=tokenizer,
    cls_token_at_end=False,
    cls_token=tokenizer.cls_token,
    cls_token_segment_id=0,
    sep_token=tokenizer.sep_token,
    sep_token_extra=False,
    pad_on_left=False,
    pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
    pad_token_segment_id=0,
    pad_token_label_id=pad_token_label_id)


all_input_ids = torch.tensor([f.input_ids for f in test_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in test_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in test_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_ids for f in test_features], dtype=torch.long)

test_tensordataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)


In [9]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

eval_sampler = SequentialSampler(tensordataset)
eval_dataloader = DataLoader(tensordataset, sampler=eval_sampler, batch_size=1)

eval_sampler = SequentialSampler(test_tensordataset)
test_eval_dataloader = DataLoader(test_tensordataset, sampler=eval_sampler, batch_size=1)

In [1]:
from gdpr.models.lstm_model.lstm import LSTM

# model = LSTM(vocab_size=len(tokenizer.vocab), patch_size=128)
model = LSTM()

total_params = sum(p.numel() for p in model.parameters())
print(f'Total parameters: {total_params}')

Total parameters: 293513


In [9]:
import torch
tensor = torch.rand(3, 3, 128)

In [10]:
model(tensor).shape

torch.Size([3, 3, 3, 9])

In [5]:
from gdpr.train.bert_train_steps import train

optimizer = torch.optim.Adam(params=model.parameters(),
                             lr=3e-3,
                             betas=(0.9, 0.999),
                             weight_decay=0.3)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

epochs = 1

# print(train_step(model, eval_dataloader, torch.nn.CrossEntropyLoss(), optimizer, device))
results = train(model, eval_dataloader, test_eval_dataloader, optimizer, torch.nn.CrossEntropyLoss(), epochs, device)

  0%|          | 0/1 [00:00<?, ?it/s]

0.1111111111111111 batch:  0
0.4411643314869121 batch:  5
0.6273372466041087 batch:  10
0.7113309870769549 batch:  15
0.7485212677520892 batch:  20
0.7604127821953687 batch:  25
0.7646123683655054 batch:  30
0.7800744868559306 batch:  35
0.776697226554104 batch:  40
0.78346503225051 batch:  45
0.7956369945415431 batch:  50
0.8051141230240427 batch:  55
0.7955346468832674 batch:  60
0.7874362920918353 batch:  65
0.7948044047658956 batch:  70
0.7993233296820909 batch:  75
0.7970304998979795 batch:  80
0.8028717499039109 batch:  85
0.8044582489062251 batch:  90
0.7894112336275443 batch:  95
0.7939235403321128 batch:  100
0.7979184386473614 batch:  105
0.8002520049120573 batch:  110
0.8070761180008232 batch:  115
0.8088787693277332 batch:  120
0.80807480728422 batch:  125
0.8003637060918435 batch:  130
0.7991549599598806 batch:  135
0.7976559261628 batch:  140
0.7968953723510437 batch:  145
0.8019678921435744 batch:  150
0.7992317836514581 batch:  155
0.7974824395288321 batch:  160
0.79505

KeyboardInterrupt: 