<a href="https://colab.research.google.com/github/soohyunme/Foreigner_speech/blob/main/Code/ASR_DeepSpeech2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch
!pip install torchaudio
!pip install hangul_utils

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# 모듈 설치 및 로드

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
from glob import glob
import os
import re
import math
from scipy.io import wavfile
from collections import defaultdict, Counter
from scipy import signal
import numpy as np
import librosa
import random
import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.utils.data import Dataset, DataLoader
import torchaudio
import tarfile
import tensorflow as tf
from hangul_utils import join_jamos

from typing import Tuple, Union

from torch import Tensor
from torch.utils.data import Dataset

In [None]:
# try:
#   resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='grpc://' + os.environ['COLAB_TPU_ADDR'])

#   tf.config.experimental_connect_to_cluster(resolver)
#   tf.tpu.experimental.initialize_tpu_system(resolver)
#   strategy = tf.distribute.TPUStrategy(resolver)
# except:
#   print('TPU device not found')

In [None]:
%tensorflow_version 2.x
import tensorflow as tf
device_name = tf.test.gpu_device_name()
if device_name != '/device:GPU:0':
  print('GPU device not found')
else:
  print('Found GPU at: {}'.format(device_name))

# 경로 설정

In [None]:
data_path = '/content/drive/MyDrive/03.foreigner_speech/dataset'

# 데이터 로드

In [None]:
train_paths = glob(data_path+"/train/wav/*.wav")
# json_paths = glob(data_path+"/Train/JSON/*.json")
test_paths = glob(data_path+"/test/*.wav")
text_df = pd.read_csv(data_path+'/final_text.csv')[['fileName','new_ReadingLabelText']]

## 레이블 df에서 존재하는 데이터만 선택

In [None]:
file_list = []
for each in train_paths:
  file_list.append(each.replace(data_path+"/train/wav/",''))
file_df = pd.DataFrame(file_list,columns=['fileName'])

In [None]:
text_df = text_df.merge(file_df)
len(text_df)

In [None]:
def train_to_dataset(df):
  tuple_list = []
  for i in tqdm(range(len(df))):
    file_name, utterance = df.iloc[i,0], df.iloc[i,1]
    waveform, sr = torchaudio.load(data_path+'/train/wav/'+file_name)
    waveform = waveform[0].reshape(1,-1)
    tuple_data = (waveform, sr,utterance,0 , 0, 0)
    tuple_list.append(tuple_data)
    result = list(tuple_list)
  return result

In [None]:
# def test_to_dataset(df):
#   tuple_list = []
#   for i in tqdm(range(len(df))):
#     file_name = df.iloc[i,0]
#     waveform, sr = torchaudio.load(data_path+'./test/'+file_name)
#     waveform = waveform[0].reshape(1,-1)
#     tuple_data = (waveform, sr,'',0 , 0, 0)
#     tuple_list.append(tuple_data)
#     result = list(tuple_list)
#   return result

In [None]:
def test_to_dataset(df):
  tuple_list = []
  for i in tqdm(range(len(df))):
    file_name = df[i]
    waveform, sr = torchaudio.load(file_name)
    waveform = waveform[0].reshape(1,-1)
    tuple_data = (waveform, sr,'',0 , 0, 0)
    tuple_list.append(tuple_data)
    result = list(tuple_list)
  return result


In [None]:
len_train = int(len(text_df)*0.8)

In [None]:
train_dataset = train_to_dataset(text_df.iloc[:len_train]) 
# train_dataset = train_to_dataset(text_df) 
train_dataset[0]

In [None]:
valid_dataset = train_to_dataset(text_df.iloc[len_train:]) 
# valid_dataset = test_to_dataset(test_paths) 
valid_dataset[0]

In [None]:
test_df = pd.DataFrame(test_paths,columns=['fileName'])

In [None]:
# ★★★★★★★★★ iloc 지우기 ★★★★★★★★★★★
test_dataset = test_to_dataset(test_df['fileName']) 
test_dataset[0]

# 텍스트 처리

In [None]:
INITIALS = list("ㄱㄲㄴㄷㄸㄹㅁㅂㅃㅅㅆㅇㅈㅉㅊㅋㅌㅍㅎ")
"char list: Hangul initials (초성)"

MEDIALS = list("ㅏㅐㅑㅒㅓㅔㅕㅖㅗㅘㅙㅚㅛㅜㅝㅞㅟㅠㅡㅢㅣ")
"char list: Hangul medials (중성)"

FINALS = list("∅ㄱㄲㄳㄴㄵㄶㄷㄹㄺㄻㄼㄽㄾㄿㅀㅁㅂㅄㅅㅆㅇㅈㅊㅋㅌㅍㅎ")
"char list: Hangul finals (종성)."

SPACE_TOKEN = " "
LABELS = sorted({SPACE_TOKEN}.union(INITIALS).union(MEDIALS).union(FINALS))

"char list: All CTC labels."

def check_syllable(char):
  return 0xAC00 <= ord(char) <= 0xD7A3

def split_syllable(char):
  assert check_syllable(char)
  diff = ord(char) - 0xAC00
  _m = diff % 28
  _d = (diff - _m) // 28
  return (INITIALS[_d // 21], MEDIALS[_d % 21], FINALS[_m])

def preprocess(str):
  result = ""
  for char in re.sub("\\s+", SPACE_TOKEN, str.strip()):
    if char == SPACE_TOKEN:
      result += SPACE_TOKEN
    elif check_syllable(char):
      result += "".join(split_syllable(char))
  return result

def join_text(text):
  result = text.replace('∅','')
  result = join_jamos(result)
  return result

# 오류 함수 정의

In [None]:
def avg_wer(wer_scores, combined_ref_len):
    return float(sum(wer_scores)) / float(combined_ref_len)

def _levenshtein_distance(ref, hyp):
    """"Levenshtein distance"는 두 시퀀스 간의 차이를 측정하기위한 문자열 메트릭입니다. 
    "Levenshtein distanc"는 한 단어를 다른 단어로 변경하는 데 필요한 최소 한 문자 편집 (대체, 삽입 또는 삭제) 수로 정의됩니다. 
    """
    m = len(ref)
    n = len(hyp)

    # special case
    if ref == hyp:
        return 0
    if m == 0:
        return n
    if n == 0:
        return m

    if m < n:
        ref, hyp = hyp, ref
        m, n = n, m

    # use O(min(m, n)) space
    distance = np.zeros((2, n + 1), dtype=np.int32)

    # initialize distance matrix
    for j in range(0,n + 1):
        distance[0][j] = j

    # calculate levenshtein distance
    for i in range(1, m + 1):
        prev_row_idx = (i - 1) % 2
        cur_row_idx = i % 2
        distance[cur_row_idx][0] = i
        for j in range(1, n + 1):
            if ref[i - 1] == hyp[j - 1]:
                distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
            else:
                s_num = distance[prev_row_idx][j - 1] + 1
                i_num = distance[cur_row_idx][j - 1] + 1
                d_num = distance[prev_row_idx][j] + 1
                distance[cur_row_idx][j] = min(s_num, i_num, d_num)

    return distance[m % 2][n]


In [None]:
def word_errors(reference, hypothesis, ignore_case=False, delimiter=' '):
    """참조 시퀀스와 가설 시퀀스 사이의 거리를 단어 수준으로 계산합니다.
     : param reference : 참조 문장.
     : param hypothesis : 가설 문장.
     : param ignore_case : 대소 문자 구분 여부.
     : param delimiter : 입력 문장의 구분자.
    """
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()

    ref_words = reference.split(delimiter)
    hyp_words = hypothesis.split(delimiter)

    edit_distance = _levenshtein_distance(ref_words, hyp_words)
    return float(edit_distance), len(ref_words)

In [None]:
def char_errors(reference, hypothesis, ignore_case=False, remove_space=False):
    if ignore_case == True:
        reference = reference.lower()
        hypothesis = hypothesis.lower()

    join_char = ' '
    if remove_space == True:
        join_char = ''

    reference = join_char.join(filter(None, reference.split(' ')))
    hypothesis = join_char.join(filter(None, hypothesis.split(' ')))

    edit_distance = _levenshtein_distance(reference, hypothesis)
    return float(edit_distance), len(reference)

In [None]:
def wer(reference, hypothesis, ignore_case=False, delimiter=' '):
    """Calculate word error rate (WER). 
    WER = (Sw + Dw + Iw) / Nw
    Sw는 대체 된 단어의 수입니다.
    Dw는 삭제 된 단어의 수입니다.
    Iw는 삽입 된 단어의 수입니다.
    Nw는 참조의 단어 수입니다.
    """
    edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case,
                                         delimiter)
    if ref_len == 0:
        raise ValueError("Reference's word number should be greater than 0.")

    wer = float(edit_distance) / ref_len
    return wer

In [None]:
def cer(reference, hypothesis, ignore_case=False, remove_space=False):
    """Calculate charactor error rate (CER). 
        CER = (Sc + Dc + Ic) / Nc
        Sc is the number of characters substituted,
        Dc is the number of characters deleted,
        Ic is the number of characters inserted
        Nc is the number of characters in the reference
    """
    edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case,
                                         remove_space)

    if ref_len == 0:
        raise ValueError("Length of reference should be greater than 0.")

    cer = float(edit_distance) / ref_len
    return cer

# Text map

In [None]:
class TextTransform:
    """Maps characters to integers and vice versa"""
    def __init__(self):
        char_map_str = """
        ' 0
        <SPACE> 1
        ㅎ 2
        ㅍ 3
        ㅌ 4
        ㅋ 5
        ㅊ 6
        ㅈ 7
        ㅇ 8
        ㅆ 9
        ㅅ 10
        ㅂ 11
        ㅁ 12
        ㄹ 13
        ㄷ 14
        ㄴ 15
        ㄲ 16
        ㄱ 17
        ㅣ 18
        ㅢ 19
        ㅡ 20
        ㅠ 21
        ㅟ 22
        ㅞ 23
        ㅝ 24
        ㅜ 25
        ㅛ 26
        ㅚ 27
        ㅙ 28
        ㅘ 29
        ㅗ 30
        ㅖ 31
        ㅕ 32
        ㅔ 33
        ㅓ 34
        ㅒ 35
        ㅑ 36
        ㅐ 37
        ㅏ 38
        ㅉ 39
        ㅄ 40
        ㅃ 41
        ㅀ 42
        ㄿ 43
        ㄾ 44
        ㄽ 45
        ㄼ 46
        ㄻ 47
        ㄺ 48
        ㄸ 49
        ㄶ 50
        ㄵ 51
        ㄳ 52
        ∅ 53
        """
        self.char_map = {}
        self.index_map = {}
        for line in char_map_str.strip().split('\n'):
            ch, index = line.split()
            self.char_map[ch] = int(index)
            self.index_map[int(index)] = ch
        self.index_map[1] = ' '

    def text_to_int(self, text):
        """ Use a character map and convert text to an integer sequence """
        int_sequence = []
        for c in text:
            if c == ' ':
                ch = self.char_map['<SPACE>']
            else:
                ch = self.char_map[c]
            int_sequence.append(ch)
        return int_sequence

    def int_to_text(self, labels):

        """ Use a character map and convert integer labels to an text sequence """
        string = []
        for i in labels:
            string.append(self.index_map[i])
        return ''.join(string).replace('<SPACE>', ' ')
    def join_text(text):
      result = text.replace('∅','')
      result = join_jamos(result)
      return result


# Transform

In [None]:
train_audio_transforms = nn.Sequential(
    torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
    torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
    torchaudio.transforms.TimeMasking(time_mask_param=100)
)

valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
text_transform = TextTransform()

# Data processing

In [None]:
def data_processing(data, data_type="train"):
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    for (waveform, _, utterance, _, _, _) in data:
        if data_type == 'train':
            spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        elif data_type == 'valid':
            spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        elif data_type == 'test':
            spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        else:
            raise Exception('data_type should be train or valid')
        spectrograms.append(spec)
        label = torch.Tensor(text_transform.text_to_int(preprocess(utterance)))
        labels.append(label)
        input_lengths.append(spec.shape[0]//2)
        label_lengths.append(len(label))
    spectrograms = nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
    labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
    return spectrograms, labels, input_lengths, label_lengths


# Decoder

In [None]:
def GreedyDecoder(output, labels, label_lengths, blank_label=54, collapse_repeated=True):
	arg_maxes = torch.argmax(output, dim=2)
	decodes = []
	targets = []
	for i, args in enumerate(arg_maxes):
		decode = []
		targets.append(text_transform.int_to_text(labels[i][:label_lengths[i]].tolist()))
		for j, index in enumerate(args):
			if index != blank_label:
				if collapse_repeated and j != 0 and index == args[j -1]:
					continue
				decode.append(index.item())
		decodes.append(text_transform.int_to_text(decode))
	return decodes, targets
# main(learning_rate, batch_size, epochs, libri_train_set, libri_test_set)


# Model

In [None]:
class CNNLayerNorm(nn.Module):
    """Layer normalization built for cnns input"""
    def __init__(self, n_feats):
        super(CNNLayerNorm, self).__init__()
        self.layer_norm = nn.LayerNorm(n_feats)

    def forward(self, x):
        # x (batch, channel, feature, time)
        x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
        x = self.layer_norm(x)
        return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) 

class ResidualCNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
        super(ResidualCNN, self).__init__()

        self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel//2)
        self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel//2)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.layer_norm1 = CNNLayerNorm(n_feats)
        self.layer_norm2 = CNNLayerNorm(n_feats)

    def forward(self, x):
        residual = x  # (batch, channel, feature, time)
        x = self.layer_norm1(x)
        x = F.gelu(x)
        x = self.dropout1(x)
        x = self.cnn1(x)
        x = self.layer_norm2(x)
        x = F.gelu(x)
        x = self.dropout2(x)
        x = self.cnn2(x)
        x += residual
        return x # (batch, channel, feature, time)

# The two sets of activations are summed to form the output activations for the layer h The function g(·) can be the standard recurrent operation
class BidirectionalGRU(nn.Module):
    def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
        super(BidirectionalGRU, self).__init__()

        self.BiGRU = nn.GRU(
            input_size=rnn_dim, hidden_size=hidden_size,
            num_layers=1, batch_first=batch_first, bidirectional=True)
        self.layer_norm = nn.LayerNorm(rnn_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.layer_norm(x)
        x = F.gelu(x)
        x, _ = self.BiGRU(x)
        x = self.dropout(x)
        return x

class SpeechRecognitionModel(nn.Module):
    def __init__(self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1):
        super(SpeechRecognitionModel, self).__init__()
        n_feats = n_feats//2
        self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3//2)  # cnn for extracting heirachal features

        # n residual cnn layers with filter size of 32
        self.rescnn_layers = nn.Sequential(*[
            ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) 
            for _ in range(n_cnn_layers)
        ])
        self.fully_connected = nn.Linear(n_feats*32, rnn_dim)
        self.birnn_layers = nn.Sequential(*[
            BidirectionalGRU(rnn_dim=rnn_dim if i==0 else rnn_dim*2,
                             hidden_size=rnn_dim, dropout=dropout, batch_first=i==0)
            for i in range(n_rnn_layers)
        ])
        self.classifier = nn.Sequential(
            nn.Linear(rnn_dim*2, rnn_dim),  # birnn returns rnn_dim*2
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(rnn_dim, n_class)
        )

    def forward(self, x):
        x = self.cnn(x)
        x = self.rescnn_layers(x)
        sizes = x.size()
        x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3])  # (batch, feature, time)
        x = x.transpose(1, 2) # (batch, time, feature)
        x = self.fully_connected(x)
        x = self.birnn_layers(x)
        x = self.classifier(x)
        return x


# Data loader

In [None]:
class Dataset(Dataset):
    def __init__(self, x):
        self.x = x

    def __getitem__(self, index):
        return self.x[index]

    def __len__(self):
        return self.x.shape[0]

def load_dataset(data, data_type="train"):
    spectrograms = []
    labels = []
    input_lengths = []
    label_lengths = []
    for (waveform, _, _, _, _, _) in data:
        if data_type == 'train':
            spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        else:
            spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
        spectrograms.append(spec)
        
    return spectrograms

def get_dataloader(data):
    x_train = load_dataset(data, "train")
    x_valid = load_dataset(data, 'valid')

    mean = np.mean(x_train)
    std = np.std(x_train)
    x_train = (x_train - mean)/std
    x_valid = (x_valid - mean)/std

    train_set = Dataset(x_train)
    vaild_set = Dataset(x_valid)

    train_loader = DataLoader(train_set, batch_size=4, shuffle=True, drop_last=False)
    valid_loader = DataLoader(vaild_set, batch_size=4, shuffle=False, drop_last=False)
    
    return train_loader, valid_loader


# train test

In [None]:
def train(model, device, train_loader, criterion, optimizer, epoch, iter_meter):
    model.train()
    data_len = len(train_loader.dataset)
    for batch_idx, _data in enumerate(train_loader):
        global spectrograms
        spectrograms, labels, input_lengths, label_lengths = _data 
        spectrograms, labels = spectrograms.to(device), labels.to(device)
        optimizer.zero_grad()

        output = model(spectrograms)  # (batch, time, n_class)
        output = F.log_softmax(output, dim=2)
        output = output.transpose(0, 1) # (time, batch, n_class)
        loss = criterion(output, labels, input_lengths, label_lengths)
        loss.backward()
        optimizer.step()
        iter_meter.step()
        if batch_idx % 100 == 0 or batch_idx == data_len:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(spectrograms), data_len,
                100. * batch_idx / len(train_loader), loss.item()))


In [None]:
def test(model, device, test_loader, criterion, epoch, iter_meter):
    print('\nevaluating...')
    model.eval()
    test_loss = 0
    test_cer, test_wer = [], []
    with torch.no_grad():
        for i, _data in enumerate(test_loader):
            spectrograms, labels, input_lengths, label_lengths = _data 
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            output = model(spectrograms)  # (batch, time, n_class)
            # The output layer L is a softmax computing a probability distribution over characters given
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1) # (time, batch, n_class)
            loss = criterion(output, labels, input_lengths, label_lengths)
            test_loss += loss.item() / len(test_loader)
            decoded_preds, decoded_targets = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
            for j in range(len(decoded_preds)):
                test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
                test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
    avg_cer = sum(test_cer)/len(test_cer)
    avg_wer = sum(test_wer)/len(test_wer)
    
    print('Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n'.format(test_loss, avg_cer, avg_wer))


In [None]:
# test 함수를 변경해서 predict을 만들어야 하는걸까? -> 더 알아보기
# 일단은 test함수에서 loss와 error 부분만 제거 후 결과값을 df로 만들어서 출력해보기

def predict(model, device, test_loader, criterion):
    print('\predicting...')
    result_text = []

    model.eval()
    with torch.no_grad():
        for i, _data in enumerate(test_loader):
            spectrograms, labels, input_lengths, label_lengths = _data 
            spectrograms, labels = spectrograms.to(device), labels.to(device)

            output = model(spectrograms)  # (batch, time, n_class)
            # The output layer L is a softmax computing a probability distribution over characters given
            output = F.log_softmax(output, dim=2)
            output = output.transpose(0, 1) # (time, batch, n_class)
            # loss = criterion(output, labels, input_lengths, label_lengths)
            # test_loss += loss.item() / len(test_loader)
            decoded_preds, _ = GreedyDecoder(output.transpose(0, 1), labels, label_lengths)
            result_text.append(join_text(decoded_preds[0]))
    result = pd.DataFrame(result_text,columns=['ReadingLableText'])
    submission = pd.read_excel(data_path+'/test/submission.xlsx')
    submission['ReadingLableText'] = result['ReadingLableText']

    #         for j in range(len(decoded_preds)):
    #             test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
    #             test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
    # avg_cer = sum(test_cer)/len(test_cer)
    # avg_wer = sum(test_wer)/len(test_wer)
    
    # print('Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n'.format(test_loss, avg_cer, avg_wer))
    return submission


# main

In [None]:
class IterMeter(object):
    """keeps track of total iterations"""
    def __init__(self):
        self.val = 0

    def step(self):
        self.val += 1

    def get(self):
        return self.val

def main(learning_rate=5e-4, batch_size=20, epochs=10):

    hparams = {
        "n_cnn_layers": 3,
        "n_rnn_layers": 5,
        "rnn_dim": 512,
        "n_class": 55,
        "n_feats": 128,
        "stride":2,
        "dropout": 0.1,
        "learning_rate": learning_rate,
        "batch_size": batch_size,
        "epochs": epochs
    }

    use_cuda = torch.cuda.is_available()
    torch.cuda.empty_cache()

    torch.manual_seed(7)
    device = torch.device("cuda" if use_cuda else "cpu")
    print(device)


    kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}
    train_loader = data.DataLoader(dataset=train_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=True,
                                collate_fn=lambda x: data_processing(x, 'train'),
                                **kwargs)
    test_loader = data.DataLoader(dataset=valid_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=False,
                                collate_fn=lambda x: data_processing(x, 'valid'),
                                **kwargs)

    model = SpeechRecognitionModel(
        hparams['n_cnn_layers'], hparams['n_rnn_layers'], hparams['rnn_dim'],
        hparams['n_class'], hparams['n_feats'], hparams['stride'], hparams['dropout']
        ).to(device)
    if os.path.isfile(data_path + '/model/model_state_dict.pth'):
        model.load_state_dict(torch.load(data_path + '/model/model_state_dict.pth'))
        print('model load')
    print('Num Model Parameters', sum([param.nelement() for param in model.parameters()]))

    optimizer = optim.Adam(model.parameters(), hparams['learning_rate'])
    criterion = nn.CTCLoss(blank=54).to(device)

    iter_meter = IterMeter()
    # for epoch in range(1, epochs + 1):
    #     train(model, device, train_loader, criterion, optimizer, epoch, iter_meter)
    #     test(model, device, test_loader, criterion, epoch, iter_meter)
    # torch.save(model.state_dict(), data_path + '/model/model_state_dict.pth')  # 모델 객체의 state_dict 저장
    # print('model save')

    test_loader = data.DataLoader(dataset=test_dataset,
                                batch_size=hparams['batch_size'],
                                shuffle=False,
                                collate_fn=lambda x: data_processing(x, 'valid'),
                                **kwargs)
    submission = predict(model, device, test_loader, criterion)
    submission.to_excel(data_path+'/submission_sample.xlsx',index=False)
    print('submission save')

    torch.cuda.empty_cache()

learning_rate = 5e-4
batch_size = 1
epochs = 1

main(learning_rate, batch_size, epochs)
