In [1]:
%load_ext autoreload
%autoreload 2
from boxes import *
from learner import *
import math
import matplotlib.pyplot as plt
import os
import wandb
import pickle

%matplotlib inline

In [2]:
def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed) # cpu
    torch.cuda.manual_seed_all(seed)  # gpu
    
set_seed(42)

torch.set_printoptions(precision=16)

In [3]:
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
use_cuda

False

In [273]:
PATH = 'data/ontologies/anatomy/'

# aligment training split
ats = 0.8

# Transitive closure
Transitive_Closure = False

if Transitive_Closure:
    tc = "tc_"
else:
    tc = ""

# Data in unary.tsv are probabilites separated by newlines. The probability on line n is P(n), where n is the id assigned to the nth element.
unary_prob = torch.from_numpy(np.loadtxt(f'{PATH}unary/unary.tsv')).float().to(device)
num_boxes = unary_prob.shape[0]

# We're going to use random negative sampling during training, so no need to include negatives in our training data itself
train = Probs.load_from_julia(PATH, f'tr_pos_{tc}{ats}.tsv', f'tr_neg_{ats}.tsv', ratio_neg = 0).to(device)

# The dev set will have a fixed set of negatives, however.
dev = Probs.load_from_julia(PATH, f'dev_align_pos_{ats}.tsv', f'dev_align_neg_{ats}.tsv', ratio_neg = 1).to(device)

# This set is used just for evaluation purposes after training
tr_align = Probs.load_from_julia(PATH, f'tr_align_pos_{ats}.tsv', f'tr_align_neg_{ats}.tsv', ratio_neg = 1).to(device)


In [274]:
with open('data/ontologies/anatomy/human.pickle', 'rb') as f:
    human_pickle = pickle.load(f)
    
with open('data/ontologies/anatomy/mouse.pickle', 'rb') as f:
    mouse_pickle = pickle.load(f)
    
with open('data/ontologies/anatomy/entities.pickle', 'rb') as f:
    entity_pickle = pickle.load(f)

In [275]:
mouse_eval = Probs.load_from_julia(PATH, 'human_dev_pos.tsv', 'human_dev_neg.tsv', ratio_neg = 1).to(device)
human_eval = Probs.load_from_julia(PATH, 'mouse_dev_pos.tsv', 'mouse_dev_neg.tsv', ratio_neg = 1).to(device)

In [276]:
import torch
from torch import Tensor # for type annotations
import torch.nn.functional as F
from typing import *


def intersection(A: Tensor, B: Tensor) -> Tensor:
    """
    :param A: Tensor(..., zZ, dim)
    :param B: Tensor(..., zZ, dim)
    :return: Tensor(..., zZ, dim), box embeddings for A intersect B
    """
    z = torch.max(A[...,0,:], B[...,0,:])
    Z = torch.min(A[...,1,:], B[...,1,:])
    return torch.stack((z, Z), dim=-2)

def neg_edge_adjustment(A: Tensor) -> Tensor:
    """
    :param A: Tensor(..., zZ, dim)
    
    Replace "negative" edges with their mean.
    
    (TODO: optimize this)
    """
    center_of_meet = torch.mean(A, dim=-2)
    neg_edges_mask = ((A[...,1,:] - A[...,0,:]) >= 0)
    neg_edges_mask_stack = torch.stack((neg_edges_mask, neg_edges_mask), dim=-2)
    center_of_meet_stack = torch.stack((center_of_meet, center_of_meet), dim=-2)
    return torch.where(neg_edges_mask_stack, A, center_of_meet_stack)

def join(A: Tensor, B: Tensor) -> Tensor:
    """
    :param A: Tensor(model, pair, zZ, dim)
    :param B: Tensor(model, pair, zZ, dim)
    :return: Tensor(model, pair, zZ, dim), box embeddings for the smallest box which contains A and B
    """
    z = torch.min(A[:,:,0], B[:,:,0])
    Z = torch.max(A[:,:,1], B[:,:,1])
    return torch.stack((z, Z), dim=2)


def clamp_volume(boxes: Tensor) -> Tensor:
    """
    :param boxes: Tensor(... zZ, dim)
    :return: Tensor(...) of volumes
    """
    return torch.prod((boxes[...,1,:] - boxes[...,0,:]).clamp_min(0), dim=-1)


def soft_volume(boxes: Tensor) -> Tensor:
    """
    :param sidelengths: Tensor(model, box, dim)
    :return: Tensor(model, box) of volumes
    """
    return torch.prod(F.softplus(boxes[:,:,1] - boxes[:,:,0]), dim=-1)


def log_clamp_volume(boxes: Tensor, eps:float = torch.finfo(torch.float32).tiny) -> Tensor:
    """
    :param boxes: Tensor(model, box, zZ, dim)
    :return: Tensor(model, box) of volumes
    """
    return torch.sum(torch.log((boxes[:,:,1] - boxes[:,:,0]).clamp_min(0) + eps), dim=-1)


def log_soft_volume(boxes: Tensor, eps:float = torch.finfo(torch.float32).tiny) -> Tensor:
    """
    :param sidelengths: Tensor(model, box, dim)
    :return: Tensor(model, box) of volumes
    """
    return torch.sum(torch.log(F.softplus(boxes[:,:,1] - boxes[:,:,0]) + eps), dim=-1)


def smallest_containing_box(boxes: Tensor) -> Tensor:
    """
    Returns the smallest box which contains all boxes in `boxes`.
    
    :param boxes: Box embedding of shape (model, box, zZ, dim)
    :return: Tensor of shape (model, 1, zZ, dim)
    """
    z = boxes[:,:,0]
    Z = boxes[:,:,1]
    min_z, _ = torch.min(z, dim=1, keepdim=True)
    max_Z, _ = torch.max(Z, dim=1, keepdim=True)
    return torch.stack((min_z, max_Z), dim=2)

def smallest_containing_box_outside_unit_cube(boxes: Tensor) -> Tensor:
    """
    Returns the smallest box which contains all boxes in `boxes` and the unit cube.

    :param boxes: Box embedding of shape (model, box, zZ, dim)
    :return: Tensor of shape (model, 1, zZ, dim)
    """
    z = boxes[:,:,0]
    Z = boxes[:,:,1]
    min_z, _ = torch.min(z, dim=1, keepdim=True)
    max_Z, _ = torch.max(Z, dim=1, keepdim=True)
    min_z = min_z.clamp_max(0)
    max_Z = max_Z.clamp_min(1)
    return torch.stack((min_z, max_Z), dim=2)


def detect_small_boxes(boxes: Tensor, vol_func: Callable = clamp_volume, min_vol: float = 1e-20) -> Tensor:
    """
    Returns the indices of boxes with volume smaller than eps.

    :param boxes: box parametrization as Tensor(model, box, z/Z, dim)
    :param vol_func: function taking in side lengths and returning volumes
    :param min_vol: minimum volume of boxes
    :return: masked tensor which selects boxes whose side lengths are less than min_vol
    """
    return vol_func(boxes) < min_vol


def replace_Z_by_cube(boxes: Tensor, indices: Tensor, cube_vol: float = 1e-20) -> Tensor:
    """
    Returns a new Z parameter for boxes for which those boxes[indices] are now replaced by cubes of size cube_vol

    :param boxes: box parametrization as Tensor(model, box, z/Z, dim)
    :param indices: box indices to replace by a cube
    :param cube_vol: volume of cube
    :return: tensor representing the Z parameter
    """
    return boxes[:, :, 0][indices] + cube_vol ** (1 / boxes.shape[-1])



def replace_Z_by_cube_(boxes: Tensor, indices: Tensor, cube_vol: float = 1e-20) -> Tensor:
    """
    Replaces the boxes indexed by `indices` by a cube of volume `min_vol` with the same z coordinate

    :param boxes: box parametrization as Tensor(model, box, z/Z, dim)
    :param indices: box indices to replace by a cube
    :param cube_vol: volume of cube
    :return: tensor representing the box parametrization with those boxes
    """
    boxes[:, :, 1][indices] = replace_Z_by_cube(boxes, indices, cube_vol)


def disjoint_boxes_mask(A: Tensor, B: Tensor) -> Tensor:
    """
    Returns a mask for when A and B are disjoint.
    Note: This is symmetric with respect to the arguments.
    """
    return ((B[:,:,1] <= A[:,:,0]) | (B[:,:,0] >= A[:,:,1])).any(dim=-1)


def overlapping_boxes_mask(A: Tensor, B: Tensor) -> Tensor:
    return disjoint_boxes_mask(A, B) ^ 1


def containing_boxes_mask(A: Tensor, B: Tensor) -> Tensor:
    """
    Returns a mask for when B contains A.
    Note: This is *not* symmetric with respect to it's arguments!
    """
    return ((B[:,:,1] >= A[:,:,1]) & (B[:,:,0] <= A[:,:,0])).all(dim=-1)


def needing_pull_mask(A: Tensor, B: Tensor, target_prob_B_given_A: Tensor) -> Tensor:
    return (target_prob_B_given_A != 0) & disjoint_boxes_mask(A, B)


def needing_push_mask(A: Tensor, B: Tensor, target_prob_B_given_A: Tensor) -> Tensor:
    return (target_prob_B_given_A != 1) & containing_boxes_mask(A, B)


In [277]:
import torch
from torch import Tensor
import scipy.stats as spstats # For Spearman r
from sklearn.metrics import roc_curve, precision_recall_curve  # for roc_curve


def metric_hard_accuracy(model, data_in, data_out):
    hard_pred = model(data_in, is_align=torch.tensor(0))["P(A|B)"] > 0.5
    return (data_out == hard_pred.float()).float().mean()


def metric_hard_f1(model, data_in, data_out):
    hard_pred = model(data_in, is_align=torch.tensor(0))["P(A|B)"] > 0.5
    true_pos = data_out[hard_pred==1].sum()
    total_pred_pos = (hard_pred==1).sum().float()
    total_actual_pos = data_out.sum().float()
    precision = true_pos / total_pred_pos
    recall = true_pos / total_actual_pos
    return 2 * (precision*recall) / (precision + recall)

def metric_hard_accuracy_align(model, data_in, data_out, threshold:float):
    A_given_B = data_in[::2]
    B_given_A = data_in[1::2,:]
    data_out = data_out[::2]

    align_probs = torch.stack((model(A_given_B, is_align=torch.tensor(1))["P(A|B)"], model(B_given_A, is_align=torch.tensor(1))["P(A|B)"]), dim=1)
    p = torch.min(align_probs, dim=1).values
    hard_pred = p > threshold

    return (data_out == hard_pred).float().mean()

def metric_hard_f1_align(model, data_in, data_out, threshold:float):
    A_given_B = data_in[::2]
    B_given_A = data_in[1::2,:]
    data_out = data_out[::2]

    align_probs = torch.stack((model(A_given_B, is_align=torch.tensor(1))["P(A|B)"], model(B_given_A, is_align=torch.tensor(1))["P(A|B)"]), dim=1)
    p = torch.min(align_probs, dim=1).values
    hard_pred = p > threshold

    true_pos = data_out[hard_pred==1].sum()
    total_pred_pos = (hard_pred==1).sum().float()
    total_actual_pos = data_out.sum().float()

    precision = true_pos / total_pred_pos
    recall = true_pos / total_actual_pos

    return 2 * (precision*recall) / (precision + recall)

def metric_hard_accuracy_align_mean(model, data_in, data_out, threshold):
    A_given_B = data_in[::2]
    B_given_A = data_in[1::2,:]
    data_out = data_out[::2]

    align_probs = torch.stack((model(A_given_B, is_align=torch.tensor(1))["P(A|B)"], model(B_given_A, is_align=torch.tensor(1))["P(A|B)"]), dim=1)
    p = torch.mean(align_probs, dim=1)
    hard_pred = p > threshold

    return (data_out == hard_pred).float().mean()

def metric_hard_f1_align_mean(model, data_in, data_out, threshold):
    A_given_B = data_in[::2]
    B_given_A = data_in[1::2,:]
    data_out = data_out[::2]

    align_probs = torch.stack((model(A_given_B, is_align=torch.tensor(1))["P(A|B)"], model(B_given_A, is_align=torch.tensor(1))["P(A|B)"]), dim=1)
    p = torch.mean(align_probs, dim=1)
    hard_pred = p > threshold

    true_pos = data_out[hard_pred==1].sum()
    total_pred_pos = (hard_pred==1).sum().float()
    total_actual_pos = data_out.sum().float()

    precision = true_pos / total_pred_pos
    recall = true_pos / total_actual_pos

    return 2 * (precision*recall) / (precision + recall)


In [278]:
import torch
from torch.utils.data import Dataset
from typing import *
from dataclasses import dataclass, field
import wandb

try:
    from IPython import get_ipython
    if 'IPKernelApp' not in get_ipython().config:  # pragma: no cover
        raise ImportError("console")
except:
    pass
else:
    import ipywidgets as widgets
    from IPython.core.display import HTML, display

if TYPE_CHECKING:
    from learner import Learner, Recorder


class Callback:
    def learner_post_init(self, learner: Learner):
        pass

    def train_begin(self, learner: Learner):
        pass

    def epoch_begin(self, learner: Learner):
        pass

    def batch_begin(self, learner: Learner):
        pass

    def backward_end(self, learner: Learner):
        pass

    def batch_end(self, learner: Learner):
        pass

    def epoch_end(self, learner: Learner):
        pass

    def train_end(self, learner: Learner):
        pass

    def eval_align(self, learner: Learner, threshold:float):
        pass

    def metric_plots(self, l: Learner):
        pass

    def eval_end(self, l: Learner):
        pass

    def bias_metric(self, l: Learner):
        pass

class CallbackCollection:

    def __init__(self, *callbacks: Callback):
        self._callbacks = callbacks

    def __call__(self, action: str, *args, **kwargs):
        for c in self._callbacks:
            getattr(c, action)(*args, **kwargs)

    def __getattr__(self, action: str):
        return lambda *args, **kwargs: self.__call__(action, *args, **kwargs)


@dataclass
class GradientClipping(Callback):
    min: float = None
    max: float = None

    def backward_end(self, learner: Learner):
        for param in learner.model.parameters():
            if param.grad is not None:
                param.grad = param.grad.clamp(self.min, self.max)


@dataclass
class LossCallback(Callback):
    recorder: Recorder
    ds: Dataset
    weighted: bool = True

    @torch.no_grad()
    def train_begin(self, learner: Learner):
        self.epoch_end(learner)

    @torch.no_grad()
    def epoch_end(self, l: Learner):
        data_in, data_out = self.ds[:]
        if l.categories:
            split_in, split_out = l.split_data(data_in, data_out, split=2737)
            
            model_pred = []
            count = 0
            for item in split_in:
                if len(item)>0:
                    if count<2:
                        model_pred.append(l.model(item, is_align=torch.tensor(0)))
                    else:
                        model_pred.append(l.model(item, is_align=torch.tensor(1)))
                else:
                    model_pred.append({'P(A|B)':l.TensorNaN(device=data_in.device)})
                count+=1

            #model_pred = [l.model(item) if len(item)>0 else {'P(A|B)':l.TensorNaN(device=data_in.device)} for item in split_in]
            l.loss_fn(model_pred, split_out, l, self.recorder, weighted=self.weighted, categories=True)  
        else:
            output = l.model(data_in, is_align=torch.tensor(0))
            l.loss_fn(output, data_out, l, self.recorder, weighted=self.weighted) # this logs the data to the recorder


@dataclass
class MetricCallback(Callback):
    recorder: Recorder
    ds: Dataset
    data_categories: str
    metric: Callable
    use_wandb: bool = False
    name: Union[str, None] = None

    def __post_init__(self):
        if self.name is None:
            self.name = self.metric.__name__
        self.name = self.recorder.get_unique_name(self.name)

    @torch.no_grad()
    def train_begin(self, learner: Learner):
        self.epoch_end(learner)

    @torch.no_grad()
    def epoch_end(self, l: Learner):
        data_in, data_out = self.ds[:]
        metric_val = self.metric(l.model, data_in, data_out)
        self.recorder.update_({self.name: metric_val}, l.progress.current_epoch_iter)
        
        print("evaluation_" + self.data_categories + "_" + self.name, str(metric_val))
        
        if self.use_wandb:
            metric_name = "evaluation_" + self.data_categories + "_" + self.name
            wandb.log({metric_name: metric_val})

@dataclass
class EvalAlignment(Callback):
    recorder: Recorder
    ds: Dataset
    data_categories: str
    metric: callable
    use_wandb: bool = False
    name: Union[str, None] = None
        

    def __post_init__(self):
        if self.name is None:
            self.name = self.metric.__name__
        self.name = self.recorder.get_unique_name(self.name)

    @torch.no_grad()
    def eval_align(self, l: Learner, threshold: float):
        data_in, data_out = self.ds[:]
        metric_val = self.metric(l.model, data_in, data_out, threshold)
        self.recorder.update_({self.name: metric_val}, threshold)
        
        print("align_evaluation_" + self.data_categories + "_" + str(threshold) + "_" + self.name, str(metric_val))
        
        if self.use_wandb:
            metric_name = "align_evaluation_" + self.data_categories + "_" + str(threshold) + "_" + self.name
            wandb.log({metric_name: metric_val})


In [286]:
import torch
from torch import Tensor
import torch.nn as nn
from torch.nn import Module, Parameter
import torch.nn.functional as F


################################################
# Box Parametrization Layers
################################################
default_init_min_vol = torch.finfo(torch.float32).tiny



class BoxParam(Module):
    """
    An example class for creating a box parametrization.
    Don't inherit from this, it is just an example which contains the methods for a class to be used as a BoxParam
    layer. Refer to the docstring of the functions when implementing your own BoxParam.

    Note: to avoid naming conflicts with min/max functions, we refer to the min coordinate for a box as `z`, and the
    max coordinate as `Z`.
    """

    def __init__(self, num_models:int, num_boxes:int, dim:int, **kwargs):
        """
        Creates the Parameters used for the representation of boxes.

        :param num_models: Number of models
        :param num_boxes: Number of boxes
        :param dim: Dimension
        :param kwargs: Unused for now, but include this for future possible parameters.
        """
        # Remember to call:
        super().__init__()
        raise NotImplemented


    def forward(self, box_indices = slice(None, None, None), **kwargs) -> Tensor:
        """
        Returns a Tensor representing the boxes specified by `box_indices` in the form they should be used for training.

        :param box_indices: Slice, List, or Tensor of the box indices
        :param kwargs: Unused for now, but include this for future possible parameters.
        :return: Tensor of shape (model, id, zZ, dim).
        """
        raise NotImplemented


class Boxes(Module):
    """
    Parametrize boxes using the min coordinate and max coordinate,
    initialized to be in the unit hypercube.

    self.boxes[model, box, min/max, dim] \in [0,1]

    In this parametrization, the min and max coordinates are explicitly stored
    in separate dimensions (as shown above), which means that care must be
    taken to preserve max > min while training. (See MinBoxSize Callback.)
    """

    def __init__(self, num_models: int, num_boxes: int, dims: int,
                 init_min_vol: float = default_init_min_vol, method = "gibbs", gibbs_iter: int = 2000, **kwargs):
        super().__init__()
        self.boxes = Parameter(initialize_boxes_in_unit_cube((num_models, num_boxes), dims, init_min_vol, method, gibbs_iter, **kwargs))

    def forward(self, box_indices = slice(None, None, None), **kwargs) -> Tensor:
        """
        Returns a Tensor representing the box embeddings specified by box_indices.

        :param box_indices: Slice, List, or Tensor of the box indices
        :param kwargs: Unused for now, but include this for future possible parameters.
        :return: NamedTensor of shape (model, id, zZ, dim).
        """
        return self.boxes[:, box_indices]



class MinMaxSigmoidBoxes(Module):
    """
    Parametrize boxes using sigmoid to make them always valid and contained within the unit cube.

    self.boxes[model, box, 2, dim] in Reals


    In this parametrization, we first convert to the unit cube:

    unit_cube_boxes = torch.sigmoid(self.boxes)  # shape: (model, box, 2, dim)

    We now select the z/Z coordinates by taking the min/max over axis 2, i.e.

    z, _ = torch.min(unit_cube_boxes, dim=2)
    Z, _ = torch.max(unit_cube_boxes, dim=2)
    """

    def __init__(self, num_models: int, num_boxes: int, dim: int, init_min_vol: float = default_init_min_vol,  **kwargs):
        super().__init__()
        unit_boxes = Boxes(num_models, num_boxes, dim, init_min_vol, **kwargs)
        self._from_UnitBoxes(unit_boxes)
        del unit_boxes


    def _from_UnitBoxes(self, unit_boxes:Boxes):
        boxes = unit_boxes().detach().clone()
        self.boxes = Parameter(torch.log(boxes / (1-boxes)))


    def forward(self, box_indices = slice(None, None, None), **kwargs) -> Tensor:
        """
        Returns a Tensor representing the box embeddings specified by box_indices.

        :param box_indices: A NamedTensor of the box indices
        :param kwargs: Unused for now, but include this for future possible parameters.
        :return: Tensor of shape (model, id, zZ, dim).
        """
        unit_cube_boxes = torch.sigmoid(self.boxes)
        z, _ = torch.min(unit_cube_boxes, dim=2)
        Z, _ = torch.max(unit_cube_boxes, dim=2)
        return torch.stack((z,Z), dim=2)
    
class AlignmentBoxes(Module):
    """
    Parametrize boxes using sigmoid to make them always valid and contained within the unit cube.

    self.boxes[model, box, 2, dim] in Reals


    In this parametrization, we first convert to the unit cube:

    unit_cube_boxes = torch.sigmoid(self.boxes)  # shape: (model, box, 2, dim)

    We now select the z/Z coordinates by taking the min/max over axis 2, i.e.

    z, _ = torch.min(unit_cube_boxes, dim=2)
    Z, _ = torch.max(unit_cube_boxes, dim=2)
    """

    def __init__(self, num_models: int, num_boxes: int, dim: int, init_min_vol: float = default_init_min_vol,  **kwargs):
        super().__init__()
        unit_boxes = Boxes(num_models, num_boxes, dim, init_min_vol, **kwargs)
        self._from_UnitBoxes(unit_boxes)
        self.fc = nn.Linear(dim, dim)
        del unit_boxes


    def _from_UnitBoxes(self, unit_boxes:Boxes):
        boxes = unit_boxes().detach().clone()
        self.boxes = Parameter(torch.log(boxes / (1-boxes)))


    def forward(self, box_indices = slice(None, None, None), is_align=torch.tensor(0), **kwargs) -> Tensor:
        """
        Returns a Tensor representing the box embeddings specified by box_indices.

        :param box_indices: A NamedTensor of the box indices
        :param kwargs: Unused for now, but include this for future possible parameters.
        :return: Tensor of shape (model, id, zZ, dim).
        """
        unit_cube_boxes = torch.sigmoid(self.boxes)
        z, _ = torch.min(unit_cube_boxes, dim=2)
        Z, _ = torch.max(unit_cube_boxes, dim=2)

        if is_align.tolist() == 1:
#             print("z before1:", z[0][10])
#             print("w:", self.fc.weight)
            z = self.fc(z)
            Z = self.fc(Z)

        return torch.stack((z,Z), dim=2)


###############################################
# Downstream Model
###############################################

class WeightedSum(Module):
    def __init__(self, num_models: int) -> None:
        super().__init__()
        self.weights = Parameter(torch.rand(num_models))

    def forward(self, box_vols: Tensor) -> Tensor:
        return (F.softmax(self.weights, dim=0).unsqueeze(0) @ box_vols).squeeze()


class LogWeightedSum(Module):
    def __init__(self, num_models: int) -> None:
        super().__init__()
        self.weights = Parameter(torch.rand(num_models))

    def forward(self, log_box_vols: Tensor) -> Tensor:
        return (torch.logsumexp(self.weights + log_box_vols, 0) - torch.logsumexp(self.weights, 0))


class BoxModel(Module):
    def __init__(self, BoxParamType: type, vol_func: Callable,
                 num_models:int, num_boxes:int, dims:int,
                 init_min_vol: float = default_init_min_vol, universe_box: Optional[Callable] = None, **kwargs):
        super().__init__()
        self.box_embedding = BoxParamType(num_models, num_boxes, dims, init_min_vol, **kwargs)
        self.vol_func = vol_func

        if universe_box is None:
            z = torch.zeros(dims)
            Z = torch.ones(dims)
            self.universe_box = lambda _: torch.stack((z,Z))[None, None]
            self.universe_vol = lambda _: self.vol_func(self.universe_box(None)).squeeze()
            self.clamp = True
        else:
            self.universe_box = universe_box
            self.universe_vol = lambda b: self.vol_func(self.universe_box(b))
            self.clamp = False

        self.weights = WeightedSum(num_models)

    def forward(self, box_indices: Tensor, is_align: Tensor) -> Dict:
        #print("is_align in BoxModel:", is_align)
        # Unary
        box_embeddings_orig = self.box_embedding(is_align = is_align)
        if self.clamp:
            box_embeddings = box_embeddings_orig.clamp(0,1)
        else:
            box_embeddings = box_embeddings_orig

        universe_vol = self.universe_vol(box_embeddings)

        unary_probs = self.weights(self.vol_func(box_embeddings) / universe_vol)

        # Conditional
        A = box_embeddings[:, box_indices[:,0]]
        B = box_embeddings[:, box_indices[:,1]]
        A_int_B_vol = self.weights(self.vol_func(intersection(A, B)) / universe_vol) + torch.finfo(torch.float32).tiny
        B_vol = unary_probs[box_indices[:,1]] + torch.finfo(torch.float32).tiny
        P_A_given_B = torch.exp(torch.log(A_int_B_vol) - torch.log(B_vol))
        
        # symmetric same
        # print("you are in right place!")
#         A = box_embeddings[:, box_indices[:,0]]
#         B = box_embeddings[:, box_indices[:,1]]
#         A_int_B_vol = self.weights(self.vol_func(intersection(A, B)) / universe_vol) + torch.finfo(torch.float32).tiny
#         A_join_B_vol = self.weights(self.vol_func(join(A, B)) / universe_vol) + torch.finfo(torch.float32).tiny
#         P_A_given_B = torch.exp(torch.log(A_int_B_vol) - torch.log(A_join_B_vol))

        return {
            "unary_probs": unary_probs,
            "box_embeddings_orig": box_embeddings_orig,
            "A": A,
            "B": B,
            "P(A|B)": P_A_given_B,
        }


In [287]:
dims = 16
lr = 0.07342406890949607
#lr = 0.2
rns_ratio = 5
#box_type = MinMaxSigmoidBoxes, AlignmentBoxes
use_unary = False
unary_weight = 1e-2

In [288]:
box_model = BoxModel(
    BoxParamType=AlignmentBoxes,
    vol_func=soft_volume,
    num_models=1,
    num_boxes=num_boxes,
    dims=dims,
    method="tree").to(device)


#### IF YOU ARE LOADING FROM JULIA WITH ratio_neg=0, train_dl WILL ONLY CONTAIN POSITIVE EXAMPLES
#### THIS MEANS YOUR MODEL SHOULD USE NEGATIVE SAMPLING DURING TRAINING
train_dl = TensorDataLoader(train, batch_size=2**6, shuffle=True)

mouse_dl = TensorDataLoader(mouse_eval, batch_size=2**6)
human_dl = TensorDataLoader(human_eval, batch_size=2**6)

eval_dl = [mouse_dl, human_dl]

opt = torch.optim.Adam(box_model.parameters(), lr=lr)

In [289]:
def mean_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

def human_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

def mouse_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

def align_cond_kl_loss(model_out: ModelOutput, target: Tensor, eps: float = torch.finfo(torch.float32).tiny) -> Tensor:
    return kl_div_sym(model_out["P(A|B)"], target, eps).mean()

# See boxes/loss_functions.py file for more options. Note that you may have to changed them to fit your use case.
# Also note that "kl_div_sym" is just binary cross-entropy.

In [290]:
global my_batch_in
global my_batch_out
my_batch_in = None
my_batch_out = None

@dataclass
class Learner:
    train_dl: DataLoader
    model: Module
    loss_fn: Callable
    opt: optim.Optimizer
    callbacks: CallbackCollection = field(default_factory=CallbackCollection)
    recorder: Recorder = field(default_factory=Recorder)
    categories: bool = False
    use_wandb: bool = False
    reraise_keyboard_interrupt: bool = False
    reraise_stop_training_exceptions: bool = False

    def __post_init__(self):
        self.progress = Progress(0,0,len(self.train_dl))
        self.callbacks.learner_post_init(self)

    #the split parameter will be used to find human/mouse/align data, so you need change it when using diff dataset(index)
    def split_data(self, batch_in, batch_out, split):
        category = torch.zeros(size=(batch_in.shape[0],), dtype=int)

        batch_class = batch_in > split

        for i, (a,b) in enumerate(batch_class):
            if not a and not b:
                category[i] = 0
            elif a and b:
                category[i] = 1
            else:
                category[i] = 2

        self.mouse_in = batch_in[category == 0]
        self.human_in = batch_in[category == 1]
        self.align_in = batch_in[category == 2]

        self.mouse_out = batch_out[category == 0]
        self.human_out = batch_out[category == 1]
        self.align_out = batch_out[category == 2]

        # INPUT TO THE MODEL:
        data_in = (self.mouse_in, self.human_in, self.align_in)
        # TARGET/LABEL:
        data_out = (self.mouse_out, self.human_out, self.align_out)

        return data_in, data_out

    def TensorNaN(self, size:Union[None,List[int], Tuple[int]]=None, device=None, requires_grad:bool=True):
        if size is None:    
            return torch.tensor(float('nan'), device=device, requires_grad=requires_grad)
        else:
            return float('nan') * torch.zeros(size=size, device=device, requires_grad=requires_grad)


    def train(self, epochs, progress_bar = True):
        global my_batch_in
        global my_batch_out
        try:
            self.callbacks.train_begin(self)
            for epoch in trange(epochs, desc="Overall Training:", disable=not progress_bar):
                self.callbacks.epoch_begin(self)
                for iteration, batch in enumerate(tqdm(self.train_dl, desc="Current Batch:", leave=False, disable=not progress_bar)):
                    if len(batch) == 2: # KLUDGE
                        self.batch_in, self.batch_out = batch
                    else:
                        self.batch_in = batch[0]
                        self.batch_out = None
                    self.progress.increment()
                    self.callbacks.batch_begin(self)
                    self.opt.zero_grad()
                    
                    if self.categories:
                        self.data_in, self.data_out = self.split_data(self.batch_in, self.batch_out, split=2737)
                        #2737 is max mouse index  
                        self.model_pred = []
                        count = 0
                        for item in self.data_in:
                            if len(item)>0:
                                if count<2:
                                    self.model_pred.append(self.model(item, is_align=torch.tensor(0)))
                                else:
                                    self.model_pred.append(self.model(item, is_align=torch.tensor(0)))
                            else:
                                self.model_pred.append({'P(A|B)':self.TensorNaN(device=self.batch_in.device)})
                            count+=1
                        #self.model_pred = [self.model(item, is_align=is_align) if len(item)>0 else {'P(A|B)':self.TensorNaN(device=self.batch_in.device)} for item in self.data_in]
                        self.loss = self.loss_fn(self.model_pred, self.data_out, self, self.recorder, categories=True, use_wandb=self.use_wandb)                        
                    else:
                        self.model_out = self.model(self.batch_in, is_align=False)
                        if self.batch_out is None:
                            self.loss = self.loss_fn(self.model_out, self, self.recorder, categories=True, use_wandb=self.use_wandb)
                        else:
                            self.loss = self.loss_fn(self.model_out, self.batch_out, self, self.recorder, categories=True, use_wandb=self.use_wandb)
                        
                    # Log metrics inside your training loop
                    if self.use_wandb:
                        metrics = {'epoch': epoch, 'loss': self.loss}
                        wandb.log(metrics)

                    # print(self.recorder.dataframe)
                    self.loss.backward()
                    self.callbacks.backward_end(self)
                    self.opt.step()
                    self.callbacks.batch_end(self)
                # print(self.recorder.dataframe)
                
                # run evaluating at the end of every epoch
                #self.evaluation(np.arange(0.1, 1, 0.1))
                self.evaluation([0.5])
                
                
                self.callbacks.epoch_end(self)
        except StopTrainingException as e:
            print(e)
            if self.reraise_stop_training_exceptions:
                raise e
        except KeyboardInterrupt:
            print(f"Stopped training at {self.progress.partial_epoch_progress()} epochs due to keyboard interrupt.")
            if self.reraise_keyboard_interrupt:
                raise KeyboardInterrupt
        finally:
            self.callbacks.train_end(self)


    def evaluation(self, trials, progress_bar=True):
        with torch.no_grad():
            # self.callbacks.eval_begin(self)
            for t in trials:
                self.callbacks.eval_align(self, t)
            self.callbacks.metric_plots(self)
            self.callbacks.bias_metric(self)
            self.callbacks.eval_end(self)

In [291]:
# For this dataset we had unary probabilities as well as conditional probabilities. Our loss function will be a sum of these, which is provided by the following loss function wrapper:

# if use_unary:
#     loss_func = LossPieces(mean_cond_kl_loss, (unary_weight, mean_unary_kl_loss(unary_prob)))
# else:
#     loss_func = LossPieces(mean_cond_kl_loss)

loss_func = LossPieces( (0.05, mouse_cond_kl_loss), (0.05 ,human_cond_kl_loss), (0.05, align_cond_kl_loss))

metrics = [metric_hard_accuracy, metric_hard_f1]
align_metrics = [metric_hard_accuracy_align, metric_hard_f1_align, metric_hard_accuracy_align_mean, metric_hard_f1_align_mean]

rec_col = RecorderCollection()

#threshold = np.arange(0.1, 1, 0.1)
threshold = [0.5]

callbacks = CallbackCollection(
    LossCallback(rec_col.train, train),
    LossCallback(rec_col.dev, dev),
#     *(MetricCallback(rec_col.dev, dev, "dev", m) for m in metrics),
#     *(MetricCallback(rec_col.train, train, "train", m) for m in metrics),
    *(MetricCallback(rec_col.onto, human_eval, "human", m) for m in metrics),
    *(MetricCallback(rec_col.onto, mouse_eval, "mouse", m) for m in metrics),
    *(EvalAlignment(rec_col.train_align, tr_align, "train_align", m) for m in align_metrics),
    *(EvalAlignment(rec_col.dev_align, dev, "dev_align", m) for m in align_metrics),
    #JustGiveMeTheData(rec_col.probs, dev, get_probabilities),
    #BiasMetric(rec_col.bias, dev, pct_of_align_cond_on_human_as_min),
    #PlotMetrics(rec_col.dev_roc_plot, dev, roc_plot),
    #PlotMetrics(rec_col.dev_pr_plot, dev, pr_plot),
    #PlotMetrics(rec_col.tr_roc_plot, tr_align, roc_plot),
    #PlotMetrics(rec_col.tr_pr_plot, tr_align, pr_plot),
    #MetricCallback(rec_col.train, train, 'train', metric_pearson_r),
    #MetricCallback(rec_col.train, train, 'train', metric_spearman_r),
    #MetricCallback(rec_col.dev, dev, 'dev', metric_pearson_r),
    #MetricCallback(rec_col.dev, dev, 'dev', metric_spearman_r),
#     MetricCallback(rec_col.train, train, 'train', mean_reciprocal_rank),
#     MetricCallback(rec_col.dev, dev, 'dev', mean_reciprocal_rank),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.25, 10),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mean_cond_kl_loss", 0.5),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mouse_cond_kl_loss", 0.25, 10),
#     PercentIncreaseEarlyStopping(rec_col.dev, "mouse_cond_kl_loss", 0.5),
#     GradientClipping(-1000,1000),
#     RandomNegativeSampling(num_boxes, rns_ratio),
    StopIfNaN(),
)

# l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn)
l = Learner(train_dl, box_model, loss_func, opt, callbacks, recorder = rec_col.learn, categories=True)

In [292]:
nEpochs = 100
l.train(nEpochs)

evaluation_human_metric_hard_accuracy tensor(0.5000000000000000)
evaluation_human_metric_hard_f1 tensor(nan)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.5000000000000000)
evaluation_mouse_metric_hard_f1_1 tensor(nan)


HBox(children=(HTML(value='Overall Training:'), FloatProgress(value=0.0), HTML(value='')))

HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.6863149404525757)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.6793088912963867)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6434460282325745)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7219544053077698)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.4587458670139313)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.3109243512153625)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4224422574043274)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.4696969389915466)
evaluation_human_metric_hard_accuracy tensor(0.7182055711746216)
evaluation_human_metric_hard_f1 tensor(0.6280896663665771)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.7900608777999878)
evaluation_mouse_metric_hard_f1_1 tensor(0.7530636787414551)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.6995053291320801)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.7384284138679504)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.5869744420051575)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7075306177139282)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.4884488582611084)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.4599303007125854)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4372937381267548)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.5633803009986877)
evaluation_human_metric_hard_accuracy tensor(0.8630952239036560)
evaluation_human_metric_hard_f1 tensor(0.8665629029273987)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8766365647315979)
evaluation_mouse_metric_hard_f1_1 tensor(0.8825697898864746)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.6953833699226379)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.7354099154472351)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.5795547962188721)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7038327455520630)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5016501545906067)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5160256028175354)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4488448798656464)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.5824999809265137)
evaluation_human_metric_hard_accuracy tensor(0.8699187040328979)
evaluation_human_metric_hard_f1 tensor(0.8813559412956238)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8743315339088440)
evaluation_mouse_metric_hard_f1_1 tensor(0.8860462903976440)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7102226018905640)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.7563257813453674)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.5816158056259155)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7050275802612305)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5247524976730347)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5384615659713745)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4636963605880737)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.5992602109909058)
evaluation_human_metric_hard_accuracy tensor(0.8699187040328979)
evaluation_human_metric_hard_f1 tensor(0.8835759162902832)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8708279728889465)
evaluation_mouse_metric_hard_f1_1 tensor(0.8845678567886353)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7180544137954712)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.7646248936653137)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.5906842350959778)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7095641493797302)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5264026522636414)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5363489985466003)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4620462059974670)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.5975308418273926)
evaluation_human_metric_hard_accuracy tensor(0.8713704943656921)
evaluation_human_metric_hard_f1 tensor(0.8852331638336182)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8699981570243835)
evaluation_mouse_metric_hard_f1_1 tensor(0.8843504190444946)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7551525235176086)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.7948895096778870)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6079967021942139)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7183890938758850)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5181518197059631)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5394322276115417)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4702970385551453)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.5992509126663208)
evaluation_human_metric_hard_accuracy tensor(0.8697735071182251)
evaluation_human_metric_hard_f1 tensor(0.8843327164649963)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8699981570243835)
evaluation_mouse_metric_hard_f1_1 tensor(0.8847285509109497)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7642209529876709)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8027586936950684)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6125308871269226)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7207368016242981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5561056137084961)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5568369626998901)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4801980257034302)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6027743220329285)
evaluation_human_metric_hard_accuracy tensor(0.8667247295379639)
evaluation_human_metric_hard_f1 tensor(0.8821868896484375)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8664945363998413)
evaluation_mouse_metric_hard_f1_1 tensor(0.8819500803947449)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7769991755485535)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8129969239234924)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6236603260040283)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7265648841857910)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5610560774803162)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5667752027511597)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5033003091812134)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6185044050216675)
evaluation_human_metric_hard_accuracy tensor(0.8700639009475708)
evaluation_human_metric_hard_f1 tensor(0.8848281502723694)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8687073588371277)
evaluation_mouse_metric_hard_f1_1 tensor(0.8838309645652771)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7922506332397461)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8241451382637024)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6356141567230225)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7327690720558167)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5594059228897095)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5601317882537842)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5132012963294983)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6213093996047974)
evaluation_human_metric_hard_accuracy tensor(0.8702090382575989)
evaluation_human_metric_hard_f1 tensor(0.8850603699684143)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8665867447853088)
evaluation_mouse_metric_hard_f1_1 tensor(0.8821757435798645)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8075020313262939)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8373389244079590)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6516900062561035)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7416692376136780)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5792078971862793)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5728642940521240)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5165016651153564)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6219354867935181)
evaluation_human_metric_hard_accuracy tensor(0.8700639009475708)
evaluation_human_metric_hard_f1 tensor(0.8849466443061829)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8691683411598206)
evaluation_mouse_metric_hard_f1_1 tensor(0.8841916322708130)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8124485015869141)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8405187726020813)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6611706614494324)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7469211816787720)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5891088843345642)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5772495865821838)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5346534848213196)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6308900713920593)
evaluation_human_metric_hard_accuracy tensor(0.8694831728935242)
evaluation_human_metric_hard_f1 tensor(0.8845215439796448)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8677853345870972)
evaluation_mouse_metric_hard_f1_1 tensor(0.8831676840782166)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8264633417129517)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8519169688224792)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6685902476310730)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7510836124420166)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5957095623016357)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5840407013893127)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5561056137084961)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6408544182777405)
evaluation_human_metric_hard_accuracy tensor(0.8683217167854309)
evaluation_human_metric_hard_f1 tensor(0.8836135268211365)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8684307336807251)
evaluation_mouse_metric_hard_f1_1 tensor(0.8837095499038696)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8289365172386169)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8535121679306030)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6809563040733337)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7581250071525574)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5973597168922424)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5821917653083801)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5594059228897095)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6425702571868896)
evaluation_human_metric_hard_accuracy tensor(0.8697735071182251)
evaluation_human_metric_hard_f1 tensor(0.8847487568855286)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8696293830871582)
evaluation_mouse_metric_hard_f1_1 tensor(0.8846091032028198)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8322341442108154)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8559291362762451)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6941467523574829)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7657828330993652)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5940594077110291)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5802047848701477)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5676567554473877)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6478494405746460)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8842943310737610)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8692605495452881)
evaluation_mouse_metric_hard_f1_1 tensor(0.8843203783035278)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8408903479576111)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8626334667205811)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7019785642623901)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7704033255577087)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6072607040405273)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5882353186607361)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5742574334144592)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6494565606117249)
evaluation_human_metric_hard_accuracy tensor(0.8704994320869446)
evaluation_human_metric_hard_f1 tensor(0.8853470087051392)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8696293830871582)
evaluation_mouse_metric_hard_f1_1 tensor(0.8846279978752136)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8417147397994995)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8632478713989258)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7114592194557190)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7760717272758484)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6138613820075989)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5937500000000000)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5792078971862793)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6511628031730652)
evaluation_human_metric_hard_accuracy tensor(0.8702090382575989)
evaluation_human_metric_hard_f1 tensor(0.8850899338722229)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8696293830871582)
evaluation_mouse_metric_hard_f1_1 tensor(0.8846279978752136)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8462489843368530)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8667380809783936)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7188788056373596)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7805663347244263)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6023102402687073)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5851979255676270)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5924092531204224)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6574202775955200)
evaluation_human_metric_hard_accuracy tensor(0.8699187040328979)
evaluation_human_metric_hard_f1 tensor(0.8848624825477600)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8700903654098511)
evaluation_mouse_metric_hard_f1_1 tensor(0.8850077390670776)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8503710031509399)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8698458671569824)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7291838526725769)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7868959307670593)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6171616911888123)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5958187580108643)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6122112274169922)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6694796085357666)
evaluation_human_metric_hard_accuracy tensor(0.8699187040328979)
evaluation_human_metric_hard_f1 tensor(0.8848624825477600)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8716577291488647)
evaluation_mouse_metric_hard_f1_1 tensor(0.8862373232841492)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8540807962417603)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8726618885993958)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7283594608306885)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7863857150077820)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6171616911888123)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5944056510925293)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6221122145652771)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6751772761344910)
evaluation_human_metric_hard_accuracy tensor(0.8709349632263184)
evaluation_human_metric_hard_f1 tensor(0.8856886029243469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8727641701698303)
evaluation_mouse_metric_hard_f1_1 tensor(0.8871073126792908)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8536685705184937)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8723480701446533)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7407254576683044)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7941079735755920)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6138613820075989)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5923345088958740)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6155115365982056)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6695035696029663)
evaluation_human_metric_hard_accuracy tensor(0.8707897663116455)
evaluation_human_metric_hard_f1 tensor(0.8855746984481812)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8713811635971069)
evaluation_mouse_metric_hard_f1_1 tensor(0.8860014677047729)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8544930219650269)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8729758858680725)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7464963197708130)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7977638840675354)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6155115365982056)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5933681726455688)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6237623691558838)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6733524203300476)
evaluation_human_metric_hard_accuracy tensor(0.8696283102035522)
evaluation_human_metric_hard_f1 tensor(0.8846647739410400)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8717499375343323)
evaluation_mouse_metric_hard_f1_1 tensor(0.8863096833229065)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8586149811744690)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8760390281677246)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7510305047035217)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8006600737571716)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6270626783370972)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6007066369056702)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6254125237464905)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6743184924125671)
evaluation_human_metric_hard_accuracy tensor(0.8697735071182251)
evaluation_human_metric_hard_f1 tensor(0.8847784399986267)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8725797533988953)
evaluation_mouse_metric_hard_f1_1 tensor(0.8869622349739075)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8627369999885559)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8793040513992310)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7551525235176086)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8033112287521362)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6155115365982056)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5933681726455688)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6287128925323486)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6762589812278748)
evaluation_human_metric_hard_accuracy tensor(0.8702090382575989)
evaluation_human_metric_hard_f1 tensor(0.8851194977760315)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8726719617843628)
evaluation_mouse_metric_hard_f1_1 tensor(0.8870347142219543)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8631492257118225)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8796228766441345)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7576256990432739)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8049104213714600)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6287128925323486)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6017698645591736)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6821480393409729)
evaluation_human_metric_hard_accuracy tensor(0.8703542351722717)
evaluation_human_metric_hard_f1 tensor(0.8852332830429077)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8733173608779907)
evaluation_mouse_metric_hard_f1_1 tensor(0.8875429630279541)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8672712445259094)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8828238844871521)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7605111002922058)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8067841529846191)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6254125237464905)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5982300639152527)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6369637250900269)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6793003678321838)
evaluation_human_metric_hard_accuracy tensor(0.8699187040328979)
evaluation_human_metric_hard_f1 tensor(0.8848921060562134)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8728563785552979)
evaluation_mouse_metric_hard_f1_1 tensor(0.8871799707412720)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8685078024864197)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8837887048721313)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7675185203552246)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8113712072372437)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6254125237464905)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5982300639152527)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6793558001518250)
evaluation_human_metric_hard_accuracy tensor(0.8706446290016174)
evaluation_human_metric_hard_f1 tensor(0.8854609131813049)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8733173608779907)
evaluation_mouse_metric_hard_f1_1 tensor(0.8875429630279541)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8693322539329529)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8844330906867981)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7716405391693115)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8140939474105835)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6320132017135620)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6024955511093140)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6402640342712402)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6794117689132690)
evaluation_human_metric_hard_accuracy tensor(0.8699187040328979)
evaluation_human_metric_hard_f1 tensor(0.8848921060562134)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735939264297485)
evaluation_mouse_metric_hard_f1_1 tensor(0.8877609372138977)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8709810376167297)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8857247233390808)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7753503918647766)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8165600895881653)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6303630471229553)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6014235019683838)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6468647122383118)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6834319233894348)
evaluation_human_metric_hard_accuracy tensor(0.8696283102035522)
evaluation_human_metric_hard_f1 tensor(0.8846647739410400)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8718054294586182)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8863719105720520)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7827699780464172)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8215374350547791)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6287128925323486)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5989304780960083)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6452144980430603)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6824224591255188)
evaluation_human_metric_hard_accuracy tensor(0.8704994320869446)
evaluation_human_metric_hard_f1 tensor(0.8853470087051392)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8742393255233765)
evaluation_mouse_metric_hard_f1_1 tensor(0.8882700204849243)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8742786645889282)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8883193135261536)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7815333604812622)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8207036256790161)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6270626783370972)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5978648066520691)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6518151760101318)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6864784359931946)
evaluation_human_metric_hard_accuracy tensor(0.8702090382575989)
evaluation_human_metric_hard_f1 tensor(0.8851194977760315)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8779884576797485)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8912564516067505)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7856553792953491)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8234894275665283)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6287128925323486)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6003552079200745)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6485148668289185)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6844444870948792)
evaluation_human_metric_hard_accuracy tensor(0.8697735071182251)
evaluation_human_metric_hard_f1 tensor(0.8847784399986267)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8744237422943115)
evaluation_mouse_metric_hard_f1_1 tensor(0.8884155750274658)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8779884576797485)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8912564516067505)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7885407805442810)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8254508972167969)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6303630471229553)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6014235019683838)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6518151760101318)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6864784359931946)
evaluation_human_metric_hard_accuracy tensor(0.8699187040328979)
evaluation_human_metric_hard_f1 tensor(0.8848921060562134)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8744237422943115)
evaluation_mouse_metric_hard_f1_1 tensor(0.8884155750274658)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8788128495216370)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8919117450714111)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7897773981094360)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8262943029403687)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6303630471229553)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6014235019683838)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6567656993865967)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6895523071289062)
evaluation_human_metric_hard_accuracy tensor(0.8702090382575989)
evaluation_human_metric_hard_f1 tensor(0.8851194977760315)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8743315339088440)
evaluation_mouse_metric_hard_f1_1 tensor(0.8883427381515503)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8816982507705688)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8942130208015442)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7897773981094360)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8262943029403687)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6336633563041687)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6035714745521545)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6567656993865967)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6895523071289062)
evaluation_human_metric_hard_accuracy tensor(0.8697735071182251)
evaluation_human_metric_hard_f1 tensor(0.8847784399986267)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8808738589286804)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8935543298721313)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7922506332397461)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8279863595962524)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6303630471229553)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6014235019683838)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6584158539772034)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6905829906463623)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8741471767425537)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881971836090088)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8808738589286804)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8935543298721313)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7943116426467896)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8294017314910889)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6320132017135620)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6024955511093140)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6666666865348816)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6957831382751465)
evaluation_human_metric_hard_accuracy tensor(0.8693379759788513)
evaluation_human_metric_hard_f1 tensor(0.8844376206398010)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8743315339088440)
evaluation_mouse_metric_hard_f1_1 tensor(0.8883427381515503)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8808738589286804)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8935543298721313)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7947238087654114)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8296853303909302)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6320132017135620)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6024955511093140)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6666666865348816)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6957831382751465)
evaluation_human_metric_hard_accuracy tensor(0.8696283102035522)
evaluation_human_metric_hard_f1 tensor(0.8846647739410400)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8745159506797791)
evaluation_mouse_metric_hard_f1_1 tensor(0.8884882926940918)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8816982507705688)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8942130208015442)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7947238087654114)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8296853303909302)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6320132017135620)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6024955511093140)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6650164723396301)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6947368383407593)
evaluation_human_metric_hard_accuracy tensor(0.8700639009475708)
evaluation_human_metric_hard_f1 tensor(0.8850058317184448)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8736861348152161)
evaluation_mouse_metric_hard_f1_1 tensor(0.8878336548805237)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8825226426124573)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8948727846145630)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7967848181724548)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8311065435409546)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6336633563041687)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6035714745521545)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6683168411254883)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6968325376510620)
evaluation_human_metric_hard_accuracy tensor(0.8696283102035522)
evaluation_human_metric_hard_f1 tensor(0.8846647739410400)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8743315339088440)
evaluation_mouse_metric_hard_f1_1 tensor(0.8883427381515503)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8833470940589905)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8955333828926086)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7971970438957214)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8313913345336914)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6353135108947754)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6046512126922607)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6699669957160950)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6978852152824402)
evaluation_human_metric_hard_accuracy tensor(0.8693379759788513)
evaluation_human_metric_hard_f1 tensor(0.8844376206398010)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8743315339088440)
evaluation_mouse_metric_hard_f1_1 tensor(0.8883427381515503)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8837592601776123)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8958640694618225)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7984336614608765)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8322470188140869)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6353135108947754)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6046512126922607)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6683168411254883)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6968325376510620)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8745159506797791)
evaluation_mouse_metric_hard_f1_1 tensor(0.8884882926940918)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8837592601776123)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8958640694618225)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7996702194213867)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8331043720245361)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6336633563041687)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6035714745521545)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6732673048973083)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7000000476837158)
evaluation_human_metric_hard_accuracy tensor(0.8696283102035522)
evaluation_human_metric_hard_f1 tensor(0.8846647739410400)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8741471767425537)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881971836090088)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8841714859008789)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8961950540542603)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8004946708679199)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8336769938468933)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6353135108947754)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6032316088676453)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7010622024536133)
evaluation_human_metric_hard_accuracy tensor(0.8696283102035522)
evaluation_human_metric_hard_f1 tensor(0.8846647739410400)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8845836520195007)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8965262770652771)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8009068369865417)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8339635729789734)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6303630471229553)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5985662937164307)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7010622024536133)
evaluation_human_metric_hard_accuracy tensor(0.8694831728935242)
evaluation_human_metric_hard_f1 tensor(0.8845511674880981)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735939264297485)
evaluation_mouse_metric_hard_f1_1 tensor(0.8877609372138977)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8849958777427673)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8968576788902283)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8017312288284302)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8345373272895813)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6336633563041687)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6007194519042969)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7010622024536133)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839836120605469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8738705515861511)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879789710044861)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8845836520195007)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8965262770652771)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8025556206703186)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8351119160652161)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6320132017135620)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5996409654617310)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7010622024536133)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8841714859008789)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8961950540542603)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8017312288284302)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8345373272895813)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6303630471229553)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5985662937164307)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7010622024536133)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8734095692634583)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876156210899353)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8841714859008789)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8961950540542603)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8029678463935852)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8353994488716125)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6320132017135620)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5996409654617310)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7010622024536133)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735017776489258)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876882791519165)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8849958777427673)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8968576788902283)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8037922382354736)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8359752297401428)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6303630471229553)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5985662937164307)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7010622024536133)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8733173608779907)
evaluation_mouse_metric_hard_f1_1 tensor(0.8875429630279541)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8854081034660339)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8971893787384033)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8046166300773621)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8365517258644104)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6320132017135620)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.5996409654617310)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7010622024536133)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8734095692634583)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876156210899353)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8862324953079224)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8978534340858459)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8042044639587402)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8362632989883423)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6336633563041687)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6021505594253540)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6732673048973083)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6990881562232971)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8734095692634583)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876156210899353)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8858202695846558)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8975212574005127)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8042044639587402)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8362632989883423)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6353135108947754)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6032316088676453)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6732673048973083)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6990881562232971)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735939264297485)
evaluation_mouse_metric_hard_f1_1 tensor(0.8877609372138977)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8858202695846558)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8975212574005127)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8050288558006287)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8368403315544128)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6369637250900269)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6057347655296326)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6732673048973083)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6990881562232971)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839836120605469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8734095692634583)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876156210899353)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8858202695846558)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8975212574005127)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8054410815238953)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8371290564537048)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6765676736831665)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7012194991111755)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735939264297485)
evaluation_mouse_metric_hard_f1_1 tensor(0.8877609372138977)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8858202695846558)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8975212574005127)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8054410815238953)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8371290564537048)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7001521587371826)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839836120605469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735017776489258)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876882791519165)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8862324953079224)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8978534340858459)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8062654733657837)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8377072215080261)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6765676736831665)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7012194991111755)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735939264297485)
evaluation_mouse_metric_hard_f1_1 tensor(0.8877609372138977)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8866446614265442)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8981858491897583)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8058532476425171)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8374180197715759)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7001521587371826)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735939264297485)
evaluation_mouse_metric_hard_f1_1 tensor(0.8877609372138977)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8866446614265442)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8981858491897583)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8058532476425171)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8374180197715759)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735017776489258)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876882791519165)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8866446614265442)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8981858491897583)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8058532476425171)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8374180197715759)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6732673048973083)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6981707811355591)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8733173608779907)
evaluation_mouse_metric_hard_f1_1 tensor(0.8875429630279541)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8862324953079224)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8978534340858459)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8062654733657837)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8377072215080261)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6732673048973083)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6981707811355591)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735017776489258)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876882791519165)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8866446614265442)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8981858491897583)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8070898652076721)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8382860422134399)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8734095692634583)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876156210899353)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8874691128730774)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8988513946533203)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8070898652076721)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8382860422134399)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8734095692634583)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876156210899353)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8866446614265442)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8981858491897583)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8070898652076721)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8382860422134399)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6732673048973083)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6981707811355591)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735017776489258)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876882791519165)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8083264827728271)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8391559720039368)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6732673048973083)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6981707811355591)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735017776489258)
evaluation_mouse_metric_hard_f1_1 tensor(0.8876882791519165)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8083264827728271)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8391559720039368)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735939264297485)
evaluation_mouse_metric_hard_f1_1 tensor(0.8877609372138977)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8079142570495605)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8388658165931702)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839836120605469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8736861348152161)
evaluation_mouse_metric_hard_f1_1 tensor(0.8878336548805237)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8079142570495605)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8388658165931702)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839836120605469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8736861348152161)
evaluation_mouse_metric_hard_f1_1 tensor(0.8878336548805237)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8083264827728271)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8391559720039368)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8736861348152161)
evaluation_mouse_metric_hard_f1_1 tensor(0.8878336548805237)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8083264827728271)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8391559720039368)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839836120605469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8736861348152161)
evaluation_mouse_metric_hard_f1_1 tensor(0.8878336548805237)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8091508746147156)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8397369384765625)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839836120605469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8091508746147156)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8397369384765625)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8738705515861511)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879789710044861)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8099752664566040)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8403186202049255)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8736861348152161)
evaluation_mouse_metric_hard_f1_1 tensor(0.8878336548805237)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8095630407333374)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8400276899337769)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840970396995544)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8099752664566040)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8403186202049255)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8099752664566040)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8403186202049255)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8099752664566040)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8403186202049255)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8099752664566040)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8403186202049255)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8099752664566040)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8403186202049255)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8099752664566040)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8403186202049255)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8739627599716187)
evaluation_mouse_metric_hard_f1_1 tensor(0.8880518078804016)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8099752664566040)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8403186202049255)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8738705515861511)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879789710044861)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8738705515861511)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879789710044861)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8738705515861511)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879789710044861)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8737783432006836)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879063129425049)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8690476417541504)
evaluation_human_metric_hard_f1 tensor(0.8842105269432068)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8738705515861511)
evaluation_mouse_metric_hard_f1_1 tensor(0.8879789710044861)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8739627599716187)
evaluation_mouse_metric_hard_f1_1 tensor(0.8880518078804016)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8739627599716187)
evaluation_mouse_metric_hard_f1_1 tensor(0.8880518078804016)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8739627599716187)
evaluation_mouse_metric_hard_f1_1 tensor(0.8880518078804016)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8739627599716187)
evaluation_mouse_metric_hard_f1_1 tensor(0.8880518078804016)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8874691128730774)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8988513946533203)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8874691128730774)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8988513946533203)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8874691128730774)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8988513946533203)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6765676736831665)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.7003058195114136)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8870568871498108)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8985185027122498)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8874691128730774)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8988513946533203)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8740549683570862)
evaluation_mouse_metric_hard_f1_1 tensor(0.8881244659423828)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8874691128730774)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8988513946533203)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8739627599716187)
evaluation_mouse_metric_hard_f1_1 tensor(0.8880518078804016)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8874691128730774)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8988513946533203)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8103874921798706)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8406098484992981)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.6386138796806335)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.6068222522735596)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.6749175190925598)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.6992366909980774)
evaluation_human_metric_hard_accuracy tensor(0.8691927790641785)
evaluation_human_metric_hard_f1 tensor(0.8843240141868591)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8739627599716187)
evaluation_mouse_metric_hard_f1_1 tensor(0.8880518078804016)



In [232]:
nEpochs = 100
l.train(nEpochs)

evaluation_human_metric_hard_accuracy tensor(0.5000000000000000)
evaluation_human_metric_hard_f1 tensor(nan)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.5000000000000000)
evaluation_mouse_metric_hard_f1_1 tensor(nan)


HBox(children=(HTML(value='Overall Training:'), FloatProgress(value=0.0), HTML(value='')))

HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.5032976269721985)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.0131040131673217)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.5173124670982361)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.0684168711304665)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5000000000000000)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(nan)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5000000000000000)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.0065573775209486)
evaluation_human_metric_hard_accuracy tensor(0.7224158048629761)
evaluation_human_metric_hard_f1 tensor(0.6358095407485962)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.7901530265808105)
evaluation_mouse_metric_hard_f1_1 tensor(0.7525548934936523)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.5251442790031433)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.1111111044883728)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.6739488840103149)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.5598219633102417)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5016501545906067)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.0503144674003124)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5099009871482849)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.1633802801370621)
evaluation_human_metric_hard_accuracy tensor(0.8712252974510193)
evaluation_human_metric_hard_f1 tensor(0.8739877939224243)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8803245425224304)
evaluation_mouse_metric_hard_f1_1 tensor(0.8865186572074890)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.5399835109710693)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.1889534890651703)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.7757625579833984)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.7536231875419617)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5132012963294983)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.1399416923522949)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5198019742965698)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.2595419585704803)
evaluation_human_metric_hard_accuracy tensor(0.8734030127525330)
evaluation_human_metric_hard_f1 tensor(0.8837953209877014)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8755301237106323)
evaluation_mouse_metric_hard_f1_1 tensor(0.8868967890739441)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.5560593605041504)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.2628336548805237)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8050288558006287)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8010096549987793)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5297029614448547)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.1971831023693085)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5181518197059631)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3080568611621857)
evaluation_human_metric_hard_accuracy tensor(0.8735482096672058)
evaluation_human_metric_hard_f1 tensor(0.8861289024353027)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8717499375343323)
evaluation_mouse_metric_hard_f1_1 tensor(0.8852593302726746)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.5605935454368591)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.2996057868003845)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8169826865196228)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8224000334739685)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5231022834777832)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.1949860751628876)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5099009871482849)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3203661143779755)
evaluation_human_metric_hard_accuracy tensor(0.8709349632263184)
evaluation_human_metric_hard_f1 tensor(0.8850975632667542)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8718421459197998)
evaluation_mouse_metric_hard_f1_1 tensor(0.8859907984733582)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.5692498087882996)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.3296985328197479)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8359439373016357)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8453768491744995)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5264026522636414)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2222222238779068)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5115511417388916)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3451327383518219)
evaluation_human_metric_hard_accuracy tensor(0.8684669137001038)
evaluation_human_metric_hard_f1 tensor(0.8833676576614380)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8679697513580322)
evaluation_mouse_metric_hard_f1_1 tensor(0.8830639123916626)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.5985160470008850)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.4082624912261963)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8425391316413879)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8556311726570129)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5264026522636414)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2222222238779068)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5132012963294983)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3572984635829926)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8838286399841309)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8694449663162231)
evaluation_mouse_metric_hard_f1_1 tensor(0.8842759728431702)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.6253091692924500)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.4687317609786987)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8458367586135864)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8611729741096497)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5313531160354614)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2324324399232864)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5099009871482849)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3585313260555267)
evaluation_human_metric_hard_accuracy tensor(0.8670151233673096)
evaluation_human_metric_hard_f1 tensor(0.8825038075447083)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8676009774208069)
evaluation_mouse_metric_hard_f1_1 tensor(0.8829665780067444)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.6492168307304382)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.5259053111076355)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8557295799255371)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8711340427398682)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5363036394119263)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2466488033533096)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5115511417388916)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3675213456153870)
evaluation_human_metric_hard_accuracy tensor(0.8641114830970764)
evaluation_human_metric_hard_f1 tensor(0.8802762627601624)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8685229420661926)
evaluation_mouse_metric_hard_f1_1 tensor(0.8837246894836426)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.6698268651962280)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.5681940913200378)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8557295799255371)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8717008829116821)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5379537940025330)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2553191483020782)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5033003091812134)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3582089245319366)
evaluation_human_metric_hard_accuracy tensor(0.8681765198707581)
evaluation_human_metric_hard_f1 tensor(0.8835001587867737)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8692605495452881)
evaluation_mouse_metric_hard_f1_1 tensor(0.8843016028404236)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.6887881159782410)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.6040901541709900)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8516075611114502)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8689956068992615)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5445544719696045)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2736842036247253)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5066006779670715)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3678646981716156)
evaluation_human_metric_hard_accuracy tensor(0.8664343953132629)
evaluation_human_metric_hard_f1 tensor(0.8820815086364746)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8711045384407043)
evaluation_mouse_metric_hard_f1_1 tensor(0.8858029842376709)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7135201692581177)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.6477445363998413)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8565539717674255)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8740954995155334)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5346534848213196)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2578947544097900)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5049505233764648)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3644067943096161)
evaluation_human_metric_hard_accuracy tensor(0.8668699264526367)
evaluation_human_metric_hard_f1 tensor(0.8824811577796936)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8703669309616089)
evaluation_mouse_metric_hard_f1_1 tensor(0.8851682543754578)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7320692539215088)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.6823069453239441)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8586149811744690)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8758595585823059)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5346534848213196)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2617801129817963)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5066006779670715)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3705263137817383)
evaluation_human_metric_hard_accuracy tensor(0.8683217167854309)
evaluation_human_metric_hard_f1 tensor(0.8835836052894592)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8708279728889465)
evaluation_mouse_metric_hard_f1_1 tensor(0.8855859637260437)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7629843354225159)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.7271001338958740)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8557295799255371)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8736461997032166)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5330032706260681)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2610965967178345)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4983498454093933)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3666666448116302)
evaluation_human_metric_hard_accuracy tensor(0.8675957918167114)
evaluation_human_metric_hard_f1 tensor(0.8830469250679016)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8704591393470764)
evaluation_mouse_metric_hard_f1_1 tensor(0.8853154778480530)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7745259404182434)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.7456997036933899)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8561418056488037)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8740527033805847)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5363036394119263)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2701298594474792)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5033003091812134)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3742203712463379)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839836120605469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8728563785552979)
evaluation_mouse_metric_hard_f1_1 tensor(0.8871983885765076)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.7906017899513245)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.7676120996475220)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8544930219650269)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8729758858680725)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5330032706260681)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2687338590621948)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5016501545906067)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3734439611434937)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839538097381592)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8722109794616699)
evaluation_mouse_metric_hard_f1_1 tensor(0.8866536021232605)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8309975266456604)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8215840458869934)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8495465517044067)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8690348267555237)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5346534848213196)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2806122303009033)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5000000000000000)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3752577304840088)
evaluation_human_metric_hard_accuracy tensor(0.8689024448394775)
evaluation_human_metric_hard_f1 tensor(0.8840672969818115)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8728563785552979)
evaluation_mouse_metric_hard_f1_1 tensor(0.8871983885765076)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8483099937438965)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8430033922195435)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8524320125579834)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8713155984878540)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5363036394119263)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2776349782943726)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5016501545906067)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3760330677032471)
evaluation_human_metric_hard_accuracy tensor(0.8700639009475708)
evaluation_human_metric_hard_f1 tensor(0.8850058317184448)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8729485273361206)
evaluation_mouse_metric_hard_f1_1 tensor(0.8872709274291992)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8561418056488037)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8530526161193848)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8495465517044067)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8692224621772766)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5313531160354614)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2717948555946350)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4966996610164642)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3737166225910187)
evaluation_human_metric_hard_accuracy tensor(0.8684669137001038)
evaluation_human_metric_hard_f1 tensor(0.8837566971778870)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8739627599716187)
evaluation_mouse_metric_hard_f1_1 tensor(0.8880700469017029)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8771640658378601)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8771640658378601)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8474856019020081)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8676680922508240)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5379537940025330)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2893401086330414)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5016501545906067)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3786008059978485)
evaluation_human_metric_hard_accuracy tensor(0.8694831728935242)
evaluation_human_metric_hard_f1 tensor(0.8845511674880981)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8735939264297485)
evaluation_mouse_metric_hard_f1_1 tensor(0.8877793550491333)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8878812789916992)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8891605734825134)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8441879749298096)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8651925921440125)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5379537940025330)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2893401086330414)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5016501545906067)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3786008059978485)
evaluation_human_metric_hard_accuracy tensor(0.8709349632263184)
evaluation_human_metric_hard_f1 tensor(0.8856886029243469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8736861348152161)
evaluation_mouse_metric_hard_f1_1 tensor(0.8878520131111145)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.8973619341850281)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.8994751572608948)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8450123667716980)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8658101558685303)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5313531160354614)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2791878283023834)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.5000000000000000)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3778234124183655)
evaluation_human_metric_hard_accuracy tensor(0.8687572479248047)
evaluation_human_metric_hard_f1 tensor(0.8839836120605469)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8747925758361816)
evaluation_mouse_metric_hard_f1_1 tensor(0.8887249827384949)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.9097279310226440)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.9132672548294067)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8441879749298096)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8651925921440125)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5313531160354614)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2864321768283844)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4966996610164642)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3762781023979187)
evaluation_human_metric_hard_accuracy tensor(0.8694831728935242)
evaluation_human_metric_hard_f1 tensor(0.8845511674880981)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8739627599716187)
evaluation_mouse_metric_hard_f1_1 tensor(0.8880700469017029)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))

align_evaluation_train_align_0.5_metric_hard_accuracy_align tensor(0.9163231849670410)
align_evaluation_train_align_0.5_metric_hard_f1_align tensor(0.9196676015853882)
align_evaluation_train_align_0.5_metric_hard_accuracy_align_mean tensor(0.8421269655227661)
align_evaluation_train_align_0.5_metric_hard_f1_align_mean tensor(0.8636525869369507)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align tensor(0.5330032706260681)
align_evaluation_dev_align_0.5_metric_hard_f1_align tensor(0.2871536612510681)
align_evaluation_dev_align_0.5_metric_hard_accuracy_align_mean tensor(0.4966996610164642)
align_evaluation_dev_align_0.5_metric_hard_f1_align_mean tensor(0.3762781023979187)
evaluation_human_metric_hard_accuracy tensor(0.8700639009475708)
evaluation_human_metric_hard_f1 tensor(0.8850058317184448)
evaluation_mouse_metric_hard_accuracy_1 tensor(0.8753457665443420)
evaluation_mouse_metric_hard_f1_1 tensor(0.8891621828079224)


HBox(children=(HTML(value='Current Batch:'), FloatProgress(value=0.0, max=177.0), HTML(value='')))


Stopped training at 24.548022598870055 epochs due to keyboard interrupt.
