In [1]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

In [2]:
# 設定裝置為 GPU 或 CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# 載入訓練好的模型和 tokenizer
model_path = "./NetflixGPT-english"  # 修改為你訓練模型的儲存路徑
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)
model.to(device)
model.eval()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50260, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50260, bias=False)
)

In [4]:
# 定義生成描述的函數
def generate_description(title, max_length=100):
    # 構建輸入文本
    input_text = f"<|startoftext|>Title: {title}<|sep|>Description:"
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    
    # 使用模型進行生成
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_length=max_length,
            num_return_sequences=1,
            no_repeat_ngram_size=2,
            pad_token_id=tokenizer.eos_token_id,
            early_stopping=True
        )
    
    # 解碼生成的描述
    generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    return generated_text.replace(input_text, "").strip()

In [5]:
test_titles = ["Stranger Things", "Breaking Bad", "The Crown"]  # 測試用的標題
for title in test_titles:
    print(f"Title: {title}")
    print("Generated Description:", generate_description(title))
    print("-" * 50)

Title: Stranger Things




Generated Description: Title: Stranger ThingsDescription: When a mysterious stranger steals a series of TV sets, the show's creators must find a way to stop him before he wreaks havoc on the world.
--------------------------------------------------
Title: Breaking Bad
Generated Description: Title: Breaking BadDescription: When a group of criminals is caught in a web of deceit, they must use their skills to save the day and save their town.
--------------------------------------------------
Title: The Crown
Generated Description: Title: The CrownDescription: A young man's life is turned upside down when he's forced to marry a woman he met on a dating site.
--------------------------------------------------
