<a href="https://colab.research.google.com/github/winny0217/gpt2/blob/main/gpt2_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [19]:
from transformers import GPT2LMHeadModel, AutoTokenizer
import torch

# 選擇 GPT-2 中文模型
model_name = "uer/gpt2-chinese-cluecorpussmall"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

if torch.cuda.is_available():
    model.to("cuda")

def generate_next_phrase(text, max_new_tokens=5):
    input_ids = tokenizer.encode(text, return_tensors="pt")
    if torch.cuda.is_available():
        input_ids = input_ids.to("cuda")

    output_ids = model.generate(
        input_ids,
        max_new_tokens=max_new_tokens,  # 最多補 5 個 token
        do_sample=True,
        top_k=50,
        temperature=0.8,
        repetition_penalty=1.2,
        pad_token_id=tokenizer.eos_token_id
    )
    generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    added_text = generated_text[len(text):].strip()

    # 控制補的字數（三個字）
    phrase = ''.join([ch for ch in added_text if ch.strip()])[:3]
    return phrase, generated_text

# ==== 測試 ====
sentence = "我今天想去"
phrase, full_sentence = generate_next_phrase(sentence)
print(f"原句：{sentence}")
print(f"補的短語：約三個字 → {phrase}")
print(f"完整句子：{full_sentence}")

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


原句：我今天想去
補的短語：約三個字 → 想去可
完整句子：我 今 天 想 去 可 看 到 老 鼠


In [15]:
from transformers import pipeline, GPT2LMHeadModel, AutoTokenizer
import torch

# 使用中文 GPT-2
model_name = "uer/gpt2-chinese-cluecorpussmall"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)


if torch.cuda.is_available():
    model.to("cuda")

# 計算句子困惑度（Perplexity）
def calculate_perplexity(sentence):
    encodings = tokenizer(sentence, return_tensors='pt')
    if torch.cuda.is_available():
        encodings = {k: v.to("cuda") for k, v in encodings.items()}

    with torch.no_grad():
        outputs = model(**encodings, labels=encodings["input_ids"])
    loss = outputs.loss
    return torch.exp(loss).item()  # perplexity

# 生成候選詞並過濾
def suggest_next_word(sentence, max_new_tokens=3, num_return_sequences=5):
    results = generator(
        sentence,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_k=50,
        temperature=0.8,
        num_return_sequences=num_return_sequences
    )

    candidates = []
    for r in results:
        text = r['generated_text'][len(sentence):].strip()
        word = ''.join([ch for ch in text if ch.strip()])[:3]
        if word:
            candidates.append(word)

    # 根據困惑度排序（越小越通順）
    scored_candidates = []
    for word in candidates:
        new_sentence = sentence + word
        perplexity = calculate_perplexity(new_sentence)
        scored_candidates.append((word, perplexity))

    scored_candidates.sort(key=lambda x: x[1])
    return scored_candidates[:3]  # 取前3個最通順的詞

Device set to use cuda:0


In [16]:
sentence = "我今天想去"
print(f"初始句子：{sentence}")

while True:
    candidates = suggest_next_word(sentence)
    print("建議詞（已按語境一致性排序）:")
    for i, (word, score) in enumerate(candidates):
        print(f"{i}. {word} (PPL={score:.2f})")

    choice = input("選擇(輸入序號採用，n重新生成，exit結束): ").strip().lower()

    if choice == "exit":
        break
    elif choice.isdigit() and 0 <= int(choice) < len(candidates):
        sentence += candidates[int(choice)][0]
        print(f"目前句子：{sentence}")
    else:
        print("重新生成...")

初始句子：我今天想去
建議詞（已按語境一致性排序）:
0. 大拇指 (PPL=36.61)
1. 到这里 (PPL=54.71)
2. 到这家 (PPL=59.19)
選擇(輸入序號採用，n重新生成，exit結束): n
重新生成...
建議詞（已按語境一致性排序）:
0. 里面吃 (PPL=48.85)
1. 歌城唱 (PPL=56.43)
2. 东街逛 (PPL=79.78)
選擇(輸入序號採用，n重新生成，exit結束): 2
目前句子：我今天想去东街逛
建議詞（已按語境一致性排序）:
0. 到这家 (PPL=64.98)
1. 想到去 (PPL=109.77)
2. 想起这 (PPL=124.97)
選擇(輸入序號採用，n重新生成，exit結束): exit
