In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from torchvision import datasets, transforms
import pandas as pd
import seaborn as sns

The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.


0it [00:00, ?it/s]

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

In [3]:
class RBM(nn.Module):
    def __init__(self, n_visible, n_hidden, n_cond=0, conditional=False):
        super(RBM, self).__init__()
        self.n_visible = n_visible
        self.n_hidden = n_hidden
        self.n_cond = n_cond
        self.conditional = conditional

        self.W = nn.Parameter(torch.randn(n_visible, n_hidden).to(device) * 0.01)
        self.h_bias = nn.Parameter(torch.zeros(n_hidden).to(device))
        self.v_bias = nn.Parameter(torch.zeros(n_visible).to(device))

        if self.conditional:
            self.U = nn.Parameter(torch.randn(n_cond, n_hidden).to(device) * 0.01)

    def sample_h(self, v, c=None):
        if self.conditional:
            activation = torch.matmul(v, self.W) + self.h_bias + torch.matmul(c, self.U)
        else:
            activation = torch.matmul(v, self.W) + self.h_bias
        p_h = torch.sigmoid(activation)
        return p_h, torch.bernoulli(p_h)

    def sample_v(self, h):
        activation = torch.matmul(h, self.W.t()) + self.v_bias
        p_v = torch.sigmoid(activation)
        return p_v, torch.bernoulli(p_v)

    def contrastive_divergence(self, v0, c0=None, k=1):
        p_h0, h0 = self.sample_h(v0, c0) if self.conditional else self.sample_h(v0)
        v_k = v0
        for _ in range(k):
            p_h, h_k = self.sample_h(v_k, c0) if self.conditional else self.sample_h(v_k)
            p_v, v_k = self.sample_v(h_k)
        p_hk, h_k = self.sample_h(v_k, c0) if self.conditional else self.sample_h(v_k)

        positive_grad = torch.matmul(v0.t(), p_h0)
        negative_grad = torch.matmul(v_k.t(), p_hk)

        if self.conditional:
            self.U.grad = -(torch.matmul(c0.t(), (p_h0 - p_hk))) / v0.size(0)

        self.W.grad = -(positive_grad - negative_grad) / v0.size(0)
        self.v_bias.grad = -(torch.sum(v0 - v_k, dim=0)) / v0.size(0)
        self.h_bias.grad = -(torch.sum(p_h0 - p_hk, dim=0)) / v0.size(0)

    def generate(self, c=None, n_samples=10, gibbs_steps=1000):
        samples = []
        v = torch.bernoulli(torch.rand(n_samples, self.n_visible).to(device))
        for _ in range(gibbs_steps):
            p_h, h = self.sample_h(v, c) if self.conditional else self.sample_h(v)
            p_v, v = self.sample_v(h)
        samples.append(v.detach().cpu())
        return torch.cat(samples)

In [4]:
def load_and_prepare_tabular_data(file_path, target_col):
    df = pd.read_csv(file_path)
    X = df.drop(target_col, axis=1).values
    y = df[target_col].values.reshape(-1, 1)
    scaler = StandardScaler()
    X = scaler.fit_transform(X)
    encoder = OneHotEncoder(sparse=False)
    y_encoded = encoder.fit_transform(y)
    return torch.tensor(X, dtype=torch.float32).to(device), torch.tensor(y_encoded, dtype=torch.float32).to(device)

In [5]:
def visualize_distributions(real_data, synthetic_data):
    plt.figure(figsize=(10, 5))
    for i in range(min(real_data.shape[1], 5)):
        sns.kdeplot(real_data[:, i], color="blue", label="Real" if i == 0 else "")
        sns.kdeplot(synthetic_data[:, i], color="red", label="Synthetic" if i == 0 else "")
    plt.legend()
    plt.title("Distribution comparison between Real and Synthetic data")
    plt.show()

In [6]:
def train_pipeline(dataset='mnist', custom_path=None, target_col=None, conditional=False, epochs=5):
    if dataset == 'mnist':
        transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.view(-1))])
        data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
        data_loader = torch.utils.data.DataLoader(data, batch_size=64, shuffle=True)
        n_visible, n_hidden, n_cond = 784, 256, 10
    else:
        X, y = load_and_prepare_tabular_data(custom_path, target_col)
        n_visible, n_hidden, n_cond = X.shape[1], 128, y.shape[1]
        dataset = torch.utils.data.TensorDataset(X, y)
        data_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

    rbm = RBM(n_visible, n_hidden, n_cond, conditional=conditional).to(device)
    optimizer = torch.optim.SGD(rbm.parameters(), lr=0.05)

    for epoch in range(epochs):
        for batch in data_loader:
            if dataset == 'mnist':
                images, labels = batch
                images, labels = images.to(device), labels.to(device)
                c = F.one_hot(labels, num_classes=10).float() if conditional else None
                v0 = images
            else:
                v0, c = batch

            optimizer.zero_grad()
            rbm.contrastive_divergence(v0, c, k=1)
            optimizer.step()
        print(f'Epoch {epoch + 1}/{epochs} completed.')

    print("Training completed. Generating synthetic data...")

    c_gen = torch.eye(n_cond).to(device) if conditional else None
    synthetic_data = rbm.generate(c_gen, n_samples=1000, gibbs_steps=500).cpu().numpy()

    if dataset == 'mnist':
        fig, axes = plt.subplots(1, 10, figsize=(15, 2))
        synthetic_data_reshaped = synthetic_data.reshape(-1, 28, 28)
        for i in range(10):
            axes[i].imshow(synthetic_data_reshaped[i], cmap='gray')
            axes[i].axis('off')
        plt.show()
    else:
        visualize_distributions(X.cpu().numpy(), synthetic_data)

In [7]:
train_pipeline(dataset='mnist', conditional=True, epochs=5)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:41<00:00, 238821.34it/s]


Extracting ./data\MNIST\raw\train-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 98853.60it/s]


Extracting ./data\MNIST\raw\train-labels-idx1-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:09<00:00, 175393.15it/s]


Extracting ./data\MNIST\raw\t10k-images-idx3-ubyte.gz to ./data\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<?, ?it/s]


Extracting ./data\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./data\MNIST\raw

Epoch 1/5 completed.
Epoch 2/5 completed.
Epoch 3/5 completed.
Epoch 4/5 completed.
Epoch 5/5 completed.
Training completed. Generating synthetic data...


RuntimeError: The size of tensor a (1000) must match the size of tensor b (10) at non-singleton dimension 0