In [1]:
# -*- coding: latin -*-
import sys, math
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as np
from torch.utils.data import DataLoader, Dataset

from typing import List

class TNode:
	def __init__(self):
		self.child = [None] * 6  # Liste pour les nœuds enfants, initialisee avec None
		self.ind_child = 0        # Indice du nœud enfant
		self.parent = None        # Pointeur vers le noeud parent
		self.ucb = 0.0            # Valeur UCB
		self.n = 0.0              # Nombre de visites
		self.w = 0.0              # Recompense cumulative
		self.num = 0              # Compteur
		self.score = 0            # Score
		self.expand = False        # Indicateur d'expansion

	def __del__(self):
		# Destructor (lorsque l'objet est detruit)
		pass  # Pas besoin de liberer manuellement la memoire en Python

	def copy(self, sim):
		# Methode pour copier un autre nœud dans celui-ci
		self.child = sim.child.copy()  # Copie la liste des enfants
		self.ind_child = sim.ind_child
		self.parent = sim.parent
		self.ucb = sim.ucb
		self.n = sim.n
		self.w = sim.w
		self.num = sim.num
		self.score = sim.score
		self.expand = sim.expand


class Vector2:
	def __init__(self, x=0.0, y=0.0):
		self.x = x
		self.y = y

def distance(vec1: Vector2, vec2: Vector2) -> float:
	return math.sqrt((vec1.x - vec2.x) ** 2 + (vec1.y - vec2.y) ** 2)

sorties = [
	(-18, 0),   # sortie 1
	(0, 0),     # sortie 2
	(18, 0),    # sortie 3
	(-18, 200), # sortie 4
	(0, 200),   # sortie 5
	(18, 200)   # sortie 6
]


class Checkpoint:
	def __init__(self, x, y):
		self.x = x
		self.y = y
		self.pos = Vector2(x, y)

MAX_DOUBLE = sys.float_info.max  # La plus grande valeur flottante possible

class Sim:
	def __init__(self):
		self.pos = Vector2()           # Position, initialisee avec un vecteur (0.0, 0.0)
		self.speed = Vector2()         # Vitesse, initialisee avec un vecteur (0.0, 0.0)
		self.direction = Vector2()      # Direction, initialisee avec un vecteur (0.0, 0.0)
		self.angle = 0.0               # Angle, initialise a 0.0
		self.angletot = 0.0             # Angle total, initialise a 0.0
		self.next_checkpoint = 0        # Prochain checkpoint, initialise a 0
		self.thrust = 0.0               # Poussee, initialisee a 0.0
		self.check_point = 0            # Compteur de checkpoints, initialise a 0
		self.check_pass = 0             # Compteur de checkpoints passes, initialise a 0


	def clone(self):
		result = Sim()  # Cree une nouvelle instance de Sim
		result.pos.x = self.pos.x
		result.pos.y = self.pos.y
		result.speed.x = self.speed.x
		result.speed.y = self.speed.y
		result.angle = self.angle
		result.angletot = self.angletot
		result.next_checkpoint = self.next_checkpoint
		result.thrust = self.thrust
		result.check_point = self.check_point
		result.check_pass = self.check_pass
		result.direction.x = self.direction.x
		result.direction.y = self.direction.y
		return result

	def simulate(self):
		anglef = int(self.angletot + self.angle + 360) % 360  # Calcul de l'angle
		angle_rad = math.radians(anglef)  # Conversion en radians
		dir = Vector2(
			math.cos(angle_rad) * self.thrust,
			math.sin(angle_rad) * self.thrust
		)
		self.speed.x += dir.x  # Mise a jour de la vitesse
		self.speed.y += dir.y
		self.pos.x += self.speed.x  # Mise a jour de la position
		self.pos.y += self.speed.y

	def end_simulate(self):
		self.pos.x = int(self.pos.x)  # Arrondi de la position
		self.pos.y = int(self.pos.y)
		self.speed.x = int(self.speed.x * 0.85)  # Reduction de la vitesse
		self.speed.y = int(self.speed.y * 0.85)

	
	def selection(self, node, leaf):
		
		for i in range(6):
			if int(node.child[i].n) != 0:
				ad = math.sqrt(2.0 * math.log(node.n) / node.child[i].n)
				node.child[i].ucb = (node.child[i].score / node.child[i].n) + ad
			else:
				node.child[i].ucb = MAX_DOUBLE  # Utilise la constante pour la plus grande valeur possible
		

		# Selectionner l'enfant avec le UCB le plus eleve
		max_ucb = -MAX_DOUBLE  # Initialise a la plus petite valeur possible
		indi = -1
		for i in range(6):
			if node.child[i].ucb > max_ucb:
				max_ucb = node.child[i].ucb
				indi = i
			

		#leaf[0] = node.child[indi]  # On utilise une liste pour permettre la modification par reference
		return indi

	def expand(self, node, depth):
		node.expand = True
		
		# Boucle pour creer jusqu'a 6 nœuds enfants
		for i in range(6):
			n = TNode()  # Cree un nouvel objet Node
			n.parent = node
			n.num = i
			n.ind_child = 0
			n.score = 0
						
			node.child[i] = n
			node.ind_child += 1

	def backpropagation(self, parent, sc: int):
		par = parent
		
		while par is not None:
			
			par.n += 1
			par.score += sc
			par.w += 1
			par = par.parent

	
	def rmcts(self, depth: int, checkp, maxc: int) -> int:
		root = TNode()
		root.parent = None
		root.score = 0
		root.ind_child = 0

		
		time = 0

		while time < 500:
			_node_p1 = root
			node_p1 = None

			pod1 = self.clone()

			for depth2 in range(depth):
				if _node_p1.ind_child == 0:
					self.expand(_node_p1, depth2)

				
				ind = self.selection(_node_p1, node_p1)
				node_p1 = _node_p1.child[ind]

				pod1.angle = sorties[node_p1.num][0]
				pod1.thrust = sorties[node_p1.num][1]
				pod1.simulate()
				pod1.end_simulate()
				pod1.angletot = (pod1.angletot + pod1.angle + 360) % 360

				
				d = distance(checkp[6][pod1.check_point].pos, pod1.pos)

				if d <= 600:
					pod1.check_pass += 1
					pod1.check_point = (pod1.check_point + 1) % maxc

				_node_p1 = node_p1

			
			d = distance(checkp[6][pod1.check_point].pos, pod1.pos)

			score = 50000.0 * pod1.check_pass - d

			self.backpropagation(node_p1, int(score))

			time += 1

		indc = -1
		maxscore = -2000000000
		for i in range(6):
			score = 0

			if root.child[i].n == 0:
				continue

			score = int(root.child[i].n)

			if score > maxscore:
				maxscore = score
				indc = i

		res = root.child[indc].num
				

		return res

class Superpod:
	def __init__(self):
		self.ind_map = 6
		self.simp = Sim()  # Assuming Sim has a no-argument constructor
		self.pass_check = False
		self.counter_check = 20000
		self.indch = 0
		self.previous_speed = 0.0
		self.nb_reach = 100
		
		# Initialize maps and lengths
		self.maps = [[] for _ in range(14)]  # List of lists for checkpoints
		self.len_map = [0] * 14  # Lengths of each map (0 to 13)

		# Map 1
		self.maps[0].append(Checkpoint(12460, 1350))
		self.maps[0].append(Checkpoint(10540, 5980))
		self.maps[0].append(Checkpoint(3580, 5180))
		self.maps[0].append(Checkpoint(13580, 7600))
		self.len_map[0] = 4

		# Map 2
		self.maps[1].append(Checkpoint(3600, 5280))
		self.maps[1].append(Checkpoint(13840, 5080))
		self.maps[1].append(Checkpoint(10680, 2280))
		self.maps[1].append(Checkpoint(8700, 7460))
		self.maps[1].append(Checkpoint(7200, 2160))
		self.len_map[1] = 5

		# Map 3
		self.maps[2].append(Checkpoint(4560, 2180))
		self.maps[2].append(Checkpoint(7350, 4940))
		self.maps[2].append(Checkpoint(3320, 7230))
		self.maps[2].append(Checkpoint(14580, 7700))
		self.maps[2].append(Checkpoint(10560, 5060))
		self.maps[2].append(Checkpoint(13100, 2320))
		self.len_map[2] = 6

		# Map 4
		self.maps[3].append(Checkpoint(5010, 5260))
		self.maps[3].append(Checkpoint(11480, 6080))
		self.maps[3].append(Checkpoint(9100, 1840))
		self.len_map[3] = 3

		# Map 5
		self.maps[4].append(Checkpoint(14660, 1410))
		self.maps[4].append(Checkpoint(3450, 7220))
		self.maps[4].append(Checkpoint(9420, 7240))
		self.maps[4].append(Checkpoint(5970, 4240))
		self.len_map[4] = 4

		# Map 6
		self.maps[5].append(Checkpoint(3640, 4420))
		self.maps[5].append(Checkpoint(8000, 7900))
		self.maps[5].append(Checkpoint(13300, 5540))
		self.maps[5].append(Checkpoint(9560, 1400))
		self.len_map[5] = 4

		# Map 7
		self.maps[6].append(Checkpoint(4100, 7420))
		self.maps[6].append(Checkpoint(13500, 2340))
		self.maps[6].append(Checkpoint(12940, 7220))
		self.maps[6].append(Checkpoint(5640, 2580))
		self.len_map[6] = 4

		# Map 8
		self.maps[7].append(Checkpoint(14520, 7780))
		self.maps[7].append(Checkpoint(6320, 4290))
		self.maps[7].append(Checkpoint(7800, 860))
		self.maps[7].append(Checkpoint(7660, 5970))
		self.maps[7].append(Checkpoint(3140, 7540))
		self.maps[7].append(Checkpoint(9520, 4380))
		self.len_map[7] = 6

		# Map 9
		self.maps[8].append(Checkpoint(10040, 5970))
		self.maps[8].append(Checkpoint(13920, 1940))
		self.maps[8].append(Checkpoint(8020, 3260))
		self.maps[8].append(Checkpoint(2670, 7020))
		self.len_map[8] = 4

		# Map 10
		self.maps[9].append(Checkpoint(7500, 6940))
		self.maps[9].append(Checkpoint(6000, 5360))
		self.maps[9].append(Checkpoint(11300, 2820))
		self.len_map[9] = 3

		# Map 11
		self.maps[10].append(Checkpoint(4060, 4660))
		self.maps[10].append(Checkpoint(13040, 1900))
		self.maps[10].append(Checkpoint(6560, 7840))
		self.maps[10].append(Checkpoint(7480, 1360))
		self.maps[10].append(Checkpoint(12700, 7100))
		self.len_map[10] = 5

		# Map 12
		self.maps[11].append(Checkpoint(3020, 5190))
		self.maps[11].append(Checkpoint(6280, 7760))
		self.maps[11].append(Checkpoint(14100, 7760))
		self.maps[11].append(Checkpoint(13880, 1220))
		self.maps[11].append(Checkpoint(10240, 4920))
		self.maps[11].append(Checkpoint(6100, 2200))
		self.len_map[11] = 6

		# Map 13
		self.maps[12].append(Checkpoint(10323, 3366))
		self.maps[12].append(Checkpoint(11203, 5425))
		self.maps[12].append(Checkpoint(7259, 6656))
		self.maps[12].append(Checkpoint(5425, 2838))
		self.len_map[12] = 4

	def reset(self):
		# Randomly choose the map (this line is commented out in Pascal)
		# self.ind_map = random.randint(0, 12)

		# Resetting checkpoint index
		self.indch = 0
		self.simp.check_point = self.indch + 1
		if self.simp.check_point == self.len_map[self.ind_map]:
			self.simp.check_point = 0

		# Set position based on the current map and checkpoint
		self.simp.pos.x = self.maps[self.ind_map][self.indch].x
		self.simp.pos.y = self.maps[self.ind_map][self.indch].y

		# Initialize speed to zero
		self.simp.speed.x = 0
		self.simp.speed.y = 0

		# Direction vector (from simp.pos to the next checkpoint)
		dir_x = self.maps[self.ind_map][self.simp.check_point].x - self.simp.pos.x
		dir_y = self.maps[self.ind_map][self.simp.check_point].y - self.simp.pos.y

		# Calculate angle in radians
		angle_radians = math.atan2(dir_y, dir_x)

		# Convert to degrees
		angle_degrees = angle_radians * (180.0 / math.pi)

		# Adjust angle to be within 0-360 degrees
		if angle_degrees < 0:
			angle_degrees += 360.0

		self.simp.angletot = -1  # This line corresponds to angle_degrees; change as needed

		# Set check_pass and random thrust and angle
		self.simp.check_pass = 1
		self.simp.thrust = random.randint(0, 200)
		self.simp.angle = random.randint(-18, 18)

		self.nb_reach = 100


	def step(self, action: int, done, turn: int) -> float:
		reward = 0.0
		done[0] = 0

		self.simp.thrust = float(sorties[action][1])
		self.simp.angle = float(sorties[action][0])

		check = Vector2()
		check.x = self.maps[self.ind_map][self.simp.check_point].x
		check.y = self.maps[self.ind_map][self.simp.check_point].y
		d = distance(self.simp.pos, check)

		# Calculate final angle
		anglef = int(self.simp.angletot + self.simp.angle + 360) % 360
		angle_rad = float(anglef) * math.pi / 180.0

		# Calculate direction vectors
		dir_x = math.cos(angle_rad) * self.simp.thrust
		dir_y = math.sin(angle_rad) * self.simp.thrust

		self.simp.direction.x = math.cos(angle_rad) * 1000
		self.simp.direction.y = math.sin(angle_rad) * 1000

		# Update speed
		self.simp.speed.x += dir_x
		self.simp.speed.y += dir_y

		# Update position
		self.simp.pos.x = round(self.simp.pos.x + self.simp.speed.x)
		self.simp.pos.y = round(self.simp.pos.y + self.simp.speed.y)

		# Reduce speed (friction coefficient)
		self.simp.speed.x = int(self.simp.speed.x * 0.85)
		self.simp.speed.y = int(self.simp.speed.y * 0.85)

		speed = math.sqrt(self.simp.speed.x ** 2 + self.simp.speed.y ** 2)

		# Update total angle
		self.simp.angletot = int(self.simp.angletot + self.simp.angle + 360) % 360

		check.x = self.maps[self.ind_map][self.simp.check_point].x
		check.y = self.maps[self.ind_map][self.simp.check_point].y
		d2 = distance(self.simp.pos, check)

		col = (self.simp.speed.x * (self.maps[self.ind_map][self.simp.check_point].x - self.simp.pos.x) +
				self.simp.speed.y * (self.maps[self.ind_map][self.simp.check_point].y - self.simp.pos.y)) / \
			   (math.sqrt(self.simp.speed.x ** 2 + self.simp.speed.y ** 2) *
				math.sqrt((self.maps[self.ind_map][self.simp.check_point].x - self.simp.pos.x) ** 2 +
						  (self.maps[self.ind_map][self.simp.check_point].y - self.simp.pos.y) ** 2) + 0.000001)

		if d2 <= 600:
			self.simp.check_point = (self.simp.check_point + 1) % self.len_map[self.ind_map]
			self.simp.check_pass += 1
			reward = 1.0

			self.nb_reach = 100
			done[0] = 2

			#if self.simp.check_pass > 1:
			#	print(f'Pass checkpoint={self.simp.check_pass} {self.simp.check_point}')
			#	print(f'thrust={self.simp.thrust} angle={self.simp.angle} action={action} pos={self.simp.pos.x} {self.simp.pos.y} ind_map={self.ind_map}')
		else:
			angle_radians = math.atan2(self.maps[self.ind_map][self.simp.check_point].y - self.simp.pos.y,
										self.maps[self.ind_map][self.simp.check_point].x - self.simp.pos.x)
			angle_to_checkpoint = angle_radians * (180.0 / math.pi)

			if angle_to_checkpoint < 0:
				angle_to_checkpoint += 360.0

			diff_angle = abs(angle_to_checkpoint - self.simp.angletot)
			if diff_angle > 180.0:
				diff_angle = 360.0 - diff_angle

			next_check = (self.simp.check_point + 1) % self.len_map[self.ind_map]
			check2 = Vector2()
			check2.x = self.maps[self.ind_map][next_check].x
			check2.y = self.maps[self.ind_map][next_check].y
			d3 = distance(check2, self.simp.pos)

			#reward = ((15000.0 - d2) / 15000.0) * 0.1
			#reward += ((180.0 - diff_angle) / 180.0) * 0.1

			self.pass_check = False
			self.counter_check -= 1

		self.nb_reach -= 1

		if self.nb_reach < 0:
			done[0] = 1
			if reward != 1.0:
				reward = -1.0

		if self.simp.check_pass >= 6:
			reward = 1.0
			done[0] = 3

		self.previous_speed = speed  # Update previous speed

		return reward


def init_state(state, simp, maps, ind_map, len_map):
	next_check = (simp.check_point + 1) % len_map[ind_map]
	next_check2 = (simp.check_point + 2) % len_map[ind_map]
	last_check = (simp.check_point - 1 + len_map[ind_map]) % len_map[ind_map]

	# Recuperer les coordonnees
	x1, y1 = simp.pos.x, simp.pos.y
	x2, y2 = maps[ind_map][simp.check_point].x, maps[ind_map][simp.check_point].y
	x3, y3 = maps[ind_map][next_check].x, maps[ind_map][next_check].y

	# Calcul de l'angle
	angle = np.arctan2(y1 - y2, x1 - x2) - np.arctan2(y3 - y2, x3 - x2)
	angle = np.degrees(angle)
	angle = (int(angle) + 180) % 360

	if angle < 0.0:
		angle += 360.0
	
	angle -= 180.0

	# Calcul de l'angle du checkpoint
	anglech = np.arctan2(y2 - y1, x2 - x1)
	anglech = np.degrees(anglech)
	anglech = float((int(anglech - simp.angletot + 540) % 360) - 180)

	# Calcul de la colinearite
	col = (simp.speed.x * (maps[ind_map][simp.check_point].x - simp.pos.x) +
		   simp.speed.y * (maps[ind_map][simp.check_point].y - simp.pos.y)) / \
		  (np.sqrt(simp.speed.x ** 2 + simp.speed.y ** 2) *
		   np.sqrt((maps[ind_map][simp.check_point].x - simp.pos.x) ** 2 +
					(maps[ind_map][simp.check_point].y - simp.pos.y) ** 2) + 1e-6)

	# Calcul des distances
	check = Vector2(maps[ind_map][simp.check_point].x, maps[ind_map][simp.check_point].y)
	dist_check = distance(check, simp.pos)

	check2 = Vector2(maps[ind_map][last_check].x, maps[ind_map][last_check].y)
	ndist_check = distance(check, check2)

	# Calcul de la vitesse
	speed = np.sqrt(simp.speed.x ** 2 + simp.speed.y ** 2)

	# Normalisation des parametres
	angle_radians = np.arctan2(maps[ind_map][simp.check_point].y - simp.pos.y,
								maps[ind_map][simp.check_point].x - simp.pos.x)
	angleToCheckpoint = np.degrees(angle_radians)

	# Si l'angle est negatif, on ajoute 360 pour le ramener entre 0 et 360 degres
	if angleToCheckpoint < 0:
		angleToCheckpoint += 360.0

	diffAngle = abs(angleToCheckpoint - simp.angletot)
	if diffAngle > 180.0:
		diffAngle = 360.0 - diffAngle

	weight = 10000.0 if col < 0.8 else 0.0

	# Mise a jour de l'etat
	x = simp.pos.x
	y = simp.pos.y

	state[0] = x / 16000.0
	state[1] = y / 9000.0
	state[2] = simp.speed.x / 1000.0
	state[3] = simp.speed.y / 1000.0
	state[4] = simp.angletot / 360.0
	state[5] = dist_check / 15000.0
	state[6] = diffAngle / 180.0
	state[7] = (50000 * simp.check_pass - dist_check) / 300000.0
	state[8] = maps[ind_map][simp.check_point].x / 16000.0
	state[9] = maps[ind_map][simp.check_point].y / 16000.0


class PolicyNet(nn.Module):
	def __init__(self):
		super(PolicyNet, self).__init__()
		self.input_norm = nn.LayerNorm(10)  # Normalisation des entrées
		self.fc1 = nn.Linear(10, 64)
		self.fc2 = nn.Linear(64, 64)
		self.fc3 = nn.Linear(64, 6)

	def forward(self, x):
		x = self.input_norm(x)  # Normalisation des entrées
		x = F.relu(self.fc1(x))
		x = F.relu(self.fc2(x))
		x = self.fc3(x)
		return x

class ValueNet(nn.Module):
	def __init__(self):
		super(ValueNet, self).__init__()
		self.input_norm = nn.LayerNorm(10)  # Normalisation des entrées
		self.fc1 = nn.Linear(10, 64)
		self.fc2 = nn.Linear(64, 64)
		self.fc3 = nn.Linear(64, 1)

	def forward(self, x):
		x = self.input_norm(x)  # Normalisation des entrées
		x = F.relu(self.fc1(x))
		x = F.relu(self.fc2(x))
		x = self.fc3(x)
		return x

class CustomDataset(Dataset):
	def __init__(self, states, rewards, actions, values, advantages):
		# Assurez-vous que les états sont des tensors 2D
		self.states = [torch.tensor(state, dtype=torch.float32) for state in states]  # Conversion en tensors
		self.rewards = torch.tensor(rewards, dtype=torch.float32)
		self.actions = torch.tensor(actions, dtype=torch.long)  # Si actions sont des entiers
		self.values = torch.tensor(values, dtype=torch.float32)
		self.advantages = torch.tensor(advantages, dtype=torch.float32)

	def __len__(self):
		return len(self.rewards)

	def __getitem__(self, idx):
		return (self.states[idx], self.rewards[idx], 
				self.actions[idx], self.values[idx], 
				self.advantages[idx])

def save_best_nn_py(model, filename):
	with open(filename, 'w') as f:
		# Sauvegarder les poids
		f.write('netw=[')
		first_layer = True
		for i, layer in enumerate(model.children()):
			if isinstance(layer, nn.Linear):  # Si c'est une couche dense
				weight = layer.weight.detach().numpy().T  # Transposer les poids
				if not first_layer:
					f.write(',' + '\n')
				f.write('[')
				for j, row in enumerate(weight):
					f.write('[' + ','.join([f'{val:.6f}' for val in row]) + ']')
					if j < len(weight) - 1:
						f.write(',')
				f.write(']')
				first_layer = False
		f.write(']\n')

		# Sauvegarder les biais
		f.write('netb=[')
		first_bias = True
		for i, layer in enumerate(model.children()):
			if isinstance(layer, nn.Linear):  # Si c'est une couche dense
				bias = layer.bias.detach().numpy()  # Récupérer les biais
				if not first_bias:
					f.write(',' + '\n')
				f.write('[' + ','.join([f'{val:.6f}' for val in bias]) + ']')
				first_bias = False
		f.write(']\n')

def calculate_A_t(rewards, values, gamma, lambd):
	T = len(rewards)
	A_t = np.zeros(T)  # Initialiser un tableau pour stocker A_t

	# Calculer les delta
	deltas = np.zeros(T)
	for t in range(T - 1):
		deltas[t] = rewards[t] + gamma * values[t + 1] - values[t]

	# Calculer A_t en utilisant les deltas
	for t in range(T):
		for k in range(t, T):
			A_t[t] += (gamma * lambd) ** (k - t) * deltas[k]

	return A_t

def compute_advantages(rewards, values, gamma=0.99, tau=0.95):
	advantages = []
	gae = 0
	for t in reversed(range(len(rewards))):
		delta = rewards[t] + gamma * (values[t + 1] if t + 1 < len(values) else 0) - values[t]
		gae = delta + gamma * tau * gae
		advantages.insert(0, gae)
	return advantages

def TrainPPO():

	policy_net = PolicyNet()
	policy_net_old = PolicyNet()
	value_net = ValueNet()

	

	# Créer des optimisateurs Adam pour les deux réseaux
	policy_optimizer = torch.optim.Adam(policy_net.parameters(), lr=0.0003)  # Taux d'apprentissage de 0.001
	value_optimizer = torch.optim.Adam(value_net.parameters(), lr=0.001)      # Taux d'apprentissage de 0.001


	for episode in range(5):
		
		done = [0]
		turn = 0

		rewards = []
		action = []
		values = []

		state_tab =[]


		super = Superpod()

		super.reset()

		# Boucle principale
		while done[0] == 0 and turn < 2000:
		
			state = [0] * 10
			init_state(state, super.simp, super.maps, super.ind_map, super.len_map)

			state_tab += [state]

			actions = super.simp.rmcts(4, super.maps, super.len_map[super.ind_map])
			action += [actions]
			rewards += [super.step(actions, done, turn)]

			# Convertir l'état en tenseur pour le réseau
			state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)  # Ajouter une dimension pour le batch

			# Obtenir la sortie du réseau de valeur
			value_output = value_net(state_tensor)
			values += [value_output[0]]

			# En supposant que tu as une cible pour la valeur, par exemple `target_value` 
			# (que tu dois définir selon ta logique de récompense)
			target_value = rewards[-1]  # Remplacer cela par ta logique pour obtenir la valeur cible

			# Calculer la perte MSE
			loss_value = F.mse_loss(value_output, torch.tensor([[target_value]], dtype=torch.float32))  # Assurez-vous que les dimensions correspondent

			# (Optionnel) Rétropropagation si tu souhaites mettre à jour value_net
			value_optimizer.zero_grad()  # Réinitialiser les gradients de value_net
			loss_value.backward()              # Calculer les gradients
			value_optimizer.step()         # Mettre à jour les poids
			#print(f'Loss Value: {loss_value.item()}')  # Affiche la perte MSE

			# Vérifier si le jeu est terminé
			if done[0] == 2:
				done[0] = 0

			turn += 1


		
		advantages = compute_advantages(rewards, values)


		total_loss = 0.0
		batch_count = 0
		for nbatch in range(500):
		
			# Créer le dataset
			dataset = CustomDataset(state_tab, rewards, action, values, advantages)

			# Créer un DataLoader
			dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

			

			# Boucle d'entraînement ou de test
			for batch in dataloader:
			
				state_batch, reward_batch, action_batch, value_batch, adv_batch = batch  # Déballer le batch

				# Les tenseurs sont déjà en format tensor grâce à CustomDataset
				# Aucune conversion supplémentaire n'est nécessaire ici

				# Calculer les logits et les probabilités
				action_logits = policy_net(state_batch) 
				action_probs = F.softmax(action_logits, dim=-1)

				with torch.no_grad():
					action_logits_old = policy_net_old(state_batch)
					action_probs_old = F.softmax(action_logits_old, dim=-1)

				# Récupérer les probabilités pour les actions prises
				# Assurez-vous que `action_batch` est au bon format (long)
				action_probs_taken = action_probs[range(len(action_batch)), action_batch]
				action_probs_old_taken = action_probs_old[range(len(action_batch)), action_batch]

				# Calculer r_t
				rt = action_probs_taken / (action_probs_old_taken + 1e-10)

				# Initialisation de rtc
				rtc = rt.clone()  # Cloner rt pour éviter la modification de la valeur d'origine
				
				
				clipping_ratio = 0.2
				# Calculer rtc avec les conditions de clipping
				rtc = torch.where(rt > (1 + clipping_ratio), 
								  torch.tensor(1 + clipping_ratio, dtype=rt.dtype, device=rt.device), 
								  rtc)
				rtc = torch.where(rt < (1 - clipping_ratio), 
								  torch.tensor(1 - clipping_ratio, dtype=rt.dtype, device=rt.device), 
								  rtc)

				adv_batch = adv_batch.clone().detach().requires_grad_(True)
				# Calculer la loss
				# Calculer LCLIP
				LCLIP = -torch.where((rt * adv_batch) < (rtc * adv_batch), 
                                 rt * adv_batch, 
                                 rtc * adv_batch).mean()

				
				loss = LCLIP 


				# Calculer la divergence KL
				#kl_divergence = F.kl_div(action_probs_old_taken.log(), action_probs_taken, reduction='batchmean')

				# Vérifier si la divergence KL dépasse la limite
				#if kl_divergence > 0.5:
				#	return  # Sortir si la limite est dépassée

				# Rétropropagation de l'erreur
				policy_optimizer.zero_grad()  # Réinitialiser les gradients
				loss.backward()        # Calculer les gradients
				policy_optimizer.step()       # Mettre à jour les poids

				#print(f'Loss: {loss.item()}')  # Affiche la perte pour le batch actuel
				total_loss += loss.item()
				batch_count += 1

		policy_net_old.load_state_dict(policy_net.state_dict())
		
		average_loss = total_loss / batch_count if batch_count > 0 else 0.0
		#if ((episode+1) % 10 == 0):
		print(f'episode={episode} Loss: {average_loss}')

	save_best_nn_py(policy_net, '/kaggle/working/best_wpod.py')

	

TrainPPO()


episode=0 Loss: -1.3280191165208817
episode=1 Loss: -1.3230233375344957
episode=2 Loss: -1.3116335142510278
episode=3 Loss: -1.314304367184639
episode=4 Loss: -1.2877177483183997
