In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
from apex import amp

import os
import random
import timm
import wandb
from tqdm import trange
import multiprocessing
from functools import partial
import numpy as np
import PIL
from torchvision.transforms import InterpolationMode

from sampler import UniqueClassSempler
from proxy_anchor.utils import calc_recall_at_k
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset
from hyptorch.pmath import dist_matrix
import hyptorch.nn as hypnn

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from datetime import timedelta

import ipdb

In [2]:
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [3]:
path = '/data/xuyunhao/datasets'
ds = 'Cars'
num_samples = 2
bs = 196

lr = 1e-5

t = 0.2
emb = 128
freeze = 0
ep = 500
hyp_c = 0.1
eval_ep = 'r('+str(ep-100)+','+str(ep+10)+',10)'

# model = 'vit_small_patch16_224'
model = 'dino_vits16'
#model = 'deit_small_distilled_patch16_224'

save_emb = False
emb_name = 'emb'
clip_r = 2.3
resize = 224
crop = 224
local_rank = 0

In [4]:
def project(x, *, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _project(x, c)

def _project(x, c):
    norm = torch.clamp_min(x.norm(dim=-1, keepdim=True, p=2), 1e-5)
    maxnorm = (1 - 1e-3) / (c ** 0.5)
    cond = norm > maxnorm
    projected = x / norm * maxnorm
    return torch.where(cond, projected, x)

def dexp0(u, *, c=1.0):
    c = torch.as_tensor(c).type_as(u)
    return _dexp0(u, c)

def _dexp0(u, c):
    sqrt_c = c ** 0.5
    u_norm = torch.clamp_min(u.norm(dim=-1, p=2, keepdim=True), 1e-5)
    gamma_1 = torch.tan(sqrt_c * u_norm) * u / (sqrt_c * u_norm)
    return gamma_1

def _dist_matrix_d(x, y, c):
    xy =torch.einsum("ij,kj->ik", (x, y))  # B x C
    x2 = x.pow(2).sum(-1, keepdim=True)  # B x 1
    y2 = y.pow(2).sum(-1, keepdim=True)  # C x 1
    sqrt_c = c ** 0.5
    num1 = 2*c*(x2+y2.permute(1, 0)-2*xy)+1e-5
    num2 = torch.mul((1+c*x2),(1+c*y2.permute(1, 0)))
    return (1/sqrt_c * torch.acos(1-num1/num2))


def dist_matrix_d(x, y, c=1.0):
    c = torch.as_tensor(c).type_as(x)
    return _dist_matrix_d(x, y, c)

class ToProjection_hypersphere(nn.Module):
    def __init__(self, c, clip_r=None):
        super(ToProjection_hypersphere, self).__init__()
        self.register_parameter("xp", None)
        self.c = c
        self.clip_r = clip_r
        
    def forward(self, x):
        if self.clip_r is not None:
            x_norm = torch.norm(x, dim=-1, keepdim=True) + 1e-5
            fac =  torch.minimum(
                torch.ones_like(x_norm), 
                self.clip_r / x_norm
            )
            x = x * fac
        return project(dexp0(x, c=self.c), c=self.c)

In [5]:
def contrastive_loss(x0, x1, tau, hyp_c):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode

    dist_f = lambda x, y: -dist_matrix_d(x, y, c=hyp_c)
    bsize = x0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    logits00 = dist_f(x0, x0) / tau - eye_mask
    logits01 = dist_f(x0, x1) / tau
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    stats = {
        "logits/min": logits01.min().item(),
        "logits/mean": logits01.mean().item(),
        "logits/max": logits01.max().item(),
        "logits/acc": (logits01.argmax(-1) == target).float().mean().item(),
    }
    return loss, stats

In [6]:
def init_model(model = model, hyp_c = 0.1, emb = 128, clip_r = 2.3, freeze = 0):
    if model.startswith("dino"):
        body = torch.hub.load("facebookresearch/dino:main", model)
    else:
        body = timm.create_model(model, pretrained=True)
    if hyp_c > 0:
        last = ToProjection_hypersphere(
            c=hyp_c,
            clip_r=clip_r,
        )
    else:
        last = NormLayer()
    bdim = 2048 if model == "resnet50" else 384
    head = nn.Sequential(nn.Linear(bdim, emb),nn.BatchNorm1d(emb), last)
    nn.init.constant_(head[0].bias.data, 0)
    nn.init.orthogonal_(head[0].weight.data)
    rm_head(body)
    if freeze is not None:
        freezer(body,freeze)
    model = HeadSwitch(body, head)
    model.cuda().train()
    return model
    
    
class HeadSwitch(nn.Module):
    def __init__(self, body, head):
        super(HeadSwitch, self).__init__()
        self.body = body
        self.head = head
        self.norm = NormLayer()

    def forward(self, x, skip_head=False):
        x = self.body(x)
        if type(x) == tuple:
            x = x[0]
        if not skip_head:
            x = self.head(x)
        else:
            x = self.norm(x)
        return x


class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)


def freezer(model, num_block):
    def fr(m):
        for param in m.parameters():
            param.requires_grad = False

    fr(model.patch_embed)
    fr(model.pos_drop)
    for i in range(num_block):
        fr(model.blocks[i])


def rm_head(m):
    names = set(x[0] for x in m.named_children())
    target = {"head", "fc", "head_dist"}
    for x in names & target:
        m.add_module(x, nn.Identity())

In [7]:
class MultiSample:
    def __init__(self, transform, n=2):
        self.transform = transform
        self.num = n

    def __call__(self, x):
        return tuple(self.transform(x) for _ in range(self.num))


def evaluate(get_emb_f, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        emb_head = get_emb_f(ds_type="eval")
        recall_head = get_recall(*emb_head, ds_name, hyp_c)
    elif ds_name == "SOP":
        emb_head = get_emb_f(ds_type="eval")
        recall_head = get_recall_sop(*emb_head, ds_name, hyp_c)
    else:
        emb_head_query = get_emb_f(ds_type="query")
        emb_head_gal = get_emb_f(ds_type="gallery")
        recall_head = get_recall_inshop(*emb_head_query, *emb_head_gal, hyp_c)
    return recall_head


def get_recall(x, y, ds_name, hyp_c):
    k_list = [1, 2, 4, 8, 16, 32]

    if hyp_c > 0:
        dist_m = torch.empty(len(x), len(x), device="cuda")
        for i in range(len(x)):
            dist_m[i : i + 1] = -dist_matrix_d(x[i : i + 1], x, hyp_c)
    else:
        dist_m = torch.empty(len(x), len(x), device="cuda")
        for i in range(len(x)):
            dist_m[i : i + 1] = -torch.cdist(x[i : i + 1], x, p=2)
    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall


def get_recall_sop(x, y, ds_name, hyp_c):
    y_cur = torch.tensor([]).cuda().int()
    number = 1000
    k_list = [1, 10, 100, 1000]
    if hyp_c > 0:
        for i in range(len(x) // number + 1):
            if (i+1)*number > len(x):
                x_s = x[i*number:]
            else:
                x_s = x[i*number: (i+1)*number]
            dist = torch.empty(len(x_s), len(x), device="cuda")
            for i in range(len(x_s)):
                dist[i : i + 1] = -dist_matrix_d(x_s[i : i + 1], x, hyp_c)
            dist = y[dist.topk(1 + max(k_list), largest=True)[1][:, 1:]]
            y_cur = torch.cat([y_cur, dist])
    else:
        for i in range(len(x) // number + 1):
            if (i+1)*number > len(x):
                x_s = x[i*number:]
            else:
                x_s = x[i*number: (i+1)*number]
            dist = torch.empty(len(x_s), len(x), device="cuda")
            for i in range(len(x_s)):
                dist[i : i + 1] = -torch.cdist(x_s[i : i + 1], x, p=2)
            dist = y[dist.topk(1 + max(k_list), largest=True)[1][:, 1:]]
            y_cur = torch.cat([y_cur, dist])
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall


def get_recall_inshop(xq, yq, xg, yg, hyp_c):
    if hyp_c > 0:
        dist_m = torch.empty(len(xq), len(xg), device="cuda")
        for i in range(len(xq)):
            dist_m[i : i + 1] = -dist_matrix_d(xq[i : i + 1], xg, hyp_c)
    else:
        dist_m = torch.empty(len(xq), len(xg), device="cuda")
        for i in range(len(xq)):
            dist_m[i : i + 1] = -torch.cdist(xq[i : i + 1], xg, hyp_c)

    def recall_k(cos_sim, query_T, gallery_T, k):
        m = len(cos_sim)
        match_counter = 0
        for i in range(m):
            pos_sim = cos_sim[i][gallery_T == query_T[i]]
            neg_sim = cos_sim[i][gallery_T != query_T[i]]
            thresh = torch.max(pos_sim).item()
            if torch.sum(neg_sim > thresh) < k:
                match_counter += 1
        return match_counter / m

    recall = [recall_k(dist_m, yq, yg, k) for k in [1, 10, 20, 30, 40, 50]]
    print(recall)
    return recall


def get_emb(
    model,
    ds,
    path,
    mean_std,
    resize=224,
    crop=224,
    ds_type="eval",
    world_size=1,
    skip_head=False,
):
    eval_tr = T.Compose(
        [
            T.Resize(resize, interpolation=PIL.Image.BICUBIC),
            T.CenterCrop(crop),
            T.ToTensor(),
            T.Normalize(*mean_std),
        ]
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=multiprocessing.cpu_count() // world_size,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    x, y = eval_dataset(model, dl_eval, skip_head)
    y = y.cuda()
    if world_size > 1:
        all_x = [torch.zeros_like(x) for _ in range(world_size)]
        all_y = [torch.zeros_like(y) for _ in range(world_size)]
        torch.distributed.all_gather(all_x, x)
        torch.distributed.all_gather(all_y, y)
        x, y = torch.cat(all_x), torch.cat(all_y)
    model.train()
    return x, y


def eval_dataset(model, dl, skip_head):
    all_x, all_y = [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            all_x.append(model(x, skip_head=skip_head))
        all_y.append(y)
    return torch.cat(all_x), torch.cat(all_y)


In [8]:
os.environ["CUDA_VISIBLE_DEVICES"] =  "5,6"
# if local_rank == 0:
#     wandb.init(project="hyp_metric")

world_size = int(os.environ.get("WORLD_SIZE", 1))
    
if model.startswith("vit"):
    mean_std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
else:
    mean_std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
train_tr = T.Compose(
    [
        T.RandomResizedCrop(
            crop, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC
        ),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(*mean_std),
    ]
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[ds]
ds_train = ds_class(path, "train", train_tr)
assert len(ds_train.ys) * num_samples >= bs * world_size
sampler = UniqueClassSempler(
    ds_train.ys, num_samples, local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=bs,
    num_workers=8,
    pin_memory=True,
    drop_last=True,
)

model = init_model(model = model, hyp_c = hyp_c, emb = emb, clip_r = clip_r, freeze = freeze)
optimizer = optim.AdamW(model.parameters(), lr=lr)
model = nn.DataParallel(model)

loss_f = partial(contrastive_loss, tau=t, hyp_c=hyp_c)
get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=path,
    mean_std=mean_std,
    world_size=world_size,
    resize=resize,
    crop=crop,
)
eval_ep = eval(eval_ep.replace("r", "list(range").replace(")", "))"))    

cudnn.benchmark = True
all_rh = []
best_rh = []
best_ep = 0

for ep in trange(ep):
    sampler.set_epoch(ep)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // num_samples, num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s = y[:, 0].tolist()
        assert len(set(s)) == len(s)

        x = x.cuda(non_blocking=True)
        z = model(x)
        z=z.view(len(x) // num_samples, num_samples, emb)
        if world_size > 1:
            with torch.no_grad():
                all_z = [torch.zeros_like(z) for _ in range(world_size)]
                torch.distributed.all_gather(all_z, z)
            all_z[local_rank] = z
            z = torch.cat(all_z)
        loss = 0
        for i in range(num_samples):
            for j in range(num_samples):
                if i != j:
                    l, s = loss_f(z[:, i], z[:, j])
                    loss += l
                    stats_ep.append({**s, "loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 3)
        optimizer.step()

#     if (ep + 1) in eval_ep:
#         rh= evaluate(get_emb_f, ds, hyp_c)
        
    if (ep+1) % 10 == 0 or ep == 0:
        rh = evaluate(get_emb_f, ds, hyp_c)
        all_rh.append(rh)
        if ep == 0:
            best_rh = rh
        else:
            if isinstance(rh, list):
                if rh[0] > best_rh[0]:
                    best_rh = rh
                    best_ep = ep
                    print("save model........")
                    torch.save(model, "/data/xuyunhao/Mixed curvature/result/dino_cars_d_best_checkout.pth")
            else:
                if rh > best_rh:
                    best_rh = rh
                    best_ep = ep
                    print("save model........")
                    torch.save(model, "/data/xuyunhao/Mixed curvature/result/dino_cars_d_best_checkout.pth")

    if local_rank == 0:
        stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
#         if (ep + 1) in eval_ep:
#             stats_ep = {"recall": rh, **stats_ep}
#         wandb.log({**stats_ep, "ep": ep})
        
print("best:", best_ep, best_rh)

if save_emb:
    ds_type = "gallery" if ds == "Inshop" else "eval"
    x, y = get_emb_f(ds_type=ds_type)
    x, y = x.float().cpu(), y.long().cpu()
    torch.save((x, y), path + "/" + emb_name + "_eval.pt")

    x, y = get_emb_f(ds_type="train")
    x, y = x.float().cpu(), y.long().cpu()
    torch.save((x, y), path + "/" + emb_name + "_train.pt")

Using cache found in /data/xuyunhao/.cache/torch/hub/facebookresearch_dino_main
  "Argument interpolation should be of type InterpolationMode instead of int. "
  0%|▎                                                                                                                                            | 1/500 [00:35<4:55:14, 35.50s/it]

[0.405116221866929, 0.5081785758209323, 0.619235026442012, 0.7286926577296765, 0.8222850817857582, 0.9028409789693764]


  2%|██▊                                                                                                                                         | 10/500 [01:32<1:28:42, 10.86s/it]

[0.5143278809494527, 0.6332554421350387, 0.7401303652687247, 0.8296642479399828, 0.9018570901488132, 0.9535112532283851]
save model........


  4%|█████▍                                                                                                                                        | 19/500 [02:09<34:20,  4.28s/it]

[0.5989423195178945, 0.7088918952158406, 0.796089041938261, 0.8727093838396262, 0.929405977124585, 0.9656868773828557]
save model........


  6%|████████▏                                                                                                                                     | 29/500 [03:13<34:34,  4.41s/it]

[0.6654778010084861, 0.7633747386545321, 0.8418398720944533, 0.9019800762513835, 0.9467470175870126, 0.974172918460214]
save model........


  8%|███████████                                                                                                                                   | 39/500 [04:15<33:03,  4.30s/it]

[0.6959783544459476, 0.7911695978354446, 0.8602877874800148, 0.9152625753289878, 0.9512975033821178, 0.9782314598450376]
save model........


 10%|█████████████▉                                                                                                                                | 49/500 [05:17<32:36,  4.34s/it]

[0.7139343254212274, 0.8110933464518509, 0.8755380641987456, 0.9223957692780715, 0.9584306973312016, 0.9810601402041569]
save model........


 12%|████████████████▊                                                                                                                             | 59/500 [06:16<31:10,  4.24s/it]

[0.7341040462427746, 0.8274504980937154, 0.8886975771737794, 0.9316197269708523, 0.961874308203173, 0.9824129873324314]
save model........


 14%|███████████████████▌                                                                                                                          | 69/500 [07:18<28:47,  4.01s/it]

[0.7387775181404501, 0.8312630672733982, 0.8926331324560325, 0.9349403517402534, 0.9640880580494404, 0.9840118066658468]
save model........


 16%|██████████████████████▍                                                                                                                       | 79/500 [08:20<30:26,  4.34s/it]

[0.7483704341409421, 0.8407329971713197, 0.9017341040462428, 0.9403517402533513, 0.9682695855368343, 0.9842577788709875]
save model........


 18%|█████████████████████████▎                                                                                                                    | 89/500 [09:25<30:48,  4.50s/it]

[0.7631287664493912, 0.8493420243512483, 0.9081293813799041, 0.9458861148690199, 0.9685155577419752, 0.9846267371786988]
save model........


 20%|████████████████████████████                                                                                                                  | 99/500 [10:35<33:40,  5.04s/it]

[0.7796089041938261, 0.8631164678391342, 0.9163694502521215, 0.9549870864592301, 0.9733120157422211, 0.9858565982044029]
save model........


 22%|██████████████████████████████▋                                                                                                              | 109/500 [11:37<28:38,  4.40s/it]

[0.7884639035788956, 0.8668060509162465, 0.9226417414832124, 0.9548641003566597, 0.9766326405116222, 0.987086459230107]
save model........


 24%|█████████████████████████████████▌                                                                                                           | 119/500 [12:38<26:52,  4.23s/it]

[0.7933833476817119, 0.8723404255319149, 0.9259623662526134, 0.958553683433772, 0.9760177099987701, 0.9861025704095437]
save model........


 26%|████████████████████████████████████▏                                                                                                      | 130/500 [14:06<1:09:00, 11.19s/it]

[0.7889558479891772, 0.8691427868650843, 0.9200590333292338, 0.9563399335875046, 0.9756487516910589, 0.9874554175378182]


 28%|███████████████████████████████████████▏                                                                                                     | 139/500 [14:43<27:13,  4.52s/it]

[0.804206124707908, 0.8766449391218792, 0.9236256303037756, 0.9560939613823638, 0.974172918460214, 0.9880703480506703]
save model........


 30%|██████████████████████████████████████████                                                                                                   | 149/500 [15:42<24:28,  4.18s/it]

[0.805681957938753, 0.8779977862501537, 0.9284220883040216, 0.9579387529209199, 0.9758947238961997, 0.9867175009223957]
save model........


 32%|████████████████████████████████████████████▊                                                                                                | 159/500 [16:44<24:57,  4.39s/it]

[0.8085106382978723, 0.8815643832246957, 0.925224449637191, 0.9573238224080679, 0.9773705571270447, 0.9878243758455294]
save model........


 34%|███████████████████████████████████████████████▋                                                                                             | 169/500 [17:49<24:18,  4.41s/it]

[0.8110933464518509, 0.8829172303529701, 0.9298979215348666, 0.9603984749723281, 0.9779854876398967, 0.9879473619480998]
save model........


 36%|██████████████████████████████████████████████████▍                                                                                          | 179/500 [18:57<24:32,  4.59s/it]

[0.815028901734104, 0.886483827327512, 0.9324806296888452, 0.9607674332800393, 0.9782314598450376, 0.9875784036403886]
save model........


 38%|█████████████████████████████████████████████████████▎                                                                                       | 189/500 [20:00<22:41,  4.38s/it]

[0.8243758455294552, 0.8932480629688845, 0.9370311154839503, 0.96396507194687, 0.9777395154347559, 0.9893002090763744]
save model........


 40%|████████████████████████████████████████████████████████                                                                                     | 199/500 [21:01<21:14,  4.23s/it]

[0.8259746648628705, 0.8965686877382856, 0.9377690320993728, 0.9643340302545812, 0.9804452096913049, 0.9906530562046488]
save model........


 42%|███████████████████████████████████████████████████████████▏                                                                                 | 210/500 [22:34<55:15, 11.43s/it]

[0.8229000122986102, 0.8911572992251876, 0.9326036157914156, 0.961874308203173, 0.9783544459476079, 0.9895461812815152]


 44%|█████████████████████████████████████████████████████████████▊                                                                               | 219/500 [23:09<19:35,  4.18s/it]

[0.8270815397860042, 0.8939859795843069, 0.9355552822531054, 0.9631041692288771, 0.9809371541015866, 0.9896691673840856]
save model........


 46%|████████████████████████████████████████████████████████████████▌                                                                            | 229/500 [24:14<20:01,  4.43s/it]

[0.8280654286065675, 0.8944779239945886, 0.9375230598942319, 0.9637190997417292, 0.9803222235887344, 0.9889312507686632]
save model........


 48%|███████████████████████████████████████████████████████████████████▋                                                                         | 240/500 [25:50<50:55, 11.75s/it]

[0.8231459845037511, 0.895215840610011, 0.9376460459968023, 0.9648259746648629, 0.9790923625630303, 0.9883163202558111]


 50%|██████████████████████████████████████████████████████████████████████▌                                                                      | 250/500 [26:53<46:36, 11.18s/it]

[0.8214241790677653, 0.8957077850202927, 0.9386299348173657, 0.9650719468700036, 0.9787234042553191, 0.9872094453326774]


 52%|█████████████████████████████████████████████████████████████████████████                                                                    | 259/500 [27:31<18:30,  4.61s/it]

[0.8311400811708277, 0.8991513958922642, 0.9398597958430698, 0.9661788217931374, 0.9790923625630303, 0.9884393063583815]
save model........


 54%|████████████████████████████████████████████████████████████████████████████▏                                                                | 270/500 [29:03<41:28, 10.82s/it]

[0.8305251506579756, 0.8996433403025458, 0.9412126429713442, 0.9658098634854262, 0.9811831263067273, 0.9901611117943673]


 56%|██████████████████████████████████████████████████████████████████████████████▉                                                              | 280/500 [30:05<38:50, 10.59s/it]

[0.8304021645554053, 0.8944779239945886, 0.9382609765096545, 0.9634731275365883, 0.9798302791784529, 0.989177222973804]


 58%|█████████████████████████████████████████████████████████████████████████████████▍                                                           | 289/500 [30:40<14:42,  4.18s/it]

[0.8342147337350879, 0.9013651457385315, 0.9398597958430698, 0.96396507194687, 0.9803222235887344, 0.9896691673840856]
save model........


 60%|████████████████████████████████████████████████████████████████████████████████████▌                                                        | 300/500 [32:08<36:52, 11.06s/it]

[0.8321239699913909, 0.9003812569179682, 0.9389988931250769, 0.9638420858442996, 0.9804452096913049, 0.9886852785635223]


 62%|███████████████████████████████████████████████████████████████████████████████████████▏                                                     | 309/500 [32:45<14:02,  4.41s/it]

[0.835444594760792, 0.9013651457385315, 0.9391218792276472, 0.9649489607674333, 0.9814290985118682, 0.9888082646660927]
save model........


 64%|██████████████████████████████████████████████████████████████████████████████████████████▏                                                  | 320/500 [34:14<32:33, 10.86s/it]

[0.8349526503505104, 0.8986594514819826, 0.9386299348173657, 0.9650719468700036, 0.9797072930758824, 0.9896691673840856]


 66%|█████████████████████████████████████████████████████████████████████████████████████████████                                                | 330/500 [35:23<33:53, 11.96s/it]

[0.8328618866068135, 0.8971836182511376, 0.9424425039970483, 0.9661788217931374, 0.9798302791784529, 0.9894231951789447]


 68%|███████████████████████████████████████████████████████████████████████████████████████████████▌                                             | 339/500 [35:59<11:45,  4.38s/it]

[0.8360595252736441, 0.9033329233796581, 0.9417045873816259, 0.9698684048702496, 0.9821670151272907, 0.9897921534866561]
save model........


 70%|██████████████████████████████████████████████████████████████████████████████████████████████████▍                                          | 349/500 [37:06<11:37,  4.62s/it]

[0.8361825113762145, 0.9005042430205387, 0.9434263928176116, 0.9675316689214118, 0.9820440290247202, 0.9895461812815152]
save model........


 72%|█████████████████████████████████████████████████████████████████████████████████████████████████████▏                                       | 359/500 [38:10<10:34,  4.50s/it]

[0.8374123724019186, 0.900012298610257, 0.9398597958430698, 0.9655638912802853, 0.9811831263067273, 0.9890542368712335]
save model........


 74%|████████████████████████████████████████████████████████████████████████████████████████████████████████                                     | 369/500 [39:14<09:50,  4.51s/it]

[0.8460213995818473, 0.9073914647644816, 0.943549378920182, 0.9685155577419752, 0.9830279178452834, 0.9902840978969376]
save model........


 76%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                  | 379/500 [40:16<08:48,  4.37s/it]

[0.8465133439921289, 0.9051777149182142, 0.9434263928176116, 0.9690075021522568, 0.9826589595375722, 0.9897921534866561]
save model........


 78%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                               | 390/500 [41:42<19:31, 10.65s/it]

[0.845652441274136, 0.9057926454310663, 0.9426884762021892, 0.968146599434264, 0.9819210429221498, 0.9905300701020785]


 80%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                            | 400/500 [42:45<18:29, 11.10s/it]

[0.8422088304021645, 0.9043168122002214, 0.9426884762021892, 0.968146599434264, 0.9820440290247202, 0.9908990284097897]


 82%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                         | 409/500 [43:24<06:54,  4.55s/it]

[0.8504488992743819, 0.910835075636453, 0.9460091009715902, 0.9688845160496864, 0.9831509039478539, 0.990407083999508]
save model........


 84%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                      | 420/500 [44:55<14:41, 11.01s/it]

[0.8436846636330094, 0.9083753535850448, 0.945640142663879, 0.9685155577419752, 0.9840118066658468, 0.9915139589226417]


 86%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                   | 430/500 [46:00<12:32, 10.76s/it]

[0.8457754273767064, 0.906284589841348, 0.9426884762021892, 0.9679006272291231, 0.9805681957938753, 0.989177222973804]


 88%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████                 | 440/500 [47:03<11:09, 11.17s/it]

[0.8470052884024105, 0.9083753535850448, 0.9428114623047595, 0.9669167384085599, 0.9817980568195794, 0.9899151395892264]


 90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉              | 450/500 [48:14<10:19, 12.40s/it]

[0.847743205017833, 0.9083753535850448, 0.9460091009715902, 0.9694994465625384, 0.9825359734350019, 0.9897921534866561]


 92%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▋           | 460/500 [49:15<07:00, 10.52s/it]

[0.8445455663510023, 0.9067765342516295, 0.9445332677407453, 0.9685155577419752, 0.9821670151272907, 0.9897921534866561]


 94%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌        | 470/500 [50:14<05:17, 10.57s/it]

[0.8482351494281146, 0.9073914647644816, 0.9458861148690199, 0.9682695855368343, 0.9817980568195794, 0.9897921534866561]


 96%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████      | 479/500 [50:50<01:28,  4.20s/it]

[0.851063829787234, 0.9128028532775796, 0.9482228508178576, 0.9682695855368343, 0.9810601402041569, 0.9888082646660927]
save model........


 98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏  | 490/500 [52:19<01:52, 11.23s/it]

[0.8508178575820933, 0.9103431312261715, 0.9460091009715902, 0.9683925716394047, 0.9824129873324314, 0.9897921534866561]


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 500/500 [53:23<00:00,  6.41s/it]

[0.8460213995818473, 0.9077604230721928, 0.9473619480998647, 0.9691304882548272, 0.9826589595375722, 0.9899151395892264]
best: 479 [0.851063829787234, 0.9128028532775796, 0.9482228508178576, 0.9682695855368343, 0.9810601402041569, 0.9888082646660927]



