# 1. 데이터 가져오기

In [1]:
import torchvision.transforms as T
import torchvision
import torch
from torch.utils.data import DataLoader

download_root = './MNIST_DATASET'

mnist_transform = T.Compose([
    T.ToTensor(),
    T.Normalize((0.5, ), (0.5, ))
])

train_dataset = torchvision.datasets.MNIST(download_root, transform=mnist_transform, train=True, download=False)
test_dataset = torchvision.datasets.MNIST(download_root, transform=mnist_transform, train=False, download=False)

total_size = len(train_dataset)
train_num, valid_num = int(total_size * 0.8), int(total_size * 0.2)
train_dataset,valid_dataset = torch.utils.data.random_split(train_dataset, [train_num, valid_num])

batch_size = 32

train_dataloader = DataLoader(train_dataset, batch_size = batch_size, shuffle = True)
valid_dataloader = DataLoader(valid_dataset, batch_size = batch_size, shuffle = False)
test_dataloader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)

In [2]:
for data, label in train_dataloader:
    print(data.shape)
    print(len(train_dataloader))
    break

torch.Size([32, 1, 28, 28])
1500


# 2.모델 만들기

In [3]:
from pytorch_lightning import LightningModule, Trainer
import torch.optim as optim
import torchmetrics
import torch.nn as nn

from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_lightning.loggers import WandbLogger

import wandb

In [14]:
class GRUClassifier(LightningModule):
    def __init__(self, input_size, hidden_size, num_layers, num_classes, lr, dropout_prob):
        super().__init__()
        
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.learning_rate = lr
        
        self.criterion = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(task='multiclass', num_classes = num_classes)

        self.gru = nn.GRU(input_size, hidden_size, num_layers, batch_first=True, dropout = dropout_prob)
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        '''
        INPUT
            x : [32, 1, 28, 28]
        OUTPUT
            out : [32, 10]
        '''
        x = x.view(x.size(0), x.size(2), x.size(3))
        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        out, _ = self.gru(x, h0)
        out = out[:, -1, :]
        out = self.fc(out)
        return out

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr = self.learning_rate)
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = 1, gamma = 0.5)
        return [optimizer], [scheduler]
        
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        loss = self.criterion(y_hat, y)

        _, predict = torch.max(y_hat, dim = 1)
        acc = self.accuracy(predict, y)

        self.log(f"train_loss", loss, on_step = False, on_epoch = True, logger = True)
        self.log(f"train_acc", acc, on_step = False, on_epoch = True, logger = True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        loss = self.criterion(y_hat, y)

        _, predict = torch.max(y_hat, dim = 1)
        acc = self.accuracy(predict, y)

        self.log(f"valid_loss", loss, on_step = False, on_epoch = True, logger = True)
        self.log(f"valid_acc", acc, on_step = False, on_epoch = True, logger = True)
        
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        loss = self.criterion(y_hat, y)

        _, predict = torch.max(y_hat, dim = 1)
        acc = self.accuracy(predict, y)

        self.log(f"test_loss", loss, on_step = False, on_epoch = True, logger = True)
        self.log(f"test_acc", acc, on_step = False, on_epoch = True, logger = True)
        

    def predict_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        _, predict = torch.max(y_hat, dim = 1)

        return predict

# 3. 모델 실행 및 평가

In [15]:
# configs

configs = {
    'hidden_size' : 64,
    'num_layers': 2,
    'lr' : 0.001,
    'dropout' : 0.2,
    'max_eppoch' : 100,
    'patence' : 5
}

In [17]:
model = GRUClassifier(input_size = 28, hidden_size = 128, num_layers = 2, num_classes = 10, lr = 0.001, dropout_prob=0.2)

early_stopping = EarlyStopping(monitor = 'valid_loss', mode = 'min', patience=5)
lr_mointor = LearningRateMonitor(logging_interval = 'epoch')

wandb_logger = WandbLogger(project = 'MNIST_GRU')

trainer = Trainer(
    max_epochs = 100,
    accelerator = 'auto',
    callbacks = [early_stopping, lr_mointor],
    logger = wandb_logger
)

trainer.fit(
    model,
    train_dataloader,
    valid_dataloader
)

trainer.test(model, test_dataloader)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | criterion | CrossEntropyLoss   | 0      | train
1 | accuracy  | MulticlassAccuracy | 0      | train
2 | gru       | GRU                | 159 K  | train
3 | fc        | Linear             | 1.3 K  | train
---------------------------------------------------------
161 K     Trainable params
0         Non-trainable params
161 K     Total params
0.644     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.03542693331837654, 'test_acc': 0.9897000193595886}]