<a href="https://colab.research.google.com/github/tasakama/media/blob/main/media_report.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import sklearn.metrics.pairwise as F

from torchvision import datasets, transforms, models
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import cosine_similarity
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from collections import Counter

In [None]:
class Projector(nn.Module):
  def __init__(self, in_dim, out_dim=2048):
    super(Projector, self).__init__()

    self.layers = nn.Sequential(
        nn.Linear(in_dim, in_dim, bias=False),
        nn.BatchNorm1d(in_dim),
        nn.ReLU(inplace=True),
        nn.Linear(in_dim, in_dim, bias=False),
        nn.BatchNorm1d(in_dim),
        nn.ReLU(inplace=True),
        nn.Linear(in_dim, out_dim, bias=False),
        nn.BatchNorm1d(out_dim),
    )

  def forward(self, x):
    return self.layers(x)


class Predictor(nn.Module):
  def __init__(self, in_dim=2048, pred_dim=512, out_dim=2048):
    super(Predictor, self).__init__()

    self.layers = nn.Sequential(
        nn.Linear(in_dim, pred_dim, bias=False),
        nn.BatchNorm1d(pred_dim),
        nn.ReLU(inplace=True),
        nn.Linear(pred_dim, out_dim)
    )

  def forward(self, x):
      return self.layers(x)


class SimSiam(nn.Module):
  def __init__(self, backbone, projector, predictor):
    super(SimSiam, self).__init__()

    self.backbone = backbone
    self.projector = projector
    self.predictor = predictor

  def forward(self, x1, x2):
    # x1,x2は変形後の画像を表す
    z1 = self.projector(self.backbone(x1).flatten(start_dim=1))
    z2 = self.projector(self.backbone(x2).flatten(start_dim=1))

    p1 = self.predictor(z1)
    p2 = self.predictor(z2)

    # .detach()で勾配停止させる
    return p1, p2, z1.detach(), z2.detach()


class SimSiamDataset(Dataset):
  def __init__(self, root, transform1, transform2, train=True):
    self.dataset = datasets.CIFAR10(root=root, train=train, download=True)
    self.transform1 = transform1
    self.transform2 = transform2

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, idx):
    image, label = self.dataset[idx]
    image1 = self.transform1(image)
    image2 = self.transform2(image)
    return image1, image2, label


def negative_cosine(p, z):
  return -cosine_similarity(p, z, dim=1).mean()


def feature_for_knn(model, data_loader, culculate_type):
  # simsiamモデルを評価用に設定
  model.eval()
  features = []
  labels = []

  with torch.no_grad():
    if culculate_type == "train":
      for x, _, y in data_loader:
        x = x.to(device)
        feature = model.backbone(x).flatten(start_dim=1)
        features.append(feature.cpu().numpy())
        labels.append(y.cpu().numpy())
    elif culculate_type == "val":
      for x, y in data_loader:
        x = x.to(device)
        feature = model.backbone(x).flatten(start_dim=1)
        features.append(feature.cpu().numpy())
        labels.append(y.cpu().numpy())

  features = np.concatenate(features, axis=0)
  labels = np.concatenate(labels, axis=0)

  return features, labels


def knn_cosine(train_features, train_labels, val_features, k):

  # コサイン類似度の計算
  cosine_sim = -F.cosine_similarity(val_features, train_features)

  # v2のラベルを推定
  val_label_pred = []

  for i in range(val_features.shape[0]):
      # 上位k個の類似度が高い(-1に近い)インデックスを取得
      top_k_indices = np.argsort(cosine_sim[i])[:k]
      # 上位k個のラベルを取得
      top_k_labels = train_labels[top_k_indices]
      # 最頻出ラベルを取得
      most_common_label = Counter(top_k_labels).most_common(1)[0][0]
      val_label_pred.append(most_common_label)

  val_label_pred = np.array(val_label_pred)

  return val_label_pred

In [None]:
# GPUの利用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# backboneとして最終全結合層だけ除いたResnet18を利用する
backbone = nn.Sequential(*list(models.resnet18(weights=None).children())[:-1]).to(device)

# SimSiamモデルの用意
projector = Projector(in_dim=512).to(device)
predictor = Predictor().to(device)
model = SimSiam(backbone, projector, predictor).to(device)

# データセット変形用
transform = transforms.Compose([
    transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
    transforms.RandomApply([
        transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)
    ], p=0.8),
    transforms.RandomGrayscale(p=0.2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

train_dataset = SimSiamDataset(
    root='./data',
    transform1=transform,
    transform2=transform,
    train=True
)
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=2)

val_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
val_loader = DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=2)

# 最適化手法の設定
optimizer = optim.SGD(model.parameters(), lr=0.03, momentum=0.9, weight_decay=0.0005)

In [None]:
# 学習ループ
num_epochs = 100

for epoch in range(num_epochs):
  model.train()
  total_loss = 0

  for x1, x2, _ in train_loader:
    x1, x2 = x1.to(device), x2.to(device)
    optimizer.zero_grad()
    p1, p2, z1, z2 = model(x1, x2)
    loss = negative_cosine(p1, z2)/2 + negative_cosine(p2, z1)/2
    loss.backward()
    optimizer.step()

    total_loss += loss.item()

  print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')

  if (epoch + 1) % 10 == 0 or epoch == 0:
    train_features, train_labels = feature_for_knn(model, train_loader, "train")
    val_features, val_labels = feature_for_knn(model, val_loader, "val")

    # 予測ラベルの計算
    predictions = knn_cosine(train_features, train_labels, val_features, k=200)

    accuracy = accuracy_score(val_labels, predictions)

    print(f'Epoch [{epoch+1}/{num_epochs}], k-NN Accuracy: {accuracy:.2f}')