Skip to content

Commit c8d0abb

Browse files
committed
added svelte skill
1 parent dee77a4 commit c8d0abb

File tree

10 files changed

+903
-2
lines changed

10 files changed

+903
-2
lines changed
16.1 KB
Loading

content/posts/machine learning/RL/ppo.ipynb

+320
Large diffs are not rendered by default.
+116
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
import gymnasium as gym
2+
import torch
3+
import numpy as np
4+
import torch.nn as nn
5+
import torch.optim as optim
6+
from torch.utils.data import Dataset, DataLoader
7+
8+
# Define the PPO agent
9+
class PPOAgent(nn.Module):
10+
def __init__(self, state_dim, action_dim):
11+
super(PPOAgent, self).__init__()
12+
self.policy_network = nn.Sequential(
13+
nn.Linear(state_dim, 128),
14+
nn.ReLU(),
15+
nn.Linear(128, 128),
16+
nn.ReLU(),
17+
nn.Linear(128, action_dim)
18+
)
19+
self.value_network = nn.Sequential(
20+
nn.Linear(state_dim, 128),
21+
nn.ReLU(),
22+
nn.Linear(128, 128),
23+
nn.ReLU(),
24+
nn.Linear(128, 1)
25+
)
26+
27+
def forward(self, state):
28+
policy_output = self.policy_network(state)
29+
value_output = self.value_network(state)
30+
return policy_output, value_output
31+
32+
# Define the priority network
33+
class PriorityNetwork(nn.Module):
34+
def __init__(self, state_dim, action_dim):
35+
super(PriorityNetwork, self).__init__()
36+
self.priority_network = nn.Sequential(
37+
nn.Linear(state_dim + action_dim + 1 + state_dim + 1, 128),
38+
nn.ReLU(),
39+
nn.Linear(128, 1)
40+
)
41+
42+
def forward(self, experience):
43+
priority_output = self.priority_network(experience)
44+
return priority_output
45+
46+
# Define the PPO trainer
47+
class PPOTrainer:
48+
49+
def __init__(self, agent, priority_network, gamma, lambda_, epsilon, c1, c2):
50+
self.agent = agent
51+
self.priority_network = priority_network
52+
self.gamma = gamma
53+
self.lambda_ = lambda_
54+
self.epsilon = epsilon
55+
self.c1 = c1
56+
self.c2 = c2
57+
58+
def train(self, batch_size, epochs):
59+
for epoch in range(epochs):
60+
# Sample a batch of experiences from the replay buffer
61+
batch_experiences = self.sample_batch(batch_size)
62+
63+
# Compute the TD-error for each experience in the batch
64+
td_errors = []
65+
for experience in batch_experiences:
66+
state, action, reward, next_state, done = experience
67+
td_error = reward + self.gamma * self.agent.value_network(next_state) - self.agent.value_network(state)
68+
td_errors.append(td_error)
69+
70+
# Train the priority network
71+
self.priority_network.train()
72+
priority_optimizer = optim.Adam(self.priority_network.parameters(), lr=0.001)
73+
priority_loss_fn = nn.MSELoss()
74+
for experience, td_error in zip(batch_experiences, td_errors):
75+
priority_optimizer.zero_grad()
76+
priority_output = self.priority_network(experience)
77+
loss = priority_loss_fn(priority_output, torch.tensor(td_error))
78+
loss.backward()
79+
priority_optimizer.step()
80+
81+
# Train the PPO agent
82+
self.agent.train()
83+
policy_optimizer = optim.Adam(self.agent.policy_network.parameters(), lr=0.001)
84+
value_optimizer = optim.Adam(self.agent.value_network.parameters(), lr=0.001)
85+
for experience in batch_experiences:
86+
state, action, reward, next_state, done = experience
87+
policy_optimizer.zero_grad()
88+
value_optimizer.zero_grad()
89+
policy_output, value_output = self.agent(state)
90+
policy_loss = -torch.log(policy_output[action]) * reward
91+
value_loss = (value_output - reward) ** 2
92+
loss = policy_loss + value_loss
93+
loss.backward()
94+
policy_optimizer.step()
95+
value_optimizer.step()
96+
97+
def sample_batch(self, batch_size):
98+
# Sample a batch of experiences from the replay buffer
99+
# This is a placeholder for the actual sampling logic
100+
batch_experiences = []
101+
for _ in range(batch_size):
102+
batch_experiences.append(np.random.rand(6)) # state, action, reward, next_state, done
103+
return batch_experiences
104+
105+
# Create the Gym Car2D environment
106+
env = gym.make("CarRacing-v2")
107+
108+
# Create the PPO agent and priority network
109+
agent = PPOAgent(state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0])
110+
priority_network = PriorityNetwork(state_dim=env.observation_space.shape[0], action_dim=env.action_space.shape[0])
111+
112+
# Create the PPO trainer
113+
trainer = PPOTrainer(agent, priority_network, gamma=0.99, lambda_=0.95, epsilon=0.1, c1=0.5, c2=0.01)
114+
115+
# Train the PPO agent
116+
trainer.train(batch_size=32, epochs=1000)

data/en/author.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ contactInfo:
1616
summary:
1717
- I am a Data Scientist
1818
- I am a Machine Learning Engineer
19-
- I love Physics and Maths
19+
- I studied Physics and Maths
2020
- I like open-source projects
2121
- I like challenges

data/en/sections/skills.yaml

+6
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,9 @@ skills:
6868
- name: C++
6969
logo: /images/sections/skills/c++.png
7070
summary: "Know basic C/C++ programming. I used often to accelerate python computational time"
71+
72+
- name: Svelte
73+
logo: /images/sections/skills/svelte.png
74+
summary: "Learned Svelte in order to build fast and robust server-side webapps."
75+
76+

public/index.html

+24-1
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ <h1 class="greeting"> Hi, I am Stefano</h1>
324324

325325
<li>I am a Machine Learning Engineer</li>
326326

327-
<li>I love Physics and Maths</li>
327+
<li>I studied Physics and Maths</li>
328328

329329
<li>I like open-source projects</li>
330330

@@ -889,6 +889,29 @@ <h5 class="card-title">C&#43;&#43;</h5>
889889
</div>
890890

891891

892+
<div class="col-xs-12 col-sm-6 col-lg-4 pt-2">
893+
<a class="text-decoration-none" >
894+
<div class="card">
895+
<div class="card-head d-flex">
896+
897+
898+
899+
900+
901+
902+
903+
<img class="card-img-xs" src="/images/sections/skills/svelte_hub16989d5432f747a495481e59dd94a22_16448_24x24_fit_box_3.png" alt="Svelte" />
904+
<h5 class="card-title">Svelte</h5>
905+
906+
</div>
907+
<div class="card-body">
908+
<p class="card-text">Learned Svelte in order to build fast and robust server-side webapps.</p>
909+
</div>
910+
</div>
911+
</a>
912+
</div>
913+
914+
892915
</div>
893916
</div>
894917
</div>

0 commit comments

Comments
 (0)