## Import and set args

In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from data.data import prepare_data
from collections import namedtuple
from utils import set_seed, Logger, CSVBatchLogger, log_args, add_args
import argparse
import os
import pandas as pd
import torch
from tqdm.auto import tqdm
from pprint import pprint

sampling_methods = ['subsample', 'reweight']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using {} device'.format(device))

parser = argparse.ArgumentParser()
parser.add_argument('--resnet_width', type=int, default=None)
parser.add_argument('--model_path', type=str)
parser.add_argument('--lr', type=float, default='1e-3')
parser.add_argument('--weight_decay', type=float, default='1e-6')
parser.add_argument('--sampling_method', choices=sampling_methods, default=None)
parser.add_argument('--n_epochs', type=int, default=25)
parser.add_argument('--log_dir', type=str, default='./phase2_log')
parser.add_argument('--seed', type=int, default=0)



# '--dataset', 'CelebA', '-s', 'confounder', '-t', 'Blond_Hair', '-c', 'Male', '--log_dir', 'twophaselog_w32', '--seed', '0', 
toparse = ['--resnet_width', '16',
          '--model_path', 'log_w16_seed0/last_model_50.pth',
          '--sampling_method', 'reweight']
args = parser.parse_args(toparse)

# setup logging
mode = 'w'
if not os.path.exists(args.log_dir):
    os.makedirs(args.log_dir)
logger = Logger(os.path.join(args.log_dir, 'log.txt'), mode)
# Record args
log_args(args, logger)
set_seed(args.seed)

Using cuda device
Resnet width: 16
Model path: log_w16_seed0/last_model_50.pth
Lr: 0.001
Weight decay: 1e-06
Sampling method: reweight
N epochs: 25
Log dir: ./phase2_log
Seed: 0



## Loads data
This is where we adjust the sampling strategy. 

In [2]:
from data.celebA_dataset import CelebADataset
from models import model_attributes
from data.dro_dataset import DRODataset

    
root_dir = '/home/thiennguyen/research/datasets/celebA/'  # dir that contains data
target_name= 'Blond_Hair'  # we are classifying whether the input image is blond or not
confounder_names= ['Male']  # we aim to avoid learning spurious features... here it's the gender
model_type= 'resnet10vw'  # what model we are using to process --> this is to determine the input size to rescale the image
augment_data= False
fraction=1.0
splits = ['train', 'val', 'test']
n_classes = 4

full_dataset = CelebADataset(root_dir=root_dir,
        target_name=target_name,
        confounder_names=confounder_names,
        model_type=model_type,  # this string is to get the model's input size (for resizing) and input type (image or precomputed)
        augment_data=augment_data)  # augment data adds random resized crop and random flip.

subsets = full_dataset.get_splits(       # basically return the Subsets object with the appropriate indices for train/val/test
        splits,                          # also implements subsampling --> just remove random indices of the appropriate groups in train
        train_frac=fraction,   # fraction means how much of the train data to use --> randomly remove if less than 1
        subsample_to_minority=(args.sampling_method == 'subsample'))

dro_subsets = [  
    DRODataset(
        subsets[split],  # process each subset separately --> applying the transform parameter.
        process_item_fn=None,
        n_groups=full_dataset.n_groups,
        n_classes=full_dataset.n_classes,
        group_str_fn=full_dataset.group_str) \
    for split in splits]

train_data, val_data, test_data = dro_subsets
train_loader = train_data.get_loader(train=True, reweight_groups=(args.sampling_method == 'reweight'), batch_size=128)
val_loader = val_data.get_loader(train=False, reweight_groups=None, batch_size=5)
test_loader = test_data.get_loader(train=False, reweight_groups=None, batch_size=128)

## Loads Models 

In [5]:
from variable_width_resnet import resnet10vw

# load saved model
model = resnet10vw(args.resnet_width, num_classes=n_classes)
# model.load_state_dict(torch.load(args.model_path))
model.to(device)

# freeze everything except the last layer
for name, param in model.named_parameters():
    if name not in ['fc.weight', 'fc.bias']:
        param.requires_grad = False

# make sure that really worked
parameters = list(filter(lambda p: p.requires_grad, model.parameters()))
assert len(parameters) == 2  # fc.weight, fc.bias


# initialize the criterion and optimizer
criterion = torch.nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=args.lr,
            momentum=0.9,
            weight_decay=args.weight_decay)

In [21]:
from variable_width_resnet import resnet10vw

model = resnet10vw(16, num_classes=4)
model2 = resnet10vw(16, num_classes=2)

model.to(device)
model2.to(device)

criteriontest = torch.nn.CrossEntropyLoss(reduction='none')

In [22]:


def pnorm(weights, p):
    normB = torch.norm(weights, 2, 1)
    ws = weights.clone()
    for i in range(weights.size(0)):
        ws[i] = ws[i] / torch.pow(normB[i], p)
    return ws


for batch_idx, batch in enumerate((val_loader)):
    batch = tuple(t.to(device) for t in batch)
    x,y,g = batch
    outputs4 = model(x)
    outputs2 = model2(x)
    loss4 = criteriontest(outputs4, g)
    loss2 = criteriontest(outputs2, y)
    break
    

In [23]:
group_idx = g 
n_groups = 4
losses = loss2

group_map = (group_idx == torch.arange(n_groups).unsqueeze(1).long().cuda()).float()
group_count = group_map.sum(1)
group_denom = group_count + (group_count==0).float() # avoid nans

group_loss = (group_map @ losses.view(-1))/group_denom
group_loss4 = (group_map @ loss4.view(-1))/group_denom

In [35]:
step_size = 0.01

adv_probs2 = torch.ones(n_groups).cuda()/n_groups
adv_probs4 = torch.ones(n_groups).cuda()/n_groups

adv_probs2 = adv_probs2 * torch.exp(step_size*group_loss)
print(adv_probs2.sum())
adv_probs2 = adv_probs2/adv_probs2.sum()

adv_probs4 = adv_probs4 * torch.exp(step_size*group_loss4)
print(adv_probs4.sum())
adv_probs4 = adv_probs4/adv_probs4.sum()

print(adv_probs2)

print(adv_probs4)

tensor(1.0032, device='cuda:0', grad_fn=<SumBackward0>)
tensor(1.0069, device='cuda:0', grad_fn=<SumBackward0>)
tensor([0.2507, 0.2509, 0.2492, 0.2492], device='cuda:0',
       grad_fn=<DivBackward0>)
tensor([0.2518, 0.2517, 0.2483, 0.2483], device='cuda:0',
       grad_fn=<DivBackward0>)


In [25]:
print(f'x: {x.shape}\ny: {y}\ng: {g}')
print('group_map\n',group_map)
print(f'group_count\n{group_count}')
print(f'group_denom\n{group_denom}')
print(f'loss2\n{loss2}')
print(f'group_loss2\n{group_loss}')

print(f'loss4\n{loss4}')
print(f'group_loss4\n{group_loss4}')




x: torch.Size([5, 3, 224, 224])
y: tensor([0, 0, 0, 0, 0], device='cuda:0')
g: tensor([0, 1, 0, 1, 0], device='cuda:0')
group_map
 tensor([[1., 0., 1., 0., 1.],
        [0., 1., 0., 1., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]], device='cuda:0')
group_count
tensor([3., 2., 0., 0.], device='cuda:0')
group_denom
tensor([3., 2., 1., 1.], device='cuda:0')
loss2
tensor([0.6003, 0.6972, 0.6142, 0.6508, 0.6186], device='cuda:0',
       grad_fn=<NllLossBackward>)
group_loss2
tensor([0.6110, 0.6740, 0.0000, 0.0000], device='cuda:0',
       grad_fn=<DivBackward0>)
loss4
tensor([1.3283, 1.2866, 1.3695, 1.4120, 1.5109], device='cuda:0',
       grad_fn=<NllLossBackward>)
group_loss4
tensor([1.4029, 1.3493, 0.0000, 0.0000], device='cuda:0',
       grad_fn=<DivBackward0>)


## Train classifier

In [4]:
import csv
from pprint import pprint


def write_to_writer(writer, content):
    writer.writerow(content)
    pprint(content)

mode = 'w'
    
train_path = open(os.path.join(args.log_dir, 'train.csv'), mode)
val_path = open(os.path.join(args.log_dir, 'val.csv'), mode)
test_path = open(os.path.join(args.log_dir, 'test.csv'), mode)

train_columns = ['epoch','total_acc', 'group0_acc', 'group1_acc', 'group2_acc', 'group3_acc', 'split_acc', 'loss', 'avg_margin', 'group0_margin', 
                'group1_margin', 'group2_margin', 'group3_margin']

valtest_columns = ['epoch', 'total_acc', 'group0_acc', 'group1_acc', 'group2_acc', 'group3_acc', 'split_acc',
                  'avg_margin', 'group0_margin', 'group1_margin', 'group2_margin', 'group3_margin']

train_writer = csv.DictWriter(train_path, fieldnames=train_columns)
train_writer.writeheader()
val_writer = csv.DictWriter(val_path, fieldnames=valtest_columns)
val_writer.writeheader()
test_writer = csv.DictWriter(test_path, fieldnames=valtest_columns)
test_writer.writeheader()


def run_epoch(epoch, model, device, optimizer, loader, loss_computer, writer, is_training):
    """
    Train the classifier
    """
    model.train() if is_training else model.eval()
    
    running_loss, total_margin = 0, 0   # keep track of avg loss, margin in train
    l_correct, g_correct, total = 0, 0, 0  # for validation
    group_track = {f'g{i}':{'correct':0, 'total':0, 'margin': 0} for i in range(4)}  # keeps track of counts of #correct, #total for each group g0-g4
    log_train_every = 200    
    
    with (torch.enable_grad() if is_training else torch.no_grad()):
        for batch_idx, batch in enumerate(tqdm(loader)):
            batch = tuple(t.to(device) for t in batch)
            x,y,g = batch
            outputs = model(x)
            
            if is_training:
                optimizer.zero_grad()
                loss = loss_computer(outputs, g)  # we are classifying groups
                loss.backward()
                optimizer.step()
                # print statistics
                running_loss += loss.item()
                if (batch_idx % log_train_every) == log_train_every -1:    # print every 200 mini-batches
                    print('[%d, %5d] loss: %.3f, avg_margin: %.3f' %
                          (epoch, batch_idx + 1, running_loss / log_train_every, total_margin/total))
                    running_loss = 0.0
            
            # extra validation
            with torch.no_grad():
                # accuracies and margins
                total += g.size(0)  # total
                _, predicted = torch.max(outputs.data, 1)  # get predictions
                maskg = torch.zeros((len(g), 4), dtype=bool, device=device)
                maskg[np.arange(len(g)),g] = 1
                margins = outputs[maskg]-torch.max(outputs*(~maskg),dim=1)[0]
                total_margin += margins.sum().item()
                for g_idx in range(4):
                    group_track[f'g{g_idx}']['correct'] += ((predicted == g)[g==g_idx]).sum().item()  # correctly predict
                    group_track[f'g{g_idx}']['total'] += (g==g_idx).sum().item()  # total number of group's instance encountered
                    group_track[f'g{g_idx}']['margin'] += margins[g==g_idx].sum().item()
                g_correct += (predicted == g).sum().item()
                l_correct += (predicted//2 == y).sum().item()

        # write stats in dict and csv
        stats_dict = {'epoch':epoch, 'total_acc':f"{l_correct/total:.4f}", 'split_acc':f'{g_correct/total:.4f}', 'avg_margin': f"{total_margin/total:.4f}"}
        for g in range(4):
            stats_dict[f'group{g}_acc'] = f"{group_track[f'g{g}']['correct']/group_track[f'g{g}']['total']:.4f}"
            stats_dict[f'group{g}_margin'] = f"{group_track[f'g{g}']['margin']/group_track[f'g{g}']['total']:.4f}"
        if is_training:
            stats_dict['loss'] = f'{running_loss/log_train_every:.4f}'
            
        write_to_writer(writer, stats_dict)

In [5]:
for epoch in range(args.n_epochs):
    # train
    print(f'Train epoch {epoch}')
    run_epoch(epoch+1, model, device, optimizer, train_loader, criterion, train_writer, is_training=True)
    
    # validate
    print(f'Validate epoch {epoch}')
    run_epoch(epoch+1, model, device, optimizer, val_loader, criterion, val_writer, is_training=False)
    

Train epoch 0


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[1,   200] loss: 0.280, avg_margin: 3.858
[1,   400] loss: 0.188, avg_margin: 3.879
[1,   600] loss: 0.189, avg_margin: 3.886
[1,   800] loss: 0.184, avg_margin: 3.898
[1,  1000] loss: 0.181, avg_margin: 3.908
[1,  1200] loss: 0.187, avg_margin: 3.914

{'avg_margin': '3.9181',
 'epoch': 1,
 'group0_acc': '0.9403',
 'group0_margin': '5.0920',
 'group1_acc': '0.9267',
 'group1_margin': '3.2736',
 'group2_acc': '0.9397',
 'group2_margin': '4.6347',
 'group3_acc': '0.8971',
 'group3_margin': '2.6794',
 'loss': '0.0651',
 'split_acc': '0.9259',
 'total_acc': '0.9342'}
Validate epoch 0


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '3.9809',
 'epoch': 1,
 'group0_acc': '0.8901',
 'group0_margin': '4.7389',
 'group1_acc': '0.9033',
 'group1_margin': '3.1593',
 'group2_acc': '0.8744',
 'group2_margin': '4.2288',
 'group3_acc': '0.7637',
 'group3_margin': '1.8769',
 'split_acc': '0.8922',
 'total_acc': '0.9161'}
Train epoch 1


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[2,   200] loss: 0.183, avg_margin: 3.970
[2,   400] loss: 0.184, avg_margin: 3.985
[2,   600] loss: 0.184, avg_margin: 3.985
[2,   800] loss: 0.185, avg_margin: 3.987
[2,  1000] loss: 0.184, avg_margin: 3.994
[2,  1200] loss: 0.182, avg_margin: 4.002

{'avg_margin': '4.0036',
 'epoch': 2,
 'group0_acc': '0.9373',
 'group0_margin': '5.0494',
 'group1_acc': '0.9268',
 'group1_margin': '3.3377',
 'group2_acc': '0.9487',
 'group2_margin': '4.7137',
 'group3_acc': '0.9179',
 'group3_margin': '2.9146',
 'loss': '0.0670',
 'split_acc': '0.9327',
 'total_acc': '0.9406'}
Validate epoch 1


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.1527',
 'epoch': 2,
 'group0_acc': '0.8887',
 'group0_margin': '4.8049',
 'group1_acc': '0.9083',
 'group1_margin': '3.4658',
 'group2_acc': '0.8824',
 'group2_margin': '4.3447',
 'group3_acc': '0.7582',
 'group3_margin': '1.7690',
 'split_acc': '0.8948',
 'total_acc': '0.9191'}
Train epoch 2


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[3,   200] loss: 0.186, avg_margin: 4.043
[3,   400] loss: 0.191, avg_margin: 4.038
[3,   600] loss: 0.177, avg_margin: 4.047
[3,   800] loss: 0.184, avg_margin: 4.046
[3,  1000] loss: 0.178, avg_margin: 4.059
[3,  1200] loss: 0.183, avg_margin: 4.061

{'avg_margin': '4.0636',
 'epoch': 3,
 'group0_acc': '0.9381',
 'group0_margin': '5.1213',
 'group1_acc': '0.9263',
 'group1_margin': '3.4337',
 'group2_acc': '0.9487',
 'group2_margin': '4.7191',
 'group3_acc': '0.9171',
 'group3_margin': '2.9812',
 'loss': '0.0652',
 'split_acc': '0.9325',
 'total_acc': '0.9408'}
Validate epoch 2


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.0803',
 'epoch': 3,
 'group0_acc': '0.8799',
 'group0_margin': '4.5705',
 'group1_acc': '0.9147',
 'group1_margin': '3.5347',
 'group2_acc': '0.8834',
 'group2_margin': '4.3323',
 'group3_acc': '0.7857',
 'group3_margin': '1.9229',
 'split_acc': '0.8940',
 'total_acc': '0.9178'}
Train epoch 3


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[4,   200] loss: 0.186, avg_margin: 4.097
[4,   400] loss: 0.180, avg_margin: 4.096
[4,   600] loss: 0.178, avg_margin: 4.109
[4,   800] loss: 0.183, avg_margin: 4.109
[4,  1000] loss: 0.179, avg_margin: 4.115
[4,  1200] loss: 0.183, avg_margin: 4.119

{'avg_margin': '4.1186',
 'epoch': 4,
 'group0_acc': '0.9376',
 'group0_margin': '5.1717',
 'group1_acc': '0.9246',
 'group1_margin': '3.5187',
 'group2_acc': '0.9487',
 'group2_margin': '4.7441',
 'group3_acc': '0.9183',
 'group3_margin': '3.0291',
 'loss': '0.0676',
 'split_acc': '0.9323',
 'total_acc': '0.9406'}
Validate epoch 3


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.1790',
 'epoch': 4,
 'group0_acc': '0.8847',
 'group0_margin': '4.7823',
 'group1_acc': '0.9099',
 'group1_margin': '3.5499',
 'group2_acc': '0.8803',
 'group2_margin': '4.3405',
 'group3_acc': '0.7747',
 'group3_margin': '1.9472',
 'split_acc': '0.8935',
 'total_acc': '0.9176'}
Train epoch 4


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[5,   200] loss: 0.180, avg_margin: 4.156
[5,   400] loss: 0.175, avg_margin: 4.169
[5,   600] loss: 0.179, avg_margin: 4.168
[5,   800] loss: 0.181, avg_margin: 4.164
[5,  1000] loss: 0.181, avg_margin: 4.164
[5,  1200] loss: 0.181, avg_margin: 4.164

{'avg_margin': '4.1672',
 'epoch': 5,
 'group0_acc': '0.9386',
 'group0_margin': '5.2655',
 'group1_acc': '0.9249',
 'group1_margin': '3.5399',
 'group2_acc': '0.9516',
 'group2_margin': '4.7918',
 'group3_acc': '0.9206',
 'group3_margin': '3.0715',
 'loss': '0.0634',
 'split_acc': '0.9339',
 'total_acc': '0.9422'}
Validate epoch 4


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.2267',
 'epoch': 5,
 'group0_acc': '0.8845',
 'group0_margin': '4.8250',
 'group1_acc': '0.9085',
 'group1_margin': '3.6009',
 'group2_acc': '0.8838',
 'group2_margin': '4.3947',
 'group3_acc': '0.7692',
 'group3_margin': '1.9808',
 'split_acc': '0.8933',
 'total_acc': '0.9171'}
Train epoch 5


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[6,   200] loss: 0.180, avg_margin: 4.201
[6,   400] loss: 0.176, avg_margin: 4.205
[6,   600] loss: 0.177, avg_margin: 4.205
[6,   800] loss: 0.182, avg_margin: 4.204
[6,  1000] loss: 0.183, avg_margin: 4.205
[6,  1200] loss: 0.180, avg_margin: 4.205

{'avg_margin': '4.2015',
 'epoch': 6,
 'group0_acc': '0.9375',
 'group0_margin': '5.2808',
 'group1_acc': '0.9279',
 'group1_margin': '3.6077',
 'group2_acc': '0.9493',
 'group2_margin': '4.8247',
 'group3_acc': '0.9189',
 'group3_margin': '3.0929',
 'loss': '0.0658',
 'split_acc': '0.9334',
 'total_acc': '0.9416'}
Validate epoch 5


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.2416',
 'epoch': 6,
 'group0_acc': '0.8873',
 'group0_margin': '4.8903',
 'group1_acc': '0.9068',
 'group1_margin': '3.5655',
 'group2_acc': '0.8803',
 'group2_margin': '4.4042',
 'group3_acc': '0.7747',
 'group3_margin': '1.9988',
 'split_acc': '0.8934',
 'total_acc': '0.9171'}
Train epoch 6


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[7,   200] loss: 0.180, avg_margin: 4.210
[7,   400] loss: 0.183, avg_margin: 4.203
[7,   600] loss: 0.178, avg_margin: 4.206
[7,   800] loss: 0.175, avg_margin: 4.221
[7,  1000] loss: 0.179, avg_margin: 4.227
[7,  1200] loss: 0.180, avg_margin: 4.227

{'avg_margin': '4.2285',
 'epoch': 7,
 'group0_acc': '0.9381',
 'group0_margin': '5.2803',
 'group1_acc': '0.9250',
 'group1_margin': '3.6535',
 'group2_acc': '0.9506',
 'group2_margin': '4.8480',
 'group3_acc': '0.9201',
 'group3_margin': '3.1325',
 'loss': '0.0642',
 'split_acc': '0.9334',
 'total_acc': '0.9413'}
Validate epoch 6


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.1672',
 'epoch': 7,
 'group0_acc': '0.8752',
 'group0_margin': '4.6443',
 'group1_acc': '0.9051',
 'group1_margin': '3.5673',
 'group2_acc': '0.8907',
 'group2_margin': '4.6048',
 'group3_acc': '0.8022',
 'group3_margin': '2.1592',
 'split_acc': '0.8893',
 'total_acc': '0.9127'}
Train epoch 7


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[8,   200] loss: 0.181, avg_margin: 4.235
[8,   400] loss: 0.180, avg_margin: 4.246
[8,   600] loss: 0.181, avg_margin: 4.242
[8,   800] loss: 0.172, avg_margin: 4.254
[8,  1000] loss: 0.179, avg_margin: 4.256
[8,  1200] loss: 0.177, avg_margin: 4.261

{'avg_margin': '4.2639',
 'epoch': 8,
 'group0_acc': '0.9397',
 'group0_margin': '5.3558',
 'group1_acc': '0.9260',
 'group1_margin': '3.6501',
 'group2_acc': '0.9506',
 'group2_margin': '4.8915',
 'group3_acc': '0.9204',
 'group3_margin': '3.1588',
 'loss': '0.0618',
 'split_acc': '0.9342',
 'total_acc': '0.9422'}
Validate epoch 7


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.2624',
 'epoch': 8,
 'group0_acc': '0.8819',
 'group0_margin': '4.8391',
 'group1_acc': '0.9032',
 'group1_margin': '3.5772',
 'group2_acc': '0.8921',
 'group2_margin': '4.6641',
 'group3_acc': '0.7692',
 'group3_margin': '2.0349',
 'split_acc': '0.8912',
 'total_acc': '0.9150'}
Train epoch 8


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[9,   200] loss: 0.179, avg_margin: 4.263
[9,   400] loss: 0.176, avg_margin: 4.267
[9,   600] loss: 0.177, avg_margin: 4.277
[9,   800] loss: 0.179, avg_margin: 4.275
[9,  1000] loss: 0.177, avg_margin: 4.276
[9,  1200] loss: 0.173, avg_margin: 4.284

{'avg_margin': '4.2828',
 'epoch': 9,
 'group0_acc': '0.9401',
 'group0_margin': '5.3733',
 'group1_acc': '0.9253',
 'group1_margin': '3.6732',
 'group2_acc': '0.9513',
 'group2_margin': '4.9217',
 'group3_acc': '0.9208',
 'group3_margin': '3.1673',
 'loss': '0.0658',
 'split_acc': '0.9343',
 'total_acc': '0.9424'}
Validate epoch 8


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.3007',
 'epoch': 9,
 'group0_acc': '0.8806',
 'group0_margin': '4.8726',
 'group1_acc': '0.9039',
 'group1_margin': '3.6294',
 'group2_acc': '0.8914',
 'group2_margin': '4.6752',
 'group3_acc': '0.7802',
 'group3_margin': '2.0989',
 'split_acc': '0.8910',
 'total_acc': '0.9148'}
Train epoch 9


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[10,   200] loss: 0.180, avg_margin: 4.279
[10,   400] loss: 0.179, avg_margin: 4.286
[10,   600] loss: 0.178, avg_margin: 4.305
[10,   800] loss: 0.177, avg_margin: 4.306
[10,  1000] loss: 0.183, avg_margin: 4.303
[10,  1200] loss: 0.180, avg_margin: 4.309

{'avg_margin': '4.3085',
 'epoch': 10,
 'group0_acc': '0.9388',
 'group0_margin': '5.4176',
 'group1_acc': '0.9271',
 'group1_margin': '3.7023',
 'group2_acc': '0.9498',
 'group2_margin': '4.9606',
 'group3_acc': '0.9220',
 'group3_margin': '3.1704',
 'loss': '0.0621',
 'split_acc': '0.9344',
 'total_acc': '0.9425'}
Validate epoch 9


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.3935',
 'epoch': 10,
 'group0_acc': '0.8855',
 'group0_margin': '4.9804',
 'group1_acc': '0.9141',
 'group1_margin': '3.8001',
 'group2_acc': '0.8820',
 'group2_margin': '4.5150',
 'group3_acc': '0.7747',
 'group3_margin': '1.9369',
 'split_acc': '0.8959',
 'total_acc': '0.9198'}
Train epoch 10


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[11,   200] loss: 0.175, avg_margin: 4.329
[11,   400] loss: 0.181, avg_margin: 4.323
[11,   600] loss: 0.180, avg_margin: 4.318
[11,   800] loss: 0.179, avg_margin: 4.313
[11,  1000] loss: 0.176, avg_margin: 4.313
[11,  1200] loss: 0.175, avg_margin: 4.318

{'avg_margin': '4.3197',
 'epoch': 11,
 'group0_acc': '0.9388',
 'group0_margin': '5.4399',
 'group1_acc': '0.9285',
 'group1_margin': '3.7178',
 'group2_acc': '0.9515',
 'group2_margin': '4.9369',
 'group3_acc': '0.9212',
 'group3_margin': '3.1794',
 'loss': '0.0642',
 'split_acc': '0.9350',
 'total_acc': '0.9428'}
Validate epoch 10


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.3625',
 'epoch': 11,
 'group0_acc': '0.8862',
 'group0_margin': '5.0449',
 'group1_acc': '0.9056',
 'group1_margin': '3.6386',
 'group2_acc': '0.8796',
 'group2_margin': '4.5652',
 'group3_acc': '0.7692',
 'group3_margin': '2.0715',
 'split_acc': '0.8923',
 'total_acc': '0.9163'}
Train epoch 11


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[12,   200] loss: 0.181, avg_margin: 4.319
[12,   400] loss: 0.173, avg_margin: 4.334
[12,   600] loss: 0.177, avg_margin: 4.339
[12,   800] loss: 0.174, avg_margin: 4.347
[12,  1000] loss: 0.174, avg_margin: 4.352
[12,  1200] loss: 0.176, avg_margin: 4.352

{'avg_margin': '4.3515',
 'epoch': 12,
 'group0_acc': '0.9398',
 'group0_margin': '5.5099',
 'group1_acc': '0.9257',
 'group1_margin': '3.6988',
 'group2_acc': '0.9521',
 'group2_margin': '4.9856',
 'group3_acc': '0.9222',
 'group3_margin': '3.2114',
 'loss': '0.0631',
 'split_acc': '0.9349',
 'total_acc': '0.9432'}
Validate epoch 11


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.5287',
 'epoch': 12,
 'group0_acc': '0.8963',
 'group0_margin': '5.2836',
 'group1_acc': '0.9143',
 'group1_margin': '3.8587',
 'group2_acc': '0.8671',
 'group2_margin': '4.3892',
 'group3_acc': '0.7473',
 'group3_margin': '1.7993',
 'split_acc': '0.8982',
 'total_acc': '0.9226'}
Train epoch 12


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[13,   200] loss: 0.176, avg_margin: 4.353
[13,   400] loss: 0.180, avg_margin: 4.344
[13,   600] loss: 0.178, avg_margin: 4.340
[13,   800] loss: 0.178, avg_margin: 4.348
[13,  1000] loss: 0.179, avg_margin: 4.350
[13,  1200] loss: 0.172, avg_margin: 4.354

{'avg_margin': '4.3562',
 'epoch': 13,
 'group0_acc': '0.9391',
 'group0_margin': '5.5165',
 'group1_acc': '0.9262',
 'group1_margin': '3.7193',
 'group2_acc': '0.9513',
 'group2_margin': '4.9724',
 'group3_acc': '0.9214',
 'group3_margin': '3.2149',
 'loss': '0.0633',
 'split_acc': '0.9345',
 'total_acc': '0.9427'}
Validate epoch 12


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.3105',
 'epoch': 13,
 'group0_acc': '0.8808',
 'group0_margin': '4.9283',
 'group1_acc': '0.9056',
 'group1_margin': '3.6284',
 'group2_acc': '0.8827',
 'group2_margin': '4.5751',
 'group3_acc': '0.8022',
 'group3_margin': '2.1849',
 'split_acc': '0.8907',
 'total_acc': '0.9142'}
Train epoch 13


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[14,   200] loss: 0.170, avg_margin: 4.381
[14,   400] loss: 0.173, avg_margin: 4.391
[14,   600] loss: 0.184, avg_margin: 4.384
[14,   800] loss: 0.174, avg_margin: 4.386
[14,  1000] loss: 0.179, avg_margin: 4.390
[14,  1200] loss: 0.175, avg_margin: 4.391

{'avg_margin': '4.3896',
 'epoch': 14,
 'group0_acc': '0.9391',
 'group0_margin': '5.5750',
 'group1_acc': '0.9261',
 'group1_margin': '3.7335',
 'group2_acc': '0.9521',
 'group2_margin': '5.0160',
 'group3_acc': '0.9209',
 'group3_margin': '3.2324',
 'loss': '0.0624',
 'split_acc': '0.9346',
 'total_acc': '0.9425'}
Validate epoch 13


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.3767',
 'epoch': 14,
 'group0_acc': '0.8819',
 'group0_margin': '4.9593',
 'group1_acc': '0.9106',
 'group1_margin': '3.7226',
 'group2_acc': '0.8866',
 'group2_margin': '4.6753',
 'group3_acc': '0.7912',
 'group3_margin': '2.0881',
 'split_acc': '0.8937',
 'total_acc': '0.9173'}
Train epoch 14


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[15,   200] loss: 0.178, avg_margin: 4.409
[15,   400] loss: 0.178, avg_margin: 4.400
[15,   600] loss: 0.177, avg_margin: 4.402
[15,   800] loss: 0.171, avg_margin: 4.412
[15,  1000] loss: 0.178, avg_margin: 4.411
[15,  1200] loss: 0.173, avg_margin: 4.412

{'avg_margin': '4.4148',
 'epoch': 15,
 'group0_acc': '0.9385',
 'group0_margin': '5.5799',
 'group1_acc': '0.9271',
 'group1_margin': '3.7787',
 'group2_acc': '0.9519',
 'group2_margin': '5.0677',
 'group3_acc': '0.9217',
 'group3_margin': '3.2212',
 'loss': '0.0609',
 'split_acc': '0.9348',
 'total_acc': '0.9429'}
Validate epoch 14


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.4111',
 'epoch': 15,
 'group0_acc': '0.8859',
 'group0_margin': '5.1140',
 'group1_acc': '0.8993',
 'group1_margin': '3.5910',
 'group2_acc': '0.8914',
 'group2_margin': '4.8310',
 'group3_acc': '0.7802',
 'group3_margin': '2.1085',
 'split_acc': '0.8913',
 'total_acc': '0.9150'}
Train epoch 15


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[16,   200] loss: 0.178, avg_margin: 4.414
[16,   400] loss: 0.171, avg_margin: 4.429
[16,   600] loss: 0.175, avg_margin: 4.429
[16,   800] loss: 0.176, avg_margin: 4.421
[16,  1000] loss: 0.179, avg_margin: 4.419
[16,  1200] loss: 0.175, avg_margin: 4.420

{'avg_margin': '4.4213',
 'epoch': 16,
 'group0_acc': '0.9384',
 'group0_margin': '5.6005',
 'group1_acc': '0.9259',
 'group1_margin': '3.7782',
 'group2_acc': '0.9514',
 'group2_margin': '5.0522',
 'group3_acc': '0.9210',
 'group3_margin': '3.2531',
 'loss': '0.0629',
 'split_acc': '0.9342',
 'total_acc': '0.9423'}
Validate epoch 15


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.4998',
 'epoch': 16,
 'group0_acc': '0.8955',
 'group0_margin': '5.3839',
 'group1_acc': '0.9042',
 'group1_margin': '3.6573',
 'group2_acc': '0.8713',
 'group2_margin': '4.4558',
 'group3_acc': '0.7747',
 'group3_margin': '2.0405',
 'split_acc': '0.8945',
 'total_acc': '0.9187'}
Train epoch 16


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[17,   200] loss: 0.177, avg_margin: 4.420
[17,   400] loss: 0.173, avg_margin: 4.424
[17,   600] loss: 0.172, avg_margin: 4.432
[17,   800] loss: 0.172, avg_margin: 4.433
[17,  1000] loss: 0.179, avg_margin: 4.434
[17,  1200] loss: 0.172, avg_margin: 4.435

{'avg_margin': '4.4354',
 'epoch': 17,
 'group0_acc': '0.9397',
 'group0_margin': '5.6228',
 'group1_acc': '0.9278',
 'group1_margin': '3.7622',
 'group2_acc': '0.9514',
 'group2_margin': '5.0789',
 'group3_acc': '0.9241',
 'group3_margin': '3.2683',
 'loss': '0.0625',
 'split_acc': '0.9358',
 'total_acc': '0.9444'}
Validate epoch 16


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.4471',
 'epoch': 17,
 'group0_acc': '0.8841',
 'group0_margin': '5.0925',
 'group1_acc': '0.9070',
 'group1_margin': '3.7282',
 'group2_acc': '0.8876',
 'group2_margin': '4.7528',
 'group3_acc': '0.7747',
 'group3_margin': '2.0448',
 'split_acc': '0.8931',
 'total_acc': '0.9171'}
Train epoch 17


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[18,   200] loss: 0.183, avg_margin: 4.422
[18,   400] loss: 0.174, avg_margin: 4.433
[18,   600] loss: 0.173, avg_margin: 4.442
[18,   800] loss: 0.178, avg_margin: 4.438
[18,  1000] loss: 0.177, avg_margin: 4.444
[18,  1200] loss: 0.176, avg_margin: 4.441

{'avg_margin': '4.4414',
 'epoch': 18,
 'group0_acc': '0.9376',
 'group0_margin': '5.6373',
 'group1_acc': '0.9279',
 'group1_margin': '3.8015',
 'group2_acc': '0.9508',
 'group2_margin': '5.0807',
 'group3_acc': '0.9207',
 'group3_margin': '3.2557',
 'loss': '0.0645',
 'split_acc': '0.9343',
 'total_acc': '0.9431'}
Validate epoch 17


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.4622',
 'epoch': 18,
 'group0_acc': '0.8830',
 'group0_margin': '5.0875',
 'group1_acc': '0.9077',
 'group1_margin': '3.7639',
 'group2_acc': '0.8866',
 'group2_margin': '4.7680',
 'group3_acc': '0.7692',
 'group3_margin': '2.0610',
 'split_acc': '0.8927',
 'total_acc': '0.9167'}
Train epoch 18


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[19,   200] loss: 0.181, avg_margin: 4.425
[19,   400] loss: 0.180, avg_margin: 4.427
[19,   600] loss: 0.174, avg_margin: 4.434
[19,   800] loss: 0.181, avg_margin: 4.429
[19,  1000] loss: 0.174, avg_margin: 4.438
[19,  1200] loss: 0.175, avg_margin: 4.443

{'avg_margin': '4.4419',
 'epoch': 19,
 'group0_acc': '0.9374',
 'group0_margin': '5.6559',
 'group1_acc': '0.9270',
 'group1_margin': '3.7701',
 'group2_acc': '0.9522',
 'group2_margin': '5.0931',
 'group3_acc': '0.9212',
 'group3_margin': '3.2750',
 'loss': '0.0650',
 'split_acc': '0.9344',
 'total_acc': '0.9421'}
Validate epoch 18


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.4062',
 'epoch': 19,
 'group0_acc': '0.8862',
 'group0_margin': '5.2014',
 'group1_acc': '0.8923',
 'group1_margin': '3.4845',
 'group2_acc': '0.8911',
 'group2_margin': '4.8382',
 'group3_acc': '0.8077',
 'group3_margin': '2.2088',
 'split_acc': '0.8888',
 'total_acc': '0.9127'}
Train epoch 19


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[20,   200] loss: 0.172, avg_margin: 4.462
[20,   400] loss: 0.172, avg_margin: 4.476
[20,   600] loss: 0.175, avg_margin: 4.471
[20,   800] loss: 0.176, avg_margin: 4.468
[20,  1000] loss: 0.171, avg_margin: 4.471
[20,  1200] loss: 0.178, avg_margin: 4.464

{'avg_margin': '4.4621',
 'epoch': 20,
 'group0_acc': '0.9416',
 'group0_margin': '5.6703',
 'group1_acc': '0.9266',
 'group1_margin': '3.7918',
 'group2_acc': '0.9520',
 'group2_margin': '5.1026',
 'group3_acc': '0.9217',
 'group3_margin': '3.2807',
 'loss': '0.0660',
 'split_acc': '0.9355',
 'total_acc': '0.9434'}
Validate epoch 19


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.5688',
 'epoch': 20,
 'group0_acc': '0.8942',
 'group0_margin': '5.4475',
 'group1_acc': '0.9060',
 'group1_margin': '3.7456',
 'group2_acc': '0.8699',
 'group2_margin': '4.4852',
 'group3_acc': '0.7747',
 'group3_margin': '2.1112',
 'split_acc': '0.8945',
 'total_acc': '0.9187'}
Train epoch 20


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[21,   200] loss: 0.176, avg_margin: 4.464
[21,   400] loss: 0.178, avg_margin: 4.467
[21,   600] loss: 0.178, avg_margin: 4.466
[21,   800] loss: 0.179, avg_margin: 4.466
[21,  1000] loss: 0.174, avg_margin: 4.463
[21,  1200] loss: 0.174, avg_margin: 4.468

{'avg_margin': '4.4677',
 'epoch': 21,
 'group0_acc': '0.9404',
 'group0_margin': '5.7220',
 'group1_acc': '0.9271',
 'group1_margin': '3.7589',
 'group2_acc': '0.9498',
 'group2_margin': '5.1061',
 'group3_acc': '0.9237',
 'group3_margin': '3.2909',
 'loss': '0.0615',
 'split_acc': '0.9352',
 'total_acc': '0.9436'}
Validate epoch 20


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.4714',
 'epoch': 21,
 'group0_acc': '0.8894',
 'group0_margin': '5.2247',
 'group1_acc': '0.9058',
 'group1_margin': '3.7011',
 'group2_acc': '0.8758',
 'group2_margin': '4.6037',
 'group3_acc': '0.7747',
 'group3_margin': '2.0851',
 'split_acc': '0.8932',
 'total_acc': '0.9168'}
Train epoch 21


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[22,   200] loss: 0.177, avg_margin: 4.433
[22,   400] loss: 0.175, avg_margin: 4.446
[22,   600] loss: 0.179, avg_margin: 4.459
[22,   800] loss: 0.182, avg_margin: 4.459
[22,  1000] loss: 0.172, avg_margin: 4.464
[22,  1200] loss: 0.176, avg_margin: 4.462

{'avg_margin': '4.4603',
 'epoch': 22,
 'group0_acc': '0.9391',
 'group0_margin': '5.6665',
 'group1_acc': '0.9245',
 'group1_margin': '3.7954',
 'group2_acc': '0.9510',
 'group2_margin': '5.0990',
 'group3_acc': '0.9237',
 'group3_margin': '3.2913',
 'loss': '0.0658',
 'split_acc': '0.9345',
 'total_acc': '0.9427'}
Validate epoch 21


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.5209',
 'epoch': 22,
 'group0_acc': '0.8825',
 'group0_margin': '5.1286',
 'group1_acc': '0.9143',
 'group1_margin': '3.8862',
 'group2_acc': '0.8848',
 'group2_margin': '4.7008',
 'group3_acc': '0.7802',
 'group3_margin': '2.0405',
 'split_acc': '0.8952',
 'total_acc': '0.9189'}
Train epoch 22


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[23,   200] loss: 0.176, avg_margin: 4.463
[23,   400] loss: 0.178, avg_margin: 4.455
[23,   600] loss: 0.180, avg_margin: 4.458
[23,   800] loss: 0.175, avg_margin: 4.461
[23,  1000] loss: 0.168, avg_margin: 4.470
[23,  1200] loss: 0.174, avg_margin: 4.471

{'avg_margin': '4.4745',
 'epoch': 23,
 'group0_acc': '0.9405',
 'group0_margin': '5.6507',
 'group1_acc': '0.9267',
 'group1_margin': '3.7983',
 'group2_acc': '0.9531',
 'group2_margin': '5.1524',
 'group3_acc': '0.9220',
 'group3_margin': '3.2994',
 'loss': '0.0603',
 'split_acc': '0.9356',
 'total_acc': '0.9438'}
Validate epoch 22


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.4776',
 'epoch': 23,
 'group0_acc': '0.8838',
 'group0_margin': '5.1639',
 'group1_acc': '0.9097',
 'group1_margin': '3.7621',
 'group2_acc': '0.8817',
 'group2_margin': '4.6493',
 'group3_acc': '0.7912',
 'group3_margin': '2.1138',
 'split_acc': '0.8934',
 'total_acc': '0.9173'}
Train epoch 23


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[24,   200] loss: 0.170, avg_margin: 4.503
[24,   400] loss: 0.177, avg_margin: 4.484
[24,   600] loss: 0.176, avg_margin: 4.489
[24,   800] loss: 0.171, avg_margin: 4.494
[24,  1000] loss: 0.170, avg_margin: 4.497
[24,  1200] loss: 0.174, avg_margin: 4.497

{'avg_margin': '4.4964',
 'epoch': 24,
 'group0_acc': '0.9389',
 'group0_margin': '5.7422',
 'group1_acc': '0.9297',
 'group1_margin': '3.8181',
 'group2_acc': '0.9529',
 'group2_margin': '5.1346',
 'group3_acc': '0.9229',
 'group3_margin': '3.2812',
 'loss': '0.0634',
 'split_acc': '0.9361',
 'total_acc': '0.9440'}
Validate epoch 23


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.5605',
 'epoch': 24,
 'group0_acc': '0.8890',
 'group0_margin': '5.3041',
 'group1_acc': '0.9122',
 'group1_margin': '3.8280',
 'group2_acc': '0.8747',
 'group2_margin': '4.6214',
 'group3_acc': '0.7692',
 'group3_margin': '2.0421',
 'split_acc': '0.8955',
 'total_acc': '0.9196'}
Train epoch 24


HBox(children=(IntProgress(value=0, max=1272), HTML(value='')))

[25,   200] loss: 0.166, avg_margin: 4.544
[25,   400] loss: 0.174, avg_margin: 4.536
[25,   600] loss: 0.178, avg_margin: 4.521
[25,   800] loss: 0.177, avg_margin: 4.515
[25,  1000] loss: 0.172, avg_margin: 4.516
[25,  1200] loss: 0.173, avg_margin: 4.519

{'avg_margin': '4.5196',
 'epoch': 25,
 'group0_acc': '0.9405',
 'group0_margin': '5.7445',
 'group1_acc': '0.9271',
 'group1_margin': '3.8385',
 'group2_acc': '0.9536',
 'group2_margin': '5.1837',
 'group3_acc': '0.9242',
 'group3_margin': '3.3095',
 'loss': '0.0610',
 'split_acc': '0.9364',
 'total_acc': '0.9445'}
Validate epoch 24


HBox(children=(IntProgress(value=0, max=156), HTML(value='')))


{'avg_margin': '4.6103',
 'epoch': 25,
 'group0_acc': '0.8909',
 'group0_margin': '5.3639',
 'group1_acc': '0.9142',
 'group1_margin': '3.9133',
 'group2_acc': '0.8716',
 'group2_margin': '4.5473',
 'group3_acc': '0.7637',
 'group3_margin': '1.9607',
 'split_acc': '0.8967',
 'total_acc': '0.9208'}
