<a href="https://colab.research.google.com/github/ykitaguchi77/pytorch-models/blob/master/AdaCos_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#Arcface-pytorch (AdaCos)
https://github.com/ronghuaiyang/arcface-pytorch/blob/master/models/metrics.py
Colabのスクリプト　https://cpp-learning.com/adacos/

In [0]:
import os
# import argparse
import numpy as np
import pandas as pd
import math
from tqdm import tqdm
import joblib
from collections import OrderedDict
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision import datasets, models, transforms
 
# from utils import *
# from mnist import archs
# import metrics

use_cuda = torch.cuda.is_available() and True
device = torch.device("cuda" if use_cuda else "cpu")

#MNISTのダウンロードとdataloader作成

In [0]:
train_set = datasets.MNIST(
        root='MNIST',
        train=True,
        download=True,
        transform=transforms.ToTensor())
 
val_set = datasets.MNIST(
        root='MNIST',
        train=False,
        download=True,
        transform=transforms.ToTensor())
 
train_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=128,
        shuffle=True,
        num_workers=8)
 
val_loader = torch.utils.data.DataLoader(
        val_set,
        batch_size=128,
        shuffle=False,
        num_workers=8)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting MNIST/MNIST/raw/train-images-idx3-ubyte.gz to MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting MNIST/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting MNIST/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))

Extracting MNIST/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/MNIST/raw
Processing...
Done!


#CNNの設定

In [0]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 4 * 4, 115)
        self.fc2 = nn.Linear(115, 84)
        self.fc3 = nn.Linear(84, 100)
 
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x







#AdaCos class

In [0]:
class AdaCos(nn.Module):
    def __init__(self, num_features, num_classes, m=0.50):
        super(AdaCos, self).__init__()
        self.num_features = num_features
        # self.n_classes = num_classes
        self.s = math.sqrt(2) * math.log(num_classes - 1)
        self.m = m
        self.W = Parameter(torch.FloatTensor(num_classes, num_features))
        nn.init.xavier_uniform_(self.W)
 
    def forward(self, input, label):
        # normalize features
        x = F.normalize(input)
        # normalize weights
        W = F.normalize(self.W)
        # dot product
        logits = F.linear(x, W)
        # add margin
        theta = torch.acos(torch.clamp(logits, -1.0 + 1e-7, 1.0 - 1e-7))
        target_logits = torch.cos(theta + self.m)
        one_hot = torch.zeros_like(logits)
        one_hot.scatter_(1, label.view(-1, 1).long(), 1)
        output = logits * (1 - one_hot) + target_logits * one_hot
        # feature re-scale
        with torch.no_grad():
            B_avg = torch.where(one_hot < 1, self.s * torch.exp(logits), torch.zeros_like(logits))
            B_avg = torch.sum(B_avg) / input.size(0)
            # print(B_avg)
            theta_med = torch.median(theta)
            self.s = torch.log(B_avg) / torch.cos(torch.min(math.pi/4 * torch.ones_like(theta_med), theta_med))
        # print(self.s)
        output *= self.s
 
        return output

#平均値・Accuracy計算

In [0]:

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res

#Train

In [0]:
def train(train_loader, model, metric_fc, criterion, optimizer):
    losses = AverageMeter()
    acc1s = AverageMeter()
 
    # switch to train mode
    model.train()
    metric_fc.train()
 
    for i, (input, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        input = input.to(device)
        target = target.to(device)
 
        feature = model(input)
        output = metric_fc(feature, target)
        loss = criterion(output, target)
 
        acc1, = accuracy(output, target, topk=(1,))
 
        losses.update(loss.item(), input.size(0))
        acc1s.update(acc1.item(), input.size(0))
 
        # compute gradient and do optimizing step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
    log = OrderedDict([
        ('loss', losses.avg),
        ('acc1', acc1s.avg),
    ])
 
    return log

#Val

In [0]:
def validate(val_loader, model, metric_fc, criterion):
    losses = AverageMeter()
    acc1s = AverageMeter()

    # switch to evaluate mode
    model.eval()
    metric_fc.eval()

    with torch.no_grad():
        for i, (input, target) in tqdm(enumerate(val_loader), total=len(val_loader)):
            input = input.to(device)
            target = target.to(device)

            feature = model(input)
            output = metric_fc(feature, target)
            loss = criterion(output, target)

            acc1, = accuracy(output, target, topk=(1,))

            losses.update(loss.item(), input.size(0))
            acc1s.update(acc1.item(), input.size(0))

    log = OrderedDict([
        ('loss', losses.avg),
        ('acc1', acc1s.avg),
    ])

    return log

#Connect instances from Net to metric_fc

In [0]:
model = Net().to(device)
num_features = model.fc3.out_features
metric_fc = AdaCos(num_features, num_classes=10).to(device)

#Set parameters

In [0]:
epochs = 100
 
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9, weight_decay=1e-4)
# optimizer = optim.Adam(model.parameters(), lr=0.02)
 
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-3)
 
criterion = nn.CrossEntropyLoss().to(device)


#Start learning

In [0]:
log = pd.DataFrame(index=[],
                   columns=[ 'epoch', 'lr', 'loss', 'acc1', 'val_loss', 'val_acc1'])
best_loss = float('inf')
 
for epoch in range(epochs):
    print('Epoch [%d/%d]' %(epoch+1, epochs))
 
    scheduler.step()
 
    # train for one epoch
    train_log = train(train_loader, model, metric_fc, criterion, optimizer)
    # evaluate on validation set
    val_log = validate(val_loader, model, metric_fc, criterion)
 
    print('loss %.4f - acc1 %.4f - val_loss %.4f - val_acc %.4f'
            %(train_log['loss'], train_log['acc1'], val_log['loss'], val_log['acc1']))
 
    tmp = pd.Series([
            epoch,
            scheduler.get_lr()[0],
            train_log['loss'],
            train_log['acc1'],
            val_log['loss'],
            val_log['acc1'],
        ], index=['epoch', 'lr', 'loss', 'acc1', 'val_loss', 'val_acc1'])
 
    log = log.append(tmp, ignore_index=True)
    log.to_csv('models_log.csv', index=False)
 
    if val_log['loss'] < best_loss:
        torch.save(model.state_dict(), 'model.pth')
        best_loss = val_log['loss'] print("> saved best model")

Epoch [1/100]


100%|██████████| 469/469 [00:08<00:00, 53.79it/s]
100%|██████████| 79/79 [00:01<00:00, 60.76it/s]


loss 0.5507 - acc1 89.8783 - val_loss 0.4071 - val_acc 93.3000
Epoch [2/100]


100%|██████████| 469/469 [00:08<00:00, 54.02it/s]
100%|██████████| 79/79 [00:01<00:00, 60.97it/s]

loss 0.3795 - acc1 93.8933 - val_loss 0.3268 - val_acc 95.1300
Epoch [3/100]



100%|██████████| 469/469 [00:08<00:00, 54.11it/s]
100%|██████████| 79/79 [00:01<00:00, 61.22it/s]


loss 0.3146 - acc1 95.3417 - val_loss 0.2779 - val_acc 96.1500
Epoch [4/100]


100%|██████████| 469/469 [00:08<00:00, 53.74it/s]
100%|██████████| 79/79 [00:01<00:00, 61.73it/s]

loss 0.2744 - acc1 96.2150 - val_loss 0.2546 - val_acc 96.6300
Epoch [5/100]



100%|██████████| 469/469 [00:08<00:00, 53.38it/s]
100%|██████████| 79/79 [00:01<00:00, 61.43it/s]

loss 0.2538 - acc1 96.6717 - val_loss 0.2383 - val_acc 96.9000
Epoch [6/100]



100%|██████████| 469/469 [00:08<00:00, 54.07it/s]
100%|██████████| 79/79 [00:01<00:00, 62.24it/s]

loss 0.2332 - acc1 97.0783 - val_loss 0.2437 - val_acc 97.0700
Epoch [7/100]



100%|██████████| 469/469 [00:08<00:00, 53.61it/s]
100%|██████████| 79/79 [00:01<00:00, 63.11it/s]

loss 0.2210 - acc1 97.4200 - val_loss 0.2305 - val_acc 97.0500
Epoch [8/100]



100%|██████████| 469/469 [00:08<00:00, 53.02it/s]
100%|██████████| 79/79 [00:01<00:00, 60.81it/s]

loss 0.2079 - acc1 97.6750 - val_loss 0.2074 - val_acc 97.5900
Epoch [9/100]



100%|██████████| 469/469 [00:08<00:00, 54.09it/s]
100%|██████████| 79/79 [00:01<00:00, 59.63it/s]

loss 0.1958 - acc1 97.9067 - val_loss 0.1976 - val_acc 97.7900
Epoch [10/100]



100%|██████████| 469/469 [00:08<00:00, 54.09it/s]
100%|██████████| 79/79 [00:01<00:00, 62.16it/s]

loss 0.1889 - acc1 98.0400 - val_loss 0.2044 - val_acc 97.7500
Epoch [11/100]



100%|██████████| 469/469 [00:08<00:00, 53.80it/s]
100%|██████████| 79/79 [00:01<00:00, 61.76it/s]


loss 0.1817 - acc1 98.2250 - val_loss 0.1949 - val_acc 97.9200
Epoch [12/100]


100%|██████████| 469/469 [00:08<00:00, 53.85it/s]
100%|██████████| 79/79 [00:01<00:00, 63.46it/s]

loss 0.1758 - acc1 98.3467 - val_loss 0.1865 - val_acc 98.1100
Epoch [13/100]



100%|██████████| 469/469 [00:08<00:00, 54.10it/s]
100%|██████████| 79/79 [00:01<00:00, 60.62it/s]

loss 0.1687 - acc1 98.5017 - val_loss 0.1879 - val_acc 98.0300
Epoch [14/100]



100%|██████████| 469/469 [00:08<00:00, 53.92it/s]
100%|██████████| 79/79 [00:01<00:00, 60.80it/s]

loss 0.1639 - acc1 98.5617 - val_loss 0.1911 - val_acc 98.0400
Epoch [15/100]



100%|██████████| 469/469 [00:08<00:00, 52.73it/s]
100%|██████████| 79/79 [00:01<00:00, 62.23it/s]

loss 0.1611 - acc1 98.6450 - val_loss 0.1785 - val_acc 98.1800
Epoch [16/100]



100%|██████████| 469/469 [00:08<00:00, 52.93it/s]
100%|██████████| 79/79 [00:01<00:00, 61.87it/s]

loss 0.1574 - acc1 98.6933 - val_loss 0.1811 - val_acc 98.1200
Epoch [17/100]



100%|██████████| 469/469 [00:08<00:00, 53.03it/s]
100%|██████████| 79/79 [00:01<00:00, 61.72it/s]

loss 0.1508 - acc1 98.8383 - val_loss 0.1708 - val_acc 98.3600
Epoch [18/100]



100%|██████████| 469/469 [00:08<00:00, 53.16it/s]


KeyboardInterrupt: ignored