In [1]:
import torch 
import torch.nn as nn
import numpy as np
import librosa
from torchsummary import summary
from torch.utils.data import DataLoader, Dataset, TensorDataset
from scipy.io import wavfile
import wandb

In [14]:

in_path = "D:\\Documents\\CMU_SUBJECTS\\BlackBoxAudioFx\\PedalNetRT\\data\\ts9_test1_in_FP32.wav"
out_path = "D:\\Documents\\CMU_SUBJECTS\\BlackBoxAudioFx\\PedalNetRT\\data\\ts9_test1_out_FP32.wav"  


# Audio pre-processing and book keeping 

in_rate, in_data = wavfile.read(in_path)
out_rate, out_data = wavfile.read(out_path)
assert in_rate == out_rate, "in_file and out_file must have same sample rate"

# Load audio files  
def normalize(data):
    data_max = max(data)
    data_min = min(data)
    data_norm = max(data_max,abs(data_min))
    return data / data_norm

# Trim the length of audio to equal the smaller wav file
if len(in_data) > len(out_data):
    print("Trimming input audio to match output audio")
in_data = in_data[0:len(out_data)]
if len(out_data) > len(in_data): 
    print("Trimming output audio to match input audio")
out_data = out_data[0:len(in_data)]

# If stereo data, use channel 0
if len(in_data.shape) > 1:
    print("[WARNING] Stereo data detected for in_data, only using first channel (left channel)")
    in_data = in_data[:,0]
if len(out_data.shape) > 1:
    print("[WARNING] Stereo data detected for out_data, only using first channel (left channel)")
    out_data = out_data[:,0]

#normalize data
if normalize == True:
    in_data = normalize(in_data)
    out_data = normalize(out_data)

# Convert PCM16 to FP32
if in_data.dtype == "int16":
    in_data = in_data/32767
    print("In data converted from PCM16 to FP32")
if out_data.dtype == "int16":
    out_data = out_data/32767
    print("Out data converted from PCM16 to FP32")

sample_time = 100e-3
sample_size = int(in_rate * sample_time)
length = len(in_data) - len(in_data) % sample_size

x = in_data[:length].reshape((-1, 1, sample_size)).astype(np.float32)
y = out_data[:length].reshape((-1, 1, sample_size)).astype(np.float32)

split = lambda d: np.split(d, [int(len(d) * 0.6), int(len(d) * 0.8)])

d = {}
d["x_train"], d["x_valid"], d["x_test"] = split(x)
d["y_train"], d["y_valid"], d["y_test"] = split(y)
d["mean"], d["std"] = d["x_train"].mean(), d["x_train"].std()
for key in "x_train", "x_valid", "x_test":
    d[key] = (d[key] - d["mean"]) / d["std"] 

x_test, y_test = d["x_test"], d["y_test"]
valid_data = TensorDataset(torch.from_numpy(d["x_valid"]), torch.from_numpy(d["y_valid"]))
test_data  = TensorDataset(torch.from_numpy(d["x_test"]), torch.from_numpy(d["y_test"]))
train_data = TensorDataset(torch.from_numpy(d["x_train"]), torch.from_numpy(d["y_train"]))


num_workers, batch_size = 4, 64
train_loader = DataLoader(train_data, batch_size = 64, num_workers= num_workers, shuffle= True)
val_loader = DataLoader(valid_data, batch_size = 64, num_workers= num_workers)
test_loader = DataLoader(test_data, batch_size = 64, num_workers= num_workers)



In [3]:
"""
Error function of choice is the error to signal ratio 
"""
def pre_emphasis_filter(x, coeff=0.95):
   
    """ 
    y[n] = x[n] - coeff * x[n-1] coefficient adapted from paper : 0.95
    """ 
    return torch.cat((x[:, :, 0:1], x[:, :, 1:] - coeff * x[:, :, :-1]), dim=2)

def error_to_signal(y, y_pred):
    """
    Error to signal ratio with pre-emphasis filter:
    https://www.mdpi.com/2076-3417/10/3/766/html
    
    """
    y, y_pred = pre_emphasis_filter(y), pre_emphasis_filter(y_pred)
    return (y - y_pred).pow(2).sum(dim=2) / (y.pow(2).sum(dim=2) + 1e-10)





In [4]:
"""
Dilated causal convolutions in WaveNet

Causal convolutions 
Causal comes from causality, which means if we have a canonical 'direction' we are reading our data, then data that is ahead of the current position cannot factor 
into the calculation. This is most obvious in time series, so only previous timesteps factor into the current and not something 'future' relative to the current. 
But note it can also be applied to other forms of data like 2D images (like in PixelCNN for e.g.)

The causal convolution concept comes about because when you do convolution, the kernel may overlap with the data from the 'future' points thus breaking causality. 
We don't want this so usually we introduce some kind of zero masking onto these points. This masking procedure is what sets apart causal convolution from standard 
convolution.

"""


class CausalConv1d(torch.nn.Conv1d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True):
        self.__padding = (kernel_size - 1) * dilation

        super(CausalConv1d, self).__init__(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=self.__padding,
            dilation=dilation,
            groups=groups,
            bias=bias,
        )

    def forward(self, input):
        result = super(CausalConv1d, self).forward(input)
        if self.__padding != 0:
            return result[:, :, : -self.__padding]
        return result


def _conv_stack(dilations, in_channels, out_channels, kernel_size):
    """
    Create stack of dilated convolutional layers, outlined in WaveNet paper:
    https://arxiv.org/pdf/1609.03499.pdf
    """
    return nn.ModuleList(
        [
            CausalConv1d(
                in_channels=in_channels,
                out_channels=out_channels,
                dilation=d,
                kernel_size=kernel_size,
            )
            for i, d in enumerate(dilations)
        ]
    )


class WaveNet(nn.Module):
    def __init__(self, num_channels, dilation_depth, num_repeat, kernel_size=2):
        super(WaveNet, self).__init__()
        dilations = [2 ** d for d in range(dilation_depth)] * num_repeat
        internal_channels = int(num_channels * 2)
        self.hidden = _conv_stack(dilations, num_channels, internal_channels, kernel_size)
        self.residuals = _conv_stack(dilations, num_channels, num_channels, 1)
        self.input_layer = CausalConv1d(
            in_channels=1,
            out_channels=num_channels,
            kernel_size=1,
        )

        self.linear_mix = nn.Conv1d(
            in_channels=num_channels * dilation_depth * num_repeat,
            out_channels=1,
            kernel_size=1,
        )
        self.num_channels = num_channels

    def forward(self, x):
        out = x
        skips = []
        out = self.input_layer(out)

        for hidden, residual in zip(self.hidden, self.residuals):
            x = out
            out_hidden = hidden(x)

            # gated activation
            # split (32,16,3) into two (16,16,3) for tanh and sigm calculations
            out_hidden_split = torch.split(out_hidden, self.num_channels, dim=1)
            out = torch.tanh(out_hidden_split[0]) * torch.sigmoid(out_hidden_split[1])

            skips.append(out)

            out = residual(out)
            out = out + x[:, :, -out.size(2) :]

        # modified "postprocess" step:
        out = torch.cat([s[:, :, -out.size(2) :] for s in skips], dim=1)
        out = self.linear_mix(out)
        return out




   



        

In [5]:
"""
hyperparameters 

"""

num_channels = 4
dilation_depth = 9
num_repeat = 2 
kernel_size =3
learning_rate, batch_size = 3e-3, 64
wavenet_model = WaveNet(
            num_channels,
            dilation_depth,
            num_repeat,
            kernel_size
        )
summary(wavenet_model) 

device = "cpu"
if torch.cuda.is_available():
    device == "cuda"
    wavenet_model.to(device)

optimizer = torch.optim.Adam(wavenet_model.parameters(), lr= learning_rate)


Layer (type:depth-idx)                   Param #
├─ModuleList: 1-1                        --
|    └─CausalConv1d: 2-1                 104
|    └─CausalConv1d: 2-2                 104
|    └─CausalConv1d: 2-3                 104
|    └─CausalConv1d: 2-4                 104
|    └─CausalConv1d: 2-5                 104
|    └─CausalConv1d: 2-6                 104
|    └─CausalConv1d: 2-7                 104
|    └─CausalConv1d: 2-8                 104
|    └─CausalConv1d: 2-9                 104
|    └─CausalConv1d: 2-10                104
|    └─CausalConv1d: 2-11                104
|    └─CausalConv1d: 2-12                104
|    └─CausalConv1d: 2-13                104
|    └─CausalConv1d: 2-14                104
|    └─CausalConv1d: 2-15                104
|    └─CausalConv1d: 2-16                104
|    └─CausalConv1d: 2-17                104
|    └─CausalConv1d: 2-18                104
├─ModuleList: 1-2                        --
|    └─CausalConv1d: 2-19                20
|    └─Ca

In [6]:
"""
Simple class for early stopping 
[stack overflow : https://stackoverflow.com/questions/71998978/early-stopping-in-pytorch]
"""

class EarlyStopper:
    def __init__(self, patience=1, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.min_validation_loss = np.inf

    def early_stop(self, validation_loss):
        if validation_loss < self.min_validation_loss:
            self.min_validation_loss = validation_loss
            self.counter = 0
        elif validation_loss > (self.min_validation_loss + self.min_delta):
            self.counter += 1
            if self.counter >= self.patience:
                return True
        return False
stopEARLY = EarlyStopper(patience= 20, min_delta = 0)

## Backprop with a custom divergence function error to signal ratio 

In [7]:
"""
Train loop with early stopping 
"""
EPOCHS = 1000

for epoch in range(EPOCHS):
    wavenet_model.train()
    for batch, (train_in, train_out) in enumerate(train_loader): 

        optimizer.zero_grad()
        train_in, train_out = train_in.to(device), train_out.to(device)
        out = wavenet_model(train_in)
        train_loss = error_to_signal(out[:,:,-out.size(2):], train_out).mean()
        train_loss.backward()
        train_loss = train_loss.detach().numpy()
        optimizer.step()

        del train_in

    wavenet_model.eval()

    for batch, (valid_in, valid_out) in enumerate(val_loader): 
        valid_in, valid_out = valid_in.to(device), valid_out.to(device)
        out = wavenet_model(valid_in)
        valid_loss = error_to_signal(out[:,:,-out.size(2):], valid_out).mean()
        valid_loss = valid_loss.detach().numpy()
        del valid_in
    
    if stopEARLY.early_stop(validation_loss= valid_loss):
        break 

    torch.save({
            'epoch': epoch,
            'model_state_dict': wavenet_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'train_loss': train_loss,'valid_loss': valid_loss,
            }, "D:\\Documents\\CMU_SUBJECTS\\BlackBoxAudioFx\\NeuralAudioModelling\\ckpt\\epoch_{}.pt".format(epoch))
    print("Epoch: [{}/{}] ---- Train Loss: {}, Valid Loss: {} ".format(epoch, EPOCHS, train_loss, valid_loss))
        


Epoch: [0/1000] ---- Train Loss: 0.8094963431358337, Valid Loss: 0.8480744361877441 
Epoch: [1/1000] ---- Train Loss: 0.6749414801597595, Valid Loss: 0.5772479176521301 
Epoch: [2/1000] ---- Train Loss: 0.6664125323295593, Valid Loss: 0.5072468519210815 
Epoch: [3/1000] ---- Train Loss: 0.5741808414459229, Valid Loss: 0.43390342593193054 
Epoch: [4/1000] ---- Train Loss: 0.5099707841873169, Valid Loss: 0.4239423871040344 
Epoch: [5/1000] ---- Train Loss: 0.5278087854385376, Valid Loss: 0.401146799325943 
Epoch: [6/1000] ---- Train Loss: 0.3982592821121216, Valid Loss: 0.36667245626449585 
Epoch: [7/1000] ---- Train Loss: 0.33165380358695984, Valid Loss: 0.3450269401073456 
Epoch: [8/1000] ---- Train Loss: 0.24998284876346588, Valid Loss: 0.33827003836631775 
Epoch: [9/1000] ---- Train Loss: 0.2462148219347, Valid Loss: 0.3360155522823334 
Epoch: [10/1000] ---- Train Loss: 0.25515618920326233, Valid Loss: 0.3237763047218323 
Epoch: [11/1000] ---- Train Loss: 0.2669546604156494, Valid Lo

In [18]:
## Testing the model 

def save(name, data): 
    wavfile.write(name, 44100, data.flatten().astype(np.float32))

load_ckpt = torch.load("D:\\Documents\\CMU_SUBJECTS\\BlackBoxAudioFx\\NeuralAudioModelling\\ckpt\\epoch_112.pt")
wavenet_model.load_state_dict(load_ckpt['model_state_dict'])
wavenet_model.eval()

prev_sample = np.concatenate((np.zeros_like(x_test[0:1]), x_test[:-1]), axis=0)
pad_x_test = np.concatenate((prev_sample, x_test), axis=2)

y_pred = []
for x in np.array_split(pad_x_test, 10):
    y_pred.append(wavenet_model(torch.from_numpy(x)).detach().numpy())

y_pred = np.concatenate(y_pred)
y_pred = y_pred[:, :, -x_test.shape[2] :]
save_path = "D:\\Documents\\CMU_SUBJECTS\\BlackBoxAudioFx\\NeuralAudioModelling\\"
save(save_path + "y_pred.wav", y_pred)
save(save_path + "x_test.wav", d["x_test"] * d["std"] + d["mean"])
save(save_path + "y_test.wav", d["y_test"])

