## **Contextual Multi Armed Bandit**

**$N$**  - bandits (ads to show on a website)\
**$k$** - context, k type of customer visiting online store website

**$M$** = $matrix_{kxN}$

In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt

torch.manual_seed(3407)

<torch._C.Generator at 0x7f6b4c14e510>

In [3]:
context_size = 5
n_bandits = 10

In [4]:
matrix = torch.rand(n_bandits, context_size)
matrix

tensor([[0.8823, 0.9150, 0.3829, 0.9593, 0.3904],
        [0.6009, 0.2566, 0.7936, 0.9408, 0.1332],
        [0.9346, 0.5936, 0.8694, 0.5677, 0.7411],
        [0.4294, 0.8854, 0.5739, 0.2666, 0.6274],
        [0.2696, 0.4414, 0.2969, 0.8317, 0.1053],
        [0.2695, 0.3588, 0.1994, 0.5472, 0.0062],
        [0.9516, 0.0753, 0.8860, 0.5832, 0.3376],
        [0.8090, 0.5779, 0.9040, 0.5547, 0.3423],
        [0.6343, 0.3644, 0.7104, 0.9464, 0.7890],
        [0.2814, 0.7886, 0.5895, 0.7539, 0.1952]])

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
torch.manual_seed(42)

<torch._C.Generator at 0x7f03140b0530>

In [14]:
class DQN(nn.Module):
    def __init__(self, num_arms, context_size):
        super(DQN, self).__init__()
        self.num_arms = num_arms
        self.context_size = context_size
        self.fc1 = nn.Linear(context_size, 64)
        self.fc2 = nn.Linear(64, num_arms)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.softmax(self.fc2(x), dim=1)
        return x

model = DQN(num_arms=n_bandits, context_size=context_size)

y_hat = model(torch.tensor([0, 0, 1, 0, 0], dtype=torch.float32).reshape(1, -1))
torch.argmax(y_hat)

tensor(6)

In [None]:
class Agent:
    def __init__(self, num_arms, context_size, epsilon=0.1, learning_rate=0.01, gamma=0.99):
        self.num_arms = num_arms
        self.context_size = context_size
        self.epsilon = epsilon
        self.learning_rate = learning_rate
        self.gamma = gamma
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Initialize Q-network
        self.q_network = DQN(num_arms, context_size).to(self.device)
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=learning_rate)

    def select_action(self, context, rewards):
        input_tensor = torch.cat((torch.tensor(rewards), torch.tensor(context)), dim=0).float().to(self.device)
        if random.random() < self.epsilon:
            # Explore: Select a random action
            action = random.randint(0, self.num_arms - 1)
        else:
            # Exploit: Select action with highest Q-value
            with torch.no_grad():
                q_values = self.q_network(input_tensor)
                action = torch.argmax(q_values).item()

        return action

    def update_network(self, context, rewards, action, next_rewards):
        input_tensor = torch.cat((torch.tensor(rewards), torch.tensor(context)), dim=0).float().to(self.device)
        next_input_tensor = torch.cat((torch.tensor(next_rewards), torch.tensor(context)), dim=0).float().to(self.device)

        q_values = self.q_network(input_tensor)
        next_q_values = self.q_network(next_input_tensor)

        # Compute target Q-value using the Bellman equation
        target_q = rewards[action] + self.gamma * torch.max(next_q_values).item()
        q_values[action] = target_q

        # Update the Q-network using backpropagation
        loss = nn.MSELoss()(q_values, self.q_network(input_tensor))
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

    def train(self, episodes):
        for episode in range(episodes):
            context = self.get_context()  # Determine the context for this episode
            rewards = [0] * self.num_arms  # Initialize rewards for each arm
            action = self.select_action(context, rewards)  # Select action based on epsilon-greedy policy

            # Simulate pulling the selected arm and observe the reward
            reward = self.pull_arm(action, context)
            rewards[action] = reward

            # Update the Q-network based on the observed reward
            self.update_network(context, rewards, action, rewards)