In [1]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import argparse
import torch.nn.functional as F
from torch.utils.data import DataLoader
from functools import partial

import net
from hyptorch.pmath import dist_matrix
from hyptorch import pmath
from proxy_anchor import dataset
from proxy_anchor.utils import calc_recall_at_k
from sampler import UniqueClassSempler
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset

In [2]:
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
parser = argparse.ArgumentParser(description='PyTorch Training')

parser.add_argument('--LOG_DIR', 
    default='/data/xuyunhao/Mixed curvature',
    help = 'Path to log folder'
)
parser.add_argument('--path', default='/data/xuyunhao/datasets', type=str,
                    help='path to datasets')
parser.add_argument('--ds', default='CUB', type=str,
                    help='')
parser.add_argument('--num_samples', default=2, type=int,
                    help='how many samples per each category in batch')
parser.add_argument('--bs', default=200, type=int,
                    help='batch size per GPU, e.g. --num_samples 3 --bs 900 means each iteration we sample 300 categories with 3 samples')
parser.add_argument('--lr', default=1e-5, type=float,
                    help='learning rate')
parser.add_argument('--t', default=0.2, type=float,
                    help='cross-entropy temperature')
parser.add_argument('--emb', default=128, type=int,
                    help='output embedding size')
parser.add_argument('--freeze', default=0, type=int,
                    help='number of blocks in transformer to freeze, None - freeze nothing, 0 - freeze only patch_embed')
parser.add_argument('--ep', default=100, type=int,
                    help='number of epochs')
parser.add_argument('--hyp_c', default=0.1, type=float,
                    help='hyperbolic c, "0" enables sphere mode')
parser.add_argument('--model', default='resnet18', type=str,
                    help='model name from timm or torch.hub, i.e. deit_small_distilled_patch16_224, vit_small_patch16_224, dino_vits16')
parser.add_argument('--save_emb', default=False, type=bool,
                    help='save embeddings of the dataset after training')
parser.add_argument('--emb_name', default='emb', type=str,
                    help='filename for embeddings')
parser.add_argument('--clip_r', default=2.3, type=float,
                    help='')
parser.add_argument('--resize', default=224, type=int,
                    help='image resize')
parser.add_argument('--crop', default=224, type=int,
                    help='center crop after resize')
parser.add_argument('--local_rank', default=0, type=int,
                    help='set automatically for distributed training')
parser.add_argument('--workers', default = 4, type = int,
    dest = 'nb_workers',
    help = 'Number of workers for dataloader.'
)
parser.add_argument('--optimizer', default = 'adamw',
    help = 'Optimizer setting'
)
parser.add_argument('--gpu-id', default = 4, type = int,
    help = 'ID of GPU that is used for training.'
)
parser.add_argument('--bn-freeze', default = 1, type = int,
    help = 'Batch normalization parameter freeze'
)
parser.add_argument('--l2-norm', default = 1, type = int,
    help = 'L2 normlization'
)
parser.add_argument('--lr-decay-step', default = 10, type =int,
    help = 'Learning decay step setting'
)
parser.add_argument('--lr-decay-gamma', default = 0.5, type =float,
    help = 'Learning decay gamma setting'
)
parser.add_argument('--warm', default = 1, type = int,
    help = 'Warmup training epochs'
)

_StoreAction(option_strings=['--warm'], dest='warm', nargs=None, const=None, default=1, type=<class 'int'>, choices=None, help='Warmup training epochs', metavar=None)

# 投影空间

## 双曲面

In [4]:
def hexp0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _hexpmap0(u, c)

def _hexpmap0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamm_1 = torch.cosh(sqrt_c*u_norm)/sqrt_c
    gamm_2 = torch.sinh(sqrt_c*u_norm)*u/(sqrt_c * u_norm)
    gamma = torch.cat([gamm_1, gamm_2],dim = 1)
    return gamma
    
def _dist_matrix_h(x, y, c):
    sqrt_c = c ** 0.5
    b = torch.ones_like(x).cuda()
    b[:, 0] = b[:, 0] * (-1)
    x2 = x * b
    xy_l = torch.einsum("ij,kj->ik", (x2, y))
    return (1/sqrt_c * pmath.arcosh(c*xy_l))


def dist_matrix_h(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix_h(x, y, c)

class ToHyperbolic(nn.Module):
    def __init__(self, c, clip_r=None):
        super(ToHyperbolic, self).__init__()
        self.register_parameter("xp", None)
        self.c = c
        self.clip_r = clip_r
        
    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return pmath.project(hexp0(x, c=self.c), c=self.c)

## 超球面

In [5]:
def sexp0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _sexpmap0(u, c)

def _sexpmap0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamm_1 = torch.cos(sqrt_c*u_norm)/sqrt_c
    gamm_2 = torch.sin(sqrt_c*u_norm)*u/(sqrt_c * u_norm)
    gamma = torch.cat([gamm_1, gamm_2],dim = 1)
    return gamma

def _dist_matrix_s(x, y, c):
    sqrt_c = c ** 0.5
    xy_l = torch.einsum("ij,kj->ik", (x, y))
    return (1/sqrt_c * torch.acos(c*xy_l))


def dist_matrix_s(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix_s(x, y, c)

class ToHypersphere(nn.Module):
    def __init__(self, c, clip_r=None):
        super(ToHypersphere, self).__init__()
        self.register_parameter("xp", None)
        self.c = c
        self.clip_r = clip_r
        
    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return pmath.project(sexp0(x, c=self.c), c=self.c)

## 投影超球

In [6]:
def dexp0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _dexp0(u, c)


def _dexp0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = torch.tan(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def _dist_matrix_d(x, y, c):
    xy =torch.einsum("ij,kj->ik", (x, y))  # B x C
    x2 = x.pow(2).sum(-1, keepdim=True)  # B x 1
    y2 = y.pow(2).sum(-1, keepdim=True)  # C x 1
    sqrt_c = c ** 0.5
    num1 = 2*c*(x2+y2.permute(1, 0)-2*xy)
    num2 = torch.mul((1+c*x2),(1+c*y2.permute(1, 0)))
    return (1/sqrt_c * torch.acos(1-num1/num2))


def dist_matrix_d(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix_d(x, y, c)

class ToProjection_hypersphere(nn.Module):
    def __init__(self, c, clip_r=None):
        super(ToProjection_hypersphere, self).__init__()
        self.register_parameter("xp", None)
        self.c = c
        self.clip_r = clip_r
        
    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return pmath.project(dexp0(x, c=self.c), c=self.c)

## 计算权重

In [7]:
class weight_conculate(nn.Module):
    def __init__(self, sigma, topk):
        super(weight_conculate, self).__init__()
        self.sigma = sigma
        self.topk = topk

    def forward(self, dist_matrix):

        N = len(dist_matrix)        
        
        with torch.no_grad():
            W_P = torch.exp(-dist_matrix.pow(2) / self.sigma)

            topk_index = torch.topk(W_P, self.topk)[1]
            topk_half_index = topk_index[:, :int(np.around(self.topk/2))]

            W_NN = torch.zeros_like(W_P).scatter_(1, topk_index, torch.ones_like(W_P))
            V = ((W_NN + W_NN.t())/2 == 1).float()

            W_C_tilda = torch.zeros_like(W_P)
            for i in range(N):
                indNonzero = torch.where(V[i, :]!=0)[0]
                W_C_tilda[i, indNonzero] = (V[:,indNonzero].sum(1) / len(indNonzero))[indNonzero]
                
            W_C_hat = W_C_tilda[topk_half_index].mean(1)
            W_C = (W_C_hat + W_C_hat.t())/2
            W = (W_P + W_C)/2

        return W

# 损失函数

In [8]:
def contrastive_loss(e0, e1, p0, p1, h0, h1, d0, d1, s0, s1, tau, hyp_c, weightconculate):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode
    dist_e = lambda x, y: x @ y.t()
    dist_p = lambda x, y: -dist_matrix(x, y, c=hyp_c)
    dist_h = lambda x, y: -dist_matrix_h(x, y, c=hyp_c)
    dist_d = lambda x, y: -dist_matrix_d(x, y, c=hyp_c)
    dist_s = lambda x, y: -dist_matrix_s(x, y, c=hyp_c)
    
    dist_e0 = dist_e(e0, e0)
    dist_p0 = dist_p(p0, p0)
    dist_h0 = dist_h(h0, h0)
    dist_d0 = dist_d(d0, d0)
    dist_s0 = dist_s(s0, s0)
    w_e0 = weightconculate(-dist_e0).cuda()
    w_p0 = weightconculate(-dist_p0).cuda()
    w_h0 = weightconculate(-dist_h0).cuda()
    w_d0 = weightconculate(-dist_d0).cuda()
    w_s0 = weightconculate(-dist_s0).cuda()
    
    dist_e1 = dist_e(e0, e1)
    dist_p1 = dist_p(p0, p1)
    dist_h1 = dist_h(h0, h1)
    dist_d1 = dist_d(d0, d1)
    dist_s1 = dist_s(s0, s1)
    w_e1 = weightconculate(-dist_e1).cuda()
    w_p1 = weightconculate(-dist_p1).cuda()
    w_h1 = weightconculate(-dist_h1).cuda()
    w_d1 = weightconculate(-dist_d1).cuda()
    w_s1 = weightconculate(-dist_s1).cuda()
    
    bsize = e0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    logits00 = (w_e0*dist_e0+w_p0*dist_p0+w_h0*dist_h0+w_d0*dist_d0+w_s0*dist_s0) / tau - eye_mask
    logits01 = (w_e1*dist_e1+w_p1*dist_p1+w_h1*dist_h1+w_d1*dist_d1+w_s1*dist_s1) / tau
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    stats = {
        "logits/min": logits01.min().item(),
        "logits/mean": logits01.mean().item(),
        "logits/max": logits01.max().item(),
        "logits/acc": (logits01.argmax(-1) == target).float().mean().item(),
    }
    return loss, stats

# 模型

In [9]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
import torch.nn.init as init
import hyptorch.nn as hypnn
from torchvision.models import resnet18

class Resnet18(nn.Module):
    def __init__(self,embedding_size, pretrained=True, bn_freeze = True, hyp_c = 0, clip_r = 0):
        super(Resnet18, self).__init__()

        self.model = resnet18(pretrained)
        self.embedding_size = embedding_size
        self.num_ftrs = self.model.fc.in_features
        self.model.gap = nn.AdaptiveAvgPool2d(1)
        self.model.gmp = nn.AdaptiveMaxPool2d(1)
        
        self.Elayer = NormLayer()
        self.Player = hypnn.ToPoincare(
            c=hyp_c,
            ball_dim=embedding_size,
            riemannian=False,
            clip_r=clip_r,
        )
        self.Hlayer = ToHyperbolic(c=hyp_c, clip_r=clip_r)
        self.Dlayer = ToProjection_hypersphere(c=hyp_c, clip_r=clip_r)
        self.Slayer = ToHypersphere(c=hyp_c, clip_r=clip_r)

        self.model.embedding = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Elayer)
        self.model.embeddingP = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Player)
        self.model.embeddingH = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Hlayer)
        self.model.embeddingD = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Dlayer)
        self.model.embeddingS = nn.Sequential(nn.Linear(self.num_ftrs, self.embedding_size), self.Slayer)
        
        self._initialize_weights()

        if bn_freeze:
            for m in self.model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
                    m.weight.requires_grad_(False)
                    m.bias.requires_grad_(False)


    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)

        avg_x = self.model.gap(x)
        max_x = self.model.gmp(x)

        x = max_x + avg_x
        
        x = x.view(x.size(0), -1)
        x_e = self.model.embedding(x)
        x_p = self.model.embeddingP(x)
        x_h = self.model.embeddingH(x)
        x_d = self.model.embeddingD(x)
        x_s = self.model.embeddingS(x)
        return x_e, x_p, x_h, x_d, x_s

    def _initialize_weights(self):
        init.kaiming_normal_(self.model.embedding[0].weight, mode='fan_out')
        init.constant_(self.model.embedding[0].bias, 0)
        init.kaiming_normal_(self.model.embeddingP[0].weight, mode='fan_out')
        init.constant_(self.model.embeddingP[0].bias, 0)
        init.kaiming_normal_(self.model.embeddingH[0].weight, mode='fan_out')
        init.constant_(self.model.embeddingH[0].bias, 0)
        init.kaiming_normal_(self.model.embeddingD[0].weight, mode='fan_out')
        init.constant_(self.model.embeddingD[0].bias, 0)
        init.kaiming_normal_(self.model.embeddingS[0].weight, mode='fan_out')
        init.constant_(self.model.embeddingS[0].bias, 0)

class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)

# 验证集

In [10]:
def evaluate(get_emb_f, ds_name, hyp_c):
    if ds_name != "Inshop":
        emb_head = get_emb_f(ds_type="eval")
        recall_head = get_recall(*emb_head, ds_name, hyp_c)
    else:
        emb_head_query = get_emb_f(ds_type="query")
        emb_head_gal = get_emb_f(ds_type="gallery")
        emb_body_query = get_emb_f(ds_type="query", skip_head=True)
        emb_body_gal = get_emb_f(ds_type="gallery", skip_head=True)
        recall_head = get_recall_inshop(*emb_head_query, *emb_head_gal, hyp_c)
        recall_body = get_recall_inshop(*emb_body_query, *emb_body_gal, 0)
    return recall_head

def get_recall(e, p, h, d, s, y, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        k_list = [1, 2, 4, 8, 16, 32]
    elif ds_name == "SOP":
        k_list = [1, 10, 100, 1000]

    dist_m = torch.empty(len(e), len(e), device="cuda")
    for i in range(len(x)):
        dist_m[i : i + 1] = -dist_matrix(p[i : i + 1], p, hyp_c)-dist_matrix_h(h[i : i + 1], h, hyp_c)-dist_matrix_d(d[i : i + 1], d, hyp_c)-dist_matrix_s(s[i : i + 1], s, hyp_c)
    dist_m = dist_m + e @ e.t()

    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall[0]

def get_emb(
    model,
    ds,
    path,
    ds_type="eval",
    world_size=1,
    num_workers=8,
):
    eval_tr = dataset.utils.make_transform(
        is_train = True, 
        is_inception = (args.model == 'bn_inception')
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    e, p, h, d, s, y = eval_dataset(model, dl_eval)
    y = y.cuda()
    model.train()
    return e, p, h, d, s, y

def eval_dataset(model, dl):
    all_xe, all_xp, all_xh, all_xd, all_xs, all_y = [], [], [], [], [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            e, p, h, d, s = model(x)
            all_xe.append(e)
            all_xp.append(p)
            all_xh.append(h)
            all_xs.append(d)
            all_xd.append(s)
        all_y.append(y)
    return torch.cat(all_xe), torch.cat(all_xp), torch.cat(all_xh), torch.cat(all_xd), torch.cat(all_xs), torch.cat(all_y)



# 主函数

In [11]:
args,_ = parser.parse_known_args()
if args.gpu_id != -1:
    torch.cuda.set_device(args.gpu_id)
LOG_DIR = args.LOG_DIR + '/logs_{}/{}_embedding{}_{}_lr{}_batch{}'.format(args.ds, args.model, args.emb, args.optimizer, args.lr, args.bs)
world_size = int(os.environ.get("WORLD_SIZE", 1))

train_tr = dataset.utils.make_transform(
    is_train = True, 
    is_inception = (args.model == 'bn_inception')
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[args.ds]
ds_train = ds_class(args.path, "train", train_tr)

sampler = UniqueClassSempler(
    ds_train.ys, args.num_samples, args.local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=args.bs,
    num_workers=args.nb_workers,
    pin_memory=True,
)

if args.model.find('resnet18')+1:
    model = Resnet18(embedding_size=args.emb, pretrained=True, bn_freeze = args.bn_freeze, hyp_c =args.hyp_c, clip_r = args.clip_r).cuda().train()
if args.gpu_id != -1:
    unfreeze_model_param = list(model.model.embedding.parameters())+list(model.model.embeddingP.parameters())+list(model.model.embeddingH.parameters())+list(model.model.embeddingD.parameters())+list(model.model.embeddingS.parameters())
    for param in list(set(model.parameters()).difference(set(unfreeze_model_param))):
        param.requires_grad = False 
else:
    unfreeze_model_param = list(model.module.model.embedding.parameters())+list(model.module.model.embeddingP.parameters())+list(model.module.model.embeddingH.parameters())+list(model.module.model.embeddingD.parameters())+list(model.module.model.embeddingS.parameters())
    for param in list(set(model.module.parameters()).difference(set(unfreeze_model_param))):
        param.requires_grad = False    
    
if args.gpu_id == -1:
    model = nn.DataParallel(model).cuda().train()
wt = weight_conculate(sigma = 1, topk = 9)
loss_f = partial(contrastive_loss, tau=args.t, hyp_c=args.hyp_c, weightconculate= wt)

get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=args.path,
    num_workers=args.nb_workers,
    world_size=world_size,
)

# Train Parameters
param_groups = [
    {'params': list(set(model.parameters()).difference(set(unfreeze_model_param))) if args.gpu_id != -1 else 
                 list(set(model.module.parameters()).difference(set(unfreeze_model_param)))},
    {'params': unfreeze_model_param, 'lr':float(args.lr) * 1},
]

optimizer = optim.AdamW(param_groups, lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step, gamma = args.lr_decay_gamma)
print("Training parameters: {}".format(vars(args)))
print("Training for {} epochs.".format(args.ep))
losses_list = []
best_recall= 0
best_epoch = 0

for epoch in range(0, args.ep):
    model.train()
    bn_freeze = args.bn_freeze
    if bn_freeze:
        modules = model.model.modules() if args.gpu_id != -1 else model.module.model.modules()
        for m in modules: 
            if isinstance(m, nn.BatchNorm2d):
                m.eval()

    losses_per_epoch = []
    
     #Warmup: Train only new params, helps stabilize learning.
#    if args.warm > 0:
#        if args.gpu_id != -1:
#            unfreeze_model_param = list(model.model.embedding.parameters())
#        else:
#            unfreeze_model_param = list(model.module.model.embedding.parameters())

#        if epoch == 0:
#            for param in list(set(model.parameters()).difference(set(unfreeze_model_param))):
#                param.requires_grad = False
#        if epoch == args.warm:
#            for param in list(set(model.parameters()).difference(set(unfreeze_model_param))):
#                param.requires_grad = True

    sampler.set_epoch(epoch)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // args.num_samples, args.num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s1 = y[:, 0].tolist()
        assert len(set(s1)) == len(s1)

        x = x.cuda(non_blocking=True)
        e, p, h, d, s = model(x)
        e = e.view(len(x) // args.num_samples, args.num_samples, args.emb)
        p = p.view(len(x) // args.num_samples, args.num_samples, args.emb)
        h = h.view(len(x) // args.num_samples, args.num_samples, args.emb+1)
        d = d.view(len(x) // args.num_samples, args.num_samples, args.emb)
        s = s.view(len(x) // args.num_samples, args.num_samples, args.emb+1)
        loss = 0
        for i in range(args.num_samples):
            for j in range(args.num_samples):
                if i != j:
                    l, st = loss_f(e[:, i], e[:, j], p[:, i], p[:, j], h[:, i], h[:, j],d[:, i], d[:, j],s[:, i], s[:, j])
                    loss += l
                    stats_ep.append({**st, "loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 10)
        optimizer.step()
            
    scheduler.step()
    rh= evaluate(get_emb_f, args.ds, args.hyp_c)
        
    if args.local_rank == 0:
        stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
        stats_ep = {"recall": rh, **stats_ep}
        if rh > best_recall :
            best_recall = rh
            best_epoch = epoch
        print("epoch:",epoch,"recall: ", rh)
        print("best epoch:",best_epoch,"best recall: ", best_recall)

Training parameters: {'LOG_DIR': '/data/xuyunhao/Mixed curvature', 'path': '/data/xuyunhao/datasets', 'ds': 'CUB', 'num_samples': 2, 'bs': 200, 'lr': 1e-05, 't': 0.2, 'emb': 128, 'freeze': 0, 'ep': 100, 'hyp_c': 0.1, 'eval_ep': '[100]', 'model': 'resnet18', 'save_emb': False, 'emb_name': 'emb', 'clip_r': 2.3, 'resize': 224, 'crop': 224, 'local_rank': 0, 'nb_workers': 4, 'optimizer': 'adamw', 'gpu_id': 4, 'bn_freeze': 1, 'l2_norm': 1, 'lr_decay_step': 10, 'lr_decay_gamma': 0.5, 'warm': 1}
Training for 100 epochs.
number:  0
[0.008440243079000676, 0.008440243079000676, 0.008440243079000676, 0.008440243079000676, 0.008440243079000676, 0.008440243079000676]
epoch: 0 recall:  0.008440243079000676
best epoch: 0 best recall:  0.008440243079000676
number:  0
[0.008440243079000676, 0.008440243079000676, 0.008440243079000676, 0.008440243079000676, 0.008440243079000676, 0.008440243079000676]
epoch: 1 recall:  0.008440243079000676
best epoch: 0 best recall:  0.008440243079000676
number:  0


KeyboardInterrupt: 