In [None]:
import numpy as np
import torch
import torchvision.datasets
import torchvision.models
import torchvision.transforms
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data
import torch.nn as nn
from tqdm import tqdm, tnrange
import warnings
warnings.filterwarnings("ignore")
import copy

import pickle as pkl
import pandas as pd
from scipy.stats import entropy
CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0
if CUDA:
    device='cuda:1'
else:
    device='cpu'
torch.cuda.is_available()

CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0
try:
    from google.colab import drive
    drive.mount('/content/drive')
    base='drive/MyDrive'
except:
    if torch.cuda.device_count()>1:
        base='.'
    else:
        base='Google Drive'

if CUDA:
    device='cuda:1'
else:
    device='cpu'
torch.cuda.is_available()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


False

In [None]:
class RBM(nn.Module):
    def __init__(self, n_vis, n_hid):
        """Create a RBM."""
        super(RBM, self).__init__()
        self.v_bias = nn.Parameter(torch.zeros(1, n_vis).to(device))
        self.h_bias = nn.Parameter(torch.zeros(1, n_hid).to(device))
        self.Weight = nn.Parameter(std*torch.randn(n_hid, n_vis).to(device))

    def v2h(self, v):
        return torch.sigmoid(F.linear(v, self.Weight, self.h_bias)).bernoulli()

    def h2v(self, h):
        return torch.sigmoid(F.linear(h, self.Weight.t(), self.v_bias)).bernoulli()
    
    def Fv(self, v):
        v_term = torch.matmul(v, self.v_bias.t()).view(len(v))
        h_term = torch.sum(F.softplus(F.linear(v, self.Weight, self.h_bias)), dim=1)
        return -h_term -v_term

    def energy(self, v, h):
        v=v.bernoulli()
        h=h.bernoulli()
        return -torch.matmul(v, self.v_bias.t())-torch.matmul(torch.matmul(v, self.Weight.t()),h.t())-torch.matmul(h, self.h_bias.t())

    def forward(self, v):

        return v.bernoulli()
from torch.utils.data import Dataset

class CustomDataset(Dataset): 
    def __init__(self, dataset):
        data_x = dataset
        self.x_data = data_x
#         self.y_data = data_y

    # 총 데이터의 개수를 리턴
    def __len__(self): 
        return len(self.x_data)
    # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
    def __getitem__(self, idx): 
        x = torch.FloatTensor(self.x_data[idx])
#         y = torch.FloatTensor([self.y_data[idx]])
        return x

def data_to_loader(fullconfigs):
    fulldata=CustomDataset(fullconfigs)
    full_dataset = fulldata
    full_loader = torch.utils.data.DataLoader(full_dataset, batch_size)
    return full_loader

def train_and_get_data(n_hid, model, lr):
    rbm=RBM(n_vis, n_hid)
    train_loss_list=[]
#     train_op = optim.Adam(rbm.parameters(), lr)
    train_op = optim.SGD(rbm.parameters(), lr, momentum=0.9)
    rbm.train()
    train_loss_list=[]
    model_list=[]
    for epoch in tnrange(n_epochs):
        Fv=torch.dot(rbm.Fv(full_configs1).double(), Pv)
        FE=-torch.log(torch.sum(torch.exp(-rbm.Fv(full_configs1))))
        train_loss = Fv-FE
        train_op.zero_grad()
        train_loss.backward()
        train_op.step()
        GE=Fv-FE-S
        if epoch in epoch_to_save:
            train_loss_list.append(float((train_loss-S).detach().cpu().numpy()))
            model_list.append(copy.deepcopy(rbm.cpu().state_dict()))
            print('epoch={epoch}, GE={GE}'.format(epoch=epoch, GE=GE))
    return model_list, train_loss_list

def CM_model(models):
    new_v_bias=0; new_h_bias=0; new_Weight=0
    for m in range(10):
        new_v_bias+=models[str(m)][-1]['v_bias']/10
        new_h_bias+=models[str(m)][-1]['h_bias']/10
        new_Weight+=models[str(m)][-1]['Weight']/10
    return {'v_bias':new_v_bias, 'h_bias':new_h_bias, 'Weight':new_Weight}  

def mean_Fv(Q_m, v0):
    value=0
    for m in range(10):
        rbm=RBM(n_vis, n_hid, k=1)
        rbm.load_state_dict(Q_m[str(m)])
        value+=rbm.Fv(v0)/10
    return value


def get_next_states(state0):
    occupied_spot=[]
    use_alpha=False
    use_beta=False

    for i in range(1,len(state0)+1):
        if state0[-i]=='1':
            occupied_spot.append(i)
    # print(occupied_spot)
    movable_spot=[]
    movable_count=0
    for j in range(len(occupied_spot)):
        if occupied_spot[j]+1 not in occupied_spot:
            movable_spot.append(occupied_spot[j])
            movable_count+=1
    # print(movable_count)

    next_states=[]
    for k in range(len(movable_spot)):
        state1=state0
        state1=list(state1)
        state1[-movable_spot[k]]='0'
        try:
            state1[-movable_spot[k]-1]='1'
            state1=''.join(state1[e] for e in range(len(state1)))
        except:
            # An outflow of a partible.
            use_beta=True
        state1=''.join(state1[e] for e in range(len(state1)))
        next_states.append(state1)
        # An inflow of a new particle.
    if state0[-1]=='0':
        use_alpha=True
        state1=state0
        state1=list(state1)
        state1[-1]='1'
        state1=''.join(state1[e] for e in range(len(state1)))
        next_states.append(state1)
    use_alpha_beta=[use_alpha, use_beta]

    factor_next_states={}
    for l in range(len(next_states)):
        try:
            factor_next_states[str(next_states[l])]=1/(len(next_states)-use_beta-use_alpha)
            # print('aa')
            # print(1/(len(next_states-use_alpha-use_beta)))
        except:
            None

        if int(list(str2arr(state0)-str2arr(next_states[l]))[-1])==-1:
            factor_next_states[str(next_states[l])]=alpha
        if int(list(str2arr(state0)-str2arr(next_states[l]))[0])==1:
            factor_next_states[str(next_states[l])]=beta
    return next_states, factor_next_states

def check_boundary(state0):
    use_alpha=False
    use_beta=False
    if state0[0]=='1':
        use_beta=True
    if state0[-1]=='0':
        use_alpha=True

    return use_alpha, use_beta

def dec2bin(integer, n_hid):
    string=bin(int(integer))[2:]
    list0=[int(d) for d in string]
    while len(list0)<n_hid:
        list0=[0]+list0
    a=np.array([list0])[0]
    b=''.join(str(e) for e in a)
    return b

def arr2str(array0):
    a=np.array([array0])[0]
    b=''.join(str(e) for e in a)
    return b


def str2arr(str0):
    b=[int(str0[i]) for i in range(len(str0))]
    b=np.array(b)
    return b




def get_ab(config0):
    a=config0.count(0)
    b=config0.count(1)
    return a,b


def seperate(config_list):
    for j in range(len(config_list)):
        if '01' in arr2str(config_list[j]):
            config0=config_list[j]
            break
        else:
            None
    try:
        for i in range(len(config0)-1):
            if config0[i]==0 and config0[i+1]==1:
                config_list.append(del_comp(config0, i))
                config_list.append(del_comp(config0, i+1))
                config_list.remove(config0)
                break
            else:
                None
    except:
        None
def check(config_list):
    TF='01'
    for i in range(len(config_list)):
        if '01' in arr2str(config_list[i]):
            TF=True
            break
        else:
            TF=False
    return TF

def del_comp(list0, i):
    list1=list0.copy()
    del list1[i]
    return list1

def opS(list0):
    config_list=[list0]
    while check(config_list):
   #     print(config_list)
        seperate(config_list)
    return config_list

def get_ideal_dist():
    dist={}
    for i in range(2**Nv):
        flist=opS(list(str2arr(dec2bin(i,Nv))))    
        f=0
        for j in range(len(flist)):
            a,b=get_ab(flist[j])
            f+=((1/alpha)**a)*((1/beta)**b)

        dist[dec2bin(i,Nv)]=f
    s=np.sum(list(dist.values()))
    for i in range(2**Nv):
        dist[dec2bin(i,Nv)]=dist[dec2bin(i,Nv)]/s
    return dist
    

In [None]:
def TASEP_dynamics(state0, n_sample):
    state_list=[]
    for step in range(n_sample):
        r0=np.random.randint(0,N+1)
        if r0==0:
            if state0[-1]=='0':
                if np.random.rand()<alpha:
                    state0=str2arr(state0)
                    state0[-1]='1'
                else:None
            else:None
        elif r0==N:
            if state0[0]=='1':
                if np.random.rand()<beta:
                    state0=str2arr(state0)
                    state0[0]='0'
                else:None
            else:None
        else:
            if state0[-r0]=='1' and state0[-r0-1]=='0':
                state0=str2arr(state0)
                state0[-r0]='0'
                state0[-r0-1]='1'
        state0=arr2str(state0)
        # state0=str2arr(state0)
        state_list.append(state0)
    
    # staet counting
    dist={}
    for i in range(len(state_list)):
        dist[state_list[i]]=0
    for i in range(len(state_list)):
        dist[state_list[i]]+=1
    observed=list(dist.keys())
    for i in range(len(observed)):
        
        dist[observed[i]]=dist[observed[i]]/len(state_list)
        
    for i in range(len(state_list)):
        state_list[i]=torch.tensor(str2arr(state_list[i]))
    return torch.stack(state_list).float()

In [None]:
# Hyper parameter들을 설정
n_vis=9
k=1
lr=0.1
std=0.5


In [None]:
lr=1
n_vis=9
N=n_vis
Nv=n_vis
std=0.1

epoch_to_save=[2**i for i in range(17)]
n_epochs=epoch_to_save[-1]+1
torch.set_printoptions(precision=10)


for alpha, beta in [[0.2,0.6],[0.3,0.5],[0.4,0.4]]:
    # for beta in [0.4]:
        prob_dict=get_ideal_dist()

        sum=np.sum(list(prob_dict.values()))
        for i in range(2**N):
            config0=dec2bin(i, N)
            prob_dict[config0]=prob_dict[config0]/sum

        full_configs1=[]
        full_configs2=[]
        for i in range(2**N):
            full_configs1.append(str2arr(dec2bin(i,N)))
            full_configs2.append(str2arr(dec2bin(i,N))*2-1)
        full_configs1=torch.tensor(full_configs1).float()
        full_configs2=torch.tensor(full_configs2).float()

        for n_hid in [1,2,3,4,6,8,12]:  
            Nh=n_hid
            Pv=torch.tensor(list(prob_dict.values()))
            S=entropy(Pv)
            # batch_size=int(vol/2)
            dict_model={}
            dict_GE={}
            train_loader_list=[]; val_loader_list=[]

            for m in range(10):
                model0, loss=train_and_get_data(n_hid, 0, lr=lr)
                dict_model[str(m)]=model0
                dict_GE[str(m)]=loss
            with open('{base}/TASEP/exact_state_dict/model_Nh={Nh}_alpha={alpha}_beta={beta}.pkl'.format(base=base, Nh=Nh, alpha=alpha, beta=beta), 'wb') as f:
                pkl.dump(dict_model, f)
            with open('{base}/TASEP/exact_loss/GE_Nh={Nh}_alpha={alpha}_beta={beta}.pkl'.format(base=base, Nh=Nh, alpha=alpha, beta=beta), 'wb') as f:
                pkl.dump(dict_GE, f)