https://cryptosalamander.tistory.com/156

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
import os
import torchvision.models as models

In [2]:
#Simple Learning Rate Scheduler
def lr_scheduler(optimizer, epoch):
    lr = learning_rate
    if epoch >= 50:
        lr /= 10
    if epoch >= 100:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

# Xavier
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [3]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='/home/sldev1/Project/hyeongeun_test/data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='/home/sldev1/Project/hyeongeun_test/data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=8)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
device = 'cuda'
model = ResNet50()
# ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 중에 택일하여 사용

In [5]:
model.apply(init_weights)
model = model.to(device)

  


In [6]:
learning_rate = 0.1
num_epoch = 150
model_name = 'model.pth'

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)

train_loss = 0
valid_loss = 0
correct = 0
total_cnt = 0
best_acc = 0

In [7]:
# Train
for epoch in range(num_epoch):
    print(f"====== { epoch+1} epoch of { num_epoch } ======")
    model.train()
    lr_scheduler(optimizer, epoch)
    train_loss = 0
    valid_loss = 0
    correct = 0
    total_cnt = 0
    # Train Phase
    for step, batch in enumerate(train_loader):
        #  input and target
        batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        
        logits = model(batch[0])
        loss = loss_fn(logits, batch[1])
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        _, predict = logits.max(1)
        
        total_cnt += batch[1].size(0)
        correct +=  predict.eq(batch[1]).sum().item()
        
        if step % 100 == 0 and step != 0:
            print(f"\n====== { step } Step of { len(train_loader) } ======")
            print(f"Train Acc : { correct / total_cnt }")
            print(f"Train Loss : { loss.item() / batch[1].size(0) }")
            
    correct = 0
    total_cnt = 0
    
# Test Phase
    with torch.no_grad():
        model.eval()
        for step, batch in enumerate(test_loader):
            # input and target
            batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
            total_cnt += batch[1].size(0)
            logits = model(batch[0])
            valid_loss += loss_fn(logits, batch[1])
            _, predict = logits.max(1)
            correct += predict.eq(batch[1]).sum().item()
        valid_acc = correct / total_cnt
        print(f"\nValid Acc : { valid_acc }")    
        print(f"Valid Loss : { valid_loss / total_cnt }")

        if(valid_acc > best_acc):
            best_acc = valid_acc
            torch.save(model, model_name)
            print("Model Saved!")


Train Acc : 0.1233756188118812
Train Loss : 0.00889238528907299

Valid Acc : 0.1791
Valid Loss : 0.025907427072525024
Model Saved!

Train Acc : 0.17531714108910892
Train Loss : 0.009605797939002514

Valid Acc : 0.1769
Valid Loss : 0.010647803544998169

Train Acc : 0.21283261138613863
Train Loss : 0.008961260318756104

Valid Acc : 0.2357
Valid Loss : 0.011087639257311821
Model Saved!

Train Acc : 0.2382038985148515
Train Loss : 0.008227312006056309

Valid Acc : 0.2795
Valid Loss : 0.008371998555958271
Model Saved!

Train Acc : 0.27223855198019803
Train Loss : 0.007088575512170792

Valid Acc : 0.2925
Valid Loss : 0.007736568339169025
Model Saved!

Train Acc : 0.3040300123762376
Train Loss : 0.007244982291013002

Valid Acc : 0.3321
Valid Loss : 0.007469364907592535
Model Saved!

Train Acc : 0.32267172029702973
Train Loss : 0.007363359443843365

Valid Acc : 0.33
Valid Loss : 0.007634988985955715

Train Acc : 0.34316986386138615
Train Loss : 0.006815379485487938

Valid Acc : 0.3625
Valid L


Valid Acc : 0.5967
Valid Loss : 0.004571280442178249

Train Acc : 0.6088722153465347
Train Loss : 0.004512564279139042

Valid Acc : 0.6051
Valid Loss : 0.004573273006826639

Train Acc : 0.6128171410891089
Train Loss : 0.004382946528494358

Valid Acc : 0.5947
Valid Loss : 0.004585244692862034

Train Acc : 0.6160659034653465
Train Loss : 0.0039051787462085485

Valid Acc : 0.5811
Valid Loss : 0.004826956894248724

Train Acc : 0.6227181311881188
Train Loss : 0.0037035252898931503

Valid Acc : 0.6291
Valid Loss : 0.004172794986516237
Model Saved!

Train Acc : 0.6296410891089109
Train Loss : 0.0038057276979088783

Valid Acc : 0.6389
Valid Loss : 0.004066056106239557
Model Saved!

Train Acc : 0.6354424504950495
Train Loss : 0.005127256736159325

Valid Acc : 0.6295
Valid Loss : 0.004506598226726055

Train Acc : 0.6589959777227723
Train Loss : 0.003448112867772579

Valid Acc : 0.6676
Valid Loss : 0.0037310225889086723
Model Saved!

Train Acc : 0.6719910272277227
Train Loss : 0.0035965810529887


Train Acc : 0.7143796410891089
Train Loss : 0.003115569707006216

Valid Acc : 0.6997
Valid Loss : 0.0034831701777875423

Train Acc : 0.7157332920792079
Train Loss : 0.0031657929066568613

Valid Acc : 0.7003
Valid Loss : 0.003462959313765168

Train Acc : 0.71484375
Train Loss : 0.0032194380182772875

Valid Acc : 0.6981
Valid Loss : 0.003477667924016714

Train Acc : 0.7189820544554455
Train Loss : 0.002933535259217024

Valid Acc : 0.7033
Valid Loss : 0.0034693111665546894
Model Saved!

Train Acc : 0.7192914603960396
Train Loss : 0.0034519610926508904

Valid Acc : 0.7011
Valid Loss : 0.0035395727027207613

Train Acc : 0.7231203589108911
Train Loss : 0.002863251604139805

Valid Acc : 0.7009
Valid Loss : 0.0034828956704586744

Train Acc : 0.7227722772277227
Train Loss : 0.0028730034828186035

Valid Acc : 0.7032
Valid Loss : 0.003453131066635251

Train Acc : 0.7243579826732673
Train Loss : 0.003317043650895357

Valid Acc : 0.6995
Valid Loss : 0.0037217368371784687

Train Acc : 0.72478341584


Train Acc : 0.752707301980198
Train Loss : 0.0026408142875880003

Valid Acc : 0.7121
Valid Loss : 0.0034969858825206757

Train Acc : 0.7486850247524752
Train Loss : 0.002759059891104698

Valid Acc : 0.7123
Valid Loss : 0.003387891221791506

Train Acc : 0.7499226485148515
Train Loss : 0.00251149688847363

Valid Acc : 0.7092
Valid Loss : 0.003415689803659916

Train Acc : 0.75
Train Loss : 0.002358286641538143

Valid Acc : 0.7135
Valid Loss : 0.003388724522665143

Train Acc : 0.7465578589108911
Train Loss : 0.003010728396475315

Valid Acc : 0.7081
Valid Loss : 0.0036071299109607935

Train Acc : 0.7487237004950495
Train Loss : 0.0031356976833194494

Valid Acc : 0.7131
Valid Loss : 0.0034654124174267054

Train Acc : 0.7522818688118812
Train Loss : 0.0029960665851831436

Valid Acc : 0.7103
Valid Loss : 0.003396150190383196

Train Acc : 0.7519337871287128
Train Loss : 0.0026506572030484676

Valid Acc : 0.7122
Valid Loss : 0.0033728897105902433

Train Acc : 0.7537515470297029
Train Loss : 0.0