In [1]:
import torch
import json
import numpy as np
from transformers import AutoTokenizer, AutoModel
import torch.nn as nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import json

# 替换为你本地的训练集路径
train_path = "train_data.jsonl"

tag_set = set()

with open(train_path, "r", encoding="utf-8") as f:
    for line in f:
        data = json.loads(line.strip())
        tag_set.update(data.get("label_tags", []))

tag2id = {tag: i for i, tag in enumerate(sorted(tag_set))}

# 保存
with open("tag2id.json", "w", encoding="utf-8") as f:
    json.dump(tag2id, f, ensure_ascii=False, indent=2)

print("✅ 已生成 tag2id.json，标签数：", len(tag2id))


✅ 已生成 tag2id.json，标签数： 668


In [3]:
# 加载标签映射
with open("tag2id.json", "r", encoding="utf-8") as f:
    tag2id = json.load(f)
id2tag = {i: tag for tag, i in tag2id.items()}
num_labels = len(tag2id)

# 定义模型结构（与训练一致）
class BERTTagRecommender(nn.Module):
    def __init__(self, model_name, output_dim):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.classifier = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, 512),
            nn.ReLU(),
            nn.Linear(512, output_dim)
        )

    def forward(self, input_ids, attention_mask):
        output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        return self.classifier(output.pooler_output)

In [4]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = BERTTagRecommender(model_name, num_labels)
model.load_state_dict(torch.load("checkpoint_model.pt", map_location=torch.device("cpu")))
model.eval()

BERTTagRecommender(
  (encoder): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 384, padding_idx=0)
      (position_embeddings): Embedding(512, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-5): 6 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e-12, eleme

In [5]:
def predict_tags(context_tags, target_character, top_k=8):
    text = ", ".join(context_tags) + " [SEP] " + target_character
    encoded = tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=64)
    with torch.no_grad():
        logits = model(encoded["input_ids"], encoded["attention_mask"])
        probs = torch.sigmoid(logits).squeeze().numpy()
    top_indices = probs.argsort()[::-1][:top_k]
    return [(id2tag[i], float(probs[i])) for i in top_indices if probs[i] > 0.2]

In [26]:
# 例：把 ganyu_(genshin_impact) 的 prompt 替换为 lumine_(genshin_impact) 风格
context_tags = [
    "1girl", "ahoge", "bare shoulders","cleavage","breasts","flower","dress","looking_at_viewer", "bell", "blinking", 
    "blue hair", "blush",  "purple eyes"
]

target_character = "lumine_(genshin_impact)"

result = predict_tags(context_tags, target_character)
print("🎯 推荐标签：", result)

🎯 推荐标签： [('hair_between_eyes', 0.9984135627746582), ('white_flower', 0.9977719187736511), ('flower', 0.9949317574501038), ('hair_flower', 0.9917808175086975), ('hair_ornament', 0.9823211431503296), ('yellow_eyes', 0.9625009894371033), ('short_hair_with_long_locks', 0.9360777139663696), ('blonde_hair', 0.8160218000411987)]
