# VF

In [60]:
import os
import time
import logging
import argparse

from utils.hparams import HParam
from utils.writer import MyWriter
from datasets.dataloader import create_dataloader

import os
import IPython.display
from torchinfo import summary
from thop import profile, clever_format

from utils.audio import Audio
from utils.hparams import HParam

parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config/best.yaml',
                    help="folder contain yaml files for configuration")
parser.add_argument('--clean_rerun', type=bool, default=False,
                    help="remove old checkpoint and log. Default: false")
parser.add_argument('-r', '--resume', type=bool, default=False,
                    help="resume from checkpoint. Default: false")
args = parser.parse_args(["-c", "config/super_converge/powlaw_loss.yaml"])

config = HParam(args.config)
exp = config["experiment"]
env = config["env"]

with open(args.config, 'r') as f:
    # store hparams as string
    hp_str = ''.join(f.readlines())


chkpt_dir = os.path.join(env.base_dir, env.log.chkpt_dir, exp.name)
log_dir = os.path.join(env.base_dir, env.log.log_dir, exp.name)

os.makedirs(chkpt_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
    

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(log_dir,
            '%s-%d.log' % (exp.name, time.time()))),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger()
writer = MyWriter(exp.audio, log_dir)

In [61]:
import torch

from utils.audio import Audio

from model.get_model import get_vfmodel, get_embedder, get_forward
from loss.get_criterion import get_criterion
from trainer.optimizer.get_optimizer import get_optimizer
from datasets.dataloader import create_dataloader

trainloader = create_dataloader(config, scheme="train")

# Start using exp config from this onward (for simplication)
config = config.experiment

it = iter(trainloader) # use iterator instead of for loop
device = "cuda" if config.use_cuda else "cpu"


# Init model, embedder, optim, criterion
audio = Audio(config)
embedder = get_embedder(config, train=False, device=device)
model, chkpt = get_vfmodel(config, train=True, device=device)
train_forward, _ = get_forward(config)
criterion = get_criterion(config)

if config.train.get("resume_from_chkpt") is True:
    print("Resuming optimizer and scheduler from checkpoint: %s" % config.model.pretrained_chkpt)
    optimizer, scheduler = get_optimizer(config, model, chkpt)
    if chkpt is not None:
        step = chkpt['step']
else:
    print("New optimizer")
    optimizer, scheduler = get_optimizer(config, model, None)


New optimizer


In [62]:
batch = next(it)

In [63]:
[batch["mixed_mag"].shape, batch["dvec_tensor"].shape]

[torch.Size([16, 301, 601]), torch.Size([16, 256])]

In [64]:
summary(model
    ,[(1, 301, 601), (1, 256)]
    ,col_names=["kernel_size", "output_size", "num_params", "mult_adds"]
    ,row_settings=["var_names"]
)

Layer (type (var_name))                  Kernel Shape              Output Shape              Param #                   Mult-Adds
VoiceFilter (VoiceFilter)                --                        [1, 301, 601]             --                        --
├─Sequential (conv)                      --                        [1, 8, 301, 601]          --                        --
│    └─ZeroPad2d (0)                     --                        [1, 1, 301, 607]          --                        --
│    └─Conv2d (1)                        [1, 7]                    [1, 64, 301, 601]         512                       92,621,312
│    └─BatchNorm2d (2)                   --                        [1, 64, 301, 601]         128                       128
│    └─ReLU (3)                          --                        [1, 64, 301, 601]         --                        --
│    └─ZeroPad2d (4)                     --                        [1, 64, 307, 601]         --                        --
│    └─C

In [65]:
batch["mixed_mag"] = batch["mixed_mag"].cuda()
batch["dvec_tensor"] = batch["dvec_tensor"].cuda()

In [67]:
macs, params = profile(model, inputs=(batch["mixed_mag"], batch["dvec_tensor"]))
display(clever_format([macs, params], "%.3f"))

[INFO] Register zero_ops() for <class 'torch.nn.modules.padding.ZeroPad2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_lstm() for <class 'torch.nn.modules.rnn.LSTM'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


('1.661T', '18.876M')

In [66]:
macs, params = profile(model, inputs=(batch["mixed_mag"][[0]], batch["dvec_tensor"][[0]]))
display(clever_format([macs, params], "%.3f"))

[INFO] Register zero_ops() for <class 'torch.nn.modules.padding.ZeroPad2d'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_lstm() for <class 'torch.nn.modules.rnn.LSTM'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


('103.832G', '18.876M')

# PSE-DCCRN

In [68]:
import os
import time
import logging
import argparse

from utils.hparams import HParam
from utils.writer import MyWriter
from datasets.dataloader import create_dataloader
from thop import profile, clever_format

import os
import IPython.display
from torchinfo import summary


from utils.audio import Audio
from utils.hparams import HParam

parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', type=str, default='config/best.yaml',
                    help="folder contain yaml files for configuration")
parser.add_argument('--clean_rerun', type=bool, default=False,
                    help="remove old checkpoint and log. Default: false")
parser.add_argument('-r', '--resume', type=bool, default=False,
                    help="resume from checkpoint. Default: false")
args = parser.parse_args(["-c", "config/super_converge/pse_dccrn_stft_L_asym.yaml"])

config = HParam(args.config)
exp = config["experiment"]
env = config["env"]

with open(args.config, 'r') as f:
    # store hparams as string
    hp_str = ''.join(f.readlines())


chkpt_dir = os.path.join(env.base_dir, env.log.chkpt_dir, exp.name)
log_dir = os.path.join(env.base_dir, env.log.log_dir, exp.name)

os.makedirs(chkpt_dir, exist_ok=True)
os.makedirs(log_dir, exist_ok=True)
    

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(os.path.join(log_dir,
            '%s-%d.log' % (exp.name, time.time()))),
        logging.StreamHandler()
    ]
)

logger = logging.getLogger()
writer = MyWriter(exp.audio, log_dir)

In [69]:
import torch

from utils.audio import Audio

from model.get_model import get_vfmodel, get_embedder, get_forward
from loss.get_criterion import get_criterion
from trainer.optimizer.get_optimizer import get_optimizer
from datasets.dataloader import create_dataloader

trainloader = create_dataloader(config, scheme="train")

# Start using exp config from this onward (for simplication)
config = config.experiment

it = iter(trainloader) # use iterator instead of for loop
device = "cuda" if config.use_cuda else "cpu"


# Init model, embedder, optim, criterion
audio = Audio(config)
embedder = get_embedder(config, train=False, device=device)
model, chkpt = get_vfmodel(config, train=True, device=device)
train_forward, _ = get_forward(config)
criterion = get_criterion(config)

if config.train.get("resume_from_chkpt") is True:
    print("Resuming optimizer and scheduler from checkpoint: %s" % config.model.pretrained_chkpt)
    optimizer, scheduler = get_optimizer(config, model, chkpt)
    if chkpt is not None:
        step = chkpt['step']
else:
    print("New optimizer")
    optimizer, scheduler = get_optimizer(config, model, None)


New optimizer


In [70]:
batch = next(it)

In [71]:
[batch["mixed_mag"].shape, batch["dvec_tensor"].shape]

[torch.Size([16, 301, 641]), torch.Size([16, 256])]

In [72]:
mixed_stft = batch["mixed_stft"]
b, t = mixed_stft.shape[:2]
mixed_stft_ = torch.view_as_real(mixed_stft).transpose(-2, -1).reshape(b, t, -1).transpose(-1, -2)

In [73]:
mixed_stft_.shape

torch.Size([16, 1282, 301])

In [74]:
summary(model
    ,[(16, 1282, 301), (16, 256)]
    ,col_names=["kernel_size", "output_size", "num_params", "mult_adds"]
    ,row_settings=["var_names"]
)

Layer (type (var_name))                            Kernel Shape              Output Shape              Param #                   Mult-Adds
PSE_DCCRN (PSE_DCCRN)                              5                         [16, 1282, 301]           1,743,655                 --
├─ModuleList (encoder)                             --                        --                        --                        --
│    └─Sequential (0)                              --                        [16, 32, 320, 301]        --                        --
│    │    └─ComplexConv2d (0)                      [5, 2]                    [16, 32, 320, 301]        352                       1,084,948,480
│    │    └─BatchNorm2d (1)                        --                        [16, 32, 320, 301]        64                        1,024
│    │    └─PReLU (2)                              --                        [16, 32, 320, 301]        1                         16
│    └─Sequential (1)                              --  

In [75]:
summary(model
    ,[(1, 1282, 301), (1, 256)]
    ,col_names=["kernel_size", "output_size", "num_params", "mult_adds"]
    ,row_settings=["var_names"]
)

Layer (type (var_name))                            Kernel Shape              Output Shape              Param #                   Mult-Adds
PSE_DCCRN (PSE_DCCRN)                              5                         [1, 1282, 301]            1,743,655                 --
├─ModuleList (encoder)                             --                        --                        --                        --
│    └─Sequential (0)                              --                        [1, 32, 320, 301]         --                        --
│    │    └─ComplexConv2d (0)                      [5, 2]                    [1, 32, 320, 301]         352                       67,809,280
│    │    └─BatchNorm2d (1)                        --                        [1, 32, 320, 301]         64                        64
│    │    └─PReLU (2)                              --                        [1, 32, 320, 301]         1                         1
│    └─Sequential (1)                              --         

In [76]:
batch["mixed_stft"] = mixed_stft_.cuda()
batch["dvec_tensor"] = batch["dvec_tensor"].cuda()

In [77]:
macs, params = profile(model, inputs=(batch["mixed_stft"], batch["dvec_tensor"]))
display(clever_format([macs, params], "%.3f"))

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_lstm() for <class 'torch.nn.modules.rnn.LSTM'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


('1.056T', '4.918M')

In [78]:
macs, params = profile(model, inputs=(batch["mixed_stft"][[0]], batch["dvec_tensor"][[0]]))
display(clever_format([macs, params], "%.3f"))

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_prelu() for <class 'torch.nn.modules.activation.PReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.ConvTranspose2d'>.
[INFO] Register count_lstm() for <class 'torch.nn.modules.rnn.LSTM'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.


('66.020G', '4.918M')