In [None]:
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder
from sklearn.manifold import TSNE
from datetime import datetime
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [4]:
class PerfumeDataset(Dataset):
    def __init__(self, names, top_notes_list, middle_notes_list, base_notes_list, fragrance_labels, note2vec):
        self.names = names
        self.top_notes_list = top_notes_list
        self.middle_notes_list = middle_notes_list
        self.base_notes_list = base_notes_list
        self.fragrance_labels = fragrance_labels
        self.note2vec = note2vec
        
    def notes_to_vector(self, notes):
        vectors = []
        for note in notes:
            if note in self.note2vec:
                vectors.append(self.note2vec[note])
        if len(vectors) == 0:
            return torch.zeros(0, 768)
        return torch.tensor(np.stack(vectors), dtype=torch.float32)
        
    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, idx):
        name = self.names[idx]
        top_notes = self.notes_to_vector(self.top_notes_list[idx])
        middle_notes = self.notes_to_vector(self.middle_notes_list[idx])
        base_notes = self.notes_to_vector(self.base_notes_list[idx])
        fragrance = self.fragrance_labels[idx]
        return {
            'name': name,
            'top_notes': top_notes,
            'middle_notes': middle_notes,
            'base_notes': base_notes,
            'fragrance': fragrance
        }

In [5]:
def perfume_collate_fn(batch):
    names = [item['name'] for item in batch]
    labels = torch.tensor([item['fragrance'] for item in batch], dtype=torch.long)

    max_top = max([item['top_notes'].shape[0] for item in batch])
    max_mid = max([item['middle_notes'].shape[0] for item in batch])
    max_base = max([item['base_notes'].shape[0] for item in batch])

    def pad_notes(note_tensor, max_len):
        N = note_tensor.shape[0]
        padded = torch.zeros(max_len, 768)
        mask = torch.zeros(max_len, dtype=torch.float32)
        if N > 0:
            padded[:N] = note_tensor
            mask[:N] = 1.0
        return padded, mask

    top_notes, top_masks = [], []
    mid_notes, mid_masks = [], []
    base_notes, base_masks = [], []

    for item in batch:
        t, tm = pad_notes(item['top_notes'], max_top)
        m, mm = pad_notes(item['middle_notes'], max_mid)
        b, bm = pad_notes(item['base_notes'], max_base)
        top_notes.append(t); top_masks.append(tm)
        mid_notes.append(m); mid_masks.append(mm)
        base_notes.append(b); base_masks.append(bm)

    return {
        'name': names,
        'top_notes': torch.stack(top_notes),   # [B, max_top, 768]
        'top_mask': torch.stack(top_masks),    # [B, max_top]
        'middle_notes': torch.stack(mid_notes),
        'middle_mask': torch.stack(mid_masks),
        'base_notes': torch.stack(base_notes),
        'base_mask': torch.stack(base_masks),
        'fragrance': labels
    }

In [None]:
class PerfumeEmbedding(nn.Module):
    def __init__(self, note_dim=768, hidden=256, z_dim=128, num_classes=8, dropout=0.3):
        super().__init__()
        self.phi_top = nn.Sequential(
            nn.Linear(note_dim, hidden),
            nn.ReLU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout)
        )
        self.phi_mid = nn.Sequential(
            nn.Linear(note_dim, hidden),
            nn.ReLU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout)
        )
        self.phi_base = nn.Sequential(
            nn.Linear(note_dim, hidden),
            nn.ReLU(),
            nn.LayerNorm(hidden),
            nn.Dropout(dropout)
        )
        
        self.attn_top = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)
        )
        self.attn_mid = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)
        )
        self.attn_base = nn.Sequential(
            nn.Linear(hidden, 64),
            nn.Tanh(),
            nn.Dropout(dropout),
            nn.Linear(64, 1)
        )
        
        self.rho_top = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.rho_mid = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        self.rho_base = nn.Sequential(
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Dropout(dropout)
        )
        
        self.cross_attn = nn.MultiheadAttention(embed_dim=hidden, num_heads=4, batch_first=True, dropout=dropout)
        
        self.rho = nn.Sequential(
            nn.Linear(hidden * 6, z_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(z_dim, z_dim), 
            nn.ReLU()
        )
        self.classifier = nn.Linear(z_dim, num_classes)

    def aggregate(self, phi, attn_net, rho, notes, mask):
        h = phi(notes)  # [B, N, H]
        scores = attn_net(h).squeeze(-1)  # [B, N]
        scores = scores.masked_fill(mask == 0, -1e9)
        attn_weights = torch.softmax(scores, dim=1).unsqueeze(-1)  # [B, N, 1]
        weighted_h = (h * attn_weights).sum(dim=1)  # [B, H]
        return rho(weighted_h)  # [B, H]
    
    def forward(self, notes_top, mask_top, notes_mid, mask_mid, notes_base, mask_base):
        h_top = self.aggregate(self.phi_top, self.attn_top, self.rho_top, notes_top, mask_top)  # [B, H]
        h_mid = self.aggregate(self.phi_mid, self.attn_mid, self.rho_mid, notes_mid, mask_mid)  # [B, H]
        h_base = self.aggregate(self.phi_base, self.attn_base, self.rho_base, notes_base, mask_base)  # [B, H]
        
        h_seq = torch.stack([h_top, h_mid, h_base], dim=1)  # [B, 3, H]
        h_interact, _ = self.cross_attn(h_seq, h_seq, h_seq)  # [B, 3, H]
        h_top_i, h_mid_i, h_base_i = h_interact.unbind(dim=1)  # 拆回 [B, H]
        
        h_all = torch.cat([h_top, h_mid, h_base, h_top_i, h_mid_i, h_base_i], dim=-1)  # [B, 6*H]
        
        z = self.rho(h_all)  # [B, z_dim]
        logits = self.classifier(z)
        return logits, z

In [None]:
# Build note2vec dictionary
df_note = pd.read_csv("data/note_embedding.csv")
note2vec_dict = {row['note']: row.iloc[1:].values.astype(np.float32) for _, row in df_note.iterrows()}

In [8]:
# Load perfume data
df_perfume = pd.read_csv("data/1976_clean.csv")
names = df_perfume['name'].tolist()
top_notes_list = df_perfume['top_notes'].apply(lambda x: x.split('、')).tolist()
middle_notes_list = df_perfume['middle_notes'].apply(lambda x: x.split('、')).tolist()
base_notes_list = df_perfume['base_notes'].apply(lambda x: x.split('、')).tolist()
le = LabelEncoder()
fragrance_labels = le.fit_transform(df_perfume['fragrance'])

In [None]:
dataset = PerfumeDataset(names, top_notes_list, middle_notes_list, base_notes_list, fragrance_labels, note2vec_dict)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True, collate_fn=perfume_collate_fn)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = PerfumeEmbedding().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()
state_dict = torch.load("perfume_embedding_v2.pth")
model.load_state_dict(state_dict)

In [None]:
model.eval()
embeddings = []
with torch.no_grad():
    for i in tqdm(range(len(dataset)), desc="Evaluating"):
        perfume = dataset[i]
        top_notes = perfume['top_notes'].to(device)
        middle_notes = perfume['middle_notes'].to(device)
        base_notes = perfume['base_notes'].to(device)
        top_mask = torch.ones(top_notes.shape[0], dtype=torch.float32, device=device)
        mid_mask = torch.ones(middle_notes.shape[0], dtype=torch.float32, device=device)
        base_mask = torch.ones(base_notes.shape[0], dtype=torch.float32, device=device)
        logits, z = model(
            top_notes.unsqueeze(0), top_mask.unsqueeze(0),
            middle_notes.unsqueeze(0), mid_mask.unsqueeze(0),
            base_notes.unsqueeze(0), base_mask.unsqueeze(0)
        )
        embeddings.append(z.cpu().numpy().squeeze())

embeddings_df = pd.DataFrame(embeddings)
embeddings_df.insert(0, 'name', df_perfume['name'])
embeddings_df.to_csv("data/perfume_embedding_v2.csv", index=False)