# StochNorm

In [1]:
import argparse
import os
from time import time

import numpy as np
import torch
import torch.nn as nn
import torchvision
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm

from module.backbone import ResNet50_F
from module.stoch_norm import StochNorm2d
from utils.transforms import get_transforms
from utils.tools import AccuracyMeter, TenCropsTest 

## read configs

In [2]:
def get_configs():
    parser = argparse.ArgumentParser(
        description='Pytorch Stochastic Normalization Training')

    # train
    parser.add_argument('--gpu', default=0, type=int,
                        help='GPU num for training')
    parser.add_argument('--seed', type=int, default=2020)

    parser.add_argument('--batch_size', default=48, type=int)
    parser.add_argument('--total_iter', default=9050, type=int)
    parser.add_argument('--eval_iter', default=1000, type=int)
    parser.add_argument('--save_iter', default=9000, type=int)
    parser.add_argument('--print_iter', default=500, type=int)

    # dataset
    parser.add_argument('--data_path', default="/path/to/dataset",
                        type=str, help='Path of dataset')
    parser.add_argument('--class_num', default=200,
                        type=int, help='number of classes')
    parser.add_argument('--num_workers', default=2, type=int,
                        help='Num of workers used in dataloading')

    # optimizer
    parser.add_argument('--lr', default=1e-3, type=float,
                        help='Learning rate for training')
    parser.add_argument('--gamma', default=0.1, type=float,
                        help='Gamma value for learning rate decay')
    parser.add_argument('--nesterov', default=True,
                        type=bool, help='nesterov momentum')
    parser.add_argument('--momentum', default=0.9, type=float,
                        help='Momentum value for optimizer')
    parser.add_argument('--weight_decay', default=5e-4,
                        type=float, help='Weight decay value for optimizer')

    # experiment
    parser.add_argument('--root', default='.', type=str,
                        help='Root of the experiment')
    parser.add_argument('--name', default='StochNorm', type=str,
                        help='Name of the experiment')
    parser.add_argument('--p', default=0.5, type=float,
                        help='Probability for StochNorm layers')
    parser.add_argument('--save_dir', default="model",
                        type=str, help='Path of saved models')
    parser.add_argument('--visual_dir', default="visual",
                        type=str, help='Path of tensorboard data for training')

    configs = parser.parse_args(args=['--gpu', '3', '--data_path', '/path', '--lr', '1e-4'])

    return configs


def str2list(v):
    return v.split(',')


def str2bool(v):
    return v.lower() in ("yes", "true", "t", "1")

# create tensorboard

In [3]:
def get_writer(log_dir):
    return SummaryWriter(log_dir)

# get dataset and dataloader

In [4]:
def get_data_loader(configs):
    # data augmentation
    data_transforms = get_transforms(resize_size=256, crop_size=224)

    # build dataset
    train_dataset = datasets.ImageFolder(
        os.path.join(configs.data_path, 'train'),
        transform=data_transforms['train'])
    val_dataset = datasets.ImageFolder(
        os.path.join(configs.data_path, 'val'),
        transform=data_transforms['val'])
    test_datasets = {
        'test' + str(i):
            datasets.ImageFolder(
                os.path.join(configs.data_path, 'test'),
                transform=data_transforms["test" + str(i)]
        )
        for i in range(10)
    }

    # build dataloader
    train_loader = DataLoader(train_dataset, batch_size=configs.batch_size, shuffle=True,
                              num_workers=configs.num_workers, pin_memory=True)

    val_loader = DataLoader(val_dataset, batch_size=configs.batch_size, shuffle=False,
                            num_workers=configs.num_workers, pin_memory=True)
    test_loaders = {
        'test' + str(i):
            DataLoader(
                test_datasets["test" + str(i)],
                batch_size=4, shuffle=False, num_workers=configs.num_workers
        )
        for i in range(10)
    }

    return train_loader, val_loader, test_loaders

# fix random seed

In [5]:
def set_seeds(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# train validataion and test

In [6]:
def train(configs, train_loader, val_loader, test_loaders, net):
    train_len = len(train_loader) - 1
    train_iter = iter(train_loader)

    # different learning rates for different layers
    params_list = [{"params": filter(lambda p: p.requires_grad, net.f_net.parameters()),},
                   {"params": filter(lambda p: p.requires_grad, net.c_net.parameters()), "lr": configs.lr * 10}]

    # optimizer and scheduler
    optimizer = torch.optim.SGD(params_list, lr=configs.lr, weight_decay=configs.weight_decay,
                                momentum=configs.momentum, nesterov=configs.nesterov)
    milestones = [6000]
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optimizer, milestones, gamma=configs.gamma)

    # check visual path
    visual_path = os.path.join(configs.visual_dir, configs.name)
    if not os.path.exists(visual_path):
        os.makedirs(visual_path)
    writer = get_writer(visual_path)

    # check model save path
    save_path = os.path.join(configs.save_dir, configs.name)
    if not os.path.exists(save_path):
        os.makedirs(save_path)

    # start training
    for iter_num in range(configs.total_iter):
        net.train()

        if iter_num % train_len == 0:
            train_iter = iter(train_loader)

        # Data Stage
        data_start = time()

        train_inputs, train_labels = next(train_iter)
        train_inputs, train_labels = train_inputs.cuda(), train_labels.cuda()

        data_duration = time() - data_start

        # Calc Stage
        calc_start = time()

        train_outputs = net(train_inputs)

        loss = classifier_loss = nn.CrossEntropyLoss()(train_outputs, train_labels)
        writer.add_scalar('loss/classifier_loss', classifier_loss, iter_num)

        net.zero_grad()
        optimizer.zero_grad()

        loss.backward()
        optimizer.step()
        scheduler.step()

        calc_duration = time() - calc_start

        if iter_num % configs.eval_iter == 0:
            acc_meter = AccuracyMeter(topk=(1,))
            with torch.no_grad():
                net.eval()
                for val_inputs, val_labels in tqdm(val_loader):
                    val_inputs, val_labels = val_inputs.cuda(), val_labels.cuda()
                    val_outputs = net(val_inputs)
                    acc_meter.update(val_outputs, val_labels)
                writer.add_scalar('acc/val_acc', acc_meter.avg[1], iter_num)
                print(
                    "Iter: {}/{} Val_Acc: {:2f}".format(
                        iter_num, configs.total_iter, acc_meter.avg[1])
                )
            acc_meter.reset()

        if iter_num % configs.save_iter == 0 and iter_num > 0:
            test_acc = TenCropsTest(test_loaders, net)
            writer.add_scalar('acc/test_acc', test_acc, iter_num)
            print(
                "Iter: {}/{} Test_Acc: {:2f}".format(
                    iter_num, configs.total_iter, test_acc)
            )
            checkpoint = {
                'state_dict': net.state_dict(),
                'iter': iter_num,
                'acc': test_acc,
            }
            torch.save(checkpoint,
                       os.path.join(save_path, '{}.pkl'.format(iter_num)))
            print("Model Saved.")

        if iter_num % configs.print_iter == 0:
            print(
                "Iter: {}/{} Loss: {:2f}, d/c: {}/{}".format(iter_num, configs.total_iter, loss, data_duration, calc_duration))

In [7]:
def main():
    configs = get_configs()
    print(configs)
    torch.cuda.set_device(configs.gpu)
    set_seeds(configs.seed)

    train_loader, val_loader, test_loaders = get_data_loader(configs)

    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.f_net = ResNet50_F(pretrained=True, norm_layer=StochNorm2d)
            self.c_net = nn.Linear(self.f_net.output_dim, configs.class_num)
            self.c_net.weight.data.normal_(0, 0.01)
            self.c_net.bias.data.fill_(0.0)

        def forward(self, x):
            feature = self.f_net(x)
            out = self.c_net(feature)
            return out
            
    net = Net().cuda()

    # set StochNorm layers
    for module in net.f_net.modules():
        if isinstance(module, StochNorm2d):
            module.p = configs.p

    train(configs, train_loader, val_loader, test_loaders, net)

In [8]:
print("PyTorch {}".format(torch.__version__))
print("TorchVision {}".format(torchvision.__version__))
main()

PyTorch 1.1.0
TorchVision 0.3.0
Namespace(batch_size=48, class_num=200, data_path='/data/finetune/cub200_15', eval_iter=1000, gamma=0.1, gpu=3, lr=0.0001, momentum=0.9, name='StochNorm', nesterov=True, num_workers=2, p=0.5, print_iter=500, root='.', save_dir='model', save_iter=9000, seed=2020, total_iter=9050, visual_dir='visual', weight_decay=0.0005)


100%|██████████| 121/121 [00:23<00:00,  5.18it/s]


Iter: 0/9050 Val_Acc: 0.878099
Iter: 0/9050 Loss: 5.308679, d/c: 0.10743975639343262/0.9772005081176758
Iter: 500/9050 Loss: 3.478653, d/c: 0.3940012454986572/0.2276926040649414


100%|██████████| 121/121 [00:23<00:00,  5.98it/s]


Iter: 1000/9050 Val_Acc: 43.891796
Iter: 1000/9050 Loss: 2.467999, d/c: 0.3902451992034912/0.21657633781433105
Iter: 1500/9050 Loss: 1.655665, d/c: 0.3385763168334961/0.22804570198059082


100%|██████████| 121/121 [00:23<00:00,  5.19it/s]


Iter: 2000/9050 Val_Acc: 48.053398
Iter: 2000/9050 Loss: 1.159652, d/c: 0.4065871238708496/0.21879196166992188
Iter: 2500/9050 Loss: 0.989734, d/c: 0.3759171962738037/0.21810555458068848


100%|██████████| 121/121 [00:23<00:00,  5.14it/s]


Iter: 3000/9050 Val_Acc: 48.371403
Iter: 3000/9050 Loss: 0.693430, d/c: 0.41485095024108887/0.2221660614013672
Iter: 3500/9050 Loss: 0.938889, d/c: 0.3791813850402832/0.22136783599853516


100%|██████████| 121/121 [00:23<00:00,  5.19it/s]


Iter: 4000/9050 Val_Acc: 49.022640
Iter: 4000/9050 Loss: 0.683620, d/c: 0.3784644603729248/0.21190619468688965
Iter: 4500/9050 Loss: 0.544682, d/c: 0.3894190788269043/0.21595501899719238


100%|██████████| 121/121 [00:23<00:00,  5.62it/s]


Iter: 5000/9050 Val_Acc: 49.511829
Iter: 5000/9050 Loss: 0.608874, d/c: 0.3726048469543457/0.22105169296264648
Iter: 5500/9050 Loss: 0.484004, d/c: 0.3883345127105713/0.21877503395080566


100%|██████████| 121/121 [00:23<00:00,  5.14it/s]


Iter: 6000/9050 Val_Acc: 49.846058
Iter: 6000/9050 Loss: 0.640509, d/c: 0.3887596130371094/0.2111964225769043
Iter: 6500/9050 Loss: 0.694746, d/c: 0.4028656482696533/0.2297976016998291


100%|██████████| 121/121 [00:22<00:00,  5.46it/s]


Iter: 7000/9050 Val_Acc: 49.742744
Iter: 7000/9050 Loss: 0.392571, d/c: 0.36809611320495605/0.21572065353393555
Iter: 7500/9050 Loss: 0.324146, d/c: 0.3683156967163086/0.22026658058166504


100%|██████████| 121/121 [00:22<00:00,  5.38it/s]


Iter: 8000/9050 Val_Acc: 49.680977
Iter: 8000/9050 Loss: 0.307792, d/c: 0.3857300281524658/0.2185518741607666
Iter: 8500/9050 Loss: 0.244998, d/c: 0.3447427749633789/0.21872878074645996


100%|██████████| 121/121 [00:22<00:00,  5.45it/s]

Iter: 9000/9050 Val_Acc: 50.148884



100%|██████████| 1449/1449 [02:23<00:00, 10.27it/s]


Iter: 9000/9050 Test_Acc: 50.414223
Model Saved.
Iter: 9000/9050 Loss: 0.321427, d/c: 0.3952596187591553/0.21174001693725586
