In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import torch.nn.functional as F
from torch.utils.data import DataLoader
from functools import partial

import net
from hyptorch.pmath import dist_matrix
from proxy_anchor import dataset
from proxy_anchor.utils import calc_recall_at_k
from sampler import UniqueClassSempler
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset

In [2]:
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
path = '/data/xuyunhao/datasets'
ds = 'CUB'
num_samples = 2
bs = 200
lr = 1e-5
t = 0.2
emb = 512
ep = 100
local_rank = 0
workers = 4
optimizer = 'adamw'
lr_decay_step = 10
lr_decay_gamma = 0.5

model =  'resnet18'
hyp_c = 0.1
clip_r  = 2.3
resize = 224
crop = 224
gpu_id = 5
bn_freeze = 1
freezer = True

# 庞加莱模型

In [4]:
def _tensor_dot(x, y):
    res = torch.einsum("ij,kj->ik", (x, y))
    return res

class Artanh(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        x = x.clamp(-1 + 1e-5, 1 - 1e-5)
        ctx.save_for_backward(x)
        res = (torch.log_(1 + x).sub_(torch.log_(1 - x))).mul_(0.5)
        return res

    @staticmethod
    def backward(ctx, grad_output):
        (input,) = ctx.saved_tensors
        return grad_output / (1 - input ** 2)
    
def artanh(x):
    return Artanh.apply(x)

def _mobius_addition_batch(x, y, c):
    xy = _tensor_dot(x, y)  # B x C
    x2 = x.pow(2).sum(-1, keepdim=True)  # B x 1
    y2 = y.pow(2).sum(-1, keepdim=True)  # C x 1
    num = 1 + 2 * c * xy + c * y2.permute(1, 0)  # B x C
    num = num.unsqueeze(2) * x.unsqueeze(1)
    num = num + (1 - c * x2).unsqueeze(2) * y  # B x C x D
    denom_part1 = 1 + 2 * c * xy  # B x C
    denom_part2 = c ** 2 * x2 * y2.permute(1, 0)
    denom = denom_part1 + denom_part2
    res = num / (denom.unsqueeze(2) + 1e-5)
    return res

def _dist_matrix(x, y, c):
    sqrt_c = c ** 0.5
    return (
        2
        / sqrt_c
        * artanh(sqrt_c * torch.norm(_mobius_addition_batch(-x, y, c=c), dim=-1))
    )


def dist_matrix(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix(x, y, c)

def tanh(x, clamp=15):
    return x.clamp(-clamp, clamp).tanh()

def expmap0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _expmap0(u, c)

def _expmap0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = tanh(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def project(x, *, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _project(x, c)

def _project(x, c):
    norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    maxnorm = (1 - 1e-3) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x)

class ToPoincare(nn.Module):
    r"""
    Module which maps points in n-dim Euclidean space
    to n-dim Poincare ball
    Also implements clipping from https://arxiv.org/pdf/2107.11472.pdf
    """

    def __init__(self, c, clip_r=None):
        super(ToPoincare, self).__init__()
        self.register_parameter("xp", None)

        self.c = c
        
        self.clip_r = clip_r
        self.grad_fix = lambda x: x

    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return self.grad_fix(project(expmap0(x, c=self.c), c=self.c))

# 投影超球

In [5]:
def project(x, *, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _project(x, c)

def _project(x, c):
    norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    maxnorm = (1 - 1e-3) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x)

def dexp0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _dexp0(u, c)

def _dexp0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = torch.tan(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def _dist_matrix_d(x, y, c):
    xy =torch.einsum("ij,kj->ik", (x, y))  # B x C
    x2 = x.pow(2).sum(-1, keepdim=True)  # B x 1
    y2 = y.pow(2).sum(-1, keepdim=True)  # C x 1
    sqrt_c = c ** 0.5
    num1 = 2*c*(x2+y2.permute(1, 0)-2*xy)+1e-5
    num2 = torch.mul((1+c*x2),(1+c*y2.permute(1, 0)))
    return (1/sqrt_c * torch.acos(1-num1/num2))


def dist_matrix_d(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix_d(x, y, c)

class ToProjection_hypersphere(nn.Module):
    def __init__(self, c, clip_r=None):
        super(ToProjection_hypersphere, self).__init__()
        self.register_parameter("xp", None)
        self.c = c
        self.clip_r = clip_r
        
    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return project(dexp0(x, c=self.c), c=self.c)

# 损失函数

In [6]:
def contrastive_loss(e0, e1, p0, p1, d0, d1, tau, hyp_c):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode
    dist_e = lambda x, y: -torch.cdist(x, y, p=2)
    dist_p = lambda x, y: -dist_matrix(x, y, c=hyp_c)
    dist_d = lambda x, y: -dist_matrix_d(x, y, c=hyp_c)
    
    dist_x0 = dist_e(e0, e0)+dist_p(p0, p0)+dist_d( d0, d0)   
    dist_x1 = dist_e(e0, e1)+dist_p(p0, p1)+dist_d( d0, d1)   

    bsize = e0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    
    logits00 = dist_x0 / tau - eye_mask
    logits01 = dist_x1 / tau
    
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    return loss

# 模型

In [7]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.nn.init as init
import hyptorch.nn as hypnn
from torchvision.models import resnet18

class Resnet18(nn.Module):
    def __init__(self,embedding_size, pretrained=True, bn_freeze = True, hyp_c = 0, clip_r = 0):
        super(Resnet18, self).__init__()

        self.model = resnet18(pretrained)
        self.embedding_size = embedding_size
        self.hyp_c = hyp_c
        self.clip_r = clip_r
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)
        
        self.model.embedding = nn.Linear(self.num_ftrs, self.embedding_size)
        
        self.Elayer = NormLayer()
        
        self.Player = ToPoincare(
            c=self.hyp_c,
            clip_r=self.clip_r,
        )
         
        self.Dlayer = ToProjection_hypersphere(
            c=self.hyp_c,
            clip_r=self.clip_r,
        )
        
        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)


    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        
        x = x.view(x.size(0), -1)
        x1 = self.model.embedding(x)
        x_e = self.Elayer(x1)
        x_p = self.Player(x1)
        x_d = self.Dlayer(x1)
        return x_e, x_p, x_d

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding.weight, mode='fan_out')
        init.constant_(self.model.embedding.bias, 0)

class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)

In [8]:
def evaluate(get_emb_f, ds_name, hyp_c):
    emb_head = get_emb_f(ds_type="eval")
    recall_head = get_recall(*emb_head, ds_name, hyp_c)
    return recall_head

def get_recall(e, p, d, y, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        k_list = [1, 2, 4, 8, 16, 32]
    elif ds_name == "SOP":
        k_list = [1, 10, 100, 1000]

    dist_m = torch.empty(len(e), len(e), device="cuda")
    for i in range(len(e)):
        dist_m[i : i + 1] = -torch.cdist(e[i : i + 1], e, p=2)- dist_matrix_d(d[i : i + 1], d, hyp_c)- dist_matrix(p[i : i + 1], p, hyp_c)
        
    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall[0]

def get_emb(
    model,
    ds,
    path,
    ds_type="eval",
    world_size=1,
    num_workers=8,
):
    eval_tr = dataset.utils.make_transform(
        is_train = True, 
        is_inception = (model == 'bn_inception')
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    e, p, d, y = eval_dataset(model, dl_eval)
    y = y.cuda()
    model.train()
    return e, p, d, y

def eval_dataset(model, dl):
    all_xe, all_xp, all_xd, all_y = [], [], [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            e, p, d= model(x)
            all_xe.append(e)
            all_xp.append(p)
            all_xd.append(d)
        all_y.append(y)
    return torch.cat(all_xe), torch.cat(all_xp), torch.cat(all_xd), torch.cat(all_y)

In [9]:
torch.cuda.set_device(gpu_id)
world_size = int(os.environ.get("WORLD_SIZE", 1))

train_tr = dataset.utils.make_transform(
    is_train = True, 
    is_inception = (model == 'bn_inception')
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[ds]
ds_train = ds_class(path, "train", train_tr)

sampler = UniqueClassSempler(
    ds_train.ys, num_samples, local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=bs,
    num_workers=workers,
    pin_memory=True,
)

if model.find('resnet18')+1:
    model = Resnet18(embedding_size=emb, pretrained=True, bn_freeze = bn_freeze,hyp_c = hyp_c, clip_r = clip_r).cuda().train() 
loss_f = partial(contrastive_loss, tau=t, hyp_c= hyp_c)

get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=path,
    num_workers=workers,
    world_size=world_size,
)
if freezer == True:
    embedding_param = list(model.model.embedding.parameters())
    for param in list(set(model.parameters()).difference(set(embedding_param))):
        param.requires_grad = False
    optimizer = optim.AdamW(embedding_param, lr=lr)
else:
    optimizer = optim.AdamW(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma = lr_decay_gamma)
print("Training for {} epochs.".format(ep))

r0= evaluate(get_emb_f, ds, hyp_c=hyp_c)
print("The recall before train: ", r0)

losses_list = []
best_recall= 0
best_epoch = 0

for epoch in range(0, ep):
    model.train()
    if bn_freeze:
        modules = model.model.modules()
        for m in modules: 
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    losses_per_epoch = []
    sampler.set_epoch(epoch)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // num_samples, num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s1 = y[:, 0].tolist()
        assert len(set(s1)) == len(s1)

        x = x.cuda(non_blocking=True)
        e, p, d = model(x)
        e = e.view(len(x) // num_samples, num_samples, emb)
        p = p.view(len(x) // num_samples, num_samples, emb)
        d = d.view(len(x) // num_samples, num_samples, emb)
        loss = 0
        for i in range(num_samples):
            for j in range(num_samples):
                if i != j:
                    l= loss_f(e[:, i], e[:, j], p[:, i], p[:, j], d[:, i], d[:, j])
                    loss += l
                    stats_ep.append({"loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
        
    scheduler.step()        
    rh= evaluate(get_emb_f, ds, hyp_c = hyp_c)
    stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
    stats_ep = {"recall": rh, **stats_ep}
    if rh > best_recall :
        best_recall = rh
        best_epoch = epoch
    print("epoch:",epoch,"recall: ", rh)
    print("best epoch:",best_epoch,"best recall: ", best_recall)

Training for 100 epochs.
[0.38436866981769074, 0.5043889264010804, 0.6282916948008103, 0.7466239027683997, 0.8394665766374072, 0.9076637407157326]
The recall before train:  0.38436866981769074
[0.3899392302498312, 0.511985145172181, 0.6338622552329507, 0.7410533423362593, 0.8369345037137069, 0.9037812288993923]
epoch: 0 recall:  0.3899392302498312
best epoch: 0 best recall:  0.3899392302498312
[0.3760972316002701, 0.5023632680621202, 0.6269412559081702, 0.7398717083051992, 0.8315327481431465, 0.9017555705604321]
epoch: 1 recall:  0.3760972316002701
best epoch: 0 best recall:  0.3899392302498312
[0.37440918298446996, 0.4952734638757596, 0.6282916948008103, 0.7493247805536799, 0.8387913571910871, 0.9086765698852127]
epoch: 2 recall:  0.37440918298446996
best epoch: 0 best recall:  0.3899392302498312
[0.3659689399054693, 0.4912221471978393, 0.6169817690749494, 0.7343011478730588, 0.8264686022957461, 0.9012491559756921]
epoch: 3 recall:  0.3659689399054693
best epoch: 0 best recall:  0.389

[0.387744767049291, 0.513673193787981, 0.6387575962187712, 0.7491559756920999, 0.8418298446995274, 0.9103646185010128]
epoch: 39 recall:  0.387744767049291
best epoch: 15 best recall:  0.3946657663740716
[0.37677245104659013, 0.5054017555705604, 0.6343686698176908, 0.74949358541526, 0.8403106009453072, 0.9053004726536125]
epoch: 40 recall:  0.37677245104659013
best epoch: 15 best recall:  0.3946657663740716
[0.3811613774476705, 0.5052329507089804, 0.6287981093855503, 0.7395340985820391, 0.8364280891289669, 0.9085077650236327]
epoch: 41 recall:  0.3811613774476705
best epoch: 15 best recall:  0.3946657663740716
[0.387744767049291, 0.5106347062795409, 0.6338622552329507, 0.7479743416610398, 0.8394665766374072, 0.9064821066846726]
epoch: 42 recall:  0.387744767049291
best epoch: 15 best recall:  0.3946657663740716
[0.3811613774476705, 0.5052329507089804, 0.6331870357866306, 0.7469615124915597, 0.8352464550979068, 0.9025995948683322]
epoch: 43 recall:  0.3811613774476705
best epoch: 15 bes

[0.3838622552329507, 0.5104659014179609, 0.6289669142471304, 0.7408845374746793, 0.8340648210668468, 0.9069885212694125]
epoch: 79 recall:  0.3838622552329507
best epoch: 15 best recall:  0.3946657663740716
[0.3759284267386901, 0.5016880486158002, 0.6309925725860904, 0.7461174881836596, 0.8435178933153274, 0.9113774476704929]
epoch: 80 recall:  0.3759284267386901
best epoch: 15 best recall:  0.3946657663740716
[0.38791357191087106, 0.5091154625253207, 0.6296421336934503, 0.7402093180283592, 0.8333896016205267, 0.9037812288993923]
epoch: 81 recall:  0.38791357191087106
best epoch: 15 best recall:  0.3946657663740716
[0.3820054017555706, 0.50033760972316, 0.6293045239702904, 0.7462862930452397, 0.8406482106684673, 0.9083389601620526]
epoch: 82 recall:  0.3820054017555706
best epoch: 15 best recall:  0.3946657663740716
[0.3835246455097907, 0.5091154625253207, 0.6357191087103309, 0.7461174881836596, 0.8328831870357867, 0.8985482781904118]
epoch: 83 recall:  0.3835246455097907
best epoch: 1