In [9]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
import json

def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(line) for line in f]

train_data = load_jsonl("../data/iwslt2017_train.jsonl")
val_data   = load_jsonl("../data/iwslt2017_validation.jsonl")
test_data  = load_jsonl("../data/iwslt2017_test.jsonl")

zh_train = [x["zh"] for x in train_data]
en_train = [x["en"] for x in train_data]
zh_val   = [x["zh"] for x in val_data]
en_val   = [x["en"] for x in val_data]
zh_test  = [x["zh"] for x in test_data]
en_test  = [x["en"] for x in test_data]

print(len(zh_train), len(en_train))
print(len(zh_val), len(en_val))
print(len(zh_test), len(en_test))

231266 231266
879 879
8549 8549


In [10]:
from utils import Vocab, count_parameters, save_model
src_vocab = Vocab(zh_train)
tgt_vocab = Vocab(en_train)

In [11]:
print(len(src_vocab.stoi))
print(len(tgt_vocab.stoi))

546813
135869


In [12]:
mx = 0
for(s, t) in zip(zh_train, en_train):
    mx = max(mx, len(s.split()), len(t.split()))
print("Max length:", mx)

Max length: 89


In [16]:
import torch
from model import Transformer
from utils import load_model, BOS, EOS, PAD

# 选择设备（有 CUDA 则用第一张可用 GPU，否则使用 CPU）
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 根据已有的词表实例化模型（使用词表长度作为 vocab size）
src_vocab_size = len(src_vocab.stoi)
tgt_vocab_size = len(tgt_vocab.stoi)

# 这里的超参可根据你的训练 config 调整
model = Transformer(src_vocab_size, tgt_vocab_size, d_model=256, N=4, h=4, d_ff=1024, dropout=0.1, max_len=128)
model = model.to(device)

checkpoint_path = "/opt/data/private/zxh/results/model_bs_64_all*10.pt"  # 替换为实际路径
# 使用 utils.load_model 加载权重（map_location 指定到当前 device）
try:
    load_model(model, checkpoint_path, map_location=device)
    print(f"Loaded checkpoint -> {checkpoint_path}")
except Exception as e:
    # 更宽容的回退：尝试直接用 torch.load 并处理常见 dict 包装形式
    print(f"load_model failed with: {e}. Trying fallback load...")
    state = torch.load(checkpoint_path, map_location=device)
    if isinstance(state, dict):
        # 常见的可能键名：'state_dict', 'model_state', 'model'
        for k in ("state_dict", "model_state", "model"):
            if k in state:
                model.load_state_dict(state[k])
                break
        else:
            # 假设直接就是 state_dict
            model.load_state_dict(state)
    else:
        model.load_state_dict(state)
    print("Fallback load succeeded")

model.eval()

# 准备输入句子：使用已有的 Vocab.encode 以保证和训练时一致的格式（包含 BOS/EOS/pad）
prompt = "I like apple"  # 如果你的 src 是中文请按分词方式调整
ids = src_vocab.encode(prompt, max_len=32)
# encode 返回已经包含 BOS/EOS 与 padding 的 id 列表
src_tensor = torch.tensor([ids], dtype=torch.long, device=device)  # batch_size=1

# 简单的贪心解码（逐步生成）——加入重复抑制与 logits 屏蔽以减少重复输出
max_gen_len = 32
with torch.no_grad():
    memory = model.encode(src_tensor)
    ys = torch.tensor([[BOS]], dtype=torch.long, device=device)  # 已有 BOS
    prev_tok = None
    for step in range(max_gen_len):
        # 解码当前已生成序列
        dec = model.decode(ys, memory)
        logits = model.out(dec[:, -1, :])  # (batch=1, vocab)

        # 屏蔽不应生成的特殊 token（例如 BOS、PAD）
        logits[0, PAD] = -1e9
        logits[0, BOS] = -1e9
        # 可以在需要时屏蔽 EOS（这里保留 EOS，让模型终止）

        # 简单的“禁止立即重复”策略：不允许生成与上一步相同的 token
        if prev_tok is not None:
            logits[0, prev_tok] = -1e9

        # 取贪心
        next_tok = logits.argmax(dim=-1).item()
        ys = torch.cat([ys, torch.tensor([[next_tok]], dtype=torch.long, device=device)], dim=1)
        prev_tok = next_tok
        if next_tok == EOS:
            break

    # 将生成 id 列表（去掉 BOS/EOS/PAD）转换为 token 并合并为字符串输出
    gen_ids = ys.squeeze(0).tolist()
    # 去掉开头 BOS
    if gen_ids and gen_ids[0] == BOS:
        gen_ids = gen_ids[1:]
    # 截断到 EOS
    if EOS in gen_ids:
        gen_ids = gen_ids[:gen_ids.index(EOS)]
    # 去掉 PAD
    gen_ids = [i for i in gen_ids if i != PAD]

    output_tokens = [tgt_vocab.itos.get(i, "<unk>") for i in gen_ids]
    output = " ".join(output_tokens)

print("Input:", prompt)
print("Output:", output)

# 诊断信息，帮助判断是否为 vocab/加载问题
print("Generated ids:", gen_ids)
# 打印前 20 logits topk 作快速检查（最后一步）
try:
    topk = torch.topk(torch.softmax(logits, dim=-1), k=10)
    print("Top tokens:")
    for score, idx in zip(topk.values.squeeze(0).tolist(), topk.indices.squeeze(0).tolist()):
        print(idx, tgt_vocab.itos.get(idx, "<unk>"), f"{score:.4f}")
except Exception:
    pass


Loaded checkpoint -> /opt/data/private/zxh/results/model_bs_64_all*10.pt
Input: I like apple
Output: prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states, prevented states,
Generated ids: [106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246, 106406, 121246]
Top tokens:
121246 states, 1.0000
71116 essentially 0.0000
131965 vicious 0.0000
124472 tactics, 0.0000
92710 meaning 0.0000
98536 nutrition, 0.0000
129014 tweet 0.0000
66307 discipline, 0.0000
130341 universes, 0.0000
106083 pregnancy 0.0000


In [1]:
import os

save_path = "results/model_none_ps.pt"

# 获取文件名（不含路径）
filename = os.path.basename(save_path)
# 去掉扩展名并加上 _loss.png
loss_path = os.path.splitext(filename)[0] + "_loss.png"

print(loss_path)


model_none_ps_loss.png
