<a href="https://colab.research.google.com/github/rushilgowda/AGA-lab-USN-1BM22AI111/blob/main/Lab_6.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np


class RBM(nn.Module):
    def __init__(self, n_visible, n_hidden, k=1):
        super(RBM, self).__init__()
        self.W = nn.Parameter(torch.randn(n_hidden, n_visible) * 0.01)
        self.h_bias = nn.Parameter(torch.zeros(n_hidden))
        self.v_bias = nn.Parameter(torch.zeros(n_visible))
        self.k = k

    def sample_h(self, v):
        h_prob = torch.sigmoid(F.linear(v, self.W, self.h_bias))
        return h_prob, torch.bernoulli(h_prob)

    def sample_v(self, h):
        v_prob = torch.sigmoid(F.linear(h, self.W.t(), self.v_bias))
        return v_prob, torch.bernoulli(v_prob)

    def contrastive_divergence(self, v, lr=0.01):
        v0 = v
        h0_prob, h0 = self.sample_h(v0)
        for _ in range(self.k):
            v1_prob, v1 = self.sample_v(h0)
            h1_prob, h1 = self.sample_h(v1_prob)
        self.W.data += lr * (torch.matmul(h0.t(), v0) - torch.matmul(h1_prob.t(), v1_prob)) / v0.size(0)
        self.v_bias.data += lr * torch.mean(v0 - v1_prob, dim=0)
        self.h_bias.data += lr * torch.mean(h0 - h1_prob, dim=0)


class DBN(nn.Module):
    def __init__(self, input_size, hidden_sizes, output_size):
        super(DBN, self).__init__()
        self.rbm1 = RBM(input_size, hidden_sizes[0])
        self.rbm2 = RBM(hidden_sizes[0], hidden_sizes[1])
        self.fc = nn.Sequential(
            nn.Linear(input_size, hidden_sizes[0]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[0], hidden_sizes[1]),
            nn.ReLU(),
            nn.Linear(hidden_sizes[1], output_size)
        )

    def forward(self, x):
        return self.fc(x)


transform = transforms.Compose([
    transforms.Grayscale(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
    lambda x: x.view(-1)
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=1000)


input_size = 32 * 32
hidden_sizes = [512, 256]
output_size = 10

rbm1 = RBM(input_size, hidden_sizes[0])
rbm2 = RBM(hidden_sizes[0], hidden_sizes[1])

print("Pretraining RBM 1...")
for epoch in range(5):
    for x, _ in train_loader:
        rbm1.contrastive_divergence(x)

print("Pretraining RBM 2...")
with torch.no_grad():
    h1_all = []
    for x, _ in train_loader:
        h1, _ = rbm1.sample_h(x)
        h1_all.append(h1)
    h1_all = torch.cat(h1_all)


labels_tensor = torch.tensor(train_data.targets)
h1_loader = DataLoader(TensorDataset(h1_all, labels_tensor), batch_size=64, shuffle=True)

for epoch in range(5):
    for h1, _ in h1_loader:
        rbm2.contrastive_divergence(h1)


dbn = DBN(input_size, hidden_sizes, output_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(dbn.parameters(), lr=0.001)

print("Fine-tuning DBN...")
for epoch in range(5):
    dbn.train()
    total_loss = 0
    for x, y in train_loader:
        output = dbn(x)
        loss = criterion(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader):.4f}")


dbn.eval()
correct = 0
total = 0
with torch.no_grad():
    for x, y in test_loader:
        pred = dbn(x).argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)

accuracy = 100. * correct / total
print(f"\n✅ DBN Accuracy on CIFAR-10: {accuracy:.2f}%")


100%|██████████| 170M/170M [00:03<00:00, 45.2MB/s]


Pretraining RBM 1...
Pretraining RBM 2...
Fine-tuning DBN...
Epoch 1, Loss: 1.8508
Epoch 2, Loss: 1.6549
Epoch 3, Loss: 1.5462
Epoch 4, Loss: 1.4538
Epoch 5, Loss: 1.3676

✅ DBN Accuracy on CIFAR-10: 45.27%
