In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import torch.nn.functional as F
from torch.utils.data import DataLoader
from functools import partial

import net
from hyptorch.pmath import dist_matrix
from proxy_anchor import dataset
from proxy_anchor.utils import calc_recall_at_k
from sampler import UniqueClassSempler
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '4,5'
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
path = '/data/xuyunhao/datasets'
ds = 'CUB'
num_samples = 2
bs = 200
lr = 1e-5
t = 0.2
emb = 384
ep = 100
local_rank = 0
workers = 8
optimizer = 'adamw'
lr_decay_step = 10
lr_decay_gamma = 0.5

model =  'resnet50'
hyp_c = 0.1
clip_r  = 2.3
resize = 224
crop = 224
gpu_id = -1
bn_freeze = 1

# 庞加莱圆盘

In [4]:
def _tensor_dot(x, y):
    res = torch.einsum("ij,kj->ik", (x, y))
    return res

class Artanh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        x = x.clamp(-1 + 1e-5, 1 - 1e-5)
        ctx.save_for_backward(x)
        res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5)
        return res

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        return grad_output / (1 - input ** 2)
    
def artanh(x):
    return Artanh.apply(x)

def _mobius_addition_batch(x, y, c):
    xy = _tensor_dot(x, y)  # B x C
    x2 = x.pow(2).sum(-1, keepdim=True)  # B x 1
    y2 = y.pow(2).sum(-1, keepdim=True)  # C x 1
    num = 1 + 2 * c * xy + c * y2.permute(1, 0)  # B x C
    num = num.unsqueeze(2) * x.unsqueeze(1)
    num = num + (1 - c * x2).unsqueeze(2) * y  # B x C x D
    denom_part1 = 1 + 2 * c * xy  # B x C
    denom_part2 = c ** 2 * x2 * y2.permute(1, 0)
    denom = denom_part1 + denom_part2
    res = num / (denom.unsqueeze(2) + 1e-5)
    return res

def _dist_matrix(x, y, c):
    sqrt_c = c ** 0.5
    return (
        2
        / sqrt_c
        * artanh(sqrt_c * torch.norm(_mobius_addition_batch(-x, y, c=c), dim=-1))
    )


def dist_matrix(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix(x, y, c)

def tanh(x, clamp=15):
    return x.clamp(-clamp, clamp).tanh()

def expmap0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _expmap0(u, c)

def _expmap0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def project(x, *, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _project(x, c)

def _project(x, c):
    norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    maxnorm = (1 - 1e-3) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x)

class ToPoincare(nn.Module):
    r"""
    Module which maps points in n-dim Euclidean space
    to n-dim Poincare ball
    Also implements clipping from https://arxiv.org/pdf/2107.11472.pdf
    """

    def __init__(self, c, clip_r=None):
        super(ToPoincare, self).__init__()
        self.register_parameter("xp", None)

        self.c = c
        
        self.clip_r = clip_r
        self.grad_fix = lambda x: x

    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return self.grad_fix(project(expmap0(x, c=self.c), c=self.c))

# 损失函数

In [5]:
def contrastive_loss(x0, x1, tau, hyp_c):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode

    dist_f = lambda x, y: -dist_matrix(x, y, c=hyp_c)
    bsize = x0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    logits00 = dist_f(x0, x0) / tau - eye_mask
    logits01 = dist_f(x0, x1) / tau
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    stats = {
        "logits/min": logits01.min().item(),
        "logits/mean": logits01.mean().item(),
        "logits/max": logits01.max().item(),
        "logits/acc": (logits01.argmax(-1) == target).float().mean().item(),
    }
    return loss, stats

# 模型

In [6]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.nn.init as init
import hyptorch.nn as hypnn
from torchvision.models import resnet18
from torchvision.models import resnet34
from torchvision.models import resnet50
from torchvision.models import resnet101

## Resnet50

In [7]:
class Resnet50(nn.Module):
    def __init__(self,embedding_size, pretrained=True, bn_freeze = True, hyp_c = 0, clip_r = 0):
        super(Resnet50, self).__init__()

        self.model = resnet50(pretrained)
        self.embedding_size = embedding_size
        self.hyp_c = hyp_c
        self.clip_r = clip_r
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)

        self.Player = ToPoincare(
            c=self.hyp_c,
            clip_r=self.clip_r,
        )
        self.model.embedding = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Player)

        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        x = x.view(x.size(0), -1)
        x_p = self.model.embedding(x)
        return x_p

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding[0].weight, mode='fan_out')
        init.constant_(self.model.embedding[0].bias, 0)


## Resnet101

In [8]:
class Resnet101(nn.Module):
    def __init__(self,embedding_size, pretrained=True, bn_freeze = True, hyp_c = 0, clip_r = 0):
        super(Resnet101, self).__init__()

        self.model = resnet101(pretrained)
        self.embedding_size = embedding_size
        self.hyp_c = hyp_c
        self.clip_r = clip_r
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)

        self.Player = ToPoincare(
            c=self.hyp_c,
            clip_r=self.clip_r,
        )
        self.model.embedding = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Player)

        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        x = x.view(x.size(0), -1)
        x_p = self.model.embedding(x)
        return x_p

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding[0].weight, mode='fan_out')
        init.constant_(self.model.embedding[0].bias, 0)

In [9]:
def evaluate(get_emb_f, ds_name, hyp_c):
    emb_head = get_emb_f(ds_type="eval")
    recall_head = get_recall(*emb_head, ds_name, hyp_c)
    return recall_head

def get_recall(p, y, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        k_list = [1, 2, 4, 8, 16, 32]
    elif ds_name == "SOP":
        k_list = [1, 10, 100, 1000]

    dist_m = torch.empty(len(p), len(p), device="cuda")
    for i in range(len(p)):
        dist_m[i : i + 1] = -dist_matrix(p[i : i + 1], p, hyp_c)

    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall[0]

def get_emb(
    model,
    ds,
    path,
    ds_type="eval",
    world_size=1,
    num_workers=8,
):
    eval_tr = dataset.utils.make_transform(
        is_train = True, 
        is_inception = (model == 'bn_inception')
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    p, y = eval_dataset(model, dl_eval)
    y = y.cuda()
    model.train()
    return p, y

def eval_dataset(model, dl):
    all_xp, all_y = [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            p= model(x)
            all_xp.append(p)
        all_y.append(y)
    return torch.cat(all_xp), torch.cat(all_y)

In [10]:
world_size = int(os.environ.get("WORLD_SIZE", 1))

train_tr = dataset.utils.make_transform(
    is_train = True, 
    is_inception = (model == 'bn_inception')
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[ds]
ds_train = ds_class(path, "train", train_tr)

sampler = UniqueClassSempler(
    ds_train.ys, num_samples, local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=bs,
    num_workers=workers,
    pin_memory=True,
)

if model.find('resnet18')+1:
    model = Resnet18(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r).cuda().train() 
elif model.find('resnet34')+1:
    model = Resnet34(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r).cuda().train() 
elif model.find('resnet50')+1:
    model = Resnet50(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r) 
elif model.find('resnet101')+1:
    model = Resnet101(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r)

if gpu_id == -1:
    model = nn.DataParallel(model)
    
model.cuda().train()

loss_f = partial(contrastive_loss, tau=t, hyp_c= hyp_c)

get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=path,
    num_workers=workers,
    world_size=world_size,
)
optimizer = optim.AdamW(model.module.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma = lr_decay_gamma)
print("Training for {} epochs.".format(ep))

r0= evaluate(get_emb_f, ds, hyp_c=hyp_c)
print("The recall before train: ", r0)

losses_list = []
best_recall= 0
best_epoch = 0

for epoch in range(0, ep):
    model.train()
    if bn_freeze:
        modules = model.module.model.modules()
        for m in modules: 
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    losses_per_epoch = []
    sampler.set_epoch(epoch)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // num_samples, num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s1 = y[:, 0].tolist()
        assert len(set(s1)) == len(s1)

        x = x.cuda(non_blocking=True)
        p = model(x)
        p = p.view(len(x) // num_samples, num_samples, emb)
        loss = 0
        for i in range(num_samples):
            for j in range(num_samples):
                if i != j:
                    l, st = loss_f(p[:, i], p[:, j])
                    loss += l
                    stats_ep.append({**st, "loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        
    scheduler.step()        
    rh= evaluate(get_emb_f, ds, hyp_c = hyp_c)
    stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
    stats_ep = {"recall": rh, **stats_ep}
    if rh > best_recall :
        best_recall = rh
        best_epoch = epoch
    print("epoch:",epoch,"recall: ", rh)
    print("best epoch:",best_epoch,"best recall: ", best_recall)

Training for 100 epochs.
[0.40698852126941254, 0.5379810938555031, 0.6617150573936529, 0.7670492910195814, 0.8583727211343687, 0.9225185685347738]
The recall before train:  0.40698852126941254
[0.4159351789331533, 0.5418636056718433, 0.6613774476704929, 0.7722822417285617, 0.8597231600270088, 0.9193112761647535]
epoch: 0 recall:  0.4159351789331533
best epoch: 0 best recall:  0.4159351789331533
[0.4299459824442944, 0.5557056043214045, 0.6760634706279541, 0.7835921674544227, 0.8676569885212694, 0.924206617150574]
epoch: 1 recall:  0.4299459824442944
best epoch: 1 best recall:  0.4299459824442944
[0.42943956785955434, 0.5604321404456448, 0.6772451046590142, 0.7780216070222823, 0.8652937204591492, 0.9230249831195139]
epoch: 2 recall:  0.42943956785955434
best epoch: 1 best recall:  0.4299459824442944
[0.437373396353815, 0.5668467251856854, 0.6831532748143146, 0.786799459824443, 0.8664753544902093, 0.924037812288994]
epoch: 3 recall:  0.437373396353815
best epoch: 3 best recall:  0.4373733

[0.46809588116137746, 0.5909858203916273, 0.6986833220796759, 0.7898379473328832, 0.8593855503038488, 0.9135719108710331]
epoch: 39 recall:  0.46809588116137746
best epoch: 33 best recall:  0.47400405131667794
[0.47012153950033764, 0.5909858203916273, 0.7037474679270763, 0.7972653612424038, 0.8727211343686698, 0.9247130317353139]
epoch: 40 recall:  0.47012153950033764
best epoch: 33 best recall:  0.47400405131667794
[0.475016880486158, 0.5962187711006077, 0.7042538825118163, 0.7918636056718433, 0.8700202565833896, 0.9237002025658338]
epoch: 41 recall:  0.475016880486158
best epoch: 41 best recall:  0.475016880486158
[0.45661715057393654, 0.5823767724510466, 0.6949696151249156, 0.7920324105334233, 0.8693450371370696, 0.9196488858879136]
epoch: 42 recall:  0.45661715057393654
best epoch: 41 best recall:  0.475016880486158
[0.4716407832545577, 0.5952059419311276, 0.7018906144496961, 0.7964213369345037, 0.8671505739365294, 0.9194800810263336]
epoch: 43 recall:  0.4716407832545577
best epoc

[0.48311951384199864, 0.6024645509790681, 0.7113436866981769, 0.8013166779203241, 0.8730587440918298, 0.9203241053342336]
epoch: 79 recall:  0.48311951384199864
best epoch: 79 best recall:  0.48311951384199864
[0.45948683322079675, 0.5835584064821067, 0.6942943956785955, 0.7950708980418636, 0.8728899392302498, 0.9211681296421337]
epoch: 80 recall:  0.45948683322079675
best epoch: 79 best recall:  0.48311951384199864
[0.4598244429439568, 0.5784942606347063, 0.6954760297096556, 0.7969277515192438, 0.8705266711681297, 0.9194800810263336]
epoch: 81 recall:  0.4598244429439568
best epoch: 79 best recall:  0.48311951384199864
[0.474848075624578, 0.5930114787305875, 0.7045914922349764, 0.7964213369345037, 0.8658001350438893, 0.9203241053342336]
epoch: 82 recall:  0.474848075624578
best epoch: 79 best recall:  0.48311951384199864
[0.47704253882511816, 0.5980756245779878, 0.7067859554355166, 0.7984469952734639, 0.8651249155975692, 0.9176232275489534]
epoch: 83 recall:  0.47704253882511816
best 