<a href="https://colab.research.google.com/github/younguk072023/Pytorch_study/blob/main/Alexnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
#데이터 로딩

def get_train_valid_loader(data_dir, batch_size, augment, random_seed, valid_size=0.1, shuffle=True):
  normalize = transforms.Normalize (
      mean = [0.4914, 0.4822, 0.4465,], #RGB채널별 평균값
      std = [0.2023,0.1994,0.2010],#RGB 채널별 표준 편차
  )

#define transforms
  valid_transform = transforms.Compose([
      transforms.Resize((227,227)),
      transforms.ToTensor(),
      normalize,
  ])

  if augment:
    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        normalize,

    ])

  else:
    train_transform = transforms.Compose([
        transforms.Resize((227,227)),
        transforms.ToTensor(),
        normalize,

      ])
#load the dataset
  train_dataset = datasets.CIFAR10(
      root=data_dir, train=True,
      download=True, transform=train_transform,
      )

  valid_dataset = datasets.CIFAR10(
      root=data_dir, train=True,
      download=True, transform=valid_transform,

  )

  num_train = len(train_dataset)
  indices = list(range(num_train))
  split = int(np.floor(valid_size * num_train))

  if shuffle:
    np.random.seed(random_seed)
    np.random.shuffle(indices)

  train_idx, valid_idx = indices[split:], indices[:split]
  train_sampler = SubsetRandomSampler(train_idx)
  valid_sampler = SubsetRandomSampler(valid_idx)

  train_loader = torch.utils.data.DataLoader(
      train_dataset, batch_size=batch_size, sampler=train_sampler
  )

  valid_loader = torch.utils.data.DataLoader(
      valid_dataset, batch_size=batch_size, sampler=valid_sampler
  )

  return(train_loader, valid_loader)

def get_test_loader(data_dir, batch_size, shuffler=True):
  normalize = transforms.Normalize(
      mean=[0.485, 0.456,0.406],
      std=[0.229,0.224,0.225],
  )
  transform = transforms.Compose([
      transforms.Resize((227,227)),
      transforms.ToTensor(),
      normalize,
  ])

  dataset = datasets.CIFAR10(
      root=data_dir, train=False,
      download=True, transform=transform,
  )

  data_loader = torch.utils.data.DataLoader(
      dataset, batch_size=batch_size, shuffle = True
  )

  return data_loader

train_loader, valid_loader = get_train_valid_loader(data_dir = './data',  batch_size = 64,
                       augment = False,     random_seed = 1)

test_loader = get_test_loader(data_dir = './data',
                              batch_size = 64)







In [None]:
class AlexNet(nn.Module):
  def __init__(Self, num_classes=10):#출력 클래스 개수 정해줌 10
    super(AlexNet, Self).__init__()
    Self.layer1 = nn.Sequential(
        nn.Conv2d(3,96,kernel_size=11, stride=4,padding=0),#입력채널 3, 출력채널 96, 필터크기 11x1
        nn.BatchNorm2d(96),#배치정규화 정규화하여 학습을 안정하는 역할 채널 수 96
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2))
    Self.layer2 = nn.Sequential(
        nn.Conv2d(96,256,kernel_size=5, stride=1, padding=2), #출력값 계산은
        nn.BatchNorm2d(256),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2))
    Self.layer3 = nn.Sequential(
        nn.Conv2d(256,384,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(384),
        nn.ReLU())
    Self.layer4=nn.Sequential(
        nn.Conv2d(384,384,kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(384),
        nn.ReLU())
    Self.layer5=nn.Sequential(
        nn.Conv2d(384, 256, kernel_size=3,stride=1,padding=1),
        nn.BatchNorm2d(256),
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3,stride=2))
    Self.fc = nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(9216,4096),
        nn.ReLU())
    Self.fc1=nn.Sequential(
        nn.Dropout(0.5),
        nn.Linear(4096,4096),
        nn.ReLU())
    Self.fc2=nn.Sequential(
        nn.Linear(4096,num_classes))

  def forward(self,x):
      out = self.layer1(x)
      out = self.layer2(out)
      out = self.layer3(out)
      out = self.layer4(out)
      out = self.layer5(out)
      out = out.reshape(out.size(0), -1)
      out = self.fc(out)
      out = self.fc1(out)
      out = self.fc2(out)
      return out




In [None]:
num_classes=10
num_epochs = 20
batch_size = 64
learning_rate=0.005

model = AlexNet(num_classes).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(),lr=learning_rate, weight_decay=0.005,momentum=0.9)

total_step = len(train_loader)

In [None]:
total_step = len(train_loader)

for epoch in range(num_epochs):
  for i,(images, labels) in enumerate(train_loader): #enumerate (인덱스, 값)
    images=images.to(device)
    labels=labels.to(device)

    outputs=model(images)
    loss=criterion(outputs, labels)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  print('Epoch [{}/{}], step[{}/{}], Loss : {:.4f}'.format(epoch+1,num_epochs,i+1,total_step,loss.item()))

  with torch.no_grad():
    correct=0
    total=0
    for images, labels in valid_loader:
      images=images.to(device)
      labels=labels.to(device)
      outputs=model(images)
      _, predicted = torch.max(outputs.data,1)
      total +=labels.size(0)
      correct+=(predicted==labels).sum().item()
      del images, labels, outputs
    print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))


    with torch.no_grad():
      correct=0
      total=0
      for images, labels in test_loader:
        images=images.to(device)
        labels=labels.to(device)
        outputs=model(images)
        _, predicted = torch.max(outputs.data,1)
        total+=labels.size(0)
        correct+=(predicted==labels).sum().item()
        del images, labels, outputs

      print('Accuracy of the network on the {} test images: {} %'.format(10000, 100 * correct / total))




