In [1]:
from datetime import datetime
from typing import Optional

import nlp
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader
from transformers import (
    AdamW,
    AutoModelForSequenceClassification,
    AutoConfig,
    AutoTokenizer,
    get_linear_schedule_with_warmup,
    glue_compute_metrics
)



In [2]:
class GLUEDataModule(pl.LightningDataModule):

    task_text_field_map = {
        'cola': ['sentence'],
        'sst2': ['sentence'],
        'mrpc': ['sentence1', 'sentence2'],
        'qqp': ['question1', 'question2'],
        'stsb': ['sentence1', 'sentence2'],
        'mnli': ['premise', 'hypothesis'],
        'qnli': ['question', 'sentence'],
        'rte': ['sentence1', 'sentence2'],
        'wnli': ['sentence1', 'sentence2'],
        'ax': ['premise', 'hypothesis']
    }

    glue_task_num_labels = {
        'cola': 2,
        'sst2': 2,
        'mrpc': 2,
        'qqp': 2,
        'stsb': 1,
        'mnli': 3,
        'qnli': 2,
        'rte': 2,
        'wnli': 2,
        'ax': 3
    }

    loader_columns = [
        'nlp_idx',
        'input_ids',
        'token_type_ids',
        'attention_mask',
        'start_positions',
        'end_positions',
        'labels'
    ]

    def __init__(
        self,
        model_name_or_path: str,
        task_name: str ='mrpc',
        max_seq_length: int = 128,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        **kwargs
    ):
        super().__init__()
        self.model_name_or_path = model_name_or_path
        self.task_name = task_name
        self.max_seq_length = max_seq_length
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size

        self.text_fields = self.task_text_field_map[task_name]
        self.num_labels = self.glue_task_num_labels[task_name]
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)

    @property
    def dataset(self):
        dataset = nlp.load_dataset('glue', self.task_name)

        for split in dataset.keys():
            dataset[split] = dataset[split].map(
                self.convert_to_features,
                batched=True,
                remove_columns=['label'],
            )
            columns = [c for c in self.dataset[split].column_names if c in GLUEDataModule.loader_columns]
            dataset[split].set_format(type="torch", columns=columns)
        
        return dataset

    @property
    def eval_splits(self):
        return [x for x in self.dataset.keys() if 'validation' in x]
    
    def prepare_data(self):
        nlp.load_dataset('glue', self.task_name)
        AutoTokenizer.from_pretrained(self.model_name_or_path, use_fast=True)
    
    def train_dataloader(self):
        return DataLoader(self.dataset['train'], batch_size=self.train_batch_size)
    
    def val_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset['validation'], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def test_dataloader(self):
        if len(self.eval_splits) == 1:
            return DataLoader(self.dataset['test'], batch_size=self.eval_batch_size)
        elif len(self.eval_splits) > 1:
            return [DataLoader(self.dataset[x], batch_size=self.eval_batch_size) for x in self.eval_splits]

    def convert_to_features(self, example_batch, indices=None):

        # Either encode single sentence or sentence pairs
        if len(self.text_fields) > 1:
            texts_or_text_pairs = list(zip(example_batch[self.text_fields[0]], example_batch[self.text_fields[1]]))
        else:
            texts_or_text_pairs = example_batch[self.text_fields[0]]

        # Tokenize the text/text pairs
        features = self.tokenizer.batch_encode_plus(
            texts_or_text_pairs,
            max_length=self.max_seq_length,
            pad_to_max_length=True,
            truncation=True
        )

        # Rename label to labels to make it easier to pass to model forward
        features['labels'] = example_batch['label']

        return features

In [3]:
class GLUETransformer(pl.LightningModule):
    def __init__(
        self,
        model_name_or_path: str,
        task_name: str,
        num_labels: int,
        learning_rate: float = 2e-5,
        adam_epsilon: float = 1e-8,
        warmup_steps: int = 0,
        weight_decay: float = 0.0,
        train_batch_size: int = 32,
        eval_batch_size: int = 32,
        eval_splits: Optional[list] = None,
        **kwargs
    ):
        print('a1')
        super().__init__()

        self.model_name_or_path =  model_name_or_path
        self.task_name = task_name
        self.num_labels = num_labels
        self.learning_rate = learning_rate
        self.adam_epsilon = adam_epsilon
        self.warmup_steps = warmup_steps
        self.weight_decay = weight_decay
        self.train_batch_size = train_batch_size
        self.eval_batch_size = eval_batch_size
        self.eval_splits = eval_splits
        self.hparams = kwargs

        print('a2')
        self.config = AutoConfig.from_pretrained(model_name_or_path, num_labels=num_labels)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, config=self.config)
        self.metric = nlp.load_metric(
            'glue',
            self.task_name,
            experiment_id=datetime.now().strftime("%d-%m-%Y_%H-%M-%S")
        )
        print('a3')
        self.total_steps = (
                (len(self.train_dataloader().dataset) // (self.train_batch_size * self.hparams.get('gpus', 1)))
                // self.hparams.get('accumulate_grad_batches', 1)
                * float(self.hparams.get('max_epochs', 1))
            )
        print('a4')

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        loss = outputs[0]
        return pl.TrainResult(loss)

    def validation_step(self, batch, batch_idx, dataloader_idx=0):
        outputs = self(**batch)
        val_loss, logits = outputs[:2]

        if self.num_labels >= 1:
            preds = torch.argmax(logits, axis=1)
        elif self.num_labels == 1:
            preds = logits.squeeze()

        labels = batch["labels"]

        return {'loss': val_loss, "preds": preds, "labels": labels}

    def validation_epoch_end(self, outputs):
        if self.task_name == 'mnli':
            for i, output in enumerate(outputs):
                # matched or mismatched
                split = self.eval_splits[i].split('_')[-1]
                preds = torch.cat([x['preds'] for x in output]).detach().cpu().numpy()
                labels = torch.cat([x['labels'] for x in output]).detach().cpu().numpy()
                loss = torch.stack([x['loss'] for x in output]).mean()
                if i == 0:
                    result = pl.EvalResult(checkpoint_on=loss)
                result.log(f'val_loss_{split}', loss, prog_bar=True)
                split_metrics = {f"{k}_{split}": v for k, v in self.metric.compute(preds, labels).items()}
                result.log_dict(split_metrics, prog_bar=True)
            return result

        preds = torch.cat([x['preds'] for x in outputs]).detach().cpu().numpy()
        labels = torch.cat([x['labels'] for x in outputs]).detach().cpu().numpy()
        loss = torch.stack([x['loss'] for x in outputs]).mean()
        result = pl.EvalResult(checkpoint_on=loss)
        result.log('val_loss', loss, prog_bar=True)
        result.log_dict(self.metric.compute(preds, labels), prog_bar=True)
        return result

    def configure_optimizers(self):
        "Prepare optimizer and schedule (linear warmup and decay)"
        model = self.model
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
                "weight_decay": self.weight_decay,
            },
            {
                "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
                "weight_decay": 0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.learning_rate, eps=self.adam_epsilon)

        scheduler = get_linear_schedule_with_warmup(
            optimizer, num_warmup_steps=self.warmup_steps, num_training_steps=self.total_steps
        )
        scheduler = {
            'scheduler': scheduler,
            'interval': 'step',
            'frequency': 1
        }
        return [optimizer], [scheduler]

In [4]:
def main():
    pl.seed_everything(42)
    print(1)
    dm = GLUEDataModule('bert-base-cased', 'sst2')
    print(2)
    model = GLUETransformer(
        'bert-base-cased',
        'sst2',
        num_labels=dm.num_labels,
        eval_splits=dm.eval_splits
    )
    print(3)
    trainer = pl.Trainer()
    print(4)
    return dm, model, trainer

In [5]:
dm, model, trainer = main()

1
2


KeyboardInterrupt: 

In [None]:
dm = GLUEDataModule('bert-base-cased', 'sst2')