In [None]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

In [None]:
def read_data(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()
    
    num_points = len(lines)
    dim_points = 28 * 28
    data = np.empty((num_points, dim_points))
    labels = np.empty(num_points)
    
    for ind, line in enumerate(lines):
        num = line.split(',')
        labels[ind] = int(num[0])
        data[ind] = [ int(x) for x in num[1:] ]
        
    return (data, labels)

train_data, train_labels = read_data("sample_train.csv")
test_data, test_labels = read_data("sample_test.csv")
print(train_data.shape, test_data.shape)
print(train_labels.shape, test_labels.shape)

In [None]:
class DQN(nn.Module):
    def __init__(self, in_features):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(in_features=in_features, out_features=32)
        self.fc2 = nn.Linear(in_features=32, out_features=2)

    def forward(self, t):
        t = F.relu(self.fc1(t))
        t = self.fc2(t)
        return t

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

    def __len__(self):
        return len(self.memory)

class EpsilonGreedyStrategy():
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay

    def get_exploration_rate(self, current_step):
        return self.end + (self.start - self.end) * math.exp(-1. * current_step * self.decay)

class Agent():
    def __init__(self, strategy):
        self.current_step = 0
        self.strategy = strategy
        self.num_actions = 2

    def select_action(self, state, policy_net):
        rate = self.strategy.get_exploration_rate(self.current_step)
        self.current_step += 1

        if random.random() < eps:
            return random.randrange(self.num_actions) # explore
        else:
            with torch.no_grad():
                return policy_net(state).argmax(dim=1).item() # exploit

In [None]:
class Net(nn.Module):
    def __init__(self, n_features):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_features, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return torch.LogSoftmax(self.fc3(x), dim=1)

def train(model):
    epochs = 10
    for e in range(epochs):
        running_loss = 0
        for images, labels in trainloader:
            optimizer.zero_grad()
            output = model(images)
            loss = criterion(output, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        else:
            print("Epoch {} - Training loss: {}".format(e, running_loss/len(trainloader)))

def predict(model):
    with torch.no_grad():
        output = model(test_data)
    softmax = torch.exp(output)
    prob = list(softmax.numpy())
    predictions = np.argmax(prob, axis=1)
    return accuracy_score(test_labels, predictions)

In [None]:
batch_size = 32
gamma = 0.999
eps_start = 1
eps_end = 0.01
eps_decay = 0.001
memory_size = 10000
num_episodes = 10
state_len = 784
budget = train_data.shape[0] // 4
strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)
agent = Agent(strategy)
memory = ReplayMemory(memory_size)
dqnet = DQN(state_len)

order = range(0, train_data.shape[0])

X_labelled = []
y_labelled = []
random.shuffle(order)
i = 0
for _ in range(num_episodes):
    model = Net()
    model_optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
    prev_acc = predict(model)
    while i < budget:
        part1 = np.copy(train_data[order[i]])
        part2 = model(train_data[order[i]])
        state = np.concatenate((part1, part2))
        a = np.argmax(dqnet(state))
        if a == 1:
            y_labelled.append(order[i])
            X_labelled.append(order[i])
            model.train()
            i += 1
        acc = predict(model)
        r = acc - prev_acc
        prev_acc = acc
        if i == budget:
            X_labelled = []
            y_labelled = []
            random.shuffle(order)
            i = 0
            part1 = np.copy(train_data[order[i]])
            part2 = model(train_data[order[i]])
            new_state = np.concatenate((part1, part2))
            memory.push(state, a, r, new_state)
            break
        part1 = np.copy(train_data[order[i + 1]])
        part2 = model(train_data[i + 1])
        new_state = np.concatenate((part1, part2))
        memory.push(state, a, r, new_state)
        if i % 10 == 0 and memory.can_provide_sample(batch_size)
            mini_batch = memory.sample(batch_size)
            # update DQN