In [None]:
!pip install transformers

In [None]:
!pip install datasets

In [None]:
!pip install pytorch_lightning==1.8.0

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics.functional as tm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from tokenizers import Tokenizer

from datasets import load_dataset

import pytorch_lightning as pl

from sklearn.metrics import multilabel_confusion_matrix
from transformers import BertForSequenceClassification, BertConfig

from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from IPython.display import display
from typing import List, Dict, Any, Tuple
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
class ApeachDataset(Dataset):
    def __init__(self,
                 split: str,
                 tokenizer: Tokenizer, 
                 max_length: int = 256,
                 padding: str = "max_length") -> None:
        super().__init__()
        dataset = load_dataset("jason9693/APEACH")
        texts = dataset[split]['text']
        inputs = tokenizer(texts, padding=padding, max_length=max_length, truncation=True, return_tensors="pt")
        
        self.input_ids = inputs["input_ids"]
        self.attention_masks = inputs["attention_mask"]
        
        labels = dataset[split]['class']
        self.labels = torch.tensor(labels, dtype=torch.float32)
        
    def __len__(self):
        return self.input_ids.shape[0]
        
    def __getitem__(self, index: Any) -> Dict:
        return self.input_ids[index], self.attention_masks[index], self.labels[index]

    def dataloader(self, **kwargs) -> DataLoader:
        return DataLoader(self, **kwargs)

In [None]:
"""
monologg/koelectra-small-v3-discriminator
beomi/KcELECTRA-base
beomi/kcbert-base
beomi/kcbert-large
"""

huggingface_model_name = "monologg/koelectra-small-v3-discriminator"
tokenizer = AutoTokenizer.from_pretrained(huggingface_model_name)

In [None]:
max_length = 64
batch_size = 128
labels = ['hate']
train_dl = ApeachDataset("train", tokenizer, max_length=max_length).dataloader(batch_size=batch_size)
val_dl = ApeachDataset("test", tokenizer, max_length=max_length).dataloader(batch_size=batch_size)



  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
def join_step_outputs(outputs):
    result = {}
    keys = outputs[0].keys()
    for k in keys:
        X = [x[k] for x in outputs]
        if X[0].dim() == 0: # zero-dim tensor
            result[k] = torch.stack(X)
        else:
            result[k] = torch.cat(X, dim=0)
    return result

class TextClassificationModule(pl.LightningModule):
    def __init__(self, huggingface_model_name, labels, lr=5e-4):
        super().__init__()
        self.save_hyperparameters()
        self.model = AutoModelForSequenceClassification.from_pretrained(huggingface_model_name, num_labels=len(labels))
        # config = {
        #     "max_position_embeddings": 300,
        #     "hidden_dropout_prob": 0.1,
        #     "hidden_act": "gelu",
        #     "initializer_range": 0.02, # 12 to 2
        #     "num_hidden_layers": 2,
        #     "pooler_num_attention_heads": 12,
        #     "type_vocab_size": 2,
        #     "vocab_size": 30000,
        #     "hidden_size": 128, # 768 to 128
        #     "attention_probs_dropout_prob": 0.1,
        #     "num_attention_heads": 2, # 12 to 2
        #     "intermediate_size": 512, # 3072 to 512,
        #     "num_labels": len(labels)
        # }
        # self.model = BertForSequenceClassification(
        #     BertConfig(**config)
        # )
        self.multiclass = len(labels) > 1
        self.criterion = nn.CrossEntropyLoss() if self.multiclass else nn.BCELoss()
        self.labels = labels
            
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=self.hparams.lr)

    def forward(self, input_ids, attention_mask=None):
        logits = self.model(input_ids, attention_mask=attention_mask).logits
        if self.multiclass:
            logits = logits.softmax(dim=-1)
        else:
            logits = logits.sigmoid().squeeze(1).float()
        return logits

    def training_step(self, batch, batch_idx):
        ids, masks, labels = batch
        
        logits = self(ids, masks)
        loss = self.criterion(logits, labels)
        output = {"loss": loss, "logits": logits, "labels": labels}
        return output
    
    def training_epoch_end(self, outputs):
        outputs = join_step_outputs(outputs)
        loss = outputs["loss"].mean()
        self.log("train_epoch_loss", loss)

        logits = outputs["logits"]
        labels = outputs["labels"]
        acc = tm.accuracy(logits, labels.int(), task = 'binary')
        self.log(f"train_acc", acc, prog_bar=True)
        
    def validation_step(self, batch, batch_idx):
        ids, masks, labels = batch
        logits = self(ids, masks)
        loss = self.criterion(logits, labels)
        output = {"loss": loss, "logits": logits, "labels": labels}
        return output
    
    def validation_epoch_end(self, outputs):
        outputs = join_step_outputs(outputs)
        loss = outputs["loss"].mean()
        self.log("val_epoch_loss", loss, prog_bar=True)

        logits = outputs["logits"]
        labels = outputs["labels"]
        acc = tm.accuracy(logits, labels.int(), task = 'binary')
        self.log(f"val_acc", acc, prog_bar=True)

In [None]:
# val every n steps
# https://github.com/PyTorchLightning/pytorch-lightning/issues/2534

logger = pl.loggers.TensorBoardLogger(
    save_dir='.',
    name='lightning_logs'
)

save_dir = "./ckpt/"
checkpoint_callback = ModelCheckpoint(
    monitor="val_epoch_loss",
    dirpath=save_dir,
    filename=f"hate_{logger.version}_" + "{val_acc:.4f}",
    mode="min",
)

module = TextClassificationModule(
    huggingface_model_name,
    labels=labels,
    lr=1e-4
)
callbacks = [
    EarlyStopping("val_epoch_loss", mode="min", patience=10),
    checkpoint_callback
]

trainer = pl.Trainer(max_epochs=1000, 
                logger=logger,
                accelerator="gpu",
                val_check_interval=20,
                callbacks=callbacks)
trainer.fit(module, train_dl, val_dl)

Some weights of the model checkpoint at monologg/koelectra-small-v3-discriminator were not used when initializing ElectraForSequenceClassification: ['discriminator_predictions.dense_prediction.weight', 'discriminator_predictions.dense_prediction.bias', 'discriminator_predictions.dense.weight', 'discriminator_predictions.dense.bias']
- This IS expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at monologg/koelectra-small-v3-discriminator and are newly initialized

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]