In [36]:
from torchvision import datasets, transforms
from my_resnet import CifarResNet20
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
from tensorboardX import SummaryWriter
from pytorch_model_summary import summary

In [81]:
def train(model, data_loader, criterion, optimizer, epoch, device, log_step=50):
    # switch to train mode
    model.train()
    
    train_loss = 0.0
    correct = 0
    total = 0
    
    for data, target in tqdm(data_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)

        loss = criterion(output, target)
        train_loss += loss

        correct += (target == output.argmax(dim=1)).sum().item()
        total += target.size(0)
        
        loss.backward()
        optimizer.step()
    
    return train_loss.item(), 100 * correct / total

def test(model, data_loader, criterion, device):
    model.eval()
    
    train_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data, target in tqdm(data_loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            train_loss += loss

            correct += (target == output.argmax(dim=1)).sum().item()
            total += target.size(0)
    
    return train_loss.item(), 100 * correct / total

In [82]:
batch_size = 512

cifar10_train = datasets.CIFAR10(root='../notebooks/data', train=True, download=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]))

train_loader = torch.utils.data.DataLoader(
        cifar10_train,
        batch_size=batch_size, shuffle=True,
        num_workers=4, pin_memory=True
)

cifar10_test = datasets.CIFAR10(root='../notebooks/data', train=False, download=True,
       transform=transforms.Compose([
           transforms.ToTensor(),
           transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
       ]))

test_loader = torch.utils.data.DataLoader(
        cifar10_test,
        batch_size=batch_size, shuffle=False,
        num_workers=4, pin_memory=True
)

Files already downloaded and verified
Files already downloaded and verified


In [83]:
model = CifarResNet20()
print(summary(model, torch.zeros((1, 3, 32, 32)), show_input=True))

-----------------------------------------------------------------------
      Layer (type)         Input Shape         Param #     Tr. Param #
          Conv2d-1      [1, 3, 32, 32]             432             432
     BatchNorm2d-2     [1, 16, 32, 32]              32              32
            ReLU-3     [1, 16, 32, 32]               0               0
      BasicBlock-4     [1, 16, 32, 32]           4,672           4,672
      BasicBlock-5     [1, 16, 32, 32]           4,672           4,672
      BasicBlock-6     [1, 16, 32, 32]           4,672           4,672
      BasicBlock-7     [1, 16, 32, 32]          14,528          14,528
      BasicBlock-8     [1, 32, 16, 16]          18,560          18,560
      BasicBlock-9     [1, 32, 16, 16]          18,560          18,560
     BasicBlock-10     [1, 32, 16, 16]          57,728          57,728
     BasicBlock-11       [1, 64, 8, 8]          73,984          73,984
     BasicBlock-12       [1, 64, 8, 8]          73,984          73,984
     

In [84]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
model.to(device)

criterion = nn.CrossEntropyLoss().cuda()
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = torch.optim.SGD(model.parameters(), 0.1, momentum=0.9, weight_decay=1e-4)

# lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 150], last_epoch=args.start_epoch - 1)
writer = SummaryWriter("logs")

for epoch in range(1, 100):
    print("start epoch: {0:d}\n".format(epoch))
    loss_train, accuracy_train = train(model, train_loader, criterion, optimizer, epoch, device)
    loss_test,  accuracy_test = test(model, test_loader, criterion, device)
    print("loss at epoch {0:d}: {1:f}, accuracy: {2:f}".format(epoch, loss_test, accuracy_test))
    
    writer.add_scalars('loss', {'train_metric': loss_train, 'test_metric': loss_test}, epoch)
    writer.add_scalars('accuracy', {'train_accuracy': accuracy_train, 'test_accuracy': accuracy_test}, epoch)
    
writer.close()

cuda:0
start epoch: 1



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


loss at epoch 1: 26.093813, accuracy: 53.580000
start epoch: 2



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


loss at epoch 2: 22.695240, accuracy: 60.480000
start epoch: 3



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


loss at epoch 3: 21.943359, accuracy: 62.220000
start epoch: 4



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


loss at epoch 4: 17.226107, accuracy: 70.340000
start epoch: 5



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


loss at epoch 5: 19.913029, accuracy: 67.830000
start epoch: 6



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


loss at epoch 6: 16.490837, accuracy: 72.790000
start epoch: 7



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


loss at epoch 7: 14.560717, accuracy: 75.530000
start epoch: 8



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


loss at epoch 8: 12.669332, accuracy: 78.320000
start epoch: 9



HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=98.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))


loss at epoch 9: 12.761278, accuracy: 78.270000
