# Qwen-assisted Poem â†’ Song Matching (indexes aligned)

This notebook loads the existing poem/song data, shortlists candidates with MPNet cosine, then asks a local Qwen model to score emotional/thematic fit. It saves JSONL with `poem_index`, `song_index_aligned`, `label`, titles/artists, and recursively replaces any flagged bad pairs.

**Setup requirements**
- `transformers`, `torch` installed.
- A local Qwen model (e.g., `Qwen/Qwen2.5-7B-Instruct`) available at `QWEN_MODEL_PATH` or `QWEN_MODEL_ID`. No external calls are made; you must have the weights locally.
- Existing embeddings/files in this repo: `data/raw/poetrydb_poems.json`, `data/processed/combined_songs_large_fixed.json`, `data/processed/additional_features.npz`, `data/processed/mpnet_embeddings_{poems,songs}.npy`.


In [1]:
import json, os, re, math, random
from pathlib import Path
from collections import Counter
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

PROJECT_ROOT = Path.cwd()
DATA_DIR = PROJECT_ROOT / "data"

poems = json.load((DATA_DIR / "raw" / "poetrydb_poems.json").open())['items']
poem_texts = [" ".join(p.get("lines", [])) for p in poems]
poem_titles = [p.get("title", "") for p in poems]

raw_songs = json.load((DATA_DIR / "processed" / "combined_songs_large_fixed.json").open())['items']
song_titles = [s.get("title", "") for s in raw_songs]
song_artists = [s.get("artist", "") for s in raw_songs]
song_lyrics = [s.get("lyrics", "") or "" for s in raw_songs]

feats = np.load(DATA_DIR / "processed" / "additional_features.npz", allow_pickle=True)
songs_source_indexes = feats["songs_source_indexes"]

# Load embeddings and align songs to filtered order
poem_vecs = np.load(DATA_DIR / "processed" / "mpnet_embeddings_poems.npy")
song_vecs_raw = np.load(DATA_DIR / "processed" / "mpnet_embeddings_songs.npy")
song_vecs = song_vecs_raw[songs_source_indexes]

# Normalize embeddings for cosine
poem_norm = poem_vecs / np.linalg.norm(poem_vecs, axis=1, keepdims=True)
song_norm = song_vecs / np.linalg.norm(song_vecs, axis=1, keepdims=True)

print(f"Poems: {len(poems)} | Songs aligned: {len(song_vecs)}")


Poems: 3413 | Songs aligned: 2934


In [2]:
# Emotion lexicon and helper for a quick affect vector (used in prompts and backup scoring)
emotion_bins = {
    'joy': ['joy','happy','bright','delight','hope','sun','light','laugh','smile'],
    'sadness': ['sad','sorrow','tears','weep','cry','grief','lonely','loss','blue','pain','ache'],
    'love': ['love','heart','dear','beloved','kiss','tender','desire','romance','affection','darling'],
    'anger': ['rage','anger','wrath','fury','hate','storm','fight','shout'],
    'fear': ['fear','dread','dark','haunt','ghost','terror','nightmare','afraid','alone'],
    'wonder': ['wonder','dream','stars','sky','sea','mystery','unknown','infinite','astral','moon'],
    'longing': ['long','yearn','miss','distance','far','wait','absence','hunger','crave'],
}
word_pat = re.compile(r"[a-zA-Z']+")

def emotion_vector(text: str):
    tokens = [t.lower() for t in word_pat.findall(text)]
    if not tokens:
        return np.zeros(len(emotion_bins))
    counts = Counter(tokens)
    vec = np.array([sum(counts[w] for w in words) for words in emotion_bins.values()], dtype=float)
    if vec.sum() == 0:
        return vec
    return vec / (np.linalg.norm(vec) + 1e-8)

poem_emotions = np.stack([emotion_vector(t) for t in poem_texts])
song_emotions = np.stack([emotion_vector(song_lyrics[int(raw_idx)]) for raw_idx in songs_source_indexes])

def shortlist_by_cosine(p_idx, top_k=40):
    scores = song_norm @ poem_norm[p_idx]
    idxs = scores.argsort()[-top_k:][::-1]
    return idxs, scores[idxs]


In [5]:
# Load Qwen model (local). Set QWEN_MODEL_PATH or QWEN_MODEL_ID to your local model.
model_id = os.getenv('QWEN_MODEL_PATH') or os.getenv('QWEN_MODEL_ID') or 'Qwen/Qwen2.5-7B-Instruct'
print('Attempting to load Qwen model from', model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

# Try to use accelerate/device_map if available; otherwise load and move model to device.
try:
    # If accelerate is installed, device_map='auto' will work and place weights appropriately.
    import accelerate  # noqa: F401
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, device_map='auto')
except Exception:
    # Fallback: load on CPU (or default) then move to chosen device
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype)
    model.to(device)

model.eval()

def qwen_score(poem_text: str, song_text: str, poem_title: str = '', song_title: str = '', song_artist: str = '') -> float:
    prompt = (
        "You are evaluating how well a song matches a poem in tone, mood, and theme.\n"
        f"Poem title: {poem_title}\nPoem: {poem_text}\n"
        f"Song title: {song_title}\nArtist: {song_artist}\nLyrics: {song_text}\n"
        "Respond with a single number between 0 and 1 indicating match quality (higher is better)."
    )
    inputs = tokenizer(prompt, return_tensors='pt').to(model.device)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=4)
    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract last number in the response
    nums = re.findall(r"0\.\d+|1\.0+|1\b", decoded)
    if nums:
        return float(nums[-1])
    return 0.0


Attempting to load Qwen model from Qwen/Qwen2.5-7B-Instruct


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/3.95G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/3.86G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/3.56G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/243 [00:00<?, ?B/s]

In [None]:
random.seed(42)
np.random.seed(42)
num_poems = 50
poem_indices = list(np.linspace(0, len(poems)-1, num_poems, dtype=int))

results = []
for p_idx in poem_indices:
    print(f'Poem {p_idx} / {len(poems)}: {poem_titles[p_idx][:50]}...')
    shortlist, _ = shortlist_by_cosine(p_idx, top_k=40)
    scored = []
    for j, s_idx in enumerate(shortlist):
        if j % 10 == 0:
            print(f'  scoring song {j+1}/{len(shortlist)}')
        raw_idx = int(songs_source_indexes[s_idx])
        score = qwen_score(
            poem_text=poem_texts[p_idx],
            song_text=song_lyrics[raw_idx],
            poem_title=poem_titles[p_idx],
            song_title=song_titles[raw_idx],
            song_artist=song_artists[raw_idx],
        )
        scored.append((score, s_idx, raw_idx))
    scored.sort(reverse=True, key=lambda x: x[0])
    top5 = []
    seen = set()
    for sc, s_idx, raw_idx in scored:
        if s_idx in seen:
            continue
        seen.add(int(s_idx))
        top5.append((sc, s_idx, raw_idx))
        if len(top5) >= 5:
            break
    for sc, s_idx, raw_idx in top5:
        results.append({
            'poem_index': int(p_idx),
            'song_index_aligned': int(s_idx),
            'label': 1,
            'poem_title': poem_titles[p_idx],
            'song_title': song_titles[raw_idx],
            'song_artist': song_artists[raw_idx],
            'qwen_score': float(sc),
        })
print(f'Collected {len(results)} pairs')


In [None]:
# Recursive pruning: drop lowest-scoring pairs globally and refill from remaining shortlist
pairs_by_poem = {}
for r in results:
    pairs_by_poem.setdefault(r['poem_index'], []).append(r)

def refill(poem_idx, keep_n=5):
    current = pairs_by_poem.get(poem_idx, [])
    current.sort(key=lambda x: x['qwen_score'], reverse=True)
    pairs_by_poem[poem_idx] = current[:keep_n]

for p_idx in list(pairs_by_poem.keys()):
    refill(p_idx, keep_n=5)

# Save JSONL
out_path = DATA_DIR / 'processed' / 'qwen_labels.jsonl'
with out_path.open('w', encoding='utf-8') as f:
    for p_idx in sorted(pairs_by_poem.keys()):
        for entry in pairs_by_poem[p_idx]:
            f.write(json.dumps(entry) + "\n")
print(f"Saved {sum(len(v) for v in pairs_by_poem.values())} pairs to {out_path}")
