In [3]:
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torchaudio.functional as AF
import numpy as np
import torch
from torch.nn import Module, Parameter
from torch import FloatTensor
from scipy import signal
import numpy as np
from torchaudio import transforms
import matplotlib.pyplot as plt
import IPython.display as ipd
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import os
from scipy import signal
from torchaudio.functional import lfilter


In [4]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("device=", device) 

device= cpu


In [5]:
train_data_dir = './data/train'
val_data_dir = './data/val'
input_name = 'ht1-input.wav'
target_name = 'ht1-target.wav'
train_sig, _ = torchaudio.load(os.path.join(train_data_dir, input_name))
val_sig, sr = torchaudio.load(os.path.join(val_data_dir, input_name))
train_target_sig, _ = torchaudio.load(os.path.join(train_data_dir, target_name))
val_target_sig, _ = torchaudio.load(os.path.join(val_data_dir, target_name))
print(sr)

44100


In [6]:
train_input_seg = train_sig#[:,:44100*120]
train_target_seg = train_target_sig#[:,:44100*120]

# For Reproducibility

In [7]:
torch.manual_seed(0)

<torch._C.Generator at 0x2b766b82a3d0>

## Initialize Dataloader

In [8]:
class DIIRDataSet(Dataset):
    def __init__(self, input, target, sequence_length):
        self.input = input
        self.target = target
        self._sequence_length = sequence_length
        self.input_sequence = self.wrap_to_sequences(self.input, self._sequence_length)
        self.target_sequence = self.wrap_to_sequences(self.target, self._sequence_length)
        self._len = self.input_sequence.shape[0]

    def __len__(self):
        return self._len

    def __getitem__(self, index):
        return {'input': self.input_sequence[index, :, :]
               ,'target': self.target_sequence[index, :, :]}

    def wrap_to_sequences(self, data, sequence_length):
        num_sequences = int(np.floor(data.shape[0] / sequence_length))
        print(num_sequences)
        truncated_data = data[0:(num_sequences * sequence_length)]
        wrapped_data = truncated_data.reshape((num_sequences, sequence_length, 1))
        wrapped_data = wrapped_data.permute(0,2,1)
        print(wrapped_data.shape)
        return np.float32(wrapped_data)


In [9]:
train_input = train_input_seg#hp_train_sig#[:,:44100*60]
train_target = train_target_seg#[:,:44100*60]

In [10]:
train_input.squeeze(0).shape

torch.Size([14994001])

In [11]:
batch_size = 512
sequence_length = 20000
train_dataset=DIIRDataSet(train_input.squeeze(0), train_target.squeeze(0), sequence_length)
loader = DataLoader(train_dataset, batch_size=batch_size, shuffle = False, pin_memory=True, drop_last=True)

749
torch.Size([749, 1, 20000])
749
torch.Size([749, 1, 20000])


In [12]:
len(loader)


1

# Declare Model

In [13]:
class FIRNN(Module):
    def __init__(self, n_input=1, n_output=1, kernel_size=1000, n_channel=32):
        super(FIRNN, self).__init__()        
        self.conv_kz = kernel_size
        self.conv1 = nn.Conv1d(n_input, 1, kernel_size=kernel_size, stride=1)

        self.fc1 = nn.Conv1d(1, n_channel*2, kernel_size=1) 
        self.fc2 = nn.Conv1d(n_channel*2, 1, kernel_size=1)
        
        self.conv2 = nn.Conv1d(1, 1, kernel_size=kernel_size, stride=1)
        
        self.mlp_layer = nn.Sequential(
            self.fc1 ,
            nn.Tanh(),
            self.fc2,
        )


    def forward(self, x):
        #print(x.shape)
        bs = x.shape[0]
        
        x = F.pad(x, (self.conv_kz-1, 0)) #pad on the left side
        x = self.conv1(x) 

        x = self.mlp_layer(x)
 
        x = F.pad(x, (self.conv_kz-1, 0))
        x = self.conv2(x)

        return x


In [14]:
model = FIRNN(kernel_size=1000, n_channel=32)

## Define optimizer and criterion

In [15]:
import torch.nn as nn
from torch.optim import Adam

n_epochs = 100
lr = 1e-3

optimizer = Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

criterion = nn.MSELoss()

# Define train loop

In [16]:
def train(criterion, model, loader, optimizer):
    model.train()
    device = next(model.parameters()).device
    total_loss = 0
    
    for ind, batch in enumerate(loader):
        input_seq_batch = batch['input'].to(device)
        target_seq_batch = batch['target'].to(device)
        optimizer.zero_grad()
        predicted_output = model(input_seq_batch)
        
        # premphasis filter
        target_seq_batch_filt = lfilter(target_seq_batch, torch.Tensor([1,0]), torch.Tensor([1, -0.95]))
        predicted_output_filt = lfilter(predicted_output, torch.Tensor([1,0]), torch.Tensor([1, -0.95]))
        
        loss = criterion(target_seq_batch_filt, predicted_output_filt)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    total_loss /= len(loader)
    return total_loss

## Train!

In [None]:
for epoch in range(n_epochs):
    loss = train(criterion, model, loader, optimizer)
    print("Epoch {} -- Loss {:3E}".format(epoch, loss))

save_path = os.path.join('./models/model_wh_whole_audio_ep{}.pth'.format(n_epochs-1))
torch.save(model.state_dict(), save_path)
print("model saved!")

# Evaluate

In [None]:
#val_batch_size = 128
sequence_length = 20000
val_dataset=DIIRDataSet(train_input.squeeze(0), train_target.squeeze(0), sequence_length)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle = False, pin_memory=True, drop_last=True)

In [None]:
def inspect_file(path):
    print("-" * 10)
    print("Source:", path)
    print("-" * 10)
    print(f" - File size: {os.path.getsize(path)} bytes")
    print(f" - {torchaudio.info(path)}")

In [None]:
def save_audio(batch):
    #1024,512,1
    out_batch = batch.detach().cpu()
    out_batch = out_batch.squeeze(-1).flatten()
    print(out_batch.shape)
    return out_batch

In [None]:
import soundfile as sf

out_path = './output/'
sample_rate = 44100
save_tensor = torch.zeros(680,22050)
with torch.no_grad():
    for i, val_batch in enumerate(val_loader):
        input_seq_batch = val_batch['input'].to(device)
        #target_seq_batch = val_batch['target'].to(device)
        predicted_output = model(input_seq_batch)
        output_tmp = predicted_output.squeeze().detach().cpu()
        #print(output_tmp.shape)
        save_tensor[i,:] = output_tmp
    
    print(save_tensor.shape)
    out_audio = save_audio(save_tensor)
    print(out_audio.shape)
    path = os.path.join(out_path, "target_fir_pure.wav")
    print("Exporting {}".format(path))
    sf.write(path, out_audio, sample_rate,'PCM_24')
    #torchaudio.save(path, out_audio, sample_rate, encoding="PCM_S", bits_per_sample=16)
    inspect_file(path)
    