In [1]:
from nlp_datamodule import NLPDataModule
from datasets import load_dataset
from torch.utils.data import DataLoader
from pl_bolts.models import LogisticRegression
from pytorch_lightning import Trainer
import torch

In [6]:
class TextClassificationDataModuleHF(NLPDataModule):

    def __init__(self, max_len: int = 500, batch_size: int = 32):
        super().__init__()
        self.max_len = max_len
        self.batch_size = batch_size
        
    # NOTE: this can be then done automatically
    def pipeline(self, data, stage=None):
        data["text"] = self.normalization(data["text"])
        data["text"] = self.tokenization(data["text"])
        data["text"] = self.cleaning(data["text"])
        if stage == "test":
            data["text"] = super().numericalization(
                data["text"], max_len=self.max_len, pad=self.word2index["<pad>"]
            )
        return data
    
    # NOTE: Here I am forced to rewrite because I need to output a dict
    # but there may be a better solution
    def numericalization(self, data, max_len, pad):
        data["text"] = super().numericalization(data["text"], max_len, pad)
        return data

    def setup(self, stage=None):
        
        if stage == 'fit' or stage is None:
            ds = load_dataset("imdb", split="train")
            self.num_classes = ds.features["label"].num_classes
            ds = ds.map(self.pipeline, fn_kwargs={"stage": stage})

            # only after the text is clean I want to build vocab
            if self.vocab is None:
                self.build_vocab(ds["text"])
            ds = ds.map(self.numericalization, fn_kwargs={"max_len": self.max_len, "pad": self.word2index["<pad>"]})
            ds = ds.train_test_split(test_size=.2)

            self.train_ds = ds["train"]
            self.val_ds = ds["test"]
            self.train_ds.set_format(type='torch', columns=['text', 'label'])
            self.val_ds.set_format(type='torch', columns=['text', 'label'])

        if stage == 'test':
            self.test_ds = load_dataset("imdb", split="test")
            self.test_ds = self.test_ds.map(self.pipeline, fn_kwargs={"stage": stage})
            self.test_ds.set_format(type='torch', columns=['text', 'label'])


    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def validation_dataloader(self):
        return DataLoader(self.validation_ds, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def test_dataloader(self):
        return DataLoader(self.test_ds, batch_size=self.batch_size, collate_fn=self.collate_fn)

    @staticmethod
    def collate_fn(batches):
        x = torch.stack([batch["text"] for batch in batches]).float()
        y = torch.stack([batch["label"] for batch in batches])
        return x, y

In [7]:
dm = TextClassificationDataModuleHF()
dm.prepare_data()
dm.setup()

Reusing dataset imdb (/Users/49796/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3)
Loading cached processed dataset at /Users/49796/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3/cache-1bbed78b80c4854d.arrow
Building vocab: 100%|██████████| 25000/25000 [00:00<00:00, 43807.71it/s]
Loading cached processed dataset at /Users/49796/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3/cache-0a76965b8264aa9a.arrow


In [9]:
model = LogisticRegression(input_dim=dm.max_len, num_classes=dm.num_classes)
trainer = Trainer(fast_dev_run=True)
trainer.fit(model, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
Running in fast_dev_run mode: will run a full train, val and test loop using a single batch

  | Name   | Type   | Params
----------------------------------
0 | linear | Linear | 1 K   
Epoch 0: 100%|██████████| 1/1 [00:00<00:00, 12.27it/s, loss=2210.080, v_num=6, train_ce_loss=2.21e+3]


1

In [11]:
dm.setup(stage="test")
trainer.test(model, datamodule=dm)

Reusing dataset imdb (/Users/49796/.cache/huggingface/datasets/imdb/plain_text/1.0.0/90099cb476936b753383ba2ae6ab2eae419b2e87f71cd5189cb9c8e5814d12a3)
Testing: 0it [00:00, ?it/s]
--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_acc': tensor(0.1562),
 'test_ce_loss': tensor(6827.8779),
 'test_loss': tensor(6827.8779)}
--------------------------------------------------------------------------------
Testing: 100%|██████████| 1/1 [00:00<00:00, 13.65it/s]


[{'test_ce_loss': 6827.8779296875,
  'test_acc': 0.15625,
  'test_loss': 6827.8779296875}]