In [None]:
%run network.ipynb

In [None]:
import time
import os
import torch
import random
import gymnasium
import flappy_bird_gymnasium
from datetime import datetime
import torch.optim as optim

def main():
    # Parameters set up
    random.seed(datetime.now().timestamp())
    env = gymnasium.make("FlappyBird-v0")
    agent = DQN()
    optimizer = optim.Adam(agent.parameters(), lr=1e-4)
    buffer_size = 10000
    num_epochs = 60
    max_performance = 85
    load_w(agent, optimizer)
    for i in range(2000):
        print("")
        print("Iteration ", i)
        print("Max performance", max_performance)
        train(env, agent, optimizer, buffer_size, num_epochs, 0.1 * (2000-i)/2000)
        score = average_score(20, agent)
        if score > max_performance:
            max_performance = score
            save_w(agent, optimizer)
    print("Max performance is", max_performance)
        
def save_w(model, optimizer):
    os.makedirs("saved_model", exist_ok=True)
    save_path = os.path.join("saved_model", "DQN.pkl")
    torch.save(dict(
        model=model.state_dict(),
        optimizer=optimizer.state_dict()
    ), save_path)

def load_w(model, optimizer):
    log_dir = os.path.abspath(os.path.expanduser("saved_model"))
    save_path = os.path.join(log_dir, "DQN.pkl")
    if os.path.isfile(save_path):
        state_dict = torch.load(
            save_path,
            torch.device('cpu') if not torch.cuda.is_available() else None
        )
        model.load_state_dict(state_dict["model"])
        optimizer.load_state_dict(state_dict["optimizer"])
        print("Successfully loaded weights from {}!".format(save_path))
        return True
    else:
        raise ValueError("Failed to load weights from {}! File does not exist!".format(save_path))
    
def train(env, agent, optimizer, buffer_size, num_epochs, e):
    with torch.no_grad():
        buffer_obs, buffer_action, buffer_t = create_buffer(agent, buffer_size, env, e)
    for epoch in range(num_epochs):
        sample_size = 5000
        obs, action, t = create_samples(buffer_obs, buffer_action, \
                                        buffer_t, buffer_size, sample_size)
        qvalues = agent(obs, action)
        assert qvalues.requires_grad
        l = torch.nn.MSELoss()
        loss = l(input=qvalues.view(t.shape), target=t)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

def create_samples(obs, a, t, buffer_size, sample_size):
    sample_obs = torch.zeros(obs.shape)
    sample_a = torch.zeros(a.shape)
    sample_t = torch.zeros(t.shape)
    for i in range(sample_size):
        random_i = random.randint(0, buffer_size-1)
        sample_obs[i] = obs[random_i]
        sample_a[i] = a[random_i]
        sample_t[i] = t[random_i]
    return sample_obs[:sample_size,:], sample_a[:sample_size], sample_t[:sample_size]
        

def create_buffer(agent, buffer_size, env, e):
    # Create buffer
    obs_batch = torch.zeros(buffer_size, 12)
    action_batch = torch.zeros(buffer_size)
    target_batch = torch.zeros(buffer_size)
    
    # Collect Samples
    terminated = True
    num = 0
    for i in range(buffer_size):
        new_episode = terminated
        if new_episode:
            # if this is a new episode
            obs, _ = env.reset()
        obs = torch.tensor(obs, dtype=torch.float32)
        obs_batch[i] = obs
        # Action
        action = agent.explore(obs, e)
        action_batch[i] = action
        obs, reward, terminated, _, info = env.step(action)
        if obs[9] < 0:
            reward = - 1
            terminated = True
        if not terminated:
            # Q max
            obs = torch.tensor(obs, dtype=torch.float32)
            pipe_w = 1/6
            bird_w = 25/230
            bird_pos = 0.305556 - bird_w
            bird_height = 0.06
            p_pos1 = obs[0] + pipe_w 
            p_pos2 = obs[3] + pipe_w 
            p_pos3 = obs[6] + pipe_w 
            bird_bottom = obs[9]
            if bird_pos < p_pos1:
                if obs[9] > obs[2]:
                    reward = 0.01
                elif obs[9] < obs[1] + bird_height:
                    reward = 0.01
            elif bird_pos < p_pos2:
                if obs[9] > obs[5]:
                    reward = 0.01
                elif obs[9] < obs[4] + bird_height:
                    reward = 0.01
            else:
                raise
            q_max = agent.compute_value(obs)
            target_batch[i] = reward + q_max
        else:
            target_batch[i] = reward 
        if reward > 0.5:
            num = num + 1        
    print(num)
    return obs_batch, action_batch, target_batch

def test_performance(agent):
    random.seed(datetime.now().timestamp())
    env = gymnasium.make("FlappyBird-v0")
    seed_number = random.randint(1, 100)
    obs, _ = env.reset(seed=seed_number)
    total_reward = 0
    while True:
        # Next action:
        # (feed the observation to your agent here)
        action = agent.compute_action(torch.tensor(obs, dtype=torch.float32))

        # Processing:
        obs, reward, terminated, _, info = env.step(action)
        total_reward = total_reward + reward
        # Checking if the player is still alive
        if terminated or info['score'] > 150:
            break
    env.close()
    return info['score']

def average_score(runs, agent):
    total = 0
    for i in range(runs):
        total = total + test_performance(agent)
    average = total/runs
    print("average score is ", average)
    return average
    
main()
    