In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForMaskedLM, AdamW
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
import pandas as pd
import random
import numpy as np

from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# sentence dataset struct
class SentenceDataset(Dataset):
    def __init__(self, tokenizer, file_path, max_len, lan_codes=None):
        self.tokenizer = tokenizer
        self.sentences = pd.read_csv(file_path)

        if lan_codes is not None:
            # filter languages
            self.sentences = self.sentences[self.sentences['lan_code'].isin(lan_codes)]
        self.max_len = max_len
        print("Done loading dataset")

    def __len__(self):
        return len(self.sentences)

    def __getitem__(self, item):
        sentence = str(self.sentences.iloc[item]['sentence']).lower()
        encoding = self.tokenizer(sentence, max_length=self.max_len, padding='max_length', truncation=True)
        return {key: torch.tensor(val) for key, val in encoding.items()}

In [3]:
# training arguments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def reinitialize_weights(model):
    for module in model.modules():
        if isinstance(module, torch.nn.Linear):
            module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, torch.nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, torch.nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

def train_model(lan_code, args): 
    lr = args['lr']
    batch_size = args['batch_size']
    log_freq = args['log_freq']
    epochs = args['epochs']
    
    lan_codes = [lan_code]
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    dataset = SentenceDataset(tokenizer, 'data/big-language-detection/sentences.csv', max_len=512)
    
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
    model = BertForMaskedLM.from_pretrained('bert-base-uncased')
    reinitialize_weights(model)    # randomize weights so we start from scratch
    model.to(device)

    
    optimizer = AdamW(model.parameters(), lr=lr)

    # train
    for epoch in (range(epochs)):  # Number of training epochs
        model.train()
        for i, batch in tqdm(enumerate(dataloader), desc=f"Epoch {epoch}", total=len(dataloader)):
            inputs = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**inputs, labels=inputs['input_ids'])
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if (i + 1) % log_freq == 0:
                print(f"Loss: {loss}")
    
    os.makedirs("checkpoints", exist_ok=True)
    model.save_pretrained(f"checkpoints/{lan_code}.pt")
    
    print("Training complete.")


In [None]:
args = {
    'batch_size': 8,
    'lr': 1e-5,
    'log_freq': 1,
    'epochs': 3
}

train_model("ang", args)

Done loading dataset


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Epoch 0:   0%|                                                                 | 1/1292727 [00:19<6869:09:35, 19.13s/it]