In [1]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn

In [2]:
class PerfumeEmbedding(nn.Module):
    def __init__(self, note_dim=768, hidden=256, z_dim=128, num_classes=8, dropout=0.3):
        super().__init__()
        self.phi_top = nn.Sequential(
            nn.Linear(note_dim, hidden),
            nn.ReLU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout)
        )
        self.phi_mid = nn.Sequential(
            nn.Linear(note_dim, hidden),
            nn.ReLU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout)
        )
        self.phi_base = nn.Sequential(
            nn.Linear(note_dim, hidden),
            nn.ReLU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout)
        )
        
        self.attn_top = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)
        )
        self.attn_mid = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)
        )
        self.attn_base = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)
        )
        
        self.rho_top = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.rho_mid = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.rho_base = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden, num_heads=4, batch_first=True, dropout=dropout)
        
        self.rho = nn.Sequential(
            nn.Linear(hidden * 6, z_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(z_dim, z_dim), 
            nn.ReLU()
        )
        self.classifier = nn.Linear(z_dim, num_classes)

    def aggregate(self, phi, attn_net, rho, notes):
        h = phi(notes)
        scores = attn_net(h).squeeze(-1)
        attn_weights = torch.softmax(scores, dim=1).unsqueeze(-1)
        weighted_h = (h * attn_weights).sum(dim=1)
        return rho(weighted_h)
    
    def forward(self, top_notes, mid_notes, base_notes):
        h_top = self.aggregate(self.phi_top, self.attn_top, self.rho_top, top_notes)
        h_mid = self.aggregate(self.phi_mid, self.attn_mid, self.rho_mid, mid_notes)
        h_base = self.aggregate(self.phi_base, self.attn_base, self.rho_base, base_notes)
        
        h_seq = torch.stack([h_top, h_mid, h_base], dim=1)
        h_interact, _ = self.cross_attn(h_seq, h_seq, h_seq)
        h_top_i, h_mid_i, h_base_i = h_interact.unbind(dim=1)
        
        h_all = torch.cat([h_top, h_mid, h_base, h_top_i, h_mid_i, h_base_i], dim=-1)
        
        z = self.rho(h_all)
        logits = self.classifier(z)
        return logits, z

In [3]:
# Build note2vec dictionary
df_note = pd.read_csv("data/note_embedding_v1.csv")
note2vec_dict = {row['note']: row.iloc[1:].values.astype(np.float32) for _, row in df_note.iterrows()}

In [None]:
def note2vec(note_list):
    vec = [ note2vec_dict[note] for note in note_list if note in note2vec_dict ]
    return vec

In [5]:
# Load perfume data
df_perfume = pd.read_csv("data/1976_clean.csv")

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PerfumeEmbedding().to(device)
state_dict = torch.load("models/PEM_v2.pth")
model.load_state_dict(state_dict)

<All keys matched successfully>

In [10]:
model.eval()
embeddings = []
with torch.no_grad():
    for _,row in df_perfume.iterrows():
        top_notes = row['top_notes'].split('、')
        mid_notes = row['middle_notes'].split('、')
        base_notes = row['base_notes'].split('、')
        
        top_vecs = note2vec(top_notes)
        mid_vecs = note2vec(mid_notes)
        base_vecs = note2vec(base_notes)
        
        # 跳过任何一个为空的样本
        if len(top_vecs) == 0 or len(mid_vecs) == 0 or len(base_vecs) == 0:
            print(f"Skip: {row['name']}")
            embeddings.append([np.nan]*128)
            continue
        
        top_vecs = torch.tensor(top_vecs, dtype=torch.float32).to(device).unsqueeze(0)
        mid_vecs = torch.tensor(mid_vecs, dtype=torch.float32).to(device).unsqueeze(0)
        base_vecs = torch.tensor(base_vecs, dtype=torch.float32).to(device).unsqueeze(0)
        
        logits, z = model(top_vecs, mid_vecs, base_vecs)
        embeddings.append(z.cpu().numpy().squeeze())

embeddings_df = pd.DataFrame(embeddings)
embeddings_df.insert(0, 'name', df_perfume['name'])
embeddings_df = embeddings_df.dropna()
embeddings_df.to_csv("data/perfume_embedding_v2.csv", index=False)

Skip: Miu Miu L'Eau Bleue 春日花园女性淡香精
Skip: Marly Castley 卡斯利淡香精中性淡香精行动香氛
