In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import random
import torch.nn as nn

from PIL import Image
from torch.utils.data import  Dataset, DataLoader, random_split
from torchvision import transforms
from torchinfo import summary
from module.train import fit , train , test_binary_classification
from module.utils import plot_fit_result 



class driverstatusModel(nn.Module):   # 모델 생성 

    def __init__(self, dropout_rate):
        super().__init__()
        self.conv1 = nn.Sequential(
            nn.Conv2d(1,32 ,kernel_size=3 , stride=1, padding = 'same' ),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.MaxPool2d(kernel_size=2 , stride=2)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(32,64 ,kernel_size=3 , stride=1, padding = 1 ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.MaxPool2d(kernel_size=2 , stride=2)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(64,128 ,kernel_size=3 , stride=1, padding = 'same' ),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.MaxPool2d(kernel_size=2 , stride=2)
        )
        self.conv4 = nn.Sequential(
            nn.Conv2d(128,64 ,kernel_size=3 , stride=1, padding = 'same' ),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.MaxPool2d(kernel_size=2 , stride=2)
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(in_features=64*9*9 , out_features=128 ),
            nn.ReLU(),
            nn.Dropout(p=dropout_rate),
            nn.Linear(128,1)
        )

    def forward(self,X):
        out = self.conv1(X)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.classifier(out)
        return torch.sigmoid(out)

def train_set_loader(root,batch_size):  
    """
    Resize = ((150,150)),
    ToTensor(),
    Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))

    train , valid , test = 8 : 1 : 1
    """

    ### 데이터 전처리 ###
    trans = transforms.Compose([transforms.Resize((150,150)),    
                                transforms.Grayscale(num_output_channels=1),
                           transforms.ToTensor(),
                            transforms.Normalize([0.5],[0.5])
                           ])

    ### Dataset (train , valid , test set 나눔 8:1:1) ###
    train_set = torchvision.datasets.ImageFolder(root = root,transform=trans) 
    dataset_size = len(train_set)
    train_size = int(dataset_size * 0.8)
    validation_size = int(dataset_size * 0.1)
    test_size = dataset_size - train_size - validation_size
    train_set, valid_set, test_set = random_split(train_set , [train_size, validation_size, test_size] ,generator=torch.Generator().manual_seed(42))

    ### DataLoader 생성 ###
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(valid_set, batch_size=batch_size)
    test_loader = DataLoader(test_set, batch_size = batch_size)
    
    return train_set, valid_set, test_set, train_loader , valid_loader, test_loader




if __name__ == '__main__':

    device = "mps" if torch.backends.mps.is_available() else "cpu"
    print(device)

    root = "datasets/archive"   # 데이터 셋주소
    ## Dataset, DataLoader 값 반환 ##
    batch_size = 64
    train_set, valid_set, test_set, train_loader , valid_loader, test_loader =  train_set_loader(root,batch_size)
    
    epochs = 50
    lr = 0.001

    # 모델생성
    model = driverstatusModel(0.3).to(device)  
    # loss 함수
    loss_fn = nn.BCELoss()
    # Optimizer
    optimizer= torch.optim.Adam(model.parameters() , lr=lr)
    
    print(summary(model, (batch_size, 1, 150, 150)))
    save_path = "models/driverstatusmodels.pth"

    # 학습 => 결과반환 (train_loss_list,train_accuracy_list,valid_loss_list,valid_accuracy_list)
    train_loss_list,train_accuracy_list,valid_loss_list,valid_accuracy_list = fit(train_loader, valid_loader ,model , loss_fn, optimizer,epochs,
            save_model_path=save_path, device=device , mode='binary')

    # Loss ,Acuuracy 시각화
    plot_result = plot_fit_result(train_loss_list,train_accuracy_list,valid_loss_list,valid_accuracy_list)