## Semantic Song Search: Fine-Tune Base Model

In [1]:
from sentence_transformers import SentenceTransformer
import pandas as pd
import numpy as np

### Data Prep

In [2]:
sds = pd.read_csv("data/small_dataset.csv")

In [3]:
sds.dtypes

Unnamed: 0     int64
index          int64
title         object
tag           object
artist        object
year           int64
views          int64
features      object
lyrics        object
id             int64
dtype: object

In [4]:
sds['tag'].unique()

array(['rap', 'pop', 'rock', 'rb', 'country', 'misc'], dtype=object)

#### Remove rap because we don't like the lyrics
- Too explicit / NSFW / NSFS
- Many similar lines, less cohesive narratives/plots
- Concern: Some rap songs are tagged as pop

In [5]:
sds = sds[sds['tag'] != 'rap']
sds.shape

(614, 10)

In [7]:
# all
lyrics = sds['lyrics']

### Remove duplicate lyrics lines
- Remove duplicate lines because otherwise our model may learn to match song halves based on repeated lyrics within a song
- Define function below to dedupe and preserve order
- Concern: There can still be quite similar lines scattered through songs

In [11]:
# remove duplicates from list and preserve order
def dedupe(seq):
    seen = set()
    seen_add = seen.add
    return [x for x in seq if not (x in seen or seen_add(x))]

In [13]:
# confirm dedupe line count
len(dedupe(test.split('\n'))) == len(set(test.split('\n')))

True

In [15]:
# process dataset

lyrics_deduped = []

for song in lyrics:
    song = dedupe(song.split('\n'))
    song = '\n'.join(song)
    lyrics_deduped.append(song)

### Training

Reference: https://www.pinecone.io/learn/fine-tune-sentence-transformers-mnr/#fast-fine-tuning

In [16]:
# prep training data

from sentence_transformers import InputExample
from tqdm.auto import tqdm  # progress bar

train_samples = []
for song in tqdm(lyrics_deduped):
    # split songs into two halves for positive pair training
    half_1 = song[:len(song)//2]
    half_2 = song[len(song)//2:]
    train_samples.append(InputExample(
        texts=[half_1, half_2]
    ))

  0%|          | 0/614 [00:00<?, ?it/s]

In [18]:
from sentence_transformers import datasets

batch_size = 32

# removes duplicate pairings
loader = datasets.NoDuplicatesDataLoader(
    train_samples, batch_size=batch_size)

In [20]:
from sentence_transformers import models, SentenceTransformer

model = SentenceTransformer('all-MiniLM-L12-v2')

model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

In [21]:
# specify loss

from sentence_transformers import losses

# mnr loss
loss = losses.MultipleNegativesRankingLoss(model)

In [22]:
# train and export model

epochs = 1
warmup_steps = int(len(loader) * epochs * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=epochs,
    warmup_steps=warmup_steps,
    output_path='./finetune_test_mnr',
    show_progress_bar=False
) 