In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
from apex import amp

import os
import random
import timm
import wandb
from tqdm import trange
import multiprocessing
from functools import partial
import numpy as np
import PIL
from torchvision.transforms import InterpolationMode

from sampler import UniqueClassSempler
from proxy_anchor.utils import calc_recall_at_k
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset
from hyptorch.pmath import dist_matrix
import hyptorch.nn as hypnn

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from datetime import timedelta

import ipdb

In [2]:
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# 超参数

In [3]:
num = 2
path = '/data/xuyunhao/datasets'
ds = 'CUB'
num_samples = 2
bs = 200

lr = 3e-5

t = 0.2
emb = 128
freeze = 0
ep = 50
hyp_c = 0.2
eval_ep = 'r('+str(ep-100)+','+str(ep+10)+',10)'

model = 'vit_small_patch16_224'
# model = 'dino_vits16'
#model = 'deit_small_distilled_patch16_224'

save_emb = False
emb_name = 'emb'
clip_r = 2.3
resize = 224
crop = 224
local_rank = 0
save_path = "/data/xuyunhao/Mixed curvature/result/{}_{}_best_d_{}_{}_checkout.pth".format(model,ds,emb,num)

# 投影超球模型

In [4]:
def project(x, *, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _project(x, c)

def _project(x, c):
    norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    maxnorm = (1 - 1e-3) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x)

def dexp0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _dexp0(u, c)

def _dexp0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = torch.tan(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def _dist_matrix_d(x, y, c):
    xy =torch.einsum("ij,kj->ik", (x, y))  # B x C
    x2 = x.pow(2).sum(-1, keepdim=True)  # B x 1
    y2 = y.pow(2).sum(-1, keepdim=True)  # C x 1
    sqrt_c = c ** 0.5
    num1 = 2*c*(x2+y2.permute(1, 0)-2*xy)+1e-5
    num2 = torch.mul((1+c*x2),(1+c*y2.permute(1, 0)))
    return (1/sqrt_c * torch.acos(1-num1/num2))


def dist_matrix_d(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix_d(x, y, c)

class ToProjection_hypersphere(nn.Module):
    def __init__(self, c, clip_r=None):
        super(ToProjection_hypersphere, self).__init__()
        self.register_parameter("xp", None)
        self.c = c
        self.clip_r = clip_r
        
    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return project(dexp0(x, c=self.c), c=self.c)

# 损失函数

In [5]:
def contrastive_loss(x0, x1, tau, hyp_c):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode

    dist_f = lambda x, y: -dist_matrix_d(x, y, c=hyp_c)
    bsize = x0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    logits00 = dist_f(x0, x0) / tau - eye_mask
    logits01 = dist_f(x0, x1) / tau
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    stats = {
        "logits/min": logits01.min().item(),
        "logits/mean": logits01.mean().item(),
        "logits/max": logits01.max().item(),
        "logits/acc": (logits01.argmax(-1) == target).float().mean().item(),
    }
    return loss, stats

# 模型

In [6]:
def init_model(model = model, hyp_c = 0.1, emb = 128, clip_r = 2.3, freeze = 0):
    if model.startswith("dino"):
        body = torch.hub.load("facebookresearch/dino:main", model)
    else:
        body = timm.create_model(model, pretrained=True)
    if hyp_c > 0:
        last = ToProjection_hypersphere(
            c=hyp_c,
            clip_r=clip_r,
        )
    else:
        last = NormLayer()
    bdim = 2048 if model == "resnet50" else 384
    head = nn.Sequential(nn.Linear(bdim, emb),nn.BatchNorm1d(emb), last)
    nn.init.constant_(head[0].bias.data, 0)
    nn.init.orthogonal_(head[0].weight.data)
    rm_head(body)
    if freeze is not None:
        freezer(body,freeze)
    model = HeadSwitch(body, head)
    model.cuda().train()
    return model
    
    
class HeadSwitch(nn.Module):
    def __init__(self, body, head):
        super(HeadSwitch, self).__init__()
        self.body = body
        self.head = head
        self.norm = NormLayer()

    def forward(self, x, skip_head=False):
        x = self.body(x)
        if type(x) == tuple:
            x = x[0]
        if not skip_head:
            x = self.head(x)
        else:
            x = self.norm(x)
        return x


class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)


def freezer(model, num_block):
    def fr(m):
        for param in m.parameters():
            param.requires_grad = False

    fr(model.patch_embed)
    fr(model.pos_drop)
    for i in range(num_block):
        fr(model.blocks[i])


def rm_head(m):
    names = set(x[0] for x in m.named_children())
    target = {"head", "fc", "head_dist"}
    for x in names & target:
        m.add_module(x, nn.Identity())

In [7]:
class MultiSample:
    def __init__(self, transform, n=2):
        self.transform = transform
        self.num = n

    def __call__(self, x):
        return tuple(self.transform(x) for _ in range(self.num))


def evaluate(get_emb_f, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        emb_head = get_emb_f(ds_type="eval")
        recall_head = get_recall(*emb_head, ds_name, hyp_c)
    elif ds_name == "SOP":
        emb_head = get_emb_f(ds_type="eval")
        recall_head = get_recall_sop(*emb_head, ds_name, hyp_c)
    else:
        emb_head_query = get_emb_f(ds_type="query")
        emb_head_gal = get_emb_f(ds_type="gallery")
        recall_head = get_recall_inshop(*emb_head_query, *emb_head_gal, hyp_c)
    return recall_head


def get_recall(x, y, ds_name, hyp_c):
    k_list = [1, 2, 4, 8, 16, 32]

    if hyp_c > 0:
        dist_m = torch.empty(len(x), len(x), device="cuda")
        for i in range(len(x)):
            dist_m[i : i + 1] = -dist_matrix_d(x[i : i + 1], x, hyp_c)
    else:
        dist_m = torch.empty(len(x), len(x), device="cuda")
        for i in range(len(x)):
            dist_m[i : i + 1] = -torch.cdist(x[i : i + 1], x, p=2)
    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall


def get_recall_sop(x, y, ds_name, hyp_c):
    y_cur = torch.tensor([]).cuda().int()
    number = 1000
    k_list = [1, 10, 100, 1000]
    if hyp_c > 0:
        for i in range(len(x) // number + 1):
            if (i+1)*number > len(x):
                x_s = x[i*number:]
            else:
                x_s = x[i*number: (i+1)*number]
            dist = torch.empty(len(x_s), len(x), device="cuda")
            for i in range(len(x_s)):
                dist[i : i + 1] = -dist_matrix_d(x_s[i : i + 1], x, hyp_c)
            dist = y[dist.topk(1 + max(k_list), largest=True)[1][:, 1:]]
            y_cur = torch.cat([y_cur, dist])
    else:
        for i in range(len(x) // number + 1):
            if (i+1)*number > len(x):
                x_s = x[i*number:]
            else:
                x_s = x[i*number: (i+1)*number]
            dist = torch.empty(len(x_s), len(x), device="cuda")
            for i in range(len(x_s)):
                dist[i : i + 1] = -torch.cdist(x_s[i : i + 1], x, p=2)
            dist = y[dist.topk(1 + max(k_list), largest=True)[1][:, 1:]]
            y_cur = torch.cat([y_cur, dist])
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall


def get_recall_inshop(xq, yq, xg, yg, hyp_c):
    if hyp_c > 0:
        dist_m = torch.empty(len(xq), len(xg), device="cuda")
        for i in range(len(xq)):
            dist_m[i : i + 1] = -dist_matrix_d(xq[i : i + 1], xg, hyp_c)
    else:
        dist_m = torch.empty(len(xq), len(xg), device="cuda")
        for i in range(len(xq)):
            dist_m[i : i + 1] = -torch.cdist(xq[i : i + 1], xg, hyp_c)

    def recall_k(cos_sim, query_T, gallery_T, k):
        m = len(cos_sim)
        match_counter = 0
        for i in range(m):
            pos_sim = cos_sim[i][gallery_T == query_T[i]]
            neg_sim = cos_sim[i][gallery_T != query_T[i]]
            thresh = torch.max(pos_sim).item()
            if torch.sum(neg_sim > thresh) < k:
                match_counter += 1
        return match_counter / m

    recall = [recall_k(dist_m, yq, yg, k) for k in [1, 10, 20, 30, 40, 50]]
    print(recall)
    return recall


def get_emb(
    model,
    ds,
    path,
    mean_std,
    resize=224,
    crop=224,
    ds_type="eval",
    world_size=1,
    skip_head=False,
):
    eval_tr = T.Compose(
        [
            T.Resize(resize, interpolation=PIL.Image.BICUBIC),
            T.CenterCrop(crop),
            T.ToTensor(),
            T.Normalize(*mean_std),
        ]
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=multiprocessing.cpu_count() // world_size,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    x, y = eval_dataset(model, dl_eval, skip_head)
    y = y.cuda()
    if world_size > 1:
        all_x = [torch.zeros_like(x) for _ in range(world_size)]
        all_y = [torch.zeros_like(y) for _ in range(world_size)]
        torch.distributed.all_gather(all_x, x)
        torch.distributed.all_gather(all_y, y)
        x, y = torch.cat(all_x), torch.cat(all_y)
    model.train()
    return x, y


def eval_dataset(model, dl, skip_head):
    all_x, all_y = [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            all_x.append(model(x, skip_head=skip_head))
        all_y.append(y)
    return torch.cat(all_x), torch.cat(all_y)


# 主函数

In [8]:
os.environ["CUDA_VISIBLE_DEVICES"] =  "3,4"
# if local_rank == 0:
#     wandb.init(project="hyp_metric")

world_size = int(os.environ.get("WORLD_SIZE", 1))
    
if model.startswith("vit"):
    mean_std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
else:
    mean_std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
train_tr = T.Compose(
    [
        T.RandomResizedCrop(
            crop, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC
        ),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(*mean_std),
    ]
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[ds]
ds_train = ds_class(path, "train", train_tr)
assert len(ds_train.ys) * num_samples >= bs * world_size
sampler = UniqueClassSempler(
    ds_train.ys, num_samples, local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=bs,
    num_workers=8,
    pin_memory=True,
    drop_last=True,
)

model = init_model(model = model, hyp_c = hyp_c, emb = emb, clip_r = clip_r, freeze = freeze)
optimizer = optim.AdamW(model.parameters(), lr=lr)
model = nn.DataParallel(model)

loss_f = partial(contrastive_loss, tau=t, hyp_c=hyp_c)
get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=path,
    mean_std=mean_std,
    world_size=world_size,
    resize=resize,
    crop=crop,
)
eval_ep = eval(eval_ep.replace("r", "list(range").replace(")", "))"))    

cudnn.benchmark = True
all_rh = []
best_rh = []
best_ep = 0
lower_cnt = 0

for ep in trange(ep):
    sampler.set_epoch(ep)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // num_samples, num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s = y[:, 0].tolist()
        assert len(set(s)) == len(s)

        x = x.cuda(non_blocking=True)
        z = model(x)
        z=z.view(len(x) // num_samples, num_samples, emb)
        if world_size > 1:
            with torch.no_grad():
                all_z = [torch.zeros_like(z) for _ in range(world_size)]
                torch.distributed.all_gather(all_z, z)
            all_z[local_rank] = z
            z = torch.cat(all_z)
        loss = 0
        for i in range(num_samples):
            for j in range(num_samples):
                if i != j:
                    l, s = loss_f(z[:, i], z[:, j])
                    loss += l
                    stats_ep.append({**s, "loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 3)
        optimizer.step()

#     if (ep + 1) in eval_ep:
#         rh= evaluate(get_emb_f, ds, hyp_c)
        
    if (ep+1) % 10 == 0 or ep == 0:
        rh = evaluate(get_emb_f, ds, hyp_c)
        all_rh.append(rh)
        if ep == 0:
            best_rh = rh
        else:
            if isinstance(rh, list):
                if rh[0] >= best_rh[0]:
                    lower_cnt = 0
                    best_rh = rh
                    best_ep = ep
                    print("save model........")
                    torch.save(model.state_dict(), save_path)
                else:
                    lower_cnt += 1
            else:
                if rh >= best_rh:
                    lower_cnt = 0
                    best_rh = rh
                    best_ep = ep
                    print("save model........")
                    torch.save(model.state_dict(), save_path)
                else:
                    lower_cnt += 1
    
    if lower_cnt >= 10:
        break

    if local_rank == 0:
        stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
#         if (ep + 1) in eval_ep:
#             stats_ep = {"recall": rh, **stats_ep}
#         wandb.log({**stats_ep, "ep": ep})
        
print("best:", best_ep+1, best_rh)

if save_emb:
    ds_type = "gallery" if ds == "Inshop" else "eval"
    x, y = get_emb_f(ds_type=ds_type)
    x, y = x.float().cpu(), y.long().cpu()
    torch.save((x, y), path + "/" + emb_name + "_eval.pt")

    x, y = get_emb_f(ds_type="train")
    x, y = x.float().cpu(), y.long().cpu()
    torch.save((x, y), path + "/" + emb_name + "_train.pt")

  "Argument interpolation should be of type InterpolationMode instead of int. "
  2%|█▏                                                            | 1/50 [00:44<36:40, 44.91s/it]

[0.7923700202565834, 0.8762660364618501, 0.9272451046590142, 0.9559419311276165, 0.9728224172856178, 0.9846387575962188]


 18%|███████████▏                                                  | 9/50 [01:07<02:24,  3.52s/it]

[0.8151586765698852, 0.887407157326131, 0.9304523970290345, 0.9581363943281567, 0.974510465901418, 0.9832883187035787]
save model........


 38%|███████████████████████▏                                     | 19/50 [01:49<01:31,  2.94s/it]

[0.8230925050641459, 0.8894328156650911, 0.9301147873058744, 0.9571235651586766, 0.9738352464550979, 0.9822754895340986]
save model........


 60%|████████████████████████████████████▌                        | 30/50 [02:48<02:30,  7.51s/it]

[0.8214044564483457, 0.8863943281566509, 0.9301147873058744, 0.9574611748818366, 0.9734976367319379, 0.9821066846725186]


 80%|████████████████████████████████████████████████▊            | 40/50 [03:28<01:08,  6.85s/it]

[0.8178595543551654, 0.8838622552329507, 0.9285955435516543, 0.9589804186360568, 0.9723160027008778, 0.9822754895340986]


100%|█████████████████████████████████████████████████████████████| 50/50 [04:08<00:00,  4.98s/it]

[0.8176907494935854, 0.8863943281566509, 0.9331532748143146, 0.9584740040513167, 0.9721471978392978, 0.9816002700877785]
best: 20 [0.8230925050641459, 0.8894328156650911, 0.9301147873058744, 0.9571235651586766, 0.9738352464550979, 0.9822754895340986]



