In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
from apex import amp

import os
import random
import timm
import wandb
from tqdm import trange
import multiprocessing
from functools import partial
import numpy as np
import PIL
from torchvision.transforms import InterpolationMode

from sampler import UniqueClassSempler
from proxy_anchor.utils import calc_recall_at_k
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset
from hyptorch.pmath import dist_matrix
import hyptorch.nn as hypnn

from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from datetime import timedelta

In [2]:
seed = 1
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

# 超参数

In [3]:
num = 1
path = '/data/xuyunhao/datasets'
ds = 'Cars'
num_samples = 2
bs = 196

lr = 1e-5

t = 0.2
emb = 128
freeze = 0
ep = 500
hyp_c = 0.1
eval_ep = 'r('+str(ep-100)+','+str(ep+10)+',10)'

# model = 'vit_small_patch16_224'
model = 'dino_vits16'
#model = 'deit_small_distilled_patch16_224'

save_emb = False
emb_name = 'emb'
clip_r = 2.3
resize = 224
crop = 224
local_rank = 0
save_path = "/data/xuyunhao/Mixed curvature/result/{}_{}_best_p_{}_{}_checkout.pth".format(model,ds,emb,num)

# 损失函数

In [4]:
def contrastive_loss(x0, x1, tau, hyp_c):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode

    if hyp_c == 0:
        dist_f = lambda x, y: x @ y.t()
    else:
        dist_f = lambda x, y: -dist_matrix(x, y, c=hyp_c)
    bsize = x0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    logits00 = dist_f(x0, x0) / tau - eye_mask
    logits01 = dist_f(x0, x1) / tau
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    stats = {
        "logits/min": logits01.min().item(),
        "logits/mean": logits01.mean().item(),
        "logits/max": logits01.max().item(),
        "logits/acc": (logits01.argmax(-1) == target).float().mean().item(),
    }
    return loss, stats

# 模型

In [5]:
def init_model(model = model, hyp_c = 0.1, emb = 128, clip_r = 2.3, freeze = 0):
    if model.startswith("dino"):
        body = torch.hub.load("facebookresearch/dino:main", model)
    else:
        body = timm.create_model(model, pretrained=True)
    if hyp_c > 0:
        last = hypnn.ToPoincare(
            c=hyp_c,
            ball_dim=emb,
            riemannian=False,
            clip_r=clip_r,
        )
    else:
        last = NormLayer()
    bdim = 2048 if model == "resnet50" else 384
    head = nn.Sequential(nn.Linear(bdim, emb), nn.BatchNorm1d(emb), last)
    nn.init.constant_(head[0].bias.data, 0)
    nn.init.orthogonal_(head[0].weight.data)
    rm_head(body)
    if freeze is not None:
        freezer(body,freeze)
    model = HeadSwitch(body, head)
    model.cuda().train()
    return model
    
    
class HeadSwitch(nn.Module):
    def __init__(self, body, head):
        super(HeadSwitch, self).__init__()
        self.body = body
        self.head = head
        self.norm = NormLayer()

    def forward(self, x, skip_head=False):
        x = self.body(x)
        if type(x) == tuple:
            x = x[0]
        if not skip_head:
            x = self.head(x)
        else:
            x = self.norm(x)
        return x


class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)


def freezer(model, num_block):
    def fr(m):
        for param in m.parameters():
            param.requires_grad = False

    fr(model.patch_embed)
    fr(model.pos_drop)
    for i in range(num_block):
        fr(model.blocks[i])


def rm_head(m):
    names = set(x[0] for x in m.named_children())
    target = {"head", "fc", "head_dist"}
    for x in names & target:
        m.add_module(x, nn.Identity())

In [6]:
class MultiSample:
    def __init__(self, transform, n=2):
        self.transform = transform
        self.num = n

    def __call__(self, x):
        return tuple(self.transform(x) for _ in range(self.num))


def evaluate(get_emb_f, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        emb_head = get_emb_f(ds_type="eval")
        recall_head = get_recall(*emb_head, ds_name, hyp_c)
    elif ds_name == "SOP":
        emb_head = get_emb_f(ds_type="eval")
        recall_head = get_recall_sop(*emb_head, ds_name, hyp_c)
    else:
        emb_head_query = get_emb_f(ds_type="query")
        emb_head_gal = get_emb_f(ds_type="gallery")
        recall_head = get_recall_inshop(*emb_head_query, *emb_head_gal, hyp_c)
    return recall_head


def get_recall(x, y, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        k_list = [1, 2, 4, 8, 16, 32]
    elif ds_name == "SOP":
        k_list = [1, 10, 100, 1000]

    if hyp_c > 0:
        dist_m = torch.empty(len(x), len(x), device="cuda")
        for i in range(len(x)):
            dist_m[i : i + 1] = -dist_matrix(x[i : i + 1], x, hyp_c)
    else:
        dist_m = torch.empty(len(x), len(x), device="cuda")
        for i in range(len(x)):
            dist_m[i : i + 1] = -torch.cdist(x[i : i + 1], x, p=2)

    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall[0]


def get_recall_sop(x, y, ds_name, hyp_c):
    y_cur = torch.tensor([]).cuda().int()
    number = 1000
    k_list = [1, 10, 100, 1000]
    if hyp_c > 0:
        for i in range(len(x) // number + 1):
            if (i+1)*number > len(x):
                x_s = x[i*number:]
            else:
                x_s = x[i*number: (i+1)*number]
            dist = torch.empty(len(x_s), len(x), device="cuda")
            for i in range(len(x_s)):
                dist[i : i + 1] = -dist_matrix(x_s[i : i + 1], x, hyp_c)
            dist = y[dist.topk(1 + max(k_list), largest=True)[1][:, 1:]]
            y_cur = torch.cat([y_cur, dist])
    else:
        for i in range(len(x) // number + 1):
            if (i+1)*number > len(x):
                x_s = x[i*number:]
            else:
                x_s = x[i*number: (i+1)*number]
            dist = torch.empty(len(x_s), len(x), device="cuda")
            for i in range(len(x_s)):
                dist[i : i + 1] = -torch.cdist(x_s[i : i + 1], x, p=2)
            dist = y[dist.topk(1 + max(k_list), largest=True)[1][:, 1:]]
            y_cur = torch.cat([y_cur, dist])
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall


def get_recall_inshop(xq, yq, xg, yg, hyp_c):
    if hyp_c > 0:
        dist_m = torch.empty(len(xq), len(xg), device="cuda")
        for i in range(len(xq)):
            dist_m[i : i + 1] = -dist_matrix(xq[i : i + 1], xg, hyp_c)
    else:
        dist_m = torch.empty(len(xq), len(xg), device="cuda")
        for i in range(len(xq)):
            dist_m[i : i + 1] = -torch.cdist(xq[i : i + 1], xg, hyp_c)

    def recall_k(cos_sim, query_T, gallery_T, k):
        m = len(cos_sim)
        match_counter = 0
        for i in range(m):
            pos_sim = cos_sim[i][gallery_T == query_T[i]]
            neg_sim = cos_sim[i][gallery_T != query_T[i]]
            thresh = torch.max(pos_sim).item()
            if torch.sum(neg_sim > thresh) < k:
                match_counter += 1
        return match_counter / m

    recall = [recall_k(dist_m, yq, yg, k) for k in [1, 10, 20, 30, 40, 50]]
    print(recall)
    return recall


def get_emb(
    model,
    ds,
    path,
    mean_std,
    resize=224,
    crop=224,
    ds_type="eval",
    world_size=1,
    skip_head=False,
):
    eval_tr = T.Compose(
        [
            T.Resize(resize, interpolation=PIL.Image.BICUBIC),
            T.CenterCrop(crop),
            T.ToTensor(),
            T.Normalize(*mean_std),
        ]
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=multiprocessing.cpu_count() // world_size,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    x, y = eval_dataset(model, dl_eval, skip_head)
    y = y.cuda()
    if world_size > 1:
        all_x = [torch.zeros_like(x) for _ in range(world_size)]
        all_y = [torch.zeros_like(y) for _ in range(world_size)]
        torch.distributed.all_gather(all_x, x)
        torch.distributed.all_gather(all_y, y)
        x, y = torch.cat(all_x), torch.cat(all_y)
    model.train()
    return x, y


def eval_dataset(model, dl, skip_head):
    all_x, all_y = [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            all_x.append(model(x, skip_head=skip_head))
        all_y.append(y)
    return torch.cat(all_x), torch.cat(all_y)

# 主函数

In [7]:
os.environ["CUDA_VISIBLE_DEVICES"] =  "1,2"
if local_rank == 0:
    wandb.init(project="hyp_metric")

world_size = int(os.environ.get("WORLD_SIZE", 1))
    
if model.startswith("vit"):
    mean_std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
else:
    mean_std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
train_tr = T.Compose(
    [
        T.RandomResizedCrop(
            crop, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC
        ),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(*mean_std),
    ]
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[ds]
ds_train = ds_class(path, "train", train_tr)
assert len(ds_train.ys) * num_samples >= bs * world_size
sampler = UniqueClassSempler(
    ds_train.ys, num_samples, local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=bs,
    num_workers=8,
    pin_memory=True,
    drop_last=True,
)

model = init_model(model = model, hyp_c = hyp_c, emb = emb, clip_r = clip_r, freeze = freeze)
optimizer = optim.AdamW(model.parameters(), lr=lr)
model = nn.DataParallel(model)

loss_f = partial(contrastive_loss, tau=t, hyp_c=hyp_c)
get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=path,
    mean_std=mean_std,
    world_size=world_size,
    resize=resize,
    crop=crop,
)
eval_ep = eval(eval_ep.replace("r", "list(range").replace(")", "))"))    

cudnn.benchmark = True
all_rh = []
best_rh = []
best_ep = 0
lower_cnt = 0

for ep in trange(ep):
    sampler.set_epoch(ep)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // num_samples, num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s = y[:, 0].tolist()
        assert len(set(s)) == len(s)

        x = x.cuda(non_blocking=True)
        z = model(x)
        z=z.view(len(x) // num_samples, num_samples, emb)
        if world_size > 1:
            with torch.no_grad():
                all_z = [torch.zeros_like(z) for _ in range(world_size)]
                torch.distributed.all_gather(all_z, z)
            all_z[local_rank] = z
            z = torch.cat(all_z)
        loss = 0
        for i in range(num_samples):
            for j in range(num_samples):
                if i != j:
                    l, s = loss_f(z[:, i], z[:, j])
                    loss += l
                    stats_ep.append({**s, "loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 3)
        optimizer.step()

#     if (ep + 1) in eval_ep:
#         rh= evaluate(get_emb_f, ds, hyp_c)

    if (ep+1) % 10 == 0 or ep == 0:
        rh = evaluate(get_emb_f, ds, hyp_c)
        all_rh.append(rh)
        if ep == 0:
            best_rh = rh
        else:
            if isinstance(rh, list):
                if rh[0] >= best_rh[0]:
                    lower_cnt = 0
                    best_rh = rh
                    best_ep = ep
                    print("save model........")
                    torch.save(model.state_dict(), save_path)
                else:
                    lower_cnt += 1
            else:
                if rh >= best_rh:
                    lower_cnt = 0
                    best_rh = rh
                    best_ep = ep
                    print("save model........")
                    torch.save(model.state_dict(), save_path)
                else:
                    lower_cnt += 1
    
    if lower_cnt >= 10:
        break

    if local_rank == 0:
        stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
#         if (ep + 1) in eval_ep:
#             stats_ep = {"recall": rh, **stats_ep}
        stats_ep = {"recall": rh, **stats_ep}
        wandb.log({**stats_ep, "ep": ep})
        
print("best:", best_ep+1, best_rh)
        
if save_emb:
    ds_type = "gallery" if ds == "Inshop" else "eval"
    x, y = get_emb_f(ds_type=ds_type)
    x, y = x.float().cpu(), y.long().cpu()
    torch.save((x, y), path + "/" + emb_name + "_eval.pt")

    x, y = get_emb_f(ds_type="train")
    x, y = x.float().cpu(), y.long().cpu()
    torch.save((x, y), path + "/" + emb_name + "_train.pt")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mxuyunhao[0m. Use [1m`wandb login --relogin`[0m to force relogin


Using cache found in /data/xuyunhao/.cache/torch/hub/facebookresearch_dino_main
  "Argument interpolation should be of type InterpolationMode instead of int. "
  0%|▏                                                                    | 1/500 [00:33<4:41:43, 33.87s/it]

[0.40523920796949947, 0.5058418398720944, 0.6177591932111671, 0.72672488008855, 0.8205632763497724, 0.9028409789693764]


  2%|█▎                                                                  | 10/500 [01:33<1:31:54, 11.25s/it]

[0.50448899274382, 0.6221866929037019, 0.7290616160373877, 0.8268355675808634, 0.8989054236871233, 0.9509285450744066]
save model........


  4%|██▋                                                                   | 19/500 [02:11<33:48,  4.22s/it]

[0.5911941950559587, 0.7002828680359119, 0.7888328618866068, 0.8676669536342393, 0.9274381994834584, 0.9644570163571516]
save model........


  6%|████                                                                  | 29/500 [03:16<33:37,  4.28s/it]

[0.6556389128028532, 0.7559955725003075, 0.8359365391710737, 0.9001352847128274, 0.9445332677407453, 0.972082154716517]
save model........


  8%|█████▍                                                                | 39/500 [04:21<33:23,  4.35s/it]

[0.6881072438814414, 0.7846513343992129, 0.8574591071208953, 0.9115729922518755, 0.951051531176977, 0.9754027794859181]
save model........


 10%|██████▊                                                               | 49/500 [05:28<32:53,  4.37s/it]

[0.7054482843438692, 0.8034682080924855, 0.8734473004550486, 0.9206739638420859, 0.9553560447669414, 0.9792153486656008]
save model........


 12%|████████▎                                                             | 59/500 [06:29<30:05,  4.09s/it]

[0.7334891157299225, 0.822531053990899, 0.8867297995326529, 0.9301438937400074, 0.9599065305620464, 0.9793383347681712]
save model........


 14%|█████████▌                                                          | 70/500 [08:00<1:19:17, 11.06s/it]

[0.7274627967039724, 0.8237609150166031, 0.8875907022506456, 0.9319886852785635, 0.9602754888697577, 0.981675070717009]


 16%|███████████                                                           | 79/500 [08:34<28:51,  4.11s/it]

[0.7389005042430206, 0.8344607059402288, 0.8939859795843069, 0.9377690320993728, 0.9665477801008486, 0.9815520846144385]
save model........


 18%|████████████▍                                                         | 89/500 [09:38<28:38,  4.18s/it]

[0.7543967531668921, 0.8424548026073053, 0.9027179928668061, 0.9414586151764851, 0.966670766203419, 0.9827819456401427]
save model........


 20%|█████████████▊                                                        | 99/500 [10:43<28:25,  4.25s/it]

[0.7717377936293199, 0.8546304267617759, 0.9093592424056082, 0.951051531176977, 0.9710982658959537, 0.9859795843069733]
save model........


 22%|███████████████                                                      | 109/500 [11:46<27:43,  4.25s/it]

[0.7814536957323822, 0.8631164678391342, 0.9175993112778256, 0.9527733366129627, 0.9744188906653548, 0.986471528717255]
save model........


 24%|████████████████▍                                                    | 119/500 [12:49<26:12,  4.13s/it]

[0.7872340425531915, 0.865453203787972, 0.9211659082523674, 0.9544951420489485, 0.9751568072807772, 0.9877013897429591]
save model........


 26%|█████████████████▍                                                 | 130/500 [14:17<1:08:11, 11.06s/it]

[0.7828065428606568, 0.866068134300824, 0.9196900750215226, 0.9514204894846883, 0.9742959045627844, 0.987086459230107]


 28%|███████████████████▏                                                 | 139/500 [14:52<25:20,  4.21s/it]

[0.7951051531176977, 0.8724634116344853, 0.9235026442012053, 0.9551100725618005, 0.9736809740499324, 0.9872094453326774]
save model........


 30%|████████████████████▌                                                | 149/500 [15:53<23:48,  4.07s/it]

[0.7999016111179437, 0.8761529947115976, 0.925224449637191, 0.9578157668183496, 0.9766326405116222, 0.9880703480506703]
save model........


 32%|█████████████████████▍                                             | 160/500 [17:21<1:02:43, 11.07s/it]

[0.7964580002459722, 0.8754150780961751, 0.9241175747140573, 0.956462919690075, 0.9761406961013406, 0.9881933341532407]


 34%|███████████████████████▎                                             | 169/500 [17:56<22:26,  4.07s/it]

[0.8027302914770631, 0.8784897306604353, 0.927315213380888, 0.9589226417414832, 0.9773705571270447, 0.9872094453326774]
save model........


 36%|████████████████████████▋                                            | 179/500 [18:58<21:50,  4.08s/it]

[0.8103554298364285, 0.8830402164555405, 0.9317427130734227, 0.9596605583569057, 0.977493543229615, 0.9875784036403886]
save model........


 38%|██████████████████████████                                           | 189/500 [19:59<20:51,  4.02s/it]

[0.8131841101955479, 0.888574591071209, 0.9344484073299717, 0.9605214610748986, 0.9778625015373262, 0.9877013897429591]
save model........


 40%|███████████████████████████▍                                         | 199/500 [21:00<20:04,  4.00s/it]

[0.8156438322469561, 0.8927561185586029, 0.9356782683556758, 0.963350141434018, 0.9783544459476079, 0.9879473619480998]
save model........


 42%|████████████████████████████▊                                        | 209/500 [22:05<20:21,  4.20s/it]

[0.81982535973435, 0.8922641741483213, 0.9355552822531054, 0.9621202804083139, 0.9786004181527488, 0.9886852785635223]
save model........


 44%|██████████████████████████████▏                                      | 219/500 [23:09<18:24,  3.93s/it]

[0.8199483458369204, 0.8900504243020538, 0.9346943795351126, 0.9632271553314475, 0.9794613208707416, 0.9888082646660927]
save model........


 46%|███████████████████████████████▋                                     | 230/500 [24:37<45:06, 10.02s/it]

[0.8182265404009347, 0.89189521584061, 0.935432296150535, 0.9602754888697577, 0.9771245849219038, 0.988562292460952]


 48%|████████████████████████████████▉                                    | 239/500 [25:10<16:24,  3.77s/it]

[0.8210552207600541, 0.8934940351740254, 0.9371541015865207, 0.9632271553314475, 0.9781084737424671, 0.988562292460952]
save model........


 50%|██████████████████████████████████▌                                  | 250/500 [26:33<43:27, 10.43s/it]

[0.8204402902472021, 0.8912802853277579, 0.9375230598942319, 0.9631041692288771, 0.97896937646046, 0.9877013897429591]


 52%|███████████████████████████████████▉                                 | 260/500 [27:31<41:35, 10.40s/it]

[0.8206862624523429, 0.8938629934817366, 0.9367851432788095, 0.9642110441520109, 0.9797072930758824, 0.9890542368712335]


 54%|█████████████████████████████████████                                | 269/500 [28:05<15:44,  4.09s/it]

[0.8292952896322715, 0.8964457016357151, 0.9376460459968023, 0.9626122248185955, 0.9782314598450376, 0.9886852785635223]
save model........


 56%|██████████████████████████████████████▌                              | 279/500 [29:16<16:23,  4.45s/it]

[0.8312630672733982, 0.8953388267125815, 0.9377690320993728, 0.961874308203173, 0.9784774320501783, 0.9893002090763744]
save model........


 58%|████████████████████████████████████████                             | 290/500 [30:52<41:33, 11.87s/it]

[0.8268355675808634, 0.8954618128151519, 0.9391218792276472, 0.9629811831263068, 0.9809371541015866, 0.9894231951789447]


 60%|█████████████████████████████████████████▍                           | 300/500 [32:01<39:35, 11.88s/it]

[0.8299102201451236, 0.8965686877382856, 0.9376460459968023, 0.9624892387160251, 0.9808141679990161, 0.9895461812815152]


 62%|██████████████████████████████████████████▊                          | 310/500 [33:05<36:17, 11.46s/it]

[0.8269585536834337, 0.8943549378920183, 0.9358012544582462, 0.9637190997417292, 0.9804452096913049, 0.9890542368712335]


 64%|████████████████████████████████████████████▏                        | 320/500 [34:11<34:19, 11.44s/it]

[0.8264666092731522, 0.8937400073791661, 0.9389988931250769, 0.9634731275365883, 0.9805681957938753, 0.9895461812815152]


 66%|█████████████████████████████████████████████▌                       | 330/500 [35:16<33:19, 11.76s/it]

[0.829418275734842, 0.8995203541999754, 0.9398597958430698, 0.9649489607674333, 0.9801992374861641, 0.9883163202558111]


 68%|██████████████████████████████████████████████▊                      | 339/500 [35:53<12:19,  4.59s/it]

[0.8313860533759685, 0.9012421596359611, 0.9446562538433158, 0.9686385438445456, 0.9809371541015866, 0.9893002090763744]
save model........


 70%|████████████████████████████████████████████████▏                    | 349/500 [36:57<10:46,  4.28s/it]

[0.8331078588119543, 0.9033329233796581, 0.9451481982535973, 0.9677776411265527, 0.9819210429221498, 0.9899151395892264]
save model........


 72%|█████████████████████████████████████████████████▌                   | 359/500 [38:03<10:06,  4.30s/it]

[0.8363054974787849, 0.9017341040462428, 0.9419505595867667, 0.9647029885622924, 0.9810601402041569, 0.9895461812815152]
save model........


 74%|██████████████████████████████████████████████████▉                  | 369/500 [39:12<09:48,  4.49s/it]

[0.8424548026073053, 0.9092362563030377, 0.9424425039970483, 0.9683925716394047, 0.9825359734350019, 0.9900381256917968]
save model........


 76%|████████████████████████████████████████████████████▎                | 379/500 [40:16<08:15,  4.09s/it]

[0.8445455663510023, 0.9080063952773336, 0.9442872955356045, 0.9677776411265527, 0.9820440290247202, 0.9900381256917968]
save model........


 78%|█████████████████████████████████████████████████████▊               | 390/500 [41:42<19:05, 10.42s/it]

[0.8434386914278686, 0.9043168122002214, 0.9399827819456401, 0.9661788217931374, 0.9824129873324314, 0.9894231951789447]


 80%|███████████████████████████████████████████████████████▏             | 400/500 [42:39<17:07, 10.28s/it]

[0.844422580248432, 0.9046857705079326, 0.9439183372278932, 0.9691304882548272, 0.9820440290247202, 0.9916369450252122]


 82%|████████████████████████████████████████████████████████▍            | 409/500 [43:13<05:57,  3.93s/it]

[0.8460213995818473, 0.9059156315336367, 0.9446562538433158, 0.968146599434264, 0.981675070717009, 0.9912679867175009]
save model........


 84%|█████████████████████████████████████████████████████████▊           | 419/500 [44:10<05:06,  3.79s/it]

[0.8497109826589595, 0.910220145123601, 0.9439183372278932, 0.9685155577419752, 0.9821670151272907, 0.9901611117943673]
save model........


 86%|███████████████████████████████████████████████████████████▎         | 430/500 [45:34<12:00, 10.30s/it]

[0.8488500799409666, 0.9099741729184602, 0.9466240314844423, 0.9693764604599681, 0.9820440290247202, 0.9902840978969376]


 88%|████████████████████████████████████████████████████████████▌        | 439/500 [46:05<03:54,  3.85s/it]

[0.8518017464026565, 0.9091132702004674, 0.9449022260484565, 0.9674086828188415, 0.9835198622555651, 0.9907760423072193]
save model........


 88%|████████████████████████████████████████████████████████████▊        | 441/500 [46:34<08:00,  8.15s/it]wandb: Network error (ConnectTimeout), entering retry loop.
 90%|██████████████████████████████████████████████████████████████       | 450/500 [47:29<08:43, 10.47s/it]

[0.8497109826589595, 0.908990284097897, 0.9458861148690199, 0.9699913909728201, 0.9820440290247202, 0.9896691673840856]


 92%|███████████████████████████████████████████████████████████████▍     | 460/500 [48:27<06:51, 10.29s/it]

[0.8451604968638544, 0.9057926454310663, 0.9423195178944779, 0.9675316689214118, 0.9815520846144385, 0.9897921534866561]


 94%|████████████████████████████████████████████████████████████████▊    | 470/500 [49:24<05:03, 10.12s/it]

[0.8484811216332554, 0.9068995203542, 0.9446562538433158, 0.9676546550239823, 0.9804452096913049, 0.988562292460952]


 96%|██████████████████████████████████████████████████████████████████▏  | 480/500 [50:21<03:22, 10.12s/it]

[0.8513098019923748, 0.9110810478415939, 0.947730906407576, 0.965440905177715, 0.981675070717009, 0.9896691673840856]


 98%|███████████████████████████████████████████████████████████████████▍ | 489/500 [50:53<00:42,  3.86s/it]

[0.8522936908129382, 0.9097282007133194, 0.9457631287664494, 0.9685155577419752, 0.9813061124092978, 0.9905300701020785]
save model........


 99%|████████████████████████████████████████████████████████████████████▏| 494/500 [51:32<00:30,  5.05s/it]wandb: Network error (ConnectTimeout), entering retry loop.
100%|█████████████████████████████████████████████████████████████████████| 500/500 [52:17<00:00,  6.27s/it]

[0.8515557741975157, 0.912310908867298, 0.9453941704587382, 0.9682695855368343, 0.9817980568195794, 0.9901611117943673]
best: 490 0.8522936908129382



