# Replicating Audio Texture Synthesis (by Dimitry Ulyanov)
- Original Source: [Ulyanov's Github](https://github.com/DmitryUlyanov/neural-style-audio-torch)
- Article: [Ulyanov's Blog](https://dmitryulyanov.github.io/audio-texture-synthesis-and-style-transfer/)

In [1]:
import utils
import librosa
import matplotlib.pyplot as plt
import numpy as np

import importlib
from pathlib import Path

In [2]:
DATASET = Path("/home/slas3r/courses/dl/project/music_st/fma")

In [None]:
tfile = DATASET / "000/000002.mp3"

In [None]:
wave, sr = librosa.load(tfile)

In [None]:
fig = plt.figure(figsize=(10, 5))
librosa.display.waveshow(wave)
plt.show()

### STFT

In [None]:
spectrum, phase = utils.stft_wave_to_spectrum(wave)

In [None]:
spectrum.shape

In [None]:
librosa.display.specshow(librosa.power_to_db(spectrum, ref=np.max), sr=sr, x_axis="time", cmap="magma")
plt.colorbar(label="dB")
plt.title('STFT-Spectrogram', fontdict=dict(size=18))
plt.xlabel('Time', fontdict=dict(size=15))
plt.ylabel('Frequency', fontdict=dict(size=15))
plt.show()

In [None]:
reconstructed_wave = utils.stft_spectrum_to_wave(spectrum)

In [None]:
len(reconstructed_wave), len(wave)

In [None]:
fig = plt.figure(figsize=(10, 5))
plt.plot(reconstructed_wave)
plt.show()

# Mel Spectogram

In [None]:
importlib.reload(utils)

In [None]:
mel_specgram = utils.mel_wave_to_spectrum(wave, sr)
power_to_db = librosa.power_to_db(mel_specgram, ref=np.max)
plt.figure(figsize=(8, 7))
librosa.display.specshow(power_to_db, sr=sr, x_axis="time", y_axis="mel", cmap="magma")
plt.colorbar(label="dB")
plt.title('Mel-Spectrogram (dB)', fontdict=dict(size=18))
plt.xlabel('Time', fontdict=dict(size=15))
plt.ylabel('Frequency', fontdict=dict(size=15))
plt.show()

In [None]:
reconstructed_wave_mel = utils.mel_spectrum_to_wave(mel_specgram, sr)

In [None]:
fig = plt.figure(figsize=(10, 5))
librosa.display.waveshow(wave)
plt.show()

## Training

In [3]:
import model as model_m
import torch
from torch.autograd import Variable

import time

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
content_file = DATASET / "056/056030.mp3"
style_file = DATASET / "045/045101.mp3"

In [6]:
content_wave, content_sr = librosa.load(content_file)
style_wave, style_sr = librosa.load(style_file)

In [None]:
method = "stft"

In [None]:
content_mat, content_phase = utils.stft_wave_to_spectrum(content_wave)
style_mat, style_phase = utils.stft_wave_to_spectrum(style_wave)

In [None]:
content_mat = torch.from_numpy(content_mat)[None, None, :, :].to(device)
style_mat = torch.from_numpy(style_mat)[None, None, :, :].to(device)

In [None]:
content_mat.device

In [None]:
model = model_m.RandomCNN()

In [None]:
model.eval()

In [None]:
model.to(device)

In [None]:
a_C_var = Variable(content_mat, requires_grad=False).float().to(device)
a_S_var = Variable(style_mat, requires_grad=False).float().to(device)

In [None]:
a_C = model(a_C_var)
a_S = model(a_S_var)

In [None]:
lr = 0.002
n_epochs = 20000
content_weight = 1e2
style_weight = 10

In [None]:
a_G_var = Variable(torch.randn(content_mat.shape) * 1e-3).to(device)
a_G_var.requires_grad = True

In [None]:
optimizer = torch.optim.Adam([a_G_var])

In [None]:
c_loss = 0
losses = []

In [None]:
for ep in range(1, n_epochs + 1):
    optimizer.zero_grad()
    a_G = model(a_G_var)

    content_loss = content_weight * utils.compute_content_loss(a_C, a_G)
    style_loss = style_weight * utils.compute_layer_style_loss(a_S, a_G)
    loss = content_loss + style_loss
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if ep % 1000 == 0:
        print(f"[Epoch {ep}/{n_epochs}] content_loss: {content_loss:.4f} style_loss: {style_loss:.4f}")

In [None]:
gen_spec = a_G_var.cpu().data.numpy().squeeze()
gen_audio = utils.stft_spectrum_to_wave(gen_spec)

In [None]:
utils.writefile(f"./gen_c_{method}_{content_weight}_{style_weight}_{content_file.parts[-1].split('.')[0]}_s{style_file.parts[-1].split('.')[0]}_{time.strftime("%m_%d_%H_%M_%S")}.wav", gen_audio, sr)

In [None]:
librosa.display.specshow(librosa.power_to_db(gen_spec, ref=np.max), sr=sr, x_axis="time", cmap="magma")
plt.colorbar(label="dB")
plt.title('STFT-Spectrogram', fontdict=dict(size=18))
plt.xlabel('Time', fontdict=dict(size=15))
plt.ylabel('Frequency', fontdict=dict(size=15))
plt.show()

In [None]:
plt.plot(losses)
plt.show()

### MEL

In [None]:
method = "mel"

In [None]:
content_mat = utils.mel_wave_to_spectrum(content_wave, content_sr)
style_mat = utils.mel_wave_to_spectrum(style_wave, style_sr)
content_mat = torch.from_numpy(content_mat)[None, None, :, :].to(device)
style_mat = torch.from_numpy(style_mat)[None, None, :, :].to(device)

In [None]:
model = model_m.RandomCNN()
model.eval()
model.to(device)

In [None]:
a_C_var = Variable(content_mat, requires_grad=False).float().to(device)
a_S_var = Variable(style_mat, requires_grad=False).float().to(device)
a_C = model(a_C_var)
a_S = model(a_S_var)

In [None]:
lr = 0.002
n_epochs = 200000
content_weight = 1e2
style_weight = 10

In [None]:
a_G_var = Variable(torch.randn(content_mat.shape) * 1e-3).to(device)
a_G_var.requires_grad = True

In [None]:
optimizer = torch.optim.Adam([a_G_var])

In [None]:
c_loss = 0
losses = []

In [None]:
for ep in range(1, n_epochs + 1):
    optimizer.zero_grad()
    a_G = model(a_G_var)

    content_loss = content_weight * utils.compute_content_loss(a_C, a_G)
    style_loss = style_weight * utils.compute_layer_style_loss(a_S, a_G)
    loss = content_loss + style_loss
    loss.backward()
    optimizer.step()
    losses.append(loss.item())

    if ep % 1000 == 0:
        print(f"[Epoch {ep}/{n_epochs}] content_loss: {content_loss:.4f} style_loss: {style_loss:.4f}")

In [None]:
gen_spec = a_G_var.cpu().data.numpy().squeeze()
gen_audio = utils.mel_spectrum_to_wave(gen_spec, content_sr)

In [None]:
utils.writefile(f"./gen_c_{method}_{content_weight}_{style_weight}_{content_file.parts[-1].split('.')[0]}_s{style_file.parts[-1].split('.')[0]}_{time.strftime("%m_%d_%H_%M_%S")}.wav", gen_audio, sr)

In [None]:
librosa.display.specshow(librosa.power_to_db(gen_spec, ref=np.max), sr=sr, x_axis="time", cmap="magma")
plt.colorbar(label="dB")
plt.title('MEL-Spectrogram', fontdict=dict(size=18))
plt.xlabel('Time', fontdict=dict(size=15))
plt.ylabel('Frequency', fontdict=dict(size=15))
plt.show()

In [None]:
plt.plot(losses)
plt.show()