In [1]:
%run network.ipynb

In [2]:
import time
import os
import torch
import random
import gymnasium
import flappy_bird_gymnasium
from datetime import datetime
import torch.optim as optim

def main():
    # Parameters set up
    random.seed(datetime.now().timestamp())
    env = gymnasium.make("FlappyBird-v0")
    agent = DQN()
    optimizer = optim.Adam(agent.parameters(), lr=1e-4)
    buffer_size = 10000
    num_epochs = 60
    max_performance = 85
    load_w(agent, optimizer)
    for i in range(2000):
        print("")
        print("Iteration ", i)
        print("Max performance", max_performance)
        train(env, agent, optimizer, buffer_size, num_epochs, 0.1 * (2000-i)/2000)
        score = average_score(20, agent)
        if score > max_performance:
            max_performance = score
            save_w(agent, optimizer)
    print("Max performance is", max_performance)
        
def save_w(model, optimizer):
    os.makedirs("saved_model", exist_ok=True)
    save_path = os.path.join("saved_model", "DQN.pkl")
    torch.save(dict(
        model=model.state_dict(),
        optimizer=optimizer.state_dict()
    ), save_path)

def load_w(model, optimizer):
    log_dir = os.path.abspath(os.path.expanduser("saved_model"))
    save_path = os.path.join(log_dir, "DQN.pkl")
    if os.path.isfile(save_path):
        state_dict = torch.load(
            save_path,
            torch.device('cpu') if not torch.cuda.is_available() else None
        )
        model.load_state_dict(state_dict["model"])
        optimizer.load_state_dict(state_dict["optimizer"])
        print("Successfully loaded weights from {}!".format(save_path))
        return True
    else:
        raise ValueError("Failed to load weights from {}! File does not exist!".format(save_path))
    
def train(env, agent, optimizer, buffer_size, num_epochs, e):
    with torch.no_grad():
        buffer_obs, buffer_action, buffer_t = create_buffer(agent, buffer_size, env, e)
    for epoch in range(num_epochs):
        sample_size = 5000
        obs, action, t = create_samples(buffer_obs, buffer_action, \
                                        buffer_t, buffer_size, sample_size)
        qvalues = agent(obs, action)
        assert qvalues.requires_grad
        l = torch.nn.MSELoss()
        loss = l(input=qvalues.view(t.shape), target=t)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

def create_samples(obs, a, t, buffer_size, sample_size):
    sample_obs = torch.zeros(obs.shape)
    sample_a = torch.zeros(a.shape)
    sample_t = torch.zeros(t.shape)
    for i in range(sample_size):
        random_i = random.randint(0, buffer_size-1)
        sample_obs[i] = obs[random_i]
        sample_a[i] = a[random_i]
        sample_t[i] = t[random_i]
    return sample_obs[:sample_size,:], sample_a[:sample_size], sample_t[:sample_size]
        

def create_buffer(agent, buffer_size, env, e):
    # Create buffer
    obs_batch = torch.zeros(buffer_size, 12)
    action_batch = torch.zeros(buffer_size)
    target_batch = torch.zeros(buffer_size)
    
    # Collect Samples
    terminated = True
    num = 0
    for i in range(buffer_size):
        new_episode = terminated
        if new_episode:
            # if this is a new episode
            obs, _ = env.reset()
        obs = torch.tensor(obs, dtype=torch.float32)
        obs_batch[i] = obs
        # Action
        action = agent.explore(obs, e)
        action_batch[i] = action
        obs, reward, terminated, _, info = env.step(action)
        if obs[9] < 0:
            reward = - 1
            terminated = True
        if not terminated:
            # Q max
            obs = torch.tensor(obs, dtype=torch.float32)
            pipe_w = 1/6
            bird_w = 25/230
            bird_pos = 0.305556 - bird_w
            bird_height = 0.06
            p_pos1 = obs[0] + pipe_w 
            p_pos2 = obs[3] + pipe_w 
            p_pos3 = obs[6] + pipe_w 
            bird_bottom = obs[9]
            if bird_pos < p_pos1:
                if obs[9] > obs[2]:
                    reward = 0.01
                elif obs[9] < obs[1] + bird_height:
                    reward = 0.01
            elif bird_pos < p_pos2:
                if obs[9] > obs[5]:
                    reward = 0.01
                elif obs[9] < obs[4] + bird_height:
                    reward = 0.01
            else:
                raise
            q_max = agent.compute_value(obs)
            target_batch[i] = reward + q_max
        else:
            target_batch[i] = reward 
        if reward > 0.5:
            num = num + 1        
    print(num)
    return obs_batch, action_batch, target_batch

def test_performance(agent):
    random.seed(datetime.now().timestamp())
    env = gymnasium.make("FlappyBird-v0")
    seed_number = random.randint(1, 100)
    obs, _ = env.reset(seed=seed_number)
    total_reward = 0
    while True:
        # Next action:
        # (feed the observation to your agent here)
        action = agent.compute_action(torch.tensor(obs, dtype=torch.float32))

        # Processing:
        obs, reward, terminated, _, info = env.step(action)
        total_reward = total_reward + reward
        # Checking if the player is still alive
        if terminated or info['score'] > 150:
            break
    env.close()
    return info['score']

def average_score(runs, agent):
    total = 0
    for i in range(runs):
        total = total + test_performance(agent)
    average = total/runs
    print("average score is ", average)
    return average
    
main()
    

Successfully loaded weights from C:\Users\shiyu\Bird-RL\saved_model\DQN.pkl!

Iteration  0
Max performance 70


  obs = torch.tensor(obs, dtype=torch.float32)


8
average score is  57.85

Iteration  1
Max performance 70
9
average score is  4.4

Iteration  2
Max performance 70
6
average score is  22.95

Iteration  3
Max performance 70
4
average score is  13.75

Iteration  4
Max performance 70
12
average score is  3.95

Iteration  5
Max performance 70
7
average score is  7.15

Iteration  6
Max performance 70
14
average score is  7.85

Iteration  7
Max performance 70
9
average score is  8.65

Iteration  8
Max performance 70
10
average score is  14.8

Iteration  9
Max performance 70
11
average score is  15.55

Iteration  10
Max performance 70
8
average score is  11.7

Iteration  11
Max performance 70
14
average score is  9.8

Iteration  12
Max performance 70
14
average score is  16.4

Iteration  13
Max performance 70
9
average score is  13.4

Iteration  14
Max performance 70
9
average score is  11.4

Iteration  15
Max performance 70
9
average score is  9.45

Iteration  16
Max performance 70
10
average score is  4.55

Iteration  17
Max performance 

16
average score is  0.8

Iteration  138
Max performance 70
22
average score is  0.5

Iteration  139
Max performance 70
30
average score is  3.35

Iteration  140
Max performance 70
21
average score is  0.6

Iteration  141
Max performance 70
22
average score is  0.15

Iteration  142
Max performance 70
23
average score is  0.65

Iteration  143
Max performance 70
18
average score is  0.75

Iteration  144
Max performance 70
16
average score is  2.1

Iteration  145
Max performance 70
21
average score is  1.9

Iteration  146
Max performance 70
27
average score is  1.75

Iteration  147
Max performance 70
21
average score is  2.4

Iteration  148
Max performance 70
24
average score is  2.05

Iteration  149
Max performance 70
30
average score is  10.35

Iteration  150
Max performance 70
26
average score is  7.5

Iteration  151
Max performance 70
39
average score is  25.4

Iteration  152
Max performance 70
25
average score is  8.45

Iteration  153
Max performance 70
27
average score is  7.3

Iter

31
average score is  69.95

Iteration  273
Max performance 70
31
average score is  34.65

Iteration  274
Max performance 70
17
average score is  14.65

Iteration  275
Max performance 70
33
average score is  8.15

Iteration  276
Max performance 70
22
average score is  19.45

Iteration  277
Max performance 70
21
average score is  23.4

Iteration  278
Max performance 70
24
average score is  36.95

Iteration  279
Max performance 70
18
average score is  18.7

Iteration  280
Max performance 70
27
average score is  41.4

Iteration  281
Max performance 70
22
average score is  27.95

Iteration  282
Max performance 70
24
average score is  4.85

Iteration  283
Max performance 70
28
average score is  11.35

Iteration  284
Max performance 70
28
average score is  9.35

Iteration  285
Max performance 70
24
average score is  5.75

Iteration  286
Max performance 70
18
average score is  6.3

Iteration  287
Max performance 70
27
average score is  42.75

Iteration  288
Max performance 70
14
average score 

20
average score is  24.35

Iteration  408
Max performance 70
31
average score is  25.0

Iteration  409
Max performance 70
29
average score is  33.75

Iteration  410
Max performance 70
21
average score is  26.8

Iteration  411
Max performance 70
25
average score is  38.0

Iteration  412
Max performance 70
27
average score is  31.2

Iteration  413
Max performance 70
19
average score is  9.15

Iteration  414
Max performance 70
40
average score is  13.25

Iteration  415
Max performance 70
27
average score is  18.2

Iteration  416
Max performance 70
26
average score is  7.55

Iteration  417
Max performance 70
23
average score is  9.0

Iteration  418
Max performance 70
22
average score is  12.35

Iteration  419
Max performance 70
20
average score is  12.75

Iteration  420
Max performance 70
27
average score is  29.95

Iteration  421
Max performance 70
14
average score is  14.3

Iteration  422
Max performance 70
22
average score is  15.8

Iteration  423
Max performance 70
15
average score is

22
average score is  10.25

Iteration  543
Max performance 70
21
average score is  14.1

Iteration  544
Max performance 70
14
average score is  13.15

Iteration  545
Max performance 70
17
average score is  16.05

Iteration  546
Max performance 70
8
average score is  18.35

Iteration  547
Max performance 70
15
average score is  15.15

Iteration  548
Max performance 70
7
average score is  14.1

Iteration  549
Max performance 70
17
average score is  35.35

Iteration  550
Max performance 70
27
average score is  43.15

Iteration  551
Max performance 70
11
average score is  27.9

Iteration  552
Max performance 70
25
average score is  34.55

Iteration  553
Max performance 70
21
average score is  30.2

Iteration  554
Max performance 70
23
average score is  23.4

Iteration  555
Max performance 70
27
average score is  15.15

Iteration  556
Max performance 70
35
average score is  11.8

Iteration  557
Max performance 70
29
average score is  15.1

Iteration  558
Max performance 70
32
average score 

36
average score is  18.7

Iteration  677
Max performance 70.45
34
average score is  39.95

Iteration  678
Max performance 70.45
44
average score is  28.1

Iteration  679
Max performance 70.45
42
average score is  29.05

Iteration  680
Max performance 70.45
40
average score is  31.5

Iteration  681
Max performance 70.45
37
average score is  41.9

Iteration  682
Max performance 70.45
28
average score is  20.1

Iteration  683
Max performance 70.45
36
average score is  28.4

Iteration  684
Max performance 70.45
30
average score is  28.2

Iteration  685
Max performance 70.45
27
average score is  31.15

Iteration  686
Max performance 70.45
31
average score is  16.35

Iteration  687
Max performance 70.45
30
average score is  17.85

Iteration  688
Max performance 70.45
40
average score is  15.85

Iteration  689
Max performance 70.45
34
average score is  16.85

Iteration  690
Max performance 70.45
37
average score is  23.75

Iteration  691
Max performance 70.45
30
average score is  31.95

Iter

43
average score is  13.6

Iteration  805
Max performance 74.3
45
average score is  9.5

Iteration  806
Max performance 74.3
60
average score is  5.35

Iteration  807
Max performance 74.3
59
average score is  3.55

Iteration  808
Max performance 74.3
40
average score is  2.65

Iteration  809
Max performance 74.3
38
average score is  3.65

Iteration  810
Max performance 74.3
41
average score is  3.25

Iteration  811
Max performance 74.3
42
average score is  6.4

Iteration  812
Max performance 74.3
30
average score is  22.25

Iteration  813
Max performance 74.3
25
average score is  26.75

Iteration  814
Max performance 74.3
28
average score is  20.35

Iteration  815
Max performance 74.3
27
average score is  31.95

Iteration  816
Max performance 74.3
35
average score is  29.05

Iteration  817
Max performance 74.3
27
average score is  35.8

Iteration  818
Max performance 74.3
20
average score is  37.8

Iteration  819
Max performance 74.3
33
average score is  37.85

Iteration  820
Max perfo

23
average score is  47.95

Iteration  935
Max performance 82.6
39
average score is  41.65

Iteration  936
Max performance 82.6
31
average score is  48.1

Iteration  937
Max performance 82.6
33
average score is  45.35

Iteration  938
Max performance 82.6
38
average score is  27.75

Iteration  939
Max performance 82.6
31
average score is  73.65

Iteration  940
Max performance 82.6
52
average score is  54.6

Iteration  941
Max performance 82.6
53
average score is  22.5

Iteration  942
Max performance 82.6
57
average score is  9.3

Iteration  943
Max performance 82.6
69
average score is  9.25

Iteration  944
Max performance 82.6
51
average score is  7.55

Iteration  945
Max performance 82.6
52
average score is  2.3

Iteration  946
Max performance 82.6
41
average score is  0.55

Iteration  947
Max performance 82.6
27
average score is  3.85

Iteration  948
Max performance 82.6
37
average score is  26.15

Iteration  949
Max performance 82.6
46
average score is  23.6

Iteration  950
Max perfo

KeyboardInterrupt: 