<a href="https://colab.research.google.com/github/physicaone/loss_IG/blob/master/%5B210815%5DPT_IG_loss.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [42]:
import numpy as np
import torch
import torchvision.datasets
import torchvision.models
import torchvision.transforms
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data
import torch.nn as nn
from tqdm import tqdm, tnrange
import warnings
warnings.filterwarnings("ignore")
import random
import pickle as pkl
import pandas as pd
from scipy.stats import entropy
CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0

try:
    from google.colab import drive
    drive.mount('/content/drive')
    base='drive/MyDrive'
except:
    if torch.cuda.device_count()>1:
        base='.'
    else:
        base='Google Drive'

if CUDA:
    device='cuda'
else:
    device='cpu'
torch.cuda.is_available()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


False

In [43]:
def decimal_to_binary(integer, n_hid):
    string=bin(int(integer))[2:]
    list0=[float(d) for d in string]
    while len(list0)<n_hid:
        list0=[0.]+list0
    return torch.tensor([list0])

def binary_to_decimal(list0):
    value=0
    list0=list0.tolist()
    for i in range(len(list0)):
        value+=list0[-i-1]*2**(i)
    return int(value)

def Ising_energy(v_list):
    L = 3
    E_list=[]
    for n in range(len(v_list)):
        v=v_list[n]
        E = 0
        for i in range(L):
            for j in range(L):
                s = v[i,j]
                neigh = v[(i+1)%L, j] + v[i,(j+1)%L] + v[(i-1)%L,j] + v[i,(j-1)%L] 
                E += -neigh * s
        E_list.append(E/2)
    return np.array(E_list)
    
class RBM(nn.Module):

    def __init__(self, n_vis, n_hid, k):
        """Create a RBM."""
        super(RBM, self).__init__()
        
        self.v_bias = nn.Parameter(torch.ones(1, n_vis).to(device))
        self.h_bias = nn.Parameter(torch.zeros(1, n_hid).to(device))
        self.Weight = nn.Parameter(torch.randn(n_hid, n_vis).to(device))
        self.k = k


    def v2h(self, v, beta):
        return torch.sigmoid(F.linear(v, self.Weight, self.h_bias)*beta).detach()

    def h2v(self, h, beta):
        return torch.sigmoid(F.linear(h, self.Weight.t(), self.v_bias)*beta).detach()
    
    def Fv(self, v):
        v_term = torch.matmul(v, self.v_bias.t()).view(len(v)).detach()
        h_term = torch.sum(F.softplus(F.linear(v, self.Weight, self.h_bias)), dim=1).detach()
        return -h_term -v_term

    def energy(self, v, h):
        v=v.bernoulli().detach()
        h=h.bernoulli().detach()
        return -torch.matmul(v, self.v_bias.t())-torch.matmul(torch.matmul(v, self.Weight.t()),h.t())-torch.matmul(h, self.h_bias.t()).detach()
    
    def Energy_GPU2(self, v_list0, h_list0):
        if CUDA:
            n_split=torch.cuda.device_count()
        else:
            n_split=1
        e_list=[]
        m_split=4
        for j in range(m_split):
            v_list1=torch.stack(list(v_list0[j*int(len(v_list0)/m_split):(j+1)*int(len(v_list0)/m_split)])).detach()
            h_list1=torch.stack(list(h_list0[j*int(len(h_list0)/m_split):(j+1)*int(len(h_list0)/m_split)])).detach()
            vs=[]
            hs=[]
            for i in range(n_split):
                v_list2=torch.stack(list(v_list1[i*int(len(v_list1)/n_split):(i+1)*int(len(v_list1)/n_split)])).detach()
                h_list2=torch.stack(list(h_list1[i*int(len(h_list1)/n_split):(i+1)*int(len(h_list1)/n_split)])).detach()
                if CUDA:
                    v_list2=v_list2.to(device='cuda:' + str(i)).view(len(v_list2), n_vis)
                    h_list2=h_list2.to(device='cuda:' + str(i)).view(len(h_list2), n_hid)
                else:
                    None
                vs.append(v_list2)
                hs.append(h_list2)
            for i in range(n_split): 
                if CUDA:
                    a=self.v_bias.to(device='cuda:' + str(i)).view(n_vis)
                    b=self.h_bias.to(device='cuda:' + str(i)).view(n_hid)
                    W=self.Weight.to(device='cuda:' + str(i)).view(n_hid, n_vis)
                    e=(-torch.matmul(vs[i].float(), a)-torch.diagonal(torch.matmul(torch.matmul(vs[i].float(), W.t()), hs[i].float().t()))-torch.matmul(hs[i].float(), b)).to('cuda:0')
                    e_list.append(e)
                else:
                    a=self.v_bias.view(n_vis)
                    b=self.h_bias.view(n_hid)
                    W=self.Weight.view(n_hid, n_vis)
                    e=(-torch.matmul(vs[i].float(), a)-torch.diagonal(torch.matmul(torch.matmul(vs[i].float(), W.t()), hs[i].float().t()))-torch.matmul(hs[i].float(), b))
                    e_list.append(e)
        return torch.stack(e_list).view(len(v_list0)).detach()
    
# 이 함수는 PT에 사용되는 transition 확률을 계산합니다.
def swap_prob(i,j, model, list00, list11):
    v1=torch.tensor(list00[i]).view(1,n_vis)
    v2=torch.tensor(list00[j]).view(1,n_vis)
    h1=torch.tensor(list11[i]).view(1,n_hid)
    h2=torch.tensor(list11[j]).view(1,n_hid)
    beta1=beta_list[i]
    beta2=beta_list[j]
    return torch.exp((beta2-beta1)*(model.energy(v2, h2)-model.energy(v1, h1)))

# 이 함수는 tansition 확률에 의거, swap을 수행합니다.
def swap(list0, list1, model):
    k=np.random.randint(0, len(list0)-1)
    combination=[k, k+1]
    if swap_prob(combination[0], combination[1], model, list0, list1)>=np.random.rand():
        a=list0[combination[0]].clone()
        b=list0[combination[1]].clone() 
        list0[combination[0]]=b
        list0[combination[1]]=a
        
        c=list1[combination[0]].clone()
        d=list1[combination[1]].clone() 
        list1[combination[0]]=d
        list1[combination[1]]=c
    else: None
#     return list0, list1
#     combinations=list(itertools.combinations(list(range(len(list0))), 2))
#     for k in range(len(combinations)):
#         if swap_prob(combinations[k][1], combinations[k][0])>=np.random.rand():
#             list0[combinations[k][0]]=list0[combinations[k][1]]; list0[combinations[k][1]]=list0[combinations[k][0]]
#         else: None
#     return list0

def P_h(list0):
    config_count={}
    for i in range(len(list0)):
        config_count[str(list0[i])]=0
    for i in range(len(list0)):
        config_count[str(list0[i])]+=1
    return config_count

def prod(L):
    p=1
    for i in L:
        p= i * p
    return p

# def Estimate_Z(model0, states):
#     Z=0
#     for i in range(len(states[0])):
#         Z+=torch.exp(-model0.energy2(states[0][i], states[1][i])).detach()
#     return float(Z.detach().numpy())

def get_hist(list00, color='red'):
    bins=range(int(min(list00)-30), int(max(list00)+30), 1)
    y1,x1,_ = plt.hist(list00, bins = bins, histtype='step', color=color)
    x1 = 0.5*(x1[1:]+x1[:-1])
    return x1, y1

def flatten_list(list0):
    flattened = [val for sublist in list0 for val in sublist]
    return flattened


In [49]:

def Entropy(fullconfigs):
    config_count={} # 각 hidden layer state 갯수 파악 (k)
    for i in range(len(fullconfigs)):
        config_count[str(fullconfigs[i])]=0
    for i in range(len(fullconfigs)):
        config_count[str(fullconfigs[i])]+=1

    listk=[]
    for i in range(len(list(config_count.values()))):
        listk.append(int(list(config_count.values())[i]))
    listmk=[]
    kcount={} # 갯수의 갯수 파악 (m_k)
    for i in range(len(listk)):
        kcount[listk[i]]=0
    for i in range(len(listk)):
        kcount[listk[i]]+=1
    for i in range(len(kcount)):
        listmk.append(kcount[sorted(list(kcount))[i]])
    x,y= sorted(list(kcount)), listmk

    N=len(fullconfigs)
    H_s=0
    for i in range(len(x)):
        H_s-=(x[i]*y[i]/N)*np.log(x[i]/N)
    return H_s
def Energy(model0_dict, v_list, h_list):
    a=model0_dict['v_bias'].detach()
    b=model0_dict['h_bias'].detach()
    W=model0_dict['Weight'].detach()
    values=[]
    for i in range(len(v_list)):
        e=-np.matmul(v_list[i], a.t())-np.matmul(np.matmul(v_list[i], W.t()), h_list[i].t())-np.matmul(h_list[i], b.t())
        values.append(e.detach())
    return float(np.mean(values))
    
# function to save every T samples
def sampling_with_PT(fullmodel, eq_step, n_step):
    # states for full model
    hidden_states_now=[decimal_to_binary(2**n_hid-1, n_hid)]*len(beta_list)
    visible_states_now=[1.]*len(beta_list)

    hidden_list=[]
    visible_list=[]
    
    # equilibrium steps for full model
    for step in range(eq_step):
        for i in range(len(beta_list)):
            visible_states_now[i]=fullmodel.h2v(hidden_states_now[i].to(device), beta_list[i]).bernoulli().detach()
            hidden_states_now[i]=fullmodel.v2h(visible_states_now[i].to(device), beta_list[i]).bernoulli().detach()

    # Tasks with PT
    for step in tnrange(n_step):
        hidden_tmp=[]
        visible_tmp=[]
        for i in range(len(beta_list)):
            # Gibbs sampling of fu
            visible_states_now[i]=fullmodel.h2v(hidden_states_now[i].to(device), beta_list[i]).bernoulli().detach()
            hidden_states_now[i]=fullmodel.v2h(visible_states_now[i].to(device), beta_list[i]).bernoulli().detach()

            hidden_tmp.append(int(binary_to_decimal(hidden_states_now[i].view(n_hid))))
            visible_tmp.append(int(binary_to_decimal(visible_states_now[i].view(n_vis))))
        swap(visible_states_now, hidden_states_now, fullmodel)
        hidden_list.append(hidden_tmp)
        visible_list.append(visible_tmp)
    return visible_list, hidden_list

def Curly_W(model_dict, v, h):
    w=[]
    for i in range(len(beta_list)-1):
        w.append(beta_list[i]*Energy(model_dict, decimal_to_binary(v[i+1], n_vis), decimal_to_binary(h[i+1], n_hid))
        -beta_list[i+1]*Energy(model_dict, decimal_to_binary(v[i+1],n_vis), decimal_to_binary(h[i+1],n_hid)))
    return w

def Curly_W_tilde(model_dict, v, h):
    w_t=[]
    for i in range(1,len(beta_list)):
        w_t.append(beta_list[i-1]*Energy(model_dict, decimal_to_binary(v[i-1], n_vis), decimal_to_binary(h[i-1], n_hid))
        -beta_list[i]*Energy(model_dict, decimal_to_binary(v[i-1],n_vis), decimal_to_binary(h[i-1],n_hid)))
    return w_t

def AISPT(model_dict, v_list, h_list):
    r=len(v_list)
    C=0
    for n in tnrange(r):
        C+=np.exp(-np.sum(Curly_W(model_dict, v_list[n], h_list[n])))
    C=C/r
    C=np.log(C)
    return -C -np.log(2**(n_vis+n_hid)/100000)-np.log(100000)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))


In [58]:
n_hid_list=[1,2,4,8,16]
T_list=[1.47,1.78,2.3,5.2,16]
n_beta=101
n_step=10000
n_eq=1000
beta_list=torch.tensor(np.linspace(1,0,n_beta).astype(float)).to(device)
n_vis=9
lr=0.05
k=1

In [None]:
for n_hid in n_hid_list:
    v_list_ising=[]
    v_list_ising2=[]
    for s in range(2**n_vis):
        v=decimal_to_binary(s, n_vis)[0]
        v_list_ising.append(np.reshape(v,(3,3))*2-1)
        v_list_ising2.append(v)
    v_list_ising2=torch.stack(v_list_ising2).to(device)  
    for T in T_list:
        bf_list=np.exp(-Ising_energy(v_list_ising)/T)
        S=entropy(bf_list)
        Pv=torch.tensor(bf_list/sum(bf_list)).to(device)
        
        IG_loss=[]
        models=pd.read_pickle('{base}/loss_IG/3*3/state_dict/model_n_hid={n_hid}_T={T}_lr={lr}_1.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr))
        for i in range(len(models)):
            rbm=RBM(n_vis, n_hid, k)
            rbm.load_state_dict(models[i])
            v_list, h_list=sampling_with_PT(rbm, eq_step, n_step)
            IG_loss.append(torch.dot(rbm.Fv(v_list_ising2), Pv)-AISPT(models[i], v_list, h_list)-S)
        with open('{base}/loss_IG/3*3/loss/train_loss_n_hid={n_hid}_T={T}_lr={lr}_1_PT.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr), 'wb') as f:
            pkl.dump(np.array(train_loss_list), f)
        with open('{base}/loss_IG/3*3/loss/IG_loss_n_hid={n_hid}_T={T}_lr={lr}_1_PT.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr), 'wb') as f:
            pkl.dump(np.array(IG_loss_list), f)

  0%|          | 0/10000 [00:00<?, ?it/s]