In [1]:
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix


def load_data():
    mnist = datasets.fetch_openml("mnist_784", version=1, as_frame=False)
    X, y = mnist.data, mnist.target
    y = y.astype(np.int32)  # Convert target to integers
    return X, y


def split_data(X, y, test_size=0.2, random_state=42):
    return train_test_split(X, y, test_size=test_size, random_state=random_state)


def avg_pooling(image, pool_size):
    pooled_height = image.shape[0] // pool_size[0]
    pooled_width = image.shape[1] // pool_size[1]
    pooled_image = np.zeros((pooled_height, pooled_width))

    for i in range(pooled_height):
        for j in range(pooled_width):
            start_i, end_i = i * pool_size[0], (i + 1) * pool_size[0]
            start_j, end_j = j * pool_size[1], (j + 1) * pool_size[1]
            pooled_image[i, j] = np.mean(image[start_i:end_i, start_j:end_j])

    return pooled_image


def preprocess_data(X_train, X_test):
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)

    return X_train_scaled, X_test_scaled

In [2]:
def train_model(X_train, y_train):
    model = LogisticRegression(
        max_iter=1000, solver="lbfgs", multi_class="multinomial", random_state=42
    )
    model.fit(X_train, y_train)
    return model

In [3]:
def evaluate_model(pred_func, X_test, y_test):
    y_pred = pred_func(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    print("Accuracy:", accuracy)
    print("Classification Report:\n", classification_report(y_test, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))


def predict(model, new_data):
    predictions = model.predict(new_data)
    return predictions

In [4]:
import tqdm

# Load and split the data
X, y = load_data()

pool_size = 7
X_pooled = np.zeros((X.shape[0], (28 // pool_size) * (28 // pool_size)))
for i in tqdm.tqdm(range(X.shape[0])):
    image = X[i].reshape(28, 28)
    pooled_image = avg_pooling(image, (pool_size, pool_size))
    X_pooled[i] = pooled_image.flatten()

X_train, X_test, y_train, y_test = split_data(X_pooled, y)

# Preprocess the data
X_train_scaled, X_test_scaled = preprocess_data(X_train, X_test)

100%|██████████| 70000/70000 [00:03<00:00, 17515.25it/s]


In [5]:
# Train the model
model = train_model(X_train_scaled, y_train)



In [6]:
# Evaluate the model
evaluate_model(model.predict, X_test_scaled, y_test)

Accuracy: 0.7457142857142857
Classification Report:
               precision    recall  f1-score   support

           0       0.82      0.68      0.74      1343
           1       0.82      0.90      0.86      1600
           2       0.82      0.75      0.79      1380
           3       0.79      0.77      0.78      1433
           4       0.64      0.69      0.66      1295
           5       0.68      0.61      0.64      1273
           6       0.84      0.87      0.86      1396
           7       0.77      0.81      0.79      1503
           8       0.64      0.69      0.66      1357
           9       0.64      0.64      0.64      1420

    accuracy                           0.75     14000
   macro avg       0.74      0.74      0.74     14000
weighted avg       0.75      0.75      0.75     14000

Confusion Matrix:
 [[ 916   10   23   23   19   77   23    4  237   11]
 [   0 1446    6   36   13   50    3   10   24   12]
 [  11   33 1038   66   60   11   81   12   54   14]
 [  22   5

In [21]:
from neural_bandits.algorithms.linear_bandits import LinearTSBandit, LinearUCBBandit


n_features = X_train_scaled.shape[1]
alpha = 0.01
n_arms = 10  # Number of classes
bandit = LinearTSBandit(n_arms, n_features * n_arms)
# bandit = LinearUCBBandit(n_arms, n_features * n_arms)

In [22]:
import torch
from neural_bandits.utils.multiclass import MultiClassContextualiser
from neural_bandits.trainers.linear_trainer import LinearTrainer

mc_contextualiser = MultiClassContextualiser(n_arms)
trainer = LinearTrainer()

for t in tqdm.tqdm(range(X_train_scaled.shape[0])):
    x_tensor = torch.tensor(X_train_scaled[t], dtype=torch.float32).reshape(1, -1)
    contextualised_actions = mc_contextualiser.contextualise(x_tensor)
    chosen_arm = bandit(contextualised_actions)
    reward = 1 if y_train[t] == chosen_arm else 0
    reward = torch.tensor([reward], dtype=torch.float32)

    trainer.update(bandit, reward, contextualised_actions[0, chosen_arm].reshape(1, -1))

100%|██████████| 56000/56000 [01:04<00:00, 866.88it/s]


In [14]:
batch = 1000
buffer_reward = []
buffer_contextualised_actions = []
for t in tqdm.tqdm(range(X_train_scaled.shape[0])):
    # for t in tqdm.tqdm(range(1)):
    contextualised_actions = mc_contextualiser.contextualise(
        torch.tensor(X_train_scaled[t], dtype=torch.float32).reshape(1, -1)
    )
    chosen_arm = bandit(contextualised_actions)
    reward = 1 if y_train[t] == chosen_arm else 0
    reward = torch.tensor([reward], dtype=torch.float32)
    buffer_reward.append(reward)
    buffer_contextualised_actions.append(contextualised_actions[:, chosen_arm])


if t % batch == 99:
    bandit = trainer.update(
        bandit,
        torch.cat(buffer_reward, dim=0),
        torch.cat(buffer_contextualised_actions, dim=0),
    )
    buffer_reward = []
    buffer_contextualised_actions = []

100%|██████████| 56000/56000 [00:24<00:00, 2274.29it/s]


In [None]:
# X_test, y_test
def evaluate_model_direct(y_pred, y_test):
    accuracy = accuracy_score(y_test, y_pred)
    print("Accuracy:", accuracy)
    print("Classification Report:\n", classification_report(y_test, y_pred))
    print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))


predictions = []
for t in tqdm.tqdm(range(X_test_scaled.shape[0])):
    contextualised_actions = mc_contextualiser.contextualise(
        torch.tensor(X_train_scaled[t], dtype=torch.float32).reshape(1, -1)
    )
    chosen_arm = bandit(contextualised_actions)
    predictions.append(chosen_arm)

evaluate_model_direct(np.array(predictions), y_test)

tensor([-5.5911e-03, -1.1154e-02, -4.8883e-03,  1.2245e-02,  5.7645e-03,
        -4.4941e-02, -3.6960e-02, -6.4851e-04,  3.2460e-03, -4.5446e-02,
        -4.9883e-02, -8.1671e-02,  5.9538e-02,  8.7288e-02, -1.8538e-01,
        -7.6898e-02,  4.5177e-02, -8.4394e-02,  7.9709e-02, -1.9640e-01,
         8.8489e-02,  6.7506e-02,  3.3737e-02,  3.1615e-02, -2.2879e-01,
        -1.6920e-02,  8.1124e-02, -1.7138e-01,  1.3196e-01, -1.6653e-01,
        -4.8368e-02,  2.2461e-02, -5.2031e-02, -4.7352e-02,  2.9442e-02,
         7.2490e-02, -7.2030e-03, -3.7497e-02,  2.5162e-02, -1.3300e-01,
        -1.5405e-02, -5.1451e-02, -7.8030e-02, -6.0621e-03, -1.3564e-02,
        -5.3518e-02, -2.3202e-02,  4.3213e-02,  5.9318e-02, -3.2409e-02,
         7.1219e-02, -2.9018e-01, -2.9166e-01, -1.3476e-01, -4.5357e-02,
         1.6238e-01, -1.7955e-02, -5.7692e-02,  1.1849e-01,  3.1481e-02,
        -2.6842e-02,  1.6301e-01, -1.0856e-01,  1.1214e-01, -1.5431e-01,
        -2.1536e-01, -1.9540e-01,  1.1765e-01,  5.4

 24%|██▎       | 3297/14000 [00:02<00:08, 1209.85it/s]


KeyboardInterrupt: 