# sound cls model with efficientNet-pytorch for DreamAI Healthcare Hackathon
by Hyungfon Ryu | NVAITC


## code include

config.json
- csv  generator
- dataloader(GPU version, CPU version) 
- inference script
- train script



##best configuration for urban sound 
 - 2 sec clip 
 - 16khz
 - nfft : 4096
 - win_size : 1024
 - hop_size : 256
 - nmel :  128 

In [None]:
%%file config10.json
{
    "train_config": {
        "fp16_run": true,
        "output_directory": "checkpoints",
        "epochs": 100000,
        "learning_rate": 1e-4,
        "sigma": 1.0,
        "iters_per_checkpoint": 2000,
        "batch_size": 12,
        "seed": 1234,
        "checkpoint_path": "",
        "with_tensorboard": false
    },
    "data_config": {
        "training_files": "/home/hryu/cls.txt",
        "segment_length": 16000,
        "sampling_rate": 16000,
        "filter_length": 4096,
        "hop_length": 512,
        "win_length": 1024,
        "mel_fmin": 0.0,
        "mel_fmax": 8000.0
    },
    "dist_config": {
        "dist_backend": "nccl",
        "dist_url": "tcp://localhost:54321"
    },

    "waveglow_config": {
        "n_mel_channels": 80,
        "n_flows": 12,
        "n_group": 8,
        "n_early_every": 4,
        "n_early_size": 2,
        "WN_config": {
            "n_layers": 8,
            "n_channels": 256,
            "kernel_size": 3
        }
    }
}

In [None]:
%%file get_csv_cls.py
# get file lists of audioDataset
import os
import glob
import csv

def datasetsWithCSV(filename='out.csv', info='cls_dir.txt', target_dir='./', output_dir='./'):
    import os
    
    # search directory for class
    file_list = sorted(os.listdir(target_dir))
    print ("file_list: {}".format(file_list))
    
    # save mapping of class 
    info_filename = os.path.join(output_dir, info )
    f = open(info_filename, 'w', encoding='utf-8')
    for i, list in enumerate(file_list):
        print(i, list )
        f.write( '{},{}\n'.format(i,list))            
    f.close()


    #

    csv_filename = os.path.join(output_dir, filename )
    f = open(csv_filename, 'w', encoding='utf-8')
    
    for i in range(len(file_list)):
        
        data_lists = glob.glob(os.path.join(target_dir+'/'+file_list[i], '**', '*.wav'), recursive=True)
                
        wr = csv.writer(f)

        for idx, data in enumerate(data_lists):
            #print(idx, data)
            filename = data.split('/')[-1]
            cls_dir = data.split('/')[-2]
            full_wavefile = os.path.join(cls_dir,filename,  )
            #print(idx, data, cls_dir, full_wavefile, full_wavefile )
            wr.writerow([full_wavefile,i])
            
        print(i," number of files :",len(data_lists))
    
    f.close()


def main(filename, info, target_dir, output_dir ):
    datasetsWithCSV(filename=filename, info=info, target_dir=target_dir, output_dir=output_dir  )



if __name__ =="__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', '--filename',   type=str, default='cls.csv')
    parser.add_argument('-i', '--info',       type=str, default='cls_dir.txt')    
    parser.add_argument('-d', '--target_dir', type=str, default='./')          
    parser.add_argument('-w', '--output_dir', type=str, default='./')         
    args = parser.parse_args()
    
    main(args.filename, args.info, args.target_dir, args.output_dir )


In [None]:
!python /home/hryu/mel2samp.py -f /home/hryu/new.txt -c /home/hryu/config2.json -o /Protein/mel

In [None]:
mel = torch.load('/Protein/mel2/choi01_00000_cut_file_s00001_e00008.wav.pt').cpu()
plt.pcolormesh( mel  )

In [None]:
%%file mel2samp.py
## mel2sample 

import os
import random
import sys
import argparse
import json

import time
import torch
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
from scipy.signal import get_window
from scipy.io.wavfile import read

from librosa.util import pad_center, tiny
import librosa.util as librosa_util
from librosa.filters import mel as librosa_mel_fn

MAX_WAV_VALUE = 32768.0

def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
                     n_fft=800, dtype=np.float32, norm=None):

    if win_length is None:
        win_length = n_fft

    n = n_fft + hop_length * (n_frames - 1)
    x = np.zeros(n, dtype=dtype)

    # Compute the squared window at the desired length
    win_sq = get_window(window, win_length, fftbins=True)
    win_sq = librosa_util.normalize(win_sq, norm=norm)**2
    win_sq = librosa_util.pad_center(win_sq, n_fft)

    # Fill the envelope
    for i in range(n_frames):
        sample = i * hop_length
        x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
    return x


def griffin_lim(magnitudes, stft_fn, n_iters=30):
    """
    PARAMS
    ------
    magnitudes: spectrogram magnitudes
    stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
    """

    angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
    angles = angles.astype(np.float32)
    angles = torch.autograd.Variable(torch.from_numpy(angles))
    signal = stft_fn.inverse(magnitudes, angles).squeeze(1)

    for i in range(n_iters):
        _, angles = stft_fn.transform(signal)
        signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
    return signal


def dynamic_range_compression(x, C=1, clip_val=1e-5):
    """
    PARAMS
    ------
    C: compression factor
    """
    return torch.log(torch.clamp(x, min=clip_val) * C)


def dynamic_range_decompression(x, C=1):
    """
    PARAMS
    ------
    C: compression factor used to compress
    """
    return torch.exp(x) / C

class STFT(torch.nn.Module):
    """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
    def __init__(self, filter_length=800, hop_length=200, win_length=800,
                 window='hann'):
        super(STFT, self).__init__()
        self.filter_length = filter_length
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.forward_transform = None
        scale = self.filter_length / self.hop_length
        fourier_basis = np.fft.fft(np.eye(self.filter_length))

        cutoff = int((self.filter_length / 2 + 1))
        fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
                                   np.imag(fourier_basis[:cutoff, :])])

        forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
        inverse_basis = torch.FloatTensor(
            np.linalg.pinv(scale * fourier_basis).T[:, None, :])

        if window is not None:
            assert(filter_length >= win_length)
            # get window and zero center pad it to filter_length
            fft_window = get_window(window, win_length, fftbins=True)
            fft_window = pad_center(fft_window, filter_length)
            fft_window = torch.from_numpy(fft_window).float()

            # window the bases
            forward_basis *= fft_window
            inverse_basis *= fft_window

        self.register_buffer('forward_basis', forward_basis.float())
        self.register_buffer('inverse_basis', inverse_basis.float())

    def transform(self, input_data):
        num_batches = input_data.size(0)
        num_samples = input_data.size(1)

        self.num_samples = num_samples

        # similar to librosa, reflect-pad the input
        input_data = input_data.view(num_batches, 1, num_samples)
        input_data = F.pad(
            input_data.unsqueeze(1),
            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
            mode='reflect')
        input_data = input_data.squeeze(1)

        forward_transform = F.conv1d(
            input_data,
            Variable(self.forward_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        cutoff = int((self.filter_length / 2) + 1)
        real_part = forward_transform[:, :cutoff, :]
        imag_part = forward_transform[:, cutoff:, :]

        magnitude = torch.sqrt(real_part**2 + imag_part**2)
        phase = torch.autograd.Variable(
            torch.atan2(imag_part.data, real_part.data))

        return magnitude, phase

    def inverse(self, magnitude, phase):
        recombine_magnitude_phase = torch.cat(
            [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)

        inverse_transform = F.conv_transpose1d(
            recombine_magnitude_phase,
            Variable(self.inverse_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        if self.window is not None:
            window_sum = window_sumsquare(
                self.window, magnitude.size(-1), hop_length=self.hop_length,
                win_length=self.win_length, n_fft=self.filter_length,
                dtype=np.float32)
            # remove modulation effects
            approx_nonzero_indices = torch.from_numpy(
                np.where(window_sum > tiny(window_sum))[0])
            window_sum = torch.autograd.Variable(
                torch.from_numpy(window_sum), requires_grad=False)
            window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]

            # scale by hop ratio
            inverse_transform *= float(self.filter_length) / self.hop_length

        inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
        inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]

        return inverse_transform

    def forward(self, input_data):
        self.magnitude, self.phase = self.transform(input_data)
        reconstruction = self.inverse(self.magnitude, self.phase)
        return reconstruction


class TacotronSTFT(torch.nn.Module):
    def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
                 n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
                 mel_fmax=8000.0):
        super(TacotronSTFT, self).__init__()
        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = librosa_mel_fn(
            sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
        mel_basis = torch.from_numpy(mel_basis).float()
        self.register_buffer('mel_basis', mel_basis)

    def spectral_normalize(self, magnitudes):
        output = dynamic_range_compression(magnitudes)
        return output

    def spectral_de_normalize(self, magnitudes):
        output = dynamic_range_decompression(magnitudes)
        return output

    def mel_spectrogram(self, y):
        """Computes mel-spectrograms from a batch of waves
        PARAMS
        ------
        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
        RETURNS
        -------
        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
        """
        assert(torch.min(y.data) >= -1)
        assert(torch.max(y.data) <= 1)

        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data
        mel_output = torch.matmul(self.mel_basis, magnitudes)
        mel_output = self.spectral_normalize(mel_output)
        return mel_output

def files_to_list(filename, split="|"):
    with open(filename, encoding='utf-8') as f:
        filepaths_and_cls = [line.strip().split(split) for line in f]
    return filepaths_and_cls


def load_wav_to_torch(full_path):
    import librosa
    """
    Loads wavdata into torch array
    """
    #sampling_rate, data = read(full_path) #scipy version 
    data, sampling_rate = librosa.load(full_path, sr=16000, res_type='polyphase', mono=True  )

    return torch.from_numpy(data).float(), sampling_rate

class Mel2Samp(torch.utils.data.Dataset):
    """
    This is the main class that calculates the spectrogram and returns the
    spectrogram, audio pair.
    """
    def __init__(self, training_files, segment_length, filter_length,
                 hop_length, win_length, sampling_rate, mel_fmin, mel_fmax):
        self.audiopaths_and_cls = files_to_list(training_files)
        random.seed(1234)
        random.shuffle(self.audiopaths_and_cls)
        self.stft = TacotronSTFT(filter_length=filter_length,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate,
                                 mel_fmin=mel_fmin, mel_fmax=mel_fmax)
        self.segment_length = segment_length
        self.sampling_rate = sampling_rate

    

    def get_mel(self, audio):
        audio_norm = audio / MAX_WAV_VALUE
        audio_norm = audio_norm.unsqueeze(0)
        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = torch.squeeze(melspec, 0)
        return melspec

    def get_cls(self, cls):
        #print("DEBUG : cls inside : before {} {:04d} ".format( cls, int(cls) ) ) 
         
        num_cls = torch.tensor( int(cls) )
        #print("DEBUG : cls inside : after ", num_cls)        
        return num_cls        

    def get_mel_cls_pair(self, audiopath_and_cls):
        audiopath, cls = audiopath_and_cls[0], audiopath_and_cls[1]
        audio, sr = load_wav_to_torch(audiopath)
        mel = self.get_mel(audio)
        cls = self.get_cls(cls)
        #print('DEBUG', cls  )
        return (  mel, cls )        

    def __getitem__(self, index):
        # Read audio
        filepath_cls = self.audiopaths_and_cls[index]
        filepath=filepath_cls[0]
        cls = filepath_cls[1]
        #print("DEBUG :", filepath_cls  )
        #print("DEBUG :",  os.path.basename(filepath)   )
        #print("DEBUG :",   cls )
        
        audio, sampling_rate = load_wav_to_torch(filepath)
        if sampling_rate != self.sampling_rate:
            raise ValueError("{} SR doesn't match target {} SR".format(
                sampling_rate, self.sampling_rate))

        # Take segment
        if audio.size(0) >= self.segment_length:
            max_audio_start = audio.size(0) - self.segment_length
            audio_start = random.randint(0, max_audio_start)
            audio = audio[audio_start:audio_start+self.segment_length]
        else:
            audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data

        mel = self.get_mel(audio)
        cls = self.get_cls(cls)
        #print('DEBUG mel : ', mel  )
        #print('DEBUG cls : ', cls  )
         

        return (mel, cls)

    def __len__(self):
        return len(self.audiopaths_and_cls)
 

# ===================================================================
# Takes directory of clean audio and makes directory of spectrograms
# Useful for making test sets
# ===================================================================
if __name__ == "__main__":
    # Get defaults so it can work with no Sacred
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', "--filelist_path", required=True)
    parser.add_argument('-c', '--config', type=str,
                        help='JSON file for configuration')
    parser.add_argument('-o', '--output_dir', type=str,
                        help='Output directory')
    args = parser.parse_args()

    with open(args.config) as f:
        data = f.read()
    data_config = json.loads(data)["data_config"]
    mel2samp = Mel2Samp(**data_config)

    filepaths_cls = files_to_list(args.filelist_path)
 

    # Make directory if it doesn't exist
    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)
        os.chmod(args.output_dir, 0o775)


    for i,filepath_cls in enumerate(filepaths_cls):
        filepath=filepath_cls[0]
        cls = filepath_cls[1]
        tic_start = time.time()
        audio, sr = load_wav_to_torch(filepath)
        tic_load = time.time()
        dur_load = tic_load - tic_start
        print("{:d}  cls:{} load {:4.2f}sec".format(i,   cls , dur_load), end='' )
        melspectrogram = mel2samp.get_mel(audio)
        tic_mel  = time.time()
        dur_mel = tic_mel - tic_load
        print("  mel {:4.2f}sec".format(dur_mel) , end='')
        filename = os.path.basename(filepath)        
        new_filepath = args.output_dir + '/' + filename + '.pt'
        torch.save(melspectrogram, new_filepath)
        tic_save = time.time()
        dur_save = tic_save - tic_mel
        #print("DEBUG: save {:4.2f}sec : {:4.2f}sec cls:{} mel:{}".format( dur_save, tic_save - tic_start,cls, melspectrogram ) )
        print(" save {:4.2f}sec : {:4.2f}sec  ".format( dur_save, tic_save - tic_start, ) )
        #print( "cls{} {:04d}  {:04d}  {:04d}  ".format(cls, int(cls),   torch.tensor( 0 ), torch.tensor( int(cls)  )   )   )


In [None]:
%%file mel2samp_GPU_infer.py
## mel2sample 

import os
import random
import sys
import argparse
import json

import time
import torch
import torch.utils.data
import torch.nn.functional as F
from torch.autograd import Variable

import numpy as np
from scipy.signal import get_window
from scipy.io.wavfile import read

from librosa.util import pad_center, tiny
import librosa.util as librosa_util
from librosa.filters import mel as librosa_mel_fn

MAX_WAV_VALUE = 32768.0

def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
                     n_fft=800, dtype=np.float32, norm=None):

    if win_length is None:
        win_length = n_fft

    n = n_fft + hop_length * (n_frames - 1)
    x = np.zeros(n, dtype=dtype)

    # Compute the squared window at the desired length
    win_sq = get_window(window, win_length, fftbins=True)
    win_sq = librosa_util.normalize(win_sq, norm=norm)**2
    win_sq = librosa_util.pad_center(win_sq, n_fft)

    # Fill the envelope
    for i in range(n_frames):
        sample = i * hop_length
        x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
    return x


def griffin_lim(magnitudes, stft_fn, n_iters=30):
    """
    PARAMS
    ------
    magnitudes: spectrogram magnitudes
    stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
    """

    angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
    angles = angles.astype(np.float32)
    angles = torch.autograd.Variable(torch.from_numpy(angles))
    signal = stft_fn.inverse(magnitudes, angles).squeeze(1)

    for i in range(n_iters):
        _, angles = stft_fn.transform(signal)
        signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
    return signal


def dynamic_range_compression(x, C=1, clip_val=1e-5):
    """
    PARAMS
    ------
    C: compression factor
    """
    return torch.log(torch.clamp(x, min=clip_val) * C)


def dynamic_range_decompression(x, C=1):
    """
    PARAMS
    ------
    C: compression factor used to compress
    """
    return torch.exp(x) / C

class STFT(torch.nn.Module):
    """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
    def __init__(self, filter_length=800, hop_length=200, win_length=800,
                 window='hann'):
        super(STFT, self).__init__()
        self.device = torch.device('cuda')
        self.filter_length = filter_length
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = window
        self.forward_transform = None
        scale = self.filter_length / self.hop_length
        fourier_basis = np.fft.fft(np.eye(self.filter_length))

        cutoff = int((self.filter_length / 2 + 1))
        fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
                                   np.imag(fourier_basis[:cutoff, :])])

        forward_basis = torch.FloatTensor(fourier_basis[:, None, :]).to(self.device)
        inverse_basis = torch.FloatTensor(
            np.linalg.pinv(scale * fourier_basis).T[:, None, :]).to(self.device)

        if window is not None:
            assert(filter_length >= win_length)
            # get window and zero center pad it to filter_length
            fft_window = get_window(window, win_length, fftbins=True)
            fft_window = pad_center(fft_window, filter_length)
            fft_window = torch.from_numpy(fft_window).float().to(self.device)

            # window the bases
            forward_basis *= fft_window
            inverse_basis *= fft_window

        self.register_buffer('forward_basis', forward_basis.float()) 
        self.register_buffer('inverse_basis', inverse_basis.float()) 

    def transform(self, input_data):
        num_batches = input_data.size(0)
        num_samples = input_data.size(1)

        self.num_samples = num_samples

        # similar to librosa, reflect-pad the input
        input_data = input_data.view(num_batches, 1, num_samples).to(self.device)
        input_data = F.pad(
            input_data.unsqueeze(1),
            (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
            mode='reflect')
        input_data = input_data.squeeze(1)

        forward_transform = F.conv1d(
            input_data,
            Variable(self.forward_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        cutoff = int((self.filter_length / 2) + 1)
        real_part = forward_transform[:, :cutoff, :]
        imag_part = forward_transform[:, cutoff:, :]

        magnitude = torch.sqrt(real_part**2 + imag_part**2)
        phase = torch.autograd.Variable(
            torch.atan2(imag_part.data, real_part.data))

        return magnitude, phase

    def inverse(self, magnitude, phase):
        recombine_magnitude_phase = torch.cat(
            [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)

        inverse_transform = F.conv_transpose1d(
            recombine_magnitude_phase,
            Variable(self.inverse_basis, requires_grad=False),
            stride=self.hop_length,
            padding=0)

        if self.window is not None:
            window_sum = window_sumsquare(
                self.window, magnitude.size(-1), hop_length=self.hop_length,
                win_length=self.win_length, n_fft=self.filter_length,
                dtype=np.float32)
            # remove modulation effects
            approx_nonzero_indices = torch.from_numpy(
                np.where(window_sum > tiny(window_sum))[0])
            window_sum = torch.autograd.Variable(
                torch.from_numpy(window_sum), requires_grad=False)
            window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
            inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]

            # scale by hop ratio
            inverse_transform *= float(self.filter_length) / self.hop_length

        inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
        inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]

        return inverse_transform

    def forward(self, input_data):
        self.magnitude, self.phase = self.transform(input_data)
        reconstruction = self.inverse(self.magnitude, self.phase)
        return reconstruction


class TacotronSTFT(torch.nn.Module):
    def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
                 n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
                 mel_fmax=8000.0):
        super(TacotronSTFT, self).__init__()
        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.stft_fn = STFT(filter_length, hop_length, win_length)
        mel_basis = librosa_mel_fn(
            sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
        mel_basis = torch.from_numpy(mel_basis).float()
        self.register_buffer('mel_basis', mel_basis)

    def spectral_normalize(self, magnitudes):
        output = dynamic_range_compression(magnitudes)
        return output

    def spectral_de_normalize(self, magnitudes):
        output = dynamic_range_decompression(magnitudes)
        return output

    def mel_spectrogram(self, y):
        """Computes mel-spectrograms from a batch of waves
        PARAMS
        ------
        y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
        RETURNS
        -------
        mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
        """
        assert(torch.min(y.data) >= -1)
        assert(torch.max(y.data) <= 1)

        magnitudes, phases = self.stft_fn.transform(y)
        magnitudes = magnitudes.data
        mel_output = torch.matmul(self.mel_basis, magnitudes)
        mel_output = self.spectral_normalize(mel_output)
        return mel_output

def files_to_list(filename, split="|"):
    with open(filename, encoding='utf-8') as f:
        filepaths_and_cls = [line.strip().split(split) for line in f]
    return filepaths_and_cls


def load_wav_to_torch(full_path):
    import librosa
    """
    Loads wavdata into torch array
    """
    #sampling_rate, data = read(full_path) #scipy version 
    #data, sampling_rate = librosa.load(full_path, sr=16000, res_type='polyphase', mono=True, duration=2  )
    data, sampling_rate = librosa.load(full_path, sr=16000, res_type='polyphase', mono=True  )    

    return torch.from_numpy(data).float(), sampling_rate

class Mel2Samp(torch.utils.data.Dataset):
    """
    This is the main class that calculates the spectrogram and returns the
    spectrogram, audio pair.
    """
    def __init__(self, 
                 segment_length, 
                 filter_length, 
                 n_mel_channels,
                 hop_length, win_length, 
                 sampling_rate, 
                 mel_fmin, 
                 mel_fmax, 
                 files_list):
        self.files_list=files_list 
        self.device = torch.device('cuda')

        self.stft = TacotronSTFT(filter_length=filter_length,
                                 n_mel_channels=n_mel_channels,
                                 hop_length=hop_length,
                                 win_length=win_length,
                                 sampling_rate=sampling_rate,
                                 mel_fmin=mel_fmin, 
                                 mel_fmax=mel_fmax).to(self.device)
        self.segment_length = segment_length
        self.sampling_rate = sampling_rate


    def get_mel(self, audio):
        audio_norm = audio / MAX_WAV_VALUE
        audio_norm = audio_norm.unsqueeze(0)
        audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
        melspec = self.stft.mel_spectrogram(audio_norm)
        melspec = torch.squeeze(melspec, 0)
        return melspec  

    def __getitem__(self, index):
        # Read audio

        filepath = self.files_list 

        #print("DEBUG :", filepath_cls  )
        #print("DEBUG :",  os.path.basename(filepath)   )
        #print("DEBUG :",   cls )
        
        audio, sampling_rate = load_wav_to_torch(filepath)
        if sampling_rate != self.sampling_rate:
            raise ValueError("{} SR doesn't match target {} SR".format(
                sampling_rate, self.sampling_rate))

        # Take segment
        if audio.size(0) >= self.segment_length:
            max_audio_start = audio.size(0) - self.segment_length
            audio_start = random.randint(0, max_audio_start)
            audio = audio[audio_start:audio_start+self.segment_length]
        else:
            audio = torch.nn.functional.pad(audio, (0, self.segment_length - audio.size(0)), 'constant').data

        audio.to(self.device)
        mel = self.get_mel(audio)

        #print('DEBUG mel : ', mel  )
        #print('DEBUG cls : ', cls  )
         

        return (mel)

    def __len__(self):
        return len(self.files_list)
 

# ===================================================================
# Takes directory of clean audio and makes directory of spectrograms
# Useful for making test sets
# ===================================================================
if __name__ == "__main__":
    # Get defaults so it can work with no Sacred
    parser = argparse.ArgumentParser()
    parser.add_argument('-f', "--filelist_path", required=True)
    parser.add_argument('-c', '--config', type=str,
                        help='JSON file for configuration')
    parser.add_argument('-o', '--output_dir', type=str,
                        help='Output directory')
    args = parser.parse_args()

    with open(args.config) as f:
        data = f.read()
    data_config = json.loads(data)["data_config"]
    print(data_config)
    mel2samp = Mel2Samp(**data_config)

    filepaths_cls = files_to_list(args.filelist_path)
 

    # Make directory if it doesn't exist
    if not os.path.isdir(args.output_dir):
        os.makedirs(args.output_dir)
        os.chmod(args.output_dir, 0o775)


    for i,filepath_cls in enumerate(filepaths_cls):
        filepath=filepath_cls[0]
        cls = filepath_cls[1]
        #print(filepath)
        
        filename = filepath.split('/')[-1]
        cls_dir = filepath.split('/')[-2]
        #print(cls_dir, filename)
        dir_cls = os.path.join(args.output_dir, cls_dir)
        if not os.path.isdir(dir_cls):
            os.makedirs(dir_cls)
            os.chmod(dir_cls, 0o775)
        
        tic_start = time.time()
        audio, sr = load_wav_to_torch(filepath)
        tic_load = time.time()
        dur_load = tic_load - tic_start
        #print("{:d}  cls:{} load {:4.2f}sec".format(i,   cls , dur_load), end='' )
        audio.to(torch.device('cuda'))
        melspectrogram = mel2samp.get_mel(audio)
        tic_mel  = time.time()
        dur_mel = tic_mel - tic_load
        #print("  mel {:4.2f}sec".format(dur_mel) , end='')
        filename_body, ext = os.path.splitext(filename)        
        new_filepath = os.path.join( dir_cls,   filename_body + '.pt')
        torch.save(melspectrogram.cpu(), new_filepath)
        tic_save = time.time()
        dur_save = tic_save - tic_mel
        #print("DEBUG: save {:4.2f}sec : {:4.2f}sec cls:{} mel:{}".format( dur_save, tic_save - tic_start,cls, melspectrogram ) )
        #print(" save {:4.2f}sec : {:4.2f}sec  ".format( dur_save, tic_save - tic_start, ) )
        #print( "cls{} {:04d}  {:04d}  {:04d}  ".format(cls, int(cls),   torch.tensor( 0 ), torch.tensor( int(cls)  )   )   )
        if i % 1000 ==0 :
            print( i )


In [None]:
%%file train.py
import time
import torch
import torchvision
from torchvision import transforms, datasets
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
plt.ion()
import json
from PIL import Image
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler as lrs
import time
import os
import copy
import random


from mel2samp import Mel2Samp


def load_checkpoint(checkpoint_path, model, optimizer):
    assert os.path.isfile(checkpoint_path)
    print("Loading checkpoint '{}'".format(checkpoint_path))
    checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
    model.load_state_dict(checkpoint_dict['state_dict'])
    optimizer.load_state_dict(checkpoint_dict['optimizer'])
    learning_rate = checkpoint_dict['learning_rate']
    epoch = checkpoint_dict['epoch']
    print("Loaded checkpoint '{}' from epoch {}" .format(
        checkpoint_path, epoch))
    return model, optimizer, learning_rate, epoch

def save_checkpoint(model, optimizer, learning_rate, epoch, filepath):
    print("Saving model and optimizer state at epoch {} to {}".format(
          epoch, filepath))
    torch.save({'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'learning_rate': learning_rate}, filepath)   


from efficientnet_pytorch import EfficientNet

class model_efficientNet(nn.Module):

    def __init__(self, model_name, num_classes):

        super().__init__()

        self.bw2col = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 10, 1, padding=0), nn.ReLU(),
            nn.Conv2d(10, 3, 1, padding=0), nn.ReLU())

        self.enet_model = EfficientNet.from_pretrained(model_name)
        print('input_size : ', self.enet_model.get_image_size(model_name))

        self.final = nn.Sequential(
            nn.Linear(1280, 512), nn.ReLU(), nn.BatchNorm1d(512),
            nn.Linear(512, num_classes))
            
        self.soft = nn.Softmax(  )

    def forward(self, x):
        #print('1', x.shape)
        x = self.bw2col(x)
        #print('2 bw2col',x.shape)
        x = self.enet_model.extract_features(x)
        #print('3 efficientnet ',x.shape)
        x = x.max(dim=-1)[0].max(dim=-1)[0]
        #print('4 dim reduct',x.shape)
        x = self.final(x)
        #print('5 final',x.shape)
        x = self.soft(x)
        #print('6 soft',x.shape)
        return x
    

def train(config='/home/hryu/config10.json',   ch_dir='/Protein/gc2020/en_0730_6pm'):
 
    ## 4096(nfft), 1024(win), 512(hop)
    with open(config) as f:
        data = f.read()
    config = json.loads(data)
    train_config = config["train_config"]
    global data_config
    data_config = config["data_config"]

    print(data_config)
    dataset = Mel2Samp(**data_config) 

    val_split = 0.8 
    train_set_len = int(len(dataset) * val_split)
    valid_set_len = len(dataset) - train_set_len 

    train_set, valid_set = torch.utils.data.random_split(dataset, [train_set_len, valid_set_len ])

    print('Total : ',  len(dataset) )
    print('Train : ',  len(train_set) )
    print('Valid : ',  len(valid_set) )


    num_gpus = 1
    batch_size = 256 
    train_sampler = DistributedSampler(train_set) if num_gpus > 1 else None
    valid_sampler = DistributedSampler(valid_set) if num_gpus > 1 else None

    from torch.utils.data import DataLoader
    train_loader = DataLoader(train_set, 
                              num_workers=20, 
                              shuffle=True,
                              sampler=train_sampler,
                              batch_size=batch_size,
                              pin_memory=True,
                              drop_last=True)

    val_loader = DataLoader(valid_set, 
                              num_workers=20, 
                              shuffle=True,
                              sampler=valid_sampler,
                              batch_size=batch_size,
                              pin_memory=True,
                              drop_last=True)

    print(len(train_loader))
    print(len(val_loader))


    print(train_loader)

    cuda = True
    device = torch.device('cuda:0' if cuda else 'cpu')
    my_model = model_efficientNet('efficientnet-b1', 16).to(device)
    my_model = my_model.train()


    learning_rate = 0.001
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(my_model.parameters(), lr=learning_rate )


    restart = False

    epoch_offset = 0

    if restart==True :

        checkpoint_path = '/Protein/gc2020/en_0730_6pm/model_005.pt'
        my_model, optimizer, _learning_rate, epoch = load_checkpoint( checkpoint_path, my_model, optimizer)
        epoch_offset = epoch + 1

    import numpy as np 
    my_model.train()
    print(len(train_loader))
    epochs = 400
    epoch_save = 1
    i_total = 0

    print("start")
    for e in range(epoch_offset, epochs):
        tic_epoch = time.time()
        total_iter = 0
        correct_iter = 0 

        for i, batch in enumerate(train_loader):
            tic_iter = time.time()
            my_model.zero_grad()
            mel, cls = batch
            #print(mel.shape)
            mel = torch.autograd.Variable(mel.cuda())
            mymel=mel.view(mel.shape[0],1,  mel.shape[1], mel.shape[2])
            #print('mel  shape : ', mymel.shape  )

            pred = my_model(mymel)
            #print('  pred shape : ', pred,   pred.shape  )
            #print('  cls shape : ',   cls   )
            cls = torch.autograd.Variable(cls.cuda())
            loss = criterion(pred, cls)
            loss.backward()
            optimizer.step()

            _, predicted = torch.max(pred, 1)
            total_iter   += cls.size(0)
            correct_iter += (predicted == cls).sum().item()
            accu_iter = 100 * correct_iter / total_iter

            toc_iter = time.time()
            dur_iter = toc_iter - tic_iter
            print("\n e{:03d}/{:03d} i {:06d}  {:03d}/{:03d} | loss: {:08.8f} acc_iter : {:4.2f}% {:5.2f}sec/iter".format( e, epochs, i_total, i, len(train_loader),  loss.item()  , accu_iter, dur_iter ) ,end='' )
            i_total +=1

        accu_iter = 100 * correct_iter / total_iter    
        toc_epoch = time.time()
        dur_epoch = toc_epoch - tic_epoch
        remain_epoch = epochs - e -1
        dur_remain =  dur_epoch * remain_epoch
        print(' | acc_train : {:4.2f}%  | {:6.4f}sec/epoch  remain {:4.2f}min for {:d}epoch'.format(accu_iter, dur_epoch,  dur_remain/60, remain_epoch), end='')

        correct_val = 0
        total_val = 0
        tic_val = time.time()
        with torch.no_grad():

            for i, batch in enumerate(val_loader):
                mel, cls = batch
                mel = torch.autograd.Variable(mel.cuda())
                mel.shape[0], mel.shape[1], mel.shape[2]
                mymel=mel.view(mel.shape[0],1,  mel.shape[1], mel.shape[2])
                cls = torch.autograd.Variable(cls.cuda()) 

                pred = my_model(mymel)
                _, predicted = torch.max(pred, 1)
                total_val += cls.size(0)
                correct_val += (predicted == cls).sum().item()
        accu_val = 100 * correct_val / total_val
        toc_val = time.time()
        dur_val = toc_val - tic_val
        print(' | Accu_val : {:4.2f}% {:6.4f}sec'.format(accu_val, dur_val ) )
        if (e % epoch_save ==0 ) :
            PATH = os.path.join( ch_dir,'model_{:03d}.pt'.format(e) )

            #torch.save(my_model.state_dict(), PATH)
            save_checkpoint(my_model, optimizer, learning_rate, e, PATH)

            print('checkpoin saved {} with Accu_val  {:4.2f}% '.format(PATH, accu_val ))

def main(config,  ch_dir ):
    train(config=config,  ch_dir=ch_dir )



if __name__ =="__main__":
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('-c', '--config',   type=str, default='/home/hryu/config10.json')  
    parser.add_argument('-d', '--ch_dir', type=str, default='/Protein/2020/ch_01')       
    args = parser.parse_args()
    print(args)
    
    main(args.config,   args.ch_dir  )
