<a href="https://colab.research.google.com/github/physicaone/loss_IG/blob/master/%5B211105%5DTrain_and_get_data_sym3_1D.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import torch
import torchvision.datasets
import torchvision.models
import torchvision.transforms
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data
import torch.nn as nn
from tqdm import tqdm, tnrange
import warnings
warnings.filterwarnings("ignore")
import random
import pickle as pkl
import pandas as pd
from scipy.stats import entropy
import copy
CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0

try:
    from google.colab import drive
    drive.mount('/content/drive')
    base='drive/MyDrive'
except:
    if torch.cuda.device_count()>1:
        base='.'
    else:
        base='Google Drive'

if CUDA:
    device='cuda'
else:
    device='cpu'
torch.cuda.is_available()

In [None]:
def decimal_to_binary(integer, n_hid):
    string=bin(int(integer))[2:]
    list0=[float(d) for d in string]
    while len(list0)<n_hid:
        list0=[0.]+list0
    return torch.tensor([list0])

def Ising_energy(v_list):
    L = 3
    E_list=[]
    for n in range(len(v_list)):
        v=v_list[n]
        E = 0
        for i in range(L):
            s = v[i]
            neigh=v[(i+1)%L]+v[(i-1)%L] 
            E += -neigh * s
        E_list.append(E/2)
    return np.array(E_list)
    
class RBM(nn.Module):

    def __init__(self, n_vis, n_hid, k):
        """Create a RBM."""
        super(RBM, self).__init__()
        
        self.Weight = nn.Parameter(std*torch.randn(n_hid, n_vis).to(device))
        self.k = k


    def v2h(self, v):
        return torch.tanh(F.linear(v, self.Weight))

    def h2v(self, h):
        return torch.tanh(F.linear(h, self.Weight.t()))
    
    def Fv(self, v):
        h_term = torch.sum(torch.log(torch.exp(-torch.matmul(v, self.Weight.t()))+torch.exp(torch.matmul(v, self.Weight.t()))), dim=1)
        return -h_term + Low_T_correction

    def energy(self, v, h):
        v=v.bernoulli()
        h=h.bernoulli()
        return -torch.matmul(torch.matmul(v, self.Weight.t()),h.t())


    def forward(self, v):

        return v
from torch.utils.data import Dataset

class CustomDataset(Dataset): 
    def __init__(self, dataset):
        data_x = dataset
        self.x_data = data_x
#         self.y_data = data_y

    # 총 데이터의 개수를 리턴
    def __len__(self): 
        return len(self.x_data)
    # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
    def __getitem__(self, idx): 
        x = torch.FloatTensor(self.x_data[idx])
#         y = torch.FloatTensor([self.y_data[idx]])
        return x

def data_to_loader(fullconfigs):
    fulldata=CustomDataset(fullconfigs)
    full_dataset = fulldata
    full_loader = torch.utils.data.DataLoader(full_dataset, batch_size)
    return full_loader


def train_and_get_data(n_hid, model, lr, train_loader):
    # Train and get the new result
    rbm=RBM(n_vis, n_hid, k)
    train_op = optim.SGD(rbm.parameters(), lr, momentum=0.9)
    rbm.train()
    val_loss_list=[]
    model_list=[]
    # for epoch in tnrange(n_epochs):
    for epoch in tnrange(n_epochs):
        for _, (data) in enumerate(train_loader):
            v= data.view(-1, n_vis).to(device)
            Q_bf_list=torch.exp(-rbm.Fv(v_list_ising2))
            Qv=torch.tensor(Q_bf_list/sum(Q_bf_list)).to(device)
            Fv_Q=torch.dot(rbm.Fv(v_list_ising2), Qv)

            train_loss = torch.mean(rbm.Fv(v)) - Fv_Q
            train_op.zero_grad()
            train_loss.backward()
            train_op.step()
        if epoch in epoch_to_save:
            model_list.append(copy.deepcopy(rbm.cpu().state_dict()))
            rbm.to(device)
            FE=-torch.log(torch.sum(torch.exp(-rbm.Fv(v_list_ising2))))
            Fv_P=torch.dot(rbm.Fv(v_list_ising2).double(), Pv)
            GE=Fv_P-FE-S
            print('epoch={epoch}, GE={GE}'.format(epoch=epoch, GE=GE))
            val_loss_list.append(GE)
    return model_list, val_loss_list

def CM_model(models):
    new_v_bias=0; new_h_bias=0; new_Weight=0
    for m in range(10):
        new_v_bias+=models[str(m)][-1]['v_bias']/10
        new_h_bias+=models[str(m)][-1]['h_bias']/10
        new_Weight+=models[str(m)][-1]['Weight']/10
    return {'v_bias':new_v_bias, 'h_bias':new_h_bias, 'Weight':new_Weight}  

def mean_Fv(Q_m, v0):
    value=0
    for m in range(10):
        rbm=RBM(n_vis, n_hid, k=1)
        rbm.load_state_dict(Q_m[str(m)])
        value+=rbm.Fv(v0)/10
    return value


In [None]:
# Hyper parameter들을 설정
n_vis=9
k=5
lr=0.1
std=0.5
epoch_to_save=[2**i for i in range(18)]
n_epochs=epoch_to_save[-1]+1
Low_T_correction=0

In [None]:

torch.set_printoptions(precision=10)
for n_hid in [1,2,3,4,6,8,12,16]:  
    v_list_ising=[]
    v_list_ising2=[]
    for s in range(2**n_vis):
        v=decimal_to_binary(s, n_vis)[0]
        v_list_ising.append(np.reshape(v,(9))*2-1)
        v_list_ising2.append(v)
    v_list_ising2=torch.stack(v_list_ising2).to(device)*2-1

    for T in [16,'inf']:
        if T=='inf' or T=='zero':
            bf_list=torch.tensor([1.]*512)
            S=entropy(bf_list)
            Pv=torch.tensor(bf_list/sum(bf_list)).to(device).double()
        else:
            bf_list=np.exp(-Ising_energy(v_list_ising)/T)
            S=entropy(bf_list)
            Pv=torch.tensor(bf_list/sum(bf_list)).to(device).double()

        for vol in [512]:
            batch_size=int(vol/2)
            dict_model={}
            dict_GE={}
            if T=='zero':
                trainset=[torch.tensor([[[-1.]*9],[[1.]*9]]*int(vol/2))]*10
                Low_T_correction=100
            elif T=='inf':
                # trainset=[v_list_ising2[::int(2**n_vis/vol)]]*10
                trainset=[v_list_ising2]*10
            else:
                trainset=torch.tensor(pd.read_pickle('{base}/loss_IG/9*1/9*1_full_T={T}.pkl'.format(base=base, T=T)))*2-1
            train_loader_list=[]; val_loader_list=[]


            for m in range(10):
                train_loader_list.append(data_to_loader(trainset[m][:vol]))

            # try:
            #     old_model=pd.read_pickle('{base}/loss_IG/3*3/state_dict/model_n_hid={n_hid}_T={T}.pkl'.format(base=base, n_hid=n_hid, T=T))
            # except:
            #     old_model=0
            for m in range(10):
                model0, loss=train_and_get_data(n_hid, 0, lr=lr, train_loader=train_loader_list[m])
                dict_model[str(m)]=model0
                dict_GE[str(m)]=loss
            with open('{base}/loss_IG/9*1/state_dict/model_n_hid={n_hid}_T={T}_lr={lr}_vol={vol}_std={std}_sym.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr, vol=vol, std=std), 'wb') as f:
                pkl.dump(dict_model, f)
            with open('{base}/loss_IG/9*1/loss/GE_n_hid={n_hid}_T={T}_lr={lr}_vol={vol}_std={std}_sym.pkl'.format(base=base, n_hid=n_hid, T=T, lr=lr, vol=vol, std=std), 'wb') as f:
                pkl.dump(dict_GE, f)
