# DistilBERT + IMDb 텍스트 분류

[DistilBERT](https://huggingface.co/distilbert/distilbert-base-uncased) 모델과 [IMDb](https://huggingface.co/datasets/imdb) 데이터셋을 이용하여 텍스트 분류 미세조정 훈련을 실습합니다.

* 참고 문서: [Text classification](https://huggingface.co/docs/transformers/en/tasks/sequence_classification)

실습에 필요한 모듈을 설치하려면 아래 줄의 주석을 제거하고 실행합니다.

In [1]:
!pip install datasets evaluate transformers



## 1. 데이터셋 적재

In [2]:
from datasets import load_dataset

imdb = load_dataset("imdb")
print(imdb)

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label'],
        num_rows: 50000
    })
})


In [3]:
imdb["test"][0]

{'text': 'I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn\'t match the background, and painfully one-dimensional characters cannot be overcome with a \'sci-fi\' setting. (I\'m sure there are those of you out there who think Babylon 5 is good sci-fi TV. It\'s not. It\'s clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It\'s really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it\'s rubbish as 

In [4]:
imdb["unsupervised"][0]

{'text': 'This is just a precious little diamond. The play, the script are excellent. I cant compare this movie with anything else, maybe except the movie "Leon" wonderfully played by Jean Reno and Natalie Portman. But... What can I say about this one? This is the best movie Anne Parillaud has ever played in (See please "Frankie Starlight", she\'s speaking English there) to see what I mean. The story of young punk girl Nikita, taken into the depraved world of the secret government forces has been exceptionally over used by Americans. Never mind the "Point of no return" and especially the "La femme Nikita" TV series. They cannot compare the original believe me! Trash these videos. Buy this one, do not rent it, BUY it. BTW beware of the subtitles of the LA company which "translate" the US release. What a disgrace! If you cant understand French, get a dubbed version. But you\'ll regret later :)',
 'label': -1}

## 2. 토크나이저 적재

In [5]:
from transformers import AutoTokenizer

model_id = "distilbert/distilbert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(model_id)

In [6]:
print(f"vocab_size: {tokenizer.vocab_size}")
print(f"model_max_length: {tokenizer.model_max_length}")
print(f"max_len_single_sentence: {tokenizer.max_len_single_sentence}")
print(f"max_len_sentences_pair: {tokenizer.max_len_sentences_pair}")
print(f"model_input_names: {tokenizer.model_input_names}")

vocab_size: 30522
model_max_length: 512
max_len_single_sentence: 510
max_len_sentences_pair: 509
model_input_names: ['input_ids', 'attention_mask']


## 3. 데이터 전처리

In [7]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [8]:
tokenized_imdb = imdb.map(preprocess_function, batched=True)
print(tokenized_imdb)

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 25000
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 25000
    })
    unsupervised: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 50000
    })
})


In [9]:
print(tokenized_imdb["train"])
print(tokenized_imdb["train"][0].keys())

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 25000
})
dict_keys(['text', 'label', 'input_ids', 'attention_mask'])


In [10]:
print(imdb["train"][0]["text"])
print(tokenized_imdb["train"][0]["input_ids"])
print(tokenizer.convert_ids_to_tokens(tokenized_imdb["train"][0]["input_ids"]))

I rented I AM CURIOUS-YELLOW from my video store because of all the controversy that surrounded it when it was first released in 1967. I also heard that at first it was seized by U.S. customs if it ever tried to enter this country, therefore being a fan of films considered "controversial" I really had to see this for myself.<br /><br />The plot is centered around a young Swedish drama student named Lena who wants to learn everything she can about life. In particular she wants to focus her attentions to making some sort of documentary on what the average Swede thought about certain political issues such as the Vietnam War and race issues in the United States. In between asking politicians and ordinary denizens of Stockholm about their opinions on politics, she has sex with her drama teacher, classmates, and married men.<br /><br />What kills me about I AM CURIOUS-YELLOW is that 40 years ago, this was considered pornographic. Really, the sex and nudity scenes are few and far between, eve

## 4. 데이터 콜레이터

In [11]:
train_ds = tokenized_imdb["train"]
print(train_ds)
print(type(train_ds[0]))
print(type(train_ds["input_ids"]))

Dataset({
    features: ['text', 'label', 'input_ids', 'attention_mask'],
    num_rows: 25000
})
<class 'dict'>
<class 'list'>


In [12]:
samples_for_test = {
    #"text": [train_ds["text"][i] for i in range(2)],
    "input_ids": [train_ds["input_ids"][i] for i in range(2)],
    "attention_mask": [train_ds["attention_mask"][i] for i in range(2)],
    "label": [train_ds["label"][i] for i in range(2)]
}

In [13]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [14]:
collated_samples_for_test = data_collator(samples_for_test)
print(collated_samples_for_test)

{'input_ids': tensor([[  101,  1045, 12524,  1045,  2572,  8025,  1011,  3756,  2013,  2026,
          2678,  3573,  2138,  1997,  2035,  1996,  6704,  2008,  5129,  2009,
          2043,  2009,  2001,  2034,  2207,  1999,  3476,  1012,  1045,  2036,
          2657,  2008,  2012,  2034,  2009,  2001,  8243,  2011,  1057,  1012,
          1055,  1012,  8205,  2065,  2009,  2412,  2699,  2000,  4607,  2023,
          2406,  1010,  3568,  2108,  1037,  5470,  1997,  3152,  2641,  1000,
          6801,  1000,  1045,  2428,  2018,  2000,  2156,  2023,  2005,  2870,
          1012,  1026,  7987,  1013,  1028,  1026,  7987,  1013,  1028,  1996,
          5436,  2003,  8857,  2105,  1037,  2402,  4467,  3689,  3076,  2315,
         14229,  2040,  4122,  2000,  4553,  2673,  2016,  2064,  2055,  2166,
          1012,  1999,  3327,  2016,  4122,  2000,  3579,  2014,  3086,  2015,
          2000,  2437,  2070,  4066,  1997,  4516,  2006,  2054,  1996,  2779,
         25430, 14728,  2245,  2055,  

In [15]:
print(collated_samples_for_test.keys())

dict_keys(['input_ids', 'attention_mask', 'labels'])


출력 내용을 보면 데이터셋의 키 목록 중에서 `label`이 사라지고 `labels`가 생겼음을 알 수 있습니다. 이것은 DataCollatorWithPadding 객체가 아래와 같이 처리하였기 때문입니다.

```python
if "label" in batch:
    batch["labels"] = batch["label"]
    del batch["label"]
if "label_ids" in batch:
    batch["labels"] = batch["label_ids"]
    del batch["label_ids"]
```

위 과정을 통해 모델의 `forward()`함수로 `labels` 인자로 전달할 값이 준비됩니다.

## 5. 모델 적재

In [16]:
import torch

# GPU 사용 가능 여부 확인 및 장치 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = "cpu"
print(f"device: {device}")

device: cuda


In [17]:
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [18]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert/distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
).to(device)

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 6. 미세조정 훈련

In [19]:
import evaluate

accuracy = evaluate.load("accuracy")

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [20]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [22]:
import os
os.environ["WANDB_DISABLED"] = "true"

In [23]:
training_args = TrainingArguments(
    output_dir="distilbert-base-uncased-finetuned-imdb",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_imdb["train"],
    eval_dataset=tokenized_imdb["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2216,0.217927,0.91608
2,0.1487,0.232638,0.93176


TrainOutput(global_step=3126, training_loss=0.20488157290643558, metrics={'train_runtime': 3111.0056, 'train_samples_per_second': 16.072, 'train_steps_per_second': 1.005, 'total_flos': 6556904415524352.0, 'train_loss': 0.20488157290643558, 'epoch': 2.0})

## 7. 모델 저장 및 사용

In [24]:
finetuned_model_path = "./fine-tuned-distilbert-imdb-text-classification"
tokenizer.save_pretrained(finetuned_model_path)
model.save_pretrained(finetuned_model_path)

tokenzier = AutoTokenizer.from_pretrained(finetuned_model_path)
model = AutoModelForSequenceClassification.from_pretrained(finetuned_model_path).to(device)

In [25]:
text = "This was a masterpiece. Not completely faithful to the books, but enthralling from beginning to end. Might be my favorite of the three."

In [26]:
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items() if isinstance(v, torch.Tensor)}

In [27]:
with torch.no_grad():
    logits = model(**inputs).logits

In [28]:
predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]

'POSITIVE'