In [1]:
#!/usr/bin/env python
# coding: utf-8

# In[1]:


from __future__ import print_function, division
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="2,3"

from itertools import chain
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
from IPython.core.debugger import set_trace
import itertools
import seaborn as sns
from tqdm import tqdm
import random
import cv2
from natsort import natsorted
import collections
import skimage
from IPython import display
import pylab as pl
from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.metrics.regression import mean_absolute_error, mean_squared_error, r2_score, explained_variance_score
import numpy as np
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF


from skorch import NeuralNetRegressor
from skorch.helper import predefined_split
from sklearn.metrics.regression import mean_absolute_error, mean_squared_error, r2_score, explained_variance_score
from skorch import callbacks

from dask.distributed import Client
from sklearn.externals import joblib
from sklearn.model_selection import GridSearchCV


plt.ion()
%matplotlib inline

# # Check cuda.is_available ?

# In[2]:


cuda_available = torch.cuda.is_available()
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print("cuda_available : {}, device : {}".format(cuda_available, device))


# # Define Dataset & DataLoader

# In[3]:

MAXLEN = 300
FPS = 24.0

def drop_huge_seq(input_df, save_path="./preprocess/example_data/person_detection_and_tracking_results_drop.pkl"):
    if os.path.exists(save_path):
        print('Already dropped! Return...')
        return
    
    vids = list(set(input_df.vids))

    for i in tqdm(range(len(vids)), desc='DropInputSeq '):
        slice_df = input_df.loc[input_df.vids==vids[i]]
        if slice_df.values.shape[0] > MAXLEN:
            input_df.iloc[slice_df.index] = np.nan * np.ones_like(slice_df.values)

    # drop Nans !
    res_df = input_df.dropna()
    res_df.to_pickle("./preprocess/example_data/person_detection_and_tracking_results_drop.pkl")

input_df = pd.read_csv('./preprocess/example_data/person_detection_and_tracking_results.csv',
                       sep='\t', names=['vids', 'idx', 'pos'])
    
# drop huge seq
drop_huge_seq(input_df, save_path="./preprocess/example_data/person_detection_and_tracking_results_drop.pkl")


# In[4]:


df = pd.read_pickle('./preprocess/example_data/person_detection_and_tracking_results_drop.pkl')
len(list(set(df.vids)))


# In[5]:


df = pd.read_pickle("./preprocess/data/targets_dataframe.pkl")
target_columns = df.columns.values[:-2]
# target_columns = ['Toe In / Out/L', 'Toe In / Out/R']


# In[6]:


def pid2vid(pid):
    num, test_id, trial_id = pid.split('_')
    return '_'.join([num, 'test', test_id, 'trial', trial_id])
    

def vid2pid(vid):
    split = vid.split('_')
    return '_'.join([split[0], split[2], split[4]])



class GAITDataset(Dataset):
    def __init__(self,
                 X, y, scaler, frame_home="/data/GaitData/CroppedFrameArrays", maxlen=300):
        
        self.frame_home = frame_home
        
        self.X = X
        self.y = y
        self.vids = [ pid2vid(pid) for pid in self.y.index ]
        
        self.maxlen = maxlen
        
        if scaler:
            scaled_values = scaler.transform(y)
            self.y.loc[:,:] = scaled_values
            
    def __len__(self):
        return len(self.vids)
    
    def __getitem__(self, idx):
        
        vid = self.vids[idx]
        positions = [ eval(val) for val in self.X.loc[self.X.vids==vid].pos.values ]
        
        
        stacked_arr = np.load(os.path.join(self.frame_home, vid) + '.npy')
        
        inputs = []        

        for cropped in stacked_arr:  
            pic = cv2.resize(cv2.cvtColor(cropped, cv2.COLOR_BGR2GRAY), (64,64))[:,:,None]
            
            pic = TF.to_tensor(pic) # scale to [0.0, 1.0]
            pic = TF.normalize(pic, (0.5,), (0.5,)).permute(1,2,0).numpy()   # scale to [-1.0, 1.0]
            inputs.append(pic)
            
        targets = self.y.loc[vid2pid(vid)].values

        # zero padding
        inputs = np.pad(inputs, ((0,self.maxlen-len(inputs)),(0,0),(0,0),(0,0)),
                                               'constant', constant_values=0).transpose(3,0,1,2)
        
        return torch.tensor(inputs, dtype=torch.float32), torch.tensor(inputs, dtype=torch.float32)



cuda_available : True, device : cuda:0
Already dropped! Return...


In [2]:
class Conv3d_with_same_padding(nn.Conv3d):
    def __init__(self, in_channels,
                       out_channels,
                       kernel_size,
                       stride=1,
                       padding=0,
                       dilation=1,
                       groups=1,
                       bias=True,
                       padding_type='same'):
        
        super(Conv3d_with_same_padding, self).__init__(in_channels,
                                     out_channels,
                                     kernel_size,
                                     stride,
                                     padding,
                                     dilation,
                                     groups,
                                     bias)
        
        self.padding_type = padding_type
    
    def forward(self, x, debug=False):
        n,c,d,h,w = x.size()
        if self.padding_type == 'same':
            padding_need = []
            for i,e in enumerate([d,h,w]):
                bias = 0.5 if self.stride[i] % 2 == 0 else 0.0
                padding_need.append(round((e * (self.stride[i]-1) + self.kernel_size[i] - self.stride[i]) / 2 + bias))
            
            padding_need = tuple(padding_need)
            
        if debug:
            set_trace()

        return F.conv3d(x, self.weight, self.bias, self.stride, 
                        padding_need, self.dilation, self.groups)


# In[ ]:


class ResidualBlock(nn.Module):
    def __init__(self, C_in, C_out, pool, highway=True):
        super(ResidualBlock, self).__init__()
        self.pool = pool
        self.highway = highway
                
        stride = 1
        
        if C_in != C_out:
            C = C_out
        else:
            C = C_in = C_out
            
        if pool:
            # input dimension matchig
            self.conv_matching = Conv3d_with_same_padding(C_in, C, kernel_size=1, stride=1, padding_type='same')
            self.bn_matching = nn.BatchNorm3d(C)

            # for pooling of residual path
            stride = 2
            self.conv_pool = Conv3d_with_same_padding(C_in, C, kernel_size=1, stride=2, padding_type='same')
            self.bn_pool= nn.BatchNorm3d(C)
                
        # conv_a : reduce number of channels by factor of 4 (output_channel = C/4)
        self.conv_a = Conv3d_with_same_padding(C, int(C/4), kernel_size=1, stride=stride, padding_type='same')
        self.bn_a = nn.BatchNorm3d(int(C/4))
        
        # conv_b : more wide receptive field (output_channel = C/4)
        self.conv_b = Conv3d_with_same_padding(int(C/4), int(C/4), kernel_size=3, stride=1, padding_type='same')
        self.bn_b = nn.BatchNorm3d(int(C/4))
        
        # conv_c : recover org channel C (output_channel = C)
        self.conv_c = Conv3d_with_same_padding(int(C/4), C, kernel_size=1, stride=1, padding_type='same')
        self.bn_c = nn.BatchNorm3d(C)
        
        if highway:
            # conv_g : gating for highway network
            self.conv_g = Conv3d_with_same_padding(C, C, kernel_size=1, stride=1, padding_type='same')
        
    
    def forward(self, x):
        '''
            x : size = (batch, channels, maxlen, height, width)
        '''
        
        res = x
        
        if self.pool:
            # input dimension matching with 1x1 conv
            x = self.conv_matching(x)
            x = self.bn_matching(x)
            
            # pooling of residual path
            res = self.conv_pool(res)
            res = self.bn_pool(res)
        
        # conv_a (C/4)
        x = self.conv_a(x)
        x = self.bn_a(x)
        x = F.relu(x)
        
        # conv_b (C/4)
        x = self.conv_b(x)
        x = self.bn_b(x)
        x = F.relu(x)
        
        # conv_c (C)
        x = self.conv_c(x)
        x = self.bn_c(x)
        
        if self.highway:
            # gating mechanism from "highway network"
            
            # gating factors controll intensity between x and f(x)
            # gating = 1.0 (short circuit) --> output is identity (same as initial input)
            # gating = 0.0 (open circuit)--> output is f(x) (case of non-residual network)
            gating = torch.sigmoid(self.conv_g(x))
            
            # apply gating mechanism
            x = gating * res + (1.0 - gating) * F.relu(x)

            
        else:
            # normal residual ops (addition)
            x = F.relu(x) + res
        
        return x


# In[ ]:


class View(nn.Module):
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape
    def forward(self, x):
        return x.view(*self.shape)

class GAP(nn.Module):
    def __init__(self):
        super(GAP, self).__init__()
    def forward(self, x):
        '''
        
            x : size = (N,C,D,H,W)
        '''
        set_trace()
        return torch.mean(x, (2,3,4))

class HighWay(nn.Module):
    def __init__(self, input_channel=1, 
                 num_layers = [3,4,6], num_filters = [64,128,256]):
        
        super(HighWay, self).__init__()
        
        self.num_layers = num_layers
        self.num_filters = num_filters

        def res_blocks(residual_blocks, num_layers, num_filters, block_ix, pool_first_layer=True):
            block_layers = num_layers[block_ix]

            for i in range(block_layers):
                # default values
                pool = False
                block_filters = num_filters[block_ix]
                
                C_in = C_out = block_filters
                
                if pool_first_layer and i==0:
                    pool = True
                if i==0 and block_ix > 0:
                    C_in = num_filters[block_ix-1]
                    
                print(f"layer : {i}, block : {block_ix}, C_in/C_out : {C_in}/{C_out}")
                residual_blocks.append(ResidualBlock(C_in=C_in, C_out=C_out,pool=pool, highway=True))
                
        residual_blocks = []

        for i in range(len(num_layers)):
            pool_first_layer = True
            if i == 0:
                pool_first_layer = False
            res_blocks(residual_blocks, num_layers=num_layers, num_filters=num_filters, block_ix=i,
                       pool_first_layer=pool_first_layer)
        
        
        nn.ReLU
        self.model = nn.Sequential(nn.Conv3d(input_channel, num_filters[0], kernel_size=(3,7,7), stride=(1,1,1)),
                                   nn.BatchNorm3d(num_filters[0]), 
                                   nn.ReLU(),
                                   nn.Conv3d(num_filters[0], num_filters[1], kernel_size=3, stride=2),
                                   nn.BatchNorm3d(num_filters[1]), 
                                   nn.ReLU(),
                                   nn.MaxPool3d(kernel_size=(3,), stride=2),
                                   nn.Conv3d(num_filters[1], num_filters[2], kernel_size=3, stride=2),
                                   nn.BatchNorm3d(num_filters[2]), 
                                   nn.ReLU(),
                                   #*residual_blocks,
                                   GAP(),
                                   nn.Linear(num_filters[-1], 17)
                                   )

    def forward(self, img):
        '''
            img : size = (batch, input_channel, maxlen, height, width)
        '''

        return self.model(img)

    
class Concat(nn.Module):
    def __init__(self, val, dim):
        super(Concat, self).__init__()
        
        self.val = val
        self.dim = dim
        
    def forward(self, x):
        return torch.cat([self.val, x], self.dim)

    
class Encoder(nn.Module):
    def __init__(self, num_filters = [1,32,64,128,256]):
        
        super(Encoder, self).__init__()
        
        self.encode = nn.Sequential(
            # b, 32, 150, 32, 32
            nn.Conv3d(num_filters[0], num_filters[1], kernel_size=4, stride=2, padding=1), # ()
            nn.BatchNorm3d(num_filters[1]), 
            nn.ReLU(True),
            
            # b, 64, 75, 16, 16
            nn.Conv3d(num_filters[1], num_filters[2], kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(num_filters[2]), 
            nn.ReLU(True),
            
            # b, 128, 25, 8, 8
            nn.Conv3d(num_filters[2], num_filters[3], kernel_size=(5,4,4), stride=(3,2,2), padding=1),
            nn.BatchNorm3d(num_filters[3]), 
            nn.ReLU(True),
            
            # b, 256, 8, 4, 4
            nn.Conv3d(num_filters[3], num_filters[4], kernel_size=(6,4,4), stride=(3,2,2), padding=1),
            nn.BatchNorm3d(num_filters[4]), 
            nn.ReLU(True),            
            
         )

    def forward(self, x):
        '''
            x : size = (B, C, D, H, W)
        '''
        return self.encode(x)

class Decoder(nn.Module):
    def __init__(self, enc_feats, num_filters = [256,128,64,32,1]):
        
        super(Decoder, self).__init__()
        
        self.decode = nn.Sequential(
            # b, 128, 25, 8, 8
            nn.ConvTranspose3d(num_filters[0], num_filters[1], kernel_size=(6,4,4), stride=(3,2,2), padding=1),
            nn.BatchNorm3d(num_filters[1]), 
            nn.ReLU(True),
            
            Concat(enc_feats[-2], 1),
            
            # b, 64, 75, 16, 16
            nn.ConvTranspose3d(2*num_filters[1], num_filters[2], kernel_size=(5,4,4), stride=(3,2,2), padding=1),
            nn.BatchNorm3d(num_filters[2]), 
            nn.ReLU(True),
            
            Concat(enc_feats[-3], 1),

            # b, 32, 150, 32, 32
            nn.ConvTranspose3d(2*num_filters[2], num_filters[3], kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(num_filters[3]), 
            nn.ReLU(True),
            
            
            Concat(enc_feats[-4], 1),
            
            # b, 1, 300, 64, 64
            nn.ConvTranspose3d(2*num_filters[3], num_filters[4], kernel_size=4, stride=2, padding=1),
            nn.Tanh()            
            
        )

    def forward(self, x):
        '''
            x : size = (B, C, D, H, W)
        '''
        return self.decode(x)
    

    
    
class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()
        
        self.encoder = Encoder()
    
    def encode(self, X):
        enc_feats = []
        for layer in self.encoder.encode:
            if type(layer).__name__ == 'ReLU':
                enc_feats.append(X)
            X = layer(X)
        
        encoded = enc_feats[-1]
        
        return Decoder(enc_feats).to(device), encoded
    
            
    def forward(self, X):
        decoder, encoded = self.encode(X)
        decoded = decoder(encoded)
        return decoded, encoded  # <- return a tuple of two values
    
    
class AutoEncoderNet(NeuralNetRegressor):
    def get_loss(self, y_pred, y_true, *args, **kwargs):
        decoded, encoded = y_pred  # <- unpack the tuple that was returned by `forward`
        loss_reconstruction = super(AutoEncoderNet, self).get_loss(decoded, y_true, *args, **kwargs)
        loss_l1 = 1e-3 * torch.abs(encoded).sum()
        
        return loss_reconstruction + loss_l1        

In [3]:
from sklearn.model_selection import train_test_split

def filter_input_df_with_vids(df, vids):
    return df[df['vids'].isin(vids)]

def filter_target_df_with_vids(df, vids):
    target_ids = [ vid2pid(vid) for vid in vids ]
    return df.loc[target_ids]

def split_dataset_with_vids(input_df, target_df, vids, test_size=0.3, random_state=42):
    train_vids, test_vids = train_test_split(vids, test_size=test_size, random_state=random_state)

    train_X, train_y = filter_input_df_with_vids(input_df,train_vids), filter_target_df_with_vids(target_df,train_vids)
    test_X, test_y = filter_input_df_with_vids(input_df,test_vids), filter_target_df_with_vids(target_df, test_vids)
        
    return train_X, train_y, train_vids, test_X, test_y, test_vids


def mape(y_true, y_pred):
    ape = []
    zero_cnt = 0
    for true,pred in zip(y_true, y_pred):
        if 0.0 not in true:
            ape.append(np.abs((true-pred)/true))  
        else:
            zero_cnt += 1
    
    return np.mean(ape), zero_cnt

In [4]:
from skorch.callbacks import Callback
from torchvision.utils import save_image

def to_img(x):
    x = 0.5 * (x + 1)
    x = np.clip(x, 0.0, 1.0)
    x = 255*x
    return x.astype(np.uint8)

def to_tensor_img(x):
    x = 0.5 * (x + 1)
    x = x.clamp(0, 1)
    x = x.view(x.size(0), 1, 64, 64)
    return x

class SaveResults(Callback):
    def __init__(self, path):
        self.path = path
        
    def on_epoch_end(self, net, **kwargs):        
        for name in ['train', 'valid']:
            dataset = kwargs['dataset_'+name]
            rand_ix = np.random.randint(len(dataset))
            X,y = dataset[rand_ix]
            
            save_dir = os.path.join(self.path, name)
            if not os.path.exists(save_dir):
                os.makedirs(save_dir)
            
            # target img
            y = y.numpy().transpose(1,2,3,0)  # (maxlen,h,w,3)
            
            # predicted img
            pred = net.predict(X[None,:])[0].transpose(1,2,3,0) # (maxlen,h,w,3)
            
            for sub_name,pic in zip(['target', 'pred'], [y,pred]):
                pic = to_tensor_img(torch.from_numpy(pic))
                save_image(pic, os.path.join(save_dir,sub_name+'.png'))

In [5]:


# dataset path
input_file = "./preprocess/example_data/person_detection_and_tracking_results_drop.pkl"
target_file = "./preprocess/data/targets_dataframe.pkl"

input_df = pd.read_pickle(input_file)
target_df = pd.read_pickle(target_file)[target_columns]

possible_vids = list(set(input_df.vids))[:100]
train_X, train_y, train_vids, test_X, test_y, test_vids = split_dataset_with_vids(input_df, target_df, possible_vids, test_size=0.3, random_state=42)

# target scaler
# scaler = StandardScaler()
# train_y.loc[:,:] = scaler.fit_transform(train_y.values)

scaler = None

# holdouf test set for final evaluation
test_dataset = GAITDataset(test_X, test_y, scaler)
test_batcher = DataLoader(test_dataset,batch_size=10, shuffle=False, num_workers=16)

from sklearn.model_selection import KFold
kf = KFold(n_splits=5)

train_vids = np.array(train_vids)


from torch.nn.modules.loss import _Loss

class MyCriterion(_Loss):
    def __init__(self):
        super(MyCriterion, self).__init__()
    
    def forward(self, x, y):
        valid_mask = ~(y.view(y.size(0),MAXLEN,-1)==0).all(dim=2)
        valid_mask = valid_mask.float()
        return torch.mean(torch.sum((valid_mask * ((x-y)**2).mean((1,3,4))),1)/torch.sum(valid_mask,1))
        
    
# for parallelism
#client = Client('127.0.0.1:8786')

# cross validation loop
scores = {'MAPE': [], 'MAE': [], 'RMSE': [], 'R2': [], 'Explained variation': []}

for train, valid in kf.split(train_vids):
    # split trainset with train/valid
    train_split, valid_split = train_vids[train], train_vids[valid]
    
    train_X, train_y = filter_input_df_with_vids(input_df,train_split), filter_target_df_with_vids(target_df,train_split)
    valid_X, valid_y = filter_input_df_with_vids(input_df,valid_split), filter_target_df_with_vids(target_df,valid_split)


    # dsataset !!
    train_dataset = GAITDataset(train_X, train_y, scaler)
    valid_dataset = GAITDataset(valid_X, valid_y, scaler)
        
    # Init net !
    net = AutoEncoderNet(
        AutoEncoder,
        batch_size=10,
        max_epochs=100,
        lr=1e-3,
        optimizer=torch.optim.Adam,
        #optimizer__weight_decay=1e-5,
        #optimizer__momentum=0.9,
        #optimizer__nesterov=True,
        criterion=MyCriterion,
        device='cuda',
        train_split=predefined_split(valid_dataset),
        # Shuffle training data on each epoch
        iterator_train__shuffle=True,
        callbacks=[#('ealy_stop', callbacks.EarlyStopping()),
                   #('lr_scheduler', callbacks.LRScheduler(policy='WarmRestartLR', base_period=2)),
                   ('prog_bar', callbacks.ProgressBar()),
                   ('save_results', SaveResults(path='./results'))
                   ],
    
    )
    
    #with joblib.parallel_backend('dask'):
    # fit with train set
    net.fit(train_dataset, y=None)

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1       [36m51.4635[0m       [32m79.4709[0m  9.3011


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      2       [36m42.1256[0m       [32m42.9796[0m  7.2749


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      3       [36m36.1266[0m       [32m40.9593[0m  7.2143


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      4       [36m32.2057[0m       [32m28.1359[0m  7.2062


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      5       [36m29.3680[0m       [32m24.8921[0m  7.2725


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      6       [36m29.0785[0m       [32m20.8529[0m  7.1456


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      7       [36m25.2865[0m       26.8626  7.2413


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      8       [36m24.3800[0m       21.1166  7.2105


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      9       [36m22.5533[0m       [32m20.1500[0m  7.1727


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     10       [36m20.8410[0m       [32m13.3072[0m  7.2371


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     11       [36m19.4518[0m       13.8138  7.1825


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     12       [36m18.0650[0m        [32m5.1682[0m  7.1554


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     13       18.2805       19.1771  7.2094


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     14       [36m17.9166[0m       17.5556  7.2341


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     15       [36m16.6159[0m       11.5311  7.1808


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     16       [36m14.8015[0m        8.2643  7.1999


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     17       [36m13.6367[0m        8.8760  7.1913


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     18       [36m12.2334[0m        6.2714  7.2077


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     19       [36m11.1095[0m        [32m5.0977[0m  7.2404


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     20       [36m10.0670[0m        6.4513  7.2116


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     21       10.2528        7.7167  7.2446


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     22        [36m9.3213[0m        8.9658  7.2707


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     23        [36m8.6126[0m        6.2115  7.2681


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     24        [36m7.3253[0m        [32m2.4168[0m  7.2003


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     25        [36m6.6912[0m        2.8634  7.1898


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     26        [36m5.2654[0m        3.1111  7.1823


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     27        [36m4.0814[0m        [32m2.2238[0m  7.2192


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     28        4.2111        2.4991  7.2705


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     29        [36m3.4842[0m        2.6264  7.2384


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     30        3.5892        3.2910  7.1883


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     31        4.3204        3.1900  7.2169


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     32        3.7240        [32m1.6860[0m  7.1690


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     33        [36m3.0891[0m        [32m1.3253[0m  7.1266


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     34        [36m2.5714[0m        1.5417  7.1270


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     35        [36m2.0281[0m        [32m1.2837[0m  7.1544


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     36        [36m1.9148[0m        [32m1.0265[0m  7.2363


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     37        [36m1.4374[0m        [32m0.8756[0m  7.2065


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     38        [36m1.1512[0m        [32m0.7577[0m  7.2325


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     39        [36m1.0384[0m        0.8153  7.1437


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     40        [36m0.9861[0m        [32m0.7274[0m  7.2035


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     41        [36m0.8940[0m        0.7627  7.2601


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     42        [36m0.7906[0m        [32m0.5706[0m  7.2326


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     43        [36m0.7386[0m        0.7751  7.2315


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     44        0.8429        0.7364  7.1896


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     45        1.6694        0.8498  7.2516


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     46        0.9950        0.7646  7.1800


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     47        0.8168        0.6164  7.2107


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     48        0.8914        0.9542  7.2181


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     49        1.4511        1.0992  7.2054


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     50        0.9010        0.6867  7.2704


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     51        0.8653        [32m0.5581[0m  7.2452


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     52        [36m0.5911[0m        0.7131  7.3445


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     53        0.6969        0.5674  7.2956


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     54        [36m0.5339[0m        0.8846  7.3014


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     55        0.6213        0.7547  7.2264


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     56        0.6521        0.6100  7.2436


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     57        0.5811        0.7127  7.1951


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     58        0.6090        0.6224  7.1328


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     59        0.6392        0.6354  7.1268


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     60        0.5801        0.6230  7.2142


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     61        0.6879        0.5728  7.1916


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     62        0.5998        0.6282  7.1874


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     63        0.7112        0.6925  7.3629


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     64        0.6765        0.6085  7.2315


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     65        0.6627        0.5820  7.1518


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     66        0.6338        [32m0.4800[0m  7.2126


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     67        0.6058        0.7589  7.1469


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     68        0.5748        0.7512  7.1564


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     69        [36m0.5263[0m        0.7586  7.1537


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     70        0.6442        [32m0.4348[0m  7.1706


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     71        0.6318        1.0094  7.1223


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     72        0.6023        0.7115  7.1704


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     73        0.5690        0.4944  7.2464


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     74        0.6221        0.6361  7.2466


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     75        0.6667        0.5676  7.2764


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     76        0.5495        0.8251  7.2116


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     77        0.5537        0.5845  7.2802


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     78        0.5922        0.5591  7.2197


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     79        0.5793        0.4607  7.2306


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     80        0.6128        0.7318  7.3695


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     81        0.6511        0.5850  7.3256


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     82        0.6736        0.7504  7.2524


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     83        0.5461        0.4903  7.3038


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     84        0.7455        0.5733  7.2591


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     85        0.6075        0.4985  7.2666


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     86        0.5403        0.6275  7.2693


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     87        0.5902        0.6554  7.2285


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     88        0.6990        0.6011  7.1925


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     89        0.6174        0.7282  7.3692


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     90        0.8033        0.5350  7.2361


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     91        1.6251        0.5295  7.3473


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     92        3.5158        4.9609  7.3057


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     93        2.0842        1.2195  7.3062


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     94        1.4653        0.9742  7.2279


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     95        1.0233        0.9162  7.3044


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     96        1.0708        0.5862  7.2212


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     97        0.7454        0.7523  7.2285


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     98        0.7054        0.4637  7.2787


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     99        0.7749        0.7126  7.2228


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

    100        0.6811        0.5368  7.3049


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1       [36m47.1216[0m       [32m54.4366[0m  7.2794


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      2       [36m37.7613[0m       [32m39.3703[0m  7.2648


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      3       [36m32.6165[0m       47.5861  7.2547


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      4       [36m28.6231[0m       [32m28.4956[0m  7.2995


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      5       [36m26.5550[0m       46.2599  7.2994


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      6       [36m25.1473[0m       32.9585  7.3104


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      7       [36m22.4791[0m       [32m19.9777[0m  7.2546


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      8       [36m21.0498[0m       [32m18.2215[0m  7.2551


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      9       [36m18.6103[0m       [32m13.3838[0m  7.2235


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     10       [36m18.2012[0m       16.8187  7.3040


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     11       [36m16.8980[0m       13.4892  7.2368


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     12       [36m15.2911[0m       [32m10.7468[0m  7.2826


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     13       [36m14.9816[0m       14.8144  7.2964


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     14       [36m14.0800[0m       [32m10.2340[0m  7.2435


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     15       [36m13.0612[0m        [32m7.6512[0m  7.3034


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     16       [36m11.4237[0m        8.0651  7.2558


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     17       [36m10.2709[0m        [32m5.8555[0m  7.2500


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     18       [36m10.0155[0m       17.0879  7.2576


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     19       11.1747       15.3225  7.2695


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     20       10.0229        9.2368  7.3255


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     21        [36m8.0391[0m        [32m5.6273[0m  7.2574


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     22        [36m6.9049[0m        [32m4.0221[0m  7.2685


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     23        [36m5.8068[0m        [32m3.9223[0m  7.2604


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     24        [36m5.2675[0m        [32m2.0977[0m  7.2398


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     25        5.4441        4.1335  7.2554


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     26        [36m3.9376[0m        2.1094  7.3322


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     27        [36m3.4767[0m        2.5557  7.2720


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     28        [36m3.2913[0m        2.6586  7.2865


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     29        [36m2.6953[0m        [32m1.7985[0m  7.2393


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     30        [36m2.5534[0m        1.8409  7.2614


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     31        [36m2.2013[0m        [32m1.2684[0m  7.2144


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     32        2.4920        2.7118  7.2500


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     33        2.4296        1.7690  7.3189


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     34        [36m1.7975[0m        1.2719  7.1910


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     35        [36m1.2417[0m        [32m1.0104[0m  7.2494


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     36        [36m1.2407[0m        [32m0.9664[0m  7.2358


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     37        [36m0.9880[0m        [32m0.8332[0m  7.2265


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     38        1.0050        1.1919  7.2533


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     39        1.1531        0.8463  7.2743


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     40        [36m0.8569[0m        [32m0.6804[0m  7.2131


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     41        0.9332        0.8790  7.2482


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     42        0.8864        [32m0.6613[0m  7.2759


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     43        0.9565        0.8631  7.2097


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     44        2.0958        2.4050  7.3208


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     45        3.7543        7.8743  7.2181


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     46        2.3404        2.6233  7.2071


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     47        1.9635        2.0846  7.2224


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     48        1.6180        1.7319  7.2597


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     49        1.5247        1.4687  7.2386


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     50        1.2308        1.2177  7.2143


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     51        0.9895        0.8391  7.1897


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     52        1.0781        1.0161  7.2651


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     53        0.8769        0.8417  7.2859


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     54        [36m0.8461[0m        0.8285  7.1888


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     55        0.9119        0.9523  7.2597


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     56        1.0067        0.8477  7.2078


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     57        [36m0.7729[0m        0.8427  7.2685


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     58        [36m0.7539[0m        0.7560  7.2629


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     59        1.1149        0.8848  7.2536


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     60        0.8061        0.9885  7.2072


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     61        0.9250        1.0301  7.1724


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     62        0.7914        0.9783  7.2293


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     63        [36m0.6958[0m        0.7923  7.2362


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     64        0.7172        0.7256  7.2067


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     65        0.7409        0.7080  7.2515


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     66        [36m0.6876[0m        0.7389  7.2296


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     67        [36m0.6589[0m        0.8404  7.2029


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     68        0.7679        0.7089  7.2219


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     69        0.7019        0.7476  7.3142


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     70        [36m0.5938[0m        [32m0.6112[0m  7.2693


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     71        0.7016        0.7306  7.2215


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     72        0.7235        [32m0.5213[0m  7.2157


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     73        0.6378        0.9812  7.1931


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     74        0.7129        0.7273  7.2362


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     75        0.7019        0.7280  7.2522


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     76        0.6922        0.5694  7.3137


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     77        0.7826        0.8118  7.2400


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     78        0.6313        0.7302  7.2577


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     79        0.6888        0.5911  7.2421


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     80        0.6626        0.7826  7.2625


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     81        0.7440        0.8160  7.3036


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     82        [36m0.5866[0m        0.7971  7.2854


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     83        0.6287        0.7938  7.2755


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     84        0.6477        0.7014  7.3061


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     85        [36m0.5798[0m        0.8325  7.2071


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     86        0.6255        0.7154  7.2153


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     87        0.6046        0.6629  7.2234


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     88        0.6975        0.8503  7.2061


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     89        0.6178        0.7461  7.2559


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     90        0.6530        0.6011  7.2345


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     91        0.6509        0.7103  7.1663


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     92        0.7792        0.7463  7.2518


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     93        0.6695        0.8352  7.2637


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     94        0.6961        0.6046  7.2784


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     95        0.7094        0.5311  7.3020


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     96        0.6591        0.5666  7.1809


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     97        0.6281        0.6804  7.3025


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     98        0.6271        0.6773  7.2092


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     99        0.5845        0.6012  7.2647


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

    100        0.7673        0.9323  7.1497


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1       [36m56.4121[0m       [32m64.5840[0m  7.2200


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      2       [36m45.7383[0m       [32m55.4118[0m  7.2185


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      3       [36m40.6574[0m       56.6505  7.2053


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      4       [36m36.3536[0m       [32m50.2174[0m  7.2428


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      5       [36m31.3265[0m       [32m39.7214[0m  7.1437


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      6       [36m28.7758[0m       [32m37.0428[0m  7.2104


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      7       [36m27.0082[0m       [32m27.6156[0m  7.1757


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      8       [36m24.3753[0m       31.4109  7.1810


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      9       [36m23.0411[0m       [32m17.3130[0m  7.3181


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     10       [36m21.4849[0m       21.8469  7.3176


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     11       [36m20.4663[0m       [32m14.7905[0m  7.2021


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     12       [36m18.2429[0m       16.5328  7.2883


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     13       [36m16.4597[0m       [32m12.6893[0m  7.2547


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     14       [36m16.1958[0m       17.1088  7.2203


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     15       [36m14.5218[0m        [32m9.2242[0m  7.2765


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     16       [36m13.6734[0m       11.4527  7.2784


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     17       [36m12.5996[0m        [32m7.0081[0m  7.2590


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     18       [36m11.4199[0m        8.3983  7.2857


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     19       [36m10.2331[0m        [32m4.3033[0m  7.2416


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     20        [36m9.3842[0m       10.0855  7.2414


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     21        9.7528        8.4328  7.3298


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     22        [36m8.7955[0m        [32m4.0982[0m  7.2396


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     23        [36m7.6487[0m        [32m3.8243[0m  7.2486


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     24        [36m6.2035[0m        [32m3.4132[0m  7.2302


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     25        [36m5.6327[0m        4.7894  7.2633


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     26        [36m5.0977[0m        [32m2.3954[0m  7.2484


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     27        [36m3.9907[0m        [32m2.3102[0m  7.2081


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     28        [36m3.1480[0m        [32m1.3543[0m  7.2550


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     29        [36m3.1207[0m        2.9141  7.2936


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     30        [36m2.9477[0m        1.5388  7.2147


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     31        [36m2.4331[0m        2.0761  7.2486


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     32        [36m2.0763[0m        [32m1.2274[0m  7.2372


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     33        2.1838        1.8471  7.2521


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     34        [36m2.0277[0m        [32m1.0006[0m  7.1867


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     35        [36m1.6163[0m        [32m0.8400[0m  7.3910


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     36        [36m1.1764[0m        0.9727  7.3280


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     37        [36m1.0625[0m        0.8454  7.3425


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     38        [36m0.9277[0m        [32m0.7272[0m  7.2311


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     39        2.3404        1.3906  7.2581


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     40        2.4740        2.2611  7.2484


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     41        1.7815        1.0328  7.2780


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     42        1.2277        0.8911  7.2690


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     43        [36m0.8703[0m        0.8383  7.2548


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     44        [36m0.8498[0m        0.8000  7.2234


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     45        0.9212        [32m0.5051[0m  7.2655


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     46        [36m0.7822[0m        0.9044  7.2113


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     47        [36m0.7105[0m        0.6787  7.2219


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     48        0.7182        0.7524  7.2908


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     49        [36m0.6440[0m        0.6699  7.1783


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     50        [36m0.6254[0m        0.6700  7.2425


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     51        [36m0.5873[0m        0.8632  7.2153


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     52        [36m0.5487[0m        0.7338  7.2552


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     53        [36m0.4887[0m        0.5442  7.2235


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     54        0.6901        0.6078  7.2536


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     55        0.7052        0.7061  7.2504


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     56        0.6666        0.5620  7.2116


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     57        0.6211        0.6177  7.2109


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     58        0.6040        0.6017  7.2678


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     59        0.6623        0.5249  7.2448


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     60        0.6354        0.8357  7.2421


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     61        0.5837        0.7453  7.2218


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     62        0.6362        0.6426  7.2149


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     63        0.5786        0.5860  7.2195


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     64        0.6475        [32m0.4964[0m  7.2321


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     65        0.5779        0.5908  7.2070


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     66        0.6789        0.7028  7.2113


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     67        0.5987        0.6099  7.2129


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     68        0.5647        0.5556  7.3169


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     69        0.6365        0.6365  7.2618


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     70        0.5616        0.7268  7.2445


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     71        0.5992        0.6733  7.2236


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     72        0.6050        0.5121  7.4015


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     73        0.6848        [32m0.4749[0m  7.3497


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     74        0.6236        0.6939  7.2562


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     75        0.6184        [32m0.4509[0m  7.3510


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     76        0.5268        [32m0.4305[0m  7.3342


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     77        0.6423        0.6683  7.2919


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     78        0.6977        0.5303  7.2190


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     79        0.6552        0.5963  7.2452


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     80        0.5533        0.6882  7.2535


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     81        0.5406        0.5066  7.3818


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     82        0.5957        0.6057  7.2734


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     83        0.6554        0.6563  7.2342


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     84        0.5256        1.0914  7.3964


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     85        0.5847        0.6363  7.3370


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     86        0.6172        0.6183  7.2808


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     87        [36m0.4738[0m        0.6474  7.2844


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     88        0.6246        0.7766  7.2858


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     89        0.5118        0.6138  7.3663


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     90        0.5130        0.7302  7.2489


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     91        0.6326        0.4639  7.2444


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     92        0.5235        0.6477  7.2709


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     93        0.5746        0.5722  7.2668


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     94        0.5352        0.6683  7.2889


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     95        0.6184        0.6687  7.1999


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     96        0.6388        0.7352  7.3669


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     97        0.5746        0.7931  7.2481


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     98        0.5655        0.7218  7.2951


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     99        0.6012        0.4797  7.2732


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

    100        0.5590        0.5762  7.2726


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1       [36m48.2549[0m       [32m35.4472[0m  7.2965


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      2       [36m38.4581[0m       [32m26.2959[0m  7.2301


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      3       [36m31.9958[0m       32.9108  7.2490


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      4       [36m29.2488[0m       [32m21.9090[0m  7.2555


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      5       [36m26.5876[0m       27.2362  7.3505


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      6       [36m24.7930[0m       26.7539  7.2701


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      7       [36m23.7966[0m       [32m17.3222[0m  7.2421


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      8       [36m21.5059[0m       [32m12.1002[0m  7.2144


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      9       [36m19.1836[0m        [32m9.7983[0m  7.2407


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     10       [36m17.8231[0m       10.2751  7.3662


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     11       [36m16.3417[0m        [32m8.1663[0m  7.1918


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     12       17.0575       11.7031  7.2538


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     13       [36m14.5394[0m        [32m2.1806[0m  7.2767


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     14       [36m13.5545[0m        3.7679  7.2406


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     15       [36m11.8146[0m        2.8290  7.3019


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     16       [36m10.6066[0m        3.0555  7.2168


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     17        [36m9.2432[0m        3.7878  7.2819


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     18        [36m8.4624[0m        4.2802  7.2768


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     19        [36m7.9862[0m        3.5959  7.3612


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     20        [36m6.3240[0m        [32m1.3445[0m  7.2463


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     21        [36m5.6439[0m        2.8621  7.2851


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     22        6.6214       17.1146  7.5025


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     23       10.9557       20.6146  7.2582


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     24        8.7672        7.4080  7.2675


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     25        6.7401        4.1576  7.2761


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     26        [36m5.2813[0m        3.0360  7.2179


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     27        [36m4.3287[0m        2.2749  7.2997


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     28        [36m3.4138[0m        1.6277  7.1872


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     29        [36m2.8650[0m        1.4897  7.2282


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     30        [36m2.2383[0m        [32m1.2174[0m  7.2235


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     31        [36m2.0011[0m        [32m1.1413[0m  7.2901


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     32        [36m1.6196[0m        [32m1.0058[0m  7.3066


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     33        [36m1.2846[0m        [32m0.9191[0m  7.4221


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     34        [36m1.1435[0m        0.9226  7.1585


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     35        1.3058        1.0850  7.2149


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     36        1.1620        [32m0.8971[0m  7.2380


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     37        [36m0.9862[0m        [32m0.7426[0m  7.1898


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     38        1.0090        0.9456  7.2096


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     39        [36m0.8388[0m        [32m0.6161[0m  7.2703


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     40        0.9138        [32m0.5158[0m  7.2810


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     41        [36m0.7756[0m        0.7071  7.2566


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     42        [36m0.7187[0m        0.7374  7.1996


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     43        0.7814        0.6874  7.1731


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     44        1.4210        1.0716  7.1484


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     45        1.4493        1.1033  7.2132


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     46        1.2179        0.7304  7.2087


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     47        0.8257        0.9280  7.2293


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     48        0.8109        0.7536  7.2298


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     49        [36m0.6922[0m        0.7069  7.2474


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     50        0.9497        0.5643  7.2852


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     51        [36m0.6187[0m        0.8967  7.2329


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     52        0.8543        0.7037  7.2974


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     53        [36m0.5996[0m        0.6957  7.2484


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     54        0.6458        0.6333  7.2458


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     55        0.7140        0.7265  7.1719


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     56        0.7440        0.5260  7.3805


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     57        0.6489        0.6388  7.2542


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     58        0.6097        0.6231  7.3563


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     59        0.6880        0.6885  7.2888


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     60        0.7208        0.5827  7.3055


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     61        0.6425        0.6940  7.2551


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     62        0.6152        0.6754  7.2221


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     63        [36m0.5993[0m        [32m0.4200[0m  7.1980


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     64        0.6294        0.4974  7.1955


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     65        0.8211        0.5796  7.2356


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     66        0.6496        0.6872  7.1819


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     67        [36m0.5806[0m        0.5978  7.1837


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     68        0.6775        0.6243  7.2430


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     69        0.6800        0.6303  7.2182


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     70        0.6693        0.7925  7.3036


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     71        0.6941        0.5961  7.1928


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     72        0.6516        0.5383  7.2754


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     73        0.6173        0.6629  7.2955


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     74        [36m0.5403[0m        0.7530  7.1805


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     75        0.6891        0.5665  7.1608


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     76        0.7704        0.7092  7.2867


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     77        0.6502        0.4470  7.2940


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     78        0.6088        0.5499  7.1982


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     79        0.7214        0.4404  7.2601


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     80        0.6062        0.6498  7.4006


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     81        0.6590        0.7364  7.2702


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     82        0.6282        0.6428  7.3107


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     83        0.7450        0.5929  7.3221


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     84        0.5722        0.7913  7.2476


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     85        0.7823        0.7337  7.2465


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     86        0.6951        0.5480  7.2174


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     87        0.7937        0.5646  7.2032


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     88        0.5866        0.6898  7.3065


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     89        0.5931        0.7501  7.3663


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     90        0.6337        0.6481  7.3429


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     91        0.6618        0.7901  7.2784


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     92        0.6855        0.6159  7.2446


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     93        0.5911        0.7413  7.3666


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     94        0.6733        0.5730  7.3229


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     95        0.6988        0.7446  7.2908


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     96        0.6695        0.7799  7.2371


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     97        0.7239        0.6353  7.2611


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     98        0.5503        0.5114  7.2825


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     99        0.7035        0.8012  7.3274


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

    100        0.6350        0.6187  7.2279


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

  epoch    train_loss    valid_loss     dur
-------  ------------  ------------  ------
      1       [36m49.6087[0m       [32m28.4417[0m  7.2878


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      2       [36m45.4255[0m       [32m25.0229[0m  7.1751


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      3       [36m41.6058[0m       [32m19.7842[0m  7.1515


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      4       [36m38.6741[0m       23.9012  7.1965


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      5       [36m34.8131[0m       [32m18.8599[0m  7.1523


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      6       [36m30.8040[0m       [32m14.0837[0m  7.1543


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      7       [36m27.6090[0m       26.6148  7.1681


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      8       [36m25.7638[0m       25.6177  7.2595


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

      9       [36m24.5669[0m       17.9231  7.2184


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     10       [36m22.0372[0m       [32m13.4462[0m  7.1725


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     11       [36m20.3016[0m       13.7617  7.1874


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     12       [36m18.4175[0m        [32m7.9950[0m  7.2202


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     13       [36m16.9537[0m        [32m4.6034[0m  7.1610


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     14       [36m15.6766[0m        9.7916  7.1242


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     15       [36m15.2298[0m        9.7937  7.2364


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     16       [36m13.4450[0m        5.4160  7.2347


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     17       [36m11.5341[0m        [32m3.8888[0m  7.1943


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     18       [36m10.2254[0m        [32m3.0392[0m  7.1551


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     19        [36m9.0610[0m        3.1125  7.1792


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     20        9.6830        7.8730  7.2432


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     21        9.1973        4.4578  7.2180


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     22        [36m7.8292[0m        3.0833  7.2840


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     23        [36m6.8903[0m        3.5156  7.2132


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     24        [36m6.3694[0m        4.7482  7.2427


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     25        [36m5.7814[0m        3.3186  7.3607


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     26        [36m4.7346[0m        [32m1.5729[0m  7.3407


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     27        [36m3.7368[0m        [32m1.4012[0m  7.2405


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     28        [36m3.4658[0m        [32m1.2592[0m  7.1944


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     29        [36m2.8463[0m        [32m1.1712[0m  7.1495


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     30        2.9281        1.2116  7.1894


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     31        [36m2.3987[0m        [32m1.0332[0m  7.2048


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     32        [36m1.9982[0m        [32m0.7600[0m  7.1670


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     33        [36m1.5384[0m        0.9418  7.2144


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     34        [36m1.3561[0m        0.8193  7.2361


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     35        [36m1.0590[0m        0.8253  7.2142


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     36        [36m1.0074[0m        [32m0.6154[0m  7.2227


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     37        [36m0.7448[0m        [32m0.6152[0m  7.1331


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     38        0.8134        [32m0.5479[0m  7.1739


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     39        1.1494        0.8372  7.2697


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     40        1.3502        0.7000  7.2158


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     41        0.9057        [32m0.5400[0m  7.2359


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     42        0.9154        0.6266  7.2534


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     43        0.7713        0.7009  7.3246


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     44        [36m0.6680[0m        0.6210  7.2694


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     45        [36m0.6670[0m        [32m0.3987[0m  7.3300


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     46        0.6715        0.5474  7.2519


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     47        [36m0.6435[0m        0.5181  7.2968


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     48        [36m0.6352[0m        0.5506  7.3119


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     49        0.7276        0.6934  7.2540


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     50        0.7755        0.5919  7.2640


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     51        [36m0.5892[0m        0.7161  7.4563


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     52        0.6540        0.4957  7.2584


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     53        0.5975        0.5137  7.2727


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     54        0.6049        0.5663  7.2876


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     55        0.6043        0.5379  7.2187


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     56        0.6380        0.5532  7.3525


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     57        [36m0.5601[0m        [32m0.3517[0m  7.2900


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     58        0.5995        0.3966  7.2582


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     59        0.5942        0.6927  7.2919


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     60        0.6367        0.7356  7.2084


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     61        0.7084        0.7119  7.2872


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     62        0.6094        0.6342  7.2406


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     63        0.6004        0.5665  7.2470


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     64        0.6340        0.5375  7.2505


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     65        0.6833        0.5335  7.2072


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     66        0.6399        0.6518  7.2247


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     67        0.6660        0.6098  7.1930


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     68        0.6617        0.6475  7.2406


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     69        0.6657        0.4549  7.2575


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     70        [36m0.5569[0m        0.5531  7.1899


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     71        0.5733        0.4266  7.2394


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     72        0.6919        0.4598  7.2473


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     73        0.6422        0.8191  7.2640


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     74        0.5750        0.4698  7.3071


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     75        0.6441        0.5078  7.2645


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     76        0.5751        0.5801  7.2703


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     77        0.6652        0.6349  7.2645


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     78        [36m0.5345[0m        0.5943  7.2845


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     79        0.6733        0.4487  7.2558


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     80        0.6777        0.5186  7.2489


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     81        0.5831        0.4658  7.2189


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     82        0.6247        0.6578  7.2297


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     83        0.6288        0.5238  7.2444


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     84        0.5620        0.5010  7.2464


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     85        0.7229        0.5136  7.2414


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     86        1.5343        2.0646  7.3099


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     87        1.8098        0.8179  7.2561


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     88        1.2855        0.7011  7.2769


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     89        0.9829        0.5318  7.2472


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     90        0.7654        0.6120  7.2442


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     91        0.7449        0.5709  7.2846


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     92        0.6838        0.5091  7.2397


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     93        0.6085        0.6112  7.2959


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     94        0.6140        0.5972  7.2324


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     95        [36m0.4928[0m        0.5242  7.2746


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     96        0.5238        0.4770  7.2574


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     97        0.6233        0.6102  7.2462


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     98        0.6102        0.4027  7.2518


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

     99        0.6024        0.6293  7.2698


HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

    100        0.5858        0.5365  7.2128


In [6]:
print('MAE :', mean_absolute_error(y_true, y_pred, multioutput='raw_values'))
print('MSE :', mean_squared_error(y_true, y_pred, multioutput='raw_values'))
print('RMSE :', np.sqrt(mean_squared_error(y_true, y_pred, multioutput='raw_values')))
print('R^2 : ', r2_score(y_true, y_pred, multioutput='variance_weighted'))
print('Explained variation : ',explained_variance_score(y_true, y_pred, multioutput='variance_weighted'))

NameError: name 'y_true' is not defined

In [None]:
y_pred = []
y_true = []

train_batcher = DataLoader(train_dataset,batch_size=10, shuffle=False, num_workers=16)

for b in iter(train_batcher):
    X_test, y_test = b
    y_pred.append(net.predict(X_test.numpy()))
    y_true.append(y_test.numpy())

y_pred = np.concatenate(y_pred, axis=0)
y_true = np.concatenate(y_true, axis=0)

if scaler:
    y_pred = scaler.inverse_transform(y_pred)
    y_true = scaler.inverse_transform(y_true)

In [None]:
y_pred[0]

In [None]:
y_true[0]

In [None]:
for ii in range(17):
    xx = np.linspace(min(y_true[:,ii]), max(y_true[:,ii]))
    plt.scatter(y_pred[:,ii], y_true[:,ii])
    plt.plot(xx,xx)
    plt.show()

In [None]:
y_true.mean(axis=0)

In [None]:
print(scaler.inverse_transform(train_y.values).mean(axis=0))
print(scaler.inverse_transform(train_y.values).std(axis=0))

In [None]:
print(scaler.inverse_transform(test_y.values).mean(axis=0))
print(scaler.inverse_transform(test_y.values).std(axis=0))

In [None]:
from sklearn.datasets import fetch_olivetti_faces

In [None]:
olivetti = fetch_olivetti_faces()

In [None]:
olivetti['data'].shape

In [None]:
olivetti['target'].shape

In [None]:
from sklearn.datasets import load_linnerud
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.multioutput import MultiOutputRegressor

linnerud = load_linnerud()

X = linnerud.data
Y = linnerud.target


In [None]:
Y

In [None]:
model = MultiOutputRegressor(GradientBoostingRegressor(), n_jobs=-1)
model.fit(X, Y)

In [None]:
model.predict(X)

In [None]:
Y