<a href="https://colab.research.google.com/github/tanakak11/DNN_WienerFilter/blob/main/DNN_WienerFilter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch>=1.2.0
!pip install torchaudio

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch.utils.data import Dataset

import torchaudio
from torchaudio import datasets, transforms
from torchaudio.datasets.utils import(
    download_url,
    extract_archive,
    walk_files,
)

import os
import glob
import pickle
from typing import Tuple
from sklearn.metrics import confusion_matrix

import numpy as np
import scipy.signal as signal
import wave as wave

from google.colab import drive

In [None]:
# drive.mount('/content/drive')
drive.mount('/content/drive', force_remount=True)

In [4]:
# Model

class Net(nn.Module):
  def __init__(self, n_in, n_h, n_out):
    super(Net, self).__init__()
    self.l1 = nn.Linear(n_in, n_h)
    self.l2 = nn.Linear(n_h, n_h)
    self.dropout2 = nn.Dropout(0.5)
    self.l3 = nn.Linear(n_h, n_out)

  def forward(self, x):
    """
    h = F.relu(self.l1(x))
    h = F.relu(self.l2(h))
    h = self.dropout2(h)
    h = torch.sigmoid(self.l3(h))
    """
    h = torch.sigmoid(self.l1(x))
    h = torch.sigmoid(self.l2(h))
    h = self.dropout2(h)
    h = self.l3(h)

    return h

In [5]:
# 出力先の指定

base_path = "/content/drive/My Drive/"
output_filename = "Output_WF_-5-10dB_20ep"
output_base_path = os.path.join(base_path, output_filename)

In [6]:
# DataLoader

URL = "train_16000"

def load_timit_item(fileid: str,
                    path: str) -> Tuple[Tensor, Tensor]:

    filter_path = os.path.join(path, fileid, "filter.cpickle")    # for WF
    mixed_path = os.path.join(path, fileid,  "input_data.cpickle")  # for WF

    
    # Wiener-Filterのフィルタ係数
    
    with open(filter_path, mode='rb') as f:     # for WF
      sn_ratio = pickle.load(f)

    with open(mixed_path, mode='rb') as f:
      data = pickle.load(f)

    return (
        data,
        sn_ratio
    )

class TIMIT16000(Dataset):

  def __init__(self,
               root: str,
               url: str = URL,
               download: bool = False) -> None:

      if url in[
            "train_16000",
            "train_16000_0.0dB",
            "train_16000_10.0dB",
            "train_16000_0-5dB",
            "train_16000_-1-4dB",
            "train_16000_-5-10dB",
            "test_16000",
            "test_16000_0.0dB",
            "test_16000_10.0dB",
            "test_16000_0-5dB",
            "test_16000_-1-4dB",
            "test_16000_-5-10dB",
            "demo_data"
      ]:

        base_url = "/content/drive/My Drive/"
        append_url = os.path.join(base_url, url)

      self._path = append_url

      dir_path = os.path.join(self._path, "*" + os.sep)

      self._dir_list = glob.glob(dir_path, recursive=True)

      if url in [
                 "test_16000",
                 "test_16000_0.0dB",
                 "test_16000_10.0dB",
                 "test_16000_0-5dB",
                 "test_16000_-1-4dB",
                 "test_16000_-5-10dB",
                 "demo_data"
      ]:
        self._dir_list = self._dir_list[:200]
      
      print(len(self._dir_list))
      print(self._dir_list)

      if url in [
                 "test_16000",
                 "test_16000_0.0dB",
                 "test_16000_10.0dB",
                 "test_16000_0-5dB",
                 "test_16000_-1-4dB",
                 "test_16000_-5-10dB",
                 "demo_data"
      ]:
        list_path = os.path.join(base_path, output_filename, "list.txt")

        with open(list_path, mode='w') as f:
          f.write('\n'.join(self._dir_list))
        
        print("done")


  def __getitem__(self, n: int) -> Tuple[Tensor, Tensor]:
    fileid = self._dir_list[n]

    return load_timit_item(fileid, self._path)

  def __len__(self) -> int:
    return len(self._dir_list)


In [None]:
# 正則化関数

def L1_loss(y, label):
  loss = torch.reduce_mean(torch.abs(y - label))
  return loss

def L2_loss(y, label):
  loss = torch.mean((torch.square(torch.abs(y - label))))
  return loss

In [None]:
# DatasetとModelの宣言

min_len = 59

epoch = 20
batchsize = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)

wf_train_data = TIMIT16000('data',
                           url = "train_16000_-5-10dB",
                           download=True)

wf_test_data = TIMIT16000('data',
                          url = "test_16000_0.0dB",
                          download=True)

train_loader = torch.utils.data.DataLoader(wf_train_data,
                                           batch_size=batchsize,
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(wf_test_data,
                                          batch_size=batchsize,
                                          shuffle=False)

model = Net(1026, 2052, 513).to(device)
optimizer = torch.optim.Adam(model.parameters())

criterion = nn.MSELoss()  # 基本的に使わない

In [None]:
# 学習済みのモデルの読み込み(必要に応じて)

model_path = os.path.join(output_base_path, "learned_model_20ep.pth")
model.load_state_dict(torch.load(model_path))

In [None]:
model.train()
for ep in range(1, epoch+1):
  tot_loss = 0.0

  for batch_idx, (data, sn_ratio) in enumerate(train_loader):

    input = data[0, :, :]
    label = sn_ratio[0, :, :]

    for i in range(1, batchsize):
      input = torch.cat((input, data[i, :, :]), 1)
      label = torch.cat((label, sn_ratio[i, :, :]), 1)

    input = torch.transpose(input, 0, 1)
    label = torch.transpose(label, 0, 1)

    input, label = input.to(device), label.to(device)

    optimizer.zero_grad()
    output = model(input)

    # loss = criterion(output, label)
    loss = L2_loss(output, label)
    loss.backward()
    optimizer.step()

    tot_loss += loss.item()

  print('Epoch: {:2d}, Average loss: {:.6f}'.format(ep, tot_loss / ((len(train_loader)) * min_len)))

In [None]:
# 学習済みモデルの保存

model_path = os.path.join(output_base_path, "learned_model_20ep.pth")
torch.save(model.state_dict(), model_path)
print("done")

In [None]:
model.eval()
test_loss = 0

with torch.no_grad():
  for i, (data, sn_ratio) in enumerate(test_loader):

    input = data[0, :, :]
    label = sn_ratio[0, :, :]

    for j in range(1, batchsize):
      input = torch.cat((input, data[j, :, :]), 1)
      label = torch.cat((label, sn_ratio[j, :, :]), 1)

    input = torch.transpose(input, 0, 1)
    label = torch.transpose(label, 0, 1)

    input, label = input.to(device), label.to(device)

    output = model(input)

    # test_loss += criterion(output, label)
    test_loss += L2_loss(output, sn_ratio)

    output_path = os.path.join(output_base_path, "{:04d}".format(i))
    os.mkdir(output_path)

    output_filter_path = os.path.join(output_path, "filter.cpickle")

    with open(output_filter_path, mode='wb') as f:
      pickle.dump(output, f)

print("done")

In [None]:
# デモ動画のロード

wf_demo_data = TIMIT16000('data',
                          url = "demo_data",
                          download=True)

demo_loader = torch.utils.data.DataLoader(wf_demo_data,
                                          batch_size=1,
                                          shuffle=False)

In [None]:
# Demo

model.eval()
test_loss = 0

with torch.no_grad():
  for i, (input, label) in enumerate(demo_loader):
    print(np.shape(input))
    print(np.shape(label))

    input = torch.transpose(input, 1, 2)
    label = torch.transpose(label, 1, 2)

    print(np.shape(input))
    print(np.shape(label))

    input, label = input.to(device), label.to(device)

    output = model(input)
    # test_loss += L2_loss(output, sn_ratio)

    output_path = os.path.join(output_base_path, "{:04d}".format(i+100))
    os.mkdir(output_path)

    output_filter_path = os.path.join(output_path, "filter.cpickle")

    with open(output_filter_path, mode='wb') as f:
      pickle.dump(output, f)