In [8]:
import math
import random
import copy
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

In [2]:
def read_data(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()
    
    num_points = len(lines)
    dim_points = 28 * 28
    data = np.empty((num_points, dim_points))
    labels = np.empty(num_points)
    
    for ind, line in enumerate(lines):
        num = line.split(',')
        labels[ind] = int(num[0])
        data[ind] = [ int(x) for x in num[1:] ]
        
    return (data, labels)

train_data, train_labels = read_data("sample_train.csv")
test_data, test_labels = read_data("sample_test.csv")
print(train_data.shape, test_data.shape)
print(train_labels.shape, test_labels.shape)
print(type(test_data[0,0]))

(6000, 784) (1000, 784)
(6000,) (1000,)
<class 'numpy.float64'>


In [3]:
class DQN(nn.Module):
    def __init__(self, in_features):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(in_features=in_features, out_features=32)
        self.fc2 = nn.Linear(in_features=32, out_features=2)

    def forward(self, t):
        t = F.relu(self.fc1(t))
        t = self.fc2(t)
        return t

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

    def __len__(self):
        return len(self.memory)

class EpsilonGreedyStrategy():
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay

    def get_exploration_rate(self, current_step):
        return self.end + (self.start - self.end) * math.exp(-1. * current_step * self.decay)

In [17]:
class Net(nn.Module):
    def __init__(self, n_features):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_features, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.log_softmax(self.fc3(x))

def train(model, optimizer, X, criterion=nn.NLLLoss()):
    epochs = 1
    batch_size = 8
    for e in range(epochs):
        running_loss = 0
        num_batches = 0
        order = np.copy(X)
        np.random.shuffle(order)
        i = 0
        while i < len(X):
            j = min(i + batch_size, len(X))
#             print(order[i:j])
            images = train_data[order[i:j], :]
            labels = torch.Tensor(train_labels[order[i:j]]).long()
            optimizer.zero_grad()
            output = model(torch.from_numpy(images).float())
            loss = F.nll_loss(output, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            num_batches += 1
            i += batch_size
#         else:
#             print("Epoch {} - Training loss: {}".format(e, running_loss / num_batches))

def predict(model):
    model.eval()
    with torch.no_grad():
        output = model(torch.from_numpy(test_data).float())
    model.train()
    softmax = torch.exp(torch.Tensor(output))
    prob = list(softmax.numpy())
    predictions = np.argmax(prob, axis=1)
    return accuracy_score(test_labels, predictions)

In [19]:
batch_size = 64
gamma = 0.999
eps_start = 1
eps_end = 0.01
eps_decay = 0.001
memory_size = 5000
num_episodes = 10
state_len = 794
budget = 1500
strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)
memory = ReplayMemory(memory_size)
dqnet = DQN(state_len)
target_dqnet = DQN(state_len)
dqnet_optimizer = optim.Adam(dqnet.parameters(), lr=1e-4)
num_actions = 2
current_step = 0

def dqn_train(model, target_model, optimizer, mini_batch):
    criterion = nn.MSELoss()
    optimizer.zero_grad()
    states = np.array([exp[0] for exp in mini_batch])
    actions = torch.Tensor([[exp[1]] for exp in mini_batch]).long()
    rewards = torch.Tensor([exp[2] for exp in mini_batch])
    next_states = torch.Tensor([exp[3] for exp in mini_batch])
    target_model.eval()
    output = model(torch.from_numpy(states).float())
    predicted = torch.gather(output, 1, actions).squeeze()
    with torch.no_grad():
        labels_next = target_model(next_states).detach().max(1).values
    labels = rewards + gamma * labels_next
    loss = criterion(predicted, labels)
    loss.backward()
    optimizer.step()

def update_target_net(model, target_model, tau=1e-3):
    for target_param, local_param in zip(target_model.parameters(), model.parameters()):
        target_param.data.copy_(tau * local_param.data + (1 - tau) * target_param.data)

def select_action(state, model, current_step):
    eps = strategy.get_exploration_rate(current_step)
    if random.random() < eps:
        return random.randrange(num_actions)
    else:
        model.eval()
        with torch.no_grad():
            a = model(torch.from_numpy(state).float()).argmax().item()
        model.train()
        return a

order = list(range(0, train_data.shape[0]))

plots = []
X_labelled = []
random.shuffle(order)
for _ in range(num_episodes):
    plot_data = []
    i = 0
    j = 0
    print('Episode {} started'.format(_ + 1), len(X_labelled))
    model = Net(28*28)
    model_optimizer = optim.SGD(model.parameters(), lr=1e-5, momentum=0.9)
    prev_acc = predict(model)
    while j < train_data.shape[0]:
        part1 = np.copy(train_data[order[i]])
        part2 = model(torch.from_numpy(train_data[order[i]]).float()).detach().numpy()
        state = np.concatenate((part1, part2))
        a = select_action(state, dqnet, current_step)
        current_step += 1
        if a == 1:
            X_labelled.append(order[i])
            i += 1
            if i % 16 == 0:
                train(model, model_optimizer, X_labelled)
                print('Acc', acc)
                plot_data.append([len(X_labelled), acc])
        acc = predict(model)
        r = acc - prev_acc
        prev_acc = acc
        if i == budget:
            print('Budget over')
            X_labelled = []
            random.shuffle(order)
            part1 = np.copy(train_data[order[0]])
            part2 = model(torch.from_numpy(train_data[order[0]]).float()).detach().numpy()
            new_state = np.concatenate((part1, part2))
            memory.push(state, a, r, new_state)
            break
        part1 = np.copy(train_data[order[i + 1]])
        part2 = model(torch.from_numpy(train_data[i + 1]).float()).detach().numpy()
        new_state = np.concatenate((part1, part2))
        memory.push(state, a, r, new_state)
        if memory.can_provide_sample(batch_size):
            mini_batch = memory.sample(batch_size)
            dqn_train(dqnet, target_dqnet, dqnet_optimizer, mini_batch)
        if i % 8 == 0:
            update_target_net(dqnet, target_dqnet)
        j += 1
#         print('Rem budget: ', budget - i, j)
    plots.append(plot_data)
    if i < budget:
        print('Budget left', len(X_labelled))
        X_labelled = []
        random.shuffle(order)

Episode 1 started 0
Acc 0.139


  # This is added back by InteractiveShellApp.init_path()


Acc 0.138
Acc 0.143
Acc 0.141
Acc 0.158
Acc 0.197
Acc 0.247
Acc 0.299
Acc 0.344
Acc 0.353
Acc 0.387
Acc 0.422
Acc 0.437
Acc 0.454
Acc 0.462
Acc 0.477
Acc 0.494
Acc 0.501
Acc 0.517
Acc 0.529
Acc 0.558
Acc 0.567
Acc 0.585
Acc 0.591
Acc 0.613
Acc 0.617
Acc 0.627
Acc 0.632
Acc 0.645
Acc 0.647
Acc 0.656
Acc 0.654
Acc 0.668
Acc 0.664
Acc 0.677
Acc 0.681
Acc 0.691
Acc 0.693
Acc 0.701
Acc 0.704
Acc 0.708
Acc 0.713
Acc 0.723
Acc 0.725
Acc 0.718
Acc 0.727
Acc 0.73
Acc 0.72
Acc 0.73
Acc 0.737
Acc 0.734
Acc 0.738
Acc 0.737
Acc 0.738
Acc 0.744
Acc 0.746
Acc 0.745
Budget left
Episode 2 started 0
Acc 0.097
Acc 0.092
Acc 0.084
Acc 0.102
Acc 0.139
Acc 0.163
Acc 0.202
Acc 0.246
Acc 0.293
Acc 0.315
Acc 0.341
Acc 0.365
Acc 0.391
Acc 0.389
Acc 0.422
Acc 0.433
Acc 0.46
Acc 0.481
Acc 0.498
Acc 0.515
Acc 0.546
Acc 0.566
Acc 0.575
Acc 0.593
Acc 0.599
Acc 0.623
Acc 0.628
Acc 0.636
Acc 0.652
Acc 0.666
Acc 0.681
Acc 0.69
Acc 0.697
Acc 0.695
Acc 0.703
Acc 0.711
Acc 0.719
Acc 0.72
Acc 0.73
Acc 0.727
Acc 0.741
Acc 0

KeyboardInterrupt: 

In [None]:
for p in plots:
    print()
    print(p[:15])

In [9]:
temp = copy.deepcopy(plots)