In [1]:
!pip install -qq wandb
import wandb
import pickle
import numpy as np
from tqdm import tqdm
from sklearn.metrics import log_loss
from keras.datasets import fashion_mnist # Used only for loading the data
from sklearn.linear_model import LogisticRegression

[K     |████████████████████████████████| 1.8MB 5.0MB/s 
[K     |████████████████████████████████| 102kB 8.1MB/s 
[K     |████████████████████████████████| 133kB 25.8MB/s 
[K     |████████████████████████████████| 174kB 18.6MB/s 
[K     |████████████████████████████████| 71kB 7.2MB/s 
[?25h  Building wheel for subprocess32 (setup.py) ... [?25l[?25hdone
  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [2]:
from google.colab import drive
drive.mount("/content/drive/")
import os
os.chdir("/content/drive/MyDrive/DL_A4")
from RBM_model import RBM
get_ipython().run_line_magic('load_ext', 'autoreload')
get_ipython().run_line_magic('autoreload', '2')

Mounted at /content/drive/


In [3]:
###############################
# Preparing the data
###############################
# Loading the pre-shuffled fashion mnist dataset
(X_fashion_train, y_fashion_train), (X_fashion_test, y_fashion_test) = fashion_mnist.load_data()
# Using only a part of the training data and splitting it into training and validation sets
X_train = (X_fashion_train.reshape(60000, 784)[:10000, :] > 127).astype(float)
y_train = y_fashion_train[:10000]
X_val = (X_fashion_train.reshape(60000, 784)[10000 : 12500, :] > 127).astype(float)
y_val = y_fashion_train[10000 : 12500]
# Using only a part of the test data
X_test = (X_fashion_test.reshape(10000, 784)[:2500, :] > 127).astype(float)
y_test = y_fashion_test[:2500]
# Checking the shapes
print(X_train.shape, y_train.shape, X_val.shape, y_val.shape, X_test.shape, y_test.shape)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
(10000, 784) (10000,) (2500, 784) (2500,) (2500, 784) (2500,)


In [4]:
###################################################
# Listing the hyperparameters in a wandb config 
###################################################
sweep_config = {'name': 'k_10_complete', 'method': 'grid'}
sweep_config['metric'] = {'name': 'val_acc', 'goal': 'maximize'}
parameters_dict = {
                   'num_hidden_vars': {'values': [64, 128, 256]}, # n
                   'num_steps_converge': {'values': [10]}, # k
                   'CD_etas': {'values': [0.001, 0.005, 0.01, 0.1]}, # eta
                  }
sweep_config['parameters'] = parameters_dict

In [5]:
import pickle
#####################################
# Defining the train function
#####################################
def RBM_wandb_logs(config=sweep_config):
    with wandb.init(config=config):
        config = wandb.init().config
        wandb.run.name = 'nh_{}_k_{}_CD_{}'.format(config.num_hidden_vars,\
                                                   config.num_steps_converge, \
                                                   config.CD_etas)
        
        ###########################################
        # Training a classifier using RBM
        ###########################################
        num_visible_vars = 784
        epochs = 10

        model = RBM(num_visible=num_visible_vars, num_hidden=config.num_hidden_vars)
        model.train(input_data=X_train, train_type="CD", epochs=epochs, \
                    k=config.num_steps_converge, eta=config.CD_etas)

        test_acc_hist = []

        for epoch in tqdm(range(epochs)):
            # Training the RBM for one epoch            
            # Getting hidden representations of the validation data and test data
            W = model.param_hist["W"][epoch]
            b = model.param_hist["b"][epoch]
            c = model.param_hist["c"][epoch]

            X_train_hidden = model.sample_h(W, c, X_train.T).T
            X_val_hidden = model.sample_h(W, c, X_val.T).T
            X_test_hidden = model.sample_h(W, c, X_test.T).T
            
            classifier = LogisticRegression(max_iter=500)
            classifier.fit(X_train_hidden, y_train)
            
            train_pred = classifier.predict(X_train_hidden)
            val_pred = classifier.predict(X_val_hidden)
            test_pred = classifier.predict(X_test_hidden)

            train_acc = np.sum(train_pred==y_train)/y_train.shape
            val_acc = np.sum(val_pred==y_val)/y_val.shape
            test_acc = np.sum(test_pred==y_test)/y_test.shape

            train_pred = classifier.predict_proba(X_train_hidden)
            val_pred = classifier.predict_proba(X_val_hidden)
            test_pred = classifier.predict_proba(X_test_hidden)

            train_loss = log_loss(y_train, train_pred)
            val_loss = log_loss(y_val, val_pred)
            test_loss = log_loss(y_test, test_pred)

            wandb.log({"train_acc": train_acc, "val_acc": val_acc, \
                       "test_acc": test_acc, "train_loss": train_loss, \
                       "val_loss": val_loss, "test_loss": test_loss, \
                       "RBM_train_loss": model.overall_loss[epoch]})

            pickling_on = open(wandb.run.name+".pickle","wb")
            pickle.dump(model, pickling_on)
            pickling_on.close()

In [6]:
#################################
# Run without wandb
#################################
num_hidden_vars = 256
num_steps_converge = 30
CD_etas = 0.1

model = RBM(num_visible=784, num_hidden=num_hidden_vars)
model.train(input_data=X_train, train_type="CD", epochs=1, \
            k=num_steps_converge, \
            eta=CD_etas)

In [None]:
#################################
# Setting up wandb sweeps
#################################
sweep_id = wandb.sweep(sweep_config, project = 'DL-Assignment4-Q4')
wandb.agent(sweep_id, function = RBM_wandb_logs)