<a href="https://colab.research.google.com/github/sangmin213/DialectClassification/blob/main/model_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Load Data

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
import os
import re
from matplotlib import pyplot as plt
from glob import glob
import numpy as np
import pickle
from tqdm import tqdm
import time
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.model_selection import KFold
from torch.utils.data import Subset

In [None]:
index2region={0:'gangwon', 1:'gyeongsang', 2:'jeonla', 3:'chungcheong', 4:'jeju'}
region2index = {v:k for k,v in index2region.items()}
region_shortening = ['GW','GS','JL','CC','JJ']
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
# 데이터 셋 구성 (full dataset)
dataset_dir = '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/'

In [None]:
region_dir = glob(dataset_dir)
region_dir

['/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/']

In [None]:
for k, v in index2region.items():
    exec(f"{v}_dirs = glob(dataset_dir+'*_{v}/*')")
jeonla_dirs

['/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000018',
 '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000027',
 '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000012',
 '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000032',
 '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000005',
 '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000014',
 '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000006',
 '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000019',
 '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000024',
 '/content/drive/MyDrive/인공지능_프로젝트_Team_12/데이터/small_dataset/preprocessed_jeonla/DJDD20000015']

In [None]:
def make_tuple_data(dirs, max_num):
    for i, region_dir in enumerate(dirs):
        if i>=max_num:break
        spectro_path = glob(region_dir+'/*_spectro.pickle')[0]
        mfcc_path = glob(region_dir+'/*_mfcc.pickle')[0]
        chroma_path = glob(region_dir+'/*_chroma.pickle')[0]
        
        with open(spectro_path, "rb") as f:
            spectro = pickle.load(f)
        with open(mfcc_path, "rb") as f:
            mfcc = pickle.load(f)
        with open(chroma_path, "rb") as f:
            chroma = pickle.load(f)

        if i == 0:
            spectro_data = spectro
            mfcc_data = mfcc
            chroma_data = chroma
        else:
            spectro_data = np.concatenate([spectro_data,spectro], axis=0)
            mfcc_data = np.concatenate([mfcc_data,mfcc], axis=0)
            chroma_data = np.concatenate([chroma_data,chroma], axis=0)
    if max_num ==0:return []
        
    r_data = [(s,m,c) for s,m,c in zip(spectro_data,mfcc_data,chroma_data)]
        
    return r_data

def make_tuple(max_num=2):
    jeonla_data = make_tuple_data(jeonla_dirs, max_num)
    chungcheong_data = make_tuple_data(chungcheong_dirs, max_num)
    gyeongsang_data = make_tuple_data(gyeongsang_dirs, max_num)
    jeju_data = make_tuple_data(jeju_dirs, max_num)
    gangwon_data = make_tuple_data(gangwon_dirs, max_num)
    return jeonla_data, chungcheong_data, gyeongsang_data, jeju_data, gangwon_data

def print_data(r_data, region):
    if len(r_data)==0: return
    print(f"{region} data num: ", len(r_data))
    print(f"{region} tuple size", len(r_data[0]))
    print(f"{region} spec shape", r_data[0][0].shape)

jeonla_data, chungcheong_data, gyeongsang_data, jeju_data, gangwon_data = make_tuple(1000)
print_data(jeonla_data, 'jeonla')
print_data(chungcheong_data, 'chungcheong')
print_data(gyeongsang_data, 'gyeongsang')
print_data(jeju_data, 'jeju')
print_data(gangwon_data, 'gangwon')

jeonla data num:  913
jeonla tuple size 3
jeonla spec shape (201, 501)
chungcheong data num:  1108
chungcheong tuple size 3
chungcheong spec shape (201, 501)
gyeongsang data num:  842
gyeongsang tuple size 3
gyeongsang spec shape (201, 501)
jeju data num:  756
jeju tuple size 3
jeju spec shape (201, 501)
gangwon data num:  1183
gangwon tuple size 3
gangwon spec shape (201, 501)


In [None]:
jeonla_data_l = []
for data in jeonla_data:
    y = [0,0,0,0,0]
    y[region2index['jeonla']] = 1
    jeonla_data_l.append((data,y))

chungcheong_data_l = []
for data in chungcheong_data:
    y = [0,0,0,0,0]
    y[region2index['chungcheong']] = 1
    chungcheong_data_l.append((data,y))

gyeongsang_data_l = []
for data in gyeongsang_data:
    y = [0,0,0,0,0]
    y[region2index['gyeongsang']] = 1
    gyeongsang_data_l.append((data,y))

jeju_data_l = []
for data in jeju_data:
    y = [0,0,0,0,0]
    y[region2index['jeju']] = 1
    jeju_data_l.append((data,y))

gangwon_data_l = []
for data in gangwon_data:
    y = [0,0,0,0,0]
    y[region2index['gangwon']] = 1
    gangwon_data_l.append((data,y))

In [None]:
datasumup = np.concatenate([jeonla_data_l[:500], chungcheong_data_l[:500], gangwon_data_l[:500], jeju_data_l[:500], gyeongsang_data_l[:500]], axis=0)



In [None]:
from sklearn.preprocessing import normalize
class MultiModalDataset(Dataset):

    def __init__(self, data):

        self.data = data

    def __getitem__(self, idx):
        datas, label = self.data[idx]
        spec, mfcc, chroma = datas
        spec, mfcc, chroma = normalize(spec), normalize(mfcc), normalize(chroma)
        spec, mfcc, chroma = torch.tensor(spec, dtype=torch.float32), torch.tensor(mfcc, dtype=torch.float32), torch.tensor(chroma, dtype=torch.float32)
        spec, mfcc, chroma = spec.unsqueeze(0), mfcc.unsqueeze(0), chroma.unsqueeze(0)
        label = torch.tensor(label, dtype=torch.float32)

        data = (spec, mfcc, chroma)
        return data, label
    
    def __len__(self):
        return len(self.data)

In [None]:
dataset = MultiModalDataset(datasumup)
len(dataset)

2500

# ResNet

In [None]:
class BasicBlock(nn.Module):
    
    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3,3), stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.seq1 = nn.Sequential(self.conv1, self.bn1, self.relu)
        self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3,3), stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.seq2 = nn.Sequential(self.conv2, self.bn2)
        
        self.down_flag = False
        if in_channels != out_channels: self.down_flag = True

        self.downsample = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(1,1), stride=2, padding=0, bias=False)
    
    def forward(self, x):
        #print(x.shape)
        y = self.seq1(x)
        #print(y.shape)
        y = self.seq2(y)
        #print(y.shape)

        if self.down_flag:
            x = self.downsample(x)
        
        y = self.relu(y)
        #print(x.shape)
        #print(y.shape)
        y = y + x

        return y
        

In [None]:
class ResNet18(nn.Module):

    def __init__(self, in_channels, output_dim=256, model_type='spec',best_model_save_path="./ResNet_best_model.pt"):
        super(ResNet18, self).__init__()

        self.best_model_save_path = best_model_save_path
        self.data_type = model_type

        self.relu = nn.ReLU()
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=(7,7), stride=2, padding=3)
        self.BN1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(kernel_size=(3,3), stride=2, padding=1)

        self.seq1 = nn.Sequential(self.conv1, self.BN1, self.pool1)

        self.seq2 = nn.Sequential(BasicBlock(64,64), BasicBlock(64,64))
        self.seq3 = nn.Sequential(BasicBlock(64,64), BasicBlock(64, 128, stride=2))
        self.seq4 = nn.Sequential(BasicBlock(128,128), BasicBlock(128,128))
        self.seq5 = nn.Sequential(BasicBlock(128,128), BasicBlock(128,256,stride=2))

        self.avg_pool1 = nn.AdaptiveAvgPool2d((1,1))
        # if model_type=='spec':
        #     self.fc1 = nn.Linear(256*13*32, output_dim)
        # elif model_type=='mfcc':
        #     self.fc1 = nn.Linear(256*7*32, output_dim)
        # elif model_type=='chroma':
        #     self.fc1 = nn.Linear(256*1*32, output_dim)
        self.fc1 = nn.Linear(256, output_dim)


        self.lastlayer = nn.Sequential(self.fc1, self.relu)

    def forward(self, x):
        y = self.seq1(x)
        y = self.seq2(y)
        y = self.seq3(y)
        y = self.seq4(y)
        y = self.seq5(y)
        y = self.avg_pool1(y)
        y = y.view(y.shape[0],-1)
        y = self.lastlayer(y)

        return y

    def train_(self, train_loader, val_loader, learning_rate, epochs, device):
        self.train_accuracy = []
        self.train_loss = []
        self.val_accuracy = []
        self.val_loss = []
        self.pred_labels_train = []
        self.real_labels_train = []
        self.pred_labels_val = None
        self.real_labels_val = None

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.loss_f=nn.CrossEntropyLoss()

        best_epoch = -1
        best_acc = -1 
        
        for epoch in range(1, epochs+1):
            total = 0
            correct = 0
            start_time = time.time()
            epoch_loss = 0.0
            epoch_acc = 0.0
            self.train()

            for batch_idx, (batch_data, batch_label) in enumerate(tqdm(train_loader)):
                
                spec, mfcc, chroma = batch_data

                if self.data_type=="mfcc":
                    batch_data=mfcc.to(device)
                elif self.data_type=="spec":
                    batch_data=spec.to(device)
                elif self.data_type=="chroma":
                    batch_data=chroma.to(device)
                
                batch_label = batch_label.to(device)

                self.optimizer.zero_grad()

                pred = self.forward(batch_data) # (batch_size, 5)
                loss = self.loss_f(pred, batch_label)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

                _, pred_indices = torch.max(pred, axis=1)
                total += batch_data.shape[0]
                batch_label = torch.max(batch_label, axis=1)[1]
                correct += pred_indices.eq(batch_label).sum().item()
                
                if epoch==epochs: #last epoch
                    self.pred_labels_train.append(pred_indices)
                    self.real_labels_train.append(batch_label)
                #for p, l in zip(pred_indices, batch_label):
                #    print(f"predicted: {index2region[p.item()]} real:{index2region[l.item()]}")
            
            end_time = time.time()
            print(f"epoch {epoch} time: {end_time-start_time}sec(s).")
            

            epoch_loss /= len(train_loader)
            self.train_loss.append(epoch_loss)
            epoch_acc = correct / total
            self.train_accuracy.append(epoch_acc)
            print(f"epoch {epoch} train accuracy: {epoch_acc}")
            print(f"epoch {epoch} loss: {epoch_loss}")  


            predicted, labels, val_loss = self.predict(val_loader, device)
            if epoch==epochs: #last epoch
                self.pred_labels_val=predicted.cpu().numpy()
                self.real_labels_val=labels.cpu().numpy()
            val_acc = predicted.eq(labels).sum().item() / len(predicted)
            print(f"epoch {epoch} val accuracy: {val_acc}")
            print(f"epoch {epoch} val loss: {val_loss}")

            if val_acc > epoch_acc:
                best_acc = val_acc
                best_epoch = epoch
                torch.save(self.state_dict(), self.best_model_save_path)
            
            self.val_accuracy.append(val_acc)
            self.val_loss.append(val_loss)
        
        self.pred_labels_train = torch.cat(self.pred_labels_train, dim=0)
        self.real_labels_train = torch.cat(self.real_labels_train, dim=0)
        self.pred_labels_train = self.pred_labels_train.cpu().numpy()
        self.real_labels_train = self.real_labels_train.cpu().numpy()
            
            
            
        print("Finish!")
        
        return best_acc, best_epoch
            
    def predict(self, test_loader, device):
        self.eval()
        labels = []
        predicted = []
        val_loss = 0.0
        with torch.no_grad():
            for batch_idx, (batch_data, batch_label) in enumerate(tqdm(test_loader)):

                spec, mfcc, chroma = batch_data
                if self.data_type=="mfcc":
                    batch_data=mfcc.to(device)
                elif self.data_type=="spec":
                    batch_data=spec.to(device)
                elif self.data_type=="chroma":
                    batch_data=chroma.to(device)
                batch_label = batch_label.to(device)
                
                pred = self.forward(batch_data)

                _, pred_indices = torch.max(pred, axis=1)
                loss = self.loss_f(pred, batch_label)
                
                val_loss += loss.item()

                predicted.append(pred_indices)
                batch_label = torch.max(batch_label, axis=1)[1]
                labels.append(batch_label)
        val_loss /= len(test_loader)
        predicted = torch.cat(predicted, dim=0)
        labels = torch.cat(labels, dim=0)

        return predicted, labels, val_loss
    
    def plot(self, which):
        
        X = [i for i in range(1, len(self.train_accuracy) + 1)]
        if which=='train_loss':
            y = self.train_loss
        elif which=='train_acc':
            y = self.train_accuracy
        elif which=='val_acc':
            y = self.val_accuracy
        elif which=='val_loss':
            y = self.val_loss
        elif which=='confusion_train':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_train, self.pred_labels_train, display_labels=region_shortening)
            plt.title('train confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()
            return
        elif which=='confusion_normalize_train':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_train, self.pred_labels_train, display_labels=region_shortening, normalize='true')
            plt.title('train confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()    
            return        
        elif which=='confusion_val':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_val, self.pred_labels_val, display_labels=region_shortening)
            plt.title('val confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()
            return
        elif which=='confusion_normalize_val':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_val, self.pred_labels_val, display_labels=region_shortening, normalize='true')
            plt.title('val confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()    
            return   
            

        plt.xlabel("epoch")
        plt.ylabel(which)
        plt.title(which)
        plt.plot(X, y, label="Train loss")
        plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
        plt.show()

# LeNet

In [None]:
class Block(nn.Module):
    def __init__(self,input_channel, output_channel, kernel_size, stride=1, padding=0):
        super(Block,self).__init__()

        self.conv1=nn.Conv2d(input_channel,output_channel,kernel_size=kernel_size,padding=padding) # no stride
        self.conv2=nn.Conv2d(output_channel,output_channel,kernel_size=kernel_size,stride=stride,padding=padding) # stride if down sampling
        self.bn=nn.BatchNorm2d(output_channel)
        self.relu=nn.ReLU()

        self.layer=nn.Sequential(self.conv1, self.bn, self.relu, self.conv2, self.bn)

        self.stride=stride
        self.iden = nn.Conv2d(input_channel, output_channel, kernel_size=(1,1), stride=1)
        if input_channel==64:
            self.iden = nn.Conv2d(input_channel, output_channel, kernel_size=(1,1), stride=1)

    def forward(self,x):
        y = self.layer(x)
        if self.stride==1: # stride==2 인 경우는 downsampling 구간이기 때문에 residual 안함
            y = y + self.iden(x)
        y = self.relu(y)

        return y

In [None]:
class LeNet(nn.Module):
    def __init__(self,data_type="mfcc",method="origin",best_model_save_path="./LeNet_best_model.pt"):
        super(LeNet,self).__init__()
        
        self.best_model_save_path = best_model_save_path
        self.data_type=data_type
        self.method=method

        if data_type=="mfcc": # (1,100,501)
            self.conv1=nn.Conv2d(1,64,kernel_size=(6,7),stride=2,padding=3)
        elif data_type=="spec": #(1,201,501)
            self.conv1=nn.Conv2d(1,64,kernel_size=(7,7),stride=2,padding=3)
        elif data_type=="chroma": #(1,12,501)
            self.conv1=nn.Conv2d(1,64,kernel_size=(6,7),stride=1,padding=3)
        self.maxpool=nn.MaxPool2d(kernel_size=(3,3),stride=2,padding=1)
        self.avgpool=nn.AdaptiveAvgPool2d((1,1)) # global avg pool
        self.relu=nn.ReLU()

        self.seq1=nn.Sequential(self.conv1,self.relu,self.maxpool)  # (1,100,501)/(1,12,501)/(1,201,501) -> (64,51,251)/(1,13,251)/(1,101,251).conv -> (64,25,126)/(64,7,126)/(64,51,126).maxpool
        self.seq2=nn.Sequential(Block(64,64,(3,3),padding=1),Block(64,64,(3,3),padding=1),Block(64,128,(3,4),stride=2,padding=1)) # (64,25,126) -> (64,25,126).block -> (64,25,126).block -> (128,13,63)/(128,4,63)/(128,26,63).block 
        if data_type=="mfcc":
            self.seq3=nn.Sequential(Block(128,128,(3,3),padding=1),Block(128,128,(3,3),padding=1),Block(128,256,(3,3),stride=2,padding=1)) # (128,13,63) -> (128,13,63).block -> (128,13,63).block -> (256,7,32).block 
        elif data_type=="spec":
            self.seq3=nn.Sequential(Block(128,128,(3,3),padding=1),Block(128,128,(3,3),padding=1),Block(128,256,(4,3),stride=2,padding=1)) # (128,26,63) -> ... -> (256,12,32).block 
        elif data_type=="chroma":
            self.seq3=nn.Sequential(Block(128,128,(3,3),padding=1),Block(128,128,(3,3),padding=1),Block(128,256,(2,3),stride=2,padding=1)) # (128,4,63)/(128,26,63) -> ... -> (256,2,32)/(256,13,32).block 
        self.seq4=nn.Sequential(Block(256,256,(3,3),padding=1),Block(256,256,(3,3),padding=1),self.avgpool) # ... -> (256,1,1).avgpool

        if self.method=="multimodal":
            self.fc=nn.Linear(256,256)
        if self.method=="origin":
            self.fc=nn.Linear(256,5)
            self.loss=nn.CrossEntropyLoss()
            self.optimizer=optim.Adam(self.parameters(),lr=0.0001)

        self.train_accuracy = []
        self.train_loss = []
        self.val_accuracy = []
        self.val_loss = []

    def forward(self,x):       
        y=self.seq1(x)
        y=self.seq2(y)
        y=self.seq3(y)
        y=self.seq4(y)
        y=y.view(y.shape[0],-1)
        y=self.fc(y)
        return y

    def train_(self, train_loader, val_loader, learning_rate, epochs, device):
        self.train_accuracy = []
        self.train_loss = []
        self.val_accuracy = []
        self.val_loss = []
        self.pred_labels_train = []
        self.real_labels_train = []
        self.pred_labels_val = None
        self.real_labels_val = None

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.loss_f=nn.CrossEntropyLoss()

        best_epoch = -1
        best_acc = -1 
        
        for epoch in range(1, epochs+1):
            total = 0
            correct = 0
            start_time = time.time()
            epoch_loss = 0.0
            epoch_acc = 0.0
            self.train()

            for batch_idx, (batch_data, batch_label) in enumerate(tqdm(train_loader)):
                
                spec, mfcc, chroma = batch_data

                if self.data_type=="mfcc":
                    batch_data=mfcc.to(device)
                elif self.data_type=="spec":
                    batch_data=spec.to(device)
                elif self.data_type=="chroma":
                    batch_data=chroma.to(device)
                
                batch_label = batch_label.to(device)

                self.optimizer.zero_grad()

                pred = self.forward(batch_data) # (batch_size, 5)
                loss = self.loss_f(pred, batch_label)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

                _, pred_indices = torch.max(pred, axis=1)
                total += batch_data.shape[0]
                batch_label = torch.max(batch_label, axis=1)[1]
                correct += pred_indices.eq(batch_label).sum().item()
                
                if epoch==epochs: #last epoch
                    self.pred_labels_train.append(pred_indices)
                    self.real_labels_train.append(batch_label)
                #for p, l in zip(pred_indices, batch_label):
                #    print(f"predicted: {index2region[p.item()]} real:{index2region[l.item()]}")
            
            end_time = time.time()
            print(f"epoch {epoch} time: {end_time-start_time}sec(s).")
            

            epoch_loss /= len(train_loader)
            self.train_loss.append(epoch_loss)
            epoch_acc = correct / total
            self.train_accuracy.append(epoch_acc)
            print(f"epoch {epoch} train accuracy: {epoch_acc}")
            print(f"epoch {epoch} loss: {epoch_loss}")  


            predicted, labels, val_loss = self.predict(val_loader, device)
            if epoch==epochs: #last epoch
                self.pred_labels_val=predicted.cpu().numpy()
                self.real_labels_val=labels.cpu().numpy()
            val_acc = predicted.eq(labels).sum().item() / len(predicted)
            print(f"epoch {epoch} val accuracy: {val_acc}")
            print(f"epoch {epoch} val loss: {val_loss}")

            if val_acc > epoch_acc:
                best_acc = val_acc
                best_epoch = epoch
                torch.save(self.state_dict(), self.best_model_save_path)
            
            self.val_accuracy.append(val_acc)
            self.val_loss.append(val_loss)
        
        self.pred_labels_train = torch.cat(self.pred_labels_train, dim=0)
        self.real_labels_train = torch.cat(self.real_labels_train, dim=0)
        self.pred_labels_train = self.pred_labels_train.cpu().numpy()
        self.real_labels_train = self.real_labels_train.cpu().numpy()
            
            
            
        print("Finish!")
        
        return best_acc, best_epoch
            
    def predict(self, test_loader, device):
        self.eval()
        labels = []
        predicted = []
        val_loss = 0.0
        with torch.no_grad():
            for batch_idx, (batch_data, batch_label) in enumerate(tqdm(test_loader)):

                spec, mfcc, chroma = batch_data
                if self.data_type=="mfcc":
                    batch_data=mfcc.to(device)
                elif self.data_type=="spec":
                    batch_data=spec.to(device)
                elif self.data_type=="chroma":
                    batch_data=chroma.to(device)
                batch_label = batch_label.to(device)
                
                pred = self.forward(batch_data)

                _, pred_indices = torch.max(pred, axis=1)
                loss = self.loss_f(pred, batch_label)
                
                val_loss += loss.item()

                predicted.append(pred_indices)
                batch_label = torch.max(batch_label, axis=1)[1]
                labels.append(batch_label)
        val_loss /= len(test_loader)
        predicted = torch.cat(predicted, dim=0)
        labels = torch.cat(labels, dim=0)

        return predicted, labels, val_loss
    
    def plot(self, which):
        
        X = [i for i in range(1, len(self.train_accuracy) + 1)]
        if which=='train_loss':
            y = self.train_loss
        elif which=='train_acc':
            y = self.train_accuracy
        elif which=='val_acc':
            y = self.val_accuracy
        elif which=='val_loss':
            y = self.val_loss
        elif which=='confusion_train':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_train, self.pred_labels_train, display_labels=region_shortening)
            plt.title('train confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()
            return
        elif which=='confusion_normalize_train':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_train, self.pred_labels_train, display_labels=region_shortening, normalize='true')
            plt.title('train confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()    
            return        
        elif which=='confusion_val':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_val, self.pred_labels_val, display_labels=region_shortening)
            plt.title('val confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()
            return
        elif which=='confusion_normalize_val':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_val, self.pred_labels_val, display_labels=region_shortening, normalize='true')
            plt.title('val confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()    
            return   
            

        plt.xlabel("epoch")
        plt.ylabel(which)
        plt.title(which)
        plt.plot(X, y, label="Train loss")
        plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
        plt.show()

# LSTM

In [None]:
class LSTM(nn.Module):
    def __init__(self,data_type="mfcc",method="origin",best_model_save_path="./LSTM_best_model.pt"):
        super(LSTM,self).__init__()

        self.best_model_save_path=best_model_save_path
        self.data_type=data_type
        self.method=method

        self.hidden_size=64
        self.num_layers=1
        if self.data_type=="mfcc":
            self.input_size=100  #mfcc 기준
        elif self.data_type=="spec":
            self.input_size=201  #mfcc 기준
        elif self.data_type=="chroma":
            self.input_size=12  #mfcc 기준

        self.lstm=nn.LSTM(input_size=self.input_size,hidden_size=self.hidden_size, num_layers =self.num_layers,batch_first=True)
        
        if self.method=="multimodal":
            self.linear=nn.Linear(32064,256) # 밑에 forward에서 y.reshape 부분 보면 이해됨.  hidden -> 256 output vector for multimodal 
        elif self.method=="origin":
            self.linear=nn.Linear(32064,5) # hidden -> 지역 개수

    def forward(self,x):
        x = x.reshape(x.shape[0],-1,x.shape[3]) # 32,100,501 = (batch size, 100, 501)
        x = x.view(x.shape[0],x.shape[2],-1) # 32,501,100 = (batch size, sequence len, input size)
        
        #Initial hidden state
        h_0 = torch.zeros((self.num_layers, x.shape[0], self.hidden_size)).to(device=device)
        #Initial cell state
        c_0 = torch.zeros((self.num_layers, x.shape[0], self.hidden_size)).to(device=device)
        
        y, (h_n,c_n) =self.lstm(x,(h_0,c_0))
        y = y.reshape(y.shape[0],-1) # y = (batch size, sequence len * hidden_size) = (32, 501*64) 
        y = self.linear(y)

        return y

    def train_(self, train_loader, val_loader, learning_rate, epochs, device):
        self.train_accuracy = []
        self.train_loss = []
        self.val_accuracy = []
        self.val_loss = []
        self.pred_labels_train = []
        self.real_labels_train = []
        self.pred_labels_val = None
        self.real_labels_val = None

        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.loss_f=nn.CrossEntropyLoss()

        best_epoch = -1
        best_acc = -1 
        
        for epoch in range(1, epochs+1):
            total = 0
            correct = 0
            start_time = time.time()
            epoch_loss = 0.0
            epoch_acc = 0.0
            self.train()

            for batch_idx, (batch_data, batch_label) in enumerate(tqdm(train_loader)):
                
                spec, mfcc, chroma = batch_data

                if self.data_type=="mfcc":
                    batch_data=mfcc.to(device)
                elif self.data_type=="spec":
                    batch_data=spec.to(device)
                elif self.data_type=="chroma":
                    batch_data=chroma.to(device)
                
                batch_label = batch_label.to(device)

                self.optimizer.zero_grad()

                pred = self.forward(batch_data) # (batch_size, 5)
                loss = self.loss_f(pred, batch_label)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

                _, pred_indices = torch.max(pred, axis=1)
                total += batch_data.shape[0]
                batch_label = torch.max(batch_label, axis=1)[1]
                correct += pred_indices.eq(batch_label).sum().item()
                
                if epoch==epochs: #last epoch
                    self.pred_labels_train.append(pred_indices)
                    self.real_labels_train.append(batch_label)
                #for p, l in zip(pred_indices, batch_label):
                #    print(f"predicted: {index2region[p.item()]} real:{index2region[l.item()]}")
            
            end_time = time.time()
            print(f"epoch {epoch} time: {end_time-start_time}sec(s).")
            

            epoch_loss /= len(train_loader)
            self.train_loss.append(epoch_loss)
            epoch_acc = correct / total
            self.train_accuracy.append(epoch_acc)
            print(f"epoch {epoch} train accuracy: {epoch_acc}")
            print(f"epoch {epoch} loss: {epoch_loss}")  


            predicted, labels, val_loss = self.predict(val_loader, device)
            if epoch==epochs: #last epoch
                self.pred_labels_val=predicted.cpu().numpy()
                self.real_labels_val=labels.cpu().numpy()
            val_acc = predicted.eq(labels).sum().item() / len(predicted)
            print(f"epoch {epoch} val accuracy: {val_acc}")
            print(f"epoch {epoch} val loss: {val_loss}")

            if val_acc > epoch_acc:
                best_acc = val_acc
                best_epoch = epoch
                torch.save(self.state_dict(), self.best_model_save_path)
            
            self.val_accuracy.append(val_acc)
            self.val_loss.append(val_loss)
        
        self.pred_labels_train = torch.cat(self.pred_labels_train, dim=0)
        self.real_labels_train = torch.cat(self.real_labels_train, dim=0)
        self.pred_labels_train = self.pred_labels_train.cpu().numpy()
        self.real_labels_train = self.real_labels_train.cpu().numpy()
            
            
            
        print("Finish!")
        
        return best_acc, best_epoch
            
    def predict(self, test_loader, device):
        self.eval()
        labels = []
        predicted = []
        val_loss = 0.0
        with torch.no_grad():
            for batch_idx, (batch_data, batch_label) in enumerate(tqdm(test_loader)):

                spec, mfcc, chroma = batch_data
                if self.data_type=="mfcc":
                    batch_data=mfcc.to(device)
                elif self.data_type=="spec":
                    batch_data=spec.to(device)
                elif self.data_type=="chroma":
                    batch_data=chroma.to(device)
                batch_label = batch_label.to(device)
                
                pred = self.forward(batch_data)

                _, pred_indices = torch.max(pred, axis=1)
                loss = self.loss_f(pred, batch_label)
                
                val_loss += loss.item()

                predicted.append(pred_indices)
                batch_label = torch.max(batch_label, axis=1)[1]
                labels.append(batch_label)
        val_loss /= len(test_loader)
        predicted = torch.cat(predicted, dim=0)
        labels = torch.cat(labels, dim=0)

        return predicted, labels, val_loss
    
    def plot(self, which):
        
        X = [i for i in range(1, len(self.train_accuracy) + 1)]
        if which=='train_loss':
            y = self.train_loss
        elif which=='train_acc':
            y = self.train_accuracy
        elif which=='val_acc':
            y = self.val_accuracy
        elif which=='val_loss':
            y = self.val_loss
        elif which=='confusion_train':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_train, self.pred_labels_train, display_labels=region_shortening)
            plt.title('train confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()
            return
        elif which=='confusion_normalize_train':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_train, self.pred_labels_train, display_labels=region_shortening, normalize='true')
            plt.title('train confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()    
            return        
        elif which=='confusion_val':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_val, self.pred_labels_val, display_labels=region_shortening)
            plt.title('val confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()
            return
        elif which=='confusion_normalize_val':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_val, self.pred_labels_val, display_labels=region_shortening, normalize='true')
            plt.title('val confusion matrix')
            plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
            plt.show()    
            return   
            

        plt.xlabel("epoch")
        plt.ylabel(which)
        plt.title(which)
        plt.plot(X, y, label="Train loss")
        plt.savefig(f"./result/LeNet_{which}_{self.data_type}.png")
        plt.show()



# Multimodal

In [None]:
class MultiModalDialectClassifier(nn.Module):
    
    def __init__(self, hidden_dim=1024, out_dim=5, method = "ResNet", best_model_save_path="./best_model.pt"):
        super(MultiModalDialectClassifier, self).__init__()

        self.method = method
        self.best_model_save_path = best_model_save_path
        
        if self.method=="ResNet"
            self.spec_NN = ResNet18(1, model_type='spec')
            self.mfcc_NN = ResNet18(1, model_type='mfcc')
            self.chroma_NN = ResNet18(1, model_type='chroma')

        self.relu = nn.ReLU()

        self.fc1_resnet = nn.Linear(128*3, 128)
        self.fc2_resnet = nn.Linear(128,out_dim)
        self.lastlayer_resnet = nn.Sequential(self.fc1_resnet, self.relu, self.fc2_resnet)
        
        self.spec_lenet = LeNet(method="multimodal",data_type="spec")
        self.mfcc_lenet = LeNet(method="multimodal",data_type="mfcc")
        self.chroma_lenet = LeNet(method="multimodal",data_type="chroma")

        self.fc1_lenet = nn.Linear(5*3, 10)
        self.fc2_lenet = nn.Linear(10,out_dim)
        self.lastlayer_lenet = nn.Sequential(self.fc1_lenet, self.relu, self.fc2_lenet)

        self.loss_f = nn.CrossEntropyLoss()
        self.optimizer = None
        #self.softmax = nn.Softmax(dim=1)
    
    def forward(self, x):
        spec_x, mfcc_x, chroma_x = x

        if self.method == "ResNet":
            spec_y = self.spec_resnet(spec_x)
            mfcc_y = self.mfcc_resnet(mfcc_x)
            chroma_y = self.chroma_resnet(chroma_x)
            y = torch.cat([spec_y, mfcc_y, chroma_y], dim=1)
            y = y.view(y.shape[0], -1)
        
            y = self.lastlayer_resnet(y)

        if self.method == "LeNet":
            spec_y = self.spec_lenet(spec_x)
            mfcc_y = self.mfcc_lenet(mfcc_x)
            chroma_y = self.chroma_lenet(chroma_x)

            y = torch.cat([spec_y, mfcc_y, chroma_y], dim=1)
            y = y.view(y.shape[0], -1)
        
            y = self.lastlayer_lenet(y)

        #y = self.softmax(y)
        #print(y.shape)
        return y
    
    def train_(self, train_loader, val_loader, learning_rate, epochs, device):
        self.train_accuracy = []
        self.train_loss = []
        self.val_accuracy = []
        self.val_loss = []
        self.pred_labels_train = []
        self.real_labels_train = []
        self.pred_labels_val = None
        self.real_labels_val = None

        self.optimizer = optim.AdamW(self.parameters(), lr=learning_rate)


        best_epoch = -1
        best_acc = -1 
        
        for epoch in range(1, epochs+1):
            total = 0
            correct = 0
            start_time = time.time()
            epoch_loss = 0.0
            epoch_acc = 0.0
            self.train()

            for batch_idx, (batch_data, batch_label) in enumerate(tqdm(train_loader)):
                
                spec, mfcc, chroma = batch_data
                spec, mfcc, chroma = spec.to(device), mfcc.to(device), chroma.to(device)
                batch_data = (spec, mfcc, chroma)
                batch_label = batch_label.to(device)

                self.optimizer.zero_grad()

                pred = self.forward(batch_data) # (batch_size, 5)
                loss = self.loss_f(pred, batch_label)
                loss.backward()
                self.optimizer.step()

                epoch_loss += loss.item()

                _, pred_indices = torch.max(pred, axis=1)
                total += batch_data[0].shape[0]
                batch_label = torch.max(batch_label, axis=1)[1]
                correct += pred_indices.eq(batch_label).sum().item()
                
                if epoch==epochs: #last epoch
                    self.pred_labels_train.append(pred_indices)
                    self.real_labels_train.append(batch_label)
                #for p, l in zip(pred_indices, batch_label):
                #    print(f"predicted: {index2region[p.item()]} real:{index2region[l.item()]}")
            
            end_time = time.time()
            print(f"epoch {epoch} time: {end_time-start_time}sec(s).")
            

            epoch_loss /= len(train_loader)
            self.train_loss.append(epoch_loss)
            epoch_acc = correct / total
            self.train_accuracy.append(epoch_acc)
            print(f"epoch {epoch} train accuracy: {epoch_acc}")
            print(f"epoch {epoch} loss: {epoch_loss}")  


            predicted, labels, val_loss = self.predict(val_loader, device)
            if epoch==epochs: #last epoch
                self.pred_labels_val=predicted.cpu().numpy()
                self.real_labels_val=labels.cpu().numpy()
            val_acc = predicted.eq(labels).sum().item() / len(predicted)
            print(f"epoch {epoch} val accuracy: {val_acc}")
            print(f"epoch {epoch} val loss: {val_loss}")

            if val_acc > epoch_acc:
                best_acc = val_acc
                best_epoch = epoch
                torch.save(self.state_dict(), self.best_model_save_path)
            
            self.val_accuracy.append(val_acc)
            self.val_loss.append(val_loss)
        
        self.pred_labels_train = torch.cat(self.pred_labels_train, dim=0)
        self.real_labels_train = torch.cat(self.real_labels_train, dim=0)
        self.pred_labels_train = self.pred_labels_train.cpu().numpy()
        self.real_labels_train = self.real_labels_train.cpu().numpy()
            
            
            
        print("Finish!")
        
        return best_acc, best_epoch
            
    def predict(self, test_loader, device):
        self.eval()
        labels = []
        predicted = []
        val_loss = 0.0
        with torch.no_grad():
            for batch_idx, (batch_data, batch_label) in enumerate(tqdm(test_loader)):

                spec, mfcc, chroma = batch_data
                spec, mfcc, chroma = spec.to(device), mfcc.to(device), chroma.to(device)
                batch_data = (spec, mfcc, chroma)
                batch_label = batch_label.to(device)
                
                pred = self.forward(batch_data)

                _, pred_indices = torch.max(pred, axis=1)
                loss = self.loss_f(pred, batch_label)
                
                val_loss += loss.item()

                predicted.append(pred_indices)
                batch_label = torch.max(batch_label, axis=1)[1]
                labels.append(batch_label)
        val_loss /= len(test_loader)
        predicted = torch.cat(predicted, dim=0)
        labels = torch.cat(labels, dim=0)

        return predicted, labels, val_loss
    
    def plot(self, which):
        
        X = [i for i in range(1, len(self.train_accuracy) + 1)]
        if which=='train_loss':
            y = self.train_loss
        elif which=='train_acc':
            y = self.train_accuracy
        elif which=='val_acc':
            y = self.val_accuracy
        elif which=='val_loss':
            y = self.val_loss
        elif which=='confusion_train':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_train, self.pred_labels_train, display_labels=region_shortening)
            plt.title('train confusion matrix')
            plt.savefig(f"./result/model_{which}.png")
            plt.show()
            return
        elif which=='confusion_normalize_train':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_train, self.pred_labels_train, display_labels=region_shortening, normalize='true')
            plt.title('train confusion matrix')
            plt.savefig(f"./result/model_{which}.png")
            plt.show()    
            return        
        elif which=='confusion_val':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_val, self.pred_labels_val, display_labels=region_shortening)
            plt.title('val confusion matrix')
            plt.savefig(f"./result/model_{which}.png")
            plt.show()
            return
        elif which=='confusion_normalize_val':
            ConfusionMatrixDisplay.from_predictions(self.real_labels_val, self.pred_labels_val, display_labels=region_shortening, normalize='true')
            plt.title('val confusion matrix')
            plt.savefig(f"./result/model_{which}.png")
            plt.show()    
            return   
            

        plt.xlabel("epoch")
        plt.ylabel(which)
        plt.title(which)
        plt.plot(X, y, label="Train loss")
        plt.savefig(f"./result/model_{which}.png")
        plt.show()

    def getConvLayers(self, lenet):
        weights = []
        conv_layers = []
        for i in range(len(lenet)):
            if type(lenet[i]) == nn.Conv2d:
                weights.append(lenet[i].weight)
                conv_layers.append(lenet[i])
            elif type(lenet[i]) == nn.Sequential:
                for basic in lenet[i].children(): # basic block
                    for in_basic in basic.children():
                        if type(in_basic) == nn.Conv2d:
                            weights.append(in_basic.weight)
                            conv_layers.append(in_basic)
                        if type(in_basic) == nn.Sequential:
                            for in_basic_in_sequential in in_basic:
                                if type(in_basic_in_sequential) == nn.Conv2d:
                                    weights.append(in_basic_in_sequential.weight)
                                    conv_layers.append(in_basic_in_sequential)            
        return weights, conv_layers

    def extractConvLayer(self):
        children_ = list(self.children())
        spec_lenet = list(children_[0].children())
        mfcc_lenet = list(children_[1].children())
        chroma_lenet = list(children_[2].children()) 

        spec_weights, spec_layers = self.getConvLayers(spec_lenet)
        mfcc_weights, mfcc_layers = self.getConvLayers(mfcc_lenet)
        chroma_weights, chroma_layers = self.getConvLayers(chroma_lenet)

        self.spec_weights = spec_weights
        self.spec_layers = spec_layers
        self.mfcc_weights = mfcc_weights
        self.mfcc_layers = mfcc_layers
        self.chroma_weights = chroma_weights
        self.chroma_layers = chroma_layers
        
        return spec_weights, spec_layers, mfcc_weights, mfcc_layers, chroma_weights, chroma_layers

    def plotFilter(self, where='first', data_type='spec', when='before_train'):
        if data_type == 'spec':
            filters = self.spec_weights
        elif data_type == 'mfcc':
            filters = self.mfcc_weights
        elif data_type == 'chroma':
            filters = self.chroma_weights
        x_len = 0
        y_len = 0
        if where=='first': # 64x1x7x7
            filters = filters[0]
            filters = filters[:,0,:,:]
            plt.figure(figsize=(20,17))
            x_len=8
            y_len=8
        elif where=='middle': # 128x64x3x3
            filters = filters[16]
            filters = filters[:,0,:,:]
            plt.figure(figsize=(30,25))
            x_len= 16
            y_len= 8
        elif where=='last': # 256x256x3x3
            filters = filters[len(filters)-2]
            filters = filters[:,0,:,:]
            plt.figure(figsize=(40,32))
            x_len=16
            y_len=16
        for i,filter in enumerate(filters):
            plt.subplot(x_len, y_len, i+1)
            plt.imshow(filter.detach().cpu(), cmap='gray')
            plt.axis('off')
        plt.savefig(f"./result/{data_type}_filter_{where}_{when}.png")
        plt.show()
        plt.close()
    
    def plotOriginalImage(self, data):
        plt.imshow(data[0,:,:])
        plt.show()

    def plotFeatureMap(self, data, where='first', data_type='spec', when='before_train'):
        if data_type == 'spec':
            layers = self.spec_layers
            x_len = 8
            y_len = 8
        elif data_type == 'mfcc':
            layers = self.mfcc_layers
            x_len = 16
            y_len = 4
        elif data_type == 'chroma':
            layers = self.chroma_layers
            x_len = 32
            y_len = 2

        if where=='first':
            plt.figure(figsize=(20,17))
            layer = layers[0]
        elif where=='middle':
            pass
        elif where=='last':
            pass
        results = layer(data.to(device)) # 64x?x251
        for i, result in enumerate(results):
            plt.subplot(x_len, y_len, i+1)
            plt.imshow(result.detach().cpu())
            plt.axis('off')
        plt.savefig(f"./result/{data_type}_feature_map_{where}_{when}.png")
        plt.show()
        plt.close()      

        
        
        

# Cross Validation

In [None]:
def CV_Plot(title, arg, y):
        X = [i for i in range(1, len(y) + 1)]
        plt.xlabel("epoch")
        plt.ylabel(title)
        plt.title(title)
        plt.plot(X, y, label=title)
        plt.savefig(f"./result/model_{title}_{arg}.png")
        plt.show()

In [None]:
def CrossValidation(dataset, learning_rate, epochs, device, method = "ResNet",data_type="mfcc"):
        hparams = []
        for i in range(len(learning_rate)):
            for j in range(len(epochs)):
                hparams.append((learning_rate[i], epochs[j]))
        print(hparams)

        train_dataset_l = []
        validation_dataset_l = []

        kf = KFold(n_splits = 5, shuffle = True, random_state = 50)

        for train_index, test_index in kf.split(train_dataset):
            train_dataset_l.append(Subset(train_dataset,train_index))
            validation_dataset_l.append(Subset(train_dataset,test_index))

        result = []
        for i in range(len(hparams)):
            lr = hparams[i][0]
            e = hparams[i][1]

            print(f"Learning rate : {lr}, Epochs : {e}")

            last_val_acc = []
            for j in range(5):
                print(f"#{j+1} validation")
                if method=="Multimodal_ResNet"
                    model = MultiModalDialectClassifier(method="ResNet").to(device) # 매번 새로 정의해서 다시 학습해야함
                elif method=="Multimodal_LeNet"
                    model = MultiModalDialectClassifier(method="LeNet").to(device) # 매번 새로 정의해서 다시 학습해야함
                elif method=="ResNet"
                    model = ResNet18(1,output_dim=5,model_type=data_type).to(device) # 매번 새로 정의해서 다시 학습해야함
                elif method=="LeNet"
                    model = LeNet(data_type=data_type).to(device) # 매번 새로 정의해서 다시 학습해야함
                elif method=="LSTM"
                    model = LSTM(data_type=data_type).to(device) # 매번 새로 정의해서 다시 학습해야함
                train_loader = DataLoader(train_dataset_l[j], batch_size=32, shuffle=True)
                validation_loader = DataLoader(validation_dataset_l[j], batch_size=32, shuffle=True)

                model.train_(train_loader, validation_loader, lr, e, device)
                last_val_acc.append(model.val_accuracy[-1])
                    
                # model.plot('train_acc')
                # model.plot('val_acc')
            result.append((np.array(last_val_acc)).mean())

        idx = result.index(max(result))
        best_lr, best_ep = hparams[idx]
        print(f"Best Learning Rate : {best_lr}, Best Epoch : {best_ep}")

        return best_lr, best_ep

In [None]:
train_size = int(len(dataset)*0.8)
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

사용법

1. CrossValidation() 함수를 사용하여 테스트한다. 이 때 결정할 인자는 아래와 같다.
2. method를 다음 중 결정한다. {"ResNet","LeNet","LSTM","Multimodal_ResNet","Multimodal_LeNet"}

3. dat_type을 다음 중 결정한다. {"spec","mfcc","chroma"} -> multimodal은 모든 데이터를 사용하게끔 설계되어 있으므로 안 적어도 되고, 아무 데이터 타입을 적어도 상관없다.

In [None]:
lr = [0.0001, 0.0005, 0.001]
ep = [30]

result = CrossValidation(train_dataset, learning_rate= lr, epochs= ep, device = device, method = "ResNet",data_type="mfcc")

[(0.001, 15), (0.0005, 15), (0.0001, 15)]
Learning rate : 0.0001, Epochs : 15
#1 validation


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 1 time: 19.515663146972656sec(s).
epoch 1 train accuracy: 0.69875
epoch 1 loss: 0.951231119632721


100%|██████████| 13/13 [00:01<00:00,  7.55it/s]


epoch 1 val accuracy: 0.255
epoch 1 val loss: 2.683886326276339


100%|██████████| 50/50 [00:19<00:00,  2.53it/s]


epoch 2 time: 19.79457950592041sec(s).
epoch 2 train accuracy: 0.95
epoch 2 loss: 0.24728955939412117


100%|██████████| 13/13 [00:01<00:00,  7.59it/s]


epoch 2 val accuracy: 0.67
epoch 2 val loss: 0.9466624122399551


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 3 time: 19.365737676620483sec(s).
epoch 3 train accuracy: 0.966875
epoch 3 loss: 0.11268262706696987


100%|██████████| 13/13 [00:01<00:00,  7.63it/s]


epoch 3 val accuracy: 0.985
epoch 3 val loss: 0.07424391920749958


100%|██████████| 50/50 [00:19<00:00,  2.59it/s]


epoch 4 time: 19.32786202430725sec(s).
epoch 4 train accuracy: 0.989375
epoch 4 loss: 0.05574384331703186


100%|██████████| 13/13 [00:01<00:00,  7.65it/s]


epoch 4 val accuracy: 0.9425
epoch 4 val loss: 0.14519856870174408


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 5 time: 19.498624324798584sec(s).
epoch 5 train accuracy: 0.99
epoch 5 loss: 0.04978871935978532


100%|██████████| 13/13 [00:01<00:00,  7.59it/s]


epoch 5 val accuracy: 0.9975
epoch 5 val loss: 0.017396053335127924


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 6 time: 19.54596757888794sec(s).
epoch 6 train accuracy: 0.996875
epoch 6 loss: 0.01780555442906916


100%|██████████| 13/13 [00:01<00:00,  7.50it/s]


epoch 6 val accuracy: 0.95
epoch 6 val loss: 0.1442633907382305


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 7 time: 19.4309139251709sec(s).
epoch 7 train accuracy: 0.9975
epoch 7 loss: 0.013311601793393493


100%|██████████| 13/13 [00:01<00:00,  7.59it/s]


epoch 7 val accuracy: 1.0
epoch 7 val loss: 0.011773067705619793


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 8 time: 19.364396810531616sec(s).
epoch 8 train accuracy: 0.996875
epoch 8 loss: 0.012168203915935009


100%|██████████| 13/13 [00:01<00:00,  7.64it/s]


epoch 8 val accuracy: 0.985
epoch 8 val loss: 0.039212265553382725


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 9 time: 19.394173622131348sec(s).
epoch 9 train accuracy: 1.0
epoch 9 loss: 0.005739534478634596


100%|██████████| 13/13 [00:01<00:00,  7.57it/s]


epoch 9 val accuracy: 0.995
epoch 9 val loss: 0.012025574831148753


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 10 time: 19.440505504608154sec(s).
epoch 10 train accuracy: 0.995625
epoch 10 loss: 0.013106236876337789


100%|██████████| 13/13 [00:01<00:00,  7.58it/s]


epoch 10 val accuracy: 0.855
epoch 10 val loss: 0.5937711344315455


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 11 time: 19.45619535446167sec(s).
epoch 11 train accuracy: 0.991875
epoch 11 loss: 0.030395570260006933


100%|██████████| 13/13 [00:01<00:00,  7.61it/s]


epoch 11 val accuracy: 0.8775
epoch 11 val loss: 0.40211995862997496


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 12 time: 19.420964002609253sec(s).
epoch 12 train accuracy: 0.985625
epoch 12 loss: 0.04589604350039735


100%|██████████| 13/13 [00:01<00:00,  7.58it/s]


epoch 12 val accuracy: 0.895
epoch 12 val loss: 0.2489442928479268


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 13 time: 19.38740634918213sec(s).
epoch 13 train accuracy: 0.990625
epoch 13 loss: 0.029977972162887456


100%|██████████| 13/13 [00:01<00:00,  7.61it/s]


epoch 13 val accuracy: 0.9825
epoch 13 val loss: 0.04078947800175788


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 14 time: 19.413176774978638sec(s).
epoch 14 train accuracy: 0.995625
epoch 14 loss: 0.020159632824361326


100%|██████████| 13/13 [00:01<00:00,  7.59it/s]


epoch 14 val accuracy: 0.99
epoch 14 val loss: 0.014155479519663809


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 15 time: 19.393656015396118sec(s).
epoch 15 train accuracy: 0.999375
epoch 15 loss: 0.0042466492101084444


100%|██████████| 13/13 [00:01<00:00,  7.57it/s]


epoch 15 val accuracy: 0.99
epoch 15 val loss: 0.036103223978828355
Finish!
#2 validation


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 1 time: 19.450356483459473sec(s).
epoch 1 train accuracy: 0.675
epoch 1 loss: 1.0093259364366531


100%|██████████| 13/13 [00:01<00:00,  7.57it/s]


epoch 1 val accuracy: 0.2225
epoch 1 val loss: 3.185189962387085


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 2 time: 19.434871673583984sec(s).
epoch 2 train accuracy: 0.925625
epoch 2 loss: 0.3000485087931156


100%|██████████| 13/13 [00:01<00:00,  7.63it/s]


epoch 2 val accuracy: 0.36
epoch 2 val loss: 1.9831223579553456


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 3 time: 19.525851488113403sec(s).
epoch 3 train accuracy: 0.97375
epoch 3 loss: 0.11317611016333103


100%|██████████| 13/13 [00:01<00:00,  7.59it/s]


epoch 3 val accuracy: 0.78
epoch 3 val loss: 0.4944761578853314


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 4 time: 19.50551462173462sec(s).
epoch 4 train accuracy: 0.98625
epoch 4 loss: 0.057984664048999546


100%|██████████| 13/13 [00:01<00:00,  7.58it/s]


epoch 4 val accuracy: 0.9775
epoch 4 val loss: 0.08347513755926719


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 5 time: 19.405630826950073sec(s).
epoch 5 train accuracy: 0.9925
epoch 5 loss: 0.03440340627916157


100%|██████████| 13/13 [00:01<00:00,  7.59it/s]


epoch 5 val accuracy: 0.97
epoch 5 val loss: 0.06734184684375158


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 6 time: 19.43099617958069sec(s).
epoch 6 train accuracy: 0.99875
epoch 6 loss: 0.017043355819769204


100%|██████████| 13/13 [00:01<00:00,  7.59it/s]


epoch 6 val accuracy: 0.91
epoch 6 val loss: 0.31492750203380215


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 7 time: 19.428962469100952sec(s).
epoch 7 train accuracy: 0.996875
epoch 7 loss: 0.013389155869372189


100%|██████████| 13/13 [00:01<00:00,  7.58it/s]


epoch 7 val accuracy: 0.8525
epoch 7 val loss: 0.4616151343171413


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 8 time: 19.444117546081543sec(s).
epoch 8 train accuracy: 0.998125
epoch 8 loss: 0.015185547019355


100%|██████████| 13/13 [00:01<00:00,  7.65it/s]


epoch 8 val accuracy: 0.9825
epoch 8 val loss: 0.05381891611390389


100%|██████████| 50/50 [00:19<00:00,  2.58it/s]


epoch 9 time: 19.407727003097534sec(s).
epoch 9 train accuracy: 0.99
epoch 9 loss: 0.03661787353921682


100%|██████████| 13/13 [00:01<00:00,  7.55it/s]


epoch 9 val accuracy: 0.8925
epoch 9 val loss: 0.3577280227954571


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 10 time: 19.502204418182373sec(s).
epoch 10 train accuracy: 0.98875
epoch 10 loss: 0.028555810645921154


100%|██████████| 13/13 [00:01<00:00,  7.47it/s]


epoch 10 val accuracy: 0.855
epoch 10 val loss: 0.40170835646299213


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 11 time: 19.488821744918823sec(s).
epoch 11 train accuracy: 0.9975
epoch 11 loss: 0.014727623106446118


100%|██████████| 13/13 [00:01<00:00,  7.54it/s]


epoch 11 val accuracy: 0.93
epoch 11 val loss: 0.22570356554709947


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 12 time: 19.488775491714478sec(s).
epoch 12 train accuracy: 0.99625
epoch 12 loss: 0.01146024110261351


100%|██████████| 13/13 [00:01<00:00,  7.26it/s]


epoch 12 val accuracy: 0.985
epoch 12 val loss: 0.06251388702255029


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 13 time: 19.484333515167236sec(s).
epoch 13 train accuracy: 0.996875
epoch 13 loss: 0.01111790528986603


100%|██████████| 13/13 [00:01<00:00,  7.51it/s]


epoch 13 val accuracy: 0.995
epoch 13 val loss: 0.018005488001598187


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 14 time: 19.49621272087097sec(s).
epoch 14 train accuracy: 0.999375
epoch 14 loss: 0.00313763161044335


100%|██████████| 13/13 [00:01<00:00,  7.54it/s]


epoch 14 val accuracy: 0.97
epoch 14 val loss: 0.06336939818440722


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 15 time: 19.463099718093872sec(s).
epoch 15 train accuracy: 0.998125
epoch 15 loss: 0.00860650738584809


100%|██████████| 13/13 [00:01<00:00,  7.54it/s]


epoch 15 val accuracy: 0.97
epoch 15 val loss: 0.0743848088895902
Finish!
#3 validation


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 1 time: 19.446795225143433sec(s).
epoch 1 train accuracy: 0.710625
epoch 1 loss: 0.9431397819519043


100%|██████████| 13/13 [00:01<00:00,  7.60it/s]


epoch 1 val accuracy: 0.195
epoch 1 val loss: 2.4360956962292013


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 2 time: 19.436033248901367sec(s).
epoch 2 train accuracy: 0.951875
epoch 2 loss: 0.21216485455632209


100%|██████████| 13/13 [00:01<00:00,  7.58it/s]


epoch 2 val accuracy: 0.7225
epoch 2 val loss: 0.5619961092105279


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 3 time: 19.430379152297974sec(s).
epoch 3 train accuracy: 0.98
epoch 3 loss: 0.09175901487469673


100%|██████████| 13/13 [00:01<00:00,  7.57it/s]


epoch 3 val accuracy: 0.9575
epoch 3 val loss: 0.12197371572256088


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 4 time: 19.43911075592041sec(s).
epoch 4 train accuracy: 0.983125
epoch 4 loss: 0.05866363633424044


100%|██████████| 13/13 [00:01<00:00,  7.56it/s]


epoch 4 val accuracy: 0.775
epoch 4 val loss: 0.6397013710095332


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 5 time: 19.478731155395508sec(s).
epoch 5 train accuracy: 0.985
epoch 5 loss: 0.051233448199927804


100%|██████████| 13/13 [00:01<00:00,  7.55it/s]


epoch 5 val accuracy: 0.815
epoch 5 val loss: 0.5300365308156381


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 6 time: 19.449608087539673sec(s).
epoch 6 train accuracy: 0.991875
epoch 6 loss: 0.030627917954698203


100%|██████████| 13/13 [00:01<00:00,  7.46it/s]


epoch 6 val accuracy: 0.985
epoch 6 val loss: 0.037869154618909724


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 7 time: 19.44995427131653sec(s).
epoch 7 train accuracy: 0.99875
epoch 7 loss: 0.010987821351736784


100%|██████████| 13/13 [00:01<00:00,  7.49it/s]


epoch 7 val accuracy: 0.9925
epoch 7 val loss: 0.023768670713672273


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 8 time: 19.477851390838623sec(s).
epoch 8 train accuracy: 0.999375
epoch 8 loss: 0.00761506368406117


100%|██████████| 13/13 [00:01<00:00,  7.51it/s]


epoch 8 val accuracy: 0.9675
epoch 8 val loss: 0.09650607066802107


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 9 time: 19.482287645339966sec(s).
epoch 9 train accuracy: 0.9975
epoch 9 loss: 0.007999860974960029


100%|██████████| 13/13 [00:01<00:00,  7.55it/s]


epoch 9 val accuracy: 0.9175
epoch 9 val loss: 0.2675776252379784


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 10 time: 19.468300819396973sec(s).
epoch 10 train accuracy: 0.99625
epoch 10 loss: 0.01614112496608868


100%|██████████| 13/13 [00:01<00:00,  7.50it/s]


epoch 10 val accuracy: 0.6875
epoch 10 val loss: 2.2033044008108287


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 11 time: 19.464776515960693sec(s).
epoch 11 train accuracy: 0.996875
epoch 11 loss: 0.01352254286641255


100%|██████████| 13/13 [00:01<00:00,  7.52it/s]


epoch 11 val accuracy: 0.9775
epoch 11 val loss: 0.05709294461681006


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 12 time: 19.463207960128784sec(s).
epoch 12 train accuracy: 0.990625
epoch 12 loss: 0.03416285324376076


100%|██████████| 13/13 [00:01<00:00,  7.51it/s]


epoch 12 val accuracy: 0.97
epoch 12 val loss: 0.11268248996482445


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 13 time: 19.465895891189575sec(s).
epoch 13 train accuracy: 0.98625
epoch 13 loss: 0.03623605300672352


100%|██████████| 13/13 [00:01<00:00,  7.52it/s]


epoch 13 val accuracy: 0.8525
epoch 13 val loss: 0.42463822720142513


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 14 time: 19.507241249084473sec(s).
epoch 14 train accuracy: 0.996875
epoch 14 loss: 0.012412288782652468


100%|██████████| 13/13 [00:01<00:00,  7.48it/s]


epoch 14 val accuracy: 0.8825
epoch 14 val loss: 0.3404146541769688


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 15 time: 19.51024580001831sec(s).
epoch 15 train accuracy: 0.998125
epoch 15 loss: 0.008958808579482138


100%|██████████| 13/13 [00:01<00:00,  7.51it/s]


epoch 15 val accuracy: 0.9925
epoch 15 val loss: 0.016544462504008643
Finish!
#4 validation


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 1 time: 19.476829528808594sec(s).
epoch 1 train accuracy: 0.698125
epoch 1 loss: 0.9730232465267181


100%|██████████| 13/13 [00:01<00:00,  7.50it/s]


epoch 1 val accuracy: 0.1925
epoch 1 val loss: 2.0417134119914127


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 2 time: 19.491068601608276sec(s).
epoch 2 train accuracy: 0.92375
epoch 2 loss: 0.3062076485157013


100%|██████████| 13/13 [00:01<00:00,  7.50it/s]


epoch 2 val accuracy: 0.395
epoch 2 val loss: 2.2259867649811964


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 3 time: 19.467320203781128sec(s).
epoch 3 train accuracy: 0.97375
epoch 3 loss: 0.12364490918815135


100%|██████████| 13/13 [00:01<00:00,  7.52it/s]


epoch 3 val accuracy: 0.7675
epoch 3 val loss: 0.6557296147713294


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 4 time: 19.54058027267456sec(s).
epoch 4 train accuracy: 0.98875
epoch 4 loss: 0.06814380574971438


100%|██████████| 13/13 [00:01<00:00,  7.49it/s]


epoch 4 val accuracy: 0.9875
epoch 4 val loss: 0.07887862923626716


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 5 time: 19.495813608169556sec(s).
epoch 5 train accuracy: 0.9975
epoch 5 loss: 0.023115125428885223


100%|██████████| 13/13 [00:01<00:00,  7.49it/s]


epoch 5 val accuracy: 0.9725
epoch 5 val loss: 0.09365201426240113


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 6 time: 19.498871564865112sec(s).
epoch 6 train accuracy: 0.995625
epoch 6 loss: 0.025707438522949815


100%|██████████| 13/13 [00:01<00:00,  7.46it/s]


epoch 6 val accuracy: 0.9725
epoch 6 val loss: 0.1128024196682068


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 7 time: 19.55402159690857sec(s).
epoch 7 train accuracy: 0.994375
epoch 7 loss: 0.026400144239887596


100%|██████████| 13/13 [00:01<00:00,  7.49it/s]


epoch 7 val accuracy: 0.935
epoch 7 val loss: 0.1857290669129445


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 8 time: 19.477907419204712sec(s).
epoch 8 train accuracy: 0.989375
epoch 8 loss: 0.03682161566335708


100%|██████████| 13/13 [00:01<00:00,  7.54it/s]


epoch 8 val accuracy: 0.9325
epoch 8 val loss: 0.1838283659173892


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 9 time: 19.46212387084961sec(s).
epoch 9 train accuracy: 0.99625
epoch 9 loss: 0.012664621453732252


100%|██████████| 13/13 [00:01<00:00,  7.46it/s]


epoch 9 val accuracy: 0.9925
epoch 9 val loss: 0.02574358200833488


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 10 time: 19.498080730438232sec(s).
epoch 10 train accuracy: 1.0
epoch 10 loss: 0.004653859648387879


100%|██████████| 13/13 [00:01<00:00,  7.47it/s]


epoch 10 val accuracy: 0.995
epoch 10 val loss: 0.011228352717947788


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 11 time: 19.52420473098755sec(s).
epoch 11 train accuracy: 1.0
epoch 11 loss: 0.002149222237057984


100%|██████████| 13/13 [00:01<00:00,  7.49it/s]


epoch 11 val accuracy: 0.995
epoch 11 val loss: 0.012208959114594528


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 12 time: 19.525471210479736sec(s).
epoch 12 train accuracy: 1.0
epoch 12 loss: 0.002778114425600506


100%|██████████| 13/13 [00:01<00:00,  7.52it/s]


epoch 12 val accuracy: 0.995
epoch 12 val loss: 0.023688171027550615


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 13 time: 19.50728988647461sec(s).
epoch 13 train accuracy: 1.0
epoch 13 loss: 0.0014601725738612003


100%|██████████| 13/13 [00:01<00:00,  7.51it/s]


epoch 13 val accuracy: 0.995
epoch 13 val loss: 0.014020775344061594


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 14 time: 19.52264094352722sec(s).
epoch 14 train accuracy: 1.0
epoch 14 loss: 0.0018448783172061666


100%|██████████| 13/13 [00:01<00:00,  7.49it/s]


epoch 14 val accuracy: 0.9975
epoch 14 val loss: 0.008192094223117098


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 15 time: 19.543914794921875sec(s).
epoch 15 train accuracy: 1.0
epoch 15 loss: 0.0013585687466547824


100%|██████████| 13/13 [00:01<00:00,  7.51it/s]


epoch 15 val accuracy: 0.995
epoch 15 val loss: 0.014065027844760781
Finish!
#5 validation


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 1 time: 19.510453701019287sec(s).
epoch 1 train accuracy: 0.715
epoch 1 loss: 0.9377399128675461


100%|██████████| 13/13 [00:01<00:00,  7.46it/s]


epoch 1 val accuracy: 0.165
epoch 1 val loss: 2.4449292696439304


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 2 time: 19.500438690185547sec(s).
epoch 2 train accuracy: 0.950625
epoch 2 loss: 0.22301914371550083


100%|██████████| 13/13 [00:01<00:00,  7.46it/s]


epoch 2 val accuracy: 0.91
epoch 2 val loss: 0.27079260234649366


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 3 time: 19.499646186828613sec(s).
epoch 3 train accuracy: 0.971875
epoch 3 loss: 0.10079849779605865


100%|██████████| 13/13 [00:01<00:00,  7.49it/s]


epoch 3 val accuracy: 0.665
epoch 3 val loss: 0.8858812772310697


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 4 time: 19.50833821296692sec(s).
epoch 4 train accuracy: 0.98
epoch 4 loss: 0.07827692616730929


100%|██████████| 13/13 [00:01<00:00,  7.51it/s]


epoch 4 val accuracy: 0.54
epoch 4 val loss: 1.3360215058693519


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 5 time: 19.501092433929443sec(s).
epoch 5 train accuracy: 0.989375
epoch 5 loss: 0.04045665206387639


100%|██████████| 13/13 [00:01<00:00,  7.51it/s]


epoch 5 val accuracy: 0.8825
epoch 5 val loss: 0.3332167204756003


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 6 time: 19.529765367507935sec(s).
epoch 6 train accuracy: 0.986875
epoch 6 loss: 0.04351671876851469


100%|██████████| 13/13 [00:01<00:00,  7.46it/s]


epoch 6 val accuracy: 0.9125
epoch 6 val loss: 0.24780980153725699


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 7 time: 19.507490396499634sec(s).
epoch 7 train accuracy: 0.991875
epoch 7 loss: 0.024242606991901994


100%|██████████| 13/13 [00:01<00:00,  7.52it/s]


epoch 7 val accuracy: 0.77
epoch 7 val loss: 0.6446399436547205


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 8 time: 19.514423608779907sec(s).
epoch 8 train accuracy: 0.99625
epoch 8 loss: 0.01517063362058252


100%|██████████| 13/13 [00:01<00:00,  7.52it/s]


epoch 8 val accuracy: 0.9525
epoch 8 val loss: 0.12932994016087973


100%|██████████| 50/50 [00:19<00:00,  2.56it/s]


epoch 9 time: 19.52389907836914sec(s).
epoch 9 train accuracy: 1.0
epoch 9 loss: 0.0050661747192498295


100%|██████████| 13/13 [00:01<00:00,  7.55it/s]


epoch 9 val accuracy: 0.9875
epoch 9 val loss: 0.04159154535199587


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 10 time: 19.509153842926025sec(s).
epoch 10 train accuracy: 0.99875
epoch 10 loss: 0.005035793014103547


100%|██████████| 13/13 [00:01<00:00,  7.48it/s]


epoch 10 val accuracy: 0.9925
epoch 10 val loss: 0.013889737928716036


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 11 time: 19.5012047290802sec(s).
epoch 11 train accuracy: 0.99625
epoch 11 loss: 0.011244865437038242


100%|██████████| 13/13 [00:01<00:00,  7.46it/s]


epoch 11 val accuracy: 0.775
epoch 11 val loss: 0.8066369157571059


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 12 time: 19.50662922859192sec(s).
epoch 12 train accuracy: 0.996875
epoch 12 loss: 0.012485704782884568


100%|██████████| 13/13 [00:01<00:00,  7.41it/s]


epoch 12 val accuracy: 0.8875
epoch 12 val loss: 0.3929000955361586


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 13 time: 19.472636699676514sec(s).
epoch 13 train accuracy: 0.9975
epoch 13 loss: 0.009730043294257484


100%|██████████| 13/13 [00:01<00:00,  7.54it/s]


epoch 13 val accuracy: 0.91
epoch 13 val loss: 0.2579081259094752


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 14 time: 19.49197506904602sec(s).
epoch 14 train accuracy: 0.978125
epoch 14 loss: 0.06573819097829983


100%|██████████| 13/13 [00:01<00:00,  7.46it/s]


epoch 14 val accuracy: 0.705
epoch 14 val loss: 1.2387986091467051


100%|██████████| 50/50 [00:19<00:00,  2.57it/s]


epoch 15 time: 19.50408148765564sec(s).
epoch 15 train accuracy: 0.995
epoch 15 loss: 0.02490442062728107


100%|██████████| 13/13 [00:01<00:00,  7.53it/s]

epoch 15 val accuracy: 0.405
epoch 15 val loss: 3.84400252195505
Finish!





Cross Val result

In [None]:
print(result)

# Model Train & Test(feat. Plot)

사용법

1. model_choose() 함수를 사용하여 model을 결정한다. 이 때 고려할 인자는 아래와 같다.
2. method를 다음 중 결정한다. {"ResNet","LeNet","LSTM","Multimodal_ResNet","Multimodal_LeNet"}

3. dat_type을 다음 중 결정한다. {"spec","mfcc","chroma"} -> multimodal은 모든 데이터를 사용하게끔 설계되어 있으므로 안 적어도 되고, 아무 데이터 타입을 적어도 상관없다.

In [None]:
model=None

def model_choose(model,method,data_type):
    if method=="Multimodal_ResNet"
        model = MultiModalDialectClassifier(method="ResNet").to(device) 
    elif method=="Multimodal"
        model = MultiModalDialectClassifier(method="LeNet").to(device) 
    elif method=="ResNet"
        model = ResNet18(1,output_dim=5,model_type=data_type).to(device) 
    elif method=="LeNet"
        model = LeNet(data_type=data_type).to(device) 
    elif method=="LSTM"
        model = LSTM(data_type=data_type).to(device) 

model_choose(model,"Multimodal_ResNet","mfcc")

spec_weights, spec_layers, mfcc_weights, mfcc_layers, chroma_weights, chroma_layers = model.extractConvLayer()

In [None]:
model.plotOriginalImage(dataset[0][0][0])

In [None]:
model.plotFilter(where='first', data_type='spec', when='before_train')
model.plotFilter(where='middle', data_type='spec', when='before_train')
model.plotFilter(where='last', data_type='spec', when='before_train')
model.plotFilter(where='first', data_type='mfcc', when='before_train')
model.plotFilter(where='middle', data_type='mfcc', when='before_train')
model.plotFilter(where='last', data_type='mfcc', when='before_train')
model.plotFilter(where='first', data_type='chroma', when='before_train')
model.plotFilter(where='middle', data_type='chroma', when='before_train')
model.plotFilter(where='last', data_type='chroma', when='before_train')

In [None]:
model.plotFeatureMap(dataset[0][0][0], data_type='spec', when='before_train')
model.plotFeatureMap(dataset[0][0][1], data_type='mfcc', when='before_train')
model.plotFeatureMap(dataset[0][0][2], data_type='chroma', when='before_train')

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [None]:
model.train_(train_loader, test_loader, best_lr, 100, device)

In [None]:
model.plot('train_acc')
model.plot('train_loss')
model.plot('val_acc')
model.plot('val_loss')
model.plot('confusion_train')
model.plot('confusion_normalize_train')
model.plot('confusion_val')
model.plot('confusion_normalize_val')

In [None]:
spec_weights, spec_layers, mfcc_weights, mfcc_layers, chroma_weights, chroma_layers = model.extractConvLayer()

In [None]:
model.plotFilter(where='first', data_type='spec', when='after_train')
model.plotFilter(where='middle', data_type='spec', when='after_train')
model.plotFilter(where='last', data_type='spec', when='after_train')
model.plotFilter(where='first', data_type='mfcc', when='after_train')
model.plotFilter(where='middle', data_type='mfcc', when='after_train')
model.plotFilter(where='last', data_type='mfcc', when='after_train')
model.plotFilter(where='first', data_type='chroma', when='after_train')
model.plotFilter(where='middle', data_type='chroma', when='after_train')
model.plotFilter(where='last', data_type='chroma', when='after_train')

In [None]:
model.plotFeatureMap(dataset[0][0][0].to(device), data_type='spec', when='after_train')
model.plotFeatureMap(dataset[0][0][1].to(device), data_type='mfcc', when='after_train')
model.plotFeatureMap(dataset[0][0][2].to(device), data_type='chroma', when='after_train')