In [None]:
import numpy as np
import scipy.io as sio
import os
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from collections import defaultdict
import torch.nn as nn

In [2]:
def aggregate_feature(mat_path, index, time_steps):

    mat = sio.loadmat(mat_path)
    feature = mat[index].squeeze().T
    
    aggregated = []

    for step in time_steps:
        start = step["frame_start"]
        end   = step["frame_end"]

        if end > start:
            agg = feature[start:end].mean(axis=0)
        else:
            agg = feature[start]

        aggregated.append(agg)

    return np.vstack(aggregated)

In [3]:
def create_file_indices(root_dir):
    mfcc_dir = os.path.join(root_dir, "MFCCs")
    cens_dir = os.path.join(root_dir, "CENS")
    beats_dir = os.path.join(root_dir, "Beats")
    meta_dir = os.path.join(root_dir, "Metadata")

    indices = []

    for folder in sorted(os.listdir(mfcc_dir)):
        for file in sorted(os.listdir(os.path.join(mfcc_dir, folder))):

            file_name = file[:file.index("_")]
            
            entry = {
                "mfcc": os.path.join(mfcc_dir, folder, file),
                "cens": os.path.join(cens_dir, folder, f"{file_name}_CENS.mat"),
                "beats": os.path.join(beats_dir, folder, f"{file_name}_Beats.mat"),
                "meta": os.path.join(meta_dir, folder, f"{file_name}.txt")
            }

            indices.append(entry)
    return indices

In [4]:
def find_meta_data(filepath):

    with open(filepath) as f:
        data = f.read().split("\n")
    
    return data

In [5]:
def inference_collate(batch):
    batch = [b for b in batch if b is not None]

    X = torch.cat([b["X"] for b in batch], dim=0)
    durations = torch.cat([b["durations"] for b in batch], dim=0)

    song_ids = []
    for b in batch:
        song_ids.extend([b["meta_data"]] * len(b["X"]))

    return X, durations, song_ids

In [6]:
def determine_time_steps(beats, hop_size, fs):

    beat_times = beats * hop_size / fs
    durations = np.diff(beat_times)

    time_steps = []
    for t in range(len(beats) - 1):
        time_steps.append({
            "frame_start": int(beats[t]),
            "frame_end": int(beats[t + 1]),
            "duration": durations[t]
        })

    return time_steps

In [7]:
def compute_normalization(loader):
    X_all = []

    for X, _, _, _ in loader:
        if X is None:
            continue
        X_all.append(X)

    if len(X_all) == 0:
        raise RuntimeError("No valid samples found")

    X_all = torch.cat(X_all, dim=0)
    mean = X_all.mean(dim=0)
    std = X_all.std(dim=0) + 1e-6
    return mean, std

In [8]:
class BeatRegressor(nn.Module):

    def __init__(self, input_dim):

        super().__init__()
        self.next = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 2)
        )

    def forward(self, x):
        return self.next(x)

In [None]:
class Covers1000Dataset(Dataset):

    def __init__(
            self,
            file_indices
    ):
        
        self.file_indices = file_indices
    
    def __len__(self):
        return len(self.file_indices)

    def __getitem__(self, idx):

        mat_path = self.file_indices[idx]["beats"]
        mat = sio.loadmat(mat_path)

        fs = mat['Fs'].item()
        hop_size = mat['hopSize'].item()
        beats0 = mat['beats0'].squeeze()
        beat_times_in_sec = beats0 * hop_size / fs
        beat_durations = np.diff(beat_times_in_sec)

        time_steps = determine_time_steps(beats0, hop_size, fs)

        mfcc_beat = aggregate_feature(self.file_indices[idx]["mfcc"], 'XMFCC', time_steps)
        cens_beat = aggregate_feature(self.file_indices[idx]["cens"], 'XCENS', time_steps)

        X = np.hstack([
            mfcc_beat,
            cens_beat
        ])

        X = torch.tensor(X, dtype=torch.float32)
        durations = torch.tensor(beat_durations, dtype=torch.float32)

        return {
            "X": X,                    # (T, D)
            "durations": durations,    # (T,)
            "meta_data": self.file_indices[idx]["meta"][2:]
        }
    

In [None]:
mat_files = create_file_indices(".")

dataset = Covers1000Dataset(mat_files)

covers_loader = DataLoader(
    dataset=dataset,
    batch_size=8,          # small batches = safer memory
    shuffle=False,         # VERY IMPORTANT
    collate_fn=inference_collate,
    num_workers=0          # increase later if needed
)



./Metadata/1/174702.txt


In [43]:
device = "cuda" if torch.cuda.is_available() else "cpu"

checkpoint = torch.load("deam_beat_regressor.pt", map_location=device)

model = BeatRegressor(checkpoint["input_dim"]).to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

mean = checkpoint["mean"].to(device)
std = checkpoint["std"].to(device)

kalman = torch.load("emotion_ssm.pkl", map_location="cpu", weights_only=False)

song_preds = defaultdict(lambda: {"preds": [], "durs": []})

with torch.no_grad():
    
    for X, durations, meta_data in covers_loader:
        X = ((X - mean) / std).to(device)
        y_hat = model(X).cpu().numpy()

        durations = durations.cpu().numpy()

        for i, meta_data in enumerate(meta_data):
            song_preds[meta_data]["preds"].append(y_hat[i])
            song_preds[meta_data]["durs"].append(durations[i])

smoothed_song_level_results = {}
unsmoothed_song_level_results = {}

for meta_data, data in song_preds.items():
    preds = np.vstack(data["preds"])
    smoothed, _ = kalman.smooth(preds)

    durs = np.array(data["durs"])

    song_level_results = {}
    
    unsmoothed_song_pred = np.average(preds, axis=0, weights=durs)
    smoothed_song_pred = np.average(smoothed, axis=0, weights=durs)

    smoothed_song_level_results[meta_data] = {
        "valence": float(smoothed_song_pred[0]),
        "arousal": float(smoothed_song_pred[1])
    }

    unsmoothed_song_level_results[meta_data] = {
        "valence": float(unsmoothed_song_pred[0]),
        "arousal": float(unsmoothed_song_pred[1])
    }



./Metadata/1/174702.txt
./Metadata/1/464894.txt
./Metadata/10/60765.txt
./Metadata/10/60766.txt
./Metadata/100/451390.txt
./Metadata/100/451394.txt
./Metadata/101/301656.txt
./Metadata/101/301658.txt
./Metadata/101/301664.txt
./Metadata/101/447209.txt
./Metadata/102/18240.txt
./Metadata/102/23802.txt
./Metadata/102/393732.txt
./Metadata/102/45550.txt
./Metadata/103/443056.txt
./Metadata/103/443058.txt
./Metadata/104/14977.txt
./Metadata/104/14978.txt
./Metadata/105/337137.txt
./Metadata/105/337139.txt
./Metadata/106/257334.txt
./Metadata/106/474292.txt
./Metadata/106/66311.txt
./Metadata/107/313.txt
./Metadata/107/314.txt
./Metadata/107/436546.txt
./Metadata/108/145985.txt
./Metadata/108/388577.txt
./Metadata/109/282571.txt
./Metadata/109/58941.txt
./Metadata/109/58942.txt
./Metadata/11/13134.txt
./Metadata/11/368163.txt
./Metadata/110/157446.txt
./Metadata/110/468031.txt
./Metadata/110/61955.txt
./Metadata/111/21330.txt
./Metadata/111/303341.txt
./Metadata/112/366248.txt
./Metadata/11

In [45]:
smoothed_covers_dict = defaultdict(list)

for key in smoothed_song_level_results:
    key_slice = key[key.index("/") + 1:key.rindex("/")]
    song_slice = key[key.rindex("/") + 1:key.index(".")]
    if key_slice in smoothed_covers_dict:
        smoothed_covers_dict[key_slice].append({song_slice: smoothed_song_level_results[key]})
    else:
        smoothed_covers_dict[key_slice] = [{song_slice: smoothed_song_level_results[key]}]

In [46]:
unsmoothed_covers_dict = defaultdict(list)

for key in unsmoothed_song_level_results:
    key_slice = key[key.index("/") + 1:key.rindex("/")]
    song_slice = key[key.rindex("/") + 1:key.index(".")]
    if key_slice in unsmoothed_covers_dict:
        unsmoothed_covers_dict[key_slice].append({song_slice: unsmoothed_song_level_results[key]})
    else:
        unsmoothed_covers_dict[key_slice] = [{song_slice: unsmoothed_song_level_results[key]}]

In [47]:
for key in unsmoothed_covers_dict:
    if key == "234":
        print(unsmoothed_covers_dict[key])

for key in smoothed_covers_dict:
    if key == "234":
        print(smoothed_covers_dict[key])

[{'336806': {'valence': 0.1661280393600464, 'arousal': 0.39293724298477173}}, {'336807': {'valence': 0.16462668776512146, 'arousal': 0.3290117681026459}}]
[{'336806': {'valence': 0.16589160926382465, 'arousal': 0.39166576759965144}}, {'336807': {'valence': 0.16447033351879517, 'arousal': 0.32773354300812196}}]
