In [None]:
import json
import os
from dataset import ScaperLoader
from networks import DeepAttractor
import torch
from utils import *
import matplotlib.pyplot as plt
from audio_embed import utilities
import librosa
import pprint
utilities.apply_style()
%matplotlib inline
plt.style.use('dark_background')
pp = pprint.PrettyPrinter(indent=4)

run_directory = 'runs/lv_like_music_spherical_15//////'
#run_directory = 'gmm_km_runs/music_unfold_kmeans_0_it_no_unfold_20/'
#un_directory = 'runs/music_l1_psa_cce_2048_sigmoid_incoherent_attr_use_means_fixed_proj_300_norm_t_1_unfold_k_means_sigmoid_20//'
#run_directory = 'runs/urbansound_l1_psa_cce_2048_sigmoid_attr_close_norm_curr_vae_proj//'
checkpoints = sorted(os.listdir(os.path.join(run_directory, 'checkpoints')))
print (checkpoints)
saved_model_path = os.path.join(run_directory, 'checkpoints', checkpoints[-1])
print(saved_model_path)
with open(os.path.join(run_directory, 'params.json'), 'r') as f:
    params = json.load(f)
    
if 'thresh' not in run_directory and 'threshold' not in params:
    params['threshold'] = None

    
if 'normalize_embeddings' not in params:
    params['normalize_embeddings'] = 'norm' in run_directory
if 'attractor_function_type' not in params:
    params['attractor_function_type'] = 'ae'
if 'embedding_activation' not in params:
    params['embedding_activation'] = 'none'
    num_clustering_iterations=params['num_k_means_iterations']
if 'num_k_means_iterations' in params:
    params['num_clustering_iterations'] = params['num_k_means_iterations']
    params['clustering_type'] = 'kmeans'
if 'num_gaussians_per_source' not in params:
    params['num_gaussians_per_source'] = 1

test_dset = ScaperLoader(folder='/mm1/seetharaman/generated/music_44k/testing/', length=params['initial_length'], n_fft=params['n_fft'], hop_length=params['hop_length'])
test_dset[0]
device = torch.device('cuda')
model = DeepAttractor(input_size=int(params['n_fft']/2 + 1),
                       sample_rate=params['sample_rate'],
                       hidden_size=params['hidden_size'], 
                       num_layers=params['num_layers'],
                       dropout=params['dropout'], 
                       num_attractors=params['num_attractors'],
                       embedding_size=params['embedding_size'],
                       activation_type=params['activation_type'],
                       projection_size=params['projection_size'],
                       num_clustering_iterations=params['num_clustering_iterations'],
                       clustering_type=params['clustering_type'],
                       attractor_function_type=params['attractor_function_type'],
                       normalize_embeddings=params['normalize_embeddings'],
                       embedding_activation=params['embedding_activation'],
                       covariance_type=params['covariance_type'],
                       num_gaussians_per_source=params['num_gaussians_per_source'],
                       threshold=params['threshold']).to(device)


model.use_likelihoods = False
model.eval()
model.load_state_dict(torch.load(saved_model_path, map_location=lambda storage, loc: storage))
show_model(model)

pp.pprint(params)

In [None]:
import librosa
from resampy import resample
from torch import nn
#mm1/seetharaman/scaper_data/music_separation/validation/mixture/Music Delta - Reggae - mixture.wav
audio_file = '/mm1/seetharaman/scaper_data/music_separation/test/mixture/Al James - Schoolboy Facination - mixture.wav'
audio_file = '../audio/Foo Fighters - Everlong-eBG7P-K-r1Y.opus'
audio_file = '../audio/heartofgold.mp3'
#audio_file = '/mm1/seetharaman/scaper_data/music_separation/test/drums/Al James - Schoolboy Facination - drums.wav'
#audio_file = '/mm1/seetharaman/scaper_data/music_separation/test/mixture/'
#audio_file = '/mm1/seetharaman/generated/urbansound/testing/000080.wav'
# audio_file = '/mm1/seetharaman/generated/urbansound/testing/000035.wav'
# mix2, sr = librosa.load(audio_file, sr=None)#, duration=30, offset=45,  sr=44100)
# audio_file = '/mm1/seetharaman/generated/urbansound/testing/000045.wav'
# mix3, sr = librosa.load(audio_file, sr=None)#, duration=30, offset=45,  sr=44100)
# mix = mix3

def mask_mixture(source_mask, mix):
    n = len(mix)
    mix = librosa.util.fix_length(mix, n + test_dset.n_fft // 2)
    mix_stft = librosa.stft(mix, n_fft=test_dset.n_fft, hop_length=test_dset.hop_length)
    mix = librosa.istft(mix_stft, hop_length=test_dset.hop_length, length=n)
    masked_mix = mix_stft * source_mask
    source = librosa.istft(masked_mix, hop_length=test_dset.hop_length, length=n)
    return source, mix


mix, sr = librosa.load(audio_file, sr=test_dset.sr, duration=30, offset=45)
_, mix = mask_mixture(1, mix)

log_spec, stft = test_dset.transform(mix, test_dset.n_fft, test_dset.hop_length)

input_data = test_dset.whiten(log_spec)
input_data = torch.from_numpy(input_data).unsqueeze(0).requires_grad_().to(device)
one_hot = torch.from_numpy(np.eye(params['num_attractors'], params['num_attractors'])).unsqueeze(0).float().requires_grad_().to(device)

#model.clusterer.n_clusters = 3
with torch.no_grad():
    masks, attractors, embedding = model(input_data, one_hots=one_hot)
print(masks.shape)
#masks = nn.functional.softmax(10*masks, dim=-1)
#sources = nn.functional.softmax(5*sources, dim=-1)
sources = masks.squeeze(0).cpu().data.numpy()
print('Mixture')

utilities.audio(mix, sr, ext='.wav')


def fix_length(x):
    return librosa.util.fix_length(x, len(x) + test_dset.n_fft // 2)

def transform(x):
    return np.abs(librosa.stft(fix_length(x), n_fft=test_dset.n_fft, hop_length=test_dset.hop_length))

separations = []
acc = 0

res = mix

for j in range(sources.shape[-1]):
    mask = (sources[:, :, j].T) #/  (sources.sum(axis=-1).T ** 2.0)
    #mask[mask<.5] = 0
    print (mask.min(), mask.max())
    isolated, mix = mask_mixture(mask, mix)
    separations.append(isolated)
    #print (test_dset.source_labels[j])
    plt.figure(figsize=(20, 4))
    plt.subplot(121)
    plt.imshow(20*np.log(np.abs(librosa.stft(isolated)) + 1e-7), origin='lower', aspect='auto', cmap='magma')
    plt.subplot(122)
    plt.imshow(mask, origin='lower', aspect='auto', cmap='magma')
    plt.colorbar()
    plt.show()
    utilities.audio(isolated, sr, ext='.wav')
    res -= isolated
    
    if j != 0:
        acc += isolated
    else:
        utilities.audio(mix - isolated, sr, ext='.wav')
utilities.audio(res, sr, ext='.wav')

In [None]:
1/attractors[1]

In [None]:
#attractor_mean = nn.functional.normalize(attractors.mean(dim=1, keepdim=True), dim=-1, p=2).detach().squeeze(0).cpu().numpy()
#embedding_mean = nn.functional.normalize(embedding.mean(dim=1, keepdim=True), dim=-1, p=2).detach().squeeze(0).cpu().numpy()
attractors_ = attractors[0].detach().squeeze(0).cpu().numpy()
one_hot = torch.from_numpy(np.eye(params['num_attractors'], params['num_attractors'])).unsqueeze(0).float().requires_grad_().to(device)
#generated_attractors = model.attractor_function(one_hot).detach().squeeze(0).cpu().numpy()
embedding = model.embedding.detach().squeeze(0).cpu().numpy()
weights = model.weights.detach().squeeze(0).cpu().numpy()
#interpolated = attractors_.detach().cpu().numpy()
fig = plt.figure(figsize=(30, 10))
subplot = 121

for j, num_dim in enumerate([2, 3]):
    projection = '3d' if num_dim == 3 else None
    threshold = .9 if num_dim == 3 else 0.1
    ax = fig.add_subplot(subplot + j, projection=projection)
    output_transform, ax = project_embeddings(np.vstack([embedding[weights>threshold], attractors_]), 
                                              t=0.0, 
                                              num_dimensions=num_dim, 
                                              fig=fig, ax=ax, bins=None, gridsize=200)
    attractor_points = output_transform[-attractors_.shape[0]:]
    labels = ['vocals', 'drums', 'bass', 'other'] #test_dset.classes
    for i, x in enumerate(attractor_points):
        if num_dim == 3:
            if i >= len(labels):
                labels.append('extra_%d' % i)
            ax.text(attractor_points[i, 0], attractor_points[i, 1], attractor_points[i, 2], labels[i], size=20, zorder=1, color='w')
        else:
            if i >= len(labels):
                labels.append('extra_%d' % i)
            plt.annotate(labels[i], xy=x, size=16)
plt.show()

In [None]:
input_data.shape

In [None]:
model.generate_attractors(one_hot)

In [None]:
embedding.norm(p=2)

In [None]:
x = embedding.reshape(-1, params['projection_size'], embedding.shape[-1])
for i in range(x.shape[-1]):
    #print (model.attractors[0][0][0][i], model.attractors[1][0][0][i])
    plt.figure(figsize=(30, 4))
    plt.imshow(x[:, :, i].T, aspect='auto', origin='lower', cmap='seismic', vmin=-1.0, vmax=1.0)
    plt.colorbar()
    plt.show()

In [None]:
model.assignments.max()

In [None]:
attractors[2]

In [None]:
torch.split(attractors, [model.embedding_size, attractors.shape[-1] - model.embedding_size  - 1, 1], dim=-1)[2].shape

In [None]:
attractors = model.attractor_function(one_hot)
attractors[:, : ,model.embedding_size:-1] **2, nn.functional.softmax(attractors[:, : , -1], dim=-1)

In [None]:
attractors[1].shape

In [None]:
np.diag(s).shape

In [None]:
s[1:] = 0
x = np.dot(u*s, v)
plt.figure(figsize=(10, 7))
plt.imshow(x, aspect='auto', cmap='seismic')
plt.colorbar()
plt.show()

In [None]:
vtv = np.dot(embedding.T, embedding)
plt.figure(figsize=(10, 7))
plt.imshow(vtv, aspect='auto', cmap='seismic')
plt.colorbar()
plt.show()

u, s, v = np.linalg.svd(vtv)
plt.plot(s)
plt.show()

In [None]:
vtv = np.dot(embedding.T, embedding)
plt.figure(figsize=(10, 7))
plt.imshow(vtv, aspect='auto', cmap='seismic')
plt.colorbar()
plt.show()

u, s, v = np.linalg.svd(vtv)
plt.plot(s)
plt.show()

In [None]:
from loss import *

In [None]:
plt.imshow(np.dot(attractors_, attractors_.T))
plt.colorbar()
plt.show()

In [None]:
var = attractors[1]
var
#var.expand(-1, -1, model.embedding_size)

In [None]:
plt.figure(figsize=(20, 5))
labels = ['vocals', 'drums', 'bass', 'other', 'attractor', 'embedding']
#print(attractors.shape, attractor_mean.shape,embedding_mean.shape )
data = np.vstack([attractors_])
print (data.shape)
for i, attractor in enumerate(data):
    plt.plot(attractor, label=labels[i])
plt.legend()
plt.show()

plt.figure(figsize=(20, 4))
plt.imshow(data)
plt.yticks(range(data.shape[0]), labels)
plt.show()

In [None]:
attractors[2]

In [None]:
import librosa
#mm1/seetharaman/scaper_data/music_separation/validation/mixture/Music Delta - Reggae - mixture.wav
#mix, sr = librosa.load('/mm1/seetharaman/scaper_data/music_separation/test/mixture/Al James - Schoolboy Facination - mixture.wav', duration=30, offset=30,  sr=16000)
mix, sr = librosa.load('../audio/Hotel California Solo - The Eagles - Acoustic Guitar Cover-r3ebOxltJ1w.opus', duration=10, offset=10,  sr=44100)
mix2, sr = librosa.load('../audio/He\'s Back! Snare Solo - Flamnambulous--U.S. Army All-American Marching Band Audition-1MF8I-XgBq0.m4a', duration=10, offset=14,  sr=44100)
mix3, sr = librosa.load('../audio/Bass Drum solo Battle-55ISUhExonc.m4a', duration=5, offset=14,  sr=44100)

mix = mix2
utilities.audio(mix, sr, ext='.wav')

log_spec, stft = test_dset.transform(mix, test_dset.n_fft, test_dset.hop_length)
print(log_spec.shape)

input_data = log_spec
input_data -= log_spec.mean()
input_data /= log_spec.std() + 1e-7



plt.figure(figsize=(20, 4))
plt.imshow(input_data.T, origin='lower', aspect='auto')
plt.show()

input_data = torch.from_numpy(input_data).unsqueeze(0).requires_grad_().to(device)
one_hot = torch.from_numpy(np.eye(4, 4)).unsqueeze(0).float().requires_grad_().to(device)
model.clusterer.n_clusters = 1
model.num_attractors = 1
model.clusterer.n_iterations = 10
sources, attractors, embedding = model(input_data, one_hots=None) #[0]

#attractors_ = nn.functional.normalize(attractors[:, 1, :] + attractors[:, 2, :], dim=-1, p=2).unsqueeze(0) #attractors[:, :, :].sum(keepdim=True, dim=1)
attractors_ = model.embedding.mean(dim=1, keepdim=True)[0].unsqueeze(0)
print (attractors_.shape)


audio_file = '/mm1/seetharaman/scaper_data/music_separation/test/drums/Al James - Schoolboy Facination - drums.wav'
#audio_file = '../audio/heartofgold.wav'
mix, sr = librosa.load(audio_file, duration=30, offset=45,  sr=44100)

log_spec, stft = test_dset.transform(mix, test_dset.n_fft, test_dset.hop_length)

input_data = log_spec
input_data -= log_spec.mean()
input_data /= log_spec.std() + 1e-7

input_data = torch.from_numpy(input_data).unsqueeze(0).requires_grad_().to(device)
one_hot = torch.from_numpy(np.eye(4, 4)).unsqueeze(0).float().requires_grad_().to(device)

model.clusterer.n_iterations = 0
model.use_likelihoods = True
sources, _, embedding = model(input_data, one_hot) #[0]
num_batch, sequence_length, num_frequencies, embedding_size = sources.size()
attractors = (attractors[0], attractors[1], attractors[2])
sources, _, _ = model.project_embedding_onto_attractors(model.embedding, attractors, model.weights)
sources = model.invert_projection(sources)
sources = sources.clamp(0.0, 1.0)
sources = sources.view(num_batch, sequence_length, num_frequencies, -1)
sources = sources.squeeze(0).cpu().data.numpy()
print('Mixture')

utilities.audio(mix, sr, ext='.wav')
def mask_mixture(source, mix):
    n = len(mix)
    mix = librosa.util.fix_length(mix, n + test_dset.n_fft // 2)
    mix_stft = librosa.stft(mix, n_fft=test_dset.n_fft, hop_length=test_dset.hop_length)
    masked_mix = mix_stft * mask
    source = librosa.istft(masked_mix, hop_length=test_dset.hop_length, length=n)
    return source

for j in range(sources.shape[-1]):
    mask = sources[:, :, j].T
    isolated = mask_mixture(mask, mix)
    print(mask.min(), mask.max())
    plt.figure(figsize=(20, 4))
    plt.subplot(121)
    plt.imshow(20*np.log(np.abs(librosa.stft(isolated)) + 1e-7), origin='lower', aspect='auto')
    plt.subplot(122)
    plt.imshow(mask, origin='lower', aspect='auto', vmin=0.0, vmax=1.0)
    plt.show()
    utilities.audio(isolated, sr, ext='.mp3')
    utilities.audio(mix - isolated, sr, ext='.mp3')

In [None]:
attractors

In [None]:
for i in range(attractors.shape[0]):
    plt.figure(figsize=(20, 4))
    plt.imshow(embedding[:, np.argmax(attractors_, axis=-1)[i]].reshape(-1, int(params['projection_size'])).T, aspect='auto', origin='lower')
    plt.show()