<a href="https://colab.research.google.com/github/yuu399/AffinityLoss/blob/main/Untitled0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn 
import numpy as np
import torch.nn.functional as F 

affinity lossの定義

In [2]:
class AffinityLoss(nn.Module):
  def __init__(self, n_classes, dim, sigma, xi):
    super(AffinityLoss, self).__init__()
    self.n_classes = n_classes
    self.dim = dim 
    self.sigma = sigma
    self.xi = xi
    self.w = nn.parameter.Parameter(torch.empty(self.n_classes, self.dim))
    nn.init.xavier_normal_(self.w)
  
  def forward(self, f, label):
    # calcurate gaussian similarity
    # d(f_i, w_j) = sim[i + n_classes * j]
    f_expand = torch.cat([f] * self.n_classes, dim=0)
    w_expand = torch.cat([self.w] * f.size(0), dim=0)
    l2_norm = torch.sum((f_expand - w_expand)**2, 1)
    d = torch.exp(-l2_norm / self.sigma)

    label_one_hot = F.one_hot(label, num_classes=self.n_classes)
    label_flatten = torch.flatten(label_one_hot)
    L_hat = self.xi + (1 - label_flatten) * d - label_flatten * d 
    L_pos = torch.where(L_hat > 0, L_hat, torch.zeros_like(L_hat))
    L_mm = torch.sum(L_pos, dim=0)

    # diversity legularizar
    W = torch.matmul(self.w, self.w.transpose(1, 0))
    W_upper_triangle = torch.triu(W, 1)
    mu = 2 / (self.n_classes**2 - self.n_classes) * torch.sum(W_upper_triangle)
    X = W - mu
    X_upper_triangle = torch.triu(X, 1)
    Rw = torch.sum(X_upper_triangle**2)

    return L_mm + Rw 

モデルの定義

In [3]:
from torchvision import models

class FeatureExtracter(nn.Module):
  def __init__(self):
    super(FeatureExtracter, self).__init__()
    self.base_model = models.mobilenet_v2(pretrained=True).features

    self.gap = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(1280, 512)
    self.act = nn.ReLU()

  def forward(self, input):
    x = self.base_model(input)
    x = self.gap(x)
    x = x.view(-1, self.num_flat_features(x))
    x = self.fc(x)
    output = self.act(x)
    return output 
  
  def num_flat_features(self, x):
    size = x.size()[1:]  # all dimensions except the batch dimension
    num_features = 1
    for s in size:
      num_features *= s
    return num_features


model = FeatureExtracter()

学習データのダウンロード

In [4]:
import torchvision
from torchvision import transforms

batch_size=64

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])

transform = transforms.Compose([transforms.ToTensor(), normalize])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=0)

Files already downloaded and verified
Files already downloaded and verified


損失関数、オプティマイザの定義

In [5]:
from torch import optim

classes = trainset.classes
n_classes = len(classes)

criterion = AffinityLoss(n_classes, 512, 10, 0.5)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

学習

In [6]:
device = 'cuda'

model.to(device)
criterion.to(device)

for epoch in range(20):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        # backpropagation
        loss.backward()
        optimizer.step()
        train_loss = loss.item()
        running_loss += loss.item()
        if i % batch_size == batch_size - 1:
            print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / batch_size))
            running_loss = 0.0
print('Finished Training')

[1,    64] loss: 320.618
[1,   128] loss: 320.348
[1,   192] loss: 320.348
[1,   256] loss: 320.348
[1,   320] loss: 320.348
[1,   384] loss: 320.348
[1,   448] loss: 320.348
[1,   512] loss: 320.348
[1,   576] loss: 320.348
[1,   640] loss: 320.348
[1,   704] loss: 320.348
[1,   768] loss: 320.348
[2,    64] loss: 320.348
[2,   128] loss: 320.348
[2,   192] loss: 320.348
[2,   256] loss: 320.348
[2,   320] loss: 320.348
[2,   384] loss: 320.348
[2,   448] loss: 320.348
[2,   512] loss: 320.348
[2,   576] loss: 320.348
[2,   640] loss: 320.348
[2,   704] loss: 320.348
[2,   768] loss: 320.348
[3,    64] loss: 320.348
[3,   128] loss: 320.348
[3,   192] loss: 320.348
[3,   256] loss: 320.348
[3,   320] loss: 320.348
[3,   384] loss: 320.348
[3,   448] loss: 320.348
[3,   512] loss: 320.348
[3,   576] loss: 320.348
[3,   640] loss: 320.348
[3,   704] loss: 320.348
[3,   768] loss: 320.348
[4,    64] loss: 320.348
[4,   128] loss: 320.348
[4,   192] loss: 320.348
[4,   256] loss: 320.348


KeyboardInterrupt: ignored