In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torchsummary

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

from resnet_sw import *

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
resnet = resnet_sw([3, 4, 6, 3]).to(device)

# 모델이 잘 통과하는지 확인
x = torch.randn(1, 3, 32, 32).to(device)
output = resnet(x)
print(output.size())

# 모델 summary
torchsummary.summary(resnet, (3, 32, 32), device=device.type)

In [None]:
batch_size = 16
learning_rate = 0.0001
num_epoch = 30

In [None]:
CIFAR10_train=datasets.CIFAR10("../DataSets/", train=True, transform=transforms.ToTensor(), target_transform=None, download=True)
CIFAR10_test=datasets.CIFAR10("../DataSets/", train=False, transform=transforms.ToTensor(), target_transform=None, download=True)

In [None]:
train_loader = DataLoader(CIFAR10_train, batch_size=batch_size, shuffle=True, num_workers=10, drop_last=True)
test_loader = DataLoader(CIFAR10_test, batch_size=batch_size, shuffle=True, num_workers=10, drop_last=True)
train_test = []
for j, [img, label] in enumerate(test_loader):
  if j >= 20: break
  train_test.append([img, label])
test_loader = DataLoader(CIFAR10_test, batch_size=batch_size, shuffle=True, num_workers=10, drop_last=True)

In [None]:
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
resnet = resnet_sw([3, 4, 6, 3]).to(device)

loss_func=nn.CrossEntropyLoss()
optimizer=optim.Adam(resnet.parameters(), lr=learning_rate)
loss_array = []
accuracy_array = []

test_model_flag = False

In [None]:
def get_accuracy():
  correct = 0
  total = 0
  with torch.no_grad():
      val_datas = train_test if  test_model_flag is True else test_loader
      for img, label in val_datas:
          x = img.to(device)
          y_ = label.to(device)

          output = resnet.forward(x)
          _, output_index = torch.max(output, 1)

          total += label.size(0)
          correct += (output_index == y_).sum().float()
      return (correct/total)

In [None]:
train_datas = train_test if test_model_flag is True else train_loader
for i in range(num_epoch):
    print("epoch", i, "is start")
    for j, [img, label] in enumerate(train_datas):
        x = img.to(device)
        y_ = label.to(device)

        optimizer.zero_grad()
        output= resnet.forward(x)
        loss = loss_func(output, y_)
        loss.backward()
        optimizer.step()

        if ((j == 0) if test_model_flag is True else (j % 100 == 0)):
            loss_array.append(loss.detach().cpu().numpy())
            aa = get_accuracy()
            print("Accuracy of Test Data: {}, ".format(100*aa), end=" "), print("loss: {}".format(loss_array[-1]))
            accuracy_array.append(aa.detach().cpu().numpy())

In [None]:
import matplotlib.pyplot as plt
plt.plot(list(range(len(accuracy_array))), accuracy_array[:])
plt.plot(loss_array[:])
plt.show()

In [None]:
correct = 0
total = 0
with torch.no_grad():
    for img, label in test_loader:
        x = img.to(device)
        y_ = label.to(device)

        output = resnet.forward(x)
        _, output_index = torch.max(output, 1)

        total += label.size(0)
        correct += (output_index == y_).sum().float()

    print("Accuracy of Test Data: {}".format(100*correct/total))