# 추가 훈련을 위한 fine-tuning

## Import Module

In [1]:
from keras.utils import to_categorical
from keras.models import Sequential
from tensorflow.keras.layers import Dense, Embedding, LSTM, Conv2D, Flatten, concatenate, Input, Dropout, MaxPooling2D
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import pandas as pd
from TextAudioMultimodalClassifier import TextAudioMultimodalClassifier
from MultimodalClassifier import MultimodalClassifier
import pickle

## 훈련 데이터 나누기

In [2]:
# 저장된 변수를 다시 불러오는 코드
with open("D:\\features_concate\\X1.pkl", 'rb') as file:
    X1 = pickle.load(file)

    # 저장된 변수를 다시 불러오는 코드
with open("D:\\features_concate\\y1.pkl", 'rb') as file:
    y1 = pickle.load(file)

# 저장된 변수를 다시 불러오는 코드
with open("D:\\features_concate\\X2.pkl", 'rb') as file:
    X2 = pickle.load(file)

# 저장된 변수를 다시 불러오는 코드
with open("D:\\features_concate\\y2.pkl", 'rb') as file:
    y2 = pickle.load(file)
    
X = np.concatenate([X1, X2])
y = np.concatenate([y1, y2])

X = X.astype('float32')

데이터 개수 10350개

In [3]:
text = X[:, :768]
image = X[:, 768:768+400]
audio = X[:, 768+400:]

print(len(text))
print(len(image))
print(len(audio))

10350
10350
10350


In [4]:
train_text, val_text, train_image, val_image, train_audio, val_audio, train_y, val_y = train_test_split(
    text, image, audio, y, test_size=0.2, random_state=42)

# 나눈 데이터를 각각 텐서로 변환
text_train_tensor = torch.tensor(train_text)
image_train_tensor = torch.tensor(train_image)
audio_train_tensor = torch.tensor(train_audio)
y_train_tensor = torch.tensor(train_y)
text_val_tensor = torch.tensor(val_text)
image_val_tensor = torch.tensor(val_image)
audio_val_tensor = torch.tensor(val_audio)
y_val_tensor = torch.tensor(val_y)

train_dataset = TensorDataset(text_train_tensor, image_train_tensor, audio_train_tensor, y_train_tensor)
val_dataset = TensorDataset(text_val_tensor, image_val_tensor, audio_val_tensor, y_val_tensor)

In [5]:
batch_size = 512
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

## 모델 불러오기 및 추가 훈련

In [2]:
# 이미지 레이어를 포함한 모델 정의
input_dim_text = 768  # 텍스트 특징 개수
input_dim_audio = 87  # 오디오 특징 개수
input_dim_image = 400  # 이미지 특징 개수
num_classes = 7  # 감정 클래스 개수
new_model = MultimodalClassifier(input_dim_text, input_dim_image, input_dim_audio, num_classes)

In [21]:
# TextAudioMultimodalClassifier 모델 가중치 로드
text_audio_model = TextAudioMultimodalClassifier(input_dim_text, input_dim_audio, num_classes)
text_audio_model.load_state_dict(torch.load("model\\text_audio_model_52.pth", map_location=torch.device('cpu')))

# 필요한 가중치 추출
text_weights = text_audio_model.text_layer.state_dict()
audio_weights = text_audio_model.audio_layer.state_dict()

# MultimodalClassifier 모델의 텍스트와 오디오 레이어에 가중치 적용
new_model.text_layer.load_state_dict(text_weights)
new_model.audio_layer.load_state_dict(audio_weights)

<All keys matched successfully>

In [22]:
# 텍스트, 오디오 레이어 고정
new_model.text_layer.requires_grad_(False)
new_model.audio_layer.requires_grad_(False)

Linear(in_features=87, out_features=512, bias=True)

In [23]:
# 손실 함수 및 옵티마이저 정의
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(new_model.parameters(), lr=1e-5)

# 조기 종료를 위한 변수 초기화
best_val_loss = float('inf')
patience = 5
early_stopping_counter = 0

# 모델 훈련
num_epochs = 10000

# 모델을 장치로 이동 (CPU 또는 GPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
new_model.to(device)

MultimodalClassifier(
  (text_layer): Linear(in_features=768, out_features=512, bias=True)
  (image_layer): Linear(in_features=400, out_features=512, bias=True)
  (audio_layer): Linear(in_features=87, out_features=512, bias=True)
  (pretrained_model): TextAudioMultimodalClassifier(
    (text_layer): Linear(in_features=768, out_features=512, bias=True)
    (audio_layer): Linear(in_features=87, out_features=512, bias=True)
    (fc): Linear(in_features=1024, out_features=7, bias=True)
  )
  (fc): Linear(in_features=1543, out_features=7, bias=True)
)

## 훈련

In [24]:
for epoch in range(num_epochs):
    new_model.train()  # 모델을 훈련 모드로 설정
    
    for text_input, image_input, audio_input, labels in train_dataloader:  # train_dataloader에 대한 반복
        text_input = text_input.to(device)
        image_input = image_input.to(device)
        audio_input = audio_input.to(device)
        labels = labels.to(device).argmax(dim=1)
        
        optimizer.zero_grad()  # 그레이디언트 초기화
        
        # 모델 예측
        outputs = new_model(text_input, image_input, audio_input)
        
        # 손실 계산
        loss = criterion(outputs, labels)
        
        # 역전파 및 가중치 업데이트
        loss.backward()
        optimizer.step()
    
    # 검증 데이터셋을 사용하여 모델 평가
    new_model.eval()  # 모델을 평가 모드로 설정
    total_correct = 0
    total_samples = 0
    val_loss = 0
    
    with torch.no_grad():
        val_loss = 0.0
        for text_input, image_input, audio_input, labels in val_dataloader:  # val_dataloader에 대한 반복
            text_input = text_input.to(device)
            image_input = image_input.to(device)
            audio_input = audio_input.to(device)
            labels = labels.to(device).argmax(dim=1)
            
            outputs = new_model(text_input, image_input, audio_input)
            probabilities = F.softmax(outputs, dim=1)  # 소프트맥스 함수를 적용
            _, predicted = torch.max(probabilities, dim=1)  # 확률값 중에서 가장 큰 값의 인덱스를 선택
            total_correct += (predicted == labels).sum().item()
            total_samples += labels.size(0)
            
            val_loss += criterion(outputs, labels).item()
            val_loss += criterion(outputs, labels).item() * text_input.size(0)

    accuracy = total_correct / total_samples
    val_loss /= len(val_dataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Accuracy: {accuracy:.4f}, Validation Loss: {val_loss:.4f}")

    # 조기 종료 확인
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stopping_counter = 0
    else:
        early_stopping_counter += 1
        print("up")
        if early_stopping_counter >= patience:
            print("Early stopping triggered!")
            break

Epoch [1/10000], Accuracy: 0.1454, Validation Loss: 1.9963
Epoch [2/10000], Accuracy: 0.1541, Validation Loss: 1.9854
Epoch [3/10000], Accuracy: 0.1589, Validation Loss: 1.9759
Epoch [4/10000], Accuracy: 0.1662, Validation Loss: 1.9680
Epoch [5/10000], Accuracy: 0.1778, Validation Loss: 1.9606
Epoch [6/10000], Accuracy: 0.1860, Validation Loss: 1.9543
Epoch [7/10000], Accuracy: 0.1923, Validation Loss: 1.9488
Epoch [8/10000], Accuracy: 0.1995, Validation Loss: 1.9436
Epoch [9/10000], Accuracy: 0.2005, Validation Loss: 1.9389
Epoch [10/10000], Accuracy: 0.2010, Validation Loss: 1.9345
Epoch [11/10000], Accuracy: 0.2014, Validation Loss: 1.9303
Epoch [12/10000], Accuracy: 0.2053, Validation Loss: 1.9263
Epoch [13/10000], Accuracy: 0.2082, Validation Loss: 1.9224
Epoch [14/10000], Accuracy: 0.2130, Validation Loss: 1.9184
Epoch [15/10000], Accuracy: 0.2179, Validation Loss: 1.9146
Epoch [16/10000], Accuracy: 0.2217, Validation Loss: 1.9107
Epoch [17/10000], Accuracy: 0.2271, Validation Lo

Epoch [138/10000], Accuracy: 0.3816, Validation Loss: 1.5899
Epoch [139/10000], Accuracy: 0.3831, Validation Loss: 1.5884
Epoch [140/10000], Accuracy: 0.3841, Validation Loss: 1.5872
Epoch [141/10000], Accuracy: 0.3841, Validation Loss: 1.5859
Epoch [142/10000], Accuracy: 0.3850, Validation Loss: 1.5849
Epoch [143/10000], Accuracy: 0.3865, Validation Loss: 1.5834
Epoch [144/10000], Accuracy: 0.3879, Validation Loss: 1.5822
Epoch [145/10000], Accuracy: 0.3870, Validation Loss: 1.5810
Epoch [146/10000], Accuracy: 0.3889, Validation Loss: 1.5795
Epoch [147/10000], Accuracy: 0.3879, Validation Loss: 1.5783
Epoch [148/10000], Accuracy: 0.3874, Validation Loss: 1.5772
Epoch [149/10000], Accuracy: 0.3894, Validation Loss: 1.5758
Epoch [150/10000], Accuracy: 0.3903, Validation Loss: 1.5745
Epoch [151/10000], Accuracy: 0.3899, Validation Loss: 1.5731
Epoch [152/10000], Accuracy: 0.3932, Validation Loss: 1.5720
Epoch [153/10000], Accuracy: 0.3937, Validation Loss: 1.5708
Epoch [154/10000], Accur

Epoch [273/10000], Accuracy: 0.4560, Validation Loss: 1.4356
Epoch [274/10000], Accuracy: 0.4575, Validation Loss: 1.4345
Epoch [275/10000], Accuracy: 0.4585, Validation Loss: 1.4333
Epoch [276/10000], Accuracy: 0.4580, Validation Loss: 1.4325
Epoch [277/10000], Accuracy: 0.4575, Validation Loss: 1.4313
Epoch [278/10000], Accuracy: 0.4609, Validation Loss: 1.4303
Epoch [279/10000], Accuracy: 0.4599, Validation Loss: 1.4293
Epoch [280/10000], Accuracy: 0.4594, Validation Loss: 1.4283
Epoch [281/10000], Accuracy: 0.4604, Validation Loss: 1.4274
Epoch [282/10000], Accuracy: 0.4599, Validation Loss: 1.4260
Epoch [283/10000], Accuracy: 0.4614, Validation Loss: 1.4250
Epoch [284/10000], Accuracy: 0.4618, Validation Loss: 1.4237
Epoch [285/10000], Accuracy: 0.4628, Validation Loss: 1.4230
Epoch [286/10000], Accuracy: 0.4643, Validation Loss: 1.4220
Epoch [287/10000], Accuracy: 0.4662, Validation Loss: 1.4208
Epoch [288/10000], Accuracy: 0.4657, Validation Loss: 1.4200
Epoch [289/10000], Accur

Epoch [408/10000], Accuracy: 0.5106, Validation Loss: 1.2999
Epoch [409/10000], Accuracy: 0.5116, Validation Loss: 1.2991
Epoch [410/10000], Accuracy: 0.5097, Validation Loss: 1.2981
Epoch [411/10000], Accuracy: 0.5082, Validation Loss: 1.2970
Epoch [412/10000], Accuracy: 0.5106, Validation Loss: 1.2962
Epoch [413/10000], Accuracy: 0.5106, Validation Loss: 1.2956
Epoch [414/10000], Accuracy: 0.5097, Validation Loss: 1.2943
Epoch [415/10000], Accuracy: 0.5121, Validation Loss: 1.2935
Epoch [416/10000], Accuracy: 0.5116, Validation Loss: 1.2922
Epoch [417/10000], Accuracy: 0.5116, Validation Loss: 1.2910
Epoch [418/10000], Accuracy: 0.5135, Validation Loss: 1.2900
Epoch [419/10000], Accuracy: 0.5130, Validation Loss: 1.2891
Epoch [420/10000], Accuracy: 0.5121, Validation Loss: 1.2882
Epoch [421/10000], Accuracy: 0.5130, Validation Loss: 1.2872
Epoch [422/10000], Accuracy: 0.5121, Validation Loss: 1.2864
Epoch [423/10000], Accuracy: 0.5126, Validation Loss: 1.2854
Epoch [424/10000], Accur

Epoch [543/10000], Accuracy: 0.5522, Validation Loss: 1.1683
Epoch [544/10000], Accuracy: 0.5536, Validation Loss: 1.1676
Epoch [545/10000], Accuracy: 0.5527, Validation Loss: 1.1667
Epoch [546/10000], Accuracy: 0.5541, Validation Loss: 1.1656
Epoch [547/10000], Accuracy: 0.5522, Validation Loss: 1.1646
Epoch [548/10000], Accuracy: 0.5560, Validation Loss: 1.1640
Epoch [549/10000], Accuracy: 0.5541, Validation Loss: 1.1631
Epoch [550/10000], Accuracy: 0.5551, Validation Loss: 1.1619
Epoch [551/10000], Accuracy: 0.5560, Validation Loss: 1.1611
Epoch [552/10000], Accuracy: 0.5546, Validation Loss: 1.1600
Epoch [553/10000], Accuracy: 0.5556, Validation Loss: 1.1591
Epoch [554/10000], Accuracy: 0.5556, Validation Loss: 1.1582
Epoch [555/10000], Accuracy: 0.5575, Validation Loss: 1.1575
Epoch [556/10000], Accuracy: 0.5565, Validation Loss: 1.1565
Epoch [557/10000], Accuracy: 0.5604, Validation Loss: 1.1563
Epoch [558/10000], Accuracy: 0.5546, Validation Loss: 1.1547
Epoch [559/10000], Accur

Epoch [678/10000], Accuracy: 0.6082, Validation Loss: 1.0512
Epoch [679/10000], Accuracy: 0.6063, Validation Loss: 1.0507
Epoch [680/10000], Accuracy: 0.6048, Validation Loss: 1.0500
Epoch [681/10000], Accuracy: 0.6048, Validation Loss: 1.0488
Epoch [682/10000], Accuracy: 0.6058, Validation Loss: 1.0482
Epoch [683/10000], Accuracy: 0.6068, Validation Loss: 1.0473
Epoch [684/10000], Accuracy: 0.6058, Validation Loss: 1.0467
Epoch [685/10000], Accuracy: 0.6053, Validation Loss: 1.0464
Epoch [686/10000], Accuracy: 0.6072, Validation Loss: 1.0456
Epoch [687/10000], Accuracy: 0.6077, Validation Loss: 1.0445
Epoch [688/10000], Accuracy: 0.6111, Validation Loss: 1.0438
Epoch [689/10000], Accuracy: 0.6106, Validation Loss: 1.0427
Epoch [690/10000], Accuracy: 0.6101, Validation Loss: 1.0418
Epoch [691/10000], Accuracy: 0.6106, Validation Loss: 1.0410
Epoch [692/10000], Accuracy: 0.6087, Validation Loss: 1.0406
Epoch [693/10000], Accuracy: 0.6063, Validation Loss: 1.0401
Epoch [694/10000], Accur

Epoch [813/10000], Accuracy: 0.6425, Validation Loss: 0.9585
Epoch [814/10000], Accuracy: 0.6444, Validation Loss: 0.9580
Epoch [815/10000], Accuracy: 0.6454, Validation Loss: 0.9567
Epoch [816/10000], Accuracy: 0.6430, Validation Loss: 0.9564
Epoch [817/10000], Accuracy: 0.6459, Validation Loss: 0.9557
Epoch [818/10000], Accuracy: 0.6425, Validation Loss: 0.9556
Epoch [819/10000], Accuracy: 0.6469, Validation Loss: 0.9548
Epoch [820/10000], Accuracy: 0.6454, Validation Loss: 0.9538
Epoch [821/10000], Accuracy: 0.6459, Validation Loss: 0.9535
Epoch [822/10000], Accuracy: 0.6464, Validation Loss: 0.9528
Epoch [823/10000], Accuracy: 0.6483, Validation Loss: 0.9522
Epoch [824/10000], Accuracy: 0.6478, Validation Loss: 0.9514
Epoch [825/10000], Accuracy: 0.6488, Validation Loss: 0.9509
Epoch [826/10000], Accuracy: 0.6483, Validation Loss: 0.9506
Epoch [827/10000], Accuracy: 0.6483, Validation Loss: 0.9499
Epoch [828/10000], Accuracy: 0.6498, Validation Loss: 0.9490
Epoch [829/10000], Accur

Epoch [947/10000], Accuracy: 0.6754, Validation Loss: 0.8849
Epoch [948/10000], Accuracy: 0.6758, Validation Loss: 0.8843
Epoch [949/10000], Accuracy: 0.6749, Validation Loss: 0.8841
Epoch [950/10000], Accuracy: 0.6778, Validation Loss: 0.8837
Epoch [951/10000], Accuracy: 0.6739, Validation Loss: 0.8829
Epoch [952/10000], Accuracy: 0.6763, Validation Loss: 0.8827
Epoch [953/10000], Accuracy: 0.6768, Validation Loss: 0.8817
Epoch [954/10000], Accuracy: 0.6744, Validation Loss: 0.8815
Epoch [955/10000], Accuracy: 0.6783, Validation Loss: 0.8809
Epoch [956/10000], Accuracy: 0.6768, Validation Loss: 0.8803
Epoch [957/10000], Accuracy: 0.6787, Validation Loss: 0.8798
Epoch [958/10000], Accuracy: 0.6768, Validation Loss: 0.8795
Epoch [959/10000], Accuracy: 0.6754, Validation Loss: 0.8791
Epoch [960/10000], Accuracy: 0.6778, Validation Loss: 0.8788
Epoch [961/10000], Accuracy: 0.6778, Validation Loss: 0.8778
Epoch [962/10000], Accuracy: 0.6783, Validation Loss: 0.8777
Epoch [963/10000], Accur

Epoch [1080/10000], Accuracy: 0.6923, Validation Loss: 0.8266
up
Epoch [1081/10000], Accuracy: 0.6937, Validation Loss: 0.8259
Epoch [1082/10000], Accuracy: 0.6932, Validation Loss: 0.8258
Epoch [1083/10000], Accuracy: 0.6932, Validation Loss: 0.8254
Epoch [1084/10000], Accuracy: 0.6937, Validation Loss: 0.8249
Epoch [1085/10000], Accuracy: 0.6918, Validation Loss: 0.8243
Epoch [1086/10000], Accuracy: 0.6942, Validation Loss: 0.8241
Epoch [1087/10000], Accuracy: 0.6947, Validation Loss: 0.8242
up
Epoch [1088/10000], Accuracy: 0.6947, Validation Loss: 0.8230
Epoch [1089/10000], Accuracy: 0.6942, Validation Loss: 0.8234
up
Epoch [1090/10000], Accuracy: 0.6952, Validation Loss: 0.8223
Epoch [1091/10000], Accuracy: 0.6966, Validation Loss: 0.8222
Epoch [1092/10000], Accuracy: 0.6952, Validation Loss: 0.8217
Epoch [1093/10000], Accuracy: 0.6961, Validation Loss: 0.8214
Epoch [1094/10000], Accuracy: 0.6947, Validation Loss: 0.8213
Epoch [1095/10000], Accuracy: 0.6961, Validation Loss: 0.8208

Epoch [1211/10000], Accuracy: 0.7097, Validation Loss: 0.7795
Epoch [1212/10000], Accuracy: 0.7097, Validation Loss: 0.7795
Epoch [1213/10000], Accuracy: 0.7106, Validation Loss: 0.7791
Epoch [1214/10000], Accuracy: 0.7101, Validation Loss: 0.7790
Epoch [1215/10000], Accuracy: 0.7116, Validation Loss: 0.7780
Epoch [1216/10000], Accuracy: 0.7126, Validation Loss: 0.7783
up
Epoch [1217/10000], Accuracy: 0.7121, Validation Loss: 0.7781
up
Epoch [1218/10000], Accuracy: 0.7106, Validation Loss: 0.7781
up
Epoch [1219/10000], Accuracy: 0.7121, Validation Loss: 0.7772
Epoch [1220/10000], Accuracy: 0.7116, Validation Loss: 0.7772
Epoch [1221/10000], Accuracy: 0.7126, Validation Loss: 0.7769
Epoch [1222/10000], Accuracy: 0.7106, Validation Loss: 0.7765
Epoch [1223/10000], Accuracy: 0.7111, Validation Loss: 0.7758
Epoch [1224/10000], Accuracy: 0.7116, Validation Loss: 0.7759
up
Epoch [1225/10000], Accuracy: 0.7121, Validation Loss: 0.7753
Epoch [1226/10000], Accuracy: 0.7111, Validation Loss: 0.7

Epoch [1342/10000], Accuracy: 0.7285, Validation Loss: 0.7421
Epoch [1343/10000], Accuracy: 0.7242, Validation Loss: 0.7421
up
Epoch [1344/10000], Accuracy: 0.7261, Validation Loss: 0.7415
Epoch [1345/10000], Accuracy: 0.7266, Validation Loss: 0.7418
up
Epoch [1346/10000], Accuracy: 0.7246, Validation Loss: 0.7411
Epoch [1347/10000], Accuracy: 0.7271, Validation Loss: 0.7409
Epoch [1348/10000], Accuracy: 0.7266, Validation Loss: 0.7407
Epoch [1349/10000], Accuracy: 0.7285, Validation Loss: 0.7405
Epoch [1350/10000], Accuracy: 0.7304, Validation Loss: 0.7399
Epoch [1351/10000], Accuracy: 0.7290, Validation Loss: 0.7398
Epoch [1352/10000], Accuracy: 0.7275, Validation Loss: 0.7397
Epoch [1353/10000], Accuracy: 0.7271, Validation Loss: 0.7397
up
Epoch [1354/10000], Accuracy: 0.7290, Validation Loss: 0.7395
Epoch [1355/10000], Accuracy: 0.7275, Validation Loss: 0.7393
Epoch [1356/10000], Accuracy: 0.7295, Validation Loss: 0.7385
Epoch [1357/10000], Accuracy: 0.7314, Validation Loss: 0.7381

Epoch [1473/10000], Accuracy: 0.7386, Validation Loss: 0.7123
Epoch [1474/10000], Accuracy: 0.7396, Validation Loss: 0.7114
Epoch [1475/10000], Accuracy: 0.7382, Validation Loss: 0.7118
up
Epoch [1476/10000], Accuracy: 0.7386, Validation Loss: 0.7109
Epoch [1477/10000], Accuracy: 0.7391, Validation Loss: 0.7115
up
Epoch [1478/10000], Accuracy: 0.7386, Validation Loss: 0.7106
Epoch [1479/10000], Accuracy: 0.7396, Validation Loss: 0.7105
Epoch [1480/10000], Accuracy: 0.7401, Validation Loss: 0.7108
up
Epoch [1481/10000], Accuracy: 0.7386, Validation Loss: 0.7106
up
Epoch [1482/10000], Accuracy: 0.7391, Validation Loss: 0.7101
Epoch [1483/10000], Accuracy: 0.7391, Validation Loss: 0.7102
up
Epoch [1484/10000], Accuracy: 0.7406, Validation Loss: 0.7096
Epoch [1485/10000], Accuracy: 0.7401, Validation Loss: 0.7100
up
Epoch [1486/10000], Accuracy: 0.7391, Validation Loss: 0.7092
Epoch [1487/10000], Accuracy: 0.7377, Validation Loss: 0.7096
up
Epoch [1488/10000], Accuracy: 0.7401, Validation 

Epoch [1603/10000], Accuracy: 0.7435, Validation Loss: 0.6880
Epoch [1604/10000], Accuracy: 0.7444, Validation Loss: 0.6883
up
Epoch [1605/10000], Accuracy: 0.7440, Validation Loss: 0.6881
up
Epoch [1606/10000], Accuracy: 0.7440, Validation Loss: 0.6875
Epoch [1607/10000], Accuracy: 0.7435, Validation Loss: 0.6882
up
Epoch [1608/10000], Accuracy: 0.7435, Validation Loss: 0.6872
Epoch [1609/10000], Accuracy: 0.7440, Validation Loss: 0.6876
up
Epoch [1610/10000], Accuracy: 0.7440, Validation Loss: 0.6871
Epoch [1611/10000], Accuracy: 0.7440, Validation Loss: 0.6869
Epoch [1612/10000], Accuracy: 0.7449, Validation Loss: 0.6866
Epoch [1613/10000], Accuracy: 0.7444, Validation Loss: 0.6867
up
Epoch [1614/10000], Accuracy: 0.7454, Validation Loss: 0.6864
Epoch [1615/10000], Accuracy: 0.7449, Validation Loss: 0.6860
Epoch [1616/10000], Accuracy: 0.7420, Validation Loss: 0.6864
up
Epoch [1617/10000], Accuracy: 0.7440, Validation Loss: 0.6860
up
Epoch [1618/10000], Accuracy: 0.7440, Validation 

Epoch [1733/10000], Accuracy: 0.7483, Validation Loss: 0.6701
Epoch [1734/10000], Accuracy: 0.7469, Validation Loss: 0.6701
up
Epoch [1735/10000], Accuracy: 0.7478, Validation Loss: 0.6702
up
Epoch [1736/10000], Accuracy: 0.7488, Validation Loss: 0.6695
Epoch [1737/10000], Accuracy: 0.7464, Validation Loss: 0.6697
up
Epoch [1738/10000], Accuracy: 0.7478, Validation Loss: 0.6693
Epoch [1739/10000], Accuracy: 0.7478, Validation Loss: 0.6692
Epoch [1740/10000], Accuracy: 0.7488, Validation Loss: 0.6693
up
Epoch [1741/10000], Accuracy: 0.7483, Validation Loss: 0.6688
Epoch [1742/10000], Accuracy: 0.7473, Validation Loss: 0.6689
up
Epoch [1743/10000], Accuracy: 0.7488, Validation Loss: 0.6688
up
Epoch [1744/10000], Accuracy: 0.7488, Validation Loss: 0.6688
up
Epoch [1745/10000], Accuracy: 0.7464, Validation Loss: 0.6686
Epoch [1746/10000], Accuracy: 0.7469, Validation Loss: 0.6682
Epoch [1747/10000], Accuracy: 0.7498, Validation Loss: 0.6687
up
Epoch [1748/10000], Accuracy: 0.7483, Validati

## 모델 저장

In [25]:
torch.save(new_model.state_dict(), "model\\text_image_audio_model.pth")

In [26]:
labels.size(0)

22

In [28]:
ai_model = MultimodalClassifier(input_dim_text, input_dim_image, input_dim_audio, num_classes)
ai_model.load_state_dict(torch.load("model\\text_image_audio_model.pth", map_location=torch.device('cpu')))
ai_model.eval()

MultimodalClassifier(
  (text_layer): Linear(in_features=768, out_features=512, bias=True)
  (image_layer): Linear(in_features=400, out_features=512, bias=True)
  (audio_layer): Linear(in_features=87, out_features=512, bias=True)
  (pretrained_model): TextAudioMultimodalClassifier(
    (text_layer): Linear(in_features=768, out_features=512, bias=True)
    (audio_layer): Linear(in_features=87, out_features=512, bias=True)
    (fc): Linear(in_features=1024, out_features=7, bias=True)
  )
  (fc): Linear(in_features=1543, out_features=7, bias=True)
)