In [16]:
import torch
from datasets import load_dataset, DatasetDict
import pytorch_lightning as pl
from torch.utils.data import random_split

In [2]:
imdb = load_dataset("imdb")

Reusing dataset imdb (/Users/pucktada/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


  0%|          | 0/3 [00:00<?, ?it/s]

In [13]:
imdb_train, imdb_val = random_split(imdb['train'], [20000, 5000])

In [2]:
MAX_LEN = 256
NUM_LABELS = 2

In [3]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
text = "Hello NLP lovers!"
inputs = tokenizer.encode_plus(text, add_special_tokens=True, max_length=MAX_LEN, truncation=True)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]

In [4]:
input_ids

[101, 7592, 17953, 2361, 10205, 999, 102]

In [5]:
token_type_ids

[0, 0, 0, 0, 0, 0, 0]

In [15]:
sent1 = "It is an excellent day for a picnic!"
sent2 = "In a day like this, I want to go for a picnic!"
inputs = tokenizer.encode_plus(sent1, sent2, add_special_tokens=True, max_length=MAX_LEN, truncation=True)
input_ids, token_type_ids = inputs["input_ids"], inputs["token_type_ids"]

In [None]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [5]:
imdb

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})

In [21]:
imdb_datadict = DatasetDict({
    'train': imdb_train,
    'val': imdb_val,
    'test': imdb['test']
})

In [22]:
class MyDataModule(pl.LightningDataModule):
    def __init__(self, batch_size: int = 32):
        super().__init__()
        self.batch_size = batch_size
    
    def train_dataloader(self):
        return DataLoader(imdb['train'].map(preprocessor), # dataset
                          sampler=RandomSampler(imdb['train']), # random sampler
                          batch_size=batch_size)

    #def val_dataloader(self):
    #    return DataLoader(imdb_datadict['val'].map(preprocessor),
    #                      sampler=SequentialSampler(imdb_datadict['val']),
    #                      batch_size=batch_size)

    def test_dataloader(self):
        return DataLoader(imdb['test'].map(preprocessor),
                          sampler=SequentialSampler(imdb['test']),
                          batch_size=batch_size)
    
    #def predict_dataloader(self):
    #    pass


In [None]:
class Model(pl.LightningModule):
    
    def __init__(self):
        super(Model, self).__init__()
        model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=NUM_LABELS)
        self.model = model

    def configure_optimizers(self):
        param_optimizer = list(self.model.named_parameters())
        no_decay = ["bias", "gamma", "beta"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
                "weight_decay_rate": 0.01
            },
            {
                "params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
                "weight_decay_rate": 0.0
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=2e-5)
        return optimizer

    def _common_step(self, batch, batch_idx):
        labels = batch["label"]
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]
        
        loss, logits = self.model(input_ids, 
                                  token_type_ids=token_type_ids, 
                                  attention_mask=attention_mask, 
                                  labels=labels)
        return loss, logits
        
    def training_step(self, batch, batch_idx):
        loss, logits = self._common_step(batch, batch_idx)

        tqdm_dict = {"train_loss": loss}
        output = OrderedDict({ "loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict })

        return output

    #def validation_step(self, batch, batch_idx):
    #    loss, logits = self._common_step(batch, batch_idx)        
    #    labels_hat = torch.argmax(logits, dim=1)
    #    correct_count = torch.sum(labels == labels_hat)
    #    if self.on_gpu:
    #        correct_count = correct_count.cuda(loss.device.index)
    #    output = OrderedDict({ "val_loss": loss, "correct_count": correct_count, "batch_size": len(labels) })
    #    return output

    #def validation_end(self, outputs):
    #    val_acc = sum([out["correct_count"] for out in outputs]).float() / sum(out["batch_size"] for out in outputs)
    #    val_loss = sum([out["val_loss"] for out in outputs]) / len(outputs)
    #    tqdm_dict = { "val_loss": val_loss, "val_acc": val_acc, }
    #    result = {"progress_bar": tqdm_dict, "log": tqdm_dict, "val_loss": val_loss}
    #    return result

    def test_step(self, batch, batch_idx):
        loss, logits = self._common_step(batch, batch_idx)

        labels_hat = torch.argmax(logits, dim=1)
        correct_count = torch.sum(labels == labels_hat)
        
        #if self.on_gpu:
        #    correct_count = correct_count.cuda(loss.device.index)
        output = OrderedDict({ "test_loss": loss, "correct_count": correct_count, "batch_size": len(labels) })
        return output

    def test_end(self, outputs):
        test_acc = sum([out["correct_count"] for out in outputs]).float() / sum(out["batch_size"] for out in outputs)
        test_loss = sum([out["test_loss"] for out in outputs]) / len(outputs)
        tqdm_dict = { "test_loss": test_loss, "test_acc": test_acc, }
        result = {"progress_bar": tqdm_dict, "log": tqdm_dict}
        return result
    
