In [1]:
import sys
import timm
import time
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
import utils.DataSet_Aug as DataSet
from utils.LabelSmooth import LabelSmoothCELoss


###### Function: Train the samples after augmentation ######

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [3]:
SAVE_PATH = './vit_cub.pth'
EPOCHS = 50
BATCH_SIZE = 4
LEARING_RATE_CLASSIFIER = 0.001
LEARING_RATE_FEATURES = 0.0002
MOMENTUM = 0.8
WEIGHT_DECAY = 5e-4
STEP_SIZE = 30
STEP_GAMMA = 0.3
SMOOTHING =0.4

In [4]:
train_loader = torch.utils.data.DataLoader(
    DataSet.load_datasets(dataset='CUB_200_2011', root='data/CUB_200_2011',
                           train=True, transform=DataSet.data_transform['train']),
    batch_size=BATCH_SIZE, shuffle=True,
    num_workers=2)  # load the trainset

test_loader = torch.utils.data.DataLoader(
    DataSet.load_datasets(dataset='CUB_200_2011', root='data/CUB_200_2011',
                          train=False, transform=DataSet.data_transform['test']),
    batch_size=BATCH_SIZE, shuffle=False,
    num_workers=2)  # load the testset


net = timm.create_model('vit_base_patch16_384',
                        pretrained=True)  # define the model
inchannel = net.head.in_features
net.head = nn.Linear(inchannel, 200)
net.to(device)

labelSmoothCELoss = LabelSmoothCELoss()

# using differential learning rate strategy
high_rate_params = []
low_rate_params = []
for name, params in net.named_parameters():
    if 'head' in name:
        high_rate_params += [params]
    else:
        low_rate_params += [params]

# define the optimizer
optimizer = optim.SGD(
    params=[
        {"params": high_rate_params, 'lr': LEARING_RATE_CLASSIFIER},
        {"params": low_rate_params},
    ],
    lr=LEARING_RATE_FEATURES, momentum=MOMENTUM, weight_decay=WEIGHT_DECAY
)

scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=STEP_SIZE,
                                      gamma=STEP_GAMMA, last_epoch=-1)

In [5]:
print(EPOCHS)
best_accuracy = 0.0

for epoch in range(EPOCHS):

    print(device)
    epoch_start = time.time()
    print('Epoch:{}'.format(epoch + 1))

    # train
    net.train()
    train_loss_list = []  # record the loss of every batch
    train_accuracy_list = []  # record the accuracy of every batch

    for step, data in enumerate(train_loader, start=0):

        images, labels = data
        images, labels = images.to(device), labels.to(device)
        logits = net(images)
        loss = labelSmoothCELoss(logits, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss = loss.item()
        train_loss_list.append(train_loss)

        prediction = torch.max(logits, dim=1)[-1]
        train_accuracy = prediction.eq(labels).cpu().float().mean()
        train_accuracy_list.append(train_accuracy)

        rate = (step + 1) / len(train_loader)
        a = "*" * int(rate * 50)
        b = "." * int((1 - rate) * 50)
        print("\rtrain loss: {:^3.0f}%[{}->{}]{:.3f}".format(
            int(rate * 100), a, b, loss), end="")  # draw the progress bar
    print()
    print('train_loss:{:.3f},train_accuracy:{:.3f}'.format(
        np.mean(train_loss_list), np.mean(train_accuracy_list)))

    # test
    net.eval()
    test_loss_list = []
    test_accuracy_list = []

    with torch.no_grad():
        for step, data in enumerate(test_loader, start=0):

            images, labels = data
            images, labels = images.to(device), labels.to(device)
            logits = net(images)
            loss = labelSmoothCELoss(logits, labels)

            test_loss = loss.item()
            test_loss_list.append(test_loss)

            prediction = torch.max(logits, dim=1)[-1]
            test_accuracy = prediction.eq(labels).cpu().float().mean()
            test_accuracy_list.append(test_accuracy)

            rate = (step + 1) / len(test_loader)
            a = "*" * int(rate * 50)
            b = "." * int((1 - rate) * 50)
            print("\rtest loss: {:^3.0f}%[{}->{}]{:.3f}".format(
                int(rate * 100), a, b, loss), end="")
        print()

        test_accuracy = np.mean(test_accuracy_list)
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            torch.save(net.state_dict(), SAVE_PATH)

        epoch_end = time.time()
        print('test_loss:{:.3f},test_accuracy:{:.3f},epoch_time:{:.3f}'.format(
            np.mean(test_loss_list), np.mean(test_accuracy_list), (epoch_end-epoch_start)))
    scheduler.step()

print('Finished Training')
print('The best accuracy : %.3f' % best_accuracy)

50
cuda:0
Epoch:1
train loss: 100%[**************************************************->]2.999
train_loss:3.997,train_accuracy:0.533
test loss: 100%[**************************************************->]3.002
test_loss:3.326,test_accuracy:0.791,epoch_time:204.812
cuda:0
Epoch:2
train loss: 100%[**************************************************->]3.441
train_loss:3.386,train_accuracy:0.772
test loss: 100%[**************************************************->]3.045
test_loss:3.200,test_accuracy:0.858,epoch_time:204.356
cuda:0
Epoch:3
train loss: 100%[**************************************************->]3.252
train_loss:3.291,train_accuracy:0.819
test loss: 100%[**************************************************->]2.916
test_loss:3.155,test_accuracy:0.879,epoch_time:204.331
cuda:0
Epoch:4
train loss: 100%[**************************************************->]3.234
train_loss:3.240,train_accuracy:0.846
test loss: 100%[**************************************************->]2.936
test_loss:3.127,