# LLM 으로 뉴스기사 분류해보기

In [1]:
import os
from dotenv import load_dotenv
from huggingface_hub import login

load_dotenv()

hugging_face_token = os.getenv("HUGGING_FACE_TOKEN")

login(hugging_face_token)

The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: write).
Your token has been saved to /Users/joyuiyeong/.cache/huggingface/token
Login successful


## GEMMA 모델과 Tokenizer 로드하기

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto")
print(model.device)
model

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

mps:0


GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

In [3]:
import torch


def get_device():
    if torch.cuda.is_available():
        return torch.device("cuda")
    elif torch.backends.mps.is_available():
        return torch.device("mps")
    else:
        return torch.device("cpu")


def clear_cache():
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    elif torch.backends.mps.is_available():
        torch.mps.empty_cache()


my_device = get_device()
my_device

device(type='mps')

## Zero-Shot 분류 함수 정의

In [4]:
def tokenize(device, text):
    tokenized_text = tokenizer(text, return_tensors="pt").to(device)
    return tokenized_text["input_ids"], tokenized_text["attention_mask"]


def zero_shot_classification(device, task_description, text, candidate_labels):
    question_input_ids, question_attention_mask = tokenize(
        device, task_description + text
    )
    scores = []
    for label in candidate_labels:
        label_input_ids, label_attention_mask = tokenize(device, label)
        num_label_tokens = label_input_ids.shape[-1] - 1

        input_ids = torch.concatenate(
            [question_input_ids, label_input_ids[..., 1:]], axis=-1
        )
        attention_mask = torch.concatenate(
            [question_attention_mask, label_attention_mask[..., 1:]], axis=-1
        )

        logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
        num_total_token = input_ids.shape[-1]
        score = sum(
            [
                logits[0, num_total_token - i, label_input_ids[0, i].item()]
                for i in range(num_label_tokens, 0, -1)
            ]
        )
        scores.append(score)

        del input_ids
        del attention_mask
        del logits

        clear_cache()
    return scores

## AG News 데이터셋 로드하기
- 4개의 뉴스 카테고리
    - 1: World, 2: Sports, 3: Business, 4: Science/Technology

In [5]:
from datasets import load_dataset

ds_ag_news = load_dataset("fancyzhx/ag_news")
ds_ag_news

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})

In [6]:
def preprocess_function(data):
    return tokenizer(data["text"])


tokenized_ds = ds_ag_news.map(preprocess_function, batched=True)

## test 데이터셋으로 분류해보기

In [7]:
from tqdm import tqdm

NUM_TEST = 50
total_correctness = 0
task_description = "A short news article is given. Decide which category the article belongs to. Article: "
candidate_labels = [
    "Answer: World",
    "Answer: Sports",
    "Answer: Business",
    "Answer: Science/Technology",
]

for i in tqdm(range(NUM_TEST)):
    text = tokenized_ds["test"][i]["text"]
    label = tokenized_ds["test"][i]["label"]

    scores = zero_shot_classification(
        device=my_device,
        task_description=task_description,
        text=text,
        candidate_labels=candidate_labels,
    )

    prediction = torch.argmax(torch.Tensor(scores)).item()
    if prediction == label:
        total_correctness += 1

100%|██████████| 50/50 [00:30<00:00,  1.62it/s]


In [8]:
print("Total Correctness: ", total_correctness)
print("Accuracy: ", total_correctness / NUM_TEST)

Total Correctness:  9
Accuracy:  0.18


In [9]:
import gc

gc.collect()

clear_cache()