In [1]:
import numpy as np
import torch
import torchvision.datasets
import torchvision.models
import torchvision.transforms
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.utils.data
import torch.nn as nn
from tqdm import tqdm, tnrange
import warnings
warnings.filterwarnings("ignore")

import pickle as pkl
import pandas as pd


In [2]:
CUDA = torch.cuda.is_available()
CUDA_DEVICE = 0

if CUDA:
    device='cuda'
else:
    device='cpu'
torch.cuda.is_available()

False

# Define RBM graph

In [3]:
class RBM(nn.Module):

    def __init__(self, n_vis, n_hid, k):
        """Create a RBM."""
        super(RBM, self).__init__()
        
        self.v_bias = nn.Parameter(torch.ones(1, n_vis).to(device))
        self.h_bias = nn.Parameter(torch.zeros(1, n_hid).to(device))
        self.Weight = nn.Parameter(torch.randn(n_hid, n_vis).to(device))
        self.k = k


    def v2h(self, v):
        return torch.sigmoid(F.linear(v, self.Weight, self.h_bias))

    def h2v(self, h):
        return torch.sigmoid(F.linear(h, self.Weight.t(), self.v_bias))
    
    def Fv(self, v):
        v_term = torch.matmul(v, self.v_bias.t()).view(len(v))
        h_term = torch.sum(F.softplus(F.linear(v, self.Weight, self.h_bias)), dim=1)
        return torch.mean(-h_term - v_term)

    def energy(self, v, h):
        v=v.bernoulli()
        h=h.bernoulli()
        return -torch.matmul(v, self.v_bias.t())-torch.matmul(torch.matmul(v, self.Weight.t()),h.t())-torch.matmul(h, self.h_bias.t())

    def forward(self, v):
        h = self.v2h(v)
        h = h.bernoulli()
        for _ in range(self.k):
            v_gibbs = self.h2v(h).to(device)
            v_gibbs = v_gibbs.bernoulli()
            h = self.v2h(v_gibbs).to(device)
            h = h.bernoulli()
        return v, v_gibbs

# Function to make a train data loader

In [4]:
from torch.utils.data import Dataset

class CustomDataset(Dataset): 
    def __init__(self, dataset):
        data_x = dataset
        self.x_data = data_x
#         self.y_data = data_y

    # 총 데이터의 개수를 리턴
    def __len__(self): 
        return len(self.x_data)
    # 인덱스를 입력받아 그에 맵핑되는 입출력 데이터를 파이토치의 Tensor 형태로 리턴
    def __getitem__(self, idx): 
        x = torch.FloatTensor(self.x_data[idx])
#         y = torch.FloatTensor([self.y_data[idx]])
        return x

def data_to_loader(fullconfigs):
    fulldata=CustomDataset(fullconfigs)
    full_dataset = fulldata
    full_loader = torch.utils.data.DataLoader(full_dataset, batch_size)
    return full_loader



In [5]:
def train_and_get_data(n_hid, model, lr, train_loader):
    # 처음부터 모델을 훈련하려는 경우
    if model==0:
        rbm=RBM(n_vis, n_hid, k)
    # 훈련된 모델을 이어서 훈련하려는 경우
    else: 
        rbm=model
    train_loss_list=[]
#     train_op = optim.Adam(rbm.parameters(), lr)
    train_op = optim.SGD(rbm.parameters(), lr, momentum=0.9)
    rbm.train()
    for epoch in range(n_epochs):
        train_loss_epoch = []
        for _, (data) in enumerate(train_loader):
            data=data.to(device)
            v, v_gibbs = rbm(data.view(-1, n_vis))
            train_loss = rbm.Fv(v) - rbm.Fv(v_gibbs)
            train_loss_epoch.append(train_loss.item())
            train_op.zero_grad()
            train_loss.backward()
            train_op.step()
        # val=np.mean(train_loss_epoch)
        # train_loss_list.append(val)
        # print(epoch, val)
        # if epoch>1 and train_loss_list[-1]*train_loss_list[-2]<0:
        #     break
    rbm=rbm.cpu()
    return rbm.state_dict()


In [6]:
# Hyper parameter들을 설정
n_vis=9
n_hid=2
k=1
n_epochs=300
batch_size=512
lr=0.03


In [7]:
# 예시 학습 데이터
fullconfigs=torch.tensor([[0.,0.,0.,0.,0.,0.,0.,0.,0.], [0.,0.,0.,0.,0.,0.,1.,0.,1.]])
# 학습 데이터를 트레인 로더에 담음
train_loader=data_to_loader(fullconfigs)
# 훈련 결과의 parameter들이 result에 저장됨
result=train_and_get_data(n_hid,0,lr,train_loader)


In [8]:
result

OrderedDict([('v_bias',
              tensor([[-3.4698, -2.9000, -3.0499, -3.8000, -3.2000, -2.8743, -0.1600, -2.7500,
                        0.2785]])),
             ('h_bias', tensor([[3.6490, 2.6603]])),
             ('Weight',
              tensor([[ 0.1134, -0.6493, -1.0528, -1.7047, -1.3565, -1.9235, -4.3780, -0.4043,
                       -4.2897],
                      [-3.0264, -2.1430, -2.3730, -0.9813, -1.6098, -2.1605,  1.9023, -2.3826,
                        1.3500]]))])