In [5]:
import tomllib
import torch
from src.utils.utils import dataclass_from_dict
from src.model.hyperparameters import Hyperparameters
from __future__ import annotations
import pathlib
import sys
from src.model.autoencoder import Autoencoder

sys.path.append(pathlib.Path.cwd().parent)

def load_model_hyperparameters() -> Hyperparameters:
    with (pathlib.Path("/home/paolo/git/spotify-playlist-generator/config/model.toml")).open("rb") as f:
        return dataclass_from_dict(Hyperparameters, tomllib.load(f))
cfg = load_model_hyperparameters()


device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "")
model = Autoencoder.load_from_checkpoint("/home/paolo/git/spotify-playlist-generator/logs/mlruns/749985226233759206/0e707c5e421c47098798cc5c91d9532b/artifacts/model/checkpoints/model_checkpoint/model_checkpoint.ckpt", hyperparam=cfg) #.eval().half().to(device)

In [3]:
from torchvision.transforms import v2
from src.model.dataset.transforms import MinMaxNorm

from src.db.schemas.song_embedding import SongEmbedding
import torchaudio.transforms as T


transforms: v2.Compose = v2.Compose([
    T.MelSpectrogram(
        sample_rate=cfg.SAMPLE_RATE,
        n_fft=cfg.N_FFT,
        win_length=cfg.WIN_LENGTH,
        hop_length=cfg.HOP_LENGTH,
        n_mels=cfg.N_MELS,
        normalized=False,
    ),
    MinMaxNorm(),
    v2.ToDtype(torch.float16, scale=False),
])

def preprocess(tensor: torch.Tensor) -> torch.Tensor:
    tensor = transforms(tensor).to(device).unsqueeze(dim=0)
    print(tensor.shape)
    return tensor

def wrap_prediction_to_song_embedding(tensor: torch.Tensor, song_id: str) -> SongEmbedding:
    print(tensor.shape)
    return SongEmbedding(id=song_id, embedding=tensor.tolist())

def predict_audio(audio: torch.Tensor) -> torch.Tensor:
    crop_frames: int = cfg.SAMPLE_RATE * cfg.CROP_SIZE_SECONDS
    num_frames: int = audio.shape[1]
    last_audio_slice_length: int = num_frames % crop_frames
    print(f"crop_frames: {crop_frames} - num_frames: {num_frames} - last_audio_slice_length: {last_audio_slice_length}")
    
    if last_audio_slice_length == 0:
        print(f"length of last audio is equal to the length of a slice")
        pass
    elif last_audio_slice_length > crop_frames / 2:
        frames_to_add: int = crop_frames - (num_frames % crop_frames)
        print(f"frames_to_add: {frames_to_add}")
        audio = torch.cat([audio, torch.zeros((audio.shape[0], frames_to_add))], dim=1)
    else:
        frames_to_remove: int = num_frames % crop_frames
        print(f"frames_to_remove: {frames_to_remove}")
        audio = audio[:, :num_frames - frames_to_remove]
    
    num_slices: int = audio.shape[1] // crop_frames
    print(f"num_frames: {audio.shape[1]} - num_slices: {num_slices}")
    slices: list[torch.Tensor] = torch.chunk(audio, num_slices, dim=1)
    return torch.stack([model.encoder(preprocess(slice)) for slice in slices]).squeeze().mean(dim=0)


def get_song_embedding(track: dict[str, str]) -> SongEmbedding:
    song_id: str = track.get("song_id")
    audio_path: str = track.get("audio_path")

    audio, _ = torchaudio.load(audio_path)
    num_frames: int = audio.shape[1]
    crop_frames: int = cfg.CROP_SIZE_SECONDS * cfg.SAMPLE_RATE
    if num_frames <= crop_frames:
        frames_to_add: int = crop_frames * (num_frames // crop_frames + 1) - num_frames
        audio = torch.cat([audio, torch.zeros((audio.shape[0], frames_to_add))], dim=1)
        # prediction: torch.Tensor = model.encoder(preprocess(audio))
        # print(f"prediction shape: {prediction.shape}")
        # return wrap_prediction_to_song_embedding(prediction, song_id)

    prediction: torch.Tensor = predict_audio(audio)
    return wrap_prediction_to_song_embedding(prediction, song_id)

* 'orm_mode' has been renamed to 'from_attributes'


In [7]:
import torchaudio
# get_song_embedding({"song_id": "test", "audio_path": "/home/paolo/git/spotify-playlist-generator/data/raw/songs/1960/4Hhv2vrOTy89HFRcjU3QOx.mp3"})

In [8]:
audio, _ = torchaudio.load("/home/paolo/git/spotify-playlist-generator/data/raw/songs/1960/4Hhv2vrOTy89HFRcjU3QOx.mp3")

crop_frames = cfg.CROP_SIZE_SECONDS * cfg.SAMPLE_RATE

audio_less = audio[:, :2*crop_frames-(crop_frames//2+10)]
audio_equal = audio[:, :2*crop_frames]

predict_audio(audio_less)
print(f"\n")
predict_audio(audio_equal)
print(f"\n")
predict_audio(audio)

crop_frames: 1440000 - num_frames: 2159990 - last_audio_slice_length: 719990
frames_to_remove: 719990
num_frames: 1440000 - num_slices: 1
torch.Size([1, 2, 256, 5626])


crop_frames: 1440000 - num_frames: 2880000 - last_audio_slice_length: 0
length of last audio is equal to the length of a slice
num_frames: 2880000 - num_slices: 2
torch.Size([1, 2, 256, 5626])
torch.Size([1, 2, 256, 5626])


crop_frames: 1440000 - num_frames: 8625280 - last_audio_slice_length: 1425280
frames_to_add: 14720
num_frames: 8640000 - num_slices: 6
torch.Size([1, 2, 256, 5626])
torch.Size([1, 2, 256, 5626])
torch.Size([1, 2, 256, 5626])
torch.Size([1, 2, 256, 5626])
torch.Size([1, 2, 256, 5626])
torch.Size([1, 2, 256, 5626])


tensor([ 7.1430e-04,  4.8685e-04, -4.8351e-04, -2.7504e-03,  2.1877e-03,
         2.8210e-03,  2.0397e-04, -7.9966e-04,  2.0278e-04,  2.0924e-03,
         1.9245e-03,  1.9341e-03, -2.3251e-03, -2.1839e-03, -2.5826e-03,
         3.0231e-04, -1.1587e-03, -1.1702e-03, -8.2350e-04, -6.7186e-04,
        -2.0084e-03, -3.0136e-03, -3.2616e-04,  1.1148e-03,  3.5801e-03,
         3.7730e-05,  1.3485e-03, -1.4267e-03, -1.3933e-03, -1.8377e-03,
         5.7697e-04, -8.5878e-04, -1.8902e-03,  2.2106e-03, -6.4898e-04,
        -3.5310e-04, -2.3727e-03,  5.2834e-04, -8.5163e-04,  1.4954e-03,
         5.4502e-04,  8.3876e-04,  9.3269e-04,  5.2452e-04,  7.6199e-04,
        -7.8678e-04, -7.8583e-04, -3.5167e-04, -2.3174e-04,  3.0875e-04,
        -4.7660e-04,  1.7416e-04, -1.1187e-03,  3.1877e-04,  1.0748e-03,
         1.3752e-03, -4.2439e-04,  2.7676e-03,  9.4414e-04,  4.9114e-04,
         3.8648e-04, -8.1587e-04,  2.9621e-03,  1.5621e-03,  5.8270e-04,
        -1.7328e-03, -6.2990e-04, -9.9087e-04,  7.5