In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
from apex import amp

import os
import random
import timm
import wandb
from tqdm import trange
import multiprocessing
from functools import partial
import numpy as np
import PIL
from torchvision.transforms import InterpolationMode

from sampler import UniqueClassSempler
from proxy_anchor.utils import calc_recall_at_k
from proxy_anchor.dataset import CUBirds, SOP, Cars
from proxy_anchor.dataset.Inshop import Inshop_Dataset
from hyptorch.pmath import dist_matrix
import hyptorch.nn as hypnn

import torch.distributed as dist
from datetime import timedelta

# 超参数

In [2]:
num = 2
path = '/data/xuyunhao/datasets'
ds = 'Inshop'
num_samples = 2
bs = 200

lr = 3e-5

t = 0.2
emb = 128
freeze = 0
epoch = 500
hyp_c = 0
eval_ep = 'r('+str(epoch-100)+','+str(epoch+10)+',10)'

model = 'vit_small_patch16_224'
# model = 'dino_vits16'
# model = 'deit_small_distilled_patch16_224'

save_emb = False
emb_name = 'emb'
clip_r = 2.3
resize = 224
crop = 224
local_rank = 0
save_path = "/data/xuyunhao/Mixed curvature/result/{}_{}_best_e_{}_{}_checkout.pth".format(model,ds,emb,num)

# 损失函数

In [3]:
def contrastive_loss(e0, e1, tau):
    # x0 and x1 - positive pair
    # tau - temperature
    # hyp_c - hyperbolic curvature, "0" enables sphere mode
    dist_e = lambda x, y: -torch.cdist(x, y, p=2)
    
    dist_e0 = dist_e(e0, e0)   
    dist_e1 = dist_e(e0, e1)

    bsize = e0.shape[0]
    target = torch.arange(bsize).cuda()
    eye_mask = torch.eye(bsize).cuda() * 1e9
    
    logits00 = dist_e0 / tau - eye_mask
    logits01 = dist_e1 / tau
    
    logits = torch.cat([logits01, logits00], dim=1)
    logits -= logits.max(1, keepdim=True)[0].detach()
    loss = F.cross_entropy(logits, target)
    stats = {
        "logits/min": logits01.min().item(),
        "logits/mean": logits01.mean().item(),
        "logits/max": logits01.max().item(),
        "logits/acc": (logits01.argmax(-1) == target).float().mean().item(),
    }
    return loss, stats

# 模型

In [4]:
def init_model(model = model, hyp_c = 0.1, emb = 128, clip_r = 2.3, freeze = 0):
    if model.startswith("dino"):
        body = torch.hub.load("facebookresearch/dino:main", model)
    else:
        body = timm.create_model(model, pretrained=True)
    if hyp_c > 0:
        last = hypnn.ToPoincare(
            c=hyp_c,
            ball_dim=emb,
            riemannian=False,
            clip_r=clip_r,
        )
    else:
        last = NormLayer()
    bdim = 2048 if model == "resnet50" else 384
    head = nn.Sequential(nn.Linear(bdim, emb), nn.BatchNorm1d(emb))
    nn.init.constant_(head[0].bias.data, 0)
    nn.init.orthogonal_(head[0].weight.data)
    rm_head(body)
    if freeze is not None:
        freezer(body,freeze)
    model = HeadSwitch(body, head)
    model.cuda().train()
    return model
    
    
class HeadSwitch(nn.Module):
    def __init__(self, body, head):
        super(HeadSwitch, self).__init__()
        self.body = body
        self.head = head
        self.norm = NormLayer()

    def forward(self, x, skip_head=False):
        x = self.body(x)
        if type(x) == tuple:
            x = x[0]
        if not skip_head:
            x = self.head(x)
        else:
            x = self.norm(x)
        return x


class NormLayer(nn.Module):
    def forward(self, x):
        return F.normalize(x, p=2, dim=1)


def freezer(model, num_block):
    def fr(m):
        for param in m.parameters():
            param.requires_grad = False

    fr(model.patch_embed)
    fr(model.pos_drop)
    for i in range(num_block):
        fr(model.blocks[i])


def rm_head(m):
    names = set(x[0] for x in m.named_children())
    target = {"head", "fc", "head_dist"}
    for x in names & target:
        m.add_module(x, nn.Identity())

In [5]:
class MultiSample:
    def __init__(self, transform, n=2):
        self.transform = transform
        self.num = n

    def __call__(self, x):
        return tuple(self.transform(x) for _ in range(self.num))


def evaluate(get_emb_f, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        emb_head = get_emb_f(ds_type="eval")
        recall_head = get_recall(*emb_head, ds_name, hyp_c)
    elif ds_name == "SOP":
        emb_head = get_emb_f(ds_type="eval")
        recall_head = get_recall_sop(*emb_head, ds_name, hyp_c)
    else:
        emb_head_query = get_emb_f(ds_type="query")
        emb_head_gal = get_emb_f(ds_type="gallery")
        recall_head = get_recall_inshop(*emb_head_query, *emb_head_gal, hyp_c)
    return recall_head

def get_recall(x, y, ds_name, hyp_c):
    if ds_name == "CUB" or ds_name == "Cars":
        k_list = [1, 2, 4, 8, 16, 32]
    elif ds_name == "SOP":
        k_list = [1, 10, 100, 1000]

    if hyp_c > 0:
        dist_m = torch.empty(len(x), len(x), device="cuda")
        for i in range(len(x)):
            dist_m[i : i + 1] = -dist_matrix_d(x[i : i + 1], x, hyp_c)
    else:
        dist_m = torch.empty(len(x), len(x), device="cuda")
        for i in range(len(x)):
            dist_m[i : i + 1] = -torch.cdist(x[i : i + 1], x, p=2)

    y_cur = y[dist_m.topk(1 + max(k_list), largest=True)[1][:, 1:]]
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall[0]


def get_recall_sop(x, y, ds_name, hyp_c):
    y_cur = torch.tensor([]).cuda().int()
    number = 1000
    k_list = [1, 10, 100, 1000]
    if hyp_c > 0:
        for i in range(len(x) // number + 1):
            if (i+1)*number > len(x):
                x_s = x[i*number:]
            else:
                x_s = x[i*number: (i+1)*number]
            dist = torch.empty(len(x_s), len(x), device="cuda")
            for i in range(len(x_s)):
                dist[i : i + 1] = -dist_matrix_d(x_s[i : i + 1], x, hyp_c)
            dist = y[dist.topk(1 + max(k_list), largest=True)[1][:, 1:]]
            y_cur = torch.cat([y_cur, dist])
    else:
        for i in range(len(x) // number + 1):
            if (i+1)*number > len(x):
                x_s = x[i*number:]
            else:
                x_s = x[i*number: (i+1)*number]
            dist = torch.empty(len(x_s), len(x), device="cuda")
            for i in range(len(x_s)):
                dist[i : i + 1] = -torch.cdist(x_s[i : i + 1], x, p=2)
            dist = y[dist.topk(1 + max(k_list), largest=True)[1][:, 1:]]
            y_cur = torch.cat([y_cur, dist])
    y = y.cpu()
    y_cur = y_cur.float().cpu()
    recall = [calc_recall_at_k(y, y_cur, k) for k in k_list]
    print(recall)
    return recall


def get_recall_inshop(xq, yq, xg, yg, hyp_c):
    if hyp_c > 0:
        dist_m = torch.empty(len(xq), len(xg), device="cuda")
        for i in range(len(xq)):
            dist_m[i : i + 1] = -dist_matrix_d(xq[i : i + 1], xg, hyp_c)
    else:
        dist_m = torch.empty(len(xq), len(xg), device="cuda")
        for i in range(len(xq)):
            dist_m[i : i + 1] = -torch.cdist(xq[i : i + 1], xg, p=2)

    def recall_k(cos_sim, query_T, gallery_T, k):
        m = len(cos_sim)
        match_counter = 0
        for i in range(m):
            pos_sim = cos_sim[i][gallery_T == query_T[i]]
            neg_sim = cos_sim[i][gallery_T != query_T[i]]
            thresh = torch.max(pos_sim).item()
            if torch.sum(neg_sim > thresh) < k:
                match_counter += 1
        return match_counter / m

    recall = [recall_k(dist_m, yq, yg, k) for k in [1, 10, 20, 30, 40, 50]]
    print(recall)
    return recall


def get_emb(
    model,
    ds,
    path,
    mean_std,
    resize=224,
    crop=224,
    ds_type="eval",
    world_size=1,
    skip_head=False,
):
    eval_tr = T.Compose(
        [
            T.Resize(resize, interpolation=PIL.Image.BICUBIC),
            T.CenterCrop(crop),
            T.ToTensor(),
            T.Normalize(*mean_std),
        ]
    )
    ds_eval = ds(path, ds_type, eval_tr)
    if world_size == 1:
        sampler = None
    else:
        sampler = torch.utils.data.distributed.DistributedSampler(ds_eval)
    dl_eval = DataLoader(
        dataset=ds_eval,
        batch_size=100,
        shuffle=False,
        num_workers=multiprocessing.cpu_count() // world_size,
        pin_memory=True,
        drop_last=False,
        sampler=sampler,
    )
    model.eval()
    x, y = eval_dataset(model, dl_eval, skip_head)
    y = y.cuda()
    if world_size > 1:
        all_x = [torch.zeros_like(x) for _ in range(world_size)]
        all_y = [torch.zeros_like(y) for _ in range(world_size)]
        torch.distributed.all_gather(all_x, x)
        torch.distributed.all_gather(all_y, y)
        x, y = torch.cat(all_x), torch.cat(all_y)
    model.train()
    return x, y


def eval_dataset(model, dl, skip_head):
    all_x, all_y = [], []
    for x, y in dl:
        with torch.no_grad():
            x = x.cuda(non_blocking=True)
            all_x.append(model(x, skip_head=skip_head))
        all_y.append(y)
    return torch.cat(all_x), torch.cat(all_y)


# 主函数

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] =  "3,4"


if local_rank == 0:
    wandb.init(project="hyp_metric")

world_size = int(os.environ.get("WORLD_SIZE", 1))
    
if model.startswith("vit"):
    mean_std = (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)
else:
    mean_std = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
train_tr = T.Compose(
    [
        T.RandomResizedCrop(
            crop, scale=(0.2, 1.0), interpolation=InterpolationMode.BICUBIC
        ),
        T.RandomHorizontalFlip(),
        T.ToTensor(),
        T.Normalize(*mean_std),
    ]
)

ds_list = {"CUB": CUBirds, "SOP": SOP, "Cars": Cars, "Inshop": Inshop_Dataset}
ds_class = ds_list[ds]
ds_train = ds_class(path, "train", train_tr)
assert len(ds_train.ys) * num_samples >= bs * world_size
sampler = UniqueClassSempler(
    ds_train.ys, num_samples, local_rank, world_size
)
dl_train = DataLoader(
    dataset=ds_train,
    sampler=sampler,
    batch_size=bs,
    num_workers=8,
    pin_memory=True,
    drop_last=True,
)

model = init_model(model = model, hyp_c = hyp_c, emb = emb, clip_r = clip_r, freeze = freeze)
optimizer = optim.AdamW(model.parameters(), lr=lr)
model = nn.DataParallel(model)

loss_f = partial(contrastive_loss, tau = t)
get_emb_f = partial(
    get_emb,
    model=model,
    ds=ds_class,
    path=path,
    mean_std=mean_std,
    world_size=world_size,
    resize=resize,
    crop=crop,
)
eval_ep = eval(eval_ep.replace("r", "list(range").replace(")", "))"))    

cudnn.benchmark = True
all_rh = []
best_rh = []
best_ep = 0
lower_cnt = 0

for ep in trange(epoch):
    sampler.set_epoch(ep)
    stats_ep = []
    for x, y in dl_train:
        y = y.view(len(y) // num_samples, num_samples)
        assert (y[:, 0] == y[:, -1]).all()
        s = y[:, 0].tolist()
        assert len(set(s)) == len(s)

        x = x.cuda(non_blocking=True)
        z = model(x)
        z=z.view(len(x) // num_samples, num_samples, emb)
        if world_size > 1:
            with torch.no_grad():
                all_z = [torch.zeros_like(z) for _ in range(world_size)]
                torch.distributed.all_gather(all_z, z)
            all_z[local_rank] = z
            z = torch.cat(all_z)
        loss = 0
        for i in range(num_samples):
            for j in range(num_samples):
                if i != j:
                    l, s = loss_f(z[:, i], z[:, j])
                    loss += l
                    stats_ep.append({**s, "loss": l.item()})

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 3)
        optimizer.step()
        
#     if (ep + 1) in eval_ep:
#         rh= evaluate(get_emb_f, ds, hyp_c)
    
    if (ep+1) % 10 == 0 or ep == 0:
        rh = evaluate(get_emb_f, ds, hyp_c)
        all_rh.append(rh)
        if ep == 0:
            best_rh = rh
        else:
            if isinstance(rh, list):
                if rh[0] >= best_rh[0]:
                    lower_cnt = 0
                    best_rh = rh
                    best_ep = ep
                    print("save model........")
                    torch.save(model.state_dict(), save_path)
                else:
                    lower_cnt += 1
            else:
                if rh >= best_rh:
                    lower_cnt = 0
                    best_rh = rh
                    best_ep = ep
                    print("save model........")
                    torch.save(model.state_dict(), save_path)
                else:
                    lower_cnt += 1
    
    if lower_cnt >= 10:
        break

    #if local_rank == 0:
        #stats_ep = {k: np.mean([x[k] for x in stats_ep]) for k in stats_ep[0]}
        #if (epoch + 1) in eval_ep:
        #    stats_ep = {"recall": rh, **stats_ep}
        #wandb.log({**stats_ep, "ep": ep})

print("best:", best_ep+1, best_rh)
print("save_path:", save_path)
print("load_model:", load_model)
        
if save_emb:
    ds_type = "gallery" if ds == "Inshop" else "eval"
    x, y = get_emb_f(ds_type=ds_type)
    x, y = x.float().cpu(), y.long().cpu()
    torch.save((x, y), path + "/" + emb_name + "_eval.pt")

    x, y = get_emb_f(ds_type="train")
    x, y = x.float().cpu(), y.long().cpu()
    torch.save((x, y), path + "/" + emb_name + "_train.pt")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mxuyunhao[0m. Use [1m`wandb login --relogin`[0m to force relogin


  "Argument interpolation should be of type InterpolationMode instead of int. "
  0%|                                                            | 1/500 [01:30<12:31:04, 90.31s/it]

[0.6818821212547475, 0.8969615979743987, 0.9291742861161908, 0.9438036292024194, 0.9542833028555352, 0.9608243072162048]


  2%|█▏                                                          | 10/500 [05:29<5:30:04, 40.42s/it]

[0.8465325643550429, 0.9676466450977634, 0.9791813194542129, 0.983330988887326, 0.9860036573357716, 0.9875509917006612]
save model........


  4%|██▎                                                         | 19/500 [08:34<2:52:38, 21.53s/it]

[0.8697425798283865, 0.9735546490364327, 0.9820649880433253, 0.9859333239555493, 0.9884653256435504, 0.9898719932479955]
save model........


  6%|███▍                                                        | 29/500 [13:07<2:53:13, 22.07s/it]

[0.8724152482768321, 0.9743283162188775, 0.983049655366437, 0.9871289914193276, 0.9891686594457729, 0.990856660571107]
save model........


  8%|████▋                                                       | 39/500 [17:37<2:46:47, 21.71s/it]

[0.8864115909410606, 0.9764383176255451, 0.9843859895906597, 0.9878323252215502, 0.9900829933886622, 0.9914896609931073]
save model........


 10%|██████                                                      | 50/500 [23:32<5:03:06, 40.42s/it]

[0.8833872555915038, 0.9781966521311014, 0.9853706569137712, 0.9887466591644395, 0.990293993529329, 0.9919819946546631]


 12%|███████                                                     | 59/500 [26:41<2:40:05, 21.78s/it]

[0.892882261921508, 0.9796033197355465, 0.9862849908566605, 0.9890279926853285, 0.9906456604304403, 0.9919819946546631]
save model........


 14%|████████▎                                                   | 69/500 [31:14<2:37:18, 21.90s/it]

[0.8949922633281755, 0.9801659867773245, 0.9868476578984386, 0.9893796595864397, 0.9912083274722183, 0.9922633281755521]
save model........


 16%|█████████▍                                                  | 79/500 [35:44<2:31:13, 21.55s/it]

[0.8959065972710648, 0.9793923195948797, 0.9864256576171051, 0.9893093262062175, 0.990575327050218, 0.9924039949359966]
save model........


 18%|██████████▊                                                 | 90/500 [41:38<4:33:44, 40.06s/it]

[0.8942185961457307, 0.9796736531157687, 0.9866366577577719, 0.9891686594457729, 0.9909973273315515, 0.9923336615557744]


 20%|███████████▉                                                | 99/500 [44:51<2:28:29, 22.22s/it]

[0.8989309326206217, 0.9788296525531017, 0.9868476578984386, 0.9898719932479955, 0.9911379940919961, 0.9925446616964412]
save model........


 22%|████████████▉                                              | 110/500 [50:52<4:31:28, 41.77s/it]

[0.8982979321986214, 0.9806583204388802, 0.9867773245182163, 0.990293993529329, 0.9916303277535519, 0.9920523280348854]


 24%|██████████████▏                                            | 120/500 [55:26<4:20:16, 41.10s/it]

[0.8943592629061753, 0.9806583204388802, 0.9867773245182163, 0.9891686594457729, 0.9906456604304403, 0.9919819946546631]


 26%|███████████████▏                                           | 129/500 [58:38<2:16:34, 22.09s/it]

[0.899563933042622, 0.9807989871993248, 0.9871289914193276, 0.9895203263468842, 0.9911379940919961, 0.9923336615557744]
save model........


 28%|███████████████▊                                         | 139/500 [1:03:10<2:14:24, 22.34s/it]

[0.903432268954846, 0.9811506541004361, 0.9875509917006612, 0.9906456604304403, 0.9920523280348854, 0.9929666619777746]
save model........


 30%|█████████████████                                        | 150/500 [1:09:04<3:56:29, 40.54s/it]

[0.9026586017724012, 0.9825573217048811, 0.9872696581797721, 0.9903643269095512, 0.9918413278942186, 0.9929666619777746]


 32%|██████████████████▏                                      | 160/500 [1:13:39<3:56:15, 41.69s/it]

[0.9008299338866226, 0.9807286538191026, 0.9877619918413278, 0.9906456604304403, 0.9919819946546631, 0.9924743283162188]


 34%|███████████████████▍                                     | 170/500 [1:18:09<3:40:08, 40.03s/it]

[0.9021662681108454, 0.9817133211422141, 0.9872696581797721, 0.9902236601491068, 0.9911379940919961, 0.9923336615557744]


 36%|████████████████████▍                                    | 179/500 [1:21:22<1:58:20, 22.12s/it]

[0.9061049374032916, 0.9816429877619919, 0.9876213250808834, 0.990575327050218, 0.9924743283162188, 0.9934589956393304]
save model........


 38%|█████████████████████▌                                   | 189/500 [1:25:53<1:53:44, 21.94s/it]

[0.9076522717681812, 0.9825573217048811, 0.9879729919819946, 0.9906456604304403, 0.9919819946546631, 0.9928963285975524]
save model........


 40%|██████████████████████▋                                  | 199/500 [1:30:24<1:51:46, 22.28s/it]

[0.9083556055704037, 0.9827683218455479, 0.9883949922633282, 0.9912786608524405, 0.9924743283162188, 0.9933886622591082]
save model........


 42%|███████████████████████▊                                 | 209/500 [1:35:01<1:49:11, 22.51s/it]wandb: Network error (ReadTimeout), entering retry loop.
 42%|███████████████████████▉                                 | 210/500 [1:36:24<3:17:11, 40.80s/it]

[0.9075116050077366, 0.9828386552257702, 0.9887466591644395, 0.9910676607117738, 0.9924743283162188, 0.9933183288788859]


 44%|████████████████████████▉                                | 219/500 [1:39:38<1:46:41, 22.78s/it]

[0.9091996061330707, 0.9829793219862146, 0.9888169925446617, 0.990293993529329, 0.9912786608524405, 0.9921929947953299]
save model........


 46%|██████████████████████████▏                              | 230/500 [1:45:34<3:05:04, 41.13s/it]

[0.9076522717681812, 0.983049655366437, 0.9881136587424392, 0.9904346602897736, 0.9916303277535519, 0.9927556618371078]


 48%|███████████████████████████▏                             | 239/500 [1:48:41<1:34:14, 21.66s/it]

[0.9101842734561824, 0.9834013222675482, 0.988043325362217, 0.9900829933886622, 0.9918413278942186, 0.9928259952173302]
save model........


 50%|████████████████████████████▌                            | 250/500 [1:54:44<2:57:15, 42.54s/it]

[0.908144605429737, 0.9824166549444366, 0.9871993247995499, 0.9898719932479955, 0.9914896609931073, 0.9921929947953299]


 52%|█████████████████████████████▌                           | 259/500 [1:57:51<1:30:15, 22.47s/it]

[0.9103249402166268, 0.9828386552257702, 0.9891686594457729, 0.9914896609931073, 0.9924743283162188, 0.9931776621184414]
save model........


 54%|██████████████████████████████▋                          | 269/500 [2:02:26<1:26:14, 22.40s/it]

[0.9127162751441834, 0.9817836545224363, 0.988043325362217, 0.9904346602897736, 0.9922633281755521, 0.9935996623997749]
save model........


 56%|███████████████████████████████▉                         | 280/500 [2:08:29<2:29:57, 40.90s/it]

[0.9087776058517373, 0.9820649880433253, 0.9874103249402166, 0.9898016598677732, 0.9917709945139963, 0.9931776621184414]


 58%|█████████████████████████████████                        | 290/500 [2:13:06<2:25:58, 41.71s/it]

[0.9112392741595161, 0.983612322408215, 0.9883246588831059, 0.9905049936699958, 0.9913489942326629, 0.9926149950766634]


 60%|██████████████████████████████████                       | 299/500 [2:16:20<1:16:12, 22.75s/it]

[0.9135602757068505, 0.9832606555071036, 0.9886763257842172, 0.9912786608524405, 0.9922633281755521, 0.9931073287382192]
save model........


 62%|███████████████████████████████████▎                     | 310/500 [2:22:21<2:10:50, 41.32s/it]

[0.9110986073990716, 0.9831903221268814, 0.9883949922633282, 0.9911379940919961, 0.9922633281755521, 0.9932479954986637]


 64%|████████████████████████████████████▍                    | 320/500 [2:26:56<2:01:42, 40.57s/it]

[0.9122239414826276, 0.9841749894499929, 0.9892389928259953, 0.9911379940919961, 0.9925446616964412, 0.9933886622591082]


 66%|█████████████████████████████████████▌                   | 329/500 [2:30:12<1:05:28, 22.98s/it]

[0.9138416092277395, 0.9834013222675482, 0.9892389928259953, 0.9913489942326629, 0.9925446616964412, 0.9931073287382192]
save model........


 68%|██████████████████████████████████████▊                  | 340/500 [2:36:16<1:49:33, 41.08s/it]

[0.9113799409199607, 0.9829793219862146, 0.9886059924039949, 0.990575327050218, 0.9921929947953299, 0.9928963285975524]


 70%|███████████████████████████████████████▉                 | 350/500 [2:40:51<1:41:35, 40.64s/it]

[0.910395273596849, 0.9828386552257702, 0.9888169925446617, 0.9915599943733295, 0.9928259952173302, 0.9937403291602195]


 72%|██████████████████████████████████████████▏                | 358/500 [2:43:40<52:53, 22.35s/it]wandb: Network error (ReadTimeout), entering retry loop.
 72%|█████████████████████████████████████████                | 360/500 [2:45:31<1:38:35, 42.25s/it]

[0.9135602757068505, 0.9838936559291039, 0.9890279926853285, 0.9912083274722183, 0.9921226614151076, 0.9931073287382192]


 74%|██████████████████████████████████████████▏              | 370/500 [2:50:05<1:28:39, 40.92s/it]

[0.9099029399352933, 0.9825573217048811, 0.9882543255028837, 0.9909269939513293, 0.9923336615557744, 0.9931776621184414]
