In [1]:
import torch
import torch.nn as nn

<img src = "https://images.velog.io/images/junyoung9696/post/3137e50c-b52f-4cdd-8ae8-2faf497efe84/r10.png" width = 400>

<br/>

<img src = "https://editor.analyticsvidhya.com/uploads/38371XTo6Q.png" width = 600, height=800>

In [12]:
class BasicBlock(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=3):
        super(BasicBlock,self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,padding=1)#spatial size유지
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels,out_channels ,kernel_size=kernel_size, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()
        
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1) #? stride= 2 case? pooling case?
        
    def forward(self,x):
        identity = x #shorcut을 위한 identity저장
        
        #F(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        
        identity = self.shortcut(identity)
        
        x = x+identity
        x = self.relu(x) #합치고 렐루
        
        return x
        

In [16]:
class ResNet(nn.Module):  #Resent18? 34?50? ... Vannila ResNet3?
    def __init__(self,num_classes=10):
        super().__init__()
        
        #기본 블록
        self.block1 = BasicBlock(in_channels=3,out_channels=64)
        self.block2 = BasicBlock(in_channels=64, out_channels=128)
        self.block3 = BasicBlock(in_channels=128, out_channels = 256)
        
        self.pool = nn.AvgPool2d(kernel_size=2, stride=2)#spatial size 32->16
        
        #Classifier
        self.fc1 = nn.Linear(in_features=4096,out_features=2048) #4*4*256 = 4096  ? CIFAR 32*32...
        self.fc2 = nn.Linear(in_features=2048, out_features=512)
        self.fc3 = nn.Linear(in_features=512, out_features=num_classes)
        
        self.relu = nn.ReLU()
        
    def forward(self,x):
        x = self.block1(x) # spatial s = 32 (kerenl3 padding1)
        x = self.pool(x)  # 16
        x = self.block2(x) # 16
        x = self.pool(x) # 8
        x = self.block3(x) # 8
        x = self.pool(x)  #  4 -> 256*4*4  -> 4096
        
        x = torch.flatten(x, start_dim=1) # dim=0 batch는 무시하고 chw 를 flatten (Batch, 4096) input
        
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        
        return x
        
    
        


In [17]:
from tqdm import tqdm

from torchvision.datasets.cifar import CIFAR10
from torchvision.transforms import Compose,ToTensor
from torchvision.transforms import RandomHorizontalFlip,RandomCrop 
from torchvision.transforms import Normalize
from torch.utils.data.dataloader import DataLoader

from torch.optim.adam import Adam

transforms = Compose([
    RandomCrop((32,32),padding=4),
    RandomHorizontalFlip(p=0.5),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225])
])

trainset = CIFAR10(root='./', train=True, download=True, transform=transforms)
testset = CIFAR10(root='./', train =False, download =True, transform= transforms)

train_loader = DataLoader(trainset , batch_size=32, shuffle=True)
test_loader = DataLoader(testset, batch_size=32,shuffle =False)

device = "cuda" if torch.cuda.is_available() else "cpu"

model = ResNet(num_classes=10)
model.to(device)


Files already downloaded and verified
Files already downloaded and verified


ResNet(
  (block1): BasicBlock(
    (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Conv2d(3, 64, kernel_size=(1, 1), stride=(1, 1))
  )
  (block2): BasicBlock(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU()
    (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1))
  )
  (block3): BasicBlock(
    (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(

In [19]:
lr = 1e-3
optim = Adam(model.parameters(),lr=lr)

for epoch in range(5):
    iterator = tqdm(train_loader)
    for data,label in iterator:
        optim.zero_grad()
        
        preds = model(data.to(device))
        
        loss = nn.CrossEntropyLoss()(preds,label.to(device))
        loss.backward()
        optim.step()
        
        iterator.set_description(f'epoch:{epoch+1} loss:{loss.item()}')
        

torch.save(model.state_dict(),"ResNet.pth")
    

epoch:1 loss:0.5612490177154541: 100%|██████████| 1563/1563 [06:23<00:00,  4.07it/s]
epoch:2 loss:0.4626854956150055: 100%|██████████| 1563/1563 [06:34<00:00,  3.96it/s] 
epoch:3 loss:0.23807354271411896: 100%|██████████| 1563/1563 [06:21<00:00,  4.09it/s]
epoch:4 loss:0.2558000683784485: 100%|██████████| 1563/1563 [06:20<00:00,  4.11it/s] 
epoch:5 loss:0.45090895891189575: 100%|██████████| 1563/1563 [06:16<00:00,  4.15it/s]


In [20]:
model.load_state_dict(torch.load("ResNet.pth",map_location=device))

num_corr=0

with torch.no_grad():
    for data,label in tqdm(test_loader):
        
        output = model(data.to(device))
        preds = output.data.max(1)[1]
        corr = preds.eq(label.to(device).data).sum().item()
        num_corr += corr
        
    print(f'Accuracy"{num_corr/len(testset)}')
    
        

100%|██████████| 313/313 [00:26<00:00, 11.65it/s]

Accuracy"0.8197



