# Load checkpoint

학습 중 저장된 모델을 다시 불러와 사용할 수 있습니다.   

In [1]:
import random
import os

import numpy as np
import pandas as pd
import datasets
from datasets import load_dataset, load_metric
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

from sklearn.model_selection import train_test_split

import torch
from torch.utils.data import Dataset, DataLoader

if torch.cuda.is_available():
    print("사용가능한 gpu : ",torch.cuda.device_count())

사용가능한 gpu :  1


Reproduction을 위한 Seed 고정  
출처 : https://dacon.io/codeshare/2363?dtype=vote&s_id=0

In [2]:
RANDOM_SEED = 42

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed) 
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True 
    
seed_everything(RANDOM_SEED)

저장된 모델의 checkpoint 경로를 설정해줍니다. 

In [3]:
model_checkpoint = "klue/bert-base"
batch_size = 32
save_checkpoint_path = "./checkpoints/checkpoint-3426"

huggingface에서 tokenizer를 가져오고 학습, 테스트 데이터를 토큰화합니다.

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

In [5]:
dataset = pd.read_csv("data/train_data.csv",index_col=False)
test = pd.read_csv("data/test_data.csv",index_col=False)

In [6]:
dataset_train, dataset_val = train_test_split(dataset,test_size = 0.2,random_state = RANDOM_SEED)

In [7]:
class BERTDataset(Dataset):
    def __init__(self, dataset, sent_key, label_key, bert_tokenizer):
        
        self.sentences = [ bert_tokenizer(i,truncation=True,return_token_type_ids=False) for i in dataset[sent_key] ]
        
        if not label_key == None:
            self.mode = "train"
        else:
            self.mode = "test"
            
        if self.mode == "train":
            self.labels = [np.int64(i) for i in dataset[label_key]]
        else:
            self.labels = [np.int64(0) for i in dataset[sent_key]]

    def __getitem__(self, i):
        if self.mode == "train":
            self.sentences[i]["label"] = self.labels[i]
            return self.sentences[i]
#             return ( self.sentences[i] , self.labels[i] )
        else:
            return self.sentences[i]

    def __len__(self):
        return (len(self.labels))


In [8]:
data_train = BERTDataset(dataset_train, "title", "topic_idx", tokenizer)
data_val = BERTDataset(dataset_val, "title", "topic_idx", tokenizer)
data_test = BERTDataset(test, "title", None, tokenizer)

`save_checkpoint_path`에서 학습된 모델을 가져옵니다. 

In [9]:
num_labels = 7
model = AutoModelForSequenceClassification.from_pretrained(save_checkpoint_path, num_labels=num_labels)

모델의 metric을 정의하고 학습된 모델을 확인합니다.

In [10]:
metric = load_metric("glue", "qnli")

In [11]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

In [12]:
metric_name = "accuracy"

args = TrainingArguments(
    "test-nli",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
)

In [13]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=data_train,
    eval_dataset=data_val,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [14]:
trainer.evaluate()

***** Running Evaluation *****
  Num examples = 9131
  Batch size = 32


{'eval_loss': 0.33201050758361816,
 'eval_accuracy': 0.8958493045668602,
 'eval_runtime': 7.4831,
 'eval_samples_per_second': 1220.223,
 'eval_steps_per_second': 38.22}

In [15]:
pred = trainer.predict(data_test)
pred = pred[0]
pred = np.argmax(pred,1)
submission = pd.read_csv('data/sample_submission.csv')
submission['topic_idx'] = pred
submission.to_csv("results/klue-bert-load-model.csv",index=False)

***** Running Prediction *****
  Num examples = 9131
  Batch size = 32
