Notebook này lập ra với mục đích:
- Inspect memory sử dụng của model hiện tại
- Đánh giá phần có thể giảm bớt để cải thiện memory sử dụng của model

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import IPython.display
from torchinfo import summary

from utils.audio import Audio
from utils.hparams import HParam

from model.embedder import SpeechEmbedder
from datasets.dataloader import create_dataloader

hp = HParam("config.yaml")
audio = Audio(hp)

  for doc in docs:


In [2]:
# testloader = create_dataloader(hp, "generate", dataset_detail=["zalo-train", "zalo-test"], scheme="test")
# it = iter(testloader)

In [3]:
# _, _, _, dvec_mel, target_wav, mixed_wav, target_mag, _, mixed_mag, _, target_stft, mixed_stft, *rest = next(it)[0]

# Profiling embedder

In [5]:
summary(SpeechEmbedder(hp)
    ,(40, 600)
    ,col_names=["kernel_size", "output_size", "num_params", "mult_adds"]
    ,row_settings=["var_names"]
)

Layer (type (var_name))                  Kernel Shape              Output Shape              Param #                   Mult-Adds
SpeechEmbedder                           --                        --                        --                        --
├─LSTM (lstm)                            --                        [14, 80, 768]             11,937,792                13,370,327,040
├─LinearNorm (proj)                      --                        [14, 256]                 --                        --
│    └─Linear (linear_layer)             [768, 256]                [14, 256]                 196,864                   2,756,096
Total params: 12,134,656
Trainable params: 12,134,656
Non-trainable params: 0
Total mult-adds (G): 13.37
Input size (MB): 0.10
Forward/backward pass size (MB): 6.91
Params size (MB): 48.54
Estimated Total Size (MB): 55.54

# Profiling VoiceFilter

In [2]:
class VoiceFilter(nn.Module):
    def __init__(self, hp):
        super(VoiceFilter, self).__init__()
        self.hp = hp
        assert hp.audio.n_fft // 2 + 1 == hp.audio.num_freq == hp.model.fc2_dim, \
            "stft-related dimension mismatch"

        self.conv = nn.Sequential(
            # cnn1
            nn.ZeroPad2d((3, 3, 0, 0)),
            nn.Conv2d(1, 64, kernel_size=(1, 7), dilation=(1, 1)),
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn2
            nn.ZeroPad2d((0, 0, 3, 3)),
            nn.Conv2d(64, 64, kernel_size=(7, 1), dilation=(1, 1)),
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn3
            nn.ZeroPad2d(2),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(1, 1)),
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn4
            nn.ZeroPad2d((2, 2, 4, 4)),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(2, 1)), # (9, 5)
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn5
            nn.ZeroPad2d((2, 2, 8, 8)),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(4, 1)), # (17, 5)
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn6
            nn.ZeroPad2d((2, 2, 16, 16)),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(8, 1)), # (33, 5)
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn7
            nn.ZeroPad2d((2, 2, 32, 32)),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(16, 1)), # (65, 5)
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn8
            nn.Conv2d(64, 8, kernel_size=(1, 1), dilation=(1, 1)), 
            nn.BatchNorm2d(8), nn.ReLU(),
        )

        self.lstm = nn.LSTM(
            8*hp.audio.num_freq + hp.embedder.emb_dim,
            hp.model.lstm_dim,
            batch_first=True,
            bidirectional=hp.model.bidirection)

        lstm_dim = 2*hp.model.lstm_dim if hp.model.bidirection else hp.model.lstm_dim
        # self.proj = nn.Linear(hp.embedder.emb_dim, hp.model.lstm_dim)
        self.fc1 = nn.Linear(lstm_dim, hp.model.fc1_dim)
        self.fc2 = nn.Linear(hp.model.fc1_dim, hp.model.fc2_dim)

    def forward(self, x, dvec):
        # x: [B, T, num_freq]
        x = x.unsqueeze(1)
        # x: [B, 1, T, num_freq]
        x = self.conv(x)
        # x: [B, 8, T, num_freq]
        x = x.transpose(1, 2).contiguous()
        # x: [B, T, 8, num_freq]
        x = x.view(x.size(0), x.size(1), -1)
        # x: [B, T, 8*num_freq]

        # dvec: [B, emb_dim]
        dvec = dvec.unsqueeze(1)
        dvec = dvec.repeat(1, x.size(1), 1)
        # dvec: [B, T, emb_dim]

        # dvec: [B, emb_dim]
        # dvec = self.proj(dvec)
        # dvec: [B, lstm_dim]
        # dvec = dvec.unsqueeze(0)
        # dvec: [1, B, lstm_dim]

        x = torch.cat((x, dvec), dim=2) # [B, T, 8*num_freq + emb_dim]

        # x, _ = self.lstm(x, (dvec, torch.zeros_like(dvec))) # [B, T, 2*lstm_dim]
        x, _ = self.lstm(x) # [B, T, 2*lstm_dim]
        x = F.relu(x)
        x = self.fc1(x) # x: [B, T, fc1_dim]
        x = F.relu(x)
        x = self.fc2(x) # x: [B, T, fc2_dim], fc2_dim == num_freq
        x = torch.sigmoid(x)
        return x

In [3]:
summary(VoiceFilter(hp)
    ,[(1,301, 601), (1, 256)]
    ,col_names=["kernel_size", "output_size", "num_params", "mult_adds"]
    ,row_settings=["var_names"]
)

Layer (type (var_name))                  Kernel Shape              Output Shape              Param #                   Mult-Adds
VoiceFilter                              --                        --                        --                        --
├─Sequential (conv)                      --                        [1, 8, 301, 601]          --                        --
│    └─ZeroPad2d (0)                     --                        [1, 1, 301, 607]          --                        --
│    └─Conv2d (1)                        [1, 64, 1, 7]             [1, 64, 301, 601]         512                       92,621,312
│    └─BatchNorm2d (2)                   [64]                      [1, 64, 301, 601]         128                       128
│    └─ReLU (3)                          --                        [1, 64, 301, 601]         --                        --
│    └─ZeroPad2d (4)                     --                        [1, 64, 307, 601]         --                        --
│    └─C

# Profiling CNN part

In [4]:
class VoiceFilter(nn.Module):
    def __init__(self, hp):
        super(VoiceFilter, self).__init__()
        self.hp = hp
        assert hp.audio.n_fft // 2 + 1 == hp.audio.num_freq == hp.model.fc2_dim, \
            "stft-related dimension mismatch"

        self.conv = nn.Sequential(
            # cnn1
            nn.ZeroPad2d((3, 3, 0, 0)),
            nn.Conv2d(1, 64, kernel_size=(1, 7), dilation=(1, 1)),
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn2
            nn.ZeroPad2d((0, 0, 3, 3)),
            nn.Conv2d(64, 64, kernel_size=(7, 1), dilation=(1, 1)),
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn3
            nn.ZeroPad2d(2),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(1, 1)),
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn4
            nn.ZeroPad2d((2, 2, 4, 4)),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(2, 1)), # (9, 5)
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn5
            nn.ZeroPad2d((2, 2, 8, 8)),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(4, 1)), # (17, 5)
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn6
            nn.ZeroPad2d((2, 2, 16, 16)),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(8, 1)), # (33, 5)
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn7
            nn.ZeroPad2d((2, 2, 32, 32)),
            nn.Conv2d(64, 64, kernel_size=(5, 5), dilation=(16, 1)), # (65, 5)
            nn.BatchNorm2d(64), nn.ReLU(),

            # cnn8
            nn.Conv2d(64, 8, kernel_size=(1, 1), dilation=(1, 1)), 
            nn.BatchNorm2d(8), nn.ReLU(),
        )

    def forward(self, x, dvec):
        # x: [B, T, num_freq]
        x = x.unsqueeze(1)
        # x: [B, 1, T, num_freq]
        x = self.conv(x)
        # x: [B, 8, T, num_freq]
        x = x.transpose(1, 2).contiguous()
        # x: [B, T, 8, num_freq]
        x = x.view(x.size(0), x.size(1), -1)
        # x: [B, T, 8*num_freq]

        # dvec: [B, emb_dim]
        dvec = dvec.unsqueeze(1)
        dvec = dvec.repeat(1, x.size(1), 1)
        # dvec: [B, T, emb_dim]

        # dvec: [B, emb_dim]
        # dvec = self.proj(dvec)
        # dvec: [B, lstm_dim]
        # dvec = dvec.unsqueeze(0)
        # dvec: [1, B, lstm_dim]

        x = torch.cat((x, dvec), dim=2) # [B, T, 8*num_freq + emb_dim]

        return x

In [5]:
summary(VoiceFilter(hp)
    ,[(1,301, 601), (1, 256)]
    ,col_names=["kernel_size", "output_size", "num_params", "mult_adds"]
    ,row_settings=["var_names"]
)

Layer (type (var_name))                  Kernel Shape              Output Shape              Param #                   Mult-Adds
VoiceFilter                              --                        --                        --                        --
├─Sequential (conv)                      --                        [1, 8, 301, 601]          --                        --
│    └─ZeroPad2d (0)                     --                        [1, 1, 301, 607]          --                        --
│    └─Conv2d (1)                        [1, 64, 1, 7]             [1, 64, 301, 601]         512                       92,621,312
│    └─BatchNorm2d (2)                   [64]                      [1, 64, 301, 601]         128                       128
│    └─ReLU (3)                          --                        [1, 64, 301, 601]         --                        --
│    └─ZeroPad2d (4)                     --                        [1, 64, 307, 601]         --                        --
│    └─C

# Profiling LSTM part

In [6]:
class VoiceFilter(nn.Module):
    def __init__(self, hp):
        super(VoiceFilter, self).__init__()
        self.hp = hp
        assert hp.audio.n_fft // 2 + 1 == hp.audio.num_freq == hp.model.fc2_dim, \
            "stft-related dimension mismatch"

        self.lstm = nn.LSTM(
            8*hp.audio.num_freq + hp.embedder.emb_dim,
            hp.model.lstm_dim,
            batch_first=True,
            bidirectional=hp.model.bidirection)

        lstm_dim = 2*hp.model.lstm_dim if hp.model.bidirection else hp.model.lstm_dim
        self.fc1 = nn.Linear(lstm_dim, hp.model.fc1_dim)
        self.fc2 = nn.Linear(hp.model.fc1_dim, hp.model.fc2_dim)

    def forward(self, x, dvec):
        # x: [B, 8, T, num_freq]
        x = x.transpose(1, 2).contiguous()
        # x: [B, T, 8, num_freq]
        x = x.view(x.size(0), x.size(1), -1)
        # x: [B, T, 8*num_freq]

        # dvec: [B, emb_dim]
        dvec = dvec.unsqueeze(1)
        dvec = dvec.repeat(1, x.size(1), 1)
        # dvec: [B, T, emb_dim]


        x = torch.cat((x, dvec), dim=2) # [B, T, 8*num_freq + emb_dim]

        x, _ = self.lstm(x) # [B, T, 2*lstm_dim]
        x = F.relu(x)
        x = self.fc1(x) # x: [B, T, fc1_dim]
        x = F.relu(x)
        x = self.fc2(x) # x: [B, T, fc2_dim], fc2_dim == num_freq
        x = torch.sigmoid(x)
        return x

In [7]:
summary(VoiceFilter(hp)
    ,[(1, 8, 301, 601), (1, 256)]
    ,col_names=["kernel_size", "output_size", "num_params", "mult_adds"]
    ,row_settings=["var_names"]
)

Layer (type (var_name))                  Kernel Shape              Output Shape              Param #                   Mult-Adds
VoiceFilter                              --                        --                        --                        --
├─LSTM (lstm)                            --                        [1, 301, 400]             8,745,600                 2,632,425,600
├─Linear (fc1)                           [400, 600]                [1, 301, 600]             240,600                   240,600
├─Linear (fc2)                           [600, 601]                [1, 301, 601]             361,201                   361,201
Total params: 9,347,401
Trainable params: 9,347,401
Non-trainable params: 0
Total mult-adds (G): 2.63
Input size (MB): 5.79
Forward/backward pass size (MB): 3.86
Params size (MB): 37.39
Estimated Total Size (MB): 47.03