# Deploy on any 2 songs
- Use model to calculate embeddings (using eval mode specifically and with no gradient updating)
- NOTE: Right now, we have to have the input data in the correct format: a spectrogram/chromagram/tempogram (generically called "gram"). So for any deployment, we'll have to do preprocessing in the streamlit app for example. OR we can have a set of say 10-15 sample songs you can compare where we've already done all of the calculations.


In [1]:
import pandas as pd
import numpy as np
import librosa
import librosa.display
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from dotenv import dotenv_values 
import spotipy
from spotipy.oauth2 import SpotifyClientCredentials
import pickle as pkl

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset

from scipy.spatial.distance import euclidean
from sklearn.model_selection import train_test_split
from transformers import Wav2Vec2Model

import torch.optim as optim
from pydub import AudioSegment
import io

### Extract Embeddings

In [2]:
# Define pretrained resnet from Torch Vision resnet 18
class ResNetEmbedding(nn.Module):
    def __init__(self, embedding_dim=128, dropout_rate=0.5):
        # get resnet super class
        super(ResNetEmbedding, self).__init__()
        self.resnet = models.resnet18(weights='DEFAULT')
        # Change structure of first layer to take non RGB images, rest of params same as default
        self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.dropout = nn.Dropout(p=dropout_rate)
        # Set the last fully connected to a set dimension "embedding_dim" instead of default 1000
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, embedding_dim)

    def forward(self, x):
        x = self.resnet(x)
        return F.normalize(x, p=2, dim=1)

In [4]:
# How to load the model later using just the state dictionary
model = ResNetEmbedding()  # Make sure this matches the architecture you used
model.load_state_dict(torch.load('../modeling/resnet18_model_weights.pth', map_location=torch.device('cpu')))

# If using a GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

ResNetEmbedding(
  (resnet): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, tr

In [None]:
# Function to clip audio to 30s (10000 ms by default)
def clip_audio(audio_data, segment_length = 120000, sr=22050):
    audio = AudioSegment.from_file(audio_data)
    # check audio is at least segment_length
    if len(audio) < segment_length:
        raise ValueError("Audio is shorter than the segment length.")

    # Randomly generate start point at least 10s from the end of the song.
    start = np.random.randint(0, len(audio) - segment_length)
    audio_segment = audio[start:start + segment_length]

    # Return clip in format importable to Librosa
    audio_segment_io = io.BytesIO()
    audio_segment.export(audio_segment_io, format="wav")
    audio_segment_io.seek(0)
    return audio_segment_io # Export (with pointer at start of audio)

def extract_embedding(model, audio_data_clip, sr=22050, is_raw_audio = True, use_model=True):
    if is_raw_audio:
        
        clipped_audio = clip_audio(audio_data_clip)
        
        y, sr = librosa.load(clipped_audio, sr=sr)
        mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr)
        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)

    # If its from the train/test set, it's already in librosa format
    else:
        y = audio_data_clip
        mel_spectrogram = librosa.feature.melspectrogram(y=y, sr=sr)
        mel_spectrogram_db = librosa.power_to_db(mel_spectrogram, ref=np.max)
    
    # Convert to tensor and move to the appropriate device
    mel_tensor = torch.tensor(mel_spectrogram_db, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    
    if use_model:
        # Get the embedding from the model
        with torch.no_grad():
            embedding = model(mel_tensor)
        
        # Normalize the embedding to we're not comparing magnitudes
        #embedding = F.normalize(embedding, p=2, dim=1)
        return embedding
    else:
        # Flatten to get rid of dimensions, then unqueeze to make 2D with batch dimension 
        no_model_embedding = mel_tensor.flatten().unsqueeze(0)
        # Normalize so that we just compare patterns not magnitudes
        no_model_embedding = F.normalize(no_model_embedding, p=2, dim=1)
        return no_model_embedding

def compute_cosine_similarity(embedding1, embedding2):
    cosine_sim = F.cosine_similarity(embedding1, embedding2).item()
    return cosine_sim 

def compute_euclidean_distance(embedding1, embedding2):
    euclidean_dist = torch.dist(embedding1, embedding2).item()
    return euclidean_dist

### Search Spotify for Track to Analyze

In [5]:
# Load stuff from .env file
env_vars = dotenv_values('.env')

client_credentials_manager = SpotifyClientCredentials(
    client_id=os.getenv("SPOTIFY_CLIENT_ID"),
    client_secret=os.getenv("SPOTIFY_CLIENT_SECRET"),
)
# Get Spotify api client and apply to df
sp = spotipy.Spotify(client_credentials_manager=client_credentials_manager)

In [6]:
# Function to get the preview URL of a song based on artist name and song title
#@retry(wait=wait_exponential(multiplier=1, min=4, max=60), stop=stop_after_attempt(10))
def search_track(artist_name, track_name, sp, rate_limit = 1.0):
    # Search for the track
    result = sp.search(q=f'artist:{artist_name} track:{track_name}', type='track', limit=1)
    if result['tracks']['items']:
        # Return the preview URL if found
        return result['tracks']['items'][0]['preview_url']
    print('Arist/Track not found...')
    return None    

In [9]:
track1_name = 'Under Pressure'
track1_artist = 'Bowie'
track2_name = 'Rosanna'
track2_artist = 'Toto'
track1 = search_track(track1_artist, track1_name, sp)
track2 = search_track(track1_artist, track1_name, sp)

In [10]:
track2

In [30]:
artist_name = 'Robin Thicke'
track_name = 'Surfin USA'
#result = sp.search(q=f'artist:{artist_name} track:{track_name}', type='track', limit=1)
result = sp.search(q=f'track:{track_name}', type='track', limit=1)
preview = None
for item in result['tracks']['items']:
    if item['preview_url'] is not None:
        preview = item['preview_url']
        break
preview   
    

In [42]:
for track in result['tracks']:
    print(f"{track['items']}")

TypeError: string indices must be integers, not 'str'

In [41]:
result['tracks']['items']

[{'album': {'album_type': 'album',
   'artists': [{'external_urls': {'spotify': 'https://open.spotify.com/artist/3oDbviiivRWhXwIE8hxkVV'},
     'href': 'https://api.spotify.com/v1/artists/3oDbviiivRWhXwIE8hxkVV',
     'id': '3oDbviiivRWhXwIE8hxkVV',
     'name': 'The Beach Boys',
     'type': 'artist',
     'uri': 'spotify:artist:3oDbviiivRWhXwIE8hxkVV'}],
   'available_markets': ['AR',
    'AU',
    'AT',
    'BE',
    'BO',
    'BR',
    'BG',
    'CA',
    'CL',
    'CO',
    'CR',
    'CY',
    'CZ',
    'DK',
    'DO',
    'DE',
    'EC',
    'EE',
    'SV',
    'FI',
    'FR',
    'GR',
    'GT',
    'HN',
    'HK',
    'HU',
    'IS',
    'IE',
    'IT',
    'LV',
    'LT',
    'LU',
    'MY',
    'MT',
    'MX',
    'NL',
    'NZ',
    'NI',
    'NO',
    'PA',
    'PY',
    'PE',
    'PH',
    'PL',
    'PT',
    'SG',
    'SK',
    'ES',
    'SE',
    'CH',
    'TW',
    'TR',
    'UY',
    'US',
    'GB',
    'AD',
    'LI',
    'MC',
    'ID',
    'JP',
    'TH',
    'VN',
