<a href="https://colab.research.google.com/github/physicaone/loss_IG/blob/master/%5B210516%5DTrain_and_get_data4.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
try:
    from google.colab import drive
    drive.mount('/content/drive')
    base='drive/MyDrive'
except:
    base='Google Drive'

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [15]:
import numpy as np
import torch
import torchvision.datasets
import torchvision.models
import torchvision.transforms
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchvision.utils import make_grid
import torch.utils.data
import torch.nn as nn
from datetime import datetime
from tqdm import tqdm, tnrange
import warnings
warnings.filterwarnings("ignore")

import random
import pickle as pkl
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline
# np.seterr(divide='ignore', invalid='ignore')
import itertools
import re
from time import sleep

In [16]:
class RBM(nn.Module):

    def __init__(self, n_vis, n_hid, k, use_cuda):
        """Create a RBM."""
        super(RBM, self).__init__()
        
        if use_cuda==True:
            self.v = nn.Parameter(torch.ones(1, n_vis).cuda())
            self.h = nn.Parameter(torch.zeros(1, n_hid).cuda())
            self.W = nn.Parameter(torch.randn(n_hid, n_vis).cuda())
            self.k = k
        else:
            self.v = nn.Parameter(torch.ones(1, n_vis))
            self.h = nn.Parameter(torch.zeros(1, n_hid))
            self.W = nn.Parameter(torch.randn(n_hid, n_vis))
            self.k = k     

    def visible_to_hidden(self, v, beta):
        return torch.sigmoid(F.linear(v, self.W, self.h)*beta)

    def hidden_to_visible(self, h, beta):
        return torch.sigmoid(F.linear(h, self.W.t(), self.v)*beta)

    def free_energy(self, v):
        v_term = torch.matmul(v, self.v.t())
        w_x_h = F.linear(v, self.W, self.h)
        h_term = torch.sum(F.softplus(w_x_h), dim=1)
        return torch.mean(-h_term - v_term)
    
    def energy(self, v):
        v=v.bernoulli()
        h=torch.sigmoid(F.linear(v, self.W, self.h))
        h=h.bernoulli()
        return -torch.matmul(v, self.v.t())-torch.matmul(torch.matmul(v, self.W.t()),h.t())-torch.matmul(h, self.h.t())
    
    def energy2(self, v, h):
        v=v.bernoulli()
        h=h.bernoulli()
        return -torch.matmul(v, self.v.t())-torch.matmul(torch.matmul(v, self.W.t()),h.t())-torch.matmul(h, self.h.t())

    def forward(self, v):
        h = self.visible_to_hidden(v,1)
        h = h.bernoulli()
        for _ in range(self.k):
            v_gibb = self.hidden_to_visible(h,1).to(device)
            v_gibb = v_gibb.bernoulli()
            h = self.visible_to_hidden(v_gibb, 1).to(device)
            h = h.bernoulli()
        return v, v_gibb

In [17]:

def decimal_to_binary(integer):
    string=bin(integer)[2:]
    list0=[float(d) for d in string]
    while len(list0)<n_hid:
        list0=[0.]+list0
    return torch.tensor([list0])

def binary_to_decimal(list0):
    value=0
    for i in range(len(list0)):
        value+=list0[-i-1]*2**(i)
    return value


In [18]:
# #Create infinite and zero T
# n_hid=9
# fullconfigs=[]
# for i in range(512):
#     for j in range(10):
#         fullconfigs.append(decimal_to_binary(i).detach().numpy())
# random.shuffle(fullconfigs)
# with open('{base}/loss_IG/3*3/3*3_inf_full.pkl'.format(base=base), 'wb') as f:
#     pkl.dump(fullconfigs, f)


# fullconfigs=[]
# for i in range(int(2**16/2)-16):
#     fullconfigs.append(decimal_to_binary(0).detach().numpy())
#     fullconfigs.append(decimal_to_binary(2**16-1).detach().numpy())
# for i in range(9):
#     fullconfigs.append(decimal_to_binary(2**i).detach().numpy())
#     fullconfigs.append(decimal_to_binary(2**16-1-2**i).detach().numpy())
# subconfigs=[]
# for i in range(2984):
#     subconfigs.append(decimal_to_binary(0).detach().numpy())
#     subconfigs.append(decimal_to_binary(2**16-1).detach().numpy())
# for i in range(9):
#     subconfigs.append(decimal_to_binary(2**i).detach().numpy())
#     subconfigs.append(decimal_to_binary(2**16-1-2**i).detach().numpy())
# random.shuffle(fullconfigs)
# random.shuffle(subconfigs)
# with open('drive/MyDrive/loss_IG/3*3/3*3_zero_full.pkl', 'wb') as f:
#     pkl.dump(fullconfigs, f)
# with open('drive/MyDrive/loss_IG/3*3/3*3_zero_sub.pkl', 'wb') as f:
#     pkl.dump(subconfigs, f)

In [19]:
CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0

if CUDA:
    device='cuda'
else:
    device='cpu'
torch.cuda.is_available()

True

In [20]:
from torch.utils.data import Dataset

class CustomDataset(Dataset): 
    def __init__(self, dataset):
        data_x = dataset
        self.x_data = data_x
#         self.y_data = data_y

    # 총 데이터의 개수를 리턴
    def __len__(self): 
        return len(self.x_data)
    # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
    def __getitem__(self, idx): 
        x = torch.FloatTensor(self.x_data[idx])
#         y = torch.FloatTensor([self.y_data[idx]])
        return x

In [122]:
def train_and_get_data(n_hid, model, lr, train_loader):
    if model==0:
        rbm = RBM(n_vis, n_hid, k, use_cuda=CUDA)
    else:
        rbm=model
    train_loss_list=[]
    train_op = optim.Adam(rbm.parameters(), lr)
    rbm.train()
    for epoch in range(n_epochs):
        train_loss_epoch = []
        for _, (data) in enumerate(train_loader):
            data=data.to(device)
            v, v_gibbs = rbm(data.view(-1, n_vis))
            train_loss = rbm.free_energy(v) - rbm.free_energy(v_gibbs)
            train_loss_epoch.append(train_loss.item())
            train_op.zero_grad()
            train_loss.backward()
            train_op.step()
        # val=np.mean(train_loss_epoch)
        # train_loss_list.append(val)
        # print(epoch, val)
        # if epoch>1 and train_loss_list[-1]*train_loss_list[-2]<0:
        #     break
    # with open('drive/MyDrive/loss_IG/3*3/loss/2021-05-02_loss_n_hid={n_hid}_{por}_T={T}.pkl'.format(n_hid=n_hid, por=por, T=T), 'wb') as f:
    #     pkl.dump(train_loss_list, f)
    # torch.save(rbm.state_dict(), 'drive/MyDrive/loss_IG/3*3/state_dict/2021-05-02_n_hid={n_hid}_{por}_T={T}'.format(n_hid=n_hid, por=por, T=T))
    rbm=rbm.cpu()
    return rbm.state_dict()

def CM_model(models):
    new_v_bias=0; new_h_bias=0; new_Weight=0
    for i in range(10):
        new_v_bias+=models[str(i)]['v']/10
        new_h_bias+=models[str(i)]['h']/10
        new_Weight+=models[str(i)]['W']/10
    return {'v':new_v_bias, 'h':new_h_bias, 'W':new_Weight}  

def Rearrange_parameters(model0):
    W=model0['W']
    Wmean={}
    for i in range(len(W)):
        Wmean[torch.mean(W[i])]=i
    sorted_mean=sorted(Wmean.keys())
    sorted_index=[Wmean[j] for j in sorted_mean]
    sorted_weight=torch.stack([W[k] for k in sorted_index])
    sorted_b=torch.tensor([[model0['h'][0][k] for k in sorted_index]])
    
    return {'W':sorted_weight, 'v':model0['v'], 'h':sorted_b}

In [128]:
n_vis=9
k=5
n_epochs=300
batch_size=512
lr=0.001
# T='1.47'
T_list=[1.47, 1.78, 2.3, 5.2, 16]
n_hid_list=[1,2,4,8,12,16,24,32]

In [None]:
for T in T_list:
    fullconfigs=pd.read_pickle('{base}/loss_IG/3*3/3*3_full_T={T}.pkl'.format(base=base, T=T))
    def data_to_loader(fullconfigs):
        fulldata=CustomDataset(fullconfigs)
        full_dataset = fulldata
        full_loader = torch.utils.data.DataLoader(full_dataset, batch_size)
        return full_loader
    loader_list=[]
    for i in range(10):
        random.shuffle(fullconfigs[i])
        loader_list.append(data_to_loader(fullconfigs[i]))

    for n_hid in n_hid_list:
        dicts={}
        for m in tnrange(10):
            dicts[str(m)]={}
            for n in range(10):
                dicts[str(m)][str(n)]=train_and_get_data(n_hid,0,lr,train_loader=loader_list[m])
        with open('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_mn.pkl'.format(base=base, n_hid=n_hid, T=T, m=m, n=n), 'wb') as f:
            pkl.dump(dicts, f)
            
    # CM_m model 만들기
        dict0={}
        models=pd.read_pickle('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_mn.pkl'.format(base=base, n_hid=n_hid, T=T))
        for m in range(10):
            dict0[str(m)]=CM_model(models[str(m)])
        with open('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_CM_m.pkl'.format(base=base, n_hid=n_hid, T=T), 'wb') as f:
            pkl.dump(dict0, f)

    # CM model 만들기
        models=pd.read_pickle('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_CM_m.pkl'.format(base=base,n_hid=n_hid, T=T))
        with open('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_CM.pkl'.format(base=base, n_hid=n_hid, T=T), 'wb') as f:
            pkl.dump(CM_model(models), f)

# Rearrange weight and hidden bias

In [129]:
for T in T_list:
    for n_hid in [1,2,4,8,12,16,24,32,64]:
        model_dicts_mn=pd.read_pickle('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_mn.pkl'.format(base=base, n_hid=n_hid, T=T)) 
        model_dicts_CM_m=pd.read_pickle('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_CM_m.pkl'.format(base=base, n_hid=n_hid, T=T))
        model_dicts_CM=pd.read_pickle('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_CM.pkl'.format(base=base, n_hid=n_hid, T=T))
        dict_mn={}; dict_CM_m={}; dict_CM={}
        
        dict_CM=Rearrange_parameters(model_dicts_CM)
        for m in range(10):
            dict_mn[str(m)]={}
            dict_CM_m[str(m)]=Rearrange_parameters(model_dicts_CM_m[str(m)])
            for n in range(10):
                dict_mn[str(m)][str(n)]=Rearrange_parameters(model_dicts_mn[str(m)][str(n)])
        
        with open('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_mn_rearranged.pkl'.format(base=base, n_hid=n_hid, T=T, m=m, n=n), 'wb') as f:
            pkl.dump(dict_mn, f)

        with open('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_CM_m_rearranged.pkl'.format(base=base, n_hid=n_hid, T=T), 'wb') as f:
            pkl.dump(dict_CM_m, f)

        with open('{base}/loss_IG/3*3/state_dict/n_hid={n_hid}_T={T}_CM_rearranged.pkl'.format(base=base, n_hid=n_hid, T=T), 'wb') as f:
            pkl.dump(dict_CM, f)