In [1]:
import pandas as pd
from tqdm import tqdm

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms


# 코드 다시 돌리기 위한 seed 고정
import random
import numpy as np
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [7]:
class MyModel(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(MyModel, self).__init__()
        """Base Model"""
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 100)
        
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x        

In [8]:
model = MyModel(3, 100).to(device)
model

MyModel(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=100, bias=True)
)

In [10]:
train_transform = transforms.Compose([    
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])        

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),(0.5, 0.5, 0.5))
])    

train = torchvision.datasets.CIFAR100(root="./", train=True, download=True, transform=train_transform)
test = torchvision.datasets.CIFAR100(root="./", train=False, download=True, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train, batch_size=256,
                                           shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test, batch_size=256,
                                          shuffle=False, num_workers=2)

optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
criterion = nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified


In [11]:
for epoch in range(100):
    model.train()
    running_loss = 0.0
    best_acc = 0.0
    best_model_wts = model.state_dict()
    print(f"train epoch: {epoch+1}----------------")
    for img, label in tqdm(train_loader):
        img = img.to(device)
        label = label.to(device)
        optimizer.zero_grad()
        output = model(img)
        loss = criterion(output, label)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()
    correct, all_data = 0,0
    print("\ntrain_loss : ", running_loss / len(train_loader))
    model.eval()
    for img, label in test_loader:
        with torch.no_grad():
            img = img.to(device)
            label = label.to(device)
            output = model(img)

            correct += torch.sum(torch.argmax(output, dim=1) == label).item()
            all_data += len(label)
    print("val_acc : ", correct / all_data)
    if correct / all_data > best_acc:
      best_acc = correct / all_data
      best_model_wts = model.state_dict()

train epoch: 1----------------


100%|██████████| 196/196 [00:13<00:00, 14.07it/s]


train_loss :  4.2152082542983855





val_acc :  0.0995
train epoch: 2----------------


100%|██████████| 196/196 [00:14<00:00, 13.86it/s]


train_loss :  3.6894896225053437





val_acc :  0.1461
train epoch: 3----------------


100%|██████████| 196/196 [00:14<00:00, 13.91it/s]


train_loss :  3.4683473791394914





val_acc :  0.1783
train epoch: 4----------------


100%|██████████| 196/196 [00:14<00:00, 13.88it/s]


train_loss :  3.3525774637047125





val_acc :  0.1888
train epoch: 5----------------


100%|██████████| 196/196 [00:14<00:00, 13.99it/s]


train_loss :  3.240210170648536





val_acc :  0.2031
train epoch: 6----------------


100%|██████████| 196/196 [00:14<00:00, 13.81it/s]


train_loss :  3.193386661763094





val_acc :  0.2084
train epoch: 7----------------


100%|██████████| 196/196 [00:14<00:00, 13.99it/s]


train_loss :  3.1144248770207774





val_acc :  0.2047
train epoch: 8----------------


100%|██████████| 196/196 [00:14<00:00, 13.77it/s]


train_loss :  3.093144809713169





val_acc :  0.2103
train epoch: 9----------------


100%|██████████| 196/196 [00:14<00:00, 13.82it/s]


train_loss :  3.0423762664502982





val_acc :  0.2356
train epoch: 10----------------


100%|██████████| 196/196 [00:14<00:00, 13.96it/s]


train_loss :  3.0029651680771186





val_acc :  0.225
train epoch: 11----------------


100%|██████████| 196/196 [00:14<00:00, 13.45it/s]


train_loss :  2.965327183811032





val_acc :  0.2203
train epoch: 12----------------


100%|██████████| 196/196 [00:14<00:00, 13.80it/s]


train_loss :  2.9229143067282073





val_acc :  0.227
train epoch: 13----------------


100%|██████████| 196/196 [00:14<00:00, 13.67it/s]


train_loss :  2.9119382257364235





val_acc :  0.2297
train epoch: 14----------------


100%|██████████| 196/196 [00:14<00:00, 13.65it/s]


train_loss :  2.885849627913261





val_acc :  0.2171
train epoch: 15----------------


100%|██████████| 196/196 [00:14<00:00, 13.94it/s]


train_loss :  2.8872552672211005





val_acc :  0.2236
train epoch: 16----------------


100%|██████████| 196/196 [00:13<00:00, 14.05it/s]


train_loss :  2.874925446753599





val_acc :  0.2216
train epoch: 17----------------


100%|██████████| 196/196 [00:14<00:00, 13.37it/s]


train_loss :  2.860319525611644





val_acc :  0.2215
train epoch: 18----------------


100%|██████████| 196/196 [00:13<00:00, 14.09it/s]


train_loss :  2.859324586634733





val_acc :  0.2268
train epoch: 19----------------


100%|██████████| 196/196 [00:13<00:00, 14.09it/s]


train_loss :  2.8240440232413158





val_acc :  0.2074
train epoch: 20----------------


100%|██████████| 196/196 [00:14<00:00, 13.52it/s]


train_loss :  2.8436548697705173





val_acc :  0.2156
train epoch: 21----------------


100%|██████████| 196/196 [00:14<00:00, 13.65it/s]


train_loss :  2.836069291951705





val_acc :  0.2162
train epoch: 22----------------


100%|██████████| 196/196 [00:14<00:00, 13.86it/s]


train_loss :  2.8333207806762384





val_acc :  0.2197
train epoch: 23----------------


 10%|▉         | 19/196 [00:01<00:17, 10.02it/s]


KeyboardInterrupt: ignored