In [1]:
from pathlib import Path

import torch
import torch.nn as nn
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import BertTokenizer, BertModel
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from datasets import BertDataset

In [2]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)

    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

In [3]:
class BertTextClassify(pl.LightningModule):
    def __init__(self, conf):
        super().__init__()
        self.conf = conf
        self.tokenizer = BertTokenizer.from_pretrained(conf.bert_model_pretrained)
        self.bert = BertModel.from_pretrained(conf.bert_model_pretrained)
        self.dropout = nn.Dropout(self.bert.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.bert.config.hidden_size, conf.num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        pooled_output = outputs[1]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

    def prepare_data(self):
        self.train_set = BertDataset(f"{self.conf.data_path}/cnews.train.txt", self.tokenizer)
        self.val_set = BertDataset(f"{self.conf.data_path}/cnews.val.txt", self.tokenizer)
        self.test_set = BertDataset(f"{self.conf.data_path}/cnews.test.txt", self.tokenizer)

    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=32, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=32, shuffle=True)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=32, shuffle=True)

    def _process_one_batch(self, batch, flag='train'):
        input_ids, attention_mask, y = batch['input_ids'], batch['attention_mask'], batch['label']
        y_hat = self(input_ids, attention_mask)
        loss_func = nn.CrossEntropyLoss()
        loss = loss_func(y_hat.view(-1, self.conf.num_classes), y.view(-1))
        self.log(f'{flag}_loss', loss)

        _, y_pred = torch.max(y_hat.view(-1, self.conf.num_classes), dim=-1)
        acc = accuracy_score(y_pred.cpu(), y.cpu())
        acc = torch.tensor(acc)
        self.log(f'{flag}_accuracy', acc)

        return loss

    def training_step(self, batch, batch_nb):
        loss = self._process_one_batch(batch, flag='train')
        return loss

    def validation_step(self, batch, batch_nb):
        return self._process_one_batch(batch, flag='val')

    def test_step(self, batch, batch_nb):
        return self._process_one_batch(batch, flag='test')

    def configure_optimizers(self):
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': 0.0},
            {'params': [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
        optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=self.conf.lr, eps=self.conf.adam_epsilon)
        return optimizer

In [3]:
def main(conf):
    model = BertTextClassify(conf)
    tb_logger = pl_loggers.TensorBoardLogger('logs/')
    ckpt = ModelCheckpoint(
        filepath=conf.model_name,
        verbose=False,
        monitor='val_loss',
        mode='min'
    )
    trainer = pl.Trainer(
        max_epochs=10,
        logger=tb_logger,
        checkpoint_callback=ckpt,
    )

    trainer.fit(model)
    trainer.test(ckpt_path=trainer.checkpoint_callback.best_model_path)


In [4]:
conf = Config(
    model_name='bert_text_classify',
    bert_model_pretrained=r'/Users/liuzhi/models/torch/bert-base-chinese',
    data_path=Path(r'/Users/liuzhi/datasets/cnews'),
    num_classes=10,
    lr=2e-5,
    adam_epsilon=1e-8
)

In [5]:
# main(conf)