In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
cd "drive/MyDrive/자연어처리/WSD_KOR"

/content/drive/MyDrive/자연어처리/WSD_KOR


In [None]:
%pip install wandb

Collecting wandb
  Downloading wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m61.9 MB/s[0m eta [36m0:00:00[0m
Collecting docker-pycreds>=0.4.0 (from wandb)
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting gitpython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.43-py3-none-any.whl (207 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m26.8 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-2.3.1-py2.py3-none-any.whl (289 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m289.0/289.0 kB[0m [31m39.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting setproctitle (from wandb)
  Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86

In [None]:
from sklearn.metrics import accuracy_score, f1_score, classification_report
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
import os
import yaml
import wandb

from prompt import prompt_generator, PromptDataset
from utils import read_config, read_files
from baseline import Baseline

In [None]:
class Train(Baseline):
    def __init__(self, config):
        super().__init__(config)
        self.batch_size = config["batch_size"]
        self.num_epochs = config["num_epochs"]
        self.lr = config["lr"]
        self.use_all_corpus = config["use_all_corpus"]
        self.run_name = config["run_name"]
        self.load_checkpoint = config["load_checkpoint"]
        self.checkpoint_load_path = config["checkpoint_load_path"]
        self.use_wandb = config["use_wandb"]
        self.train_loader = None
        self.val_loader = None
        self.test_loader = None
        self.optimizer = None
        self.criterion = None
        self.prompt_format = None

        if self.load_checkpoint:
            state_dict = torch.load(self.checkpoint_load_path)
            self.model.load_state_dict(state_dict)
            print(f"Checkpoint loaded from {self.checkpoint_load_path}")

    def ready_for_train(self):
        print("Reading files ...")
        train_contexts, val_contexts, dictionary = read_files(self.config)

        print("Generating prompt ...")
        if self.use_all_corpus:
            train_data = prompt_generator(train_contexts, dictionary, use_all=True, split=False)
            # length: 3,390,121
            val_data, test_data = prompt_generator(val_contexts, dictionary, use_all=True, split=True, test_size=0.5)
            # length: 374,927 / 2 each
        else:
            train_data, val_data, test_data = prompt_generator(val_contexts, dictionary, use_all=False, split=True, test_size=0.3)
            # length: 262,448 | 56,239 | 56,240
        self.prompt_format = test_data["inputs"][0]

        print("Making datasets ...")
        train_dataset = PromptDataset(train_data, tokenizer=self.tokenizer)
        val_dataset = PromptDataset(val_data, tokenizer=self.tokenizer)
        test_dataset = PromptDataset(test_data, tokenizer=self.tokenizer)

        print("Converting to dataloader ...")
        self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
        self.val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=True)
        self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=True)

        self.optimizer = AdamW(self.model.parameters(), lr=self.lr)
        self.criterion = torch.nn.CrossEntropyLoss()
        print("Ready for training !")

        print(f"Train dataloader length: {len(self.train_loader)}")
        print(f"Validation dataloader length: {len(self.val_loader)}")
        print(f"Test dataloader length: {len(self.test_loader)}")

    def train(self):
        device = self.device
        print(f"Using Device: {device}")

        self.model.to(device)

        checkpoint_dir = os.path.join(self.run_name, f"checkpoints")
        if not os.path.isdir(checkpoint_dir):
                os.makedirs(checkpoint_dir)

        config_path = os.path.join(self.run_name, f"config.yaml")
        with open(config_path, 'w', encoding="utf-8") as f:
            self.config["prompt_format"] = self.prompt_format
            yaml.dump(self.config, f, allow_unicode=True)
        print(f"Config saved at {config_path}")

        num_steps = len(self.train_loader)

        for epoch in range(self.num_epochs):
            self.model.train()
            for step, (input_ids, attention_masks, labels) in tqdm(enumerate(self.train_loader), desc=f"Epoch {epoch+1} train"):
                input_ids, attention_masks, labels = input_ids.to(device), attention_masks.to(device), labels.to(device)
                outputs = self.model(input_ids=input_ids, attention_mask=attention_masks)[0]
                loss = self.criterion(outputs, labels)
                if step % 50 == 0:
                    if self.use_wandb == True:
                        wandb.log({"train_loss": loss.item()}, step= epoch*num_steps + step)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch+1}.pt")
            torch.save(self.model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved at {checkpoint_path}")

            self.model.eval()
            total_preds = []
            total_labels = []
            with torch.no_grad():
                for input_ids, attention_masks, labels in tqdm(self.val_loader, desc=f"Epoch {epoch+1} validation"):
                    input_ids, attention_masks, labels = input_ids.to(device), attention_masks.to(device), labels.to(device)
                    outputs = self.model(input_ids=input_ids, attention_mask=attention_masks)[0]
                    preds = torch.argmax(outputs, dim=1)
                    total_preds.extend(preds.cpu().numpy())
                    total_labels.extend(labels.cpu().numpy())

            print(f"Epoch {epoch+1} validation result:")
            print(classification_report(y_true=total_labels, y_pred=total_preds, digits=4))

            acc = accuracy_score(y_true=total_labels, y_pred=total_preds)
            f1_weighted = f1_score(y_true=total_labels, y_pred=total_preds, average="weighted")
            if self.use_wandb == True:
                wandb.log({"val_acc": acc, "val_f1_weigthed": f1_weighted})

    def evaluation(self):
        device = self.device
        self.model.eval()
        self.model.to(device)
        total_preds = []
        total_labels = []
        with torch.no_grad():
            for input_ids, attention_masks, labels in tqdm(self.test_loader, desc=f"Evalutation ..."):
                input_ids, attention_masks, labels = input_ids.to(device), attention_masks.to(device), labels.to(device)
                outputs = self.model(input_ids=input_ids, attention_mask=attention_masks)[0]
                preds = torch.argmax(outputs, dim=1)
                total_preds.extend(preds.cpu().numpy())
                total_labels.extend(labels.cpu().numpy())

        print(f"Evaluation result:")
        report = classification_report(y_true=total_labels, y_pred=total_preds, digits=4)
        print(report)

        acc = accuracy_score(y_true=total_labels, y_pred=total_preds)
        f1_weighted = f1_score(y_true=total_labels, y_pred=total_preds, average="weighted")
        if self.use_wandb == True:
            wandb.log({"test_acc": acc, "test_f1_weigthed": f1_weighted})

        output_dir = os.path.join(self.run_name)
        if not os.path.isdir(output_dir):
                os.mkdir(output_dir)

        report_path = os.path.join(output_dir, "evaluation_report.txt")
        with open(report_path, "w") as f:
            f.write(report)
        print(f"Evaluation result saved at {report_path}")

In [None]:
import sys

# Ignore arg parse in Google Colab
sys.argv = [sys.argv[0]]

In [None]:
config = read_config()
train = Train(config)
train.model

Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at monologg/koelectra-base-v3-discriminator and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ElectraForSequenceClassification(
  (electra): ElectraModel(
    (embeddings): ElectraEmbeddings(
      (word_embeddings): Embedding(35000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): ElectraEncoder(
      (layer): ModuleList(
        (0-11): 12 x ElectraLayer(
          (attention): ElectraAttention(
            (self): ElectraSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): ElectraSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): L

In [None]:
train.config

{'data_path': 'Data/',
 'dict_path': 'Dict/',
 'train_fname': 'processed_train.csv',
 'val_fname': 'processed_eval.csv',
 'dict_fname': 'processed_dictionary.json',
 'use_all_corpus': False,
 'run_name': '0526-lr-1e-5',
 'checkpoint_load_path': '',
 'batch_size': 32,
 'num_epochs': 5,
 'lr': 1e-05}

In [None]:
train.ready_for_train()

Reading files ...
Preprocessing for prompt ...
Making datasets ...
Converting to dataloader ...
Ready for training !
Train dataloader length: 8202
Validation dataloader length: 1758
Test dataloader length: 1758


In [None]:
wandb.init(project="NLP-WSD-KOR", config=config)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [None]:
wandb.config["prompt_format"] = train.prompt_format
wandb.run.name = train.config["run_name"]
wandb.define_metric("train_loss", summary="min")
wandb.define_metric("val_acc", summary="max")
wandb.define_metric("val_f1_weighted", summary="max")



True

In [None]:
train.train()

Using Device: cuda
Config saved at 0526-lr-1e-5/config.yaml


Epoch 1 train: 8202it [59:02,  2.32it/s]


Checkpoint saved at 0526-lr-1e-5/checkpoints/epoch_1.pt


Epoch 1 validation: 100%|██████████| 1758/1758 [04:33<00:00,  6.42it/s]


Epoch 1 validation result:
              precision    recall  f1-score   support

           0       0.97      0.95      0.96     36362
           1       0.92      0.95      0.93     19877

    accuracy                           0.95     56239
   macro avg       0.95      0.95      0.95     56239
weighted avg       0.95      0.95      0.95     56239



Epoch 2 train: 8202it [59:01,  2.32it/s]


Checkpoint saved at 0526-lr-1e-5/checkpoints/epoch_2.pt


Epoch 2 validation: 100%|██████████| 1758/1758 [04:33<00:00,  6.43it/s]


Epoch 2 validation result:
              precision    recall  f1-score   support

           0       0.97      0.96      0.97     36362
           1       0.94      0.95      0.94     19877

    accuracy                           0.96     56239
   macro avg       0.95      0.96      0.96     56239
weighted avg       0.96      0.96      0.96     56239



Epoch 3 train: 8202it [59:01,  2.32it/s]


Checkpoint saved at 0526-lr-1e-5/checkpoints/epoch_3.pt


Epoch 3 validation: 100%|██████████| 1758/1758 [04:33<00:00,  6.43it/s]


Epoch 3 validation result:
              precision    recall  f1-score   support

           0       0.97      0.97      0.97     36362
           1       0.95      0.95      0.95     19877

    accuracy                           0.96     56239
   macro avg       0.96      0.96      0.96     56239
weighted avg       0.96      0.96      0.96     56239



Epoch 4 train: 8202it [59:00,  2.32it/s]


Checkpoint saved at 0526-lr-1e-5/checkpoints/epoch_4.pt


Epoch 4 validation: 100%|██████████| 1758/1758 [04:33<00:00,  6.43it/s]


Epoch 4 validation result:
              precision    recall  f1-score   support

           0       0.98      0.97      0.97     36362
           1       0.95      0.95      0.95     19877

    accuracy                           0.96     56239
   macro avg       0.96      0.96      0.96     56239
weighted avg       0.97      0.96      0.96     56239



Epoch 5 train: 8202it [59:02,  2.32it/s]


Checkpoint saved at 0526-lr-1e-5/checkpoints/epoch_5.pt


Epoch 5 validation: 100%|██████████| 1758/1758 [04:33<00:00,  6.42it/s]


Epoch 5 validation result:
              precision    recall  f1-score   support

           0       0.97      0.97      0.97     36362
           1       0.95      0.95      0.95     19877

    accuracy                           0.96     56239
   macro avg       0.96      0.96      0.96     56239
weighted avg       0.97      0.96      0.97     56239



In [None]:
train.evaluation()

Evalutation ...: 100%|██████████| 1758/1758 [04:29<00:00,  6.53it/s]

Evaluation result:
              precision    recall  f1-score   support

           0       0.97      0.97      0.97     36363
           1       0.95      0.95      0.95     19877

    accuracy                           0.96     56240
   macro avg       0.96      0.96      0.96     56240
weighted avg       0.96      0.96      0.96     56240






In [None]:
wandb.finish()

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_acc,▁
test_f1_weigthed,▁
train_loss,▄▆▅▂▅█▆▂▂▂▂▂▁▂▁▄▂▃▄▁▂▂▂▁▂▁▂▂▁▂▃▂▂▁▂▁▂▁▂▂

0,1
test_acc,0.96394
test_f1_weigthed,0.96396
