In [1]:
import torch, librosa
from muq import MuQ

device = 'cuda:2'
wav, sr = librosa.load("/data2/zhouyz/rec/MSD/MSD_2/save/AA/TRAAJEN144D1BB07C0.mp3", sr = 24000)
wavs = torch.tensor(wav).unsqueeze(0).to(device) 

# This will automatically fetch the checkpoint from huggingface
muq = MuQ.from_pretrained("OpenMuQ/MuQ-large-msd-iter")
muq = muq.to(device).eval()

with torch.no_grad():
    output = muq(wavs, output_hidden_states=True)

print('Total number of layers: ', len(output.hidden_states))
print('Feature shape: ', output.hidden_states[0].shape) # L: 13

  from .autonotebook import tqdm as notebook_tqdm
  WeightNorm.apply(module, name, dim)


Total number of layers:  13
Feature shape:  torch.Size([1, 3993, 1024])


In [2]:
import numpy as np
all_features = []
for l in range(1, len(output.hidden_states)):
    feature_i = output.hidden_states[0].detach().cpu().numpy().squeeze().mean(axis=0)
    all_features.append(feature_i)

all_features = np.array(all_features)
print('All features shape: ', all_features.shape) # (13, 1024)

All features shape:  (12, 1024)


In [7]:
a = None 
a is None

True

In [5]:
import numpy as np
# 定义切片参数
slice_duration_sec = 30
slice_samples = slice_duration_sec * sr

# 存储每个切片的特征
# layer_features_list = [[] for _ in range(12)]
layer_chunk_means_list = [[] for _ in range(12)]

# 对音频进行切片并处理
for i in range(0, len(wav), slice_samples):
    chunk = wav[i:i + slice_samples]
    
    # 如果最后一个切片太短，可以跳过或填充，这里我们直接使用
    if len(chunk) < 1 * sr: # 跳过短于1秒的片段
        continue

    wavs_chunk = torch.tensor(chunk).unsqueeze(0).to(device)
    with torch.no_grad():
        output = muq(wavs_chunk, output_hidden_states=True)
    
    # 提取并保存每个层的隐藏状态
    for l_idx in range(1, len(output.hidden_states)):
        # feature_l_chunk shape: (1, T_chunk, 1024)
        feature_l_chunk = output.hidden_states[l_idx].detach().cpu()
        # mean_feature_l_chunk shape: (1, 1, 1024)
        mean_feature_l_chunk = feature_l_chunk.mean(dim=1, keepdim=True)
        layer_chunk_means_list[l_idx-1].append(mean_feature_l_chunk)
        
all_features = []
for l_idx in range(len(layer_chunk_means_list)):
    # layer_chunk_means_list[l_idx] 是一个列表，包含多个 (1, 1, 1024) 的张量
    # 沿时间维度(dim=1)拼接，得到 (1, N_chunks, 1024) 的张量
    concatenated_features_l = torch.cat(layer_chunk_means_list[l_idx], dim=1)
    all_features.append(concatenated_features_l)

# 将所有层的特征堆叠起来，形成 (12, N_chunks, 1024) 的张量
final_features = torch.cat(all_features, dim=0).numpy()

# (12, N_chunks, 1024)
# print('Final features shape: ', final_features.shape)

all_features = np.array(all_features).squeeze(1)

In [6]:
all_features.shape

(12, 6, 1024)

### 提取特征之后打包

In [10]:
import pickle
import os
import numpy as np

folder_path = "/data2/zhouyz/rec/MSD/muq-alllayers-mean-m4afil/"
data_dict = {}
files = os.listdir(folder_path)
for file in files:
    if file.endswith('.npy'):
        file_path = os.path.join(folder_path, file)
        file_prefix = file.split('.')[0]
        data = np.load(file_path)
        data_dict[file_prefix] = data

In [14]:
data_dict

{'TRFFFMN144D1C707E2': array([[ 0.35309762,  0.8321519 , -0.32937583, ..., -0.5965435 ,
         -2.1171436 , -1.2733753 ],
        [ 1.0116181 ,  0.6193582 ,  0.30295214, ..., -1.1658264 ,
         -2.8138828 , -0.5444085 ],
        [ 0.70444393,  0.01311614,  0.37231183, ..., -0.62723255,
         -1.2533677 , -0.19998284],
        ...,
        [ 1.0535073 , -0.00669639,  0.2890007 , ..., -0.735588  ,
         -0.14805481, -0.1313042 ],
        [ 0.53300834, -0.08967499, -0.05280351, ..., -0.79564136,
         -0.19787066, -0.33490992],
        [-0.09246495, -0.11420657,  0.0768149 , ..., -0.08545823,
         -0.07636913,  0.05735036]], dtype=float32),
 'TRFLDCY144D1DAAED1': array([[ 3.8823149e-01,  6.7558728e-02, -2.4637751e-01, ...,
         -1.4275805e+00,  2.4812395e-02, -7.5608844e-01],
        [ 5.5321729e-01,  1.6889389e-01,  4.4103211e-01, ...,
         -6.2607676e-01, -1.0993235e+00,  4.8033986e-02],
        [ 4.1547072e-01, -1.8973979e-01,  4.4631159e-01, ...,
          1.

In [13]:
with open('/user/zhouyz/rec/recbole_v2/dataset/m4a-fil/muq-alllayers-mean-m4afil.pkl', 'wb') as f:
    pickle.dump(data_dict, f)

In [15]:
# map
with open('/user/zhouyz/rec/recbole_v2/dataset/m4a-fil/id2spot.pkl', 'rb') as f:
    id2spot = pickle.load(f)
id2spot

{10: '0sObxZUKqdPfgjG9ahSlP5',
 74: '27fZwisXkxSJQ9ZmsfCLWq',
 87: '2XJNDakEmbxWqWsv7qAhu1',
 92: '53ErZun9BFfUwC6UNKdxiE',
 102: '6oEnxDyOjfChPHtIqcVVUA',
 104: '1Yuh9jaQMP00GCI6Moj33T',
 138: '6dC7G1YUnQz5hAYzZw7akw',
 156: '6OzNUyileilgx2xvUs0d4H',
 192: '7kCdrLnIs38ldALtOZoqzc',
 286: '34tr7aHsxhZ9RVTxQEAyYb',
 304: '2LJ30RzKcTBdjJLheeBJyA',
 336: '1BxZLTTQRbpaOZjtAOsrWN',
 362: '2t31y6OVkOkd860UWyegxA',
 372: '2z98iXXCkumhZomXtWxD8G',
 373: '36FfPdAPO1izvIEZ8OEZYz',
 395: '5clhOLEdyvjIhksr7b9p4A',
 450: '6DPCSLA7hJlae1rsravTuY',
 463: '5nIVgOc9WZ3IOsZlxBC256',
 468: '3Yus7xIaoc40QR30Dmw0HK',
 488: '3oIQPJNQB9fuaB8uB3omE9',
 580: '20UXPbBPmW8wiEpsZEb5Vb',
 626: '3ls7yzKBW98swfxxcqXtmZ',
 677: '7GMuol3AtrV2M4ugE8kYYy',
 700: '3EQPwPh40gSn7eqN5d6kta',
 707: '4JIfgJkxaWu9I2gdNR3wi5',
 708: '07WaEpXvWVAxnYLUPx5Bpz',
 713: '1ZTJaAKefjClF5AT19ddV9',
 740: '7A7SyddvytQczZca3rMQ5d',
 745: '1uhtcTuQ51EvD6zKBSCoZD',
 776: '6wL9MnOgXNpKWz8LLKkgOo',
 779: '7MPhjJG61hTJ3ajELZNRFe',
 856: '5vbPx