<a href="https://colab.research.google.com/github/repairedserver/ResNet/blob/main/ResNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [4]:
import torchvision
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from main import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
import os
import torchvision.models as models

In [5]:
def lr_scheduler(optimizer, epoch):
    lr = learning_rate
    if epoch >= 50:
        lr /= 10
    if epoch >= 100:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)
        m.bias.data.fill_(0.01)

In [6]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True, num_workers=8)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=256, shuffle=False, num_workers=8)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified




In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = ResNet50()

In [8]:
model.apply(init_weights)
model = model.to(device)

  torch.nn.init.xavier_uniform(m.weight)


In [9]:
learning_rate = 0.1
num_epoch = 150
model_name = 'model.pth'

loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=0.0001)

train_loss = 0
valid_loss = 0
correct = 0
total_cnt = 0
best_acc = 0

In [10]:
for epoch in range(num_epoch):
    print(f"====== { epoch+1} epoch of { num_epoch } ======")
    model.train()
    lr_scheduler(optimizer, epoch)
    train_loss = 0
    valid_loss = 0
    correct = 0
    total_cnt = 0

    for step, batch in enumerate(train_loader):
        batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
        optimizer.zero_grad()
        
        logits = model(batch[0])
        loss = loss_fn(logits, batch[1])
        loss.backward()
        
        optimizer.step()
        train_loss += loss.item()
        _, predict = logits.max(1)
        
        total_cnt += batch[1].size(0)
        correct +=  predict.eq(batch[1]).sum().item()
        
        if step % 100 == 0 and step != 0:
            print(f"\n====== { step } Step of { len(train_loader) } ======")
            print(f"Train Accuracy : { correct / total_cnt }")
            print(f"Train Loss : { loss.item() / batch[1].size(0) }")
            
    correct = 0
    total_cnt = 0
    
    with torch.no_grad():
        model.eval()
        for step, batch in enumerate(test_loader):
            batch[0], batch[1] = batch[0].to(device), batch[1].to(device)
            total_cnt += batch[1].size(0)
            logits = model(batch[0])
            valid_loss += loss_fn(logits, batch[1])
            _, predict = logits.max(1)
            correct += predict.eq(batch[1]).sum().item()
        valid_acc = correct / total_cnt
        print(f"\nValid Accuracy : { valid_acc }")    
        print(f"Valid Loss : { valid_loss / total_cnt }")

        if(valid_acc > best_acc):
            best_acc = valid_acc
            torch.save(model, model_name)
            print("Model Saved!")






Train Accuracy : 0.13613861386138615
Train Loss : 0.0091056227684021

Valid Accuracy : 0.1587
Valid Loss : 0.011182684451341629
Model Saved!

Train Accuracy : 0.18193069306930693
Train Loss : 0.008575672283768654

Valid Accuracy : 0.2157
Valid Loss : 0.05970612168312073
Model Saved!

Train Accuracy : 0.23352413366336633
Train Loss : 0.007505644112825394

Valid Accuracy : 0.2869
Valid Loss : 0.00885414145886898
Model Saved!

Train Accuracy : 0.2863165222772277
Train Loss : 0.007537904195487499

Valid Accuracy : 0.3368
Valid Loss : 0.008145949803292751
Model Saved!

Train Accuracy : 0.31764387376237624
Train Loss : 0.007209503557533026

Valid Accuracy : 0.3571
Valid Loss : 0.007333721965551376
Model Saved!

Train Accuracy : 0.3370590965346535
Train Loss : 0.006919067353010178

Valid Accuracy : 0.3606
Valid Loss : 0.006981063634157181
Model Saved!

Train Accuracy : 0.34924195544554454
Train Loss : 0.006870843470096588

Valid Accuracy : 0.3977
Valid Loss : 0.006675952114164829
Model Saved