In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

from lifelines import CoxPHFitter
from lifelines import KaplanMeierFitter

from pycox.evaluation import EvalSurv
from sklearn.preprocessing import StandardScaler

In [2]:
random_seed = 137
torch.manual_seed(random_seed)
np.random.seed(random_seed)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

print(device)

cuda:0


In [3]:
# Early stopping class from https://github.com/Bjarten/early-stopping-pytorch
from SurvNODE.EarlyStopping import EarlyStopping
from SurvNODE.SurvNODE_x_ranking import *

In [4]:
def measures(odesurv,initial,x,Tstart,Tstop,From,To,trans,status, multiplier=1.,points=500):
    with torch.no_grad():
        time_grid = np.linspace(0, multiplier, points)
        pvec = torch.zeros((points,x.shape[0]))
        surv_ode = odesurv.predict(x,torch.from_numpy(np.linspace(0,multiplier,points)).float().to(x.device))
        pvec = torch.einsum("ilkj,k->ilj",(surv_ode[:,:,:,:],initial))[:,:,0].cpu()
        pvec = np.array(pvec.cpu().detach())
        surv_ode_df = pd.DataFrame(pvec)
        surv_ode_df.loc[:,"time"] = np.linspace(0,multiplier,points)
        surv_ode_df = surv_ode_df.set_index(["time"])
        ev_ode = EvalSurv(surv_ode_df, np.array(Tstop.cpu()), np.array(status.cpu()), censor_surv='km')
        conc = ev_ode.concordance_td('antolini')
        ibs = ev_ode.integrated_brier_score(time_grid)
        inbll = ev_ode.integrated_nbll(time_grid)
    return conc,ibs,inbll

In [5]:
from sklearn_pandas import DataFrameMapper
import pandas as pd

def make_dataloader(df,Tmax,batchsize):
#     cols_standardize = ['x0', 'x7', 'x8', 'x9', 'x10', 'x11', 'x12', 'x13']
#     cols_leave = ['x1', 'x2', 'x3', 'x4', 'x5', 'x6']

#     standardize = [([col], StandardScaler()) for col in cols_standardize]
#     leave = [(col, None) for col in cols_leave]
#     x_mapper = DataFrameMapper(standardize + leave)
#     X = x_mapper.fit_transform(df).astype('float32')

    X = df[['x0', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8']].values
    
    X = torch.from_numpy(X).float().to(device)
    T = torch.from_numpy(df[["duration"]].values).float().flatten().to(device)
    T = T/Tmax
    T[T==0] = 1e-8
    E = torch.from_numpy(df[["event"]].values).float().flatten().to(device)

    Tstart = torch.from_numpy(np.array([0 for i in range(T.shape[0])])).float().to(device)
    From = torch.tensor([1],device=device).repeat((T.shape))
    To = torch.tensor([2],device=device).repeat((T.shape))
    trans = torch.tensor([1],device=device).repeat((T.shape))

    dataset = TensorDataset(X,Tstart,T,From,To,trans,E)
    loader = DataLoader(dataset, batch_size=batchsize, shuffle=True)
    return loader

In [6]:
from sklearn.model_selection import train_test_split

def odesurv_manual_benchmark(df_train, df_test,config,name):
    torch.cuda.empty_cache()
    df_train, df_val = train_test_split(df_train, test_size=0.2, stratify=df_train.loc[:,"event"])
    
    Tmax = df_train["duration"].max()
    
    train_loader = make_dataloader(df_train,Tmax/config["multiplier"],int(len(df_train)*config["batch_size"]))
    val_loader = make_dataloader(df_val,Tmax/config["multiplier"],len(df_val))
    test_loader = make_dataloader(df_test,Tmax/config["multiplier"],len(df_test))
    
    num_in = 9
    num_latent = config["num_latent"]
    layers_encoder =  [config["encoder_neurons"]]*config["num_encoder_layers"]
    dropout_encoder = [config["encoder_dropout"]]*config["num_encoder_layers"]
    layers_odefunc1 =  [config["odefunc_neurons1"]]*config["num_odefunc_layers1"]
#     layers_odefunc2 =  [config["odefunc_neurons2"]]*config["num_odefunc_layers2"]

    trans_matrix = torch.tensor([[np.nan,1],[np.nan,np.nan]]).to(device)

    encoder = Encoder(num_in,num_latent,layers_encoder, dropout_encoder).to(device)
    odefunc = ODEFunc(trans_matrix,num_in,num_latent,layers_odefunc1,config["softplus_beta"]).to(device)
    block = ODEBlock(odefunc).to(device)
    odesurv = SurvNODE(block,encoder).to(device)

    optimizer = torch.optim.Adam(odesurv.parameters(), weight_decay = config["weight_decay"], lr=config["lr"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=config["scheduler_gamma"], patience=config["scheduler_epoch"])

    early_stopping = EarlyStopping(name=name,patience=config["patience"], verbose=True)
    t = tqdm(range(200))
    
        
    for i in t:
        odesurv.train()
        
        loss_all = 0
        for mini,ds in enumerate(train_loader):
            myloss,_,_ = loss(odesurv,*ds,mu=config["mu"])
            optimizer.zero_grad()
            myloss.backward()    
            optimizer.step()

            loss_all += myloss
        odesurv.eval()
        with torch.no_grad():
            lossval,conc,ibs,ibnll = 0., 0., 0., 0.
            for _,ds in enumerate(val_loader):
                t1,_,_ = loss(odesurv,*ds,mu=config["mu"])
                lossval += t1.item()
                t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds,multiplier=config["multiplier"])
                conc += t1
                ibs += t2
                ibnll += t3
            early_stopping(lossval/len(val_loader), odesurv)
            scheduler.step(lossval/len(val_loader))
#             tune.report(score=lossval/len(val_loader), iterations=i)
            
            conc_test,ibs_test,ibnll_test = 0., 0., 0.
            print("it: "+str(i)+", train loss="+str(loss_all.item())+", validation loss="+str(lossval/len(val_loader))+", c="+str(conc/len(val_loader))+", ibs="+str(ibs/len(val_loader))+", ibnll="+str(ibnll/len(val_loader)))

        if early_stopping.early_stop:
            print("Early stopping")
            break
            t.refresh()
            t.set_postfix({"loss.:": myloss/len(train_loader)})

#     odesurv.load_state_dict(torch.load(name+'_checkpoint.pt'))

    odesurv.eval()
    with torch.no_grad():
        conc,ibs,ibnll = 0., 0., 0.
        for _,ds in enumerate(test_loader):
            t1,t2,t3 = measures(odesurv,torch.tensor([1.,0.],device=device),*ds,multiplier=config["multiplier"])
            conc += t1
            ibs += t2
            ibnll += t3
    return lossval/len(val_loader), conc/len(test_loader), ibs/len(test_loader), ibnll/len(test_loader)

In [7]:
from sklearn.model_selection import StratifiedKFold
from pycox import datasets

kfold = StratifiedKFold(5, shuffle=True)
df_all = datasets.metabric.read_df()
gen = kfold.split(df_all.iloc[:,df_all.columns.values!="event"],df_all.loc[:,"event"])

config = {'batch_size': 0.1, 
          'encoder_dropout': 0.0, 
          'encoder_neurons': 223, 
          'lr': 0.00005, 
          'mu': 0.0001, 
          'multiplier': 1.0, 
          'num_encoder_layers': 3, 
          'num_latent': 145, 
          'num_odefunc_layers1': 4, 
          'num_odefunc_layers2': 3, 
          'odefunc_neurons1': 1421, 
          'odefunc_neurons2': 1231, 
          'patience': 20, 
          'scheduler_epoch': 5, 
          'scheduler_gamma': 0.1, 
          'softplus_beta': 1.0, 
          'weight_decay': 0.0001}

odesurv_bench_vals = []
# for g in gen:
#     df_train = df_all.iloc[g[0]]
#     df_test =  df_all.iloc[g[1]]
#     _, conc, ibs, ibnll = odesurv_manual_benchmark(df_train,df_test,config,"metabric_test")
#     odesurv_bench_vals.append([conc,ibs,ibnll])
#     scores = torch.tensor(odesurv_bench_vals)
#     print(scores)
#     print(torch.mean(scores, dim=0))
#     print(torch.std(scores, dim=0))
    
# print(scores)
# print(torch.mean(scores, dim=0))
# print(torch.std(scores, dim=0))

In [8]:
# print("c="+str(np.mean(np.array(odesurv_bench_vals)[:,0]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,0])))
# print("ibs="+str(np.mean(np.array(odesurv_bench_vals)[:,1]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,1])))
# print("ibnll="+str(np.mean(np.array(odesurv_bench_vals)[:,2]))+"+-"+str(np.std(np.array(odesurv_bench_vals)[:,2])))

In [9]:
from pycox import datasets
df = datasets.metabric.read_df()
df_train, df_test = train_test_split(df, test_size=0.2, stratify=df.loc[:,"event"])

In [10]:
torch.cuda.empty_cache()
from hyperopt import hp
args = {
    "lr": hp.choice("lr", [1e-4, 5e-4]),
    "weight_decay": hp.choice("weight_decay", [1e-3, 1e-4]),
    "num_latent": hp.randint('num_latent', 20, 200),
    "encoder_neurons": hp.randint('encoder_neurons', 100, 800),
    "num_encoder_layers": hp.randint("num_encoder_layers", 2, 5),
    "encoder_dropout": 0.,
    "odefunc_neurons1": hp.randint('odefunc_neurons1', 100, 1500),
    "num_odefunc_layers1": hp.randint("num_odefunc_layers1", 2, 5),
#     "odefunc_neurons2": hp.randint('odefunc_neurons2', 100, 1500),
#     "num_odefunc_layers2": 3,
    "batch_size": 128,
    "multiplier": 1.,
    "mu": 1e-4,
    "softplus_beta": hp.choice("softplus_beta", [1., 0.1]),
    "scheduler_epoch": 5,
    "scheduler_gamma": 0.1,
    "patience": 20
}
args = {'batch_size': hp.uniform('batch_size', 0.01, 0.5), 
          'encoder_dropout': hp.choice('encoder_dropout', [0.0,0.1,0.2]), 
          'encoder_neurons': hp.quniform('encoder_neurons', 20, 300), 
          'lr': 0.00005, #hp.choice('lr', [0.00005, 0.0005, 0.0001]), 
          'mu': 0.0001, 
          'multiplier': hp.choice('multiplier', [1.0, 2.0]), 
          'num_encoder_layers': 3, 
          'num_latent': hp.quniform('num_latent', 20, 200), 
          'num_odefunc_layers1': hp.randint('num_odefunc_layers1', 2, 5), 
#           'num_odefunc_layers2': 3, 
          'odefunc_neurons1': hp.quniform('odefunc_neurons1', 1000, 2000), 
#           'odefunc_neurons2': 1231, 
          'patience': 20, 
          'scheduler_epoch': 5, 
          'scheduler_gamma': 0.1, 
          'softplus_beta': 1.0, 
          'weight_decay': 0.0001}

In [15]:
# define an objective function
def objective(args):
    print(args)
    score, conc, ibs, ibnll = odesurv_manual_benchmark(df_train,df_test,args,"metabrick_test")
    print("\nResults: {}, {}, {}".format(conc, ibs, ibnll))
    with open("hp_sep_log.txt", "a") as f:
        f.write(str(args))
        f.write("\nResults: {}, {}, {}".format(conc, ibs, ibnll))
    return score

In [16]:
# # minimize the objective over the space
from hyperopt import fmin, tpe, space_eval
best = fmin(objective, args, algo=tpe.suggest, max_evals=100)

{'batch_size': 128, 'encoder_dropout': 0.0, 'encoder_neurons': 730, 'lr': 0.0005, 'mu': 0.0001, 'multiplier': 1.0, 'num_encoder_layers': 3, 'num_latent': 129, 'num_odefunc_layers1': 2, 'odefunc_neurons1': 556, 'patience': 20, 'scheduler_epoch': 5, 'scheduler_gamma': 0.1, 'softplus_beta': 0.1, 'weight_decay': 0.0001}
  0%|                                                                          | 0/100 [00:00<?, ?trial/s, best loss=?]

  0%|          | 0/200 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.408804).  Saving model ...                                                        
it: 0, train loss=0.43715882301330566, validation loss=0.4088039994239807, c=0.5318760811448341, ibs=0.28881434103571474, ibnll=1.0297291634181962
Validation loss decreased (0.408804 --> 0.401936).  Saving model ...                                                   
it: 1, train loss=0.4290658235549927, validation loss=0.40193551778793335, c=0.5025947476018242, ibs=0.2871722229746581, ibnll=1.0169307097269589
Validation loss decreased (0.401936 --> 0.393154).  Saving model ...                                                   
it: 2, train loss=0.421892374753952, validation loss=0.39315417408943176, c=0.5269067463437648, ibs=0.2850657185358277, ibnll=1.000712252099843
Validation loss decreased (0.393154 --> 0.384757).  Saving model ...                                                   
it: 3, train loss=0.41279610991477966, validation loss=0.3847571909427643, c=0.5802

it: 30, train loss=0.224263533949852, validation loss=0.21048656105995178, c=0.40622739424437804, ibs=0.21577546153722205, ibnll=0.6358479137522306
Validation loss decreased (0.210487 --> 0.207347).  Saving model ...                                                   
it: 31, train loss=0.22070057690143585, validation loss=0.20734748244285583, c=0.4651360276773078, ibs=0.2137503811140742, ibnll=0.6287078515854274
Validation loss decreased (0.207347 --> 0.204393).  Saving model ...                                                   
it: 32, train loss=0.2173394113779068, validation loss=0.2043931931257248, c=0.5709073753734863, ibs=0.21178570043044265, ibnll=0.621923310872244
Validation loss decreased (0.204393 --> 0.201618).  Saving model ...                                                   
it: 33, train loss=0.21417075395584106, validation loss=0.20161780714988708, c=0.4938826859569115, ibs=0.2098826812735606, ibnll=0.6154839329448738
Validation loss decreased (0.201618 --> 0.199014).

Validation loss decreased (0.170263 --> 0.170001).  Saving model ...                                                   
it: 61, train loss=0.1764790266752243, validation loss=0.1700010597705841, c=0.515301147979242, ibs=0.18011765520098783, ibnll=0.5311793426095048
Validation loss decreased (0.170001 --> 0.169756).  Saving model ...                                                   
it: 62, train loss=0.17615236341953278, validation loss=0.1697555035352707, c=0.6110709231011165, ibs=0.1796914734476861, ibnll=0.5302094810620012
Validation loss decreased (0.169756 --> 0.169544).  Saving model ...                                                   
it: 63, train loss=0.17584918439388275, validation loss=0.1695440709590912, c=0.6051580437175657, ibs=0.17929978944836866, ibnll=0.5293374160920684
Validation loss decreased (0.169544 --> 0.169359).  Saving model ...                                                   
it: 64, train loss=0.17558042705059052, validation loss=0.16935916244983673, c=

it: 91, train loss=0.17335879802703857, validation loss=0.16813994944095612, c=0.4327724484981915, ibs=0.17489889557566587, ibnll=0.5204152300447697
EarlyStopping counter: 4 out of 20                                                                                     
it: 92, train loss=0.17334237694740295, validation loss=0.16814233362674713, c=0.38521780154112284, ibs=0.1748820115691329, ibnll=0.5204245243222355
EarlyStopping counter: 5 out of 20                                                                                     
it: 93, train loss=0.17332902550697327, validation loss=0.16815394163131714, c=0.44022645069979555, ibs=0.17490259907465572, ibnll=0.5205435782445875
EarlyStopping counter: 6 out of 20                                                                                     
it: 94, train loss=0.173318549990654, validation loss=0.1681586503982544, c=0.4350684069822299, ibs=0.1748474434216072, ibnll=0.5203771766878585
EarlyStopping counter: 7 out of 20             

Validation loss decreased (0.168043 --> 0.168043).  Saving model ...                                                   
it: 122, train loss=0.17314445972442627, validation loss=0.16804301738739014, c=0.5927976096870577, ibs=0.17492212591670378, ibnll=0.5206433320206882
Validation loss decreased (0.168043 --> 0.168043).  Saving model ...                                                   
it: 123, train loss=0.1731443554162979, validation loss=0.1680428683757782, c=0.5929234156313886, ibs=0.17492229564483133, ibnll=0.5206435032733865
Validation loss decreased (0.168043 --> 0.168043).  Saving model ...                                                   
it: 124, train loss=0.17314423620700836, validation loss=0.16804258525371552, c=0.5929548671174713, ibs=0.17492239413207797, ibnll=0.5206435245512827
Validation loss decreased (0.168043 --> 0.168042).  Saving model ...                                                   
it: 125, train loss=0.17314404249191284, validation loss=0.1680422872304

EarlyStopping counter: 1 out of 20                                                                                     
it: 152, train loss=0.17314159870147705, validation loss=0.16803964972496033, c=0.5939613146721182, ibs=0.1749210189360962, ibnll=0.5206397185614865
EarlyStopping counter: 2 out of 20                                                                                     
it: 153, train loss=0.17314156889915466, validation loss=0.16803963482379913, c=0.5937411542695392, ibs=0.17492101523477727, ibnll=0.5206397129810643
Validation loss decreased (0.168040 --> 0.168040).  Saving model ...                                                   
it: 154, train loss=0.17314158380031586, validation loss=0.16803961992263794, c=0.593992766158201, ibs=0.1749210087165389, ibnll=0.5206396976058102
Validation loss decreased (0.168040 --> 0.168040).  Saving model ...                                                   
it: 155, train loss=0.17314155399799347, validation loss=0.16803961992263

it: 182, train loss=0.1731414943933487, validation loss=0.1680395007133484, c=0.5939613146721182, ibs=0.17492090018836803, ibnll=0.520639472187227
EarlyStopping counter: 3 out of 20                                                                                     
it: 183, train loss=0.1731414943933487, validation loss=0.16803953051567078, c=0.593992766158201, ibs=0.1749208931891914, ibnll=0.5206394525980829
Validation loss decreased (0.168039 --> 0.168039).  Saving model ...                                                   
it: 184, train loss=0.1731414496898651, validation loss=0.168039470911026, c=0.5939613146721182, ibs=0.17492089303653066, ibnll=0.5206394625943891
Validation loss decreased (0.168039 --> 0.168039).  Saving model ...                                                   
it: 185, train loss=0.1731414496898651, validation loss=0.168039470911026, c=0.5940242176442837, ibs=0.17492088279594056, ibnll=0.5206394387938489
EarlyStopping counter: 1 out of 20                  

  0%|          | 0/200 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.268016).  Saving model ...                                                        
it: 0, train loss=0.28365153074264526, validation loss=0.26801562309265137, c=0.4778272998364703, ibs=0.2433038298878813, ibnll=0.6741336855645694
Validation loss decreased (0.268016 --> 0.258524).  Saving model ...                                                   
it: 1, train loss=0.27343636751174927, validation loss=0.25852420926094055, c=0.4649693782665854, ibs=0.2342575429789681, ibnll=0.6543262784818654
Validation loss decreased (0.258524 --> 0.245797).  Saving model ...                                                   
it: 2, train loss=0.26387307047843933, validation loss=0.24579688906669617, c=0.41735338442299674, ibs=0.22220770136867987, ibnll=0.6283975421581239
Validation loss decreased (0.245797 --> 0.233190).  Saving model ...                                                   
it: 3, train loss=0.2510251998901367, validation loss=0.2331899255514145, c=0

it: 30, train loss=0.17572376132011414, validation loss=0.1725836992263794, c=0.5799531856222144, ibs=0.17000963837349176, ibnll=0.5055993596490447
EarlyStopping counter: 16 out of 20                                                                                    
it: 31, train loss=0.17571961879730225, validation loss=0.17257873713970184, c=0.5791195049219225, ibs=0.17000372578339226, ibnll=0.5055874094939448
EarlyStopping counter: 17 out of 20                                                                                    
it: 32, train loss=0.17571468651294708, validation loss=0.1725730448961258, c=0.5788309231410523, ibs=0.16999695635970657, ibnll=0.5055737147982483
EarlyStopping counter: 18 out of 20                                                                                    
it: 33, train loss=0.17570903897285461, validation loss=0.17256669700145721, c=0.5780293070830795, ibs=0.16998937816603527, ibnll=0.5055583443338194
EarlyStopping counter: 19 out of 20           

  0%|          | 0/200 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.456188).  Saving model ...                                                        
it: 0, train loss=0.42666083574295044, validation loss=0.4561879336833954, c=0.48797483966452887, ibs=0.3133387881183362, ibnll=1.1618860207373418
Validation loss decreased (0.456188 --> 0.453868).  Saving model ...                                                   
it: 1, train loss=0.42451268434524536, validation loss=0.45386800169944763, c=0.5230019733596448, ibs=0.3128141345765937, ibnll=1.157485980767316
Validation loss decreased (0.453868 --> 0.451506).  Saving model ...                                                   
it: 2, train loss=0.42230093479156494, validation loss=0.45150643587112427, c=0.5094351258016774, ibs=0.31227372612519544, ibnll=1.1529746262468736
Validation loss decreased (0.451506 --> 0.449064).  Saving model ...                                                   
it: 3, train loss=0.42006081342697144, validation loss=0.44906413555145264, c=0

it: 30, train loss=0.3502001464366913, validation loss=0.37451595067977905, c=0.5561174148988653, ibs=0.2929032457222501, ibnll=1.0026092615474407
Validation loss decreased (0.374516 --> 0.371927).  Saving model ...                                                   
it: 31, train loss=0.34775999188423157, validation loss=0.3719266653060913, c=0.5444314257523434, ibs=0.29216166951175476, ibnll=0.9974120085728424
Validation loss decreased (0.371927 --> 0.369368).  Saving model ...                                                   
it: 32, train loss=0.345347136259079, validation loss=0.3693682551383972, c=0.5321287617168229, ibs=0.29142129917706266, ibnll=0.9922638720246855
Validation loss decreased (0.369368 --> 0.366844).  Saving model ...                                                   
it: 33, train loss=0.34296169877052307, validation loss=0.3668440580368042, c=0.5321287617168229, ibs=0.2906823416244322, ibnll=0.9871670379238431
Validation loss decreased (0.366844 --> 0.364354).  

Validation loss decreased (0.309303 --> 0.307517).  Saving model ...                                                   
it: 61, train loss=0.28682637214660645, validation loss=0.30751746892929077, c=0.5476689689195856, ibs=0.27076904894061116, ibnll=0.8648271271522524
Validation loss decreased (0.307517 --> 0.305753).  Saving model ...                                                   
it: 62, train loss=0.2851641774177551, validation loss=0.30575305223464966, c=0.5503206709422792, ibs=0.270091883159939, ibnll=0.8611123882480596
Validation loss decreased (0.305753 --> 0.304011).  Saving model ...                                                   
it: 63, train loss=0.2835230827331543, validation loss=0.30401092767715454, c=0.5526640355204736, ibs=0.26941733783666993, ibnll=0.8574376593773838
Validation loss decreased (0.304011 --> 0.302292).  Saving model ...                                                   
it: 64, train loss=0.2819029688835144, validation loss=0.30229151248931885, c

it: 91, train loss=0.24523931741714478, validation loss=0.2632361650466919, c=0.5790885545140602, ibs=0.25170528314464596, ibnll=0.7692974628899933
Validation loss decreased (0.263236 --> 0.262030).  Saving model ...                                                   
it: 92, train loss=0.24412114918231964, validation loss=0.2620304524898529, c=0.5796743956586088, ibs=0.25111782985919673, ibnll=0.7666277933556799
Validation loss decreased (0.262030 --> 0.260851).  Saving model ...                                                   
it: 93, train loss=0.24301038682460785, validation loss=0.26085054874420166, c=0.5792118894918599, ibs=0.25053377193928494, ibnll=0.7639885156857601
Validation loss decreased (0.260851 --> 0.259680).  Saving model ...                                                   
it: 94, train loss=0.24192386865615845, validation loss=0.2596798539161682, c=0.581277750370005, ibs=0.24995312073465525, ibnll=0.7613791159115851
Validation loss decreased (0.259680 --> 0.258520

Validation loss decreased (0.233328 --> 0.232526).  Saving model ...                                                   
it: 122, train loss=0.21671262383460999, validation loss=0.2325255125761032, c=0.6114331524420326, ibs=0.23509304827104266, ibnll=0.6992692318149881
Validation loss decreased (0.232526 --> 0.231737).  Saving model ...                                                   
it: 123, train loss=0.21598199009895325, validation loss=0.2317366749048233, c=0.612913172175629, ibs=0.23461472936108768, ibnll=0.6974109312752477
Validation loss decreased (0.231737 --> 0.230957).  Saving model ...                                                   
it: 124, train loss=0.21526268124580383, validation loss=0.23095713555812836, c=0.6135915145535273, ibs=0.2341409298699778, ibnll=0.6955775809477944
Validation loss decreased (0.230957 --> 0.230188).  Saving model ...                                                   
it: 125, train loss=0.2145516276359558, validation loss=0.2301877737045288

it: 152, train loss=0.19857542216777802, validation loss=0.21293939650058746, c=0.6173223976319684, ibs=0.22226405268477567, ibnll=0.6520982244285818
Validation loss decreased (0.212939 --> 0.212484).  Saving model ...                                                   
it: 153, train loss=0.19807714223861694, validation loss=0.21248356997966766, c=0.6172298963986187, ibs=0.22206317672275813, ibnll=0.6513256854067784
Validation loss decreased (0.212484 --> 0.211909).  Saving model ...                                                   
it: 154, train loss=0.19759540259838104, validation loss=0.21190889179706573, c=0.622903305377405, ibs=0.22160312209194633, ibnll=0.6497566418322646
Validation loss decreased (0.211909 --> 0.211362).  Saving model ...                                                   
it: 155, train loss=0.19709312915802002, validation loss=0.21136176586151123, c=0.6261408485446472, ibs=0.2211470236706265, ibnll=0.6482194994720126
Validation loss decreased (0.211362 --> 0.

it: 182, train loss=0.18587929010391235, validation loss=0.19956772029399872, c=0.6212999506660088, ibs=0.2125085787875206, ibnll=0.6193875673272193
Validation loss decreased (0.199568 --> 0.199165).  Saving model ...                                                   
it: 183, train loss=0.18551091849803925, validation loss=0.1991645097732544, c=0.6211457819437592, ibs=0.21214414880315527, ibnll=0.6182555790460995
Validation loss decreased (0.199165 --> 0.198851).  Saving model ...                                                   
it: 184, train loss=0.18514060974121094, validation loss=0.1988510936498642, c=0.6172298963986187, ibs=0.21205253438806348, ibnll=0.6179525703489156
Validation loss decreased (0.198851 --> 0.198429).  Saving model ...                                                   
it: 185, train loss=0.18479330837726593, validation loss=0.1984289437532425, c=0.6201282683769117, ibs=0.211498705488886, ibnll=0.6162323806225175
Validation loss decreased (0.198429 --> 0.1982

  0%|          | 0/200 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.276821).  Saving model ...                                                        
it: 0, train loss=0.28474366664886475, validation loss=0.2768207788467407, c=0.572818662198737, ibs=0.2451448393932713, ibnll=0.675386258954228
Validation loss decreased (0.276821 --> 0.275750).  Saving model ...                                                   
it: 1, train loss=0.2837564945220947, validation loss=0.2757495045661926, c=0.5465588349014048, ibs=0.2442546309719321, ibnll=0.6733873245808549
Validation loss decreased (0.275750 --> 0.274663).  Saving model ...                                                   
it: 2, train loss=0.28269270062446594, validation loss=0.27466341853141785, c=0.5263242685913133, ibs=0.24336566428274387, ibnll=0.6713956123273848
Validation loss decreased (0.274663 --> 0.273545).  Saving model ...                                                   
it: 3, train loss=0.2816111445426941, validation loss=0.27354496717453003, c=0.5409

it: 30, train loss=0.24342672526836395, validation loss=0.23488834500312805, c=0.4820208789792499, ibs=0.21124520233323862, ibnll=0.6018038968467042
Validation loss decreased (0.234888 --> 0.233358).  Saving model ...                                                   
it: 31, train loss=0.2418995201587677, validation loss=0.23335768282413483, c=0.5251321046526615, ibs=0.21005314510432563, ibnll=0.5992904032009642
Validation loss decreased (0.233358 --> 0.231843).  Saving model ...                                                   
it: 32, train loss=0.24037589132785797, validation loss=0.23184259235858917, c=0.5061541435752029, ibs=0.2088783113026842, ibnll=0.5968166203706528
Validation loss decreased (0.231843 --> 0.230337).  Saving model ...                                                   
it: 33, train loss=0.23886917531490326, validation loss=0.23033717274665833, c=0.4921703827812863, ibs=0.20771640974134956, ibnll=0.5943730768807797
Validation loss decreased (0.230337 --> 0.2288

Validation loss decreased (0.198248 --> 0.197425).  Saving model ...                                                   
it: 61, train loss=0.205694779753685, validation loss=0.1974247545003891, c=0.502448769171285, ibs=0.1847612858678079, ibnll=0.5461410920310156
Validation loss decreased (0.197425 --> 0.196625).  Saving model ...                                                   
it: 62, train loss=0.20489083230495453, validation loss=0.19662532210350037, c=0.514434849851785, ibs=0.1842994056635077, ibnll=0.5451532092260485
Validation loss decreased (0.196625 --> 0.195852).  Saving model ...                                                   
it: 63, train loss=0.2041102796792984, validation loss=0.1958521008491516, c=0.5138226575589638, ibs=0.1838602008618406, ibnll=0.5442117764235673
Validation loss decreased (0.195852 --> 0.195103).  Saving model ...                                                   
it: 64, train loss=0.20335568487644196, validation loss=0.19510260224342346, c=0.51

it: 91, train loss=0.19040557742118835, validation loss=0.18210755288600922, c=0.5440456244361387, ibs=0.17790169000426515, ibnll=0.530974927819525
Validation loss decreased (0.182108 --> 0.181830).  Saving model ...                                                   
it: 92, train loss=0.190133199095726, validation loss=0.18182975053787231, c=0.5489753834256992, ibs=0.1778409289902771, ibnll=0.530823963383644
Validation loss decreased (0.181830 --> 0.181561).  Saving model ...                                                   
it: 93, train loss=0.18987222015857697, validation loss=0.18156149983406067, c=0.5461721871375177, ibs=0.17778652004444018, ibnll=0.5306869148319683
Validation loss decreased (0.181561 --> 0.181304).  Saving model ...                                                   
it: 94, train loss=0.18962058424949646, validation loss=0.18130375444889069, c=0.5458177600206212, ibs=0.17773808655043494, ibnll=0.5305630029237981
Validation loss decreased (0.181304 --> 0.181056)

Validation loss decreased (0.177070 --> 0.176984).  Saving model ...                                                   
it: 122, train loss=0.18551449477672577, validation loss=0.17698420584201813, c=0.581421574945225, ibs=0.17784480534939665, ibnll=0.5304097553008725
Validation loss decreased (0.176984 --> 0.176902).  Saving model ...                                                   
it: 123, train loss=0.18543925881385803, validation loss=0.17690160870552063, c=0.5842892125273875, ibs=0.1778761351883181, ibnll=0.5304697068859607
Validation loss decreased (0.176902 --> 0.176822).  Saving model ...                                                   
it: 124, train loss=0.18536745011806488, validation loss=0.17682193219661713, c=0.5832581518236886, ibs=0.1779081901926998, ibnll=0.5305316037109944
Validation loss decreased (0.176822 --> 0.176746).  Saving model ...                                                   
it: 125, train loss=0.1852983683347702, validation loss=0.176745787262916

it: 152, train loss=0.18415379524230957, validation loss=0.17543204128742218, c=0.5959530867379818, ibs=0.17873476457664356, ibnll=0.5321892200490758
Validation loss decreased (0.175432 --> 0.175415).  Saving model ...                                                   
it: 153, train loss=0.18412259221076965, validation loss=0.17541486024856567, c=0.5987562830261631, ibs=0.1787529491339194, ibnll=0.5322274932467298
Validation loss decreased (0.175415 --> 0.175390).  Saving model ...                                                   
it: 154, train loss=0.1841035932302475, validation loss=0.1753903180360794, c=0.598530738497229, ibs=0.17877787659070826, ibnll=0.5322788878767427
Validation loss decreased (0.175390 --> 0.175352).  Saving model ...                                                   
it: 155, train loss=0.1840839982032776, validation loss=0.17535223066806793, c=0.5982085320273232, ibs=0.17880453503547852, ibnll=0.532331456530579
Validation loss decreased (0.175352 --> 0.1753

Validation loss decreased (0.174400 --> 0.174346).  Saving model ...                                                   
it: 183, train loss=0.18310748040676117, validation loss=0.17434588074684143, c=0.6045559994844697, ibs=0.17839229546682167, ibnll=0.531181691479687
Validation loss decreased (0.174346 --> 0.174269).  Saving model ...                                                   
it: 184, train loss=0.1830366849899292, validation loss=0.17426873743534088, c=0.6058448253640933, ibs=0.1783121303464875, ibnll=0.5309838113442898
Validation loss decreased (0.174269 --> 0.174237).  Saving model ...                                                   
it: 185, train loss=0.18295246362686157, validation loss=0.17423690855503082, c=0.6077458435365383, ibs=0.1781573074401097, ibnll=0.5306309777256222
Validation loss decreased (0.174237 --> 0.174113).  Saving model ...                                                   
it: 186, train loss=0.18287187814712524, validation loss=0.174113184213638

  0%|          | 0/200 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.457116).  Saving model ...                                                        
it: 0, train loss=0.47485142946243286, validation loss=0.4571162164211273, c=0.5072620185636313, ibs=0.3254776264875189, ibnll=1.2458792524841242
Validation loss decreased (0.457116 --> 0.456174).  Saving model ...                                                   
it: 1, train loss=0.47392356395721436, validation loss=0.45617443323135376, c=0.5040550124888218, ibs=0.3252760907626198, ibnll=1.2439282439607586
Validation loss decreased (0.456174 --> 0.455063).  Saving model ...                                                   
it: 2, train loss=0.47295302152633667, validation loss=0.4550629258155823, c=0.392118165839218, ibs=0.32503133631149966, ibnll=1.2416294926213296
Validation loss decreased (0.455063 --> 0.453747).  Saving model ...                                                   
it: 3, train loss=0.47180652618408203, validation loss=0.4537470042705536, c=0.48

it: 30, train loss=0.40788838267326355, validation loss=0.390828013420105, c=9.250979061950723e-05, ibs=0.30945984395985254, ibnll=1.106773119003031
Validation loss decreased (0.390828 --> 0.388414).  Saving model ...                                                   
it: 31, train loss=0.40536507964134216, validation loss=0.3884141743183136, c=9.250979061950723e-05, ibs=0.3088048864895481, ibnll=1.1016171020675534
Validation loss decreased (0.388414 --> 0.386028).  Saving model ...                                                   
it: 32, train loss=0.40286552906036377, validation loss=0.386027991771698, c=0.022942428073637795, ibs=0.3081511739101885, ibnll=1.0965082277277292
Validation loss decreased (0.386028 --> 0.383666).  Saving model ...                                                   
it: 33, train loss=0.4003942608833313, validation loss=0.38366588950157166, c=0.07696814579543002, ibs=0.30749873694999974, ibnll=1.0914467903397318
Validation loss decreased (0.383666 --> 0.38

Validation loss decreased (0.328727 --> 0.326996).  Saving model ...                                                   
it: 61, train loss=0.34095439314842224, validation loss=0.3269955813884735, c=0.1537821086064942, ibs=0.28986960723875677, ibnll=0.9674917427357528
Validation loss decreased (0.326996 --> 0.325284).  Saving model ...                                                   
it: 62, train loss=0.3391551375389099, validation loss=0.32528379559516907, c=0.18292269265163896, ibs=0.28926851276842636, ibnll=0.9636616062874855
Validation loss decreased (0.325284 --> 0.323592).  Saving model ...                                                   
it: 63, train loss=0.33737602829933167, validation loss=0.32359206676483154, c=0.1441919269789386, ibs=0.28866947263209836, ibnll=0.9598684833316281
Validation loss decreased (0.323592 --> 0.321918).  Saving model ...                                                   
it: 64, train loss=0.33561739325523376, validation loss=0.3219179511070251

it: 91, train loss=0.2947465181350708, validation loss=0.28310710191726685, c=0.10302507015325789, ibs=0.2727593360198529, ibnll=0.8671337115933505
Validation loss decreased (0.283107 --> 0.281881).  Saving model ...                                                   
it: 92, train loss=0.2934500575065613, validation loss=0.28188079595565796, c=0.28992568380153566, ibs=0.27222206885854316, ibnll=0.8642517741414146
Validation loss decreased (0.281881 --> 0.280665).  Saving model ...                                                   
it: 93, train loss=0.29216986894607544, validation loss=0.2806648313999176, c=0.2937802584106818, ibs=0.27168700484376224, ibnll=0.8613967012360277
Validation loss decreased (0.280665 --> 0.279466).  Saving model ...                                                   
it: 94, train loss=0.29090023040771484, validation loss=0.2794656753540039, c=0.11634648000246693, ibs=0.2711540765728264, ibnll=0.8585678012380427
Validation loss decreased (0.279466 --> 0.27827

Validation loss decreased (0.251537 --> 0.250654).  Saving model ...                                                   
it: 122, train loss=0.26041358709335327, validation loss=0.2506539821624756, c=0.43572111381787904, ibs=0.257127095979917, ibnll=0.7890999297335682
Validation loss decreased (0.250654 --> 0.249780).  Saving model ...                                                   
it: 123, train loss=0.25948670506477356, validation loss=0.24977964162826538, c=0.4907336026396127, ibs=0.25665869137794006, ibnll=0.7869361541636912
Validation loss decreased (0.249780 --> 0.248916).  Saving model ...                                                   
it: 124, train loss=0.25856876373291016, validation loss=0.2489161491394043, c=0.5055660057356071, ibs=0.2561925533552937, ibnll=0.7847922959583358
Validation loss decreased (0.248916 --> 0.248061).  Saving model ...                                                   
it: 125, train loss=0.25766199827194214, validation loss=0.248060956597328

it: 152, train loss=0.23657405376434326, validation loss=0.22826065123081207, c=0.4926763082426223, ibs=0.2440742211868528, ibnll=0.7321989246182227
Validation loss decreased (0.228261 --> 0.227639).  Saving model ...                                                   
it: 153, train loss=0.23590673506259918, validation loss=0.22763873636722565, c=0.48191433593388633, ibs=0.24367487713375466, ibnll=0.7305638836669488
Validation loss decreased (0.227639 --> 0.227021).  Saving model ...                                                   
it: 154, train loss=0.235249325633049, validation loss=0.22702129185199738, c=0.45598075796355114, ibs=0.2432778393279812, ibnll=0.7289442478958744
Validation loss decreased (0.227021 --> 0.226414).  Saving model ...                                                   
it: 155, train loss=0.23459652066230774, validation loss=0.22641350328922272, c=0.49831940547041226, ibs=0.2428830818695977, ibnll=0.7273397129348909
Validation loss decreased (0.226414 --> 0.

it: 182, train loss=0.21952581405639648, validation loss=0.212376669049263, c=0.4436461192142835, ibs=0.23308311867793527, ibnll=0.6893048660740435
Validation loss decreased (0.212377 --> 0.211937).  Saving model ...                                                   
it: 183, train loss=0.21905352175235748, validation loss=0.21193687617778778, c=0.4604520645101607, ibs=0.23275139359663857, ibnll=0.6880758209437104
Validation loss decreased (0.211937 --> 0.211503).  Saving model ...                                                   
it: 184, train loss=0.21858470141887665, validation loss=0.21150335669517517, c=0.45298960806685373, ibs=0.2324218461191329, ibnll=0.6868585177232979
Validation loss decreased (0.211503 --> 0.211074).  Saving model ...                                                   
it: 185, train loss=0.21812234818935394, validation loss=0.2110743522644043, c=0.4564433069166487, ibs=0.23209449723797262, ibnll=0.6856529232437679
Validation loss decreased (0.211074 --> 0.2

  0%|          | 0/200 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.377158).  Saving model ...                                                        
it: 0, train loss=0.44525110721588135, validation loss=0.37715792655944824, c=0.5650535987748851, ibs=0.26822108439146763, ibnll=0.9397420776279374
Validation loss decreased (0.377158 --> 0.372290).  Saving model ...                                                   
it: 1, train loss=0.437107115983963, validation loss=0.37229031324386597, c=0.5400306278713629, ibs=0.2670418367939295, ibnll=0.9306955431601874
Validation loss decreased (0.372290 --> 0.363857).  Saving model ...                                                   
it: 2, train loss=0.43149322271347046, validation loss=0.36385729908943176, c=0.43972434915773356, ibs=0.2649645206085912, ibnll=0.9156029673323108
Validation loss decreased (0.363857 --> 0.354520).  Saving model ...                                                   
it: 3, train loss=0.421893835067749, validation loss=0.35452038049697876, c=0.5

it: 30, train loss=0.21683675050735474, validation loss=0.18625947833061218, c=0.4407350689127106, ibs=0.19388077837738368, ibnll=0.5706155661227817
Validation loss decreased (0.186259 --> 0.183718).  Saving model ...                                                   
it: 31, train loss=0.21334217488765717, validation loss=0.18371783196926117, c=0.45479326186830016, ibs=0.1920127609506281, ibnll=0.5647922597853136
Validation loss decreased (0.183718 --> 0.181381).  Saving model ...                                                   
it: 32, train loss=0.21007610857486725, validation loss=0.1813809871673584, c=0.4356508422664625, ibs=0.19022540601577595, ibnll=0.5593457618637397
Validation loss decreased (0.181381 --> 0.179222).  Saving model ...                                                   
it: 33, train loss=0.20704114437103271, validation loss=0.17922185361385345, c=0.46624808575803983, ibs=0.188519340596495, ibnll=0.5542613673757983
Validation loss decreased (0.179222 --> 0.1772

EarlyStopping counter: 1 out of 20                                                                                     
it: 61, train loss=0.17598943412303925, validation loss=0.16137553751468658, c=0.5449923430321593, ibs=0.16718186044401182, ibnll=0.501637752649523
EarlyStopping counter: 2 out of 20                                                                                     
it: 62, train loss=0.17583151161670685, validation loss=0.16138124465942383, c=0.554854517611026, ibs=0.16700342536489124, ibnll=0.5013401120524973
EarlyStopping counter: 3 out of 20                                                                                     
it: 63, train loss=0.17567786574363708, validation loss=0.16139459609985352, c=0.5624196018376723, ibs=0.16683668857708103, ibnll=0.5010707018715873
EarlyStopping counter: 4 out of 20                                                                                     
it: 64, train loss=0.1755441427230835, validation loss=0.16142193973064423,

  0%|          | 0/200 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.269541).  Saving model ...                                                        
it: 0, train loss=0.2823184132575989, validation loss=0.2695411145687103, c=0.5288565681572025, ibs=0.2442519818426777, ibnll=0.6764783039726875
Validation loss decreased (0.269541 --> 0.256272).  Saving model ...                                                   
it: 1, train loss=0.26869118213653564, validation loss=0.2562720477581024, c=0.5233792356490794, ibs=0.2326021301166712, ibnll=0.6508122522173532
Validation loss decreased (0.256272 --> 0.243734).  Saving model ...                                                   
it: 2, train loss=0.2557070553302765, validation loss=0.24373376369476318, c=0.5178709577595544, ibs=0.22185780287409357, ibnll=0.6275154769918667
Validation loss decreased (0.243734 --> 0.232167).  Saving model ...                                                   
it: 3, train loss=0.24344375729560852, validation loss=0.23216743767261505, c=0.52

it: 30, train loss=0.16762438416481018, validation loss=0.16379088163375854, c=0.6194336995203465, ibs=0.17325243649160807, ibnll=0.515641133797427
Validation loss decreased (0.163791 --> 0.163286).  Saving model ...                                                   
it: 31, train loss=0.16725394129753113, validation loss=0.16328638792037964, c=0.6192789726133374, ibs=0.17274807123872357, ibnll=0.514349290936896
Validation loss decreased (0.163286 --> 0.162899).  Saving model ...                                                   
it: 32, train loss=0.1667647659778595, validation loss=0.16289937496185303, c=0.6190004641807211, ibs=0.17216210878349278, ibnll=0.5129999417394722
Validation loss decreased (0.162899 --> 0.162527).  Saving model ...                                                   
it: 33, train loss=0.1662665456533432, validation loss=0.16252654790878296, c=0.6185053380782918, ibs=0.1714846491442083, ibnll=0.5114779602392434
Validation loss decreased (0.162527 --> 0.161936)

EarlyStopping counter: 1 out of 20                                                                                     
it: 61, train loss=0.157700315117836, validation loss=0.15509314835071564, c=0.6276342255918305, ibs=0.1623064416310488, ibnll=0.48934941498639256
EarlyStopping counter: 2 out of 20                                                                                     
it: 62, train loss=0.15731199085712433, validation loss=0.15461310744285583, c=0.6276032802104285, ibs=0.16225899648688658, ibnll=0.4889419849984559
Validation loss decreased (0.154488 --> 0.153938).  Saving model ...                                                   
it: 63, train loss=0.1569836586713791, validation loss=0.15393781661987305, c=0.6261797926659446, ibs=0.16251857027713618, ibnll=0.48896398554458625
EarlyStopping counter: 1 out of 20                                                                                     
it: 64, train loss=0.15678803622722626, validation loss=0.15475299954414368

it: 91, train loss=0.14985036849975586, validation loss=0.1499488651752472, c=0.6439733869719945, ibs=0.15714029989195888, ibnll=0.4756476491845156
Validation loss decreased (0.148322 --> 0.148081).  Saving model ...                                                   
it: 92, train loss=0.14955003559589386, validation loss=0.14808060228824615, c=0.6399814327711589, ibs=0.1580054395452832, ibnll=0.47644220710013513
EarlyStopping counter: 1 out of 20                                                                                     
it: 93, train loss=0.14913277328014374, validation loss=0.14855699241161346, c=0.6421785548506885, ibs=0.1572116713724054, ibnll=0.4750642920431713
EarlyStopping counter: 2 out of 20                                                                                     
it: 94, train loss=0.1487048864364624, validation loss=0.14874133467674255, c=0.6430759709113415, ibs=0.1569312479812738, ibnll=0.4745641540816151
Validation loss decreased (0.148081 --> 0.147495

EarlyStopping counter: 1 out of 20                                                                                     
it: 122, train loss=0.14254355430603027, validation loss=0.14518830180168152, c=0.6499767909639487, ibs=0.1534874290787786, ibnll=0.4654798987813106
Validation loss decreased (0.143307 --> 0.142779).  Saving model ...                                                   
it: 123, train loss=0.142656609416008, validation loss=0.14277924597263336, c=0.645489710660684, ibs=0.15577090068245777, ibnll=0.4693005502708991
EarlyStopping counter: 1 out of 20                                                                                     
it: 124, train loss=0.14285698533058167, validation loss=0.14505945146083832, c=0.6509360977874052, ibs=0.15320819059565832, ibnll=0.4648327544135913
Validation loss decreased (0.142779 --> 0.142516).  Saving model ...                                                   
it: 125, train loss=0.14241454005241394, validation loss=0.142515942454338

EarlyStopping counter: 1 out of 20                                                                                     
it: 152, train loss=0.13790154457092285, validation loss=0.14006514847278595, c=0.6563205941513229, ibs=0.1509263277601421, ibnll=0.45789844795612444
EarlyStopping counter: 2 out of 20                                                                                     
it: 153, train loss=0.137063130736351, validation loss=0.1389235556125641, c=0.6557016865232864, ibs=0.15140503519633627, ibnll=0.458519939821558
Validation loss decreased (0.138652 --> 0.138330).  Saving model ...                                                   
it: 154, train loss=0.13660873472690582, validation loss=0.13832981884479523, c=0.6548042704626335, ibs=0.15202669808360472, ibnll=0.4596694397718606
EarlyStopping counter: 1 out of 20                                                                                     
it: 155, train loss=0.1366717368364334, validation loss=0.1404442638158798

Validation loss decreased (0.134890 --> 0.134850).  Saving model ...                                                   
it: 182, train loss=0.13358767330646515, validation loss=0.13485030829906464, c=0.6607767290731859, ibs=0.15209873729812798, ibnll=0.45883972326306677
EarlyStopping counter: 1 out of 20                                                                                     
it: 183, train loss=0.1342737078666687, validation loss=0.1379466950893402, c=0.6587033885192635, ibs=0.14839681365121407, ibnll=0.4516146485459465
Validation loss decreased (0.134850 --> 0.134505).  Saving model ...                                                   
it: 184, train loss=0.13420478999614716, validation loss=0.13450495898723602, c=0.6613028005570168, ibs=0.15172987159205556, ibnll=0.45793919044936476
EarlyStopping counter: 1 out of 20                                                                                     
it: 185, train loss=0.13374052941799164, validation loss=0.13555437326

  0%|          | 0/200 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.409047).  Saving model ...                                                        
it: 0, train loss=0.43692028522491455, validation loss=0.40904706716537476, c=0.5964375173949346, ibs=0.2770400978379276, ibnll=0.9668574094433864
Validation loss decreased (0.409047 --> 0.401583).  Saving model ...                                                   
it: 1, train loss=0.42815762758255005, validation loss=0.401582807302475, c=0.5909639113090268, ibs=0.2752466839926786, ibnll=0.9539170047498091
Validation loss decreased (0.401583 --> 0.393503).  Saving model ...                                                   
it: 2, train loss=0.42032039165496826, validation loss=0.39350301027297974, c=0.5491542196245787, ibs=0.27324692941505413, ibnll=0.9401194319860201
Validation loss decreased (0.393503 --> 0.383875).  Saving model ...                                                   
it: 3, train loss=0.41212019324302673, validation loss=0.3838752508163452, c=0.4

it: 30, train loss=0.20329967141151428, validation loss=0.19220390915870667, c=0.6123016977456165, ibs=0.1914262074144969, ibnll=0.5607546935275544
Validation loss decreased (0.192204 --> 0.189601).  Saving model ...                                                   
it: 31, train loss=0.20020122826099396, validation loss=0.18960119783878326, c=0.639051241611776, ibs=0.18942258433987263, ibnll=0.5549970121895845
Validation loss decreased (0.189601 --> 0.187235).  Saving model ...                                                   
it: 32, train loss=0.19737215340137482, validation loss=0.18723466992378235, c=0.6355567925286824, ibs=0.18754172459750357, ibnll=0.5497184715951567
Validation loss decreased (0.187235 --> 0.185082).  Saving model ...                                                   
it: 33, train loss=0.1947798877954483, validation loss=0.18508240580558777, c=0.6476791291709187, ibs=0.1857371999535702, ibnll=0.5447735046393306
Validation loss decreased (0.185082 --> 0.183193

Validation loss decreased (0.169183 --> 0.169068).  Saving model ...                                                   
it: 61, train loss=0.17330853641033173, validation loss=0.16906790435314178, c=0.6010761666202802, ibs=0.16684969866942137, ibnll=0.5009931758630704
Validation loss decreased (0.169068 --> 0.168988).  Saving model ...                                                   
it: 62, train loss=0.1731421798467636, validation loss=0.168988436460495, c=0.6216408448526456, ibs=0.1669092427973906, ibnll=0.5010262463342114
Validation loss decreased (0.168988 --> 0.168777).  Saving model ...                                                   
it: 63, train loss=0.173108771443367, validation loss=0.16877660155296326, c=0.6329591489624888, ibs=0.16676662814054735, ibnll=0.5007214088484543
EarlyStopping counter: 1 out of 20                                                                                     
it: 64, train loss=0.17289744317531586, validation loss=0.168927863240242, c=0.

it: 91, train loss=0.1703696846961975, validation loss=0.16363149881362915, c=0.6300213377864365, ibs=0.16226540502681946, ibnll=0.49045338150160683
Validation loss decreased (0.163631 --> 0.162842).  Saving model ...                                                   
it: 92, train loss=0.16884015500545502, validation loss=0.16284199059009552, c=0.6218573151498283, ibs=0.16441567911475416, ibnll=0.49270675806103365
EarlyStopping counter: 1 out of 20                                                                                     
it: 93, train loss=0.1705198884010315, validation loss=0.16421718895435333, c=0.6251971425920771, ibs=0.16289496453639785, ibnll=0.4916255693652527
EarlyStopping counter: 2 out of 20                                                                                     
it: 94, train loss=0.16941657662391663, validation loss=0.16516435146331787, c=0.6198163094906763, ibs=0.1636022980835212, ibnll=0.4930589880131397
EarlyStopping counter: 3 out of 20           

Validation loss decreased (0.161064 --> 0.160971).  Saving model ...                                                   
it: 122, train loss=0.16648393869400024, validation loss=0.16097141802310944, c=0.6328973003061509, ibs=0.16237180743711768, ibnll=0.4892133927875638
Validation loss decreased (0.160971 --> 0.160892).  Saving model ...                                                   
it: 123, train loss=0.1664150506258011, validation loss=0.1608920395374298, c=0.6339487274638959, ibs=0.1623716580448848, ibnll=0.48918452953354163
Validation loss decreased (0.160892 --> 0.160823).  Saving model ...                                                   
it: 124, train loss=0.1663525253534317, validation loss=0.1608230322599411, c=0.6344125923864304, ibs=0.1623560474931089, ibnll=0.4891355339101855
Validation loss decreased (0.160823 --> 0.160764).  Saving model ...                                                   
it: 125, train loss=0.16629159450531006, validation loss=0.1607643514871597

it: 152, train loss=0.1649111807346344, validation loss=0.15933604538440704, c=0.636886538639948, ibs=0.16193868796522515, ibnll=0.48779446749244315
Validation loss decreased (0.159336 --> 0.159283).  Saving model ...                                                   
it: 153, train loss=0.16486848890781403, validation loss=0.15928299725055695, c=0.637010235952624, ibs=0.16192880151713723, ibnll=0.48775532221377566
Validation loss decreased (0.159283 --> 0.159231).  Saving model ...                                                   
it: 154, train loss=0.16482660174369812, validation loss=0.15923123061656952, c=0.636948387296286, ibs=0.16191234358351123, ibnll=0.48770674152134746
Validation loss decreased (0.159231 --> 0.159181).  Saving model ...                                                   
it: 155, train loss=0.1647852659225464, validation loss=0.15918098390102386, c=0.6371339332652998, ibs=0.16189035211827868, ibnll=0.48765041164710043
Validation loss decreased (0.159181 --> 0

it: 182, train loss=0.1637614369392395, validation loss=0.15793956816196442, c=0.6388966199709312, ibs=0.16146772782630622, ibnll=0.4864291243767927
Validation loss decreased (0.157940 --> 0.157896).  Saving model ...                                                   
it: 183, train loss=0.16372747719287872, validation loss=0.15789617598056793, c=0.6388038469864242, ibs=0.16145028630011718, ibnll=0.4863818807122172
Validation loss decreased (0.157896 --> 0.157853).  Saving model ...                                                   
it: 184, train loss=0.1636936068534851, validation loss=0.15785303711891174, c=0.6390821659399449, ibs=0.1614330679817255, ibnll=0.48633571559978117
Validation loss decreased (0.157853 --> 0.157810).  Saving model ...                                                   
it: 185, train loss=0.16365978121757507, validation loss=0.15781007707118988, c=0.639051241611776, ibs=0.1614168353589616, ibnll=0.4862917522625187
Validation loss decreased (0.157810 --> 0.15

  0%|          | 0/200 [00:00<?, ?it/s]

Validation loss decreased (inf --> 0.396619).  Saving model ...                                                        
it: 0, train loss=0.43976372480392456, validation loss=0.396618515253067, c=0.5666913443148868, ibs=0.27651250300533387, ibnll=0.9736606082173378
Validation loss decreased (0.396619 --> 0.387305).  Saving model ...                                                   
it: 1, train loss=0.42959922552108765, validation loss=0.3873048722743988, c=0.5675859090628663, ibs=0.27430296005034693, ibnll=0.9571125897169593
Validation loss decreased (0.387305 --> 0.374410).  Saving model ...                                                   
it: 2, train loss=0.41960060596466064, validation loss=0.37441006302833557, c=0.6021963106915911, ibs=0.27105456340179407, ibnll=0.9338510640541148
Validation loss decreased (0.374410 --> 0.361035).  Saving model ...                                                   
it: 3, train loss=0.40586796402931213, validation loss=0.36103519797325134, c=0

it: 30, train loss=0.19045336544513702, validation loss=0.17680463194847107, c=0.6308532296872108, ibs=0.18422778455007277, ibnll=0.5408612971417154
Validation loss decreased (0.176805 --> 0.175138).  Saving model ...                                                   
it: 31, train loss=0.18817032873630524, validation loss=0.17513830959796906, c=0.6303288296625331, ibs=0.18256203359222423, ibnll=0.5365129869037455
Validation loss decreased (0.175138 --> 0.173683).  Saving model ...                                                   
it: 32, train loss=0.18612416088581085, validation loss=0.1736830621957779, c=0.6298044296378555, ibs=0.18102446320845242, ibnll=0.5325968842076642
Validation loss decreased (0.173683 --> 0.172434).  Saving model ...                                                   
it: 33, train loss=0.18430913984775543, validation loss=0.17243418097496033, c=0.6240668764266766, ibs=0.17961078299848363, ibnll=0.5290823258516967
Validation loss decreased (0.172434 --> 0.171

Validation loss decreased (0.162746 --> 0.162468).  Saving model ...                                                   
it: 61, train loss=0.170539990067482, validation loss=0.16246813535690308, c=0.6110802640508359, ibs=0.16728499509017702, ibnll=0.50067520782583
Validation loss decreased (0.162468 --> 0.162128).  Saving model ...                                                   
it: 62, train loss=0.17023339867591858, validation loss=0.16212797164916992, c=0.6135788759331235, ibs=0.16699392399023527, ibnll=0.5000208752420051
Validation loss decreased (0.162128 --> 0.161508).  Saving model ...                                                   
it: 63, train loss=0.16988718509674072, validation loss=0.16150790452957153, c=0.6126226170645938, ibs=0.16671177812903776, ibnll=0.49912727231688964
Validation loss decreased (0.161508 --> 0.160763).  Saving model ...                                                   
it: 64, train loss=0.169437438249588, validation loss=0.16076251864433289, c

it: 91, train loss=0.16110706329345703, validation loss=0.1525839865207672, c=0.6254858411993337, ibs=0.1618250214679401, ibnll=0.48165837992958016
Validation loss decreased (0.152584 --> 0.152295).  Saving model ...                                                   
it: 92, train loss=0.16002731025218964, validation loss=0.15229454636573792, c=0.62474551175273, ibs=0.16196401004573946, ibnll=0.4814552568918095
EarlyStopping counter: 1 out of 20                                                                                     
it: 93, train loss=0.16016116738319397, validation loss=0.15381596982479095, c=0.6282929236843728, ibs=0.1614447836039798, ibnll=0.4819950715039992
EarlyStopping counter: 2 out of 20                                                                                     
it: 94, train loss=0.15968750417232513, validation loss=0.1536223292350769, c=0.6280153001418964, ibs=0.16132549463798768, ibnll=0.4815453871816267
Validation loss decreased (0.152295 --> 0.151935)

it: 121, train loss=0.15417367219924927, validation loss=0.14972524344921112, c=0.6277068295391449, ibs=0.1595242786618832, ibnll=0.4742511024992986
EarlyStopping counter: 2 out of 20                                                                                     
it: 122, train loss=0.1540268510580063, validation loss=0.14912784099578857, c=0.6282003825035474, ibs=0.15953965774124437, ibnll=0.47374092574184623
Validation loss decreased (0.149099 --> 0.148885).  Saving model ...                                                   
it: 123, train loss=0.15377511084079742, validation loss=0.14888527989387512, c=0.6282312295638226, ibs=0.1595425126165235, ibnll=0.4735164909614056
EarlyStopping counter: 1 out of 20                                                                                     
it: 124, train loss=0.15363933145999908, validation loss=0.1494952142238617, c=0.6282929236843728, ibs=0.1594289276465864, ibnll=0.4738722307427349
Validation loss decreased (0.148885 --> 0.14

it: 151, train loss=0.14918789267539978, validation loss=0.14782275259494781, c=0.6225553704731939, ibs=0.1586143842031684, ibnll=0.4711341765510447
Validation loss decreased (0.146882 --> 0.146602).  Saving model ...                                                   
it: 152, train loss=0.1492001712322235, validation loss=0.14660190045833588, c=0.6246838176321796, ibs=0.15894757534690626, ibnll=0.47071134268931225
EarlyStopping counter: 1 out of 20                                                                                     
it: 153, train loss=0.1491970270872116, validation loss=0.14798963069915771, c=0.6220001233882411, ibs=0.15857005908651442, ibnll=0.47114315329778467
EarlyStopping counter: 2 out of 20                                                                                     
it: 154, train loss=0.14903947710990906, validation loss=0.14667101204395294, c=0.6240051823061262, ibs=0.15875647833098777, ibnll=0.47039928278947296
EarlyStopping counter: 3 out of 20      

job exception: CUDA out of memory. Tried to allocate 862.00 MiB (GPU 0; 8.00 GiB total capacity; 5.17 GiB already allocated; 342.27 MiB free; 6.11 GiB reserved in total by PyTorch)



  8%|███                                   | 8/100 [16:35:00<190:42:37, 7462.58s/trial, best loss: 0.13281087577342987]


RuntimeError: CUDA out of memory. Tried to allocate 862.00 MiB (GPU 0; 8.00 GiB total capacity; 5.17 GiB already allocated; 342.27 MiB free; 6.11 GiB reserved in total by PyTorch)

In [None]:
print(best)
print(space_eval(space, best))