In [2]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import torch
from torch.nn import Module, Parameter
from torch import FloatTensor
from scipy import signal
import numpy as np
from torchaudio import transforms
import matplotlib.pyplot as plt
import IPython.display as ipd
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import os
from torch.optim import Adam
from scipy import signal
from torchaudio.functional import lfilter

In [3]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device=", device) 

device= cpu


In [4]:
train_data_dir = './data/train'
val_data_dir = './data/val'
input_name = 'muff-input.wav'
target_name = 'muff-target.wav'
train_sig, _ = torchaudio.load(os.path.join(train_data_dir, input_name))
val_sig, sr = torchaudio.load(os.path.join(val_data_dir, input_name))
train_target_sig, _ = torchaudio.load(os.path.join(train_data_dir, target_name))
val_target_sig, _ = torchaudio.load(os.path.join(val_data_dir, target_name))
print(sr)

44100


In [5]:
train_input = train_sig[:,:44100*60]
train_target = train_target_sig[:,:44100*60]

In [6]:
train_input.shape


torch.Size([1, 2646000])

# For Reproducibility

In [7]:
torch.manual_seed(0)

<torch._C.Generator at 0x2ba38bda2290>

## Initialize Dataloader

In [8]:
class DIIRDataSet(Dataset):
    def __init__(self, input, target, sequence_length):
        self.input = input
        self.target = target
        self._sequence_length = sequence_length
        self.input_sequence = self.wrap_to_sequences(self.input, self._sequence_length)
        self.target_sequence = self.wrap_to_sequences(self.target, self._sequence_length)
        self._len = self.input_sequence.shape[0]

    def __len__(self):
        return self._len

    def __getitem__(self, index):
        return {'input': self.input_sequence[index, :, :]
               ,'target': self.target_sequence[index, :, :]}

    def wrap_to_sequences(self, data, sequence_length):
        num_sequences = int(np.floor(data.shape[0] / sequence_length))
        #print(num_sequences)
        truncated_data = data[0:(num_sequences * sequence_length)]
        wrapped_data = truncated_data.reshape((num_sequences, sequence_length, 1))
        wrapped_data = wrapped_data.permute(0,2,1)
        print(wrapped_data.shape)
        return np.float32(wrapped_data)


In [9]:
train_input.squeeze(0).shape

torch.Size([2646000])

In [10]:
batch_size = 512
sequence_length = 512
train_dataset=DIIRDataSet(train_input.squeeze(0), train_target.squeeze(0), sequence_length)
loader = DataLoader(train_dataset, batch_size=batch_size, shuffle =True, pin_memory=True, drop_last=True) #? what does the shuffle really shuffles here?

torch.Size([5167, 1, 512])
torch.Size([5167, 1, 512])


In [11]:
len(loader)

10

# Declare Model

In [12]:
class FIRNN(Module):
    def __init__(self, n_input=1, n_output=1, kernel_size=80, n_channel=32):
        super(FIRNN, self).__init__()        
        self.conv_kz = kernel_size
        self.input_len = 512
        self.conv1 = nn.Conv1d(n_input, n_channel, kernel_size=kernel_size, stride=1)
        self.nonlinear = nn.Tanh()
        self.bn1 = nn.BatchNorm1d(n_channel)

        self.fc1 = nn.Conv1d(n_channel, n_channel*2, kernel_size=1) 
        self.fc2 = nn.Conv1d(n_channel*2, n_channel, kernel_size=1)
        
        self.conv2 = nn.Conv1d(n_channel, n_output, kernel_size=kernel_size, stride=1)
        
        self.mlp_layer = nn.Sequential(
            self.fc1 ,
            nn.Tanh(),
            self.fc2,
        )


    def forward(self, x):
        #print(x.shape)
        bs = x.shape[0]
        
        x = F.pad(x, (self.conv_kz-1, 0)) #pad on the left side
        x = self.conv1(x) 
        #print(x.shape)
        x = self.nonlinear(self.bn1(x))
        #print(x.shape)

        #x = x.view(bs, -1)
        x = self.mlp_layer(x)
        #x = self.layers(x)
        #print(x.shape)
        
        x = F.pad(x, (self.conv_kz-1, 0))
        x = self.conv2(x)
        #print(x.shape)

        return x


In [13]:
model = FIRNN(kernel_size=80, n_channel=32)

## Define optimizer and criterion

In [14]:
n_epochs = 100
lr = 1e-3

optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
# optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.1) 
criterion = nn.MSELoss()

# Define train loop

In [15]:
def train(criterion, model, loader, optimizer):
    model.train()
    device = next(model.parameters()).device
    total_loss = 0
    
    for ind, batch in enumerate(loader):
        input_seq_batch = batch['input'].to(device)
        target_seq_batch = batch['target'].to(device)
        optimizer.zero_grad()
        predicted_output = model(input_seq_batch)
        # premphasis filter
        target_seq_batch_filt = lfilter(target_seq_batch, torch.Tensor([1,0]), torch.Tensor([1, -0.95]))
        predicted_output_filt = lfilter(predicted_output, torch.Tensor([1,0]), torch.Tensor([1, -0.95]))
        
        #loss = criterion(target_seq_batch, predicted_output)
        loss = criterion(target_seq_batch_filt, predicted_output_filt)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    total_loss /= len(loader)
    return total_loss

## Train!

In [16]:
for epoch in range(n_epochs):
    loss = train(criterion, model, loader, optimizer)
    print("Epoch {} -- Loss {:3E}".format(epoch, loss))

save_path = os.path.join('./models/muff_model_firnl_mlp_firnl_premphasis_ep{}.pth'.format(n_epochs-1))
torch.save(model.state_dict(), save_path)
print("model saved!")

Epoch 0 -- Loss 2.956059E-04
Epoch 1 -- Loss 2.649233E-04
Epoch 2 -- Loss 2.604671E-04
Epoch 3 -- Loss 2.585816E-04
Epoch 4 -- Loss 2.575083E-04
Epoch 5 -- Loss 2.568084E-04
Epoch 6 -- Loss 2.562477E-04
Epoch 7 -- Loss 2.547379E-04
Epoch 8 -- Loss 2.541558E-04
Epoch 9 -- Loss 2.533623E-04
Epoch 10 -- Loss 2.526465E-04
Epoch 11 -- Loss 2.516496E-04
Epoch 12 -- Loss 2.511096E-04
Epoch 13 -- Loss 2.510554E-04
Epoch 14 -- Loss 2.503401E-04
Epoch 15 -- Loss 2.500704E-04
Epoch 16 -- Loss 2.495512E-04
Epoch 17 -- Loss 2.486954E-04
Epoch 18 -- Loss 2.476196E-04
Epoch 19 -- Loss 2.468737E-04
Epoch 20 -- Loss 2.449902E-04
Epoch 21 -- Loss 2.429432E-04
Epoch 22 -- Loss 2.410267E-04
Epoch 23 -- Loss 2.371715E-04
Epoch 24 -- Loss 2.341082E-04
Epoch 25 -- Loss 2.322392E-04
Epoch 26 -- Loss 2.304266E-04
Epoch 27 -- Loss 2.285643E-04
Epoch 28 -- Loss 2.265173E-04
Epoch 29 -- Loss 2.249473E-04
Epoch 30 -- Loss 2.240478E-04
Epoch 31 -- Loss 2.229256E-04
Epoch 32 -- Loss 2.223914E-04
Epoch 33 -- Loss 2.2

# Evaluate

In [17]:
val_batch_size = 128
sequence_length = 512
val_dataset=DIIRDataSet(train_input.squeeze(0), train_target.squeeze(0), sequence_length)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle = False, pin_memory=True, drop_last=True)

torch.Size([5167, 1, 512])
torch.Size([5167, 1, 512])


In [18]:
def inspect_file(path):
    print("-" * 10)
    print("Source:", path)
    print("-" * 10)
    print(f" - File size: {os.path.getsize(path)} bytes")
    print(f" - {torchaudio.info(path)}")

In [19]:
def save_audio(batch):
    #1024,512,1
    out_batch = batch.detach().cpu()
    out_batch = out_batch.squeeze(-1).flatten()
    print(out_batch.shape)
    return out_batch

In [20]:
import soundfile as sf

out_path = './output/'
sample_rate = 44100
save_tensor = torch.zeros(5167,512)
with torch.no_grad():
    for i, val_batch in enumerate(val_loader):
        input_seq_batch = val_batch['input'].to(device)
        #target_seq_batch = val_batch['target'].to(device)
        predicted_output = model(input_seq_batch)
        output_tmp = predicted_output.squeeze().detach().cpu()
        #print(output_tmp.shape)
        save_tensor[i,:] = output_tmp
    
    print(save_tensor.shape)
    out_audio = save_audio(save_tensor)
    print(out_audio.shape)
    path = os.path.join(out_path, "target_muff_firnl_mlp_firnl_premphasis.wav")
    print("Exporting {}".format(path))
    sf.write(path, out_audio, sample_rate,'PCM_24')
    #torchaudio.save(path, out_audio, sample_rate, encoding="PCM_S", bits_per_sample=16)
    inspect_file(path)
    

torch.Size([5167, 512])
torch.Size([2645504])
torch.Size([2645504])
Exporting ./output/target_muff_firnl_mlp_firnl_premphasis.wav
----------
Source: ./output/target_muff_firnl_mlp_firnl_premphasis.wav
----------
 - File size: 7936556 bytes
 - AudioMetaData(sample_rate=44100, num_frames=2645504, num_channels=1, bits_per_sample=24, encoding=PCM_S)
