In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import random
import torch.nn as nn

from PIL import Image
from torch.utils.data import  Dataset, DataLoader, random_split
from torchvision import transforms
from torchinfo import summary
from module.train import fit , train , test_binary_classification
from module.utils import plot_fit_result 

# 모델 생성
class driverstatusModel(nn.Module):  

    def __init__(self, dropout_rate):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1,32 ,kernel_size=3 , stride=1, padding = 'same' ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.MaxPool2d(kernel_size=2 , stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32,64 ,kernel_size=3 , stride=1, padding = 1 ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.MaxPool2d(kernel_size=2 , stride=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64,128 ,kernel_size=3 , stride=1, padding = 'same' ),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.MaxPool2d(kernel_size=2 , stride=2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(128,64 ,kernel_size=3 , stride=1, padding = 'same' ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.MaxPool2d(kernel_size=2 , stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=64*9*9 , out_features=128 ),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(128,1)
        )

    def forward(self,X):
        out = self.conv1(X)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.classifier(out)
        return torch.sigmoid(out)

# 저장한 모델을 불러오는 함수

def load_model(load_model_path):
    status_model = torch.load("models/driverstatusmodels.pth" ,  map_location=torch.device("cpu"))
    status_model = status_model.to(device)

    return status_model

if __name__ == '__main__':

    # 학습한 모델 불러오기
    load_model_path = "models/driverstatusmodels.pth"

    status_model = load_model(load_model_path)

    # 추론 : 학습된 모델로 test 로 검증하기
    loss, acc = test_binary_classification(test_loader,status_model,nn.BCELoss(),device=device)
    
    print(f"Loss : {loss} , Accuracy: {acc}")