In [1]:
import numpy as np
import copy
import torch

from torch import nn, optim
from PIL import ImageFile

from resnet import seresnet50

ImageFile.LOAD_TRUNCATED_IMAGES = True

import warnings
warnings.filterwarnings('ignore')

In [2]:
multi_gpus = False
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    multi_gpus = True
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
model = seresnet50(101)

In [4]:
import torchvision
from torchvision import datasets, transforms

def get_transform(random_crop=True):
    normalize = transforms.Normalize(
        #mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
        #std=[x / 255.0 for x in [63.0, 62.1, 66.7]]
        [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        )
    transform = []
    transform.append(transforms.Resize(256))
    if random_crop:
        #transform.append(transforms.RandomRotation(30))
        transform.append(transforms.RandomResizedCrop(224))
        transform.append(transforms.RandomHorizontalFlip())
        transform.append(transforms.ColorJitter(hue=.05, saturation=.05),)
    else:
        transform.append(transforms.CenterCrop(224))
    transform.append(transforms.ToTensor())
    transform.append(normalize)
    return transforms.Compose(transform)

In [5]:
from torch.utils import data

data_dir = './data/food101/'
train_data = datasets.ImageFolder(data_dir + 'train', transform=get_transform(random_crop=True))
test_data = datasets.ImageFolder(data_dir + 'test', transform=get_transform(random_crop=False))
tr_loader = data.DataLoader(dataset=train_data,
                            batch_size=256,
                            #sampler = RandomIdentitySampler(train_set, 4),
                            shuffle=True,
                            #pin_memory=True,
                            num_workers=16)

val_loader = data.DataLoader(dataset=test_data,
                             batch_size=256,
                             shuffle=False,
                             #pin_memory=True,                             
                             num_workers=16)

In [6]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [7]:
count_parameters(model)

26245973

In [8]:
import apex
print("using apex synced BN")
model = apex.parallel.convert_syncbn_model(model)

using apex synced BN


In [9]:
optimizer = optim.SGD(model.parameters(), lr=1., momentum=0.9, weight_decay=5e-4, nesterov=True)

In [10]:
from apex import amp, optimizers

model, optimizer = amp.initialize(model.cuda(), optimizer, opt_level='O3',keep_batchnorm_fp32=True)

Selected optimization level O3:  Pure FP16 training.
Defaults for this optimization level are:
enabled                : True
opt_level              : O3
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : False
master_weights         : False
loss_scale             : 1.0
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O3
cast_model_type        : torch.float16
patch_torch_functions  : False
keep_batchnorm_fp32    : True
master_weights         : False
loss_scale             : 1.0


In [11]:
criterion = nn.CrossEntropyLoss().cuda()

In [12]:
import datetime
test_time = datetime.datetime.now()

torch.cuda.synchronize()
model.train()
for _ in range(2):
    inputs, labels = next(iter(tr_loader))
    print(1)
    inputs = inputs.cuda(non_blocking=True)        
    labels = labels.cuda(non_blocking=True)    
    print(2)    
    logits = model(inputs)
    print(3)                       
    loss = criterion(logits, labels)                   
    print(4)                   
    loss.backward()
    print(5)                            
    model.zero_grad()
    print(10)                                
torch.cuda.synchronize()
test_end = datetime.datetime.now() - test_time
print('test {}'.format(test_end))

1
2
3
4
5
10
1
2
3
4
5
10
test 0:00:30.963236


In [13]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'

In [14]:
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

In [15]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=len(tr_loader)
                                                , epochs=30, pct_start=0.2)

In [16]:
import datetime
import time
high = 0.0
epoch_time = AverageMeter('Epoch', ':6.3f')
batch_time = AverageMeter('Batch', ':6.3f')
data_time = AverageMeter('Data', ':6.3f')
losses = AverageMeter('Loss', ':.5f')
learning_rates = AverageMeter('LearningRate', ':.5f')
top1 = AverageMeter('Acc@1', ':6.2f')
top5 = AverageMeter('Acc@5', ':6.2f')

for epoch in range(30):  # loop over the dataset multiple times
    time_ = datetime.datetime.now()    
    model.train()
    running_loss = 0.0
    running_corrects = 0
    total = 0
    progress = ProgressMeter(
        len(tr_loader),
        [batch_time, data_time, losses, top1, top5, learning_rates],
        prefix="Epoch: [{}]".format(epoch))
    
    end = time.time()    
    for i, (inputs, labels) in enumerate(tr_loader, 0):
        # get the inputs; data is a list of [inputs, labels]
        #print(inputs.shape)
        #print(labels.shape)
        data_time.update(time.time() - end)
        inputs = inputs.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        #_, preds = torch.max(outputs, 1)
        #loss.backward()
        with amp.scale_loss(loss, optimizer) as scaled_loss:
            scaled_loss.backward()
            
        optimizer.step()
        scheduler.step()
        # print statistics
        acc1, acc5 = accuracy(outputs, labels, topk=(1, 5))
        losses.update(loss.item(), inputs.size(0))
        learning_rates.update(scheduler.get_lr()[0])        
        top1.update(acc1[0], inputs.size(0))
        top5.update(acc5[0], inputs.size(0))

        
        batch_time.update(time.time() - end)
        if i % 100 == 99:    # print every 2000 mini-batches
            progress.display(i)
            #running_loss = 0.0
    elapsed = datetime.datetime.now() - time_
    print('{} elapsed for {}'.format(elapsed, epoch+1))

    
print('Finished Training')

Epoch: [0][ 99/296]	Batch 64.594 (37.115)	Data 64.110 (36.630)	Loss 4.59936 (4.63962)	Acc@1   1.56 (  1.49)	Acc@5   6.25 (  6.20)	LearningRate 0.00475 (0.00425)
Epoch: [0][199/296]	Batch 120.943 (64.899)	Data 120.456 (64.412)	Loss 4.48013 (4.59019)	Acc@1   3.52 (  2.15)	Acc@5  16.41 (  8.41)	LearningRate 0.00698 (0.00500)
0:02:52.021192 elapsed for 1
Epoch: [1][ 99/296]	Batch 64.711 (77.975)	Data 64.286 (77.492)	Loss 4.20553 (4.46757)	Acc@1   7.42 (  3.68)	Acc@5  23.83 ( 12.90)	LearningRate 0.01531 (0.00785)
Epoch: [1][199/296]	Batch 121.505 (81.079)	Data 121.000 (80.594)	Loss 4.16423 (4.40946)	Acc@1  10.16 (  4.58)	Acc@5  21.88 ( 14.96)	LearningRate 0.02134 (0.00995)
0:02:52.228428 elapsed for 2
Epoch: [2][ 99/296]	Batch 65.952 (84.094)	Data 65.424 (83.612)	Loss 3.96431 (4.29958)	Acc@1  11.33 (  6.15)	Acc@5  28.12 ( 18.61)	LearningRate 0.03572 (0.01515)
Epoch: [2][199/296]	Batch 123.054 (85.433)	Data 122.620 (84.949)	Loss 3.86217 (4.24325)	Acc@1   9.38 (  6.98)	Acc@5  29.69 ( 20.39)	L

In [19]:
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': loss,    
    
}, './checkpoint/resnet50_fp16_sconv_ep030.b0.pth')

In [17]:
def classification_val(model, val_loader):
    correct = 0
    total = 0    
    
    model.eval()
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    return correct/total

In [18]:
classification_val(model, val_loader)

0.8205148514851485

In [18]:
cls_results = [classification_val(model, val_loader) for i in range(10)]

In [19]:
np.mean(cls_results)

0.7856330699556754

In [24]:
cls_results

[0.7881094623241472,
 0.7854114472923492,
 0.7869531701676623,
 0.7803044902678743,
 0.7797263441896319,
 0.7850260165735209,
 0.7908074773559453,
 0.7862786664097129,
 0.7879167469647331,
 0.7857968780111775]

In [22]:
np.mean(cls_results)

0.7725380612834843

In [23]:
np.var(cls_results)

6.425454679560811e-06

In [24]:
np.std(cls_results)

0.0025348480584762496

In [19]:
def val_retrieval(model, val_loader):
    feats = None
    data_ids = None

    model.eval()
    with torch.no_grad():
        for idx, (images, labels) in enumerate(val_loader):
            images = images.to(device)
            #labels = labels.to(device)

            feat = model(images, feature=True)
            feat = feat.detach().cpu().numpy()

            feat = feat/np.linalg.norm(feat, axis=1)[:, np.newaxis]

            if feats is None:
                feats = feat
            else:
                feats = np.append(feats, feat, axis=0)

            if data_ids is None:
                data_ids = labels
            else:
                data_ids = np.append(data_ids, labels, axis=0)

        score_matrix = feats.dot(feats.T)
        np.fill_diagonal(score_matrix, -np.inf)
        top1_reference_indices = np.argmax(score_matrix, axis=1)

        top1_reference_ids = [
            [data_ids[idx], data_ids[top1_reference_indices[idx]]] for idx in
            range(len(data_ids))]

    total_count = len(top1_reference_ids)
    correct = 0
    for ids in top1_reference_ids:
        if ids[0] == ids[1]:
            correct += 1        
    return correct/total_count

In [20]:
val_retrieval(model, val_loader)

0.7493465346534653

In [21]:
retrieval_result = [val_retrieval(model, val_loader) for i in range(10)]

In [22]:
np.mean(retrieval_result)

0.6885912507226826

In [23]:
retrieval_result

[0.684428598959337,
 0.6872229716708421,
 0.693582578531509,
 0.6871266139911351,
 0.6907882058200039,
 0.6866448255925998,
 0.6872229716708421,
 0.690980921179418,
 0.6823087300057814,
 0.6956060898053575]

In [27]:
np.mean(retrieval_result)

0.6178358065137791

In [28]:
np.std(retrieval_result)

0.0038908581994208

In [29]:
retrieval_result

[0.6151474272499519,
 0.6104259009443053,
 0.615436500289073,
 0.6237232607438813,
 0.6200616689150126,
 0.6146656388514165,
 0.6214106764309115,
 0.6160146463673155,
 0.6200616689150126,
 0.6214106764309115]