In [1]:
import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertForMaskedLM, AdamW
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
import pandas as pd
import random
import numpy as np

from tqdm import tqdm
from torch import nn



[2024-04-20 23:47:37,632] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [2]:
    # sentence dataset struct
    class SentenceDataset(Dataset):
        def __init__(self, tokenizer, file_path, max_len, lan_codes=None, split=None):
            self.tokenizer = tokenizer
            self.sentences = pd.read_csv(file_path)
            train_ratio = 0.999
    
            if lan_codes is not None:
                # filter languages
                self.sentences = self.sentences[self.sentences['lan_code'].isin(lan_codes)]
            if split is not None:
                length = len(self.sentences)
                if split == 'train':
                    self.sentences = self.sentences.iloc[:int(length * train_ratio)]
                elif split == 'test':
                    self.sentences = self.sentences.iloc[int(length * train_ratio):]
                else:
                    raise ValueError(f"split {split} not recognized")
            self.max_len = max_len
            
            print(f"Done loading dataset for split {split}")
    
        def __len__(self):
            return len(self.sentences)
    
        def __getitem__(self, item):
            sentence = str(self.sentences.iloc[item]['sentence']).lower()
            encoding = self.tokenizer(sentence, max_length=self.max_len, padding='max_length', truncation=True)
            return {key: torch.tensor(val) for key, val in encoding.items()}

In [3]:
# training arguments
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def convert_weights_to_fp16(model: nn.Module, dtype=torch.bfloat16):
    """Convert applicable model parameters to fp16"""

    def _convert_weights_to_fp16(l):
        if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
            l.weight.data = l.weight.data.to(dtype)
            if l.bias is not None:
                l.bias.data = l.bias.data.to(dtype)

    model.apply(_convert_weights_to_fp16)

def reinitialize_weights(model):
    for module in model.modules():
        if isinstance(module, torch.nn.Linear):
            module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, torch.nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=model.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, torch.nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

In [4]:
args = {
    'batch_size': 60,
    'lr': 2e-5,
    'log_freq': 1,
    'eval_freq': 100,
    'epochs': 3,
    'mlm_prob': 0.4,
    'grad_accums': 20
}
lan_code = 'eng'

In [5]:


lan_codes = [lan_code]
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

train_dataset = SentenceDataset(tokenizer, 'data/big-language-detection/sentences.csv', max_len=512, split='train')
test_dataset = SentenceDataset(tokenizer, 'data/big-language-detection/sentences.csv', max_len=512, split='test')

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob)
train_dataloader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              collate_fn=data_collator)
test_dataloader = DataLoader(test_dataset,
                             batch_size=batch_size*3,
                             shuffle=True,
                             collate_fn=data_collator)

Done loading dataset for split train
Done loading dataset for split test


In [6]:
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
reinitialize_weights(model)    # randomize weights so we start from scratch
convert_weights_to_fp16(model)
model.to(device)
print(f"training with {sum([p.numel() for p in model.parameters() if p.requires_grad]) / 1e6:.2f}M parameters")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


training with 109.51M parameters


In [7]:
import wandb

wandb.init(project='ling')

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpranavputta-tonic[0m ([33mcvmlp-embodiment-transfer[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [None]:

optimizer = AdamW(model.parameters(), lr=lr)

def train_step(model, batch, optimizer, step_optim):
    inputs = {k: v.to(device) for k, v in batch.items()}
    with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
        outputs = model(**inputs)
    loss = outputs.loss
    loss.backward()
    if step_optim:
        optimizer.step()
        optimizer.zero_grad()
    return loss

@torch.no_grad()
def eval_step(model):
    # step has different meanings here lol but whatever
    total_loss = 0
    for batch in tqdm(test_dataloader, desc='Eval', total=len(test_dataloader)):
        inputs = {k: v.to(device) for k, v in batch.items()}
        with torch.amp.autocast(device_type='cuda', dtype=torch.bfloat16):
            outputs = model(**inputs)
        total_loss += outputs.loss
    return total_loss / len(test_dataloader)

def train(args):
    # train
    lr = args['lr']
    batch_size = args['batch_size']
    log_freq = args['log_freq']
    eval_freq = args['eval_freq']
    ckpt_freq = args['ckpt_freq']
    epochs = args['epochs']
    mlm_prob = args['mlm_prob']
    grad_accums = args['grad_accums']
    
    step = 0
    for epoch in (range(epochs)):  # Number of training epochs
        model.train()
        train_loss = 0.0
        for i, batch in tqdm(enumerate(train_dataloader), desc=f"Epoch {epoch}", total=len(train_dataloader)):
            if (i + 1) % grad_accums != 0:
                train_loss += train_step(model, batch, optimizer, step_optim=False)
                continue
       
            train_loss += train_step(model, batch, optimizer, step_optim=True)
            train_loss /= grad_accums
            step += 1                

            if (step) % eval_freq == 0:
                test_loss = eval_step(model)
                print("Eval Loss:" ,test_loss)
                wandb.log({"test.loss": test_loss})
    
            if (step) % log_freq == 0:
                print(f"Loss: {train_loss}")
                wandb.log({"train.loss": train_loss})

            if (step) % ckpt_freq == 0:
                os.makedirs("checkpoints", exist_ok=True)
                torch.save(f"checkpoints/{lan_code}_{step}.pt")

            train_loss = 0.

Epoch 0:   0%|                                                                                                                                | 20/172192 [00:20<52:25:22,  1.10s/it]

Loss: 10.50524616241455


Epoch 0:   0%|                                                                                                                                | 40/172192 [00:40<52:09:07,  1.09s/it]

Loss: 10.364455223083496


Epoch 0:   0%|                                                                                                                                | 60/172192 [00:59<52:10:57,  1.09s/it]

Loss: 10.260120391845703


Epoch 0:   0%|                                                                                                                                | 80/172192 [01:19<52:10:58,  1.09s/it]

Loss: 10.168787956237793


Epoch 0:   0%|                                                                                                                               | 100/172192 [01:39<52:09:46,  1.09s/it]

Loss: 10.092512130737305


Epoch 0:   0%|                                                                                                                               | 120/172192 [01:59<52:10:54,  1.09s/it]

Loss: 10.0335111618042


Epoch 0:   0%|                                                                                                                               | 140/172192 [02:18<52:11:06,  1.09s/it]

Loss: 9.950963020324707


Epoch 0:   0%|                                                                                                                               | 160/172192 [02:38<52:08:45,  1.09s/it]

Loss: 9.927413940429688


Epoch 0:   0%|▏                                                                                                                              | 180/172192 [02:58<52:08:44,  1.09s/it]

Loss: 9.921772956848145


Epoch 0:   0%|▏                                                                                                                              | 200/172192 [03:17<52:08:34,  1.09s/it]

Loss: 9.854147911071777


Epoch 0:   0%|▏                                                                                                                              | 220/172192 [03:37<52:08:44,  1.09s/it]

Loss: 9.843486785888672


Epoch 0:   0%|▏                                                                                                                              | 240/172192 [03:57<52:07:54,  1.09s/it]

Loss: 9.855245590209961


Epoch 0:   0%|▏                                                                                                                              | 260/172192 [04:17<52:06:31,  1.09s/it]

Loss: 9.800031661987305


Epoch 0:   0%|▏                                                                                                                              | 280/172192 [04:36<52:07:13,  1.09s/it]

Loss: 9.783282279968262


Epoch 0:   0%|▏                                                                                                                              | 300/172192 [04:56<52:07:12,  1.09s/it]

Loss: 9.743365287780762


Epoch 0:   0%|▏                                                                                                                              | 320/172192 [05:16<52:05:35,  1.09s/it]

Loss: 9.729315757751465


Epoch 0:   0%|▎                                                                                                                              | 340/172192 [05:35<52:05:29,  1.09s/it]

Loss: 9.751376152038574


Epoch 0:   0%|▎                                                                                                                              | 360/172192 [05:55<52:05:28,  1.09s/it]

Loss: 9.713427543640137


Epoch 0:   0%|▎                                                                                                                              | 380/172192 [06:15<52:05:18,  1.09s/it]

Loss: 9.688971519470215


Epoch 0:   0%|▎                                                                                                                              | 400/172192 [06:35<52:05:10,  1.09s/it]

Loss: 9.690367698669434


Epoch 0:   0%|▎                                                                                                                              | 420/172192 [06:54<52:03:54,  1.09s/it]

Loss: 9.675924301147461


Epoch 0:   0%|▎                                                                                                                              | 440/172192 [07:14<52:04:06,  1.09s/it]

Loss: 9.66103744506836


Epoch 0:   0%|▎                                                                                                                              | 460/172192 [07:34<52:03:02,  1.09s/it]

Loss: 9.62876033782959


Epoch 0:   0%|▎                                                                                                                              | 480/172192 [07:53<52:03:46,  1.09s/it]

Loss: 9.657923698425293


Epoch 0:   0%|▎                                                                                                                              | 500/172192 [08:13<52:02:37,  1.09s/it]

Loss: 9.61341381072998


Epoch 0:   0%|▍                                                                                                                              | 520/172192 [08:33<52:02:45,  1.09s/it]

Loss: 9.577372550964355


Epoch 0:   0%|▍                                                                                                                              | 540/172192 [08:52<52:02:42,  1.09s/it]

Loss: 9.593214988708496


Epoch 0:   0%|▍                                                                                                                              | 560/172192 [09:12<52:01:50,  1.09s/it]

Loss: 9.580409049987793


Epoch 0:   0%|▍                                                                                                                              | 580/172192 [09:32<52:01:44,  1.09s/it]

Loss: 9.558121681213379


Epoch 0:   0%|▍                                                                                                                              | 600/172192 [09:52<52:01:12,  1.09s/it]

Loss: 9.528478622436523


Epoch 0:   0%|▍                                                                                                                              | 620/172192 [10:11<52:00:42,  1.09s/it]

Loss: 9.548881530761719


Epoch 0:   0%|▍                                                                                                                              | 640/172192 [10:31<52:00:59,  1.09s/it]

Loss: 9.497949600219727


Epoch 0:   0%|▍                                                                                                                              | 660/172192 [10:51<52:00:20,  1.09s/it]

Loss: 9.531257629394531


Epoch 0:   0%|▌                                                                                                                              | 680/172192 [11:11<51:59:19,  1.09s/it]

Loss: 9.4994535446167


Epoch 0:   0%|▌                                                                                                                              | 700/172192 [11:30<51:59:30,  1.09s/it]

Loss: 9.500032424926758


Epoch 0:   0%|▌                                                                                                                              | 720/172192 [11:50<52:01:05,  1.09s/it]

Loss: 9.496455192565918


Epoch 0:   0%|▌                                                                                                                              | 740/172192 [12:10<51:59:34,  1.09s/it]

Loss: 9.475293159484863


Epoch 0:   0%|▌                                                                                                                              | 760/172192 [12:29<51:58:33,  1.09s/it]

Loss: 9.450366020202637


Epoch 0:   0%|▌                                                                                                                              | 780/172192 [12:49<51:58:09,  1.09s/it]

Loss: 9.425127983093262


Epoch 0:   0%|▌                                                                                                                              | 800/172192 [13:09<51:57:31,  1.09s/it]

Loss: 9.473881721496582


Epoch 0:   0%|▌                                                                                                                              | 820/172192 [13:29<51:57:40,  1.09s/it]

Loss: 9.439116477966309


Epoch 0:   0%|▌                                                                                                                              | 840/172192 [13:48<51:56:22,  1.09s/it]

Loss: 9.417978286743164


Epoch 0:   0%|▋                                                                                                                              | 860/172192 [14:08<51:55:56,  1.09s/it]

Loss: 9.425447463989258


Epoch 0:   1%|▋                                                                                                                              | 880/172192 [14:28<51:57:02,  1.09s/it]

Loss: 9.400135040283203


Epoch 0:   1%|▋                                                                                                                              | 900/172192 [14:47<51:56:17,  1.09s/it]

Loss: 9.444106101989746


Epoch 0:   1%|▋                                                                                                                              | 920/172192 [15:07<51:55:04,  1.09s/it]

Loss: 9.419078826904297


Epoch 0:   1%|▋                                                                                                                              | 940/172192 [15:27<51:54:03,  1.09s/it]

Loss: 9.40247631072998


Epoch 0:   1%|▋                                                                                                                              | 960/172192 [15:47<51:54:58,  1.09s/it]

Loss: 9.39191722869873


Epoch 0:   1%|▋                                                                                                                              | 980/172192 [16:06<51:54:13,  1.09s/it]

Loss: 9.34675407409668


Epoch 0:   1%|▋                                                                                                                             | 1000/172192 [16:26<51:54:05,  1.09s/it]

Loss: 9.349030494689941


Epoch 0:   1%|▋                                                                                                                             | 1020/172192 [16:46<51:53:02,  1.09s/it]

Loss: 9.337716102600098


Epoch 0:   1%|▊                                                                                                                             | 1040/172192 [17:05<51:53:52,  1.09s/it]

Loss: 9.360568046569824


Epoch 0:   1%|▊                                                                                                                             | 1060/172192 [17:25<51:53:09,  1.09s/it]

Loss: 9.355856895446777


Epoch 0:   1%|▊                                                                                                                             | 1080/172192 [17:45<51:52:40,  1.09s/it]

Loss: 9.340324401855469


Epoch 0:   1%|▊                                                                                                                             | 1100/172192 [18:04<51:52:31,  1.09s/it]

Loss: 9.340080261230469


Epoch 0:   1%|▊                                                                                                                             | 1120/172192 [18:24<51:51:11,  1.09s/it]

Loss: 9.335403442382812


Epoch 0:   1%|▊                                                                                                                             | 1140/172192 [18:44<51:51:44,  1.09s/it]

Loss: 9.349233627319336


Epoch 0:   1%|▊                                                                                                                             | 1160/172192 [19:04<51:50:30,  1.09s/it]

Loss: 9.31374454498291


Epoch 0:   1%|▊                                                                                                                             | 1180/172192 [19:23<51:49:53,  1.09s/it]

Loss: 9.282546043395996


Epoch 0:   1%|▉                                                                                                                             | 1200/172192 [19:43<51:50:05,  1.09s/it]

Loss: 9.292162895202637


Epoch 0:   1%|▉                                                                                                                             | 1220/172192 [20:03<51:51:46,  1.09s/it]

Loss: 9.26181697845459


Epoch 0:   1%|▉                                                                                                                             | 1240/172192 [20:22<51:48:28,  1.09s/it]

Loss: 9.307523727416992


Epoch 0:   1%|▉                                                                                                                             | 1260/172192 [20:42<51:48:43,  1.09s/it]

Loss: 9.221877098083496


Epoch 0:   1%|▉                                                                                                                             | 1280/172192 [21:02<51:48:34,  1.09s/it]

Loss: 9.25483512878418


Epoch 0:   1%|▉                                                                                                                             | 1300/172192 [21:22<51:47:52,  1.09s/it]

Loss: 9.259468078613281


Epoch 0:   1%|▉                                                                                                                             | 1320/172192 [21:41<51:48:55,  1.09s/it]

Loss: 9.290961265563965


Epoch 0:   1%|▉                                                                                                                             | 1340/172192 [22:01<51:47:40,  1.09s/it]

Loss: 9.19813346862793


Epoch 0:   1%|▉                                                                                                                             | 1360/172192 [22:21<51:47:08,  1.09s/it]

Loss: 9.215005874633789


Epoch 0:   1%|█                                                                                                                             | 1380/172192 [22:40<51:47:57,  1.09s/it]

Loss: 9.262953758239746


Epoch 0:   1%|█                                                                                                                             | 1400/172192 [23:00<51:45:19,  1.09s/it]

Loss: 9.236639022827148


Epoch 0:   1%|█                                                                                                                             | 1420/172192 [23:20<51:46:21,  1.09s/it]

Loss: 9.203132629394531


Epoch 0:   1%|█                                                                                                                             | 1440/172192 [23:40<51:45:33,  1.09s/it]

Loss: 9.230382919311523


Epoch 0:   1%|█                                                                                                                             | 1460/172192 [23:59<51:45:42,  1.09s/it]

Loss: 9.196354866027832


Epoch 0:   1%|█                                                                                                                             | 1480/172192 [24:19<51:44:40,  1.09s/it]

Loss: 9.189078330993652


Epoch 0:   1%|█                                                                                                                             | 1500/172192 [24:39<51:44:36,  1.09s/it]

Loss: 9.222686767578125


Epoch 0:   1%|█                                                                                                                             | 1520/172192 [24:58<51:44:22,  1.09s/it]

Loss: 9.176434516906738


Epoch 0:   1%|█▏                                                                                                                            | 1540/172192 [25:18<51:43:23,  1.09s/it]

Loss: 9.117759704589844


Epoch 0:   1%|█▏                                                                                                                            | 1560/172192 [25:38<51:42:41,  1.09s/it]

Loss: 9.174029350280762


Epoch 0:   1%|█▏                                                                                                                            | 1580/172192 [25:57<51:42:26,  1.09s/it]

Loss: 9.162071228027344


Epoch 0:   1%|█▏                                                                                                                            | 1600/172192 [26:17<51:42:46,  1.09s/it]

Loss: 9.174112319946289


Epoch 0:   1%|█▏                                                                                                                            | 1620/172192 [26:37<51:42:45,  1.09s/it]

Loss: 9.113444328308105


Epoch 0:   1%|█▏                                                                                                                            | 1640/172192 [26:57<51:41:02,  1.09s/it]

Loss: 9.101839065551758


Epoch 0:   1%|█▏                                                                                                                            | 1660/172192 [27:16<51:41:34,  1.09s/it]

Loss: 9.12109375


Epoch 0:   1%|█▏                                                                                                                            | 1680/172192 [27:36<51:41:12,  1.09s/it]

Loss: 9.14246940612793


Epoch 0:   1%|█▏                                                                                                                            | 1700/172192 [27:56<51:41:08,  1.09s/it]

Loss: 9.093134880065918


Epoch 0:   1%|█▎                                                                                                                            | 1720/172192 [28:15<51:40:56,  1.09s/it]

Loss: 9.120455741882324


Epoch 0:   1%|█▎                                                                                                                            | 1740/172192 [28:35<51:39:58,  1.09s/it]

Loss: 9.045050621032715


Epoch 0:   1%|█▎                                                                                                                            | 1760/172192 [28:55<51:39:42,  1.09s/it]

Loss: 9.037403106689453


Epoch 0:   1%|█▎                                                                                                                            | 1780/172192 [29:15<51:39:51,  1.09s/it]

Loss: 9.080626487731934


Epoch 0:   1%|█▎                                                                                                                            | 1800/172192 [29:34<51:38:55,  1.09s/it]

Loss: 9.034016609191895


Epoch 0:   1%|█▎                                                                                                                            | 1820/172192 [29:54<51:38:43,  1.09s/it]

Loss: 9.031122207641602


Epoch 0:   1%|█▎                                                                                                                            | 1840/172192 [30:14<51:38:27,  1.09s/it]

Loss: 9.029157638549805


Epoch 0:   1%|█▎                                                                                                                            | 1860/172192 [30:33<51:38:44,  1.09s/it]

Loss: 9.035408973693848


Epoch 0:   1%|█▍                                                                                                                            | 1880/172192 [30:53<51:37:41,  1.09s/it]

Loss: 9.004018783569336


Epoch 0:   1%|█▍                                                                                                                            | 1900/172192 [31:13<51:37:26,  1.09s/it]

Loss: 9.05422592163086


Epoch 0:   1%|█▍                                                                                                                            | 1920/172192 [31:32<51:36:37,  1.09s/it]

Loss: 9.057308197021484


Epoch 0:   1%|█▍                                                                                                                            | 1940/172192 [31:52<51:36:47,  1.09s/it]

Loss: 8.98967456817627


Epoch 0:   1%|█▍                                                                                                                            | 1960/172192 [32:12<51:35:59,  1.09s/it]

Loss: 9.042770385742188


Epoch 0:   1%|█▍                                                                                                                            | 1980/172192 [32:32<51:35:41,  1.09s/it]

Loss: 8.972175598144531


Epoch 0:   1%|█▍                                                                                                                            | 1999/172192 [32:50<46:24:35,  1.02it/s]
Eval:   0%|                                                                                                                                                   | 0/58 [00:00<?, ?it/s][A
Eval:   2%|██▍                                                                                                                                        | 1/58 [00:01<01:53,  2.00s/it][A
Eval:   3%|████▊                                                                                                                                      | 2/58 [00:02<00:52,  1.07it/s][A
Eval:   5%|███████▏                                                                                                                                   | 3/58 [00:03<00:51,  1.07it/s][A
Eval:   7%|█████████▌                                                         

Eval Loss: 

Epoch 0:   1%|█▍                                                                                                                           | 2000/172192 [33:47<838:49:16, 17.74s/it]

tensor(8.9732, device='cuda:0')
Loss: 9.03100299835205


Epoch 0:   1%|█▍                                                                                                                            | 2020/172192 [34:07<52:13:19,  1.10s/it]

Loss: 8.988593101501465


Epoch 0:   1%|█▍                                                                                                                            | 2040/172192 [34:26<51:35:25,  1.09s/it]

Loss: 8.961772918701172


Epoch 0:   1%|█▌                                                                                                                            | 2060/172192 [34:46<51:34:37,  1.09s/it]

Loss: 8.942352294921875


Epoch 0:   1%|█▌                                                                                                                            | 2080/172192 [35:06<51:35:01,  1.09s/it]

Loss: 9.022245407104492


Epoch 0:   1%|█▌                                                                                                                            | 2100/172192 [35:25<51:34:50,  1.09s/it]

Loss: 8.965218544006348


Epoch 0:   1%|█▌                                                                                                                            | 2111/172192 [35:36<46:22:28,  1.02it/s]

In [None]:
torch.cuda.is_available()

In [None]:
dataset.sentences.iloc[:100]