<a href="https://colab.research.google.com/github/physicaone/loss_IG/blob/master/%5B210901%5DTrain_and_get_data2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import torch
import torchvision.datasets
import torchvision.models
import torchvision.transforms
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data
import torch.nn as nn
from tqdm import tqdm, tnrange
import warnings
warnings.filterwarnings("ignore")
import random
import pickle as pkl
import pandas as pd
from scipy.stats import entropy
CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0

try:
    from google.colab import drive
    drive.mount('/content/drive')
    base='drive/MyDrive'
except:
    if torch.cuda.device_count()>1:
        base='.'
    else:
        base='Google Drive'

if CUDA:
    device='cuda'
else:
    device='cpu'
torch.cuda.is_available()

Mounted at /content/drive


True

In [6]:
def decimal_to_binary(integer, n_hid):
    string=bin(int(integer))[2:]
    list0=[float(d) for d in string]
    while len(list0)<n_hid:
        list0=[0.]+list0
    return torch.tensor([list0])

def Ising_energy(v_list):
    L = 3
    E_list=[]
    for n in range(len(v_list)):
        v=v_list[n]
        E = 0
        for i in range(L):
            for j in range(L):
                s = v[i,j]
                neigh = v[(i+1)%L, j] + v[i,(j+1)%L] + v[(i-1)%L,j] + v[i,(j-1)%L] 
                E += -neigh * s
        E_list.append(E/2)
    return np.array(E_list)
    
class RBM(nn.Module):

    def __init__(self, n_vis, n_hid, k):
        """Create a RBM."""
        super(RBM, self).__init__()
        
        self.v_bias = nn.Parameter(torch.zeros(1, n_vis).to(device))
        self.h_bias = nn.Parameter(torch.zeros(1, n_hid).to(device))
        self.Weight = nn.Parameter(torch.randn(n_hid, n_vis).to(device)*0.5)
        self.k = k


    def v2h(self, v):
        return torch.sigmoid(F.linear(v, self.Weight, self.h_bias))

    def h2v(self, h):
        return torch.sigmoid(F.linear(h, self.Weight.t(), self.v_bias))
    
    def Fv(self, v):
        v_term = torch.matmul(v, self.v_bias.t()).view(len(v))
        h_term = torch.sum(F.softplus(F.linear(v, self.Weight, self.h_bias)), dim=1)
        return -h_term -v_term

    def energy(self, v, h):
        v=v.bernoulli()
        h=h.bernoulli()
        return -torch.matmul(v, self.v_bias.t())-torch.matmul(torch.matmul(v, self.Weight.t()),h.t())-torch.matmul(h, self.h_bias.t())
    
    def Energy_GPU2(self, v_list0, h_list0):
        if CUDA:
            n_split=torch.cuda.device_count()
        else:
            n_split=1
        e_list=[]
        m_split=2**6
        for j in range(m_split):
            v_list1=torch.stack(list(v_list0[j*int(len(v_list0)/m_split):(j+1)*int(len(v_list0)/m_split)]))
            h_list1=torch.stack(list(h_list0[j*int(len(h_list0)/m_split):(j+1)*int(len(h_list0)/m_split)]))
            vs=[]
            hs=[]
            for i in range(n_split):
                v_list2=torch.stack(list(v_list1[i*int(len(v_list1)/n_split):(i+1)*int(len(v_list1)/n_split)]))
                h_list2=torch.stack(list(h_list1[i*int(len(h_list1)/n_split):(i+1)*int(len(h_list1)/n_split)]))
                if CUDA:
                    v_list2=v_list2.to(device='cuda:' + str(i)).view(len(v_list2), n_vis)
                    h_list2=h_list2.to(device='cuda:' + str(i)).view(len(h_list2), n_hid)
                else:
                    None
                vs.append(v_list2)
                hs.append(h_list2)
            for i in range(n_split): 
                if CUDA:
                    a=self.v_bias.to(device='cuda:' + str(i)).view(n_vis)
                    b=self.h_bias.to(device='cuda:' + str(i)).view(n_hid)
                    W=self.Weight.to(device='cuda:' + str(i)).view(n_hid, n_vis)
                    e=(-torch.matmul(vs[i].float(), a)-torch.diagonal(torch.matmul(torch.matmul(vs[i].float(), W.t()), hs[i].float().t()))-torch.matmul(hs[i].float(), b)).to('cuda:0')
                    e_list.append(e)
                else:
                    a=self.v_bias.view(n_vis)
                    b=self.h_bias.view(n_hid)
                    W=self.Weight.view(n_hid, n_vis)
                    e=(-torch.matmul(vs[i].float(), a)-torch.diagonal(torch.matmul(torch.matmul(vs[i].float(), W.t()), hs[i].float().t()))-torch.matmul(hs[i].float(), b))
                    e_list.append(e)
        return torch.stack(e_list).view(len(v_list0))
    
    def forward(self, v):
        h = self.v2h(v)
        h = h.bernoulli()
        for _ in range(self.k):
            v_gibbs = self.h2v(h).to(device)
            v_gibbs = v_gibbs.bernoulli()
            h = self.v2h(v_gibbs).to(device)
            h = h.bernoulli()
        return v, v_gibbs
        
from torch.utils.data import Dataset

class CustomDataset(Dataset): 
    def __init__(self, dataset):
        data_x = dataset
        self.x_data = data_x
#         self.y_data = data_y

    # 총 데이터의 개수를 리턴
    def __len__(self): 
        return len(self.x_data)
    # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
    def __getitem__(self, idx): 
        x = torch.FloatTensor(self.x_data[idx])
#         y = torch.FloatTensor([self.y_data[idx]])
        return x

def data_to_loader(fullconfigs):
    fulldata=CustomDataset(fullconfigs)
    full_dataset = fulldata
    full_loader = torch.utils.data.DataLoader(full_dataset, batch_size)
    return full_loader

def train_and_get_data(n_hid, model, lr, train_loader):
    # Train and get the new result
    rbm=RBM(n_vis, n_hid, k)
    train_op = optim.SGD(rbm.parameters(), lr, momentum=0.9)
    rbm.train()
    train_loss_list=[]
    IG_loss_list=[]
    model_list=[]
    decay_count=0
    # for epoch in tnrange(n_epochs):
    epoch=0
    while epoch<=n_epochs:
        epoch+=1
        train_loss_epoch = []
        for _, (data) in enumerate(train_loader):
            data=data.to(device)
            v, v_gibbs = rbm(data.view(-1, n_vis))
            train_loss = torch.mean(rbm.Fv(v)) - torch.mean(rbm.Fv(v_gibbs))
            train_loss_epoch.append(train_loss.item())
            train_op.zero_grad()
            train_loss.backward()
            train_op.step()
        if epoch in epoch_to_save:
#             decay_count+=1
#             train_op = optim.SGD(rbm.parameters(), 0.5**decay_count, momentum=0.9)
#             print(0.9**decay_count)
            model_list.append(rbm.cpu().state_dict())
            rbm.to(device)
            train_loss_list.append(np.mean(train_loss_epoch))
            Fv=torch.dot(rbm.Fv(v_list_ising2), Pv)
            F=-torch.log(torch.sum(torch.exp(-rbm.Energy_GPU2(v_list_rbm, h_list_rbm))))
            IG_loss_list.append(float((Fv-F-S).detach().cpu().numpy()))
            print(epoch, train_loss_list[-1], IG_loss_list[-1])
            if IG_loss_list[-1]==float('inf'):
                epoch=int(epoch/2)
                lr=lr*0.1    
                train_op = optim.SGD(rbm.parameters(), lr, momentum=0.9)
                rbm=RBM(n_vis, n_hid, k)
                rbm.load_state_dict(model_list[-2])
                rbm.train()
                rbm.to(device)
                IG_loss_list.pop()
                train_loss_list.pop()
                model_list.pop()

        # if min_loss>train_loss_list[-1]:
        #     min_loss=train_loss_list[-1]
        #     model_to_save=rbm
        # if epoch>100 and train_loss_list[-1]>=train_loss_list[-30] and train_loss_list[-1]>=train_loss_list[-100]:
        #     model_to_save=rbm
        #     break

    return model_list, train_loss_list, IG_loss_list

In [7]:
# Hyper parameter들을 설정
n_vis=9
k=5
lr=0.1
vol=512
batch_size=int(vol/2)
epoch_to_save=[2**i for i in range(21)]
n_epochs=epoch_to_save[-1]+1

In [None]:

torch.set_printoptions(precision=10)
for n_hid in [1,2,4,8]:
    v_list_rbm=[]; h_list_rbm=[]
    for s in tqdm(range(2**(n_vis+n_hid))):
        full=decimal_to_binary(s, n_hid+n_vis)[0]
        v=full[:n_vis]; h=full[-n_hid:]
        v_list_rbm.append(v); h_list_rbm.append(h)
    v_list_rbm=torch.stack(v_list_rbm).to(device)
    h_list_rbm=torch.stack(h_list_rbm).to(device)

    v_list_ising=[]
    v_list_ising2=[]
    for s in range(2**n_vis):
        v=decimal_to_binary(s, n_vis)[0]
        v_list_ising.append(np.reshape(v,(3,3))*2-1)
        v_list_ising2.append(v)
    v_list_ising2=torch.stack(v_list_ising2).to(device)        
    for T in [1.47,1.78,2.3,5.2,16]:
        fullconfigs=pd.read_pickle('{base}/loss_IG/3*3/3*3_full_T={T}.pkl'.format(base=base, T=T))
        loader_list=[]
        for i in range(10):
            loader_list.append(data_to_loader(fullconfigs[i][:vol]))

        bf_list=np.exp(-Ising_energy(v_list_ising)/T)
        S=entropy(bf_list)
        Pv=torch.tensor(bf_list/sum(bf_list)).to(device)
        # try:
        #     old_model=pd.read_pickle('{base}/loss_IG/3*3/state_dict/model_n_hid={n_hid}_T={T}.pkl'.format(base=base, n_hid=n_hid, T=T))
        # except:
        #     old_model=0
        # dict_model={}
        # dict_train_loss={}
        # dict_IG_loss={}
        # for m in range(10):
        model0, train_loss_list, IG_loss_list=train_and_get_data(n_hid, 0, lr=lr, train_loader=loader_list[0])
            # dict_model[str(m)]=model0
            # dict_train_loss[str(m)]=train_loss_list
            # dict_IG_loss[str(m)]=IG_loss_list
        with open('{base}/loss_IG/3*3/state_dict/model_n_hid={n_hid}_T={T}_lr={lr}_vol={vol}_std_half.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr, vol=vol), 'wb') as f:
            pkl.dump(model0, f)
        with open('{base}/loss_IG/3*3/loss/train_loss_n_hid={n_hid}_T={T}_lr={lr}_vol={vol}_std_half.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr, vol=vol), 'wb') as f:
            pkl.dump(np.array(train_loss_list), f)
        with open('{base}/loss_IG/3*3/loss/IG_loss_n_hid={n_hid}_T={T}_lr={lr}_vol={vol}_std_half.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr, vol=vol), 'wb') as f:
            pkl.dump(np.array(IG_loss_list), f)

100%|██████████| 1024/1024 [00:00<00:00, 54297.26it/s]


1 0.10287454724311829 5.295716285705566
2 0.0735967755317688 5.281715393066406
4 0.009208440780639648 5.271400451660156
8 0.07106983661651611 5.341963768005371
16 -0.12391269207000732 5.113858222961426
32 -1.6575413942337036 3.2452828884124756
64 -1.7976221442222595 2.6308233737945557
128 -0.8268611431121826 5.235495567321777
256 -0.27458810806274414 8.120588302612305
512 -0.14913034439086914 9.674600601196289
1024 -0.13956308364868164 10.50554370880127
2048 0.02631664276123047 10.705501556396484
4096 0.05428314208984375 10.604898452758789
8192 0.013948440551757812 10.387551307678223
16384 -0.05788469314575195 10.020417213439941
32768 -0.013788700103759766 9.451355934143066
65536 0.012848854064941406 9.045248031616211


# Move dicts from GPU to CPU

In [None]:
lr=0.9
for T in [1.47,1.78,2.3,5.2,16]:
    for n_hid in [1,2,4,8,16]:
        models=pd.read_pickle('{base}/loss_IG/3*3/state_dict/model_n_hid={n_hid}_T={T}_lr={lr}.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr))
        new_models=[]
        for i in range(len(models)):
            rbm=RBM(n_vis, n_hid, k)
            rbm.load_state_dict(models[i])
            new_models.append(rbm.cpu().state_dict())
        with open('{base}/loss_IG/3*3/state_dict/model_n_hid={n_hid}_T={T}_lr={lr}_1.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr), 'wb') as f:
            pkl.dump(new_models, f)    

# $D_{\text{KL}}(P_{\text{trainset}}\|P_{\text{realIsing}})$

In [None]:
def Entropy(fullconfigs):
    config_count={} # 각 hidden layer state 갯수 파악 (k)
    for i in range(len(fullconfigs)):
        config_count[str(fullconfigs[i])]=0
    for i in range(len(fullconfigs)):
        config_count[str(fullconfigs[i])]+=1

    listk=[]
    for i in range(len(list(config_count.values()))):
        listk.append(int(list(config_count.values())[i]))
    listmk=[]
    kcount={} # 갯수의 갯수 파악 (m_k)
    for i in range(len(listk)):
        kcount[listk[i]]=0
    for i in range(len(listk)):
        kcount[listk[i]]+=1
    for i in range(len(kcount)):
        listmk.append(kcount[sorted(list(kcount))[i]])
    x,y= sorted(list(kcount)), listmk

    N=len(fullconfigs)
    H_s=0
    for i in range(len(x)):
        H_s-=(x[i]*y[i]/N)*np.log(x[i]/N)
    return H_s

torch.set_printoptions(precision=10)
v_list_ising=[]
v_list_ising2=[]
for s in range(2**n_vis):
    v=decimal_to_binary(s, n_vis)[0]
    v_list_ising.append(np.reshape(v,(3,3))*2-1)
    v_list_ising2.append(v)
v_list_ising2=torch.stack(v_list_ising2).to(device)        
for T in [1.47, 1.78, 2.3, 5.2, 16]:
    fullconfigs=pd.read_pickle('{base}/loss_IG/3*3/3*3_full_T={T}.pkl'.format(base=base, T=T))
    S=Entropy(fullconfigs[0])
    bf_list_exact=np.exp(-Ising_energy(v_list_ising)/T)
    F_exact=-np.log(sum(bf_list_exact))
    bf_list_train=np.exp(-Ising_energy(np.reshape(fullconfigs[0],(len(fullconfigs[0]),3,3))*2-1))
    F_train=np.dot(bf_list_train, -Ising_energy(np.reshape(fullconfigs[0],(len(fullconfigs[0]),3,3))*2-1))/sum(bf_list_train)
    print(T, F_train-F_exact-S)
    # S=entropy(bf_list)
    # Pv=torch.tensor(bf_list/sum(bf_list)).to(device)    

1.47 29.978602375811814
1.78 27.4414906741254
2.3 24.244981963649273
5.2 18.891115054608626
16 18.064142139269855
