In [1]:
import math
import random
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
from sklearn.metrics import accuracy_score

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as T

In [2]:
def read_data(filename):
    with open(filename, 'r') as f:
        lines = f.readlines()
    
    num_points = len(lines)
    dim_points = 28 * 28
    data = np.empty((num_points, dim_points))
    labels = np.empty(num_points)
    
    for ind, line in enumerate(lines):
        num = line.split(',')
        labels[ind] = int(num[0])
        data[ind] = [ int(x) for x in num[1:] ]
        
    return (data, labels)

train_data, train_labels = read_data("sample_train.csv")
test_data, test_labels = read_data("sample_test.csv")
print(train_data.shape, test_data.shape)
print(train_labels.shape, test_labels.shape)
print(type(test_data[0,0]))

(6000, 784) (1000, 784)
(6000,) (1000,)
<class 'numpy.float64'>


In [3]:
class DQN(nn.Module):
    def __init__(self, in_features):
        super(DQN, self).__init__()
        self.fc1 = nn.Linear(in_features=in_features, out_features=32)
        self.fc2 = nn.Linear(in_features=32, out_features=2)

    def forward(self, t):
        t = F.relu(self.fc1(t))
        t = self.fc2(t)
        return t

Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))

class ReplayMemory(object):
    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def can_provide_sample(self, batch_size):
        return len(self.memory) >= batch_size

    def __len__(self):
        return len(self.memory)

class EpsilonGreedyStrategy():
    def __init__(self, start, end, decay):
        self.start = start
        self.end = end
        self.decay = decay

    def get_exploration_rate(self, current_step):
        return self.end + (self.start - self.end) * math.exp(-1. * current_step * self.decay)

In [4]:
class Net(nn.Module):
    def __init__(self, n_features):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(n_features, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return F.log_softmax(self.fc3(x))

def train(model, optimizer, X, criterion=nn.NLLLoss()):
    epochs = 1
    batch_size = 8
    for e in range(epochs):
        running_loss = 0
        num_batches = 0
        order = np.copy(X)
        np.random.shuffle(order)
        i = 0
        while i < len(X):
            j = min(i + batch_size, len(X))
#             print(order[i:j])
            images = train_data[order[i:j], :]
            labels = torch.Tensor(train_labels[order[i:j]]).long()
            optimizer.zero_grad()
            output = model(torch.from_numpy(images).float())
            loss = F.nll_loss(output, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            num_batches += 1
            i += batch_size
        else:
            print("Epoch {} - Training loss: {}".format(e, running_loss / num_batches))

def predict(model):
    model.eval()
    with torch.no_grad():
        output = model(torch.from_numpy(test_data).float())
    model.train()
    softmax = torch.exp(torch.Tensor(output))
    prob = list(softmax.numpy())
    predictions = np.argmax(prob, axis=1)
    return accuracy_score(test_labels, predictions)

In [7]:
batch_size = 32
gamma = 0.999
eps_start = 1
eps_end = 0.01
eps_decay = 0.001
memory_size = 1000
num_episodes = 10
state_len = 794
budget = 1500
strategy = EpsilonGreedyStrategy(eps_start, eps_end, eps_decay)
memory = ReplayMemory(memory_size)
dqnet = DQN(state_len)
target_dqnet = DQN(state_len)
dqnet_optimizer = optim.Adam(dqnet.parameters(), lr=1e-3)
num_actions = 2

def dqn_train(model, target_model, optimizer, mini_batch):
    criterion = nn.MSELoss()
    optimizer.zero_grad()
    states = np.array([exp[0] for exp in mini_batch])
    actions = torch.Tensor([[exp[1]] for exp in mini_batch]).long()
    rewards = torch.Tensor([exp[2] for exp in mini_batch])
    next_states = torch.Tensor([exp[3] for exp in mini_batch])
    model.train()
    target_model.eval()
    output = model(torch.from_numpy(states).float())
    predicted = torch.gather(output, 1, actions).squeeze()
    with torch.no_grad():
        labels_next = target_model(next_states).detach().max(1).values
    labels = rewards + gamma * labels_next
    loss = criterion(predicted, labels)
    loss.backward()
    optimizer.step()

def update_target_net(model, target_model, tau=1e-3):
    for target_param, local_param in zip(target_model.parameters(), model.parameters()):
        target_param.data.copy_(tau * local_param.data + (1 - tau) * target_param.data)

def select_action(state, model, current_step):
    eps = strategy.get_exploration_rate(current_step)
    if random.random() < eps:
        return random.randrange(num_actions) # explore
    else:
        model.eval()
        with torch.no_grad():
            a = model(torch.from_numpy(state).float()).argmax().item() # exploit
        model.train()
        return a

order = list(range(0, train_data.shape[0]))

X_labelled = []
random.shuffle(order)
for _ in range(num_episodes):
    i = 0
    j = 0
    current_step = 0
    print('Episode {} started'.format(_ + 1))
    model = Net(28*28)
    model_optimizer = optim.SGD(model.parameters(), lr=1e-5, momentum=0.9)
    prev_acc = predict(model)
    while j < train_data.shape[0]:
        part1 = np.copy(train_data[order[i]])
        part2 = model(torch.from_numpy(train_data[order[i]]).float()).detach().numpy()
        state = np.concatenate((part1, part2))
        a = select_action(state, dqnet, current_step)
        current_step += 1
        if a == 1:
            X_labelled.append(order[i])
            i += 1
            if i % 16 == 0:
                train(model, model_optimizer, X_labelled)
                print('Acc', acc)
        acc = predict(model)
        r = acc - prev_acc
        prev_acc = acc
        if i == budget:
            X_labelled = []
            random.shuffle(order)
            i = 0
            part1 = np.copy(train_data[order[i]])
            part2 = model(torch.from_numpy(train_data[order[i]]).float()).detach().numpy()
            new_state = np.concatenate((part1, part2))
            memory.push(state, a, r, new_state)
            break
        part1 = np.copy(train_data[order[i + 1]])
        part2 = model(torch.from_numpy(train_data[i + 1]).float()).detach().numpy()
        new_state = np.concatenate((part1, part2))
        memory.push(state, a, r, new_state)
        if i % 4 == 0 and memory.can_provide_sample(batch_size):
            mini_batch = memory.sample(batch_size)
            dqn_train(dqnet, target_dqnet, dqnet_optimizer, mini_batch)
        if i % 8 == 0:
            update_target_net(dqnet, target_dqnet)
        j += 1
#         print('Rem budget: ', budget - i, j)

Episode 1 started
Epoch 0 - Training loss: 9.494784355163574
Acc 0.108


  # This is added back by InteractiveShellApp.init_path()


Epoch 0 - Training loss: 8.746197819709778
Acc 0.115
Epoch 0 - Training loss: 8.639788707097372
Acc 0.132
Epoch 0 - Training loss: 5.517461180686951
Acc 0.158
Epoch 0 - Training loss: 3.1371501326560973
Acc 0.191
Epoch 0 - Training loss: 2.3503908862670264
Acc 0.227
Epoch 0 - Training loss: 1.9444488840443748
Acc 0.272
Epoch 0 - Training loss: 1.4046201333403587
Acc 0.304
Epoch 0 - Training loss: 1.2241539087974362
Acc 0.338
Epoch 0 - Training loss: 1.0489060312509537
Acc 0.372
Epoch 0 - Training loss: 0.8908535706048663
Acc 0.415
Epoch 0 - Training loss: 1.0498574102918308
Acc 0.447
Epoch 0 - Training loss: 0.9281893007170695
Acc 0.474
Epoch 0 - Training loss: 0.8283872519220624
Acc 0.509
Epoch 0 - Training loss: 0.6446093983948231
Acc 0.523
Epoch 0 - Training loss: 0.5949507185723633
Acc 0.539
Epoch 0 - Training loss: 0.5778817737146336
Acc 0.559
Epoch 0 - Training loss: 0.48229599506076837
Acc 0.567
Epoch 0 - Training loss: 0.42292896589558376
Acc 0.578
Epoch 0 - Training loss: 0.44

Epoch 0 - Training loss: 0.0055691109911234025
Acc 0.846
Epoch 0 - Training loss: 0.005456203099657003
Acc 0.845
Epoch 0 - Training loss: 0.00530580824911473
Acc 0.845
Epoch 0 - Training loss: 0.00517362817260286
Acc 0.846
Epoch 0 - Training loss: 0.00502726187401431
Acc 0.846
Epoch 0 - Training loss: 0.004914949005333924
Acc 0.843
Epoch 0 - Training loss: 0.00492353061983724
Acc 0.846
Epoch 0 - Training loss: 0.004765787162162123
Acc 0.845
Epoch 0 - Training loss: 0.004655432134188294
Acc 0.845
Epoch 0 - Training loss: 0.004562424260310936
Acc 0.844
Epoch 0 - Training loss: 0.004509183179029198
Acc 0.845
Epoch 0 - Training loss: 0.004382258348252016
Acc 0.845
Epoch 0 - Training loss: 0.004282799856159185
Acc 0.845
Epoch 0 - Training loss: 0.004189135210657486
Acc 0.844
Epoch 0 - Training loss: 0.004128686549729619
Acc 0.845
Epoch 0 - Training loss: 0.00402754014475476
Acc 0.846
Epoch 0 - Training loss: 0.003974972442035744
Acc 0.846
Epoch 0 - Training loss: 0.00388046164556651
Acc 0.8

Epoch 0 - Training loss: 0.12988266221157485
Acc 0.656
Epoch 0 - Training loss: 0.09328959511913126
Acc 0.671
Epoch 0 - Training loss: 0.0727057622773855
Acc 0.676
Epoch 0 - Training loss: 0.09553789312485605
Acc 0.667
Epoch 0 - Training loss: 0.12112596593298285
Acc 0.677
Epoch 0 - Training loss: 0.1298252023261739
Acc 0.678
Epoch 0 - Training loss: 0.10814676158337828
Acc 0.701
Epoch 0 - Training loss: 0.10058898659190163
Acc 0.707
Epoch 0 - Training loss: 0.07248938423939623
Acc 0.706
Epoch 0 - Training loss: 0.0937031322187977
Acc 0.706
Epoch 0 - Training loss: 0.08924281145105274
Acc 0.704
Epoch 0 - Training loss: 0.07158516933088235
Acc 0.717
Epoch 0 - Training loss: 0.06144044606751481
Acc 0.707
Epoch 0 - Training loss: 0.07432387099701869
Acc 0.723
Epoch 0 - Training loss: 0.06619941940168954
Acc 0.725
Epoch 0 - Training loss: 0.10167007788297032
Acc 0.732
Epoch 0 - Training loss: 0.08626681005463321
Acc 0.734
Epoch 0 - Training loss: 0.0864190070466672
Acc 0.741
Epoch 0 - Trai

Epoch 0 - Training loss: 0.03821429338987317
Acc 0.791
Epoch 0 - Training loss: 0.03167023020371366
Acc 0.795
Epoch 0 - Training loss: 0.030227575930702084
Acc 0.802
Epoch 0 - Training loss: 0.038697431849349026
Acc 0.805
Epoch 0 - Training loss: 0.036500569382753936
Acc 0.803
Epoch 0 - Training loss: 0.033333590905033725
Acc 0.801
Episode 6 started
Epoch 0 - Training loss: 9.177235126495361
Acc 0.116
Epoch 0 - Training loss: 8.19506049156189
Acc 0.123
Epoch 0 - Training loss: 6.009455839792888
Acc 0.132
Epoch 0 - Training loss: 3.9757786840200424
Acc 0.169
Epoch 0 - Training loss: 2.799916994571686
Acc 0.225
Epoch 0 - Training loss: 2.069724455475807
Acc 0.278
Epoch 0 - Training loss: 1.6163665098803384
Acc 0.314
Epoch 0 - Training loss: 1.2384975859895349
Acc 0.346
Epoch 0 - Training loss: 1.0289650426970587
Acc 0.383
Epoch 0 - Training loss: 0.8425578653812409
Acc 0.408
Epoch 0 - Training loss: 0.6808893545107408
Acc 0.444
Epoch 0 - Training loss: 0.5422363799686233
Acc 0.465
Epoch 

Epoch 0 - Training loss: 0.08540839135409596
Acc 0.783
Epoch 0 - Training loss: 0.08121890386273325
Acc 0.78
Epoch 0 - Training loss: 0.08884869673797353
Acc 0.785
Epoch 0 - Training loss: 0.08109782421420927
Acc 0.788
Epoch 0 - Training loss: 0.07731937758395807
Acc 0.781
Epoch 0 - Training loss: 0.07016924804614472
Acc 0.781
Epoch 0 - Training loss: 0.06008161128680143
Acc 0.784
Epoch 0 - Training loss: 0.0678739440588591
Acc 0.784
Epoch 0 - Training loss: 0.07407372971790553
Acc 0.782
Epoch 0 - Training loss: 0.05968116483023961
Acc 0.787
Epoch 0 - Training loss: 0.05752157914038333
Acc 0.79
Epoch 0 - Training loss: 0.07185040799595299
Acc 0.792
Epoch 0 - Training loss: 0.06970820580161391
Acc 0.793
Epoch 0 - Training loss: 0.054737982574310576
Acc 0.791
Epoch 0 - Training loss: 0.057602543265683884
Acc 0.791
Epoch 0 - Training loss: 0.053688392942764465
Acc 0.797
Epoch 0 - Training loss: 0.052171321474420634
Acc 0.798
Epoch 0 - Training loss: 0.04635170839776817
Acc 0.803
Epoch 0 -

Epoch 0 - Training loss: 0.25912864930513835
Acc 0.564
Epoch 0 - Training loss: 0.34368440445120396
Acc 0.561
Epoch 0 - Training loss: 0.3541848157066852
Acc 0.568
Epoch 0 - Training loss: 0.3464086518312494
Acc 0.571
Epoch 0 - Training loss: 0.25898941487751226
Acc 0.585
Epoch 0 - Training loss: 0.244648572543393
Acc 0.59
Epoch 0 - Training loss: 0.2529952134937048
Acc 0.6
Epoch 0 - Training loss: 0.22228206194937228
Acc 0.615
Epoch 0 - Training loss: 0.18359173429556763
Acc 0.617
Epoch 0 - Training loss: 0.1777164591131387
Acc 0.624
Epoch 0 - Training loss: 0.12442268660691168
Acc 0.644
Epoch 0 - Training loss: 0.14655067238571315
Acc 0.646
Epoch 0 - Training loss: 0.1353764933689187
Acc 0.65
Epoch 0 - Training loss: 0.12201697317763202
Acc 0.66
Epoch 0 - Training loss: 0.1199311423843028
Acc 0.662
Epoch 0 - Training loss: 0.1404277610056328
Acc 0.666
Epoch 0 - Training loss: 0.14628014632719843
Acc 0.677
Epoch 0 - Training loss: 0.139905808756261
Acc 0.676
Epoch 0 - Training loss: 0

Epoch 0 - Training loss: 0.05260126403491564
Acc 0.807
Epoch 0 - Training loss: 0.04574132944508637
Acc 0.808
Epoch 0 - Training loss: 0.05668191375876538
Acc 0.805
Epoch 0 - Training loss: 0.06339989554157623
Acc 0.807
Epoch 0 - Training loss: 0.05802470530816208
Acc 0.802
Epoch 0 - Training loss: 0.06245970173363781
Acc 0.816
Epoch 0 - Training loss: 0.05955495409526942
Acc 0.821
Epoch 0 - Training loss: 0.045004311448989816
Acc 0.816
Epoch 0 - Training loss: 0.04526184238282093
Acc 0.819
Epoch 0 - Training loss: 0.052810858041287725
Acc 0.826
Epoch 0 - Training loss: 0.07609858773012515
Acc 0.822
Epoch 0 - Training loss: 0.07659533979764593
Acc 0.828
Epoch 0 - Training loss: 0.06355818933033416
Acc 0.831
Epoch 0 - Training loss: 0.06549926146601975
Acc 0.828
Epoch 0 - Training loss: 0.05591238141597488
Acc 0.828
Epoch 0 - Training loss: 0.0597434039908886
Acc 0.827
Epoch 0 - Training loss: 0.051250534783795716
Acc 0.829
Epoch 0 - Training loss: 0.05410155643376031
Acc 0.831
