In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import torch.nn.functional as F
from torch.utils.data import DataLoader
from functools import partial

import net
from hyptorch.pmath import dist_matrix
from proxy_anchor import dataset
from proxy_anchor.utils import calc_recall_at_k
from sampler import UniqueClassSempler
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset

In [2]:
os.environ['CUDA_VISIBLE_DEVICES'] = '2,3,4'
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
path = '/data/xuyunhao/datasets'
ds = 'CUB'
num_samples = 2
bs = 200
lr = 1e-5
t = 0.2
emb = 128
ep = 100
local_rank = 0
workers = 8
optimizer = 'adamw'
lr_decay_step = 10
lr_decay_gamma = 0.5

model =  'resnet50'
hyp_c = 0.1
clip_r  = 2.3
resize = 224
crop = 224
gpu_id = -1
bn_freeze = 1

# 庞加莱模型

In [4]:
def _tensor_dot(x, y):
    res = torch.einsum("ij,kj->ik", (x, y))
    return res

class Artanh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        x = x.clamp(-1 + 1e-5, 1 - 1e-5)
        ctx.save_for_backward(x)
        res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5)
        return res

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        return grad_output / (1 - input ** 2)
    
def artanh(x):
    return Artanh.apply(x)

def _mobius_addition_batch(x, y, c):
    xy = _tensor_dot(x, y)  # B x C
    x2 = x.pow(2).sum(-1, keepdim=True)  # B x 1
    y2 = y.pow(2).sum(-1, keepdim=True)  # C x 1
    num = 1 + 2 * c * xy + c * y2.permute(1, 0)  # B x C
    num = num.unsqueeze(2) * x.unsqueeze(1)
    num = num + (1 - c * x2).unsqueeze(2) * y  # B x C x D
    denom_part1 = 1 + 2 * c * xy  # B x C
    denom_part2 = c ** 2 * x2 * y2.permute(1, 0)
    denom = denom_part1 + denom_part2
    res = num / (denom.unsqueeze(2) + 1e-5)
    return res

def _dist_matrix(x, y, c):
    sqrt_c = c ** 0.5
    return (
        2
        / sqrt_c
        * artanh(sqrt_c * torch.norm(_mobius_addition_batch(-x, y, c=c), dim=-1))
    )


def dist_matrix(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix(x, y, c)

def tanh(x, clamp=15):
    return x.clamp(-clamp, clamp).tanh()

def expmap0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _expmap0(u, c)

def _expmap0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def project(x, *, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _project(x, c)

def _project(x, c):
    norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    maxnorm = (1 - 1e-3) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x)

class ToPoincare(nn.Module):
    r"""
    Module which maps points in n-dim Euclidean space
    to n-dim Poincare ball
    Also implements clipping from https://arxiv.org/pdf/2107.11472.pdf
    """

    def __init__(self, c, clip_r=None):
        super(ToPoincare, self).__init__()
        self.register_parameter("xp", None)

        self.c = c
        
        self.clip_r = clip_r
        self.grad_fix = lambda x: x

    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return self.grad_fix(project(expmap0(x, c=self.c), c=self.c))

# 投影超球模型

In [5]:
def project(x, *, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _project(x, c)

def _project(x, c):
    norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    maxnorm = (1 - 1e-3) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x)

def dexp0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _dexp0(u, c)

def _dexp0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = torch.tan(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def _dist_matrix_d(x, y, c):
    xy =torch.einsum("ij,kj->ik", (x, y))  # B x C
    x2 = x.pow(2).sum(-1, keepdim=True)  # B x 1
    y2 = y.pow(2).sum(-1, keepdim=True)  # C x 1
    sqrt_c = c ** 0.5
    num1 = 2*c*(x2+y2.permute(1, 0)-2*xy)+1e-5
    num2 = torch.mul((1+c*x2),(1+c*y2.permute(1, 0)))
    return (1/sqrt_c * torch.acos(1-num1/num2))


def dist_matrix_d(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix_d(x, y, c)

class ToProjection_hypersphere(nn.Module):
    def __init__(self, c, clip_r=None):
        super(ToProjection_hypersphere, self).__init__()
        self.register_parameter("xp", None)
        self.c = c
        self.clip_r = clip_r
        
    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return project(dexp0(x, c=self.c), c=self.c)

# 损失函数

In [6]:
def contrastive_loss_e(e0, e1, tau):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode
    dist_e = lambda x, y: -torch.cdist(x, y, p=2)
    
    dist_e0 = dist_e(e0, e0)   
    dist_e1 = dist_e(e0, e1)

    bsize = e0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    
    logits00 = dist_e0 / tau - eye_mask
    logits01 = dist_e1 / tau
    
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    return loss

In [7]:
def contrastive_loss_p(x0, x1, tau, hyp_c):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode

    dist_f = lambda x, y: -dist_matrix(x, y, c=hyp_c)
    bsize = x0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    logits00 = dist_f(x0, x0) / tau - eye_mask
    logits01 = dist_f(x0, x1) / tau
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    return loss

In [8]:
def contrastive_loss_d(x0, x1, tau, hyp_c):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode

    dist_f = lambda x, y: -dist_matrix_d(x, y, c=hyp_c)
    bsize = x0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    logits00 = dist_f(x0, x0) / tau - eye_mask
    logits01 = dist_f(x0, x1) / tau
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    return loss

# 模型

In [9]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.nn.init as init
import hyptorch.nn as hypnn
from torchvision.models import resnet18
from torchvision.models import resnet34
from torchvision.models import resnet50
from torchvision.models import resnet101

## Resnet50

In [10]:
class Resnet50(nn.Module):
    def __init__(self,embedding_size, pretrained=True, bn_freeze = True, hyp_c = 0, clip_r = 0):
        super(Resnet50, self).__init__()

        self.model = resnet50(pretrained)
        self.embedding_size = embedding_size
        self.hyp_c = hyp_c
        self.clip_r = clip_r
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)

        self.Elayer = NormLayer()
        self.model.embedding_e = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Elayer)
        
        self.Player = ToPoincare(
            c=self.hyp_c,
            clip_r=self.clip_r,
        )
        self.model.embedding_p = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Player)
        
        self.Dlayer = ToProjection_hypersphere(
            c=self.hyp_c,
            clip_r=self.clip_r,
        )
        self.model.embedding_d = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Dlayer)

        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)

    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        x = x.view(x.size(0), -1)
        x_e = self.model.embedding_e(x)
        x_p = self.model.embedding_p(x)
        x_d = self.model.embedding_d(x)
        return x_e, x_p, x_d

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding_e[0].weight, mode='fan_out')
        init.constant_(self.model.embedding_e[0].bias, 0)
        init.kaiming_normal_(self.model.embedding_p[0].weight, mode='fan_out')
        init.constant_(self.model.embedding_p[0].bias, 0)
        init.kaiming_normal_(self.model.embedding_d[0].weight, mode='fan_out')
        init.constant_(self.model.embedding_d[0].bias, 0)

class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)

In [11]:
def evaluate(get_emb_f, ds_name, hyp_c):
    emb_head = get_emb_f(ds_type="eval")
    recall_head = get_recall(*emb_head, ds_name, hyp_c)
    return recall_head

def get_recall(e, p, d, y, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        k_list = [1, 2, 4, 8, 16, 32]
    elif ds_name == "SOP":
        k_list = [1, 10, 100, 1000]

    dist_m = torch.empty(len(e), len(e), device="cuda")
    for i in range(len(e)):
        dist_m[i : i + 1] = -torch.cdist(e[i : i + 1], e, p=2)- dist_matrix_d(d[i : i + 1], d, hyp_c)- dist_matrix(p[i : i + 1], p, hyp_c)
        
    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall[0]

def get_emb(
    model,
    ds,
    path,
    ds_type="eval",
    world_size=1,
    num_workers=8,
):
    eval_tr = dataset.utils.make_transform(
        is_train = True, 
        is_inception = (model == 'bn_inception')
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    e, p, d, y = eval_dataset(model, dl_eval)
    y = y.cuda()
    model.train()
    return e, p, d, y

def eval_dataset(model, dl):
    all_xe, all_xp, all_xd, all_y = [], [], [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            e, p, d= model(x)
            all_xe.append(e)
            all_xp.append(p)
            all_xd.append(d)
        all_y.append(y)
    return torch.cat(all_xe), torch.cat(all_xp), torch.cat(all_xd), torch.cat(all_y)

In [12]:
world_size = int(os.environ.get("WORLD_SIZE", 1))

train_tr = dataset.utils.make_transform(
    is_train = True, 
    is_inception = (model == 'bn_inception')
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[ds]
ds_train = ds_class(path, "train", train_tr)

sampler = UniqueClassSempler(
    ds_train.ys, num_samples, local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=bs,
    num_workers=workers,
    pin_memory=True,
)

if model.find('resnet18')+1:
    model = Resnet18(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r).cuda().train() 
elif model.find('resnet34')+1:
    model = Resnet34(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r).cuda().train() 
elif model.find('resnet50')+1:
    model = Resnet50(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r) 
elif model.find('resnet101')+1:
    model = Resnet101(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r)

if gpu_id == -1:
    model = nn.DataParallel(model)
    
model.cuda().train()

loss_e = partial(contrastive_loss_e, tau=t)
loss_p = partial(contrastive_loss_p, tau=t, hyp_c= hyp_c)
loss_d = partial(contrastive_loss_d, tau=t, hyp_c= hyp_c)

get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=path,
    num_workers=workers,
    world_size=world_size,
)
optimizer = optim.AdamW(model.module.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma = lr_decay_gamma)
print("Training for {} epochs.".format(ep))

r0= evaluate(get_emb_f, ds, hyp_c=hyp_c)
print("The recall before train: ", r0)

losses_list = []
best_recall= 0
best_epoch = 0

for epoch in range(0, ep):
    model.train()
    if bn_freeze:
        modules = model.module.model.modules()
        for m in modules: 
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    losses_per_epoch = []
    sampler.set_epoch(epoch)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // num_samples, num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s1 = y[:, 0].tolist()
        assert len(set(s1)) == len(s1)

        x = x.cuda(non_blocking=True)
        e, p, d = model(x)
        e = e.view(len(x) // num_samples, num_samples, emb)
        p = p.view(len(x) // num_samples, num_samples, emb)
        d = d.view(len(x) // num_samples, num_samples, emb)
        loss = 0
        for i in range(num_samples):
            for j in range(num_samples):
                if i != j:
                    l= loss_e(e[:, i], e[:, j])+loss_p(p[:, i], p[:, j])+loss_d(d[:, i], d[:, j])
                    loss += l
                    stats_ep.append({"loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        
    scheduler.step()        
    rh= evaluate(get_emb_f, ds, hyp_c = hyp_c)
    stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
    stats_ep = {"recall": rh, **stats_ep}
    if rh > best_recall :
        best_recall = rh
        best_epoch = epoch
    print("epoch:",epoch,"recall: ", rh)
    print("best epoch:",best_epoch,"best recall: ", best_recall)

Training for 100 epochs.
[0.3938217420661715, 0.5271775827143822, 0.6487170830519919, 0.7596218771100608, 0.8509453072248481, 0.9182984469952734]
The recall before train:  0.3938217420661715
[0.40597569209993245, 0.5305536799459825, 0.6522619851451722, 0.7678933153274814, 0.8543214044564483, 0.9162727886563133]
epoch: 0 recall:  0.40597569209993245
best epoch: 0 best recall:  0.40597569209993245
[0.42319378798109386, 0.5497974341661039, 0.6716745442268738, 0.7739702903443619, 0.8603983794733289, 0.9199864956110736]
epoch: 1 recall:  0.42319378798109386
best epoch: 1 best recall:  0.42319378798109386
[0.4324780553679946, 0.5545239702903444, 0.6760634706279541, 0.7765023632680621, 0.8548278190411884, 0.9201553004726536]
epoch: 2 recall:  0.4324780553679946
best epoch: 2 best recall:  0.4324780553679946
[0.4350101282916948, 0.5600945307224848, 0.6816340310600946, 0.7849426063470628, 0.8652937204591492, 0.9221809588116138]
epoch: 3 recall:  0.4350101282916948
best epoch: 3 best recall:  0.

[0.46168129642133693, 0.5818703578663066, 0.6954760297096556, 0.7834233625928426, 0.8587103308575287, 0.9152599594868333]
epoch: 39 recall:  0.46168129642133693
best epoch: 27 best recall:  0.4724848075624578
[0.4686022957461175, 0.5867656988521269, 0.700540175557056, 0.7976029709655638, 0.8637744767049291, 0.9172856178257934]
epoch: 40 recall:  0.4686022957461175
best epoch: 27 best recall:  0.4724848075624578
[0.4638757596218771, 0.5855840648210668, 0.6931127616475354, 0.7879810938555031, 0.8652937204591492, 0.9176232275489534]
epoch: 41 recall:  0.4638757596218771
best epoch: 27 best recall:  0.4724848075624578
[0.4550979068197164, 0.5828831870357867, 0.6958136394328157, 0.7893315327481432, 0.8651249155975692, 0.9134031060094531]
epoch: 42 recall:  0.4550979068197164
best epoch: 27 best recall:  0.4724848075624578
[0.4707967589466577, 0.5921674544226874, 0.7022282241728561, 0.7893315327481432, 0.862424037812289, 0.9186360567184335]
epoch: 43 recall:  0.4707967589466577
best epoch: 2

[0.47704253882511816, 0.5987508440243079, 0.7037474679270763, 0.7955773126266037, 0.8663065496286293, 0.9186360567184335]
epoch: 79 recall:  0.47704253882511816
best epoch: 70 best recall:  0.4812626603646185
[0.4559419311276165, 0.5806887238352465, 0.6909182984469953, 0.7937204591492235, 0.8656313301823092, 0.9172856178257934]
epoch: 80 recall:  0.4559419311276165
best epoch: 70 best recall:  0.4812626603646185
[0.45543551654287645, 0.5766374071573261, 0.6915935178933154, 0.787643484132343, 0.8641120864280891, 0.9162727886563133]
epoch: 81 recall:  0.45543551654287645
best epoch: 70 best recall:  0.4812626603646185
[0.4736664415935179, 0.5914922349763673, 0.6968264686022958, 0.7920324105334233, 0.862592842673869, 0.9145847400405132]
epoch: 82 recall:  0.4736664415935179
best epoch: 70 best recall:  0.4812626603646185
[0.46455097906819715, 0.5887913571910871, 0.6983457123565159, 0.7928764348413234, 0.8639432815665091, 0.9137407157326131]
epoch: 83 recall:  0.46455097906819715
best epoc