In [2]:

from sklearn.datasets import fetch_20newsgroups
import torch
from transformers import DistilBertTokenizerFast
from galileo import Galileo
import pandas as pd
import numpy as np


class NewsgroupDataset(torch.utils.data.Dataset):
    def __init__(self, split: str, g: Galileo):

        newsgroups = fetch_20newsgroups(subset=split, remove=('headers', 'footers', 'quotes'))

        self.dataset = pd.DataFrame()
        self.dataset["text"] = newsgroups.data
        self.dataset["label"] = newsgroups.target

        for i in range(len(self.dataset)):
            payload = {
                "id": i,
                "text": self.dataset["text"][i],
                "gold": str(self.dataset["label"][i]),
            }
            g.logger.log_input(payload,
                               logger_mode="training" if split == "train" else "test")

        if split == "train":
            g.logger.log_labels(newsgroups.target_names)

        tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
        self.encodings = tokenizer(self.dataset["text"].tolist(), truncation=True, padding=True)

    def __getitem__(self, idx):
        x = torch.tensor(self.encodings["input_ids"][idx])
        attention_mask = torch.tensor(self.encodings["attention_mask"][idx])
        y = self.dataset["label"][idx]
        return idx, x, attention_mask, y

    def __len__(self):
        return len(self.dataset)

In [3]:
import pytorch_lightning as pl
from transformers import DistilBertForSequenceClassification, AdamW, DistilBertConfig
import torch.nn.functional as F
import torchmetrics

class LightningDistilBERT(pl.LightningModule):
    def __init__(self, g: Galileo):
        super().__init__()
        self.model = DistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased', config=DistilBertConfig(num_labels=20))

        self.g = g

        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()
    
    def forward(self, x, attention_mask, x_idxs=None, epoch=None, logging=False):
        out = self.model(x, attention_mask=attention_mask)

        log_probs = F.log_softmax(out.logits, dim=1)

        probs = F.softmax(out.logits, dim=1)  # This should be exp(x)

        if logging and x_idxs is not None and epoch is not None:
            for i in range(len(x_idxs)):
                index = int(x_idxs[i])
                prob = probs[i].detach().cpu().numpy().tolist()
                self.g.logger.log_output(
                    {
                        "epoch": epoch,
                        "id": index,
                        "emb": [0.0],
                        "prob": prob,
                    }
                )
        return log_probs
    
    
    def training_step(self, batch, batch_idx):
        """Model training step."""
        x_idxs, x, attention_mask, y = batch
        log_probs = self(x, attention_mask, x_idxs, self.current_epoch, True)
        loss = F.nll_loss(log_probs, y)
        self.train_acc(torch.argmax(log_probs, 1), y)
        self.log("train_acc", self.train_acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        """Model validation step."""
        x_idxs, x, attention_mask, y = batch
        log_probs = self(x, attention_mask, logging=False)
        loss = F.nll_loss(log_probs, y)
        self.val_acc(torch.argmax(log_probs, 1), y)
        self.log("val_acc", self.val_acc, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):  # Using this to log galileo
        """Model test step."""
        x_idxs, x, attention_mask, y = batch
        log_probs = self(x, attention_mask, x_idxs, -1, True)
        loss = F.nll_loss(log_probs, y)
        self.test_acc(torch.argmax(log_probs, 1), y)
        self.log("test_acc", self.test_acc, prog_bar=True)
        return loss

    def configure_optimizers(self):
        """Model optimizers."""
        return torch.optim.AdamW(
            filter(lambda p: p.requires_grad, self.parameters()), lr=1e-5
        )

In [5]:
from pytorch_lightning.callbacks import Callback


class MyCustomCallback(Callback):
    def on_init_start(self, trainer):
        print("#Atin Starting to init trainer!")

    def on_init_end(self, trainer):
        print("#Atin Trainer is init now")

    def on_train_end(self, trainer, pl_module):
        print("#Atin Do something when training ends")

In [None]:
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

# Initialize Galileo session
g = Galileo(namespace="newsgroup_demo")

model = LightningDistilBERT(g)
train_dataloader = torch.utils.data.DataLoader(NewsgroupDataset("train", g), batch_size=8, shuffle=True)
val_dataloader = torch.utils.data.DataLoader(NewsgroupDataset("test", g), batch_size=8, shuffle=True)

In [None]:
trainer = pl.Trainer(
    gpus=-1 if torch.cuda.is_available() else None,
    max_epochs=2,
    callbacks=[MyCustomCallback()]
)

trainer.fit(model, train_dataloader, val_dataloader)