In [1]:
import math
import random
import copy
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

In [2]:
def read_data(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()
    
    num_points = len(lines)
    dim_points = 28 * 28
    data = np.empty((num_points, dim_points))
    labels = np.empty(num_points)
    
    for ind, line in enumerate(lines):
        num = line.split(',')
        labels[ind] = int(num[0])
        data[ind] = [ int(x) for x in num[1:] ]
        
    return (data, labels)

train_data, train_labels = read_data("sample_train.csv")
test_data, test_labels = read_data("sample_test.csv")
print(train_data.shape, test_data.shape)
print(train_labels.shape, test_labels.shape)

(6000, 784) (1000, 784)
(6000,) (1000,)


In [3]:
class DQN(nn.Module):
    def __init__(self, in_features):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(in_features=in_features, out_features=1024)
        self.fc2 = nn.Linear(in_features=1024, out_features=32)
        self.fc3 = nn.Linear(in_features=32, out_features=2)

    def forward(self, t):
        t = F.relu(self.fc1(t))
        t = F.relu(self.fc2(t))
        t = self.fc3(t)
        return t

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

    def __len__(self):
        return len(self.memory)

class EpsilonGreedyStrategy():
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay

    def get_exploration_rate(self, current_step):
        return self.end + (self.start - self.end) * math.exp(-1. * current_step * self.decay)

In [4]:
class Net(nn.Module):
    def __init__(self, n_features):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_features, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.log_softmax(self.fc3(x))

def train(model, optimizer, X, criterion=nn.NLLLoss()):
    epochs = 1
    batch_size = 64
    for e in range(epochs):
        running_loss = 0
        num_batches = 0
        order = np.copy(X)
        np.random.shuffle(order)
        i = 0
        while i < len(X):
            j = min(i + batch_size, len(X))
            images = train_data[order[i:j], :]
            labels = torch.Tensor(train_labels[order[i:j]]).long()
            optimizer.zero_grad()
            output = model(torch.from_numpy(images).float())
            loss = F.nll_loss(output, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            num_batches += 1
            i += batch_size

def predict(model):
    with torch.no_grad():
        output = model(torch.from_numpy(test_data).float())
    softmax = torch.exp(torch.Tensor(output))
    prob = softmax.numpy()
    predictions = np.argmax(prob, axis=1)
    return accuracy_score(test_labels, predictions)

In [5]:
batch_size = 64
gamma = 0.999
eps_start = 1
eps_end = 0.01
eps_decay = 0.0001
memory_size = 1000
num_episodes = 5
state_len = 10 + 784
budget = 1000
strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)
memory = ReplayMemory(memory_size)
dqnet = DQN(state_len)
target_dqnet = DQN(state_len)
target_dqnet.load_state_dict(dqnet.state_dict())
dqnet_optimizer = optim.Adam(dqnet.parameters(), lr=5e-3)
num_actions = 2
current_step = 0
target_update = 8

In [6]:
def dqn_train(model, target_model, optimizer, mini_batch):
    criterion = nn.MSELoss()
    optimizer.zero_grad()
    states = np.array([exp[0] for exp in mini_batch])
    actions = torch.Tensor([[exp[1]] for exp in mini_batch]).long()
    rewards = torch.Tensor([exp[2] for exp in mini_batch])
    next_states = torch.Tensor([exp[3] for exp in mini_batch])
    output = model(torch.from_numpy(states).float())
    predicted = torch.gather(output, 1, actions).squeeze()
    with torch.no_grad():
        labels_next = target_model(next_states).detach().max(1).values
    labels = rewards + gamma * labels_next
    loss = criterion(predicted, labels)
    loss.backward()
    optimizer.step()

def select_action(state, model, current_step):
    eps = strategy.get_exploration_rate(current_step)
    if random.random() < eps:
        return random.randrange(num_actions)
    else:
        with torch.no_grad():
            a = model(torch.from_numpy(state).float()).argmax().item()
        return a

In [7]:
order = list(range(0, train_data.shape[0]))
X_labelled = []
random.shuffle(order)
for _ in range(num_episodes):
    memory.memory.clear()
    i = 0
    print('Episode {} started'.format(_ + 1))
    model = Net(28 * 28)
    model_optimizer = optim.SGD(model.parameters(), lr=1e-5, momentum=0.9)
    prev_acc = 0
    for j in range(train_data.shape[0]):
        print('episode ', _, ', id ', j, ', label ', train_labels[order[j]])
        sample = train_data[order[j]]
        part1 = np.copy(sample)
        with torch.no_grad():
            part2 = model(torch.from_numpy(sample).float()).detach().numpy()
            part2 = np.exp(part2)
            part2 = part2 / np.sum(part2)
        state = np.concatenate((part1, part2))
        
        a = select_action(state, dqnet, current_step)
        current_step += 1
        queried = False
        if a == 1:
            queried = True
            X_labelled.append(order[j])
            i += 1
            if i % 16 == 0:
                train(model, model_optimizer, X_labelled)
        if queried:
            acc = predict(model)
            r = acc - prev_acc
            prev_acc = acc
        else:
            r = 0
        if i == budget:
            print('Budget over')
            X_labelled = []
            random.shuffle(order)
            sample = train_data[order[0]]
            part1 = np.copy(sample)
            with torch.no_grad():
                part2 = model(torch.from_numpy(sample).float()).detach().numpy()
                part2 = np.exp(part2)
                part2 = part2 / np.sum(part2)
            new_state = np.concatenate((part1, part2))
            memory.push(state, a, r, new_state)
            break
        
        sample = train_data[order[j + 1]]
        part1 = np.copy(sample)
        with torch.no_grad():
            part2 = model(torch.from_numpy(sample).float()).detach().numpy()
            part2 = np.exp(part2)
            part2 = part2 / np.sum(part2)
        new_state = np.concatenate((part1, part2))
        memory.push(state, a, r, new_state)
        if memory.can_provide_sample(batch_size):
            mini_batch = memory.sample(batch_size)
            dqn_train(dqnet, target_dqnet, dqnet_optimizer, mini_batch)
        if j % target_update == 0:
            target_dqnet.load_state_dict(dqnet.state_dict())
    if i < budget:
        X_labelled = []
        random.shuffle(order)

Episode 1 started
episode  0 , id  0 , label  6.0
episode  0 , id  1 , label  5.0
episode  0 , id  2 , label  9.0
episode  0 , id  3 , label  3.0
episode  0 , id  4 , label  4.0
episode  0 , id  5 , label  7.0
episode  0 , id  6 , label  8.0
episode  0 , id  7 , label  7.0
episode  0 , id  8 , label  3.0
episode  0 , id  9 , label  2.0
episode  0 , id  10 , label  4.0
episode  0 , id  11 , label  3.0
episode  0 , id  12 , label  0.0
episode  0 , id  13 , label  7.0
episode  0 , id  14 , label  2.0
episode  0 , id  15 , label  8.0
episode  0 , id  16 , label  0.0
episode  0 , id  17 , label  6.0
episode  0 , id  18 , label  1.0
episode  0 , id  19 , label  9.0
episode  0 , id  20 , label  1.0
episode  0 , id  21 , label  2.0
episode  0 , id  22 , label  6.0
episode  0 , id  23 , label  0.0
episode  0 , id  24 , label  3.0
episode  0 , id  25 , label  5.0
episode  0 , id  26 , label  3.0
episode  0 , id  27 , label  4.0
episode  0 , id  28 , label  8.0
episode  0 , id  29 , label  0.0
ep

  # This is added back by InteractiveShellApp.init_path()


episode  0 , id  66 , label  5.0
episode  0 , id  67 , label  1.0
episode  0 , id  68 , label  5.0
episode  0 , id  69 , label  4.0
episode  0 , id  70 , label  9.0
episode  0 , id  71 , label  9.0
episode  0 , id  72 , label  6.0
episode  0 , id  73 , label  5.0
episode  0 , id  74 , label  8.0
episode  0 , id  75 , label  8.0
episode  0 , id  76 , label  5.0
episode  0 , id  77 , label  7.0
episode  0 , id  78 , label  7.0
episode  0 , id  79 , label  9.0
episode  0 , id  80 , label  1.0
episode  0 , id  81 , label  1.0
episode  0 , id  82 , label  5.0
episode  0 , id  83 , label  1.0
episode  0 , id  84 , label  2.0
episode  0 , id  85 , label  1.0
episode  0 , id  86 , label  3.0
episode  0 , id  87 , label  3.0
episode  0 , id  88 , label  8.0
episode  0 , id  89 , label  9.0
episode  0 , id  90 , label  1.0
episode  0 , id  91 , label  6.0
episode  0 , id  92 , label  2.0
episode  0 , id  93 , label  0.0
episode  0 , id  94 , label  0.0
episode  0 , id  95 , label  5.0
episode  0

episode  0 , id  308 , label  5.0
episode  0 , id  309 , label  9.0
episode  0 , id  310 , label  9.0
episode  0 , id  311 , label  7.0
episode  0 , id  312 , label  6.0
episode  0 , id  313 , label  4.0
episode  0 , id  314 , label  9.0
episode  0 , id  315 , label  4.0
episode  0 , id  316 , label  1.0
episode  0 , id  317 , label  0.0
episode  0 , id  318 , label  3.0
episode  0 , id  319 , label  0.0
episode  0 , id  320 , label  4.0
episode  0 , id  321 , label  2.0
episode  0 , id  322 , label  5.0
episode  0 , id  323 , label  4.0
episode  0 , id  324 , label  7.0
episode  0 , id  325 , label  4.0
episode  0 , id  326 , label  4.0
episode  0 , id  327 , label  5.0
episode  0 , id  328 , label  8.0
episode  0 , id  329 , label  7.0
episode  0 , id  330 , label  5.0
episode  0 , id  331 , label  1.0
episode  0 , id  332 , label  9.0
episode  0 , id  333 , label  7.0
episode  0 , id  334 , label  9.0
episode  0 , id  335 , label  3.0
episode  0 , id  336 , label  0.0
episode  0 , i

episode  0 , id  556 , label  3.0
episode  0 , id  557 , label  4.0
episode  0 , id  558 , label  4.0
episode  0 , id  559 , label  1.0
episode  0 , id  560 , label  4.0
episode  0 , id  561 , label  1.0
episode  0 , id  562 , label  8.0
episode  0 , id  563 , label  2.0
episode  0 , id  564 , label  3.0
episode  0 , id  565 , label  0.0
episode  0 , id  566 , label  1.0
episode  0 , id  567 , label  1.0
episode  0 , id  568 , label  6.0
episode  0 , id  569 , label  2.0
episode  0 , id  570 , label  6.0
episode  0 , id  571 , label  1.0
episode  0 , id  572 , label  1.0
episode  0 , id  573 , label  2.0
episode  0 , id  574 , label  4.0
episode  0 , id  575 , label  8.0
episode  0 , id  576 , label  8.0
episode  0 , id  577 , label  5.0
episode  0 , id  578 , label  3.0
episode  0 , id  579 , label  4.0
episode  0 , id  580 , label  2.0
episode  0 , id  581 , label  6.0
episode  0 , id  582 , label  8.0
episode  0 , id  583 , label  3.0
episode  0 , id  584 , label  7.0
episode  0 , i

episode  0 , id  801 , label  3.0
episode  0 , id  802 , label  6.0
episode  0 , id  803 , label  9.0
episode  0 , id  804 , label  8.0
episode  0 , id  805 , label  3.0
episode  0 , id  806 , label  6.0
episode  0 , id  807 , label  5.0
episode  0 , id  808 , label  0.0
episode  0 , id  809 , label  0.0
episode  0 , id  810 , label  3.0
episode  0 , id  811 , label  1.0
episode  0 , id  812 , label  1.0
episode  0 , id  813 , label  5.0
episode  0 , id  814 , label  4.0
episode  0 , id  815 , label  3.0
episode  0 , id  816 , label  4.0
episode  0 , id  817 , label  8.0
episode  0 , id  818 , label  3.0
episode  0 , id  819 , label  6.0
episode  0 , id  820 , label  5.0
episode  0 , id  821 , label  9.0
episode  0 , id  822 , label  7.0
episode  0 , id  823 , label  4.0
episode  0 , id  824 , label  6.0
episode  0 , id  825 , label  1.0
episode  0 , id  826 , label  1.0
episode  0 , id  827 , label  6.0
episode  0 , id  828 , label  8.0
episode  0 , id  829 , label  2.0
episode  0 , i

episode  0 , id  1042 , label  9.0
episode  0 , id  1043 , label  9.0
episode  0 , id  1044 , label  7.0
episode  0 , id  1045 , label  2.0
episode  0 , id  1046 , label  7.0
episode  0 , id  1047 , label  8.0
episode  0 , id  1048 , label  7.0
episode  0 , id  1049 , label  2.0
episode  0 , id  1050 , label  1.0
episode  0 , id  1051 , label  0.0
episode  0 , id  1052 , label  2.0
episode  0 , id  1053 , label  3.0
episode  0 , id  1054 , label  7.0
episode  0 , id  1055 , label  6.0
episode  0 , id  1056 , label  7.0
episode  0 , id  1057 , label  6.0
episode  0 , id  1058 , label  9.0
episode  0 , id  1059 , label  7.0
episode  0 , id  1060 , label  0.0
episode  0 , id  1061 , label  7.0
episode  0 , id  1062 , label  8.0
episode  0 , id  1063 , label  2.0
episode  0 , id  1064 , label  5.0
episode  0 , id  1065 , label  0.0
episode  0 , id  1066 , label  3.0
episode  0 , id  1067 , label  0.0
episode  0 , id  1068 , label  3.0
episode  0 , id  1069 , label  2.0
episode  0 , id  107

episode  0 , id  1278 , label  4.0
episode  0 , id  1279 , label  7.0
episode  0 , id  1280 , label  9.0
episode  0 , id  1281 , label  0.0
episode  0 , id  1282 , label  5.0
episode  0 , id  1283 , label  0.0
episode  0 , id  1284 , label  0.0
episode  0 , id  1285 , label  2.0
episode  0 , id  1286 , label  6.0
episode  0 , id  1287 , label  8.0
episode  0 , id  1288 , label  1.0
episode  0 , id  1289 , label  6.0
episode  0 , id  1290 , label  9.0
episode  0 , id  1291 , label  2.0
episode  0 , id  1292 , label  9.0
episode  0 , id  1293 , label  1.0
episode  0 , id  1294 , label  0.0
episode  0 , id  1295 , label  3.0
episode  0 , id  1296 , label  3.0
episode  0 , id  1297 , label  0.0
episode  0 , id  1298 , label  4.0
episode  0 , id  1299 , label  9.0
episode  0 , id  1300 , label  7.0
episode  0 , id  1301 , label  1.0
episode  0 , id  1302 , label  5.0
episode  0 , id  1303 , label  4.0
episode  0 , id  1304 , label  2.0
episode  0 , id  1305 , label  3.0
episode  0 , id  130

episode  0 , id  1514 , label  6.0
episode  0 , id  1515 , label  4.0
episode  0 , id  1516 , label  9.0
episode  0 , id  1517 , label  5.0
episode  0 , id  1518 , label  8.0
episode  0 , id  1519 , label  3.0
episode  0 , id  1520 , label  5.0
episode  0 , id  1521 , label  3.0
episode  0 , id  1522 , label  4.0
episode  0 , id  1523 , label  2.0
episode  0 , id  1524 , label  9.0
episode  0 , id  1525 , label  0.0
episode  0 , id  1526 , label  7.0
episode  0 , id  1527 , label  1.0
episode  0 , id  1528 , label  7.0
episode  0 , id  1529 , label  4.0
episode  0 , id  1530 , label  6.0
episode  0 , id  1531 , label  9.0
episode  0 , id  1532 , label  0.0
episode  0 , id  1533 , label  7.0
episode  0 , id  1534 , label  4.0
episode  0 , id  1535 , label  4.0
episode  0 , id  1536 , label  1.0
episode  0 , id  1537 , label  6.0
episode  0 , id  1538 , label  6.0
episode  0 , id  1539 , label  8.0
episode  0 , id  1540 , label  1.0
episode  0 , id  1541 , label  8.0
episode  0 , id  154

episode  0 , id  1756 , label  4.0
episode  0 , id  1757 , label  4.0
episode  0 , id  1758 , label  2.0
episode  0 , id  1759 , label  5.0
episode  0 , id  1760 , label  2.0
episode  0 , id  1761 , label  4.0
episode  0 , id  1762 , label  7.0
episode  0 , id  1763 , label  9.0
episode  0 , id  1764 , label  8.0
episode  0 , id  1765 , label  5.0
episode  0 , id  1766 , label  3.0
episode  0 , id  1767 , label  3.0
episode  0 , id  1768 , label  4.0
episode  0 , id  1769 , label  6.0
episode  0 , id  1770 , label  1.0
episode  0 , id  1771 , label  3.0
episode  0 , id  1772 , label  6.0
episode  0 , id  1773 , label  3.0
episode  0 , id  1774 , label  4.0
episode  0 , id  1775 , label  8.0
episode  0 , id  1776 , label  6.0
episode  0 , id  1777 , label  4.0
episode  0 , id  1778 , label  8.0
episode  0 , id  1779 , label  3.0
episode  0 , id  1780 , label  9.0
episode  0 , id  1781 , label  1.0
episode  0 , id  1782 , label  2.0
episode  0 , id  1783 , label  7.0
episode  0 , id  178

episode  0 , id  2003 , label  9.0
episode  0 , id  2004 , label  9.0
episode  0 , id  2005 , label  9.0
episode  0 , id  2006 , label  4.0
episode  0 , id  2007 , label  8.0
episode  0 , id  2008 , label  8.0
episode  0 , id  2009 , label  6.0
episode  0 , id  2010 , label  1.0
episode  0 , id  2011 , label  3.0
episode  0 , id  2012 , label  5.0
episode  0 , id  2013 , label  2.0
episode  0 , id  2014 , label  3.0
episode  0 , id  2015 , label  2.0
episode  0 , id  2016 , label  9.0
episode  0 , id  2017 , label  1.0
episode  0 , id  2018 , label  9.0
episode  0 , id  2019 , label  3.0
episode  0 , id  2020 , label  1.0
episode  0 , id  2021 , label  4.0
episode  0 , id  2022 , label  3.0
episode  0 , id  2023 , label  3.0
episode  0 , id  2024 , label  8.0
episode  0 , id  2025 , label  4.0
episode  0 , id  2026 , label  5.0
Budget over
final acc 0.65
Episode 2 started
episode  1 , id  0 , label  7.0
episode  1 , id  1 , label  1.0
episode  1 , id  2 , label  7.0
episode  1 , id  3 

episode  1 , id  228 , label  5.0
episode  1 , id  229 , label  3.0
episode  1 , id  230 , label  3.0
episode  1 , id  231 , label  6.0
episode  1 , id  232 , label  8.0
episode  1 , id  233 , label  4.0
episode  1 , id  234 , label  2.0
episode  1 , id  235 , label  7.0
episode  1 , id  236 , label  9.0
episode  1 , id  237 , label  4.0
episode  1 , id  238 , label  0.0
episode  1 , id  239 , label  2.0
episode  1 , id  240 , label  0.0
episode  1 , id  241 , label  3.0
episode  1 , id  242 , label  5.0
episode  1 , id  243 , label  8.0
episode  1 , id  244 , label  7.0
episode  1 , id  245 , label  4.0
episode  1 , id  246 , label  5.0
episode  1 , id  247 , label  6.0
episode  1 , id  248 , label  1.0
episode  1 , id  249 , label  7.0
episode  1 , id  250 , label  6.0
episode  1 , id  251 , label  3.0
episode  1 , id  252 , label  7.0
episode  1 , id  253 , label  0.0
episode  1 , id  254 , label  2.0
episode  1 , id  255 , label  7.0
episode  1 , id  256 , label  4.0
episode  1 , i

episode  1 , id  471 , label  3.0
episode  1 , id  472 , label  0.0
episode  1 , id  473 , label  4.0
episode  1 , id  474 , label  3.0
episode  1 , id  475 , label  3.0
episode  1 , id  476 , label  7.0
episode  1 , id  477 , label  3.0
episode  1 , id  478 , label  2.0
episode  1 , id  479 , label  2.0
episode  1 , id  480 , label  0.0
episode  1 , id  481 , label  4.0
episode  1 , id  482 , label  3.0
episode  1 , id  483 , label  7.0
episode  1 , id  484 , label  9.0
episode  1 , id  485 , label  9.0
episode  1 , id  486 , label  8.0
episode  1 , id  487 , label  9.0
episode  1 , id  488 , label  2.0
episode  1 , id  489 , label  4.0
episode  1 , id  490 , label  6.0
episode  1 , id  491 , label  3.0
episode  1 , id  492 , label  4.0
episode  1 , id  493 , label  5.0
episode  1 , id  494 , label  5.0
episode  1 , id  495 , label  6.0
episode  1 , id  496 , label  3.0
episode  1 , id  497 , label  6.0
episode  1 , id  498 , label  2.0
episode  1 , id  499 , label  6.0
episode  1 , i

episode  1 , id  714 , label  1.0
episode  1 , id  715 , label  8.0
episode  1 , id  716 , label  7.0
episode  1 , id  717 , label  9.0
episode  1 , id  718 , label  6.0
episode  1 , id  719 , label  2.0
episode  1 , id  720 , label  2.0
episode  1 , id  721 , label  9.0
episode  1 , id  722 , label  9.0
episode  1 , id  723 , label  3.0
episode  1 , id  724 , label  8.0
episode  1 , id  725 , label  5.0
episode  1 , id  726 , label  9.0
episode  1 , id  727 , label  1.0
episode  1 , id  728 , label  0.0
episode  1 , id  729 , label  0.0
episode  1 , id  730 , label  8.0
episode  1 , id  731 , label  0.0
episode  1 , id  732 , label  9.0
episode  1 , id  733 , label  7.0
episode  1 , id  734 , label  6.0
episode  1 , id  735 , label  7.0
episode  1 , id  736 , label  4.0
episode  1 , id  737 , label  3.0
episode  1 , id  738 , label  8.0
episode  1 , id  739 , label  7.0
episode  1 , id  740 , label  5.0
episode  1 , id  741 , label  4.0
episode  1 , id  742 , label  6.0
episode  1 , i

episode  1 , id  962 , label  7.0
episode  1 , id  963 , label  6.0
episode  1 , id  964 , label  9.0
episode  1 , id  965 , label  6.0
episode  1 , id  966 , label  3.0
episode  1 , id  967 , label  1.0
episode  1 , id  968 , label  7.0
episode  1 , id  969 , label  9.0
episode  1 , id  970 , label  3.0
episode  1 , id  971 , label  5.0
episode  1 , id  972 , label  5.0
episode  1 , id  973 , label  6.0
episode  1 , id  974 , label  7.0
episode  1 , id  975 , label  1.0
episode  1 , id  976 , label  7.0
episode  1 , id  977 , label  7.0
episode  1 , id  978 , label  0.0
episode  1 , id  979 , label  7.0
episode  1 , id  980 , label  2.0
episode  1 , id  981 , label  7.0
episode  1 , id  982 , label  7.0
episode  1 , id  983 , label  4.0
episode  1 , id  984 , label  5.0
episode  1 , id  985 , label  1.0
episode  1 , id  986 , label  6.0
episode  1 , id  987 , label  3.0
episode  1 , id  988 , label  0.0
episode  1 , id  989 , label  4.0
episode  1 , id  990 , label  2.0
episode  1 , i

episode  1 , id  1204 , label  3.0
episode  1 , id  1205 , label  5.0
episode  1 , id  1206 , label  5.0
episode  1 , id  1207 , label  3.0
episode  1 , id  1208 , label  0.0
episode  1 , id  1209 , label  9.0
episode  1 , id  1210 , label  7.0
episode  1 , id  1211 , label  7.0
episode  1 , id  1212 , label  6.0
episode  1 , id  1213 , label  5.0
episode  1 , id  1214 , label  0.0
episode  1 , id  1215 , label  2.0
episode  1 , id  1216 , label  1.0
episode  1 , id  1217 , label  3.0
episode  1 , id  1218 , label  5.0
episode  1 , id  1219 , label  8.0
episode  1 , id  1220 , label  0.0
episode  1 , id  1221 , label  2.0
episode  1 , id  1222 , label  7.0
episode  1 , id  1223 , label  3.0
episode  1 , id  1224 , label  6.0
episode  1 , id  1225 , label  4.0
episode  1 , id  1226 , label  0.0
episode  1 , id  1227 , label  0.0
episode  1 , id  1228 , label  6.0
episode  1 , id  1229 , label  7.0
episode  1 , id  1230 , label  0.0
episode  1 , id  1231 , label  4.0
episode  1 , id  123

episode  1 , id  1448 , label  0.0
episode  1 , id  1449 , label  5.0
episode  1 , id  1450 , label  2.0
episode  1 , id  1451 , label  3.0
episode  1 , id  1452 , label  4.0
episode  1 , id  1453 , label  4.0
episode  1 , id  1454 , label  0.0
episode  1 , id  1455 , label  8.0
episode  1 , id  1456 , label  0.0
episode  1 , id  1457 , label  8.0
episode  1 , id  1458 , label  9.0
episode  1 , id  1459 , label  7.0
episode  1 , id  1460 , label  9.0
episode  1 , id  1461 , label  0.0
episode  1 , id  1462 , label  4.0
episode  1 , id  1463 , label  6.0
episode  1 , id  1464 , label  4.0
episode  1 , id  1465 , label  0.0
episode  1 , id  1466 , label  8.0
episode  1 , id  1467 , label  3.0
episode  1 , id  1468 , label  4.0
episode  1 , id  1469 , label  6.0
episode  1 , id  1470 , label  8.0
episode  1 , id  1471 , label  9.0
episode  1 , id  1472 , label  1.0
episode  1 , id  1473 , label  2.0
episode  1 , id  1474 , label  1.0
episode  1 , id  1475 , label  0.0
episode  1 , id  147

episode  1 , id  1690 , label  7.0
episode  1 , id  1691 , label  3.0
episode  1 , id  1692 , label  7.0
episode  1 , id  1693 , label  4.0
episode  1 , id  1694 , label  1.0
episode  1 , id  1695 , label  8.0
episode  1 , id  1696 , label  4.0
episode  1 , id  1697 , label  6.0
episode  1 , id  1698 , label  9.0
episode  1 , id  1699 , label  6.0
episode  1 , id  1700 , label  8.0
episode  1 , id  1701 , label  1.0
episode  1 , id  1702 , label  8.0
episode  1 , id  1703 , label  1.0
episode  1 , id  1704 , label  1.0
episode  1 , id  1705 , label  2.0
episode  1 , id  1706 , label  6.0
episode  1 , id  1707 , label  0.0
episode  1 , id  1708 , label  5.0
episode  1 , id  1709 , label  9.0
episode  1 , id  1710 , label  2.0
episode  1 , id  1711 , label  3.0
episode  1 , id  1712 , label  4.0
episode  1 , id  1713 , label  4.0
episode  1 , id  1714 , label  6.0
episode  1 , id  1715 , label  0.0
episode  1 , id  1716 , label  6.0
episode  1 , id  1717 , label  6.0
episode  1 , id  171

episode  1 , id  1933 , label  2.0
episode  1 , id  1934 , label  0.0
episode  1 , id  1935 , label  3.0
episode  1 , id  1936 , label  3.0
episode  1 , id  1937 , label  3.0
episode  1 , id  1938 , label  1.0
episode  1 , id  1939 , label  5.0
episode  1 , id  1940 , label  3.0
episode  1 , id  1941 , label  2.0
episode  1 , id  1942 , label  3.0
episode  1 , id  1943 , label  8.0
episode  1 , id  1944 , label  6.0
episode  1 , id  1945 , label  4.0
episode  1 , id  1946 , label  0.0
episode  1 , id  1947 , label  2.0
episode  1 , id  1948 , label  9.0
episode  1 , id  1949 , label  3.0
episode  1 , id  1950 , label  3.0
episode  1 , id  1951 , label  0.0
episode  1 , id  1952 , label  2.0
episode  1 , id  1953 , label  1.0
episode  1 , id  1954 , label  0.0
episode  1 , id  1955 , label  5.0
episode  1 , id  1956 , label  4.0
episode  1 , id  1957 , label  3.0
episode  1 , id  1958 , label  0.0
episode  1 , id  1959 , label  2.0
episode  1 , id  1960 , label  5.0
episode  1 , id  196

episode  2 , id  130 , label  8.0
episode  2 , id  131 , label  5.0
episode  2 , id  132 , label  6.0
episode  2 , id  133 , label  9.0
episode  2 , id  134 , label  7.0
episode  2 , id  135 , label  7.0
episode  2 , id  136 , label  7.0
episode  2 , id  137 , label  8.0
episode  2 , id  138 , label  9.0
episode  2 , id  139 , label  2.0
episode  2 , id  140 , label  8.0
episode  2 , id  141 , label  9.0
episode  2 , id  142 , label  9.0
episode  2 , id  143 , label  0.0
episode  2 , id  144 , label  7.0
episode  2 , id  145 , label  1.0
episode  2 , id  146 , label  6.0
episode  2 , id  147 , label  2.0
episode  2 , id  148 , label  1.0
episode  2 , id  149 , label  8.0
episode  2 , id  150 , label  9.0
episode  2 , id  151 , label  4.0
episode  2 , id  152 , label  7.0
episode  2 , id  153 , label  5.0
episode  2 , id  154 , label  8.0
episode  2 , id  155 , label  7.0
episode  2 , id  156 , label  5.0
episode  2 , id  157 , label  2.0
episode  2 , id  158 , label  1.0
episode  2 , i

episode  2 , id  377 , label  4.0
episode  2 , id  378 , label  6.0
episode  2 , id  379 , label  5.0
episode  2 , id  380 , label  5.0
episode  2 , id  381 , label  8.0
episode  2 , id  382 , label  8.0
episode  2 , id  383 , label  4.0
episode  2 , id  384 , label  9.0
episode  2 , id  385 , label  8.0
episode  2 , id  386 , label  1.0
episode  2 , id  387 , label  6.0
episode  2 , id  388 , label  6.0
episode  2 , id  389 , label  5.0
episode  2 , id  390 , label  2.0
episode  2 , id  391 , label  3.0
episode  2 , id  392 , label  9.0
episode  2 , id  393 , label  2.0
episode  2 , id  394 , label  6.0
episode  2 , id  395 , label  9.0
episode  2 , id  396 , label  4.0
episode  2 , id  397 , label  7.0
episode  2 , id  398 , label  3.0
episode  2 , id  399 , label  5.0
episode  2 , id  400 , label  2.0
episode  2 , id  401 , label  9.0
episode  2 , id  402 , label  5.0
episode  2 , id  403 , label  5.0
episode  2 , id  404 , label  3.0
episode  2 , id  405 , label  0.0
episode  2 , i

KeyboardInterrupt: 

In [None]:
train_data1, train_labels1 = read_data("sample_train.csv")
test_data1, test_labels1 = read_data("sample_test.csv")

In [None]:
def train1(model, optimizer, train_data1, train_labels1, plotdata):
    epochs = 1
    batch_size = 4
    for e in range(epochs):
        print(e)
        order = np.arange(train_data1.shape[0])
        np.random.shuffle(order)
        i = 0
        while i < train_data1.shape[0]:
            j = min(i + batch_size, train_data1.shape[0])
            images = train_data1[order[i:j], :]
            labels = torch.Tensor(train_labels1[order[i:j]]).long()
            optimizer.zero_grad()
            output = model(torch.from_numpy(images).float())
            prob = np.copy(output.detach().numpy())
            loss = F.nll_loss(output, labels)
            loss.backward()
            optimizer.step()
            i += batch_size
            with torch.no_grad():
                output = model(torch.from_numpy(test_data1).float())
            softmax = torch.exp(torch.Tensor(output))
            prob = list(softmax.numpy())
            predictions = np.argmax(prob, axis=1)
            plotdata.append([i + 1, accuracy_score(test_labels1, predictions)])
            
def train_AL(model, optimizer, train_data1, train_labels1, plotdata):
    epochs = 1
    for e in range(epochs):
        print(e)
        running_loss = 0
        num_batches = 0
        order = np.arange(train_data1.shape[0])
        np.random.shuffle(order)
        i = 0
        while i < train_data1.shape[0]:
            image = train_data1[order[i:i+1], :]
            part1 = np.copy(image[0])
            model.eval()
            with torch.no_grad():
                part2 = model(torch.from_numpy(image).float()).squeeze().detach().numpy()
            model.train()
            state = np.concatenate((part1, part2))
            a = select_action(state, dqnet, 1e9)
            if a == 1:
                label = torch.Tensor(train_labels1[order[i:i+1]]).long()
                optimizer.zero_grad()
                output = model(torch.from_numpy(image).float())
                prob = np.copy(output.detach().numpy())
                loss = F.nll_loss(output, label)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
            i += 1
            with torch.no_grad():
                output = model2(torch.from_numpy(test_data1).float())
            softmax = torch.exp(torch.Tensor(output))
            prob = list(softmax.numpy())
            predictions = np.argmax(prob, axis=1)
            plotdata.append([i + 1, accuracy_score(test_labels1, predictions)])
    return taken

def read_data(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()
    
    num_points = len(lines)
    dim_points = 28 * 28
    data = np.empty((num_points, dim_points))
    labels = np.empty(num_points)
    
    for ind, line in enumerate(lines):
        num = line.split(',')
        labels[ind] = int(num[0])
        data[ind] = [ int(x) for x in num[1:] ]
        
    return (data, labels)

In [None]:
model1 = Net(28 * 28)
model1_optimizer = optim.SGD(model1.parameters(), lr=1e-5, momentum=0.9)
D1 = []
train1(model1, model1_optimizer, train_data1, train_labels1, D1)

In [None]:
model2 = Net(28 * 28)
model2_optimizer = optim.SGD(model2.parameters(), lr=1e-5, momentum=0.9)
D2 = []
train_AL(model2, model2_optimizer, train_data1, train_labels1, D2)

In [None]:
x1 = [item[0] for item in D1]
y1 = [item[1] for item in D1]
x2 = [item[0] for item in D2]
y2 = [item[1] for item in D2]
plt.plot(x1, y1, c='springgreen', label='Passive Learning')
plt.plot(x2, y2, c='mediumpurple', label='Active Learning')
plt.title('Model: DQN        Dataset: MNIST')
plt.xlabel('Number of labelled instances')
plt.ylabel('Accuracy score')
plt.legend()
plt.show()