In [None]:
import torch
from torchvision import datasets, transforms
import torchvision
import time
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import glob
from torchvision import transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader


class cDataset(Dataset):
    def __init__(self, path, type, train=True, transform=None):
        if train:
            self.path = path + '/Training'
        else:
            self.path = path + '/Validation'
        if type == 0:
            self.path_n = self.path+'/각막궤양/무'
            self.path_y = self.path+'/각막궤양/유'
        elif type == 1:
            self.path_n = self.path+'/각막부골편/무'
            self.path_y = self.path+'/각막부골편/유'
        elif type == 2:
            self.path_n = self.path+'/결막염/무'
            self.path_y = self.path+'/결막염/유'
        elif type == 3:
            self.path_n = self.path+'/비궤양성각막염/무'
            self.path_y = self.path+'/비궤양성각막염/유'
        else:
            self.path_n = self.path+'/안검염/무'
            self.path_y = self.path+'/안검염/유'

        self.n_img_list = glob.glob(self.path_n + '/*.jpg')
        self.y_img_list = glob.glob(self.path_y + '/*.jpg')

        self.transform = transform

        self.img_list = self.n_img_list + self.y_img_list
        self.class_list = [0] * len(self.n_img_list) + [1] * len(self.y_img_list) 

    def __len__(self):
        return len(self.img_list)
    
    def __getitem__(self, idx):
        img_path = self.img_list[idx]
        label = self.class_list[idx]
        img = Image.open(img_path)

        if self.transform is not None:
            img = self.transform(img)

        return img, label
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # device object


def train(net, epoch, optimizer, criterion, train_dataloader):
    print('[ Train epoch: %d ]' % epoch)
    net.train() 
    train_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(train_dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad() # 기울기(gradient) 초기화

        outputs = net(inputs) # 모델 입력하여 결과 계산
        loss = criterion(outputs, targets) # 손실(loss) 값 계산
        loss.backward() # 역전파를 통해 기울기(gradient) 계산

        optimizer.step() # 계산된 기울기를 이용해 모델 가중치 업데이트
        train_loss += loss.item()
        _, predicted = outputs.max(1)

        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print('Train accuarcy:', 100. * correct / total)
    print('Train average loss:', train_loss / total)
    return (100. * correct / total, train_loss / total)


def validate(net, epoch, criterion, val_dataloader):
    print('[ Validation epoch: %d ]' % epoch)
    net.eval() # 모델을 평가 모드로 설정
    val_loss = 0
    correct = 0
    total = 0
    for batch_idx, (inputs, targets) in enumerate(val_dataloader):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = net(inputs) # 모델 입력하여 결과 계산
        val_loss += criterion(outputs, targets).item()
        _, predicted = outputs.max(1)

        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print('Accuarcy:', 100. * correct / total)
    print('Average loss:', val_loss / total)
    return (100. * correct / total, val_loss / total)
# 네트워크에 데이터셋을 입력하여 혼동 행렬(confusion matrix)을 계산하는 함수
def get_confusion_matrix(net, num_classes, data_loader):
    net.eval() # 모델을 평가 모드로 설정
    confusion_matrix = torch.zeros(num_classes, num_classes, dtype=torch.int32)

    for batch_idx, (inputs, targets) in enumerate(data_loader):
        inputs, targets = inputs.to(device), targets.to(device)

        outputs = net(inputs)
        _, predicted = outputs.max(1)

        for t, p in zip(targets.view(-1), predicted.view(-1)):
            confusion_matrix[t.long(), p.long()] += 1

    return confusion_matrix


train_transforms  = transforms.Compose([
    transforms.Resize((250, 250)), 
    transforms.RandomCrop((224, 224)),
    transforms.RandomHorizontalFlip(0.5),
    transforms.RandomVerticalFlip(0.5),
    transforms.RandomRotation(5),
    transforms.ToTensor(),
])

valid_transforms  = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])