In [1]:
import numpy as np
import pandas as pd
from src.latin_square import LatinSquare
from ngboost import NGBRegressor

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error


from scipy.special import softmax

In [2]:
def utility(y):
    return y

In [3]:
co_problem = LatinSquare()
ngb = NGBRegressor()
ngb.fit(co_problem.X, co_problem.y)

[iter 0] loss=-1.4863 val_loss=0.0000 scale=1.0000 norm=0.4762
[iter 100] loss=-1.8591 val_loss=0.0000 scale=1.0000 norm=0.2911
[iter 200] loss=-2.1151 val_loss=0.0000 scale=1.0000 norm=0.2852
[iter 300] loss=-2.3637 val_loss=0.0000 scale=1.0000 norm=0.2876
[iter 400] loss=-2.6138 val_loss=0.0000 scale=1.0000 norm=0.2797


In [40]:
## Update z
def update_z(idx, z_init):
    Z_candidates = co_problem.generate_candidates_idx(z_init,idx)
    Z_candidates_d = co_problem.dummify(Z_candidates)
    preds = ngb.pred_dist(Z_candidates_d)
    #for i in preds:
    #    print(i.dist.mean())

    energies = np.zeros(Z_candidates.shape[0])

    for i in range(len(energies)):
        energies[i] = np.sum(preds[i].dist.logpdf(y_sample))

    p = softmax(energies)
    candidate_idx = np.random.choice( np.arange(len(Z_candidates)), p=p )
    return Z_candidates[candidate_idx]

In [41]:
def update_y(z_init, y_sample, value):
    z_init_d = co_problem.dummify(z_init)
    #--#
    preds = ngb.pred_dist(z_init_d)
    y_new = preds[0].sample(len(y_sample))
    value_new = utility(y_new) 
    condition = np.random.uniform(size=len(value_new)) < value_new / value
    y_sample[condition] = y_new[condition] 

    return y_sample, utility(y_sample)

In [42]:
def update_y_alt(z_init, y_sample, value):
    H = len(y_sample)
    z_init_d = co_problem.dummify(z_init)
    #--#
    preds = ngb.pred_dist(z_init_d)
    y_new = preds[0].sample(len(y_sample))
    value_new = np.sum( np.log(utility(y_new)) ) / H
    
    if np.random.uniform() < np.exp(H*value_new - H*value):
        print("Accept!")
        return y_new, value_new
    
    else:
        return y_sample, value
    

In [52]:
## Init
step=100
sch = np.arange(1000, 100000, step)
z_init = co_problem.generate_candidate()
print(z_init)
z_init_d = co_problem.dummify(z_init)
#--#
preds = ngb.pred_dist(z_init_d)
y_sample = preds[0].sample(sch[0])
#--#
value = np.sum( np.log(utility(y_sample)) ) / sch[0]
#print(preds[0].dist.mean(), preds[0].dist.std())
for i in sch:
    y_sample, value = update_y_alt(z_init, y_sample, value)
    y_sample = np.append(y_sample, np.random.choice(y_sample, step))
    value = np.sum( np.log(utility(y_sample)) ) / len(y_sample)
    for idx in range(len(z_init[0])):
        z_init = update_z(idx, z_init)
        z_init = z_init.reshape(1,-1)
    
    print(np.mean(y_sample))

[[4. 0. 4. 4. 4. 4. 4. 2. 0. 0. 1. 2. 4. 3. 2. 3. 3. 0. 3. 4. 3. 0. 4. 3.
  2.]]
0.5625346695451642
Accept!
0.5636543062674817
Accept!
0.5647784553151827
0.5648714467616948
0.5648785530448528
Accept!
0.5649461657136345
Accept!
0.5649062111570319
0.5650186654971403
Accept!
0.5644517785361627
Accept!
0.5643298556478084
Accept!
0.5643175955812518
Accept!
0.5639610182145097
0.5638485620012058
0.5638914153112536
Accept!
0.5640657935671385
0.5640780141960253
Accept!
0.5640513240838682
Accept!
0.564332945198597
Accept!
0.564407636671347
Accept!
0.56411167648269
Accept!
0.564465710845459
Accept!
0.564475041532097
Accept!
0.564788096155187
0.5647748263268352
Accept!
0.5649949886821883
0.5650239338498794
0.5650467451167023
0.5650499630065451
Accept!
0.5652155405694649
0.5652040221996332
Accept!
0.565449598728811
0.5654465814900221
0.5654710961185843
0.5654719080796887
0.5654892691440578
0.5654874770293191
0.5654545518334454
Accept!
0.5656593554605793
Accept!
0.5657725128560063
0.5657969047357324

KeyboardInterrupt: 

In [56]:
## Init
step=10
sch = np.arange(1, 100000, 1)
z_init = co_problem.generate_candidate()
print(z_init)
z_init_d = co_problem.dummify(z_init)
#--#
preds = ngb.pred_dist(z_init_d)
y_sample = preds[0].sample(sch[0])
#--#
value = utility(y_sample)
#print(preds[0].dist.mean(), preds[0].dist.std())

for i in sch:
    y_sample, value = update_y(z_init, y_sample, value)
    y_sample = np.append(y_sample, np.random.choice(y_sample, step))
    value = utility(y_sample)
    for idx in range(len(z_init[0])):
        z_init = update_z(idx, z_init)
        z_init = z_init.reshape(1,-1)
    
    print(np.mean(y_sample))
    
        

[[0. 1. 4. 1. 3. 3. 0. 2. 0. 2. 2. 0. 0. 4. 2. 2. 4. 3. 4. 0. 3. 3. 1. 2.
  1.]]
0.5783608164473517
0.5762659621576682
0.5764533929851441
0.5747380435638588
0.572181478821821
0.5716926951940425
0.5733131788823579
0.5733192507593247
0.5716410186011726
0.569156884364039
0.5729988339509832
0.5723042469327577
0.574207017718909
0.5748999889843109
0.574577373609394
0.5770140332099688
0.5785322871087908
0.5823095609794624
0.583381834754458
0.585097018424002
0.5864204799927848
0.5874493608446345
0.5884575786777524
0.5902292260148851
0.5905000509073596
0.5898892566440433
0.5895423251862882
0.5911002778924302
0.5908443011612725
0.5928201240468864
0.5892521394006927
0.5908586359881501
0.5889958503594961
0.5908329149074409
0.5898024248187911
0.5882549332744454
0.5908570841754449
0.5912153788148153
0.5914217519921415
0.5917288512371217
0.591870651719188
0.5928998379908725
0.5942800044671579
0.5947040545392426
0.5936350924154188
0.5941161924546786
0.5941780513470493
0.5962269913253035
0.598366451500

0.6310571541217477
0.6305794886484783
0.6308702912932344
0.6318410369495061
0.6310074543278531
0.6316981386340975
0.6307760232798065
0.6319033829145657
0.6316559398463364
0.6312617991708604
0.6317599505872931
0.6310135000265384
0.6311027316605965
0.6311899522639415
0.6317320957861101
0.6313458945569719
0.6312615974677162
0.6309929947487999
0.6311297290803003
0.6312646919878442
0.6315340180739076
0.6311558267815386
0.6310834676977903
0.6312879013530286
0.6312666564881882
0.6313201301481726
0.6311739456514349
0.6317041596853653
0.6311222549977595
0.6312305109112127
0.6311177203118973
0.6310526185475264
0.6314498024173157
0.6309633410895172
0.631256444999829
0.6318173694050758
0.6309321790744628
0.6318424662279077
0.6311503843768586
0.6306751341143768
0.6313430623987862
0.6308474146008726
0.6316546990967773
0.631370421023673
0.6315453544422994
0.6308082076386882
0.6306827338471104
0.6313285139267798
0.6311373467797577
0.6312082336650473
0.6315114051629559
0.6314737827785877
0.631244472923

0.6310284281649908
0.6308515605334315
0.6308624204055718
0.6309978216697272
0.6312533806965042
0.6309007159935965
0.6311837252219877
0.6312828188361046
0.631216032999801
0.6312932536133384
0.6310506306560121
0.6310929637455479
0.6313700486561937
0.631029521049779
0.6308347217145974
0.6311673775705628
0.6309928619568337
0.6311157721575869
0.6307559180269287
0.6310898557466933
0.6310950667941457
0.6305391256749977
0.6311199633651893
0.6305564110255223
0.6308351748706416
0.6315373123975507
0.6313750624940914
0.6311082728473731
0.6312127780442124
0.6308940079778607
0.6309785210302651
0.6310686778204574
0.6310825649156121
0.6309976484284543
0.6312107140594042
0.6313426881667473
0.631299774636828
0.6309335636753978
0.6312584645700736
0.6308335253986406
0.6310594110343324
0.6310485928139926
0.6312374755784063
0.6314026055751571
0.6309828962092341
0.630901352955558
0.6313532406142397
0.6310489207785216
0.6309625625304334
0.6309611545115298
0.6311898564240486
0.630959850126092
0.631477761492202

0.631045702601391
0.6309660354925082
0.6309201788684646
0.6308745392116225
0.6306646082855171
0.6309729732623058
0.6311923453524196
0.6310936240942581
0.631057271027761
0.6309342251948719
0.6309615377668653
0.6312296759651737
0.6309363125197889
0.6309993670232215
0.6311861134546969
0.630656503283984
0.6310847312316845
0.6309878498396813
0.6310611598625945
0.6308729574249354
0.631095650533284
0.6309636476656609
0.6311365030859607
0.6311627253144038
0.6313656591500972
0.6309714987845245
0.630936413224178
0.6310020705127904
0.6312507799723482
0.6311974218515432
0.6312045390296619
0.6308701615124688


KeyboardInterrupt: 

array([0.56866911, 0.58154484, 0.56866911])

In [252]:
update_y(z_init, y_sample)

(array([0.56866911, 0.58154484]), array([0.56866911, 0.58154484]))

In [146]:
y_sample

array([0.63857364, 0.64941767])

In [145]:
y_new

array([0.61567806, 0.65347749])

In [149]:
value_new / value

array([0.96414575, 1.00625148])

In [147]:
value

array([0.63857364, 0.64941767])

In [108]:
update_z(1, z_init)

0.5832697384786996
0.5832697384786996
0.5773392478001587
0.5832697384786996
0.5836683050090883


array([2., 1., 4., 1., 3., 3., 2., 1., 0., 0., 1., 3., 3., 0., 0., 0., 0.,
       0., 1., 3., 3., 3., 2., 2., 1.])

0.5866557136878559
0.5865371990502762
0.5824550691277889
0.5870951964612465
0.5865371990502762


In [40]:
Z_candidates

array([[0., 0., 0., 0., 3., 4., 4., 2., 0., 1., 0., 3., 1., 2., 4., 3.,
        3., 4., 3., 2., 0., 3., 0., 4., 0.],
       [1., 0., 0., 0., 3., 4., 4., 2., 0., 1., 0., 3., 1., 2., 4., 3.,
        3., 4., 3., 2., 0., 3., 0., 4., 0.],
       [2., 0., 0., 0., 3., 4., 4., 2., 0., 1., 0., 3., 1., 2., 4., 3.,
        3., 4., 3., 2., 0., 3., 0., 4., 0.],
       [3., 0., 0., 0., 3., 4., 4., 2., 0., 1., 0., 3., 1., 2., 4., 3.,
        3., 4., 3., 2., 0., 3., 0., 4., 0.],
       [4., 0., 0., 0., 3., 4., 4., 2., 0., 1., 0., 3., 1., 2., 4., 3.,
        3., 4., 3., 2., 0., 3., 0., 4., 0.]])

In [31]:
Z_candidates_d

array([[1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 1., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0.,
        0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0., 0., 1., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 1., 0.,
   

In [19]:
y_sample

array([0.55978141, 0.57531186])

In [13]:
z_init_d

array([[0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
        0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 1.,
        0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0.,
        1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.]])