In [54]:
import torch
import torch.nn as nn
import random
import numpy as np

TOTAL_SQUARES = 16
TOTAL_OPTIONS = 20
TOTAL_ACTIONS = 4
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
torch.random.manual_seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x7fb7798c08f0>

In [231]:
class State2048:
    def __init__(self,version):
        self.version = version
        if self.version == "Log":
            self.base = 1
        elif self.version == "Reg":
            self.base = 2
        else:
            print("VERSION MUST BE EITHER 'REG' OR 'LOG'")
    
    def initialize_state(self):
        state = np.zeros((4,4))
        check = self.sample_n_locations_in_place(state, 2)
        assert np.sum(state != 0) == 2
        assert check
        return state
    
    def gen_new_number(self):
        z = random.uniform(0,1)
        if z > 0.1:
            return self.base
        else:
            return self.increment_number(self.base)
        
    def increment_number(self, num):
        if self.version == "Reg":
            return num * 2
        elif self.version == "Log":
            return num + 1
        else:
            print("VERSION MUST BE EITHER 'REG' OR 'LOG'")
        
    def move_by_action(self, input_state: np.ndarray, action: int):
        curr_state = np.copy(input_state)
        done = [[0 for _ in range(4)] for _ in range(4)]
        combines, combined_amt = 0, 0
        if action == 0:
            for c in range(4):
                for r in range(1,4):
                    num = curr_state[r,c]
                    if num == 0: continue
                    found = -1
                    for j in range(r-1,-1,-1):
                        if curr_state[j,c] != 0:
                            found = j
                            break
                    if found < 0 or curr_state[found,c] != num or done[found][c]:
                        curr_state[r,c] = 0
                        curr_state[found+1,c] = num
                    else:
                        curr_state[r,c] = 0
                        curr_state[found,c] = self.increment_number(curr_state[found,c])
                        combined_amt += curr_state[found,c]
                        combines += 1
                        done[found][c] = 1
        elif action == 1:
            for r in range(4):
                for c in range(2,-1,-1):
                    num = curr_state[r,c]
                    if num == 0: continue
                    found = 4
                    for j in range(c+1,4):
                        if curr_state[r,j] != 0:
                            found = j
                            break
                    if found > 3 or curr_state[r,found] != num or done[r][found]:
                        curr_state[r,c] = 0
                        curr_state[r,found-1] = num
                    else:
                        curr_state[r,c] = 0
                        curr_state[r,found] = self.increment_number(curr_state[r,found])
                        combined_amt += curr_state[found,c]
                        combines += 1
                        done[r][found] = 1
        elif action == 2:
            for c in range(4):
                for r in range(2,-1,-1):
                    num = curr_state[r,c]
                    if num == 0: continue
                    found = 4
                    for j in range(r+1,4):
                        if curr_state[j,c] != 0:
                            found = j
                            break
                    if found > 3 or curr_state[found,c] != num or done[found][c]:
                        curr_state[r,c] = 0
                        curr_state[found-1,c] = num
                    else:
                        curr_state[r,c] = 0
                        curr_state[found,c] = self.increment_number(curr_state[found,c])
                        combined_amt += curr_state[found,c]
                        combines += 1
                        done[found][c] = 1
        elif action == 3:
            for r in range(4):
                for c in range(1,4):
                    num = curr_state[r,c]
                    if num == 0: continue
                    found = -1
                    for j in range(c-1,-1,-1):
                        if curr_state[r,j] != 0:
                            found = j
                            break
                    if found < 0 or curr_state[r,found] != num or done[r][found]:
                        curr_state[r,c] = 0
                        curr_state[r,found+1] = num
                    else:
                        curr_state[r,c] = 0
                        curr_state[r,found] = self.increment_number(curr_state[r,found])
                        combined_amt += curr_state[found,c]
                        combines += 1
                        done[r][found] = 1
        else:
              print("ERROR INCORRECT INPUT ACTION")
        
        return curr_state, combines, combined_amt
    
    def transition(self, input_state, action: int):
        if not(action in self.possible_actions(input_state)):
            return input_state, True, False
        
        new_state, combined_count, combined_amt = self.move_by_action(input_state, action)
        if not(self.check_move_is_possible(new_state)):
            return new_state, False, True
        else:            
            prev = np.sum(new_state != 0)
            check = self.sample_n_locations_in_place(new_state, 1)
            assert prev == 16 or prev + 1 == np.sum(new_state != 0)
            assert check
            return new_state, True, True
        
    def check_move_is_possible(self, state):
        for r in range(4):
            for c in range(4):
                if state[r][c] == 0:
                    return True
                if r+1 < 4 and state[r+1][c] == state[r][c]:
                    return True
                if c+1 < 4 and state[r][c+1] == state[r][c]:
                    return True
        return False
    
    def used_spots(self, state):
        return np.sum(state != 0)
    
    def state_score(self, state, not_term, valid_move):
        if not(valid_move):
            return -20
        elif not(not_term):
            return -10
        else:
            count = np.sum(state != 0)
            s = np.sum(state)
            avg = s/count
            return avg
    
    def check_action_is_possible(self, state, action):
        if action == 0:
            for r in range(1,4):
                for c in range(4):
                    if state[r,c] != 0 and (state[r-1,c] == state[r,c] or state[r-1,c] == 0): return True
        elif action == 1:
            for r in range(4):
                for c in range(3):
                    if state[r,c] != 0 and (state[r,c+1] == state[r,c] or state[r,c+1] == 0): return True
        elif action == 2:
            for r in range(3):
                for c in range(4):
                    if state[r,c] != 0 and (state[r+1,c] == state[r,c] or state[r+1,c] == 0): return True
        elif action == 3:
            for r in range(4):
                for c in range(1,4):
                    if state[r,c] != 0 and (state[r,c-1] == state[r,c] or state[r,c-1] == 0): return True
        else:
            print("ERROR INCORRECT INPUT ACTION")
        return False
        
    def sample_n_locations_in_place(self, input_state, sample_count: int):
        locations = [(i,j) for i in range(4) for j in range(4) if input_state[i,j] == 0]
        if len(locations) == 0: return False
        for sample_r, sample_c in random.sample(locations, sample_count):
            input_state[sample_r,sample_c] = self.gen_new_number()
        return True
    
    def print_state(self, state):
        for r in range(4):
            for c in range(4):
                print(int(state[r][c]), end = " ")
            print()
    
    def possible_actions(self, state):
        return [i for i in range(4) if self.check_action_is_possible(state, i)]
            

In [232]:
#def conv_state(state_np):
#    return torch.flatten(torch.nn.functional.one_hot(torch.tensor(state_np).to(torch.int64), TOTAL_OPTIONS)).float()

class CONV_STATE:
    def __init__(self,version, ENV):
        self.version = version
        self.ENV = ENV
        allowed = ["base", "each", "alt", "old"]
        if not(version in allowed):
            print("ERROR: ILLEGAL CONV_STATE VERSION")
    def conv_state_old(self, state_np):
        return torch.flatten(torch.nn.functional.one_hot(torch.tensor(state_np).to(torch.int64), TOTAL_OPTIONS)).float()
    def conv_state_base(self, state_np):
        res = np.zeros((4, 4, 2, 3))
        for r in range(4):
            for c in range(4):
                curr, lr, lc = state_np[r][c], [], []
                for nr in range(4):
                    if nr == r: continue
                    if state_np[nr][c] == curr: lr.append(curr)
                    elif curr == 0 or state_np[nr][c] == 0: lr.append(0)
                    else: lr.append(-max(curr, state_np[nr][c]))
                for nc in range(4):
                    if nc == c: continue
                    if state_np[r][nc] == curr: lc.append(curr)
                    elif curr ==0 or state_np[r][nc] == 0: lc.append(0)
                    else: lc.append(-max(curr, state_np[r][nc]))
                for i, val in enumerate(lr):
                    res[r][c][0][i] = val
                for i, val in enumerate(lc):
                    res[r][c][1][i] = val
        input_tensor = torch.tensor(np.reshape(res, (-1,)))
        return input_tensor.float()

    def conv_state_each(self, state_np):
        res = np.reshape(state_np, (-1,))
        for a in range(4):
            state_next_np, non_term, valid_move = self.ENV.transition(state_np, a)
            if not(non_term) or not(valid_move):
                res = np.append(res, np.reshape(np.zeros(state_np.shape) - 1, (-1,)))
            else:
                res = np.append(res, np.reshape(state_np, (-1,)))
        input_tensor = torch.tensor(res).float()
        return input_tensor

    def conv_state_alt(self, state_np):
        res = np.zeros(4)
        for a in range(4):
            state_next_np, non_term, valid_move = self.ENV.transition(state_np, a)
            if not(non_term) or not(valid_move):
                res[a] = -1
            else:
                res[a] = 1
        input_tensor = torch.tensor(res).float()
        return input_tensor
    
    def convert(self, state_np):
        if self.version == 'base':
            return self.conv_state_base(state_np)
        elif self.version == 'each':
            return self.conv_state_each(state_np)
        elif self.version == 'alt':
            return self.conv_state_alt(state_np)
        elif self.version == 'old':
            return self.conv_state_old(state_np)
        else:
            print("ERROR: ILLEGAL CONV_STATE VERSION")
        

In [233]:
class DQN(nn.Module):
    def __init__(self, layers, learning_rate):
        super().__init__()
        self.model = nn.Sequential(*layers)
        self.optimizer = torch.optim.Adam(params = self.model.parameters(), lr = learning_rate)
        
    def forward(self, input_state):
        return self.model(input_state)
    
    def backprop_single(self, state, action, true_value):
        self.optimizer.zero_grad()
        computed_value = torch.sum(self.model(state) * torch.nn.functional.one_hot(torch.tensor(action).to(torch.int64), TOTAL_ACTIONS))
        mse_loss = (computed_value - true_value) ** 2
        mse_loss.backward()
        self.optimizer.step()
        return mse_loss.item()
        
    def backprop_batch(self, states, actions, true_values):
        self.optimizer.zero_grad()
        mask = torch.nn.functional.one_hot(actions.to(torch.int64), TOTAL_ACTIONS)
        values = torch.sum(self.model(states) * mask, dim = 1)
        mse_loss = torch.nn.functional.mse_loss(true_values.float(), values.float())
        mse_loss.backward()
        self.optimizer.step()
        return mse_loss.item()

In [1]:
epsilon = 1
GAMMA = 1
iterations = 1000
minibatch_size = 10
copy_model_freq = 10
start_lr = 1e-4
decrement_lr_freq = 100
model_save_freq = 100
decrement_eps_freq = 1
eps_decay_rate = 0.8

state_version = "old"
state_conv = {"old" : 4 * 4 * TOTAL_OPTIONS, "base" : 4 * 4 * 2 * 3, "each" : 4 * 4 * 5, "alt" : 4}

ENV = State2048("Log")
ENCODER = CONV_STATE(state_version, ENV)
IN_DIM = int(state_conv[state_version])
OUT_DIM = TOTAL_ACTIONS

layers = [nn.Linear(IN_DIM,128), nn.ReLU(), nn.Linear(128,64), nn.ReLU(), nn.Linear(64,32), nn.ReLU(), nn.Linear(32,16), nn.ReLU(), nn.Linear(16,8), nn.ReLU(), nn.Linear(8,OUT_DIM)]
DQNetwork = DQN(layers, start_lr)

score_history = np.array([])
best_square_history = np.array([])
move_history = np.array([])

D = []

def greedy_action(model, state):
    action_values = model(state)
    #print(action_values)
    #print()
    action = torch.argmax(action_values)
    return int(action), action_values[action]

def greedy_action_full(all_models, state):
    action_values = np.zeros(4)
    for action in range(4):
        action_values[action] = all_models[action](state)
    action = np.argmax(action_values)
    return int(action), action_values[action]

def buffer_train(D, sample_size, g, model):
    sample = random.sample(D, min(len(D),sample_size))
    loss = 0
    for s,a,r,s_next,n_t,v_m in sample:
        tv = r
        if n_t and v_m: tv += g * greedy_action(model, s_next)[1].item()
        loss += model.backprop_single(s, a, tv)
    avg_loss = loss/len(sample)
    return avg_loss

valid_move_rate = []
action_freq = np.zeros(4)
avg_losses = []
total_moves = []
total_score = []
highest_tile = []

D = []
#DD = [[] for _ in range(4)]
#counts = [[0,0] for _ in range(4)]

#state_np = ENV.initialize_state()

for iteration in range(1,iterations+1):
    moves = 0
    real_moves = 0
    valid_moves = 0
    non_term = True
    state_np = ENV.initialize_state()
    loss = 0
    trains = 0
    
    epsilon *= 0.8
    
    avg_loss = 0
    avg_loss_count= 0
    
    while non_term and moves < 100:
        #state = conv_state(state_np)
        #state = conv_state_each(state_np, ENV)
        state = ENCODER.convert(state_np)
        
        action = -1
        rand = False
        if random.uniform(0,1) <= epsilon:
            action = random.choice([0,1,2,3])
            rand = True
        else:
            action, predicted_action_value = greedy_action(DQNetwork, state)
        
        state_next_np, non_term, valid_move = ENV.transition(state_np, action)
        state_next = ENCODER.convert(state_next_np)
        reward = ENV.state_score(state_next_np, non_term, valid_move)
        
        if not(rand): real_moves += 1
        moves += 1
        valid_moves += (not(rand) and valid_move)
        if not(rand): action_freq[action] += 1
        
        D.append((state,action,reward,state_next,non_term,valid_move))
        
        #samples = random.sample(D, min(len(D),minibatch_size))
        #for s,a,r,s_next,poss_actions,not_term in samples:
        #    _, target_action_value = generate_greedy_action(TargetNetwork, s_next, poss_actions)
        #    target_val = r + GAMMA * target_action_value.item()
        #    if DOUBLE_DQN:
        #        if np.random.uniform(0,1) < 0.5:
        #            TargetNetwork.backprop_single(s, a, r + GAMMA * generate_greedy_action(TargetNetwork, s_next, poss_actions, True, False, DQNetwork).item())
        #        else:
        #            DQNetwork.backprop_single(s, a, r + GAMMA * generate_greedy_action(DQNetwork, s_next, poss_actions, True, False, DQNetwork).item())
        #    else:
        #        if not_term:
        #            DQNetwork.backprop_single(s, a, target_val)
        #        else:
        #            DQNetwork.backprop_single(s, a, r)
        
        
        _, target_action_value = greedy_action(DQNetwork, state_next)
        target_val = reward + GAMMA * target_action_value.item()
        if not(valid_move) or not(non_term): target_val = reward
        DQNetwork.backprop_single(state, action, target_val)
        poss = ENV.possible_actions(state_np)
        if len(poss) < 4:
            trains += 1
            for act in range(4):
                t1, t2, t3 = ENV.transition(state_np, act)
                rew = ENV.state_score(t1,t2,t3)
                D.append((state,act,rew,ENCODER.convert(t1),t2,t3))
                #tl = DQNetwork.backprop_single(state, act, rew)
                #loss += tl
        
        avg_loss += buffer_train(D, 64, 0.9, DQNetwork)
        avg_loss_count += 1
        
        state_np = state_next_np
        
    if iteration%100 == 0:
        print(iteration)
    avg_loss /= avg_loss_count
    #print(loss)
    #print(moves, real_moves, valid_moves)
    avg_losses.append(avg_loss)
    valid_move_rate.append(valid_moves/real_moves)
    total_moves.append(moves)
    total_score.append(np.sum(state_np))
    highest_tile.append(int(2 ** (np.amax(state_np))))
    np.savetxt("valid_move_rate.txt", np.array(valid_move_rate))
    np.savetxt("action_freq.txt", action_freq)
    np.savetxt("avg_losses.txt", np.array(avg_losses))
    np.savetxt("total_moves.txt", np.array(total_moves))
    np.savetxt("total_score.txt", np.array(total_score))
    np.savetxt("highest_tile.txt", np.array(highest_tile))

NameError: name 'TOTAL_OPTIONS' is not defined

In [229]:
good, bad = 0, 0
for state, act, rew, v in D:
    if rew == 0: bad += 1
    else: good += 1
        
print(good, bad)

70170 29830


In [230]:
avg_loss_pure = []
avg_acc_pure = []
DQNetwork = DQN(layers, 0.0001)
for iteration in range(1000):
    sample = random.sample(D, 64)
    avg_loss = 0
    acc = 0
    ppp = 0
    for s, a, r, poss in sample:
        r = r * 2 - 1
        action_values = DQNetwork(s)
        action = int(torch.argmax(action_values))
        acc += action in poss
        ppp += len(poss) > 0
        loss = DQNetwork.backprop_single(s,a,r)
        avg_loss += loss
    avg_loss /= len(sample)
    acc /= ppp
    avg_loss_pure.append(avg_loss)
    avg_acc_pure.append(acc)
    np.savetxt("avg_loss_pure.txt", np.array(avg_loss_pure))
    np.savetxt("avg_acc_pure.txt", np.array(avg_acc_pure))

KeyboardInterrupt: 

In [163]:
for action in range(4):
    print("ACTION: ", action)
    for iteration in range(100):
        sample = random.sample(DD[action], 1000)
        avg_loss = 0
        for s, a, r in sample:
            loss = DQNFull[action].backprop_single(s,a,r)
            print(r, DQNFull[action](s))
            avg_loss += loss
        avg_loss /= len(sample)
    return avg_loss

ACTION:  0
0 tensor([-0.1967], grad_fn=<AddBackward0>)
1 tensor([-0.0881], grad_fn=<AddBackward0>)
1 tensor([0.1707], grad_fn=<AddBackward0>)
0 tensor([0.3784], grad_fn=<AddBackward0>)
1 tensor([0.6575], grad_fn=<AddBackward0>)
0 tensor([0.8113], grad_fn=<AddBackward0>)
1 tensor([0.9777], grad_fn=<AddBackward0>)
1 tensor([1.1309], grad_fn=<AddBackward0>)
0 tensor([1.1010], grad_fn=<AddBackward0>)
1 tensor([1.0592], grad_fn=<AddBackward0>)
0 tensor([0.8646], grad_fn=<AddBackward0>)
0 tensor([0.5615], grad_fn=<AddBackward0>)
1 tensor([0.3536], grad_fn=<AddBackward0>)
0 tensor([0.1140], grad_fn=<AddBackward0>)
1 tensor([0.0297], grad_fn=<AddBackward0>)
0 tensor([-0.0506], grad_fn=<AddBackward0>)
0 tensor([-0.1154], grad_fn=<AddBackward0>)
0 tensor([-0.1566], grad_fn=<AddBackward0>)
1 tensor([-0.0223], grad_fn=<AddBackward0>)
1 tensor([0.2500], grad_fn=<AddBackward0>)
0 tensor([0.4581], grad_fn=<AddBackward0>)
0 tensor([0.5775], grad_fn=<AddBackward0>)
0 tensor([0.5994], grad_fn=<AddBackwa

KeyboardInterrupt: 

In [161]:
print(counts)

[[74969, 74974], [74777, 74786], [74851, 74861], [75247, 75257]]


In [158]:
for iteration in range(100):
    sample = random.sample(D, min(len(D),10000))
    avg_loss = 0
    for s, a, r in sample:
        loss = DQNetwork.backprop_single(s,a,r)
        avg_loss += loss
    avg_loss /= len(sample)
    print(avg_loss)

1.0381510297510896
1.0080166441151295
1.009635012982856
1.0246107665729234
0.9912681942549865
1.0149562809613804
1.011351292843352
1.0231753498534022
0.9990903886910532
1.0956049033678235
0.9750692650138489
0.9610300931951852
0.9940751549784985
0.9631001674925337
1.0131517821022187
1.0895837585776291
0.9856074868755768
1.0294584598070455
1.0069715013987546


KeyboardInterrupt: 