In [None]:
from matplotlib import pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from math import sqrt
import random
from tqdm import tqdm_notebook as tqdm

class RBM(nn.Module):
    DREAM_ITERATIONS = 1
    
    def __init__(self, visible, hidden):  
        super(RBM, self).__init__()
        self.visible_bias = nn.Parameter(torch.empty((visible)).uniform_(-2/visible, 2/visible)) # visible bias
        self.hidden_bias = nn.Parameter(torch.empty((hidden)).uniform_(-2/hidden, 2/hidden))
        self.visible_size = visible
        self.hidden_size = hidden
        r = min(1/visible, 1/hidden)
        self.W = nn.Parameter(torch.empty((hidden, visible)).uniform_(-r, r))
        
    @staticmethod
    def to_binary_sample(v):
        v = torch.sigmoid(v) # torch.sigmiod
        try:
            result = torch.bernoulli(v)
        except:
            print(v) 
            exit(0)
        return result
    
    @staticmethod
    def to_binary_optimal(v):
        return (v > 0).float()
    
    @staticmethod
    def linear(W, bias, x):
        result = torch.t(W @ x) + bias
        result = torch.t(result)
        return result
    
    def visible_to_hidden_sample(self, visible):
        result = RBM.linear(self.W, self.hidden_bias, visible)
        return RBM.to_binary_sample(result)
    
    def visible_to_hidden_optimal(self, visible):
        result = RBM.linear(self.W, self.hidden_bias, visible)
        return RBM.to_binary_optimal(result)
    
    def hidden_to_visible_sample(self, hidden):
        result = RBM.linear(torch.t(self.W), self.visible_bias, hidden)
        return RBM.to_binary_sample(result)
    
    def hidden_to_visible_optimal(self, hidden):
        result = RBM.linear(torch.t(self.W), self.visible_bias, hidden)
        return RBM.to_binary_optimal(result)
    
    def dream(self, start, iterations=2):
        v = start
        h = None
        iterations = random.randint(1, iterations)
        for i in range(iterations):
            h = self.visible_to_hidden_sample(v)
            v = self.hidden_to_visible_sample(h)
        return (v, h)    
    
    def forward(self, v, h):
        proposed_h = self.W @ v
        energy = torch.trace(-1 * torch.t(h) @ self.W @ v)
        energy += torch.sum(-torch.t(self.visible_bias) @ v)
        energy += torch.sum(-torch.t(self.hidden_bias) @ h)
        energy /= v.shape[0]
        return energy
    
    def get_likelihood(self, v):
        h_opt = self.visible_to_hidden_optimal(v)
        return self.forward(v, h_opt)

In [None]:
m = RBM(4, 3)
x = torch.empty(4, 2).random_(0, 2)
h = torch.empty(3, 2).random_(0, 2)
print(m(x, h))

In [None]:
import torchvision
from torchvision import transforms

IMAGE_SIZE = (28, 28)
VISIBLE_SIZE = IMAGE_SIZE[0] * IMAGE_SIZE[1]
HIDDEN_SIZE = 20

def transform(x):
    x = x.view(VISIBLE_SIZE)
    x[x > 0.5] = 1
    x[x <= 0.5] = 0
    return x

ds_train = torchvision.datasets.MNIST('./files/', train=True, download=True, 
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transform
                       ]))
ds_test = torchvision.datasets.MNIST('./files/', train=False, download=True, 
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transform
                       ]))
# can use transform here

In [None]:
batch_size_train = 128
batch_size_test = 128

train_loader = torch.utils.data.DataLoader(ds_train, batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(ds_test, batch_size=batch_size_test, shuffle=True)

In [None]:
learning_rate = 1e-4

model = RBM(VISIBLE_SIZE, HIDDEN_SIZE)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

def test_loss():
    for z, _ in test_loader:
        z.t_()
        print(model.get_likelihood(z))
        break
        
def loss(positive_phase, negative_phase):
    return positive_phase - negative_phase
        
test_loss()
train_loss = []

In [None]:
EPOCHS = 1

for i in tqdm(range(EPOCHS)):
    for x, y in tqdm(train_loader, leave=None):
        x.t_()
        
        h_opt = model.visible_to_hidden_sample(x)
        (x_bad, h_bad) = model.dream(x)
        
        positive_phase = model(x, h_opt)
        negative_phase = model(x, h_bad)
        
        l = loss(positive_phase, negative_phase)
        train_loss.append(l.item())
        
        optimizer.zero_grad()
        l.backward()
        optimizer.step()
    scheduler.step(i)

In [None]:
train_loss_tensor = torch.tensor(train_loss).view(1, 1, -1)
print(train_loss_tensor.shape)
KERNEL_SIZE = 20
conv_loss = torch.nn.functional.conv1d(train_loss_tensor, torch.ones(1, 1, KERNEL_SIZE) / KERNEL_SIZE)
conv_loss.squeeze_()
plt.plot(conv_loss)

In [None]:
print(model.W.shape)

In [None]:
plt.figure()
plt.imshow((model.visible_bias).view(*IMAGE_SIZE).detach().numpy())

f, axarr = plt.subplots(4,5, figsize=(10, 10))

for i in range(4):
    for j in range(5):
        axarr[i, j].imshow((model.W[i * 5 + j]).view(*IMAGE_SIZE).detach().numpy())

In [None]:
x, y = next(iter(test_loader))
x.t_()
print(x.shape)
h = model.visible_to_hidden_optimal(x)
v = model.hidden_to_visible_optimal(h)
print(v.shape)

f, axarr = plt.subplots(2,5, figsize=(10, 5))

for i in range(5):
    ind = random.randint(0, v.shape[1] - 1)
    axarr[0, i].imshow(x[:, ind].view(28, 28))
    axarr[1, i].imshow(v[:, ind].detach().view(28, 28))