<a href="https://colab.research.google.com/github/winny0217/gpt2/blob/main/gpt2_4ipynb.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import torch.nn.functional as F

In [16]:
# 使用繁體中文 GPT-2 模型
model_name = "ckiplab/gpt2-base-chinese"
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name)

# 設置模型到GPU（如果可用）
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(21128, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=21128, bias=False)
)

In [11]:
def get_three_word_suggestions(input_text, top_k=5, temperature=0.8):
    # 編碼輸入文字
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)

    suggestions = []

    with torch.no_grad():
        # 一次性生成三個字的詞語
        current_ids = input_ids.clone()

        for position in range(3):  # 生成三個字
            # 模型預測
            outputs = model(current_ids)
            logits = outputs.logits[0, -1, :] / temperature  # 應用溫度參數

            # 獲取前k個候選
            top_k_probs, top_k_ids = torch.topk(F.softmax(logits, dim=-1), top_k)

            # 解碼候選字符
            position_candidates = []
            for i in range(top_k):
                token_id = top_k_ids[i].item()
                token = tokenizer.decode([token_id]).strip()
                prob = top_k_probs[i].item()

                # 過濾特殊符號和空白
                if token and token not in ['<|endoftext|>', '[UNK]', '<pad>'] and len(token.strip()) > 0:
                    position_candidates.append({
                        'token': token,
                        'probability': prob,
                        'token_id': token_id
                    })

            suggestions.append(position_candidates[:top_k])

            # 為下一次預測選擇概率最高的字符
            if position_candidates:
                selected_token_id = position_candidates[0]['token_id']
                current_ids = torch.cat([current_ids, torch.tensor([[selected_token_id]], device=device)], dim=1)

    return suggestions

In [12]:
def generate_complete_words(input_text, num_words=5):
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    complete_words = []

    with torch.no_grad():
        for _ in range(num_words):
            current_ids = input_ids.clone()
            word = ""

            # 生成三個字的完整詞語
            for _ in range(3):
                outputs = model(current_ids)
                logits = outputs.logits[0, -1, :]

                # 使用較低的溫度確保連貫性，但保持一定隨機性
                probs = F.softmax(logits / 0.7, dim=-1)

                # 從概率分布中採樣，而不是總是選擇最高概率的
                next_token_id = torch.multinomial(probs, num_samples=1).item()
                next_token = tokenizer.decode([next_token_id]).strip()

                # 過濾無效字符
                if next_token and next_token not in ['<|endoftext|>', '[UNK]', '<pad>'] and len(next_token.strip()) > 0:
                    word += next_token
                    current_ids = torch.cat([current_ids, torch.tensor([[next_token_id]], device=device)], dim=1)
                else:
                    break

            if len(word) == 3 and word not in complete_words:
                complete_words.append(word)

    return complete_words

In [13]:
def interactive_word_completion():
    """
    互動式詞語補全系統
    """
    print("=== 繁體中文三字詞語補全系統 ===")
    print("輸入文字，系統將建議三個字的詞語")
    print("輸入 'quit' 結束程式\n")

    while True:
        try:
            input_text = input("請輸入文字: ").strip()

            if input_text.lower() == 'quit':
                print("感謝使用！")
                break

            if not input_text:
                print("請輸入有效的文字")
                continue

            print(f"\n輸入文字：{input_text}")
            print("=" * 50)

            # 方案一：逐字建議
            print("【逐字建議】")
            suggestions = get_three_word_suggestions(input_text, top_k=3)

            for i, position_suggestions in enumerate(suggestions):
                print(f"第{i+1}個字的候選：")
                for j, candidate in enumerate(position_suggestions):
                    print(f"  {j+1}. {candidate['token']} (機率: {candidate['probability']:.3f})")
                print()

            # 方案二：完整詞語建議
            print("【完整三字詞語建議】")
            complete_words = generate_complete_words(input_text, num_words=5)

            for i, word in enumerate(complete_words):
                print(f"{i+1}. {input_text}{word}")

            print("\n" + "=" * 50 + "\n")

        except KeyboardInterrupt:
            print("\n程式中斷，感謝使用！")
            break
        except Exception as e:
            print(f"發生錯誤：{e}")
            continue

In [15]:
# ==== 示範用法 ====
if __name__ == "__main__":
    # 啟動互動模式
    interactive_word_completion()

=== 繁體中文三字詞語補全系統 ===
輸入文字，系統將建議三個字的詞語
輸入 'quit' 結束程式

請輸入文字: 我想去

輸入文字：我想去
【逐字建議】
第1個字的候選：
  1. , (機率: 0.067)
  2. 」 (機率: 0.047)
  3. 。 (機率: 0.037)

第2個字的候選：
  1. 但 (機率: 0.186)
  2. 我 (機率: 0.171)
  3. 這 (機率: 0.061)

第3個字的候選：
  1. 是 (機率: 0.319)
  2. 我 (機率: 0.317)
  3. 這 (機率: 0.042)

【完整三字詞語建議】
1. 我想去」。昨
2. 我想去此時,
3. 我想去,也不
4. 我想去好的,
5. 我想去個月我


請輸入文字: quit
感謝使用！
