In [None]:
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn



In [None]:
import os, sys
project_root = os.path.abspath('/Users/subhojit/workspace/saturn/src')
if project_root not in sys.path:
    sys.path.append(project_root)

from rbm.rbm import RBM
from rbm.auto_encoder import DeepAutoEncoder
from rbm.auto_encoder_random_init_weight import DeepAutoEncoderWithRandomInitWeight

In [None]:

transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: torch.bernoulli(x)),
])

In [None]:
mnist = torchvision.datasets.MNIST(root='./dataset', train=True, download=True, transform=transforms)

In [None]:
loader = DataLoader(mnist, batch_size=64, shuffle=True)
learning_rate = 0.1
n_visible = 784
n_hidden = 256
epochs = 5

In [None]:
# rbm = RBM(n_visible=n_visible, n_hidden=n_hidden)
# optimizer = torch.optim.AdamW(rbm.parameters(), lr=learning_rate)
#
# for epoch in range(epochs):
#     total_loss = 0.0
#
#     for batch, _ in loader:
#         v = batch.view(-1, n_visible)
#         optimizer.zero_grad()
#         loss = rbm.cd_step(v)
#         optimizer.step()
#         total_loss += loss
#
#     print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}")

In [6]:
# pretrain RBM
def pretrain_rbm(rbm, data_loader, epoch=5, lr=0.1):
    optimizer = torch.optim.AdamW(rbm.parameters(), lr=lr)
    for epoch in range(epochs):
        total_loss = 0.0
        for batch, _ in data_loader:
            v = batch.view(-1, rbm.W.shape[0])
            optimizer.zero_grad()
            loss = rbm.cd_step(v)
            optimizer.step()
            total_loss += loss.item()

        print(f"RBM {rbm.W.shape[0]}->{rbm.W.shape[1]} Epoch {epoch+1}, Loss: {total_loss:.4f}")


def transform_data(data_loader, rbm_stack):
    result = []
    for batch, _ in data_loader:
        with torch.no_grad():
            for rbm in rbm_stack:
                v = batch.view(-1, rbm.W.shape[0])
                x = rbm(v)
        result.append(x)
    return torch.cat(result, dim=0)

In [None]:
train_loader = DataLoader(mnist, batch_size=64, shuffle=True)

rbm1 = RBM(784, 500)
pretrain_rbm(rbm1, train_loader)
h1 = transform_data(train_loader, [rbm1])
loader_h1 = DataLoader([(x, 0) for x in h1], batch_size=64, shuffle=True)

rbm2 = RBM(500, 250)
pretrain_rbm(rbm2, loader_h1)
h2 = transform_data(loader_h1, [rbm2])
loader_h2 = DataLoader([(x, 0) for x in h2], batch_size=64, shuffle=True)


rbm3 = RBM(250, 125)
pretrain_rbm(rbm3, loader_h2)
h3 = transform_data(loader_h2, [rbm3])
loader_h3 = DataLoader([(x, 0) for x in h3], batch_size=64, shuffle=True)

rbm4 = RBM(125, 2)
pretrain_rbm(rbm2, loader_h3)

In [None]:
model = DeepAutoEncoder(rbm1, rbm2, rbm3, rbm4)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.BCELoss()

for epoch in range(epochs):
    total_loss = 0.0
    for batch, _ in loader:
        x = batch.view(-1, n_visible)
        optimizer.zero_grad()
        recon = model(x)
        loss = loss_fn(recon, x)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Fine-tuning Epoch {epoch+1}, Loss: {total_loss:.4f}")



In [8]:
# Auto encoder training with backprop without any pretraining
model = DeepAutoEncoderWithRandomInitWeight()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
loss_fn = nn.BCELoss()

for epoch in range(epochs):
    total_loss = 0.0
    for batch, _ in loader:
        x = batch.view(-1, n_visible)
        optimizer.zero_grad()
        recon = model(x)
        loss = loss_fn(recon, x)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Fine-tuning Epoch {epoch + 1}, Loss: {total_loss:.4f}")



Fine-tuning Epoch 1, Loss: 3114.6280
Fine-tuning Epoch 2, Loss: 6003.9148
Fine-tuning Epoch 3, Loss: 5791.3560
Fine-tuning Epoch 4, Loss: 4815.9262
Fine-tuning Epoch 5, Loss: 3115.8646
