In [49]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split

import sys

import matplotlib.pyplot as plt
import torch.nn.functional as F
import pytorch_lightning as pl

import torchmetrics
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from dataloader import LoadData

from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

from models.CNN import CNN
from models.DNN import DNN

import yaml


0. Setting

In [50]:
config = yaml.load(open("./config.yaml", "r"), Loader=yaml.FullLoader)
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')

1. dataset

In [51]:
load_data = LoadData(config)
    
train_dataloader = load_data.train_dataloader()
val_dataloader = load_data.val_dataloader()
test_dataloader =load_data.test_dataloader()
predict_dataloader = load_data.predict_dataloader()

accuracy=torchmetrics.classification.MulticlassAccuracy(num_classes = 10).to(DEVICE)


2. Model

In [52]:
model_name = config["model"]["model_name"]
if model_name=="CNN":
    model=CNN(config)
elif model_name=='DNN':
    model=DNN(config)
elif model_name=="RESNET":
    model=RESNET()
else:
    print("No model")
    sys.exit()
    
model = model.to(DEVICE)

3. Loss & 4. Optimizer

In [53]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

5. Train

In [54]:
def train(model, train_dataloader, optimizer, log_interval = 5):
    model.train()
    for batch_idx, (x, y) in enumerate(train_dataloader):
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        y_hat = model(x)
        loss = criterion(y_hat, y)
    
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print(f"Train Epoch: {Epochs} [{batch_idx*len(x)}/{len(train_dataloader.dataset)}({100*batch_idx/len(train_dataloader):.0f}%)]\tTrain Loss: {loss.item():.6f}")
            

In [62]:
def eval(model, test_dataloader):
    model.eval()
    loss = 0
    
    with torch.no_grad():
        for x, y in test_dataloader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            y_hat = model(x)
            
            loss +=criterion(y_hat,y).item()
            accuracy(y_hat,y)
           
    loss /= len(test_dataloader.dataset)
    
    acc = accuracy.compute()
    accuracy.reset()
    return loss, acc

In [56]:
max_epochs = config['train']['max_epochs']

for Epochs in range(1, max_epochs+1):
    train(model, train_dataloader, optimizer, log_interval = 200)
    val_loss, val_acc = eval(model, val_dataloader)
    print(f"\n[Epoch: {Epochs}], \tval Loss: {val_loss:.4f}\tval Accuracy: {val_acc:.4f}")

model_save_path = config["train"]["model_save_dir"] + "/model_state_dict.pt"
torch.save(model.state_dict(), model_save_path)
        


[Epoch: 1], 	val Loss: 0.0024	val Accuracy: 0.9764

[Epoch: 2], 	val Loss: 0.0013	val Accuracy: 0.9864

[Epoch: 3], 	val Loss: 0.0014	val Accuracy: 0.9889


Test

In [64]:
model_name = config["model"]["model_name"]
if model_name=="CNN":
    model=CNN(config)
elif model_name=='DNN':
    model=DNN(config)
elif model_name=="RESNET":
    model=RESNET()
else:
    print("No model")
    sys.exit()
    
model = model.to(DEVICE)

model.load_state_dict(torch.load(model_save_path))
test_loss, test_acc = eval(model, test_dataloader)
print(f"Test Loss: {test_loss:.4f}\tTest Accuracy: {test_acc:.4f}")


0.00859987735748291
0.011519250925630331
0.01368005876429379
0.018931448692455888
0.019067008266574703
0.0213104209251469
0.0225982981355628
0.05101049678341951
0.05787212435097899
0.05875138942792546
0.14035630975558888
0.14523381930484902
0.14696136424026918
0.20175264710269403
0.26559511506638955
0.2661702695040731
0.2663410475652199
0.269482822594
0.36910089987213723
0.37003596129943617
0.39011713679064997
0.4423534367524553
0.4744872730516363
0.5838699993037153
0.5851518053386826
0.5863291772839148
0.6092384311195929
0.6353109947231133
0.7291340849187691
0.9949747643258888
1.030600650497945
1.1938142015424091
1.2226635387924034
1.2228919051995035
1.2254705469531473
1.2338967168179806
1.3013141372648533
1.3023082420404535
1.49952140005189
1.5726629153068643
1.6272334002132993
1.6300430011178833
1.6312518475751858
1.6434085780929308
1.648410976602463
1.7163199491042178
1.718535733380122
1.8294323356531095
1.8372714426077437
1.8376512564136647
1.8738291374756955
1.8854574933066033
1.