In [1]:
import os
import time
import warnings

import numpy as np
import torch
import librosa
import speech_recognition
import speech_model

from utils.tokenization import BertTokenizer
from utils.classifier_utils import KorNLIProcessor
from preprocessing import preprocessing

warnings.filterwarnings(action='ignore')
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [2]:
model_speech = speech_model.Classifier()
model_speech.load_state_dict(torch.load('./output/m_speech.pt'))
model_speech.to(device)
model_speech.eval()

Classifier(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1))
    (1): ReLU()
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1))
    (4): ReLU()
    (5): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(5, 4), stride=(1, 1))
    (1): ReLU()
    (2): Batc

In [3]:
tokenizer = BertTokenizer('./data/large_v2_32k_vocab.txt', max_len=128)

model_text = torch.load('./output/m_text.pt')
model_text = model_text.module
model_text.to(device)
model_text.eval()

SequenceClassification(
  (bert): Model(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(32000, 1024, padding_idx=0)
      (position_embeddings): Embedding(384, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1)
    )
    (encoder): Encoder(
      (layer): ModuleList(
        (0): Block(
          (attention_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
          (ffn_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
          (ffn): PositionWiseFeedForward(
            (fc1): Linear(in_features=1024, out_features=4048, bias=True)
            (fc2): Linear(in_features=4048, out_features=1024, bias=True)
            (dropout): Dropout(p=0.1)
          )
          (attn): Attention(
            (query): Linear(in_features=1024, out_features=1024, bias=True)
            (key): Linear(in_features=1024, out_features=1024, bias=True)
    

In [4]:
processor = KorNLIProcessor()
output_mode = "classification"

label_list = processor.get_labels()
num_labels = len(label_list)

In [5]:
recognizer = speech_recognition.Recognizer()
recognizer.energy_threshold = 300

In [6]:
def speech_to_text(path):
    audio = speech_recognition.AudioFile(path)
    with audio as source:
        ex = recognizer.record(source)
        text = recognizer.recognize_google(audio_data=ex, language='ko-KR')
    return text

In [7]:
class InputFeatures(object):
    """A single set of features of data."""

    def __init__(self, input_ids, input_mask, segment_ids):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids

In [8]:
def convert_example_to_feature(example, label_list, max_seq_length,
                                 tokenizer, output_mode):
    """Loads a data file into a list of `InputBatch`s."""

    label_map = {label : i for i, label in enumerate(label_list)}

    tokens_a = tokenizer.tokenize(example)

    if len(tokens_a) > max_seq_length - 2:
        tokens_a = tokens_a[:(max_seq_length - 2)]

    tokens = ["[CLS]"] + tokens_a + ["[SEP]"]
    segment_ids = [0] * len(tokens)

    input_ids = tokenizer.convert_tokens_to_ids(tokens)

    input_mask = [1] * len(input_ids)

    padding = [0] * (max_seq_length - len(input_ids))
    input_ids += padding
    input_mask += padding
    segment_ids += padding

    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    feature = InputFeatures(input_ids=input_ids,
                            input_mask=input_mask,
                            segment_ids=segment_ids)
    return feature

In [9]:
def predict_speech(x):
    return torch.mean((torch.sigmoid(model_speech(x)) >= 0.5).float()).item()

In [10]:
def predict_text(input_id, segment_id, input_mask):
    logit = model_text(input_id, segment_id, input_mask, labels=None)
    return torch.nn.Softmax(dim=-1)(logit)[0, 1].item()

In [11]:
def predict(x, input_id, segment_id, input_mask, threshold=0.5):
    pred_speech = predict_speech(x)
    pred_text = predict_text(input_id, segment_id, input_mask)
    pred = 0.5 * pred_speech + 0.5 * pred_text
    pred = 1 if pred >= threshold else 0
    return pred

In [12]:
for i in range(1, 10):
    start = time.time()
    
    try:
        spectrogram = preprocessing('./wav_data/sample_0%d.wav'%i, method='mfcc', sr=22050)
        spectrogram = torch.tensor(spectrogram, device=device).float()
        spectrogram = spectrogram.permute(0, 3, 1, 2)

        text = speech_to_text('./wav_data/sample_0%d.wav'%i)
        feature = convert_example_to_feature(text, label_list, 128, tokenizer, output_mode)

        input_id = torch.tensor([feature.input_ids], dtype=torch.long).to(device)
        input_mask = torch.tensor([feature.input_mask], dtype=torch.long).to(device)
        segment_id = torch.tensor([feature.segment_ids], dtype=torch.long).to(device)

        pred = predict(spectrogram, input_id, segment_id, input_mask)

        print('sample_%d:'%i, pred, '\t run time:', time.time()-start)
    except:
        pass

sample_1: 1 	 run time: 3.04257869720459
sample_2: 1 	 run time: 1.1287448406219482
sample_4: 1 	 run time: 0.9967217445373535
sample_5: 1 	 run time: 1.0154814720153809
sample_7: 1 	 run time: 1.170450210571289
sample_8: 1 	 run time: 1.2942752838134766
sample_9: 1 	 run time: 1.1795034408569336
