In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pytorch_lightning.loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader
from datasets import TxtDataset

In [2]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)

    def set(self, key, val):
        self[key] = val
        setattr(self, key, val)

In [3]:
class WordAVGModel(pl.LightningModule):
    def __init__(self, conf):
        super().__init__()
        self.conf = conf
        self.embedding = nn.Embedding(conf.vocab_size, conf.embed_size, padding_idx=0)
        self.fc = nn.Linear(conf.embed_size, conf.num_classes)
        self.dropout = nn.Dropout(conf.dropout)
    
    def forward(self, x):
        embedded = self.dropout(self.embedding(x))  # [B, seq, embed_size]
        pooled = F.avg_pool2d(embedded, (embedded.size(1), 1)).squeeze(1)
        return self.fc(pooled)
    
    def prepare_data(self):
        self.train_set = TxtDataset(f"{self.conf.data_path}/cnews.train.txt", maxlen=512)
        self.val_set = TxtDataset(f"{self.conf.data_path}/cnews.val.txt", maxlen=512)
        self.test_set = TxtDataset(f"{self.conf.data_path}/cnews.test.txt", maxlen=512)
    
    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=32, shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=32, shuffle=True)
    
    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=32, shuffle=True)
    
    def _process_one_batch(self, batch, flag='train'):
        x, y = batch
        y_hat = self(x)
        loss_func = nn.CrossEntropyLoss()
        loss = loss_func(y_hat.view(-1, self.conf.num_classes), y.view(-1))
        self.log(f'{flag}_loss', loss)

        _, y_pred = torch.max(y_hat.view(-1, self.conf.num_classes), dim=-1)
        acc = accuracy_score(y_pred.cpu(), y.cpu())
        acc = torch.tensor(acc)
        self.log(f'{flag}_accuracy', acc)

        return loss 

    def training_step(self, batch, batch_nb):
        loss = self._process_one_batch(batch, flag='train')
        return loss
    
    def validation_step(self, batch, batch_nb):
        return self._process_one_batch(batch, flag='val')
    
    def test_step(self, batch, batch_nb):
        return self._process_one_batch(batch, flag='test')

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.conf.lr)
        return optimizer

In [4]:
def main(conf):
    model = WordAVGModel(conf)
    tb_logger = pl_loggers.TensorBoardLogger('logs/')
    ckpt = ModelCheckpoint(
        filepath=conf.model_name,
        verbose=False,
        monitor='val_loss',
        mode='min'
    )
    trainer = pl.Trainer(
        max_epochs=10,
        logger=tb_logger,
        checkpoint_callback=ckpt,
    )

    trainer.fit(model)
    trainer.test(ckpt_path=trainer.checkpoint_callback.best_model_path)

In [5]:
conf = Config(
    model_name='word_avg',
    data_path=Path(r'/Users/liuzhi/datasets/cnews'),
    vocab_size=50000,
    embed_size=300,
    num_classes=10,
    lr=3e-3,
    dropout=0.2
)

In [6]:
main(conf)

it/s, loss=0.012, v_num=0]
Epoch 3:  93%|█████████▎| 1600/1720 [04:41<00:21,  5.68it/s, loss=0.012, v_num=0]
Epoch 3:  93%|█████████▎| 1605/1720 [04:41<00:20,  5.69it/s, loss=0.012, v_num=0]
Epoch 3:  94%|█████████▎| 1610/1720 [04:42<00:19,  5.71it/s, loss=0.012, v_num=0]
Epoch 3:  94%|█████████▍| 1615/1720 [04:42<00:18,  5.72it/s, loss=0.012, v_num=0]
Epoch 3:  94%|█████████▍| 1620/1720 [04:42<00:17,  5.74it/s, loss=0.012, v_num=0]
Epoch 3:  94%|█████████▍| 1625/1720 [04:42<00:16,  5.76it/s, loss=0.012, v_num=0]
Epoch 3:  95%|█████████▍| 1630/1720 [04:42<00:15,  5.77it/s, loss=0.012, v_num=0]
Epoch 3:  95%|█████████▌| 1635/1720 [04:42<00:14,  5.79it/s, loss=0.012, v_num=0]
Epoch 3:  95%|█████████▌| 1640/1720 [04:42<00:13,  5.80it/s, loss=0.012, v_num=0]
Epoch 3:  96%|█████████▌| 1645/1720 [04:42<00:12,  5.82it/s, loss=0.012, v_num=0]
Epoch 3:  96%|█████████▌| 1650/1720 [04:42<00:12,  5.83it/s, loss=0.012, v_num=0]
Epoch 3:  96%|█████████▌| 1655/1720 [04:42<00:11,  5.85it/s, loss=0.012