In [77]:
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    LogitsProcessorList,
    MinLengthLogitsProcessor,
    TopKLogitsWarper,
    TemperatureLogitsWarper,
    StoppingCriteriaList,
    MaxLengthCriteria,
)
import torch

tokenizer = AutoTokenizer.from_pretrained("facebook/xglm-564M")
model = AutoModelForCausalLM.from_pretrained("facebook/xglm-564M")

In [98]:
def get_response(prompt_text):
    input_prompt = f'คำสุภาพของ \"{prompt_text}\" คือ'
    input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids

    # instantiate logits processors
    logits_processor = LogitsProcessorList(
        [
            MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
        ]
    )
    # instantiate logits processors
    logits_warper = LogitsProcessorList(
        [
            TopKLogitsWarper(50),
            TemperatureLogitsWarper(0.7),
        ]
    )

    stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=len(prompt_text)*2)])

    torch.manual_seed(0)
    outputs = model.sample(
        input_ids,
        logits_processor=logits_processor,
        logits_warper=logits_warper,
        stopping_criteria=stopping_criteria,
    )

    return (tokenizer.batch_decode(outputs, skip_special_tokens=True))[0]

In [99]:
import pandas as pd

test_df = pd.read_csv("test.csv")

In [100]:
non_polite = test_df[test_df["labels"] == "ไม่สุภาพ"]

In [101]:
from tqdm import tqdm
polite = []
for _, row in tqdm(non_polite.iterrows()):
    polite.append(get_response(row["text"]))

40it [04:17,  6.44s/it]


In [105]:
polite_filtered = []
for p in polite:
    idx = p.find("คือ")
    text = p[idx+3:].strip()
    polite_filtered.append(text)

In [107]:
non_polite["polite"] = polite_filtered

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  non_polite["polite"] = polite_filtered


In [92]:
non_polite.drop(columns=["labels"]).to_csv("result.csv", index=False)