In [7]:
import torch
from tqdm.auto import tqdm
from transformers import AdamW
from transformers import BertTokenizer, BertForMaskedLM
from datasets import load_dataset
import time

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased', return_dict=True)

dataset = load_dataset('gayanin/pubmed-gastro-maskfilling')

data_masked_highlights = dataset['train']['masked_highlights'][:150000]

inputs = tokenizer(
    data_masked_highlights,
    max_length=512,
    truncation=True,
    padding='max_length',
    return_tensors='pt'
)

for i in range(len(inputs['input_ids'])):
    for j in range(inputs['input_ids'].shape[1]):
        if inputs['input_ids'][i, j] == tokenizer.convert_tokens_to_ids('<mask>'):
            inputs['input_ids'][i, j] = tokenizer.convert_tokens_to_ids('[MASK]')

inputs['labels'] = inputs['input_ids'].clone()

for i in range(len(inputs['input_ids'])):
    inputs['labels'][i] = torch.where(inputs['input_ids'][i] == tokenizer.convert_tokens_to_ids('[MASK]'), torch.tensor(-100), inputs['labels'][i])

class MaskedTextDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __len__(self):
        return len(self.encodings['input_ids'])

    def __getitem__(self, index):
        input_ids = self.encodings['input_ids'][index]
        labels = self.encodings['labels'][index]
        attention_mask = self.encodings['attention_mask'][index]
        token_type_ids = self.encodings['token_type_ids'][index]
        return {
            'input_ids': input_ids,
            'labels': labels,
            'attention_mask': attention_mask,
            'token_type_ids': token_type_ids
        }

dataset = MaskedTextDataset(inputs)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    shuffle=True
)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model.to(device)

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


masked_pubmed_highlights_dataset.csv:   0%|          | 0.00/2.50M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/11772 [00:00<?, ? examples/s]

BertForMaskedLM(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [8]:
epochs = 2
optimizer = AdamW(model.parameters(), lr=5e-5)
model.train()

for epoch in range(epochs):
    loop = tqdm(dataloader, dynamic_ncols=True)
    start_time = time.time()
    
    for step, batch in enumerate(loop):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        labels = batch['labels'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        
        elapsed_time = time.time() - start_time
        steps_done = step + 1
        total_steps = len(loop)
        remaining_time = (elapsed_time / steps_done) * (total_steps - steps_done)
        
        loop.set_description(f"Epoch {epoch + 1}")
        loop.set_postfix(loss=loss.item(), elapsed=f"{elapsed_time:.2f}s", remaining=f"{remaining_time:.2f}s")



  0%|          | 0/736 [00:00<?, ?it/s]

  0%|          | 0/736 [00:00<?, ?it/s]