<a href="https://colab.research.google.com/github/physicaone/loss_IG/blob/master/%5B210911%5Dmarginalized_GE2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torchvision.datasets
import torchvision.models
import torchvision.transforms
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data
import torch.nn as nn
from tqdm import tqdm, tnrange
import warnings
warnings.filterwarnings("ignore")
import random
import pickle as pkl
import pandas as pd
from scipy.stats import entropy
from itertools import combinations

import copy
CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0

try:
    from google.colab import drive
    drive.mount('/content/drive')
    base='drive/MyDrive'
except:
    if torch.cuda.device_count()>1:
        base='.'
    else:
        base='Google Drive'

if CUDA:
    device='cuda'
else:
    device='cpu'
torch.cuda.is_available()

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


True

In [None]:
# 아래 두 함수는 PT와는 아무 상관 없습니다. 10진법을 2진법으로 바꾸고 리스트 형태로 변형하는 함수입니다. 예)15->[1,1,1,1]
# 그리고 그 다음은 역함수입니다.
def decimal_to_binary(integer, n_hid):
    string=bin(int(integer))[2:]
    list0=[float(d) for d in string]
    while len(list0)<n_hid:
        list0=[0.]+list0
    return torch.tensor([list0])

def binary_to_decimal(list0):
    value=0
    list0=list0.tolist()
    for i in range(len(list0)):
        value+=list0[-i-1]*2**(i)
    return int(value)

def get_hist(list00, color='red'):
    bins=range(int(min(list00)-30), int(max(list00)+30), 1)
    y1,x1,_ = plt.hist(list00, bins = bins, histtype='step', color=color)
    x1 = 0.5*(x1[1:]+x1[:-1])
    return x1, y1

def flatten_list(list0):
    flattened = [val for sublist in list0 for val in sublist]
    return flattened

def Energy(model0_dict, v_list, h_list):
    a=model0_dict['v'].detach()
    b=model0_dict['h'].detach()
    W=model0_dict['W'].detach()
    values=[]
    for i in range(len(v_list)):
        e=-np.matmul(v_list[i], a.t())-np.matmul(np.matmul(v_list[i], W.t()), h_list[i].t())-np.matmul(h_list[i], b.t())
        values.append(e.detach())
    return float(np.mean(values))

# Energy_GPU returns an array of values
def Energy_GPU(model0_dict, v_list0, h_list0):
    n_split=torch.cuda.device_count()

    values=[]
    vs=[]
    hs=[]
    for i in range(n_split):
        v_list=torch.stack(list(v_list0[i*int(len(v_list0)/n_split):(i+1)*int(len(v_list0)/n_split)]))
        h_list=torch.stack(list(h_list0[i*int(len(h_list0)/n_split):(i+1)*int(len(h_list0)/n_split)]))
        v_list=v_list.detach().to(device='cuda:' + str(i)).view(len(v_list), n_vis)
        h_list=h_list.detach().to(device='cuda:' + str(i)).view(len(h_list), n_hid)
        vs.append(v_list)
        hs.append(h_list)
    for i in range(n_split):  
        a=model0_dict['v'].detach().to(device='cuda:' + str(i)).view(n_vis)
        b=model0_dict['h'].detach().to(device='cuda:' + str(i)).view(n_hid)
        W=model0_dict['W'].detach().to(device='cuda:' + str(i)).view(n_hid, n_vis)
        e=-torch.matmul(vs[i].float(), a)-torch.diagonal(torch.matmul(torch.matmul(vs[i].float(), W.t()), hs[i].float().t()))-torch.matmul(hs[i].float(), b)
        values.append(e.cpu().detach().numpy())

    return np.array(flatten_list(values))


def sigmoid(x):
    return 1 / (1 + np.exp(-x))


In [None]:
def Ising_energy(v_list):
    L = 3
    E_list=[]
    for n in range(len(v_list)):
        v=v_list[n]
        E = 0
        for i in range(L):
            for j in range(L):
                s = v[i,j]
                neigh = v[(i+1)%L, j] + v[i,(j+1)%L] + v[(i-1)%L,j] + v[i,(j-1)%L] 
                E += -neigh * s
        E_list.append(E/2)
    return np.array(E_list)
    
class RBM(nn.Module):

    def __init__(self, n_vis, n_hid, k):
        """Create a RBM."""
        super(RBM, self).__init__()
        
        self.v_bias = nn.Parameter(torch.zeros(1, n_vis).to(device))
        self.h_bias = nn.Parameter(torch.zeros(1, n_hid).to(device))
        self.Weight = nn.Parameter(torch.zeros(n_hid, n_vis).to(device))
        self.k = k


    def v2h(self, v):
        return torch.sigmoid(F.linear(v, self.Weight, self.h_bias))

    def h2v(self, h):
        return torch.sigmoid(F.linear(h, self.Weight.t(), self.v_bias))
    
    def Fv(self, v):
        v_term = torch.matmul(v, self.v_bias.t()).view(len(v))
        h_term = torch.sum(F.softplus(F.linear(v, self.Weight, self.h_bias)), dim=1)
        return -h_term -v_term

    def energy(self, v, h):
        v=v.bernoulli()
        h=h.bernoulli()
        return -torch.matmul(v, self.v_bias.t())-torch.matmul(torch.matmul(v, self.Weight.t()),h.t())-torch.matmul(h, self.h_bias.t())
    
    def Energy_GPU(self, v_list0, h_list0):
        if CUDA:
            n_split=torch.cuda.device_count()
        else:
            n_split=1
        e_list=[]
        m_split=2**13
        for j in range(m_split):
            v_list1=torch.stack(list(v_list0[j*int(len(v_list0)/m_split):(j+1)*int(len(v_list0)/m_split)]))
            h_list1=torch.stack(list(h_list0[j*int(len(h_list0)/m_split):(j+1)*int(len(h_list0)/m_split)]))
            vs=[]
            hs=[]
            for i in range(n_split):
                v_list2=torch.stack(list(v_list1[i*int(len(v_list1)/n_split):(i+1)*int(len(v_list1)/n_split)]))
                h_list2=torch.stack(list(h_list1[i*int(len(h_list1)/n_split):(i+1)*int(len(h_list1)/n_split)]))
                if CUDA:
                    v_list2=v_list2.to(device='cuda:' + str(i)).view(len(v_list2), n_vis)
                    h_list2=h_list2.to(device='cuda:' + str(i)).view(len(h_list2), n_hid)
                else:
                    None
                vs.append(v_list2)
                hs.append(h_list2)
            for i in range(n_split): 
                if CUDA:
                    a=self.v_bias.to(device='cuda:' + str(i)).view(n_vis)
                    b=self.h_bias.to(device='cuda:' + str(i)).view(n_hid)
                    W=self.Weight.to(device='cuda:' + str(i)).view(n_hid, n_vis)
                    e=(-torch.matmul(vs[i].float(), a)-torch.diagonal(torch.matmul(torch.matmul(vs[i].float(), W.t()), hs[i].float().t()))-torch.matmul(hs[i].float(), b)).to('cuda:0')
                    e_list.append(e)
                else:
                    a=self.v_bias.view(n_vis)
                    b=self.h_bias.view(n_hid)
                    W=self.Weight.view(n_hid, n_vis)
                    e=(-torch.matmul(vs[i].float(), a)-torch.diagonal(torch.matmul(torch.matmul(vs[i].float(), W.t()), hs[i].float().t()))-torch.matmul(hs[i].float(), b))
                    e_list.append(e)
        return torch.stack(e_list).view(len(v_list0))
    
    def forward(self, v):
        h = self.v2h(v)
        h = h.bernoulli()
        for _ in range(self.k):
            v_gibbs = self.h2v(h).to(device)
            v_gibbs = v_gibbs.bernoulli()
            h = self.v2h(v_gibbs).to(device)
            h = h.bernoulli()
        return v, v_gibbs
        
def CM_model(models):
    new_v_bias=0; new_h_bias=0; new_Weight=0
    for m in range(10):
        new_v_bias+=models[str(m)][-1]['v_bias']/10
        new_h_bias+=models[str(m)][-1]['h_bias']/10
        new_Weight+=models[str(m)][-1]['Weight']/10
    return {'v_bias':new_v_bias, 'h_bias':new_h_bias, 'Weight':new_Weight}  

def mean_Fv(Q_m, v0):
    value=0
    for m in range(10):
        rbm=RBM(n_vis, n_hid, k=1)
        rbm.load_state_dict(Q_m[str(m)])
        value+=rbm.Fv(v0)/10
    return value

def JS_mn(Q_m, Q_n):
    rbm_m=RBM(n_vis, n_hid, k=1); rbm_m.load_state_dict(Q_m)
    rbm_n=RBM(n_vis, n_hid, k=1); rbm_n.load_state_dict(Q_n)
    bf_list_m=[]; bf_list_n=[]
    for i in range(len(v_list_ising)):
        bf_list_m.append(rbm_m.Fv(v_list_ising2[i].view(1,9)))
        bf_list_n.append(rbm_n.Fv(v_list_ising2[i].view(1,9)))
    bf_list_m=torch.tensor(bf_list_m)
    bf_list_n=torch.tensor(bf_list_n)
    Qm=torch.tensor(bf_list_m/sum(bf_list_m)).to(device)
    Qn=torch.tensor(bf_list_n/sum(bf_list_n)).to(device)
    Fvm=rbm_m.Fv(v_list_ising2)
    Fvn=rbm_n.Fv(v_list_ising2)
    
    Fvm_m=float(torch.dot(rbm_m.Fv(v_list_ising2), Qm).detach().cpu().numpy())
    Fvm_n=float(torch.dot(rbm_m.Fv(v_list_ising2), Qn).detach().cpu().numpy())
    Fvn_m=float(torch.dot(rbm_n.Fv(v_list_ising2), Qm).detach().cpu().numpy())
    Fvn_n=float(torch.dot(rbm_n.Fv(v_list_ising2), Qn).detach().cpu().numpy())

    return 0.5*(Fvm_n+Fvn_m-Fvm_m-Fvn_n)

def JS_mn_vh(Q_m, Q_n):
    rbm_m=RBM(n_vis, n_hid, k=1); rbm_m.load_state_dict(Q_m)
    rbm_n=RBM(n_vis, n_hid, k=1); rbm_n.load_state_dict(Q_n)

    Em_list=rbm_m.Energy_GPU(v_list_rbm, h_list_rbm).cpu()
    En_list=rbm_n.Energy_GPU(v_list_rbm, h_list_rbm).cpu()
    bf_list_m=torch.exp(-Em_list).detach().cpu()
    bf_list_n=torch.exp(-En_list).detach().cpu()
    Qm=bf_list_m/sum(bf_list_m).detach().cpu()
    Qn=bf_list_n/sum(bf_list_n).detach().cpu()

    return float(0.5*(torch.dot(En_list-Em_list, Qm) + torch.dot(Em_list-En_list, Qn)).detach().numpy())


In [None]:
n_vis=9
k=5
lr=0.1
vol=1024
std=0.5
T_list=[1.9,2.3,3.0]
n_hid_list=[1,2,3,4,6,8,12]

In [None]:
for n_hid in n_hid_list:
    v_list_rbm=[]; h_list_rbm=[]
    for s in range(2**(n_vis+n_hid)):
        full=decimal_to_binary(s, n_hid+n_vis)[0]
        v=full[:n_vis]; h=full[-n_hid:]
        v_list_rbm.append(v); h_list_rbm.append(h)
    v_list_rbm=torch.stack(v_list_rbm).to(device)
    h_list_rbm=torch.stack(h_list_rbm).to(device)

    v_list_ising=[]
    v_list_ising2=[]
    for s in range(2**n_vis):
        v=decimal_to_binary(s, n_vis)[0]
        v_list_ising.append(np.reshape(v,(3,3))*2-1)
        v_list_ising2.append(v)
    v_list_ising2=torch.stack(v_list_ising2).to(device)       
    for T in T_list:
        bf_list=np.exp(-Ising_energy(v_list_ising)/T)
        S=entropy(bf_list)
        Pv=torch.tensor(bf_list/sum(bf_list)).to(device)
        models=pd.read_pickle('{base}/loss_IG/3*3/state_dict/model_n_hid={n_hid}_T={T}_lr={lr}_vol={vol}_std={std}.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr, vol=vol, std=std))

        # quantities of mean model
        Q_bar=CM_model(models)
        rbm=RBM(n_vis, n_hid, k=1)
        rbm.load_state_dict(Q_bar)
        # F_bar=float(-torch.log(torch.sum(torch.exp(-rbm.Energy_GPU(v_list_rbm, h_list_rbm)))).detach().cpu().numpy())
        # F_bar=float(-torch.log(torch.sum(torch.exp(-rbm.Fv(v_list_rbm)))).detach().cpu().numpy())

        # quantities of model's mean
        Q_m={}
        for m in range(10):
            Q_m[str(m)]=models[str(m)][-1]
        
        F_bar=0
        for i in range(len(v_list_ising2)):
            F_bar+=torch.exp(-mean_Fv(Q_m, v_list_ising2[i].view(1,9))-100)
        F_bar=-float(torch.log(F_bar).detach().cpu().numpy())-100

        F_mean=[]
        Fv_mean=[]
        for m in range(10):
            rbm=RBM(n_vis, n_hid, k=1)
            rbm.load_state_dict(Q_m[str(m)])
            F_m=-float(torch.log(torch.sum(torch.exp(-rbm.Energy_GPU(v_list_rbm, h_list_rbm)-100))).detach().cpu().numpy())-100
            Fv_m=float(torch.dot(rbm.Fv(v_list_ising2), Pv).detach().cpu().numpy())
            F_mean.append(F_m)
            Fv_mean.append(Fv_m)
        Bias=-S+np.array(Fv_mean)-F_bar
        Variance=-np.array(F_mean)+F_bar
        with open('{base}/loss_IG/3*3/data/mar_BV_n_hid={n_hid}_T={T}_lr={lr}_vol={vol}_std={std}.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr, vol=vol, std=std), 'wb') as f:
            pkl.dump([Bias,Variance], f)
            

In [None]:
n_hid_list=[8,12]
for n_hid in n_hid_list:
    v_list_rbm=[]; h_list_rbm=[]
    for s in range(2**(n_vis+n_hid)):
        full=decimal_to_binary(s, n_hid+n_vis)[0]
        v=full[:n_vis]; h=full[-n_hid:]
        v_list_rbm.append(v); h_list_rbm.append(h)
    v_list_rbm=torch.stack(v_list_rbm).to(device)
    h_list_rbm=torch.stack(h_list_rbm).to(device)

    v_list_ising=[]
    v_list_ising2=[]
    for s in range(2**n_vis):
        v=decimal_to_binary(s, n_vis)[0]
        v_list_ising.append(np.reshape(v,(3,3))*2-1)
        v_list_ising2.append(v)
    v_list_ising2=torch.stack(v_list_ising2).to(device)       
    for T in T_list:
        models=pd.read_pickle('{base}/loss_IG/3*3/state_dict/model_n_hid={n_hid}_T={T}_lr={lr}_vol={vol}_std={std}.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr, vol=vol, std=std))
        Q_m={}
        for m in range(10):
            Q_m[str(m)]=models[str(m)][-1]
        JS_v=[]
        JS_vh=[]
        com=list(combinations(list(range(10)), 2))
        for c in com:
            JS_v.append(JS_mn(Q_m[str(c[0])], Q_m[str(c[1])]))
            JS_vh.append(JS_mn_vh(Q_m[str(c[0])], Q_m[str(c[1])]))
        with open('{base}/loss_IG/3*3/data/JS_n_hid={n_hid}_T={T}_lr={lr}_vol={vol}_std={std}.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr, vol=vol, std=std), 'wb') as f:
            pkl.dump([JS_v, JS_vh], f)