In [2]:
#!g1.1
def apply_to_dict_values(dict, f):
    for key, value in dict.items():
        dict[key] = f(value)

In [3]:
#!g1.1
BERT_TYPE = 'bert-base-uncased'

In [4]:
#!g1.1
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel, BertConfig

class RCVDataset(Dataset):
    """RCV dataset."""

    def __init__(self, path):
        self.path = path
        self.data = pd.read_csv(self.path, sep='\t', header=None)
        self.tokenizer = BertTokenizer.from_pretrained(BERT_TYPE)

    def __len__(self):
        return self.data.shape[0]

    @staticmethod
    def target_to_tensor(target):
        return torch.tensor([float(label) for label in target])

    def __getitem__(self, idx):
        data = self.tokenizer(self.data.iloc[idx, 1], return_tensors="pt", max_length=512, padding="max_length", truncation=True) # max_len=512 !DocBERT
        apply_to_dict_values(data, lambda x: x.flatten())
        return data, RCVDataset.target_to_tensor(self.data.iloc[idx, 0])

In [5]:
#!g1.1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [6]:
#!g1.1
train_dataset = RCVDataset('./train_docbert.rcv')
val_dataset = RCVDataset('./valid_docbert.rcv')
test_dataset = RCVDataset('./test_docbert.rcv')

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=231508.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=28.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=466062.0), HTML(value='')))

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=570.0), HTML(value='')))







In [7]:
#!g1.1
BATCH_SIZE = 16 # !DocBERT
N_CLASSES = test_dataset[0][1].shape[0]
BATCH_SIZE, N_CLASSES

(16, 103)

In [8]:
#!g1.1
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
#!g1.1
config = BertConfig.from_pretrained(BERT_TYPE)
config.return_dict = True
bert = BertModel.from_pretrained(BERT_TYPE, config=config)

HBox(children=(HTML(value='Downloading'), FloatProgress(value=0.0, max=440473133.0), HTML(value='')))




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [10]:
#!g1.1
import torch.nn as nn

class MultiLabelBert(nn.Module):
    def __init__(self, bert_model, n_classes, freeze_bert_weights=False):
        super(MultiLabelBert, self).__init__()
        self.bert_model = bert_model
        if freeze_bert_weights: # == false !DocBERT
            for param in self.bert_model.parameters():
                param.requires_grad = False
        self.n_bert_features = bert_model.pooler.dense.out_features
        self.n_classes = n_classes
        self.dense = nn.Linear(self.n_bert_features, self.n_classes)

    def forward(self, inputs):
        bert_output = self.bert_model(**inputs)
        return self.dense(bert_output.pooler_output)

In [11]:
#!g1.1
model = MultiLabelBert(bert, N_CLASSES).to(device)

In [12]:
#!g1.1
import torch.optim as optim

LEARNING_RATE = 2e-5

In [13]:
#!g1.1
%pip install wandb
from wandb_writer import WandbWriter

Defaulting to user installation because normal site-packages is not writeable
You should consider upgrading via the '/usr/local/bin/python3 -m pip install --upgrade pip' command.[0m


In [15]:
#!g1.1
from sklearn.metrics import classification_report

import tqdm
from sklearn.metrics import f1_score, hamming_loss

def test_model(
        model,
        test_dataloader,
        criterion,
        threshold,
        wandb_writer):

    model.eval()

    with torch.no_grad():
        total_loss = 0.
        all_targets = []
        all_preds = []
        for batch, targets in tqdm.tqdm(test_dataloader, f"Test epoch#{0}", leave=False):
            apply_to_dict_values(batch, lambda x: x.to(device))
            targets = targets[:, 1:].to(device)
            logits = model(batch)[:, 1:]
            all_targets.extend(targets.to('cpu').tolist())
            all_preds.extend((torch.sigmoid(logits) > threshold).type(torch.DoubleTensor).to('cpu').tolist())
#             loss = criterion(logits, targets)
#             total_loss += loss.item()

#         print("Validation loss", total_loss / len(val_dataloader))
        wandb_writer.add_scalar("Test F1 (micro)",(f1_score(all_targets, all_preds, average='micro')))
        wandb_writer.add_scalar("Test F1 (macro)",(f1_score(all_targets, all_preds, average='macro')))
        wandb_writer.add_scalar("Test Hamming loss",(hamming_loss(all_targets, all_preds)))
        print("Test F1 (micro)",(f1_score(all_targets, all_preds, average='micro')))
        print("Test F1 (macro)",(f1_score(all_targets, all_preds, average='macro')))
        print("Test Hamming loss",(hamming_loss(all_targets, all_preds)))

def train_model(
        model,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        optimizer,
        criterion,
        n_epochs,
        wandb_writer,
        wandb_iter_start = 0):

    wandb_iter = wandb_iter_start
    for epoch in range(1, n_epochs+1):
        model.train()
        total_loss = 0.
        all_targets = []
        all_preds = []
        for batch, targets in tqdm.tqdm(train_dataloader, f"Train epoch#{epoch}", leave=False):

            apply_to_dict_values(batch, lambda x: x.to(device))
            targets = targets.to(device)
            logits = model(batch)
            all_targets.extend(targets.to('cpu').tolist())
            all_preds.extend((torch.sigmoid(logits) > 0.5).type(torch.DoubleTensor).to('cpu').tolist())
            optimizer.zero_grad()
            loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            wandb_writer.set_step(wandb_iter)
            wandb_writer.add_scalar("Batch train loss", loss.item())
            wandb_iter += 1

        wandb_writer.add_scalar("Train loss", total_loss / len(train_dataloader))
        wandb_writer.add_scalar("Train F1 (micro)",(f1_score(all_targets, all_preds, average='micro')))
        wandb_writer.add_scalar("Train F1 (macro)",(f1_score(all_targets, all_preds, average='macro')))
        wandb_writer.add_scalar("Train Hamming loss",(hamming_loss(all_targets, all_preds)))


        model.eval()

        with torch.no_grad():
            total_loss = 0.
            all_targets = []
            all_preds = []
            for batch, targets in tqdm.tqdm(val_dataloader, f"Val epoch#{epoch}", leave=False):
                apply_to_dict_values(batch, lambda x: x.to(device))
                targets = targets.to(device)
                logits = model(batch)
                all_targets.extend(targets.to('cpu').tolist())
                all_preds.extend((torch.sigmoid(logits) > 0.5).type(torch.DoubleTensor).to('cpu').tolist())
                loss = criterion(logits, targets)
                total_loss += loss.item()
            
            wandb_writer.add_scalar("Validation loss", total_loss / len(val_dataloader))
            wandb_writer.add_scalar("Validation F1 (micro)",(f1_score(all_targets, all_preds, average='micro')))
            wandb_writer.add_scalar("Validation F1 (macro)",(f1_score(all_targets, all_preds, average='macro')))
            wandb_writer.add_scalar("Validation Hamming loss",(hamming_loss(all_targets, all_preds)))
    
    test_model(model, test_dataloader, criterion, 0.5, wandb_writer)
    print(wandb_iter)

In [16]:
#!g1.1
print("Cuda memory allocated: {:.4} Gb".format(torch.cuda.max_memory_allocated('cuda') / 1024 ** 3))

Cuda memory allocated: 0.4092 Gb


In [17]:
#!g1.1
import gc

gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats('cuda')
print("Cuda memory allocated: {:.4} Gb".format(torch.cuda.max_memory_allocated('cuda') / 1024 ** 3))

Cuda memory allocated: 0.4092 Gb


In [18]:
#!g1.1
model = MultiLabelBert(bert, N_CLASSES).to(device)
optimizer = optim.AdamW(params=[p for p in model.parameters() if p.requires_grad], lr=LEARNING_RATE, weight_decay=0.01)

In [19]:
#!g1.1
train_model(model, train_dataloader, val_dataloader, test_dataloader, optimizer, torch.nn.BCEWithLogitsLoss(), 10, WandbWriter("MultiLabelDocBert"), wandb_iter_start=0)



Test F1 (micro) 0.7969249240588023
Test F1 (macro) 0.49048497349046183
Test Hamming loss 0.010744050969123354
13030


In [20]:
#!g1.1
train_model(model, train_dataloader, val_dataloader, test_dataloader, optimizer, torch.nn.BCEWithLogitsLoss(), 1, WandbWriter("MultiLabelDocBert"), wandb_iter_start=18800)

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Batch train loss,█▅▄▄▃▃▂▃▃▃▃▂▂▂▂▂▂▁▂▂▁▂▂▁▂▂▂▁▁▁▁▁▂▁▂▂▁▁▁▂
Test F1 (macro),▁
Test F1 (micro),▁
Test Hamming loss,▁
Train F1 (macro),▁▂▃▅▅▆▆▇██
Train F1 (micro),▁▅▆▇▇▇████
Train Hamming loss,█▄▃▂▂▂▂▁▁▁
Train loss,█▄▃▂▂▂▁▁▁▁
Validation F1 (macro),▁▃▄▅▆▇▇▇██
Validation F1 (micro),▁▅▇▇██████

0,1
Batch train loss,0.02861
Test F1 (macro),0.49048
Test F1 (micro),0.79692
Test Hamming loss,0.01074
Train F1 (macro),0.62125
Train F1 (micro),0.93183
Train Hamming loss,0.00408
Train loss,0.01323
Validation F1 (macro),0.53689
Validation F1 (micro),0.84817




KeyboardInterrupt: 

In [None]:
#!g1.1
