In [5]:
import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.neural_network import BernoulliRBM
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load Breast Cancer Dataset
data = load_breast_cancer()
X, y = data.data, data.target

# Normalize the dataset
X = X / np.max(X, axis=0)

# Train-Test Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Train RBM for feature extraction
rbm = BernoulliRBM(n_components=10, learning_rate=0.1, n_iter=10, random_state=42)
rbm.fit(X_train)

# Transform data using trained RBM
X_train_transformed = rbm.transform(X_train)
X_test_transformed = rbm.transform(X_test)

# Train a classifier on extracted features
classifier = LogisticRegression(max_iter=1000, solver='lbfgs', random_state=42)
classifier.fit(X_train_transformed, y_train)

# Predict and evaluate
y_pred = classifier.predict(X_test_transformed)
accuracy = accuracy_score(y_test, y_pred)

# Print sample results
print("Original Sample Input (First 3 X_test Samples):")
print(X_test[:3])  # Show first 3 samples

print("\nExtracted Features from RBM (First 3 Samples):")
print(X_test_transformed[:3])  # Show transformed features

print("\nPredicted Labels (First 3 Samples):", y_pred[:3])
print("True Labels (First 3 Samples):", y_test[:3])

print(f"\nClassification Accuracy: {accuracy:.4f}")


Original Sample Input (First 3 X_test Samples):
[[0.44361437 0.47352342 0.43018568 0.19268293 0.60985312 0.30631152
  0.18755858 0.18991054 0.63322368 0.65404351 0.13786982 0.21371546
  0.11360328 0.05586499 0.22335368 0.14113737 0.06820707 0.19643872
  0.22571248 0.12017426 0.41537181 0.49737586 0.38236465 0.1593559
  0.64061096 0.22476371 0.21333866 0.34879725 0.45405243 0.42168675]
 [0.67378157 0.54251527 0.65570292 0.45181927 0.55134639 0.29791546
  0.25304592 0.39517893 0.52039474 0.56044745 0.27455621 0.16325486
  0.24959054 0.17714865 0.14275618 0.12200886 0.05729798 0.25951885
  0.17555415 0.05690349 0.68978912 0.53653613 0.66042994 0.43864598
  0.5359389  0.22079395 0.21461661 0.61477663 0.3843025  0.31754217]
 [0.54998221 0.49592668 0.53952255 0.29944022 0.66829865 0.35408222
  0.34348641 0.40193837 0.63519737 0.59482759 0.16508876 0.16088025
  0.14076433 0.08909996 0.20044973 0.10960118 0.07103535 0.20704679
  0.17694744 0.08247319 0.53440622 0.52482842 0.49721338 0.27174424