In [None]:
from typing import Optional, Dict, Any, Callable

## Pytorch recap

In [None]:
import torch

In [None]:
class MyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(10, 2)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)

In [None]:
my_model = MyModel()
batch = torch.rand((1, 10))
my_model(batch)

## Data loading

In [None]:
from datasets import load_dataset

In [None]:
dataset = load_dataset("pietrolesci/ag_news", "concat")
dataset

## Data Preprocessing

In [None]:
from transformers import AutoTokenizer
from torch.utils.data import DataLoader

In [None]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [None]:
dataset = dataset.map(lambda ex: tokenizer(ex["text"], truncation=True), batched=True)

In [None]:
train_dataset, test_dataset = dataset["train"], dataset["test"]

In [None]:
_split = train_dataset.train_test_split(0.3)
train_dataset, val_dataset = _split["train"], _split["test"]

columns_to_keep = ['label', 'input_ids', 'token_type_ids', 'attention_mask']
train_dataset = train_dataset.with_format(columns=columns_to_keep)
val_dataset = val_dataset.with_format(columns=columns_to_keep)

In [None]:
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding=True,
    return_tensors="pt",
)

train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=data_collator)
val_dataloader = DataLoader(val_dataset, batch_size=1, collate_fn=data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=data_collator)

In [None]:
next(iter(DataLoader(t, batch_size=32, collate_fn=data_collator)))

## All in one place: DataModule

In [None]:
from pytorch_lightning import LightningDataModule
from transformers import DataCollatorWithPadding

In [None]:
class AGNewsDataModule(LightningDataModule):
    def __init__(self, batch_size: int = 32, val_perc: float = 0.3) -> None:
        super().__init__()
        self.batch_size = batch_size
        self.val_perc = val_perc

    def setup(self, stage: Optional[str] = None) -> None:
        # load tokenizer
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        
        # load dataset
        dataset = load_dataset("pietrolesci/ag_news", "concat")
        
        # tokenize
        dataset = dataset.map(lambda ex: tokenizer(ex["text"]), batched=True)
        columns_to_keep = ['label', 'input_ids', 'token_type_ids', 'attention_mask']
        
        # train-val split and record datasets
        train_dataset, test_dataset = dataset["train"], dataset["test"]
        self.test_dataset = test_dataset.with_format(columns=columns_to_keep)

        if self.val_perc:
            _split = train_dataset.train_test_split(0.3)
            train_dataset, val_dataset = _split["train"], _split["test"]
            self.val_dataset = val_dataset.with_format(columns=columns_to_keep)


        self.train_dataset = train_dataset.with_format(columns=columns_to_keep)


    def train_dataloader(self) -> DataLoader:
        return DataLoader(self.train_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def val_dataloader(self) -> Optional[DataLoader]:
        return DataLoader(self.val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(self.test_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)

    @property
    def collate_fn(self) -> Callable:
        return DataCollatorWithPadding(
            tokenizer=self.tokenizer,
            padding=True,
            return_tensors="pt",
        )

    @property
    def num_classes(self) -> int:
        if hasattr(self, "train_dataset"):
            return self.train_dataset.features["label"].num_classes
    

In [None]:
datamodule = AGNewsDataModule()
datamodule.prepare_data()
datamodule.setup()

In [None]:
next(iter(datamodule.train_dataloader()))

## Model

In [None]:
from pytorch_lightning import LightningModule
from transformers import AutoModelForSequenceClassification, AdamW, get_constant_schedule_with_warmup

In [None]:
class TransformerModel(LightningModule):
    def __init__(self, model_name: str, num_classes: int, learning_rate: float = 0.00001, num_warmup_steps: int = 2_000) -> None:
        super().__init__()
        self.model_name = model_name
        self.num_classes = num_classes
        self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name, num_labels=self.num_classes)
        self.learning_rate = learning_rate
        self.num_warmup_steps = num_warmup_steps

    def common_step(self, batch: Any, stage: str):
        """Outputs loss and logits, logs loss and metrics."""
        out = self(batch)
        _, loss = out.logits, out.loss
        self.log(f"{stage}_loss", loss)
        return loss

    def forward(self, batch) -> torch.Tensor:
        return self.model(**batch)

    def training_step(self, batch: Any, batch_idx: int = 0, optimizer_idx: int = 0) -> Dict[str, Any]:
        return self.common_step(batch, "train")

    def validation_step(self, batch: Any, batch_idx: int = 0) -> Dict[str, Any]:
        return self.common_step(batch, "val")

    def test_step(self, batch: Any, batch_idx: int = 0) -> Dict[str, Any]:
        return self.common_step(batch, "test")

    def configure_optimizers(self) -> Dict[str, Any]:
        optimizer = AdamW(filter(lambda p: p.requires_grad, self.parameters()), lr=self.learning_rate)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=self.num_warmup_steps),
                "monitor": "loss",
                "frequency": 1,
                "interval": "step",
            }
        }


In [None]:
model = TransformerModel("bert-base-uncased", num_classes=datamodule.num_classes)

## Train!

In [None]:
from pytorch_lightning import Trainer

In [None]:
trainer = Trainer(fast_dev_run=True)

In [None]:
trainer.fit(model, datamodule)