In [1]:
import torch
import torchaudio
import torchaudio.functional as func
import torch.nn as nn

In [2]:
DISCRIM_DIM = (128, 128)

SAMPLE_RATE = 44000
TEMP_RES_SPEC = 1000
FREQ_RES_SPEC = 800
CONV_FEATURES = 25
TEMP_RES_ENC = 80
EMB_SIZE = 128

In [3]:
class Analyzer(nn.Module):
    
    def __init__(self, parser, sample_rate, temp_res_spec, freq_res_spec, 
                conv_features, temp_res_enc, emb_size):
        """
        sample_rate: sample rate of the audio input
        
        temp_res_spec: temporal resolution (in slices per second) of the 
            spectrogram that is performed on the audio clip
        
        freq_res_spec: frequency resolution (in total number of slices)
            of the spectrogram that is performed on the audio clip
            
        temp_res_enc: the temporal resolution (again in slices / sec) of 
            of the encoded signal that is returned from this object
            
        """
        
        super().__init__()
        
        self.sample_rate = sample_rate
        self.temp_res_spec = temp_res_spec
        self.freq_res_spec = freq_res_spec
        self.temp_res_enc = temp_res_enc
        self.conv_features = conv_features
        self.emb_size = emb_size
        
        self.freq_res_trimmed = round(self.freq_res_spec * 0.35)
        
        self.conv_width = round(temp_res_spec / temp_res_enc)
        
        conv_height1 = round(self.freq_res_trimmed / 3)
        kernel_size1 = (self.conv_width, conv_height1)
        padding_size1 = (0, conv_height1 - 1)
        stride_size1 = (self.conv_width, round(conv_height1 / 3))
        
        self.conv1 = nn.Conv2d(in_channels = 1, out_channels=conv_features,
                                kernel_size=kernel_size1, 
                                padding=padding_size1,
                                stride=stride_size1)
        
        conv_height2 = round(self.freq_res_trimmed / 13)
        kernel_size2 = (self.conv_width, conv_height2)
        padding_size2 = (0, conv_height2 - 1)
        stride_size2 = (self.conv_width, round(conv_height2 / 13))
        
        self.conv2 = nn.Conv2d(in_channels=1, out_channels=conv_features,
                                kernel_size=kernel_size2,
                                padding=padding_size2,
                                stride=stride_size2)
                
        with torch.no_grad():
            self.out_size1 = self.conv1(torch.ones(1, 1, self.temp_res_spec,
                                            self.freq_res_trimmed)).size()
            
            self.out_size2 = self.conv2(torch.ones(1, 1, self.temp_res_spec, 
                                              self.freq_res_trimmed)).size()
        
        linear_input_size = self.out_size1[3] + self.out_size2[3]
        self.linear1 = nn.Linear(linear_input_size, emb_size)
        
        self.linear2 = nn.Linear(self.conv_features, 1)
        
        self.gru = nn.GRU(input_size=emb_size, hidden_size=int(emb_size / 2), bidirectional=True)
        
        self.cel = nn.LSTMCell(input_size=emb_size, hidden_size=int(emb_size / 2))
        
        self.parser = parser

        

    def forward(self, x, single_word=False):
        window_size = self.freq_res_spec
        naud = x.size()[1]
        nsec = (naud / self.sample_rate)
        nspec = nsec * self.temp_res_spec
        hop = round(naud / nspec)
        
        spec = func.spectrogram(
                sig=x, 
                pad=0, window=torch.bartlett_window(window_size), 
                n_fft=round(window_size * 2), hop=hop, 
                ws=window_size, power=2, normalize=False)
        
        spec = torch.sqrt(spec)
        spec = spec.unsqueeze(0)
        
        out1 = self.conv1(spec)
        out2 = self.conv2(spec)
        
        out3 = torch.cat((out1.squeeze(0), out2.squeeze(0)), dim=2)
        out3 = torch.transpose(out3, 0, 1)
        
        out4 = self.linear1(out3)
        
        out5 = torch.transpose(out4, 1, 2)
        out5 = self.linear2(out5)
        out5 = torch.transpose(out5, 1, 2)
        
        out6 = self.gru(out5)
        
        words = torch.empty((1, self.emb_size, 0))
        seq_tags = torch.empty((4, 1))
        
        cel_state = torch.empty((1, int(self.emb_size / 2)))
        nn.init.xavier_uniform_(cel_state)
        
        hid_state = torch.empty((1, int(self.emb_size / 2)))
        nn.init.xavier_normal_(hid_state)
        
        bit = 0
        last_bit = 0
        new_word = False

        for s in out6[0]:
            hid_state, cel_state = self.cel(s, (hid_state, cel_state))
            
            cur_state = torch.cat((cel_state, hid_state), dim=1)
            
            res = self.linear3(cur_state)
            res = nn.functional.softmax(res, 1)
            
            seq_tags = torch.cat((seq_tags, res.transpose(0, 1)), 1)
            
            last_bit = bit
            bit = torch.argmax(res)
            
            if(bit != 0 and last_bit == 0):
                new_word = True
                
            if(bit == 1 and last_bit == 3):
                new_word = True
            
            if(new_word and not single_word):
                new_word = False
                words = torch.cat((words, cur_state.unsqueeze(2)), dim=2)
                cel_state = torch.empty((1, int(self.emb_size / 2)))
                nn.init.xavier_normal_(cel_state)
        
        if(single_word):
            words = torch.cat((words, cur_state.unsqueeze(2)), dim=2)
        
        return (seq_tags, words)

In [4]:
class SpeechNet():
    def __init__(self, high_grit, low_grit, parser, recognizer):
        """
        high_grit: reversible conv layer to go from feature-scale
                    to audio-scale resolution and vice versa
                    
        low_grit: reversible conv layer to go from feature-scale
                    to low feature-scale resolution for RNN 
                    processing
        
        parser: reversible module to go from low feature-scale 
                    resolution to discrete word encodings at much
                    lower temporal resolution (output still time-dep.)
                    
        recognizer: reversible module to go from word encodings
                    of variable time dimension to words in the
                    vocabulary by way of linear layers.
        """

        self.high_grit = high_grit
        self.low_grit = low_grit
        self.parser = parser
        self.recognizer = recognizer
        return
    
    
    def text_to_spec(self, x):
        x = self.recognizer(x)
        x = self.parser(x)
        x = self.low_grit(x)
        x = self.high_grit(x)
        return x
    
    
    def spec_to_text(self, x):
        x = self.high_grit(x)
        x = self.low_grit(x)
        x = self.parser(x)
        x = self.recognizer(x)
        return x
    
    
    def text_to_rec(self, x):
        x = self.recognizer(x)
        return x
    
    
    def spec_to_rec(self, x):
        x = self.high_grit(x)
        x = self.low_grit(x)
        x = self.parser(x)
        x = self.recognizer(x)
        return x
    
    
    def text_to_seq(self, x):
        x = self.recognizer(x)
        x = self.parser(x)
        return x
    
    
    def spec_to_seq(self, x):
        x = self.high_grit(x)
        x = self.low_grit(x)
        x = self.parser(x)
        return x
    
    
    def text_to_low(self, x):
        x = self.recognizer(x)
        x = self.parser(x)
        x = self.low_grit(x)
        return x
    
    
    def spec_to_low(self, x):
        x = self.high_grit(x)
        x = self.low_grit(x)
        return x
    
    
    def run_train(self, dataset, num_epochs=15, mon_interval=100):
        """
        dataset: an iterable dataset
        num_epochs: how many epochs to do the training
        mon_interval: how often to output monitoring info
        """
        
        for i, (sounds, labels) in enumerate(dataset):
            
            # Set up this batch of data
            spec = self.sound_to_spec(sounds)
            text = self.labels_to_text(labels)
            
            optimizer = optim.Adam(self.parameters())
            
            # Train the parser
            optimizer.zero_grad()
            self.parser.train(True)
            seq1 = self.text_to_seq(text)
            seq2 = self.spec_to_seq(sounds)
            loss1 = nn.MarginLoss(seq1, seq2)
            loss1.backward()
            optimizer.step()
            self.parser.train(False)
            
            # Train the recognizer            
            optimizer.zero_grad()
            self.recognizer.train(True)
            rec1 = self.text_to_rec(text)
            rec2 = self.spec_to_rec(spec)
            loss2 = nn.MSELoss(rec1, rec2)
            loss2.backward()
            optimizer.step()
            self.recognizer.train(False)
            
            # Train the low-grit conv layer
            optimizer.zero_grad()
            lg1 = self.spec_to_low(spec)
            lg2 = self.text_to_low(text)
            loss3 = nn.MSELoss(lg1, lg2)
            loss3.backward()
            optimizer.step()
            
            # Train the high-grit conv layer
            optimizer.zero_grad()
            hg1 = self.spec_to_high(spec)
            hg2 = self.text_to_high(text)
            loss4 = nn.MSELoss(hg1, hg2)
            loss4.backward()
            optimizer.step()
            
        return
    
    
    def forward(self, x, c = None):
        if (c is not None):
            return self.listen(x, c)
        
        else:
            return self.speak(x, c)