<a href="https://colab.research.google.com/github/xinyuezhang-shirley/cs229FinalProject/blob/main/CS229_ProjectionLayer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# MPNet embeddings (raw, not yet filtered)
poem_vecs = np.load("data/processed/mpnet_embeddings_poems.npy")
song_vecs = np.load("data/processed/mpnet_embeddings_songs.npy")

# Load all features from full_features.npz
full = np.load("data/processed/full_features.npz", allow_pickle=True)

# Structural + lexical features (concatenated)
poem_struct = full["poem_struct"]  # (3413, 3)
poem_lexical = full["poem_lexical"]  # (3413, 3)
poem_feats = np.concatenate([poem_struct, poem_lexical], axis=1)  # (3413, 6)

song_struct = full["song_struct"]  # (2995, 4)
song_lexical = full["song_lexical"]  # (2995, 3)
# For songs, only use first 3 structural features to match poems (exclude WPM)
song_feats = np.concatenate([song_struct[:, :3], song_lexical], axis=1)  # (2995, 6)

# Semantic features
poem_sem_all = full["poem_semantic"]  # (3413, 36)
song_sem_all = full["song_semantic"]  # (2995, 36)

# Split semantic features by groups
# emotions(9): 0-9, themes(10): 9-19, other(17): 19-36
poem_sem_emo   = poem_sem_all[:, 0:9]
poem_sem_theme = poem_sem_all[:, 9:19]
poem_sem_other = poem_sem_all[:, 19:36]
song_sem_emo   = song_sem_all[:, 0:9]
song_sem_theme = song_sem_all[:, 9:19]
song_sem_other = song_sem_all[:, 19:36]

# Align song embeddings to match cleaned features
idx_map = full["song_source_indexes"]  # (2995,) maps cleaned songs -> raw embedding indices
song_vecs = song_vecs[idx_map]  # reorder raw embeddings to match cleaned data

print(f"Poems: {poem_vecs.shape[0]} items")
print(f"Songs: {song_vecs.shape[0]} items")
print(f"poem_vecs: {poem_vecs.shape}, song_vecs: {song_vecs.shape}")
print(f"poem_feats: {poem_feats.shape}, song_feats: {song_feats.shape}")
print(f"poem_sem (emo/theme/other): {poem_sem_emo.shape}, {poem_sem_theme.shape}, {poem_sem_other.shape}")
print(f"song_sem (emo/theme/other): {song_sem_emo.shape}, {song_sem_theme.shape}, {song_sem_other.shape}")


Poems: 3413 items
Songs: 2995 items
poem_vecs: (3413, 768), song_vecs: (2995, 768)
poem_feats: (3413, 6), song_feats: (2995, 6)
poem_sem (emo/theme/other): (3413, 9), (3413, 10), (3413, 17)
song_sem (emo/theme/other): (2995, 9), (2995, 10), (2995, 17)


In [5]:
class PairDataset(Dataset):
    def __init__(self, pos_pairs, neg_pairs, hard_pairs, size):
        """
        Returns poem/song indices for each sample.
        size = number of samples per epoch
        """
        self.pos_pairs  = pos_pairs
        self.neg_pairs  = neg_pairs
        self.hard_pairs = hard_pairs
        self.size = size

    def __len__(self):
        return self.size

    def __getitem__(self, idx):
        # Sample random positive pair
        i_poem, j_song = self.pos_pairs[np.random.randint(len(self.pos_pairs))]
        
        # Return indices only (training loop will index the actual data)
        return i_poem, j_song


In [7]:
# Branch sizes
p_dim_mp = poem_in["mpnet"].shape[1]
p_dim_emo = poem_in["sem_emo"].shape[1]
p_dim_theme = poem_in["sem_theme"].shape[1]
p_dim_other = poem_in["sem_other"].shape[1]
p_dim_ft  = poem_in["feat"].shape[1]
s_dim_mp = song_in["mpnet"].shape[1]
s_dim_emo = song_in["sem_emo"].shape[1]
s_dim_theme = song_in["sem_theme"].shape[1]
s_dim_other = song_in["sem_other"].shape[1]
s_dim_ft  = song_in["feat"].shape[1]
proj_dim = 128

class ProjectionModel(nn.Module):
    def __init__(self, p_dims, s_dims, proj_dim):
        super().__init__()
        p_mp, p_emo, p_theme, p_other, p_ft = p_dims
        s_mp, s_emo, s_theme, s_other, s_ft = s_dims
        # poem branches
        self.poem_mp = nn.Sequential(nn.Linear(p_mp, 256), nn.ReLU(), nn.Linear(256, 128))
        self.poem_emo = nn.Sequential(nn.Linear(max(p_emo,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.poem_theme = nn.Sequential(nn.Linear(max(p_theme,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.poem_other = nn.Sequential(nn.Linear(max(p_other,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.poem_ft = nn.Sequential(nn.Linear(p_ft, 64), nn.ReLU(), nn.Linear(64, 64))
        self.poem_proj = nn.Sequential(nn.LayerNorm(128+64+64+64+64), nn.Linear(128+64+64+64+64, proj_dim))
        # song branches
        self.song_mp = nn.Sequential(nn.Linear(s_mp, 256), nn.ReLU(), nn.Linear(256, 128))
        self.song_emo = nn.Sequential(nn.Linear(max(s_emo,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.song_theme = nn.Sequential(nn.Linear(max(s_theme,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.song_other = nn.Sequential(nn.Linear(max(s_other,1), 64), nn.ReLU(), nn.Linear(64, 64))
        self.song_ft = nn.Sequential(nn.Linear(s_ft, 64), nn.ReLU(), nn.Linear(64, 64))
        self.song_proj = nn.Sequential(nn.LayerNorm(128+64+64+64+64), nn.Linear(128+64+64+64+64, proj_dim))
    def forward_poem(self, p):
        mp = self.poem_mp(p["mpnet"])
        emo_in = p["sem_emo"] if p_dim_emo>0 else torch.zeros(p["mpnet"].shape[0], 1, device=p["mpnet"].device)
        theme_in = p["sem_theme"] if p_dim_theme>0 else torch.zeros(p["mpnet"].shape[0], 1, device=p["mpnet"].device)
        other_in = p["sem_other"] if p_dim_other>0 else torch.zeros(p["mpnet"].shape[0], 1, device=p["mpnet"].device)
        emo = self.poem_emo(emo_in)
        theme = self.poem_theme(theme_in)
        other = self.poem_other(other_in)
        ft  = self.poem_ft(p["feat"])
        comb = torch.cat([ALPHA*mp, BETA_EMO*emo, BETA_THEME*theme, BETA_OTHER*other, GAMMA*ft], dim=1)
        z = self.poem_proj(comb)
        return F.normalize(z, dim=1)
    def forward_song(self, s):
        mp = self.song_mp(s["mpnet"])
        emo_in = s["sem_emo"] if s_dim_emo>0 else torch.zeros(s["mpnet"].shape[0], 1, device=s["mpnet"].device)
        theme_in = s["sem_theme"] if s_dim_theme>0 else torch.zeros(s["mpnet"].shape[0], 1, device=s["mpnet"].device)
        other_in = s["sem_other"] if s_dim_other>0 else torch.zeros(s["mpnet"].shape[0], 1, device=s["mpnet"].device)
        emo = self.song_emo(emo_in)
        theme = self.song_theme(theme_in)
        other = self.song_other(other_in)
        ft  = self.song_ft(s["feat"])
        comb = torch.cat([ALPHA*mp, BETA_EMO*emo, BETA_THEME*theme, BETA_OTHER*other, GAMMA*ft], dim=1)
        z = self.song_proj(comb)
        return F.normalize(z, dim=1)

In [9]:
# Build model filename from hyperparameters
model_name = (f"model_bs{BATCH_SIZE}_ep{EPOCHS}_lr{LR}_temp{TEMP}_"
             f"posK{POS_TOPK}_hardK{HARD_TOPK}_"
             f"a{ALPHA}_bemo{BETA_EMO}_bthm{BETA_THEME}_both{BETA_OTHER}_g{GAMMA}.pt")

model = ProjectionModel(
    (p_dim_mp, p_dim_emo, p_dim_theme, p_dim_other, p_dim_ft),
    (s_dim_mp, s_dim_emo, s_dim_theme, s_dim_other, s_dim_ft),
    proj_dim
).to(DEVICE)

# Check if model already exists
if os.path.exists(model_name):
    print(f"Loading existing model from {model_name}...")
    model.load_state_dict(torch.load(model_name, map_location=DEVICE))
    print("Model loaded. Skipping training.")
    SKIP_TRAINING = True
else:
    print(f"No existing model found. Will train and save to {model_name}")
    SKIP_TRAINING = False

opt = torch.optim.Adam(model.parameters(), lr=LR)

# Set up schedulers (created even if may not be used to keep code simple)
if USE_COSINE_SCHEDULE and not SKIP_TRAINING:
    # Warmup: scale LR from (1/WARMUP_EPOCHS)*LR to LR
    def lr_lambda(epoch):
        if epoch < WARMUP_EPOCHS:
            return (epoch + 1) / WARMUP_EPOCHS
        return 1.0
    warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)
    # Cosine anneal after warmup
    scheduler_main = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS - WARMUP_EPOCHS, eta_min=MIN_LR)
else:
    warmup_scheduler = None
    scheduler_main = None

Loading existing model from model_bs256_ep500_lr0.002_temp0.1_posK10_hardK20_a0.6_bemo0.1_bthm0.15_both0.1_g0.15.pt...
Model loaded. Skipping training.


In [11]:
if not SKIP_TRAINING:
    torch.save(model.state_dict(), model_name)
    print(f"Model saved to {model_name}")
else:
    print("Model already loaded; no new save needed.")

Model already loaded; no new save needed.


In [None]:
# Evaluation on Human-Labeled Test Set
Evaluate the unsupervised model on human-labeled triplets to measure real-world performance.

Loaded 46 valid human-labeled triplets for supervised training
Example: poem=231, song1=550, song2=2335, label=1


In [None]:
# Evaluate unsupervised model on all human triplets
model.eval()
correct = 0
total = len(human_triplets)

with torch.no_grad():
    for p_idx, s1_idx, s2_idx, label in human_triplets:
        # Get embeddings for poem and both songs
        p_batch = {
            "mpnet": poem_gpu["mpnet"][p_idx:p_idx+1],
            "sem_emo": poem_gpu["sem_emo"][p_idx:p_idx+1],
            "sem_theme": poem_gpu["sem_theme"][p_idx:p_idx+1],
            "sem_other": poem_gpu["sem_other"][p_idx:p_idx+1],
            "feat": poem_gpu["feat"][p_idx:p_idx+1],
        }
        s1_batch = {
            "mpnet": song_gpu["mpnet"][s1_idx:s1_idx+1],
            "sem_emo": song_gpu["sem_emo"][s1_idx:s1_idx+1],
            "sem_theme": song_gpu["sem_theme"][s1_idx:s1_idx+1],
            "sem_other": song_gpu["sem_other"][s1_idx:s1_idx+1],
            "feat": song_gpu["feat"][s1_idx:s1_idx+1],
        }
        s2_batch = {
            "mpnet": song_gpu["mpnet"][s2_idx:s2_idx+1],
            "sem_emo": song_gpu["sem_emo"][s2_idx:s2_idx+1],
            "sem_theme": song_gpu["sem_theme"][s2_idx:s2_idx+1],
            "sem_other": song_gpu["sem_other"][s2_idx:s2_idx+1],
            "feat": song_gpu["feat"][s2_idx:s2_idx+1],
        }
        
        # Forward pass
        p_z = model.forward_poem(p_batch)
        s1_z = model.forward_song(s1_batch)
        s2_z = model.forward_song(s2_batch)
        
        # Compute cosine similarities
        sim1 = (p_z * s1_z).sum().item()
        sim2 = (p_z * s2_z).sum().item()
        
        # Predict which song is closer (1 or 2)
        pred = 1 if sim1 > sim2 else 2
        
        if pred == label:
            correct += 1

accuracy = correct / total
print(f"\n{'='*60}")
print(f"Unsupervised Model Evaluation on Human Triplets")
print(f"{'='*60}")
print(f"Total triplets: {total}")
print(f"Correct predictions: {correct}")
print(f"Accuracy: {accuracy*100:.2f}%")
print(f"{'='*60}")

In [None]:
# Baseline: Test raw MPNet performance (no training)
print("\n" + "="*60)
print("BASELINE: Raw MPNet Performance (Before Training)")
print("="*60)

mpnet_correct = 0
with torch.no_grad():
    for p_idx, s1_idx, s2_idx, label in human_triplets:
        # Raw MPNet cosine similarity (normalized embeddings)
        p_mpnet = poem_gpu["mpnet"][p_idx:p_idx+1]
        s1_mpnet = song_gpu["mpnet"][s1_idx:s1_idx+1]
        s2_mpnet = song_gpu["mpnet"][s2_idx:s2_idx+1]
        
        sim1 = (p_mpnet * s1_mpnet).sum().item()
        sim2 = (p_mpnet * s2_mpnet).sum().item()
        
        pred = 1 if sim1 > sim2 else 2
        if pred == label:
            mpnet_correct += 1

mpnet_acc = mpnet_correct / len(human_triplets)
print(f"Raw MPNet Accuracy: {mpnet_acc*100:.2f}% ({mpnet_correct}/{len(human_triplets)})")
print(f"Your Model Accuracy: 43.48% (20/{len(human_triplets)})")
print(f"\nPerformance Loss: {(mpnet_acc - 0.4348)*100:.2f}%")
print("="*60)

In [None]:
# Check label encoding and predictions for first few examples
print("\n" + "="*60)
print("DEBUG: First 5 triplet predictions (check for off-by-one)")
print("="*60)

model.eval()
with torch.no_grad():
    for idx, (p_idx, s1_idx, s2_idx, label) in enumerate(human_triplets[:5]):
        # Raw MPNet
        p_mpnet = poem_gpu["mpnet"][p_idx:p_idx+1]
        s1_mpnet = song_gpu["mpnet"][s1_idx:s1_idx+1]
        s2_mpnet = song_gpu["mpnet"][s2_idx:s2_idx+1]
        
        mpnet_sim1 = (p_mpnet * s1_mpnet).sum().item()
        mpnet_sim2 = (p_mpnet * s2_mpnet).sum().item()
        mpnet_pred = 1 if mpnet_sim1 > mpnet_sim2 else 2
        
        # Model
        p_batch = {"mpnet": poem_gpu["mpnet"][p_idx:p_idx+1],
                   "sem_emo": poem_gpu["sem_emo"][p_idx:p_idx+1],
                   "sem_theme": poem_gpu["sem_theme"][p_idx:p_idx+1],
                   "sem_other": poem_gpu["sem_other"][p_idx:p_idx+1],
                   "feat": poem_gpu["feat"][p_idx:p_idx+1]}
        s1_batch = {"mpnet": song_gpu["mpnet"][s1_idx:s1_idx+1],
                    "sem_emo": song_gpu["sem_emo"][s1_idx:s1_idx+1],
                    "sem_theme": song_gpu["sem_theme"][s1_idx:s1_idx+1],
                    "sem_other": song_gpu["sem_other"][s1_idx:s1_idx+1],
                    "feat": song_gpu["feat"][s1_idx:s1_idx+1]}
        s2_batch = {"mpnet": song_gpu["mpnet"][s2_idx:s2_idx+1],
                    "sem_emo": song_gpu["sem_emo"][s2_idx:s2_idx+1],
                    "sem_theme": song_gpu["sem_theme"][s2_idx:s2_idx+1],
                    "sem_other": song_gpu["sem_other"][s2_idx:s2_idx+1],
                    "feat": song_gpu["feat"][s2_idx:s2_idx+1]}
        
        p_z = model.forward_poem(p_batch)
        s1_z = model.forward_song(s1_batch)
        s2_z = model.forward_song(s2_batch)
        
        model_sim1 = (p_z * s1_z).sum().item()
        model_sim2 = (p_z * s2_z).sum().item()
        model_pred = 1 if model_sim1 > model_sim2 else 2
        
        print(f"\nTriplet {idx}: poem={p_idx}, song1={s1_idx}, song2={s2_idx}")
        print(f"  Label says: Song {label} is closer")
        print(f"  MPNet: sim1={mpnet_sim1:.3f}, sim2={mpnet_sim2:.3f} → pred={mpnet_pred} {'✓' if mpnet_pred==label else '✗'}")
        print(f"  Model: sim1={model_sim1:.3f}, sim2={model_sim2:.3f} → pred={model_pred} {'✓' if model_pred==label else '✗'}")

print("\n" + "="*60)
print("If MPNet is mostly ✓ but Model is mostly ✗, training degraded performance")
print("If both are ✗, check if labels are backwards (flip 1↔2)")
print("="*60)