In [None]:
import numpy as np
import torch
import torchvision.datasets
import torchvision.models
import torchvision.transforms
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data
import torch.nn as nn
from tqdm import tqdm, tnrange
import warnings
warnings.filterwarnings("ignore")
import copy

import pickle as pkl
import pandas as pd
from scipy.stats import entropy
CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0
if CUDA:
    device='cuda:1'
else:
    device='cpu'
torch.cuda.is_available()

CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0
try:
    from google.colab import drive
    drive.mount('/content/drive')
    base='drive/MyDrive'
except:
    if torch.cuda.device_count()>1:
        base='.'
    else:
        base='Google Drive'

if CUDA:
    device='cuda:1'
else:
    device='cpu'
torch.cuda.is_available()


Mounted at /content/drive


False

In [None]:
class RBM(nn.Module):

    def __init__(self, n_vis, n_hid):
        """Create a RBM."""
        super(RBM, self).__init__()
        
        self.v_bias = nn.Parameter(torch.zeros(1, n_vis).to(device))
        self.h_bias = nn.Parameter(torch.zeros(1, n_hid).to(device))
        self.Weight = nn.Parameter(std*torch.randn(n_hid, n_vis).to(device))


    def v2h(self, v):
        return torch.sigmoid(F.linear(v, self.Weight, self.h_bias)).bernoulli()

    def h2v(self, h):
        return torch.sigmoid(F.linear(h, self.Weight.t(), self.v_bias)).bernoulli()
    
    def Fv(self, v):
        v_term = torch.matmul(v, self.v_bias.t()).view(len(v))
        h_term = torch.sum(F.softplus(F.linear(v, self.Weight, self.h_bias)), dim=1)
        return -h_term -v_term

    def energy(self, v, h):
        v=v.bernoulli()
        h=h.bernoulli()
        return -torch.matmul(v, self.v_bias.t())-torch.matmul(torch.matmul(v, self.Weight.t()),h.t())-torch.matmul(h, self.h_bias.t())


    def forward(self, v):

        return v.bernoulli()
from torch.utils.data import Dataset

class CustomDataset(Dataset): 
    def __init__(self, dataset):
        data_x = dataset
        self.x_data = data_x
#         self.y_data = data_y

    # 총 데이터의 개수를 리턴
    def __len__(self): 
        return len(self.x_data)
    # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
    def __getitem__(self, idx): 
        x = torch.FloatTensor(self.x_data[idx])
#         y = torch.FloatTensor([self.y_data[idx]])
        return x

def data_to_loader(fullconfigs):
    fulldata=CustomDataset(fullconfigs)
    full_dataset = fulldata
    full_loader = torch.utils.data.DataLoader(full_dataset, batch_size)
    return full_loader


def train_and_get_data(n_hid, model, lr, train_loader):
    # Train and get the new result
    rbm=RBM(n_vis, n_hid)
    train_op = optim.SGD(rbm.parameters(), lr, momentum=0.9)
    rbm.train()
    val_loss_list=[]
    model_list=[]
    # for epoch in tnrange(n_epochs):
    correction=0
    for epoch in tnrange(n_epochs):
        for _, (data) in enumerate(train_loader):
            v= data.view(-1, n_vis).to(device)
            Q_bf_list=torch.exp(-rbm.Fv(full_configs1)-correction).float()
            # while np.isnan(Q_bf_list[0].detach().numpy())==True:
            #     print('correction')
            #     correction=correction*10
            #     Q_bf_list=torch.exp(-rbm.Fv(v_list_ising2)-correction)

            Qv=torch.tensor(Q_bf_list/torch.sum(Q_bf_list)).to(device)
            Fv_Q=torch.dot(rbm.Fv(full_configs1), Qv)

            train_loss = torch.mean(rbm.Fv(v)) - Fv_Q
            train_op.zero_grad()
            train_loss.backward()
            train_op.step()
        if epoch in epoch_to_save:
            model_list.append(copy.deepcopy(rbm.cpu().state_dict()))
            rbm.to(device)
            FE=-torch.log(torch.sum(torch.exp(-rbm.Fv(full_configs1))))
            if FE==float('inf'):
                FE=-torch.log(torch.sum(torch.exp(-rbm.Fv(full_configs1))))
            Fv_P=torch.dot(rbm.Fv(full_configs1).double(), Pv.double())
            GE=Fv_P-FE-S
            print('epoch={epoch}, GE={GE}'.format(epoch=epoch, GE=GE))
            val_loss_list.append(GE)
    return model_list, val_loss_list

def CM_model(models):
    new_v_bias=0; new_h_bias=0; new_Weight=0
    for m in range(10):
        new_v_bias+=models[str(m)][-1]['v_bias']/10
        new_h_bias+=models[str(m)][-1]['h_bias']/10
        new_Weight+=models[str(m)][-1]['Weight']/10
    return {'v_bias':new_v_bias, 'h_bias':new_h_bias, 'Weight':new_Weight}  

def mean_Fv(Q_m, v0):
    value=0
    for m in range(10):
        rbm=RBM(n_vis, n_hid, k=1)
        rbm.load_state_dict(Q_m[str(m)])
        value+=rbm.Fv(v0)/10
    return value


def get_next_states(state0):
    occupied_spot=[]
    use_alpha=False
    use_beta=False

    for i in range(1,len(state0)+1):
        if state0[-i]=='1':
            occupied_spot.append(i)
    # print(occupied_spot)
    movable_spot=[]
    movable_count=0
    for j in range(len(occupied_spot)):
        if occupied_spot[j]+1 not in occupied_spot:
            movable_spot.append(occupied_spot[j])
            movable_count+=1
    # print(movable_count)

    next_states=[]
    for k in range(len(movable_spot)):
        state1=state0
        state1=list(state1)
        state1[-movable_spot[k]]='0'
        try:
            state1[-movable_spot[k]-1]='1'
            state1=''.join(state1[e] for e in range(len(state1)))
        except:
            # An outflow of a partible.
            use_beta=True
        state1=''.join(state1[e] for e in range(len(state1)))
        next_states.append(state1)
        # An inflow of a new particle.
    if state0[-1]=='0':
        use_alpha=True
        state1=state0
        state1=list(state1)
        state1[-1]='1'
        state1=''.join(state1[e] for e in range(len(state1)))
        next_states.append(state1)
    use_alpha_beta=[use_alpha, use_beta]

    factor_next_states={}
    for l in range(len(next_states)):
        try:
            factor_next_states[str(next_states[l])]=1/(len(next_states)-use_beta-use_alpha)
            # print('aa')
            # print(1/(len(next_states-use_alpha-use_beta)))
        except:
            None

        if int(list(str2arr(state0)-str2arr(next_states[l]))[-1])==-1:
            factor_next_states[str(next_states[l])]=alpha
        if int(list(str2arr(state0)-str2arr(next_states[l]))[0])==1:
            factor_next_states[str(next_states[l])]=beta
    return next_states, factor_next_states

def check_boundary(state0):
    use_alpha=False
    use_beta=False
    if state0[0]=='1':
        use_beta=True
    if state0[-1]=='0':
        use_alpha=True

    return use_alpha, use_beta

def dec2bin(integer, n_hid):
    string=bin(int(integer))[2:]
    list0=[int(d) for d in string]
    while len(list0)<n_hid:
        list0=[0]+list0
    a=np.array([list0])[0]
    b=''.join(str(e) for e in a)
    return b

def arr2str(array0):
    a=np.array([array0])[0]
    b=''.join(str(e) for e in a)
    return b


def str2arr(str0):
    b=[int(str0[i]) for i in range(len(str0))]
    b=np.array(b)
    return b

def get_ab(config0):
    a=config0.count(0)
    b=config0.count(1)
    return a,b


def opD(list0):
    tf='tf'
    if list0[-1]==0:
        del list0[-1]
        N=N-1
        factor=factor/alpha

    else:
        None

def opE(list0):
    tf='tf'
    if list0[0]==1:
        del list0[0]
        N=N-1
        factor=factor/beta
    else:
        None






def seperate(config_list):
    for j in range(len(config_list)):
        if '01' in arr2str(config_list[j]):
            config0=config_list[j]
            break
        else:
            None
    try:
        for i in range(len(config0)-1):
            if config0[i]==0 and config0[i+1]==1:
                config_list.append(del_comp(config0, i))
                config_list.append(del_comp(config0, i+1))
                config_list.remove(config0)
                break
            else:
                None
    except:
        None
def check(config_list):
    TF='01'
    for i in range(len(config_list)):
        if '01' in arr2str(config_list[i]):
            TF=True
            break
        else:
            TF=False
    return TF

def del_comp(list0, i):
    list1=list0.copy()
    del list1[i]
    return list1

def opS(list0):
    config_list=[list0]
    while check(config_list):
   #     print(config_list)
        seperate(config_list)

    return config_list

def get_ideal_dist():
    dist={}
    for i in range(2**Nv):
        flist=opS(list(str2arr(dec2bin(i,Nv))))    
        f=0
        for j in range(len(flist)):
            a,b=get_ab(flist[j])
            f+=((1/alpha)**a)*((1/beta)**b)

        dist[dec2bin(i,Nv)]=f
    s=sum(dist.values())
    for i in range(2**Nv):
        dist[dec2bin(i,Nv)]=dist[dec2bin(i,Nv)]/s
    return dist
    

In [None]:
def TASEP_dynamics(state0, n_sample, interval):
    state_list=[]
    for step in range(n_sample*interval):
        r0=np.random.randint(0,N+1)
        if r0==0:
            if state0[-1]=='0':
                if np.random.rand()<alpha:
                    state0=str2arr(state0)
                    state0[-1]='1'
                else:None
            else:None
        elif r0==N:
            if state0[0]=='1':
                if np.random.rand()<beta:
                    state0=str2arr(state0)
                    state0[0]='0'
                else:None
            else:None
        else:
            if state0[-r0]=='1' and state0[-r0-1]=='0':
                state0=str2arr(state0)
                state0[-r0]='0'
                state0[-r0-1]='1'
        state0=arr2str(state0)
        # state0=str2arr(state0)
        if step%interval==0:
            state_list.append(state0)
    
    # staet counting
    dist={}
    for i in range(len(state_list)):
        dist[state_list[i]]=0
    for i in range(len(state_list)):
        dist[state_list[i]]+=1
    observed=list(dist.keys())
    for i in range(len(observed)):
        
        dist[observed[i]]=dist[observed[i]]/len(state_list)
        
    for i in range(len(state_list)):
        state_list[i]=torch.tensor(str2arr(state_list[i]))
    return torch.stack(state_list).float()
    # return dist

In [None]:
vol=4096
Nv=9
N=Nv
seed=dec2bin(np.random.randint(0,Nv),Nv)

for alpha in [0.3, 0.5, 0.7]:
    for beta in [0.7]:
        lat=TASEP_dynamics(seed, vol, 10*Nv)
        with open('{base}/TASEP/test_alpha={alpha}_beta={beta}.pkl'.format(base=base, alpha=alpha, beta=beta), 'wb') as f:
            pkl.dump(lat, f)


In [None]:
lr=0.1
n_vis=9
N=n_vis
Nv=n_vis
std=0.1
alpha=0.1
beta=0.1

epoch_to_save=[2**i for i in range(17)]
n_epochs=epoch_to_save[-1]+1
torch.set_printoptions(precision=10)

for alpha in [0.3, 0.5, 0.7]:
    for beta in [0.7]:
        prob_dict=get_ideal_dist()
        full_configs1=[]
        full_configs2=[]
        for i in range(2**N):
            full_configs1.append(str2arr(dec2bin(i,N)))
            full_configs2.append(str2arr(dec2bin(i,N))*2-1)
        full_configs1=torch.tensor(full_configs1).float()
        full_configs2=torch.tensor(full_configs2).float()
        for n_hid in [1]:  
            Nh=n_hid
            Pv=torch.tensor(list(prob_dict.values()))
            S=entropy(Pv)
            for vol in [4096]:
                batch_size=int(vol/2)
                dict_model={}
                dict_GE={}
                train_loader_list=[]; val_loader_list=[]
                for m in range(10):
                    seed=dec2bin(np.random.randint(0,N),N)
                    configs1=TASEP_dynamics(seed, vol, 10*Nv)
                    train_loader_list.append(data_to_loader(configs1))
                m=0
                while m<10:
                    model0, loss=train_and_get_data(n_hid, 0, lr=lr, train_loader=train_loader_list[m])
                    dict_model[str(m)]=model0
                    dict_GE[str(m)]=loss
                    if np.isnan(loss[-1].detach().numpy())==True:
                        None
                    else: 
                        dict_model[str(m)]=model0
                        dict_GE[str(m)]=loss
                        m=m+1
                    print(m)

            with open('{base}/TASEP/state_dict/model_Nh={Nh}_alpha={alpha}_beta={beta}_vol={vol}.pkl'.format(base=base, Nh=Nh, alpha=alpha, beta=beta, vol=vol), 'wb') as f:
                pkl.dump(dict_model, f)
            with open('{base}/TASEP/loss/GE_Nh={Nh}_alpha={alpha}_beta={beta}_vol={vol}.pkl'.format(base=base, Nh=Nh, alpha=alpha, beta=beta, vol=vol), 'wb') as f:
                pkl.dump(dict_GE, f)


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.41851506523982795
epoch=2, GE=0.22351656106388695
epoch=4, GE=0.016742587861767966
epoch=8, GE=0.0626593293928508
epoch=16, GE=0.0029160885171535256
epoch=32, GE=0.002065972864997434
epoch=64, GE=0.001860344813197301
epoch=128, GE=0.0016868002399217374
epoch=256, GE=0.0015614143731328411
epoch=512, GE=0.0015421653348264286
epoch=1024, GE=0.001816407622744265
epoch=2048, GE=0.002090371961831927
epoch=4096, GE=0.00217790573504395
epoch=8192, GE=0.0022333459218844
epoch=16384, GE=0.002270337328928562
epoch=32768, GE=0.002297680584382178
epoch=65536, GE=0.0023171154644430203
1


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.4235783432497495
epoch=2, GE=0.2330275421027883
epoch=4, GE=0.02165847268081933
epoch=8, GE=0.050779498556018154
epoch=16, GE=0.003824143808262903
epoch=32, GE=0.0015853541143284744
epoch=64, GE=0.002368733149614499
epoch=128, GE=0.0021686006386447687
epoch=256, GE=0.0020534281280468747
epoch=512, GE=0.0022502704851472544
epoch=1024, GE=0.0030031065039990423
epoch=2048, GE=0.003100326813341603
epoch=4096, GE=0.003170683881053904
epoch=8192, GE=0.0032084671281182864
epoch=16384, GE=0.0032237936658390254
epoch=32768, GE=0.0032365852248998905
epoch=65536, GE=0.0032594143763891026
2


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.4120006396318967
epoch=2, GE=0.21746618972438192
epoch=4, GE=0.01360544000523678
epoch=8, GE=0.0723629070875953
epoch=16, GE=0.0010646525490436431
epoch=32, GE=0.0020084155876674004
epoch=64, GE=0.0009216740202280249
epoch=128, GE=0.0009149293524082225
epoch=256, GE=0.0009962755923842437
epoch=512, GE=0.0012697278092801412
epoch=1024, GE=0.0014654284663420114
epoch=2048, GE=0.0015284085859175178
epoch=4096, GE=0.0015731218756700471
epoch=8192, GE=0.0016220473448207429
epoch=16384, GE=0.0016801229408578422
epoch=32768, GE=0.0017492777286056693
epoch=65536, GE=0.0017924132605982024
3


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.4139433729375588
epoch=2, GE=0.21766185008165984
epoch=4, GE=0.013760012694602253
epoch=8, GE=0.0713818627473577
epoch=16, GE=0.0010721110160956115
epoch=32, GE=0.0017840764967074207
epoch=64, GE=0.0007367764035768332
epoch=128, GE=0.0005857095691679248
epoch=256, GE=0.0004719726036785943
epoch=512, GE=0.0006728984598822763
epoch=1024, GE=0.001372212223820668
epoch=2048, GE=0.0015551190127069958
epoch=4096, GE=0.0016682443276998526
epoch=8192, GE=0.001744876699045328
epoch=16384, GE=0.0017973981874881417
epoch=32768, GE=0.0018412309982132058
epoch=65536, GE=0.001884107215936126
4


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.4155471152697485
epoch=2, GE=0.22392987047154467
epoch=4, GE=0.01687281571977106
epoch=8, GE=0.06158228732362225
epoch=16, GE=0.0015967119306381505
epoch=32, GE=0.001221539768843094
epoch=64, GE=0.0010096039800284728
epoch=128, GE=0.0009193099430166995
epoch=256, GE=0.0009703956751181408
epoch=512, GE=0.0014182659329424752
epoch=1024, GE=0.0018246147158382797
epoch=2048, GE=0.0018592672719499603
epoch=4096, GE=0.0019647680913204013
epoch=8192, GE=0.002123098181091798
epoch=16384, GE=0.0022287889686172235
epoch=32768, GE=0.0022717003617565013
epoch=65536, GE=0.002289102683256772
5


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.4222646586239627
epoch=2, GE=0.22074176827916858
epoch=4, GE=0.014282735223899223
epoch=8, GE=0.06911434393244065
epoch=16, GE=0.002040923126068961
epoch=32, GE=0.0021084716562675965
epoch=64, GE=0.0014314562928010588
epoch=128, GE=0.001318311579101561
epoch=256, GE=0.0012422335813919716
epoch=512, GE=0.0012819447728578126
epoch=1024, GE=0.0014676707472505157
epoch=2048, GE=0.001661540474713341
epoch=4096, GE=0.002181589434075093
epoch=8192, GE=0.0027090164036849274
epoch=16384, GE=0.0028762800566957125
epoch=32768, GE=0.002950117170195554
epoch=65536, GE=0.0029899546343532535
6


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.4184785421792645
epoch=2, GE=0.2245565813103978
epoch=4, GE=0.01734185180945591
epoch=8, GE=0.06035817306323388
epoch=16, GE=0.0021488085369849586
epoch=32, GE=0.001294542685984723
epoch=64, GE=0.0012026860829390529
epoch=128, GE=0.0010302869510887547
epoch=256, GE=0.001050379445769245
epoch=512, GE=0.0015553697667733957
epoch=1024, GE=0.0018438481983089616
epoch=2048, GE=0.002008208125838351
epoch=4096, GE=0.0021539281881102212
epoch=8192, GE=0.0022886368549697167
epoch=16384, GE=0.002558681836905663
epoch=32768, GE=0.0028086305191141747
epoch=65536, GE=0.0028882051302465683
7


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.405556846089584
epoch=2, GE=0.22142697717566762
epoch=4, GE=0.017641350303470027
epoch=8, GE=0.06649460085170489
epoch=16, GE=0.0025496734209617955
epoch=32, GE=0.002597891807473829
epoch=64, GE=0.0019120419252267595
epoch=128, GE=0.0019082720999978875
epoch=256, GE=0.0019969933935346518
epoch=512, GE=0.0023373897643494956
epoch=1024, GE=0.003308685258149424
epoch=2048, GE=0.003494554952794715
epoch=4096, GE=0.003543310761507712
epoch=8192, GE=0.0035917531567219996
epoch=16384, GE=0.0036228571561967016
epoch=32768, GE=0.0036405569212920597
epoch=65536, GE=0.0036506993262133136
8


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.42267627842739763
epoch=2, GE=0.2150163154633269
epoch=4, GE=0.011080524919655943
epoch=8, GE=0.07724300397892758
epoch=16, GE=0.0009171977005433263
epoch=32, GE=0.002148287492618195
epoch=64, GE=0.0008271351151494244
epoch=128, GE=0.0007790859181833198
epoch=256, GE=0.0009107481638945814
epoch=512, GE=0.001774300548547103
epoch=1024, GE=0.0024842615040547855
epoch=2048, GE=0.002655427984623593
epoch=4096, GE=0.002726053696384234
epoch=8192, GE=0.0027534412079317505
epoch=16384, GE=0.0027606136893920663
epoch=32768, GE=0.002763329527336822
epoch=65536, GE=0.002766535873054643
9


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.41192601245390215
epoch=2, GE=0.22857721257071262
epoch=4, GE=0.02077427175341473
epoch=8, GE=0.05362507913746217
epoch=16, GE=0.0023889020808036676
epoch=32, GE=0.001078474127594653
epoch=64, GE=0.001406747786464102
epoch=128, GE=0.0012367577830527665
epoch=256, GE=0.0012210150036846557
epoch=512, GE=0.0019655676591128213
epoch=1024, GE=0.0029505057759831033
epoch=2048, GE=0.0030258138202770013
epoch=4096, GE=0.003108884342227114
epoch=8192, GE=0.003244437711618531
epoch=16384, GE=0.00334799632579319
epoch=32768, GE=0.0033961422278938613
epoch=65536, GE=0.0034183327109476025
10


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.057501323988651265
epoch=2, GE=0.03460242414084114
epoch=4, GE=0.011042362301301623
epoch=8, GE=0.020225693376817233
epoch=16, GE=0.010498433356022296
epoch=32, GE=0.009761943226605396
epoch=64, GE=0.00940538403606439
epoch=128, GE=0.00917087744856282
epoch=256, GE=0.009250911135836759
epoch=512, GE=0.010053371288655022
epoch=1024, GE=0.010162384027220916
epoch=2048, GE=0.01022075694011626
epoch=4096, GE=0.010293020599196012
epoch=8192, GE=0.010386189403977397
epoch=16384, GE=0.010482764128919797
epoch=32768, GE=0.010558977510982004
epoch=65536, GE=0.010607003871414555
1


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06221168756612805
epoch=2, GE=0.033175901381157225
epoch=4, GE=0.01101724890536282
epoch=8, GE=0.029740917413324297
epoch=16, GE=0.01099847080022176
epoch=32, GE=0.011186885931559587
epoch=64, GE=0.0107292662917029
epoch=128, GE=0.010587113654719538
epoch=256, GE=0.010490766771971138
epoch=512, GE=0.011861252653798893
epoch=1024, GE=0.012112197844574624
epoch=2048, GE=0.012061308089744927
epoch=4096, GE=0.011559623332593638
epoch=8192, GE=0.009779886628511036
epoch=16384, GE=0.009576734520893027
epoch=32768, GE=0.009515044195214983
epoch=65536, GE=0.009486935493402626
2


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.05958415529122263
epoch=2, GE=0.03295890641343213
epoch=4, GE=0.009462952658460644
epoch=8, GE=0.024691739710243965
epoch=16, GE=0.010019443074739343
epoch=32, GE=0.009780797475518632
epoch=64, GE=0.009239819916324166
epoch=128, GE=0.009013329060829633
epoch=256, GE=0.0087438818389165
epoch=512, GE=0.009308167559093583
epoch=1024, GE=0.009455216746837358
epoch=2048, GE=0.009281685525048822
epoch=4096, GE=0.009052284799723864
epoch=8192, GE=0.008887803663702343
epoch=16384, GE=0.008784534440517966
epoch=32768, GE=0.008723306589451418
epoch=65536, GE=0.008687333042019851
3


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06361065092438611
epoch=2, GE=0.03603808620415716
epoch=4, GE=0.01043571684739586
epoch=8, GE=0.02192284052227933
epoch=16, GE=0.010917409093951669
epoch=32, GE=0.009839183514741556
epoch=64, GE=0.009615722354757494
epoch=128, GE=0.00951248757075085
epoch=256, GE=0.009371377440908546
epoch=512, GE=0.01009947963927349
epoch=1024, GE=0.010314515860465612
epoch=2048, GE=0.010198945473908871
epoch=4096, GE=0.010216384088697517
epoch=8192, GE=0.010295039141059448
epoch=16384, GE=0.010407219385196953
epoch=32768, GE=0.01049799586734057
epoch=65536, GE=0.010547383066814575
4


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06609140085445553
epoch=2, GE=0.03516961678994601
epoch=4, GE=0.009877803704412713
epoch=8, GE=0.027848399275113067
epoch=16, GE=0.01032558759585367
epoch=32, GE=0.0102025650042874
epoch=64, GE=0.009779523182960048
epoch=128, GE=0.009568695679599948
epoch=256, GE=0.009310150034060882
epoch=512, GE=0.009646379360010116
epoch=1024, GE=0.009751730422570049
epoch=2048, GE=0.00956387215945842
epoch=4096, GE=0.00941979228456269
epoch=8192, GE=0.009331468230547202
epoch=16384, GE=0.009289955840062092
epoch=32768, GE=0.009306839268511169
epoch=65536, GE=0.009663857066986026
5


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06696623459365103
epoch=2, GE=0.035186324862435114
epoch=4, GE=0.010066920988391459
epoch=8, GE=0.02634582141249009
epoch=16, GE=0.010877610315159991
epoch=32, GE=0.010008780965106823
epoch=64, GE=0.009575774058630415
epoch=128, GE=0.009384710342017044
epoch=256, GE=0.009218545139678014
epoch=512, GE=0.010571883376099933
epoch=1024, GE=0.01057148860105972
epoch=2048, GE=0.009007094949882166
epoch=4096, GE=0.008727310906313512
epoch=8192, GE=0.008655746091843497
epoch=16384, GE=0.00862990187874324
epoch=32768, GE=0.008617918038829053
epoch=65536, GE=0.008613041165922652
6


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.056311840324815954
epoch=2, GE=0.03309549460457806
epoch=4, GE=0.010571901093127778
epoch=8, GE=0.02066251937051966
epoch=16, GE=0.01046284762681271
epoch=32, GE=0.009694371630773801
epoch=64, GE=0.009382879712504177
epoch=128, GE=0.009241939231407592
epoch=256, GE=0.009066804629791037
epoch=512, GE=0.008781174652757961
epoch=1024, GE=0.008583729434632481
epoch=2048, GE=0.00825000939358489
epoch=4096, GE=0.008096744374745057
epoch=8192, GE=0.008014128518810892
epoch=16384, GE=0.00796880528243804
epoch=32768, GE=0.007943138045979126
epoch=65536, GE=0.007928652469563424
7


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06390141656462145
epoch=2, GE=0.036393452878111354
epoch=4, GE=0.009934357664370452
epoch=8, GE=0.019970629628481618
epoch=16, GE=0.010625306934328727
epoch=32, GE=0.009263308115476931
epoch=64, GE=0.009043523834540323
epoch=128, GE=0.008947895668240946
epoch=256, GE=0.00885084880836029
epoch=512, GE=0.00903882780881382
epoch=1024, GE=0.009355591969918642
epoch=2048, GE=0.009509496613177681
epoch=4096, GE=0.009273645746656456
epoch=8192, GE=0.009061005582482018
epoch=16384, GE=0.008961590186931545
epoch=32768, GE=0.008912767708956437
epoch=65536, GE=0.008887390449848098
8


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06180654457780044
epoch=2, GE=0.03388219467046927
epoch=4, GE=0.009126778485605946
epoch=8, GE=0.023652977375807183
epoch=16, GE=0.00961604657905113
epoch=32, GE=0.00914444109521817
epoch=64, GE=0.00869280466460598
epoch=128, GE=0.008410202570752645
epoch=256, GE=0.008424558374954039
epoch=512, GE=0.009811102139468453
epoch=1024, GE=0.009930805905291606
epoch=2048, GE=0.009584399670605315
epoch=4096, GE=0.00938270878144909
epoch=8192, GE=0.00924892133214339
epoch=16384, GE=0.009157598249337262
epoch=32768, GE=0.009094388769242734
epoch=65536, GE=0.009048380242315801
9


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06266852343541895
epoch=2, GE=0.035422126592077774
epoch=4, GE=0.010076760635006288
epoch=8, GE=0.021968072590055954
epoch=16, GE=0.010527640867036503
epoch=32, GE=0.009656335188538812
epoch=64, GE=0.009383646785706468
epoch=128, GE=0.009252346380216103
epoch=256, GE=0.009149498801423128
epoch=512, GE=0.009854893137374532
epoch=1024, GE=0.009902856271215121
epoch=2048, GE=0.009857602712966873
epoch=4096, GE=0.009674090600213958
epoch=8192, GE=0.008191318781010537
epoch=16384, GE=0.007971901783458968
epoch=32768, GE=0.007895126303608002
epoch=65536, GE=0.00785750738252311
10


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.07001147146549336
epoch=2, GE=0.049012776656438106
epoch=4, GE=0.029548584080655083
epoch=8, GE=0.040729181587447094
epoch=16, GE=0.03061545428391632
epoch=32, GE=0.0294358583843346
epoch=64, GE=0.029206970522246856
epoch=128, GE=0.028978685475968824
epoch=256, GE=0.028191400851337534
epoch=512, GE=0.028455362112308258
epoch=1024, GE=0.027944850819294764
epoch=2048, GE=0.025543393458121066
epoch=4096, GE=0.02481761877748312
epoch=8192, GE=0.024504613340582182
epoch=16384, GE=0.02434595789689986
epoch=32768, GE=0.024262581646679493
epoch=65536, GE=0.02421802713075305
1


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06382312810131108
epoch=2, GE=0.04559220140111009
epoch=4, GE=0.030195005658726792
epoch=8, GE=0.04009437473587063
epoch=16, GE=0.031115143481653895
epoch=32, GE=0.02985184252066375
epoch=64, GE=0.029795625148927307
epoch=128, GE=0.029528051934188504
epoch=256, GE=0.029414396761326955
epoch=512, GE=0.030550300000417252
epoch=1024, GE=0.029616562762030618
epoch=2048, GE=0.026968275392754926
epoch=4096, GE=0.026043459967051596
epoch=8192, GE=0.02565834196374972
epoch=16384, GE=0.025476373440440625
epoch=32768, GE=0.02538502013637789
epoch=65536, GE=0.02533788658174352
2


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06627153480119397
epoch=2, GE=0.04553417987477015
epoch=4, GE=0.029173602295606393
epoch=8, GE=0.04366973093717341
epoch=16, GE=0.029535335466424506
epoch=32, GE=0.02898197944657621
epoch=64, GE=0.02894324130932624
epoch=128, GE=0.028609619311454004
epoch=256, GE=0.02822221048250384
epoch=512, GE=0.028104400782934746
epoch=1024, GE=0.025946593409095442
epoch=2048, GE=0.02490788755300244
epoch=4096, GE=0.024442107454969886
epoch=8192, GE=0.02419041320563231
epoch=16384, GE=0.024050565420631997
epoch=32768, GE=0.023974848741502797
epoch=65536, GE=0.023933951687942212
3


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06752717548794962
epoch=2, GE=0.046861389399877496
epoch=4, GE=0.02884202694522653
epoch=8, GE=0.04008574604396742
epoch=16, GE=0.03008432318701626
epoch=32, GE=0.02881371084461115
epoch=64, GE=0.028630171756526934
epoch=128, GE=0.02841742306525674
epoch=256, GE=0.027723530132169216
epoch=512, GE=0.02819058207024039
epoch=1024, GE=0.025764514964853547
epoch=2048, GE=0.024464121778192904
epoch=4096, GE=0.023998458838985215
epoch=8192, GE=0.023779413535176097
epoch=16384, GE=0.023667513057542777
epoch=32768, GE=0.023608742672012184
epoch=65536, GE=0.02357797650187532
4


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.07112363974750302
epoch=2, GE=0.04906272926494015
epoch=4, GE=0.029141446485433598
epoch=8, GE=0.041400844805311365
epoch=16, GE=0.03020587436266009
epoch=32, GE=0.0290886376222792
epoch=64, GE=0.028925944098957324
epoch=128, GE=0.028847197217246823
epoch=256, GE=0.028409182959897805
epoch=512, GE=0.028131797582154228
epoch=1024, GE=0.025752162650967136
epoch=2048, GE=0.0245770223188293
epoch=4096, GE=0.024106375814609393
epoch=8192, GE=0.023877488925200474
epoch=16384, GE=0.023759183029479658
epoch=32768, GE=0.02369806193416668
epoch=65536, GE=0.023666252247610764
5


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06965630042061655
epoch=2, GE=0.046793849565688106
epoch=4, GE=0.02837563676696142
epoch=8, GE=0.043503191974907374
epoch=16, GE=0.029267365292820813
epoch=32, GE=0.02850865341777009
epoch=64, GE=0.028416593157237102
epoch=128, GE=0.028181907805238993
epoch=256, GE=0.028102366414577418
epoch=512, GE=0.029644115564028795
epoch=1024, GE=0.029164827589099218
epoch=2048, GE=0.02620308233411528
epoch=4096, GE=0.02500612690751769
epoch=8192, GE=0.024545911874345627
epoch=16384, GE=0.02433528919973149
epoch=32768, GE=0.024231566217627787
epoch=65536, GE=0.024177924765757197
6


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.06784911923855752
epoch=2, GE=0.047558251731823376
epoch=4, GE=0.02854795842405533
epoch=8, GE=0.03839190637200307
epoch=16, GE=0.02968942740477143
epoch=32, GE=0.028188739867723633
epoch=64, GE=0.027948843031262882
epoch=128, GE=0.0274139921615264
epoch=256, GE=0.027430655221584388
epoch=512, GE=0.027865374600538395
epoch=1024, GE=0.026551044953976977
epoch=2048, GE=0.02580141136615932
epoch=4096, GE=0.025569648241561715
epoch=8192, GE=0.0254701887822133
epoch=16384, GE=0.02542109357822575
epoch=32768, GE=0.02539582190874423
epoch=65536, GE=0.025382719670464127
7


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.07441015002314444
epoch=2, GE=0.052612199048716946
epoch=4, GE=0.029919918544818458
epoch=8, GE=0.036680709483193574
epoch=16, GE=0.031226876381927227
epoch=32, GE=0.02908672314172378
epoch=64, GE=0.028943601710864186
epoch=128, GE=0.028584336638096453
epoch=256, GE=0.02788537751924025
epoch=512, GE=0.02807429059702926
epoch=1024, GE=0.026940837479505042
epoch=2048, GE=0.02553605843115303
epoch=4096, GE=0.024769199651751173
epoch=8192, GE=0.02439082485702926
epoch=16384, GE=0.024201593077394
epoch=32768, GE=0.02410408866383129
epoch=65536, GE=0.02405317967687992
8


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.08039481216680411
epoch=2, GE=0.05352431144132552
epoch=4, GE=0.028582811519798135
epoch=8, GE=0.03964250176652229
epoch=16, GE=0.030477091081290375
epoch=32, GE=0.028372635961553705
epoch=64, GE=0.028322053987066553
epoch=128, GE=0.028195579463419662
epoch=256, GE=0.027333399614507314
epoch=512, GE=0.027019304591830107
epoch=1024, GE=0.025908049482742612
epoch=2048, GE=0.023898856104836774
epoch=4096, GE=0.023214013256822597
epoch=8192, GE=0.022910957362062945
epoch=16384, GE=0.022759101092543865
epoch=32768, GE=0.02267966928594145
epoch=65536, GE=0.02263782921180635
9


  0%|          | 0/65537 [00:00<?, ?it/s]

epoch=1, GE=0.08328119092841213
epoch=2, GE=0.05388651341671391
epoch=4, GE=0.028587179799118623
epoch=8, GE=0.04201398573232229
epoch=16, GE=0.030417615855393265
epoch=32, GE=0.028353115450422628
epoch=64, GE=0.02828901376131121
epoch=128, GE=0.02817512669573219
epoch=256, GE=0.02751380243617252
epoch=512, GE=0.027574174655200068
epoch=1024, GE=0.027881869616801502
epoch=2048, GE=0.025567079809603044
epoch=4096, GE=0.024412190240963305
epoch=8192, GE=0.024064713826969175
epoch=16384, GE=0.02391088454950019
epoch=32768, GE=0.023835431713158606
epoch=65536, GE=0.02379735245273551
10
