In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms

# 훈련셋, 테스트셋 준비
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainSet = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
trainLoader = torch.utils.data.DataLoader(
    trainSet, batch_size=128, shuffle=True, num_workers=2)

testSet = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)
testLoader = torch.utils.data.DataLoader(testSet, batch_size=100,
                                         shuffle=False, num_workers=2)

# BasicBlock 구현
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inputCh, outputCh, stride=1):
        super(BasicBlock, self).__init__()
        self.firstConv = nn.Conv2d(inputCh, outputCh, kernel_size=3,
                                   stride=stride, padding=1)
        self.firstBN = nn.BatchNorm2d(outputCh)

        self.secondConv = nn.Conv2d(outputCh, outputCh, kernel_size=3,
                                    stride=1, padding=1)
        self.secondBN = nn.BatchNorm2d(outputCh)

        self.shortcut = nn.Sequential()
        if stride != 1 or inputCh != outputCh:
            self.shortcut = nn.Sequential(
                nn.Conv2d(inputCh, outputCh, kernel_size=1, stride=stride),
                nn.BatchNorm2d(outputCh))

    def forward(self, x):
        out = F.relu(self.firstBN(self.firstConv(x)))
        out = self.secondBN(self.secondConv(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

# ResNet20 구현
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.inputCh = 16

        self.conv = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.BN = nn.BatchNorm2d(16)

        self.firstLayer = self.makeLayer(block, 16, num_blocks[0], stride=1)
        self.secondLayer = self.makeLayer(block, 32, num_blocks[1], stride=2)
        self.thirdLayer = self.makeLayer(block, 64, num_blocks[2], stride=2)

        self.averagePooling = nn.AdaptiveAvgPool2d((1, 1))
        self.fullyConnectedLayer = nn.Linear(64 * block.expansion, num_classes)

    def makeLayer(self, block, outputCh, blocks, stride):
        strides = [stride] + [1] * (blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.inputCh, outputCh, stride))
            self.inputCh = outputCh * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.BN(self.conv(x)))
        out = self.firstLayer(out)
        out = self.secondLayer(out)
        out = self.thirdLayer(out)
        out = self.averagePooling(out)
        out = out.view(out.size(0), -1)
        out = self.fullyConnectedLayer(out)
        return out

def ResNet20():
    return ResNet(BasicBlock, [3, 3, 3])

net = ResNet20()

# 필요 시 GPU로 이동
device = 'cuda' if torch.cuda.is_available() else 'cpu'
net = net.to(device)

# 손실 함수 및 옵티마이저 설정
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.1)

# 훈련 및 테스트 시작
def train(epoch):
    net.train()
    trainLoss = 0
    correct = 0
    total = 0
    for batchIdx, (inputs, targets) in enumerate(trainLoader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        trainLoss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        if batchIdx % 100 == 0:
            print(f'Epoch: {epoch} | Batch: {batchIdx} | Loss: {trainLoss/(batchIdx+1)} | Acc: {100.*correct/total}')


def test(epoch):
    net.eval()
    testLoss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batchIdx, (inputs, targets) in enumerate(testLoader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            testLoss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    print(f'Test Loss: {testLoss/(batchIdx+1)} | Acc: {100.*correct/total}')


for epoch in range(0, 200):
    train(epoch)
    test(epoch)
    scheduler.step()