In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import torch.nn.functional as F
from torch.utils.data import DataLoader
from functools import partial

import net
from proxy_anchor import dataset
from proxy_anchor.utils import calc_recall_at_k
from sampler import UniqueClassSempler
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset

In [2]:
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
path = '/data/xuyunhao/datasets'
ds = 'CUB'
num_samples = 2
bs = 200
lr = 1e-5
t = 0.2
emb = 512
ep = 100
local_rank = 0
workers = 4
optimizer = 'adamw'
lr_decay_step = 10
lr_decay_gamma = 0.5

model =  'resnet18'
hyp_c = 0.1
clip_r  = 2.3
resize = 224
crop = 224
gpu_id = 5
bn_freeze = 1
freezer = True

# 投影超球模型

In [4]:
def project(x, *, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _project(x, c)

def _project(x, c):
    norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    maxnorm = (1 - 1e-3) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x)

def dexp0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _dexp0(u, c)

def _dexp0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = torch.tan(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def _dist_matrix_d(x, y, c):
    xy =torch.einsum("ij,kj->ik", (x, y))  # B x C
    x2 = x.pow(2).sum(-1, keepdim=True)  # B x 1
    y2 = y.pow(2).sum(-1, keepdim=True)  # C x 1
    sqrt_c = c ** 0.5
    num1 = 2*c*(x2+y2.permute(1, 0)-2*xy)+1e-5
    num2 = torch.mul((1+c*x2),(1+c*y2.permute(1, 0)))
    return (1/sqrt_c * torch.acos(1-num1/num2))


def dist_matrix_d(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix_d(x, y, c)

class ToProjection_hypersphere(nn.Module):
    def __init__(self, c, clip_r=None):
        super(ToProjection_hypersphere, self).__init__()
        self.register_parameter("xp", None)
        self.c = c
        self.clip_r = clip_r
        
    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return project(dexp0(x, c=self.c), c=self.c)

# 损失函数

In [5]:
def contrastive_loss(x0, x1, tau, hyp_c):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode

    dist_f = lambda x, y: -dist_matrix_d(x, y, c=hyp_c)
    bsize = x0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    logits00 = dist_f(x0, x0) / tau - eye_mask
    logits01 = dist_f(x0, x1) / tau
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    stats = {
        "logits/min": logits01.min().item(),
        "logits/mean": logits01.mean().item(),
        "logits/max": logits01.max().item(),
        "logits/acc": (logits01.argmax(-1) == target).float().mean().item(),
    }
    return loss, stats

# 模型

In [6]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.nn.init as init
import hyptorch.nn as hypnn
from torchvision.models import resnet18

class Resnet18(nn.Module):
    def __init__(self,embedding_size, pretrained=True, bn_freeze = True, hyp_c = 0, clip_r = 0):
        super(Resnet18, self).__init__()

        self.model = resnet18(pretrained)
        self.embedding_size = embedding_size
        self.hyp_c = hyp_c
        self.clip_r = clip_r
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)
        
        self.Dlayer = ToProjection_hypersphere(
            c=self.hyp_c,
            clip_r=self.clip_r,
        )
        self.model.embedding = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Dlayer)

        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)


    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        
        x = x.view(x.size(0), -1)
        x_d = self.model.embedding(x)
        return x_d

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding[0].weight, mode='fan_out')
        init.constant_(self.model.embedding[0].bias, 0)

In [7]:
def evaluate(get_emb_f, ds_name, hyp_c):
    emb_head = get_emb_f(ds_type="eval")
    recall_head = get_recall(*emb_head, ds_name, hyp_c)
    return recall_head

def get_recall(d, y, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        k_list = [1, 2, 4, 8, 16, 32]
    elif ds_name == "SOP":
        k_list = [1, 10, 100, 1000]

    dist_m = torch.empty(len(d), len(d), device="cuda")
    for i in range(len(d)):
        dist_m[i : i + 1] = -dist_matrix_d(d[i : i + 1], d, hyp_c)

    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall[0]

def get_emb(
    model,
    ds,
    path,
    ds_type="eval",
    world_size=1,
    num_workers=8,
):
    eval_tr = dataset.utils.make_transform(
        is_train = True, 
        is_inception = (model == 'bn_inception')
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    d, y = eval_dataset(model, dl_eval)
    y = y.cuda()
    model.train()
    return d, y

def eval_dataset(model, dl):
    all_xd, all_y = [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            d= model(x)
            all_xd.append(d)
        all_y.append(y)
    return torch.cat(all_xd), torch.cat(all_y)

In [8]:
torch.cuda.set_device(gpu_id)
world_size = int(os.environ.get("WORLD_SIZE", 1))

train_tr = dataset.utils.make_transform(
    is_train = True, 
    is_inception = (model == 'bn_inception')
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[ds]
ds_train = ds_class(path, "train", train_tr)

sampler = UniqueClassSempler(
    ds_train.ys, num_samples, local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=bs,
    num_workers=workers,
    pin_memory=True,
)

if model.find('resnet18')+1:
    model = Resnet18(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r).cuda().train() 
loss_f = partial(contrastive_loss, tau=t, hyp_c= hyp_c)

get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=path,
    num_workers=workers,
    world_size=world_size,
)
if freezer == True:
    embedding_param = list(model.model.embedding.parameters())
    for param in list(set(model.parameters()).difference(set(embedding_param))):
        param.requires_grad = False
    optimizer = optim.AdamW(embedding_param, lr=lr)
else:
    optimizer = optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma = lr_decay_gamma)
print("Training for {} epochs.".format(ep))

r0= evaluate(get_emb_f, ds, hyp_c=hyp_c)
print("The recall before train: ", r0)

losses_list = []
best_recall= 0
best_epoch = 0

for epoch in range(0, ep):
    model.train()
    if bn_freeze:
        modules = model.model.modules()
        for m in modules: 
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    losses_per_epoch = []
    sampler.set_epoch(epoch)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // num_samples, num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s1 = y[:, 0].tolist()
        assert len(set(s1)) == len(s1)

        x = x.cuda(non_blocking=True)
        d = model(x)
        d = d.view(len(x) // num_samples, num_samples, emb)
        loss = 0
        for i in range(num_samples):
            for j in range(num_samples):
                if i != j:
                    l, st = loss_f(d[:, i], d[:, j])
                    loss += l
                    stats_ep.append({**st, "loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        
    scheduler.step()        
    rh= evaluate(get_emb_f, ds, hyp_c = hyp_c)
    stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
    stats_ep = {"recall": rh, **stats_ep}
    if rh > best_recall :
        best_recall = rh
        best_epoch = epoch
    print("epoch:",epoch,"recall: ", rh)
    print("best epoch:",best_epoch,"best recall: ", best_recall)

Training for 100 epochs.
[0.38436866981769074, 0.5043889264010804, 0.6282916948008103, 0.7466239027683997, 0.8394665766374072, 0.9076637407157326]
The recall before train:  0.38436866981769074
[0.3896016205266712, 0.512322754895341, 0.6341998649561107, 0.7410533423362593, 0.8369345037137069, 0.9036124240378123]
epoch: 0 recall:  0.3896016205266712
best epoch: 0 best recall:  0.3896016205266712
[0.3759284267386901, 0.5023632680621202, 0.6266036461850101, 0.7400405131667792, 0.8311951384199865, 0.9015867656988521]
epoch: 1 recall:  0.3759284267386901
best epoch: 0 best recall:  0.3896016205266712
[0.37474679270763, 0.4951046590141796, 0.6284604996623903, 0.7493247805536799, 0.8389601620526671, 0.9090141796083727]
epoch: 2 recall:  0.37474679270763
best epoch: 0 best recall:  0.3896016205266712
[0.3661377447670493, 0.4913909520594193, 0.6169817690749494, 0.7343011478730588, 0.8264686022957461, 0.9009115462525321]
epoch: 3 recall:  0.3661377447670493
best epoch: 0 best recall:  0.389601620

[0.387407157326131, 0.513335584064821, 0.6377447670492911, 0.7488183659689399, 0.8419986495611074, 0.9117150573936529]
epoch: 39 recall:  0.387407157326131
best epoch: 15 best recall:  0.3944969615124916
[0.37947332883187035, 0.5050641458474004, 0.6340310600945307, 0.7506752194463201, 0.8392977717758271, 0.9058068872383525]
epoch: 40 recall:  0.37947332883187035
best epoch: 15 best recall:  0.3944969615124916
[0.38301823092505066, 0.5054017555705604, 0.6291357191087104, 0.7413909520594193, 0.8364280891289669, 0.9076637407157326]
epoch: 41 recall:  0.38301823092505066
best epoch: 15 best recall:  0.3944969615124916
[0.38825118163403105, 0.512491559756921, 0.6323430114787306, 0.74966239027684, 0.8398041863605672, 0.9068197164078325]
epoch: 42 recall:  0.38825118163403105
best epoch: 15 best recall:  0.3944969615124916
[0.3818365968939905, 0.5060769750168805, 0.6330182309250506, 0.7464550979068197, 0.8350776502363269, 0.9022619851451722]
epoch: 43 recall:  0.3818365968939905
best epoch: 1

[0.38538149898717083, 0.5106347062795409, 0.6291357191087104, 0.7418973666441594, 0.8352464550979068, 0.9063133018230926]
epoch: 79 recall:  0.38538149898717083
best epoch: 15 best recall:  0.3944969615124916
[0.37508440243079, 0.5016880486158002, 0.6326806212018906, 0.7467927076299797, 0.8419986495611074, 0.9113774476704929]
epoch: 80 recall:  0.37508440243079
best epoch: 15 best recall:  0.3944969615124916
[0.38841998649561105, 0.5079338284942606, 0.6298109385550303, 0.7424037812288994, 0.8322079675894666, 0.9036124240378123]
epoch: 81 recall:  0.38841998649561105
best epoch: 15 best recall:  0.3944969615124916
[0.3818365968939905, 0.5018568534773802, 0.6303173531397704, 0.7461174881836596, 0.8408170155300473, 0.9088453747467927]
epoch: 82 recall:  0.3818365968939905
best epoch: 15 best recall:  0.3944969615124916
[0.38419986495611075, 0.5101282916948008, 0.6362255232950709, 0.7451046590141797, 0.8322079675894666, 0.8987170830519919]
epoch: 83 recall:  0.38419986495611075
best epoch: