# Решение домашнего задания #8

In [6]:
import ast
from concurrent.futures import ThreadPoolExecutor
import os
from typing import Literal
import re

import kagglehub
from transformers import pipeline

import pandas as pd
import numpy as np
from tabulate import tabulate
from tqdm import tqdm

import swifter

In [None]:
HARD_PREDICTION_PROMPT = """Classify the following BBC news article into one of the following categories: {categories}.

Article:
{article}

What is the most appropriate category for this article? Write the name of the category only, without any additional text or explanation. Do not use any punctuation marks."""

SOFT_PREDICTION_PROMPT = """Given the following BBC news article, estimate the likelihood (in percentage) that it belongs to each of the following categories: {categories}. Ensure that the total adds up to 100%.

Article:
{article}

What are the predicted probabilities for each category? Write the probabilities in the following format:
Category1: <probabilty1>, Category2: <probabilty2>, Category3: <probabilty3>, ...
Do not use any additional text or explanations. Ensure that the percentages are whole numbers and sum to 100%."""

In [8]:
# path = kagglehub.dataset_download("adityajn105/flickr8k")                              # C:\\Users\\setday\\.cache\\kagglehub\\datasets\\adityajn105\\flickr8k\\versions\\1
# path = kagglehub.dataset_download("alfathterry/bbc-full-text-document-classification") # C:\\Users\\setday\\.cache\\kagglehub\\datasets\\alfathterry\\bbc-full-text-document-classification\\versions\\1

In [None]:
# path = kagglehub.dataset_download("alfathterry/bbc-full-text-document-classification")
path = "C:\\Users\\setday\\.cache\\kagglehub\\datasets\\alfathterry\\bbc-full-text-document-classification\\versions\\1"

data = pd.read_csv(path + "\\bbc_data.csv")
categories = data["labels"].unique()

In [204]:
data["soft_query"] = data["data"].apply(lambda data: f'User: {SOFT_PREDICTION_PROMPT.format(categories=", ".join(categories), article=data)}\n\nModel: Article category probabilities are:')
data["hard_query"] = data["data"].apply(lambda data: f'User: {HARD_PREDICTION_PROMPT.format(categories=", ".join(categories), article=data)}\n\nModel: Article category is:')

In [205]:
data

Unnamed: 0,data,labels,hard_query,soft_query,hard_answers
0,Musicians to tackle US red tape Musicians gro...,entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
1,"U2s desire to be number one U2, who have won ...",entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
2,Rocker Doherty in on-stage fight Rock singer ...,entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
3,Snicket tops US box office chart The film ada...,entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
4,"Oceans Twelve raids box office Oceans Twelve,...",entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
...,...,...,...,...,...
2220,Warning over Windows Word files Writing a Mic...,tech,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",business
2221,Fast lifts rise into record books Two high-sp...,tech,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
2222,Nintendo adds media playing to DS Nintendo is...,tech,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
2223,Fast moving phone viruses appear Security fir...,tech,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment


In [12]:
pipe = pipeline("text-generation", model="meta-llama/Llama-3.2-1B", device='cuda', max_new_tokens=30)

In [None]:
def extract_category(texts: list[str], categories: list[str]):
    texts = [text.lower() for text in texts]
    categories = [category.lower() for category in categories]

    extracted = []
    for text in texts:
        for category in categories:
            if category in text:
                extracted.append(category)
                break
        else:
            extracted.append('none')
    return extracted

def get_categories(queries: list[str], categories: list[str], batch_size=32):
    results = []
    for i in tqdm(range(0, len(queries), batch_size)):
        batch_data = queries.iloc[i:i + batch_size].to_list()
        outs = pipe.predict(batch_data)
        outs = [
            out[0]["generated_text"][len(query):]
            for out, query in zip(outs, batch_data)
        ]
        results.extend(extract_category(outs, categories))
    return results

In [None]:
data["hard_answers"] = get_categories(data["hard_query"], categories)

  0%|          | 0/70 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id`:None for open-end generation.
Setting `pad_token_id` to `eos_token_id

In [4]:
data["hard_answers"].value_counts()

hard_answers
entertainment    1611
business          323
sport             174
politics           85
tech               30
none                2
Name: count, dtype: int64

In [11]:
data = pd.read_csv("tst.csv")

In [None]:
def get_categories_probs(queries: list[str], categories: list[str], batch_size=64):
    results = []
    for i in tqdm(range(0, len(queries), batch_size)):
        batch_data = queries.iloc[i:i + batch_size].to_list()
        outs = pipe.predict(batch_data)
        outs = [
            out[0]["generated_text"][len(query):]
            for out, query in zip(outs, batch_data)
        ]
        results.extend(outs)
    return results

def extract_categories_probs(answers: list[str], categories: list[str]):
    splits = [
        a.lower()\
            .replace("%", "")\
                .replace("=", " ")\
                    .replace(":", " ")\
                        .replace("(", "")\
                            .replace(")", "")\
                                .replace("\n", " | ")\
                                    .replace(",", " ")\
                                        .replace(". ", " ")\
                                            .split()\
        for a in answers
    ]
    extracted = [
        {
            category: float(proba)
            for category, proba in zip(s, s[1:])
            if category in categories and (proba.isdigit() or re.match(r'^-?\d+(?:\.\d*)$', proba) is not None)
        }
        for s in splits
    ]
    filtered = [
        e if all([category in e for category in categories])
        else {}
        for e in extracted
    ]
    return [
        {
            category: float(proba) / sum(f.values())
            for category, proba in f.items()
        }
        for f in filtered
    ]

data["soft_probs"] = extract_categories_probs(get_categories_probs(data["soft_query"], categories), categories)

In [118]:
data.to_csv("123.csv")

In [214]:
data

Unnamed: 0,data,labels,hard_query,soft_query,hard_answers
0,Musicians to tackle US red tape Musicians gro...,entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
1,"U2s desire to be number one U2, who have won ...",entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
2,Rocker Doherty in on-stage fight Rock singer ...,entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
3,Snicket tops US box office chart The film ada...,entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
4,"Oceans Twelve raids box office Oceans Twelve,...",entertainment,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
...,...,...,...,...,...
2220,Warning over Windows Word files Writing a Mic...,tech,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",business
2221,Fast lifts rise into record books Two high-sp...,tech,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
2222,Nintendo adds media playing to DS Nintendo is...,tech,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
2223,Fast moving phone viruses appear Security fir...,tech,User: Classify the following BBC news article ...,"User: Given the following BBC news article, es...",entertainment
