# 🚀 Entraînement Chess NN sur GPU (Google Colab)

Ce notebook entraîne ton réseau de neurones d'évaluation d'échecs sur GPU.

## ⚡ Configuration GPU
**IMPORTANT**: Active le GPU avant de commencer !
- Menu: **Runtime** → **Change runtime type**
- Hardware accelerator: **GPU (T4)**
- Clique **Save**

In [10]:
# Vérifier que le GPU est bien activé
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA disponible: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ AUCUN GPU DÉTECTÉ! Va dans Runtime > Change runtime type > GPU")

PyTorch version: 2.8.0+cu126
CUDA disponible: False
⚠️ AUCUN GPU DÉTECTÉ! Va dans Runtime > Change runtime type > GPU


## 📦 Installation des dépendances

In [9]:
!pip install pandas tqdm -q

## 📂 Upload du code et dataset

**Option 1: Upload direct** (pour petits fichiers)

In [None]:
from google.colab import files

# Upload tes fichiers Python
print("Upload Chess.py, torch_nn_evaluator.py, train_torch.py")
uploaded = files.upload()

# IMPORTANT: Pour le dataset (13M lignes = gros fichier),
# utilise plutôt Google Drive (voir Option 2 ci-dessous)

**Option 2: Google Drive** (RECOMMANDÉ pour gros dataset)

In [15]:
!git clone https://github.com/promaaa/smart-chess.git
!cp smart-chess/ai/*.py .

Cloning into 'smart-chess'...
remote: Enumerating objects: 483, done.[K
remote: Counting objects: 100% (483/483), done.[K
remote: Compressing objects: 100% (306/306), done.[K
remote: Total 483 (delta 189), reused 437 (delta 169), pack-reused 0 (from 0)[K
Receiving objects: 100% (483/483), 5.57 MiB | 18.93 MiB/s, done.
Resolving deltas: 100% (189/189), done.


In [13]:
import pandas as pd

# Make sure the path below is correct for your file in Google Drive
dataset_path = "/content/drive/MyDrive/smart_chess_ia/chessData.csv"

try:
    df = pd.read_csv(dataset_path)
    print("Dataset loaded successfully!")
    display(df.head())
except FileNotFoundError:
    print(f"Error: The file was not found at {dataset_path}. Please check the path and make sure the file exists in your Google Drive.")
except Exception as e:
    print(f"An error occurred while loading the dataset: {e}")

Dataset loaded successfully!


Unnamed: 0,FEN,Evaluation
0,rnbqkbnr/pppppppp/8/8/4P3/8/PPPP1PPP/RNBQKBNR ...,-10
1,rnbqkbnr/pppp1ppp/4p3/8/4P3/8/PPPP1PPP/RNBQKBN...,56
2,rnbqkbnr/pppp1ppp/4p3/8/3PP3/8/PPP2PPP/RNBQKBN...,-9
3,rnbqkbnr/ppp2ppp/4p3/3p4/3PP3/8/PPP2PPP/RNBQKB...,52
4,rnbqkbnr/ppp2ppp/4p3/3p4/3PP3/8/PPPN1PPP/R1BQK...,-26


## 🔧 Code: Chess.py
Colle ton code Chess.py ici (ou upload via la cellule précédente)

In [18]:
# @title
%%writefile Chess.py
import numpy as np

class Chess:
    def __init__(self):
        # Use Python ints for bitboards to avoid numpy overhead in hot paths
        self.bitboards = {
            'P': 0x000000000000FF00,
            'N': 0x0000000000000042,
            'B': 0x0000000000000024,
            'R': 0x0000000000000081,
            'Q': 0x0000000000000008,
            'K': 0x0000000000000010,
            'p': 0x00FF000000000000,
            'n': 0x4200000000000000,
            'b': 0x2400000000000000,
            'r': 0x8100000000000000,
            'q': 0x0800000000000000,
            'k': 0x1000000000000000
        }
        self.white_to_move = True
        self.castling_rights = {'K': True, 'Q': True, 'k': True, 'q': True}
        self.en_passant_target = None
        self.check = False
        self.history = []

    # safe square_mask: return Python int mask (fast)
    def square_mask(self, sq):
        return 1 << int(sq)

    def occupancy(self):
        occ = 0
        for bb in self.bitboards.values():
            occ |= int(bb)
        return occ

    def pieces_of_color(self, white):
        mask = 0
        for p, bb in self.bitboards.items():
            if (p.isupper() and white) or (p.islower() and not white):
                mask |= int(bb)
        return mask

    def color_of_piece_char(self, p):
        return p.isupper()

    def print_board(self):
        board = ['.' for _ in range(64)]
        for piece, bitboard in self.bitboards.items():
            for i in range(64):
                if bool(bitboard & self.square_mask(i)):
                    board[i] = piece
        for rank in range(7, -1, -1):
            print(' '.join(board[rank*8:(rank+1)*8]))
        print()
    def compute_king_moves_basic(self, square, piece=None):
        """
        Mouvements basiques du roi (sans roque) pour éviter la récursion
        """
        king_moves = 0
        directions = [1, -1, 8, -8, 7, -7, 9, -9]
        own = 0
        if piece is not None:
            own = self.pieces_of_color(self.color_of_piece_char(piece))

        for direction in directions:
            target_square = square + direction
            if 0 <= target_square < 64:
                if abs((square % 8) - (target_square % 8)) <= 1:
                    if not (own & self.square_mask(target_square)):
                        king_moves |= self.square_mask(target_square)

        return king_moves
    # --- compute_* et helpers (inchangés logiquement) ---
    def compute_king_moves(self, square, piece=None):
        """
        Mouvements complets du roi (avec roque)
        """
        # Commencer par les mouvements basiques
        king_moves = self.compute_king_moves_basic(square, piece)

        # Ajouter le roque seulement si une pièce est spécifiée
        if piece is None:
            return king_moves

        # Gestion du roque
        if piece == 'K':
            if self.castling_rights.get('K', False):
                f1 = 5; g1 = 6; e1 = 4
                if not (self.occupancy() & (self.square_mask(f1) | self.square_mask(g1))):
                    if (not self.is_square_attacked(e1, by_white=False) and
                        not self.is_square_attacked(f1, by_white=False) and
                        not self.is_square_attacked(g1, by_white=False)):
                        king_moves |= self.square_mask(g1)
            if self.castling_rights.get('Q', False):
                d1 = 3; c1 = 2; b1 = 1; e1 = 4
                if not (self.occupancy() & (self.square_mask(d1) | self.square_mask(c1) | self.square_mask(b1))):
                    if (not self.is_square_attacked(e1, by_white=False) and
                        not self.is_square_attacked(d1, by_white=False) and
                        not self.is_square_attacked(c1, by_white=False)):
                        king_moves |= self.square_mask(c1)
        elif piece == 'k':
            if self.castling_rights.get('k', False):
                f8 = 61; g8 = 62; e8 = 60
                if not (self.occupancy() & (self.square_mask(f8) | self.square_mask(g8))):
                    if (not self.is_square_attacked(e8, by_white=True) and
                        not self.is_square_attacked(f8, by_white=True) and
                        not self.is_square_attacked(g8, by_white=True)):
                        king_moves |= self.square_mask(g8)
            if self.castling_rights.get('q', False):
                d8 = 59; c8 = 58; b8 = 57; e8 = 60
                if not (self.occupancy() & (self.square_mask(d8) | self.square_mask(c8) | self.square_mask(b8))):
                    if (not self.is_square_attacked(e8, by_white=True) and
                        not self.is_square_attacked(d8, by_white=True) and
                        not self.is_square_attacked(c8, by_white=True)):
                        king_moves |= self.square_mask(c8)

        return king_moves

    def compute_knight_moves(self, square, piece=None):
        knight_moves = 0
        directions = [15, 17, 10, 6, -15, -17, -10, -6]
        own = 0
        if piece is not None:
            own = self.pieces_of_color(self.color_of_piece_char(piece))
        for direction in directions:
            target_square = square + direction
            if 0 <= target_square < 64:
                if abs((square % 8) - (target_square % 8)) <= 2:
                    if not (own & self.square_mask(target_square)):
                        knight_moves |= self.square_mask(target_square)
        return knight_moves

    def compute_pawn_moves(self, square, is_white):
        pawn_moves = 0
        occ = self.occupancy()
        own = self.pieces_of_color(is_white)
        enemy = occ & ~own

        if is_white:
            one_forward = square + 8
            two_forward = square + 16
            if one_forward < 64 and not (occ & self.square_mask(one_forward)):
                pawn_moves |= self.square_mask(one_forward)
                if (square // 8) == 1 and two_forward < 64 and not (occ & self.square_mask(two_forward)):
                    pawn_moves |= self.square_mask(two_forward)
            if square % 8 > 0:
                left = square + 7
                if left < 64 and (enemy & self.square_mask(left)):
                    pawn_moves |= self.square_mask(left)
            if square % 8 < 7:
                right = square + 9
                if right < 64 and (enemy & self.square_mask(right)):
                    pawn_moves |= self.square_mask(right)
            if self.en_passant_target is not None and (square // 8) == 4:
                if square % 8 > 0 and self.en_passant_target == square + 7:
                    pawn_moves |= self.square_mask(self.en_passant_target)
                if square % 8 < 7 and self.en_passant_target == square + 9:
                    pawn_moves |= self.square_mask(self.en_passant_target)
        else:
            one_forward = square - 8
            two_forward = square - 16
            if one_forward >= 0 and not (occ & self.square_mask(one_forward)):
                pawn_moves |= self.square_mask(one_forward)
                if (square // 8) == 6 and two_forward >= 0 and not (occ & self.square_mask(two_forward)):
                    pawn_moves |= self.square_mask(two_forward)
            if square % 8 > 0:
                left = square - 9
                if left >= 0 and (enemy & self.square_mask(left)):
                    pawn_moves |= self.square_mask(left)
            if square % 8 < 7:
                right = square - 7
                if right >= 0 and (enemy & self.square_mask(right)):
                    pawn_moves |= self.square_mask(right)
            if self.en_passant_target is not None and (square // 8) == 3:
                if square % 8 > 0 and self.en_passant_target == square - 9:
                    pawn_moves |= self.square_mask(self.en_passant_target)
                if square % 8 < 7 and self.en_passant_target == square - 7:
                    pawn_moves |= self.square_mask(self.en_passant_target)

        return pawn_moves

    def compute_rook_moves(self, square, piece=None):
        rook_moves = 0
        occ = self.occupancy()
        own = 0
        if piece is not None:
            own = self.pieces_of_color(self.color_of_piece_char(piece))
        directions = [1, -1, 8, -8]
        for direction in directions:
            target_square = square
            while True:
                target_square += direction
                if not (0 <= target_square < 64):
                    break
                if direction in [1, -1] and (target_square // 8) != (square // 8):
                    break
                mask = self.square_mask(target_square)
                if own & mask:
                    break
                rook_moves |= mask
                if occ & mask:
                    break
        return rook_moves

    def compute_bishop_moves(self, square, piece=None):
        bishop_moves = 0
        occ = self.occupancy()
        own = 0
        if piece is not None:
            own = self.pieces_of_color(self.color_of_piece_char(piece))
        directions = [7, -7, 9, -9]
        for direction in directions:
            target_square = square
            while True:
                target_square += direction
                if not (0 <= target_square < 64):
                    break
                if abs((target_square % 8) - (square % 8)) != abs((target_square // 8) - (square // 8)):
                    break
                mask = self.square_mask(target_square)
                if own & mask:
                    break
                bishop_moves |= mask
                if occ & mask:
                    break
        return bishop_moves

    def compute_queen_moves(self, square, piece=None):
        return self.compute_rook_moves(square, piece) | self.compute_bishop_moves(square, piece)

    def get_all_moves(self, square):
        piece = None
        from_mask = self.square_mask(square)
        for p, bitboard in self.bitboards.items():
            if int(bitboard) & from_mask:
                piece = p
                break
        if piece is None:
            return 0

        if piece in ['K', 'k']:
            return self.compute_king_moves(square, piece)
        elif piece in ['N', 'n']:
            return self.compute_knight_moves(square, piece)
        elif piece in ['P']:
            return self.compute_pawn_moves(square, True)
        elif piece in ['p']:
            return self.compute_pawn_moves(square, False)
        elif piece in ['R', 'r']:
            return self.compute_rook_moves(square, piece)
        elif piece in ['B', 'b']:
            return self.compute_bishop_moves(square, piece)
        elif piece in ['Q', 'q']:
            return self.compute_queen_moves(square, piece)
        else:
            return 0

    def ray_attacks_from(self, square, directions):
        occ = self.occupancy()
        attacks = 0
        for direction in directions:
            target = square
            while True:
                target += direction
                if not (0 <= target < 64):
                    break
                if direction in [1, -1] and (target // 8) != (square // 8):
                    break
                if direction in [7, -7, 9, -9]:
                    if abs((target % 8) - (square % 8)) != abs((target // 8) - (square // 8)):
                        break
                attacks |= self.square_mask(target)
                if occ & self.square_mask(target):
                    break
        return attacks
    def compute_pawn_attacks(self, square, piece):
        """
        Attaques du pion (utilisé par is_square_attacked)
        """
        attacks = 0
        is_white = piece.isupper()

        if is_white:
            # Pion blanc attaque vers le haut
            if square % 8 > 0 and square + 7 < 64:  # Attaque diagonale gauche
                attacks |= self.square_mask(square + 7)
            if square % 8 < 7 and square + 9 < 64:  # Attaque diagonale droite
                attacks |= self.square_mask(square + 9)
        else:
            # Pion noir attaque vers le bas
            if square % 8 < 7 and square - 7 >= 0:  # Attaque diagonale droite
                attacks |= self.square_mask(square - 7)
            if square % 8 > 0 and square - 9 >= 0:  # Attaque diagonale gauche
                attacks |= self.square_mask(square - 9)

        return attacks
    def is_square_attacked(self, square, by_white):
        """
        Version corrigée qui utilise compute_king_moves_basic pour éviter la récursion
        """
        mask = self.square_mask(square)

        # Vérifier les attaques de pions (itérer uniquement sur bits posés)
        pawn_bb = int(self.bitboards['P'] if by_white else self.bitboards['p'])
        pawn_piece = 'P' if by_white else 'p'
        temp = pawn_bb
        while temp:
            i = (temp & -temp).bit_length() - 1
            if self.compute_pawn_attacks(i, pawn_piece) & mask:
                return True
            temp &= temp - 1

        # Vérifier les attaques de cavaliers
        knight_bb = int(self.bitboards['N'] if by_white else self.bitboards['n'])
        knight_piece = 'N' if by_white else 'n'
        temp = knight_bb
        while temp:
            i = (temp & -temp).bit_length() - 1
            if self.compute_knight_moves(i, knight_piece) & mask:
                return True
            temp &= temp - 1
        # Vérifier les attaques de fous et dames (diagonales)
        bishop_bb = int(self.bitboards['B'] if by_white else self.bitboards['b'])
        queen_bb = int(self.bitboards['Q'] if by_white else self.bitboards['q'])
        temp = (bishop_bb | queen_bb)
        while temp:
            i = (temp & -temp).bit_length() - 1
            if self.ray_attacks_from(i, [7, -7, 9, -9]) & mask:
                return True
            temp &= temp - 1

        # Vérifier les attaques de tours et dames (lignes/colonnes)
        rook_bb = int(self.bitboards['R'] if by_white else self.bitboards['r'])
        temp = (rook_bb | queen_bb)
        while temp:
            i = (temp & -temp).bit_length() - 1
            if self.ray_attacks_from(i, [1, -1, 8, -8]) & mask:
                return True
            temp &= temp - 1

        # Vérifier les attaques du roi (UTILISER LA VERSION BASIQUE)
        king_bb = int(self.bitboards['K'] if by_white else self.bitboards['k'])
        king_piece = 'K' if by_white else 'k'
        temp = king_bb
        while temp:
            i = (temp & -temp).bit_length() - 1
            if self.compute_king_moves_basic(i, king_piece) & mask:
                return True
            temp &= temp - 1

        return False
    def is_in_check(self, white_color):
        king_piece = 'K' if white_color else 'k'
        king_bb = int(self.bitboards.get(king_piece, 0))
        if king_bb == 0:
            return False
        # find king square via bit-scan
        king_square = (king_bb & -king_bb).bit_length() - 1
        return self.is_square_attacked(king_square, by_white=not white_color)

    def move_piece(self, from_sq: int, to_sq: int, promotion: str = None):
        """
        Déplace une pièce de from_sq vers to_sq, avec option de promotion.
        Enregistre uniquement les deltas nécessaires pour un undo rapide.
        """
        from_mask = self.square_mask(from_sq)
        to_mask = self.square_mask(to_sq)

        # Préparer l'état précédent minimal (pas de snapshot complet):
        prev_castling = dict(self.castling_rights)
        prev_en_passant = self.en_passant_target
        prev_white_to_move = self.white_to_move

        moving_piece = None
        captured_piece = None
        captured_square = None

        # Trouver la pièce qui bouge
        for piece, bb in self.bitboards.items():
            if bb & from_mask:
                moving_piece = piece
                break
        if moving_piece is None:
            raise RuntimeError(f"Aucune pièce trouvée sur la case {from_sq}")

        # Gérer capture normale sur la case de destination (détecter avant mutation)
        for piece, bb in self.bitboards.items():
            if bb & to_mask:
                captured_piece = piece
                captured_square = to_sq
                break

        # Gérer capture en passant (détecter avant mutation)
        if moving_piece in ('P', 'p') and self.en_passant_target is not None and to_sq == self.en_passant_target:
            if moving_piece == 'P':  # blanc capture vers le haut
                captured_square = to_sq - 8
                captured_piece = 'p'
            else:  # noir
                captured_square = to_sq + 8
                captured_piece = 'P'

        # Avant d'appliquer la mutation, créer un enregistrement minimal pour undo
        self.history.append({
            'from': from_sq,
            'to': to_sq,
            'moving_piece': moving_piece,
            'captured_piece': captured_piece,
            'captured_square': captured_square,
            'promotion': None,
            'prev_castling': prev_castling,
            'prev_en_passant': prev_en_passant,
            'prev_white_to_move': prev_white_to_move,
        })

        # Appliquer la capture si présente (normale ou en-passant)
        if captured_piece is not None and captured_square is not None:
            self.bitboards[captured_piece] &= ~self.square_mask(int(captured_square))

        # Déplacer la pièce (mutation en place)
        self.bitboards[moving_piece] &= ~from_mask
        self.bitboards[moving_piece] |= to_mask

        # Gestion des promotions : si un pion arrive sur la dernière rangée sans argument, promouvoir en dame
        promotion_needed = False
        if moving_piece == 'P' and (to_sq // 8) == 7:
            promotion_needed = True
        if moving_piece == 'p' and (to_sq // 8) == 0:
            promotion_needed = True

        if promotion_needed:
            promoted_piece = promotion if promotion else ('Q' if moving_piece == 'P' else 'q')
            # retirer le pion promu
            self.bitboards[moving_piece] &= ~to_mask
            # ajouter la pièce promue
            self.bitboards[promoted_piece] |= to_mask
            # enregistrer dans l'historique
            self.history[-1]['promotion'] = promoted_piece
        else:
            # si une promotion explicite est fournie
            if promotion and moving_piece in ('P', 'p'):
                promoted_piece = promotion if moving_piece == 'P' else promotion.lower()
                self.bitboards[moving_piece] &= ~to_mask
                self.bitboards[promoted_piece] |= to_mask
                self.history[-1]['promotion'] = promoted_piece

    # Roque: déplacer la tour si nécessaire
        if moving_piece in ('K', 'k') and abs(from_sq - to_sq) == 2:
            if moving_piece == 'K':
                if to_sq == 6:   # petit roque
                    self.bitboards['R'] &= ~self.square_mask(7)
                    self.bitboards['R'] |= self.square_mask(5)
                elif to_sq == 2:  # grand roque
                    self.bitboards['R'] &= ~self.square_mask(0)
                    self.bitboards['R'] |= self.square_mask(3)
            else:
                if to_sq == 62:
                    self.bitboards['r'] &= ~self.square_mask(63)
                    self.bitboards['r'] |= self.square_mask(61)
                elif to_sq == 58:
                    self.bitboards['r'] &= ~self.square_mask(56)
                    self.bitboards['r'] |= self.square_mask(59)

        # Mettre à jour les droits de roque et l'en passant (utilise l'état courant)
        self.update_castling_rights(moving_piece, from_sq)
        self.update_en_passant(moving_piece, from_sq, to_sq)

        # Mettre à jour l'enregistrement d'historique (cohérence)
        self.history[-1]['prev_castling'] = prev_castling
        self.history[-1]['prev_en_passant'] = prev_en_passant

        # Vérifier la légalité : le camp qui a joué ne doit pas être en échec après le coup
        mover_is_white = moving_piece.isupper()
        try:
            illegal = self.is_in_check(mover_is_white)
        except Exception:
            # en cas d'erreur lors du test d'échec, considérer comme illégal
            illegal = True

        if illegal:
            # Utiliser undo_move() pour restaurer l'état via l'enregistrement minimal
            # (undo_move() poppera l'entrée d'historique que nous avons ajoutée).
            self.undo_move()
            raise ValueError("Illegal move: leaves king in check")

        # Changer le trait
        self.white_to_move = not self.white_to_move
        # Recalculer l'état d'échec pour le côté à jouer
        self.check = self.is_in_check(self.white_to_move)

    def undo_move(self):
        """
        Annule le dernier coup enregistré dans self.history.

        Compatibilité:
        - Si l'entrée d'historique contient 'prev_bitboards', on restaure l'ancien comportement (copie complète).
        - Sinon, on suppose que l'entrée est un enregistrement minimal (delta) créé par move_piece:
        keys attendues: 'from', 'to', 'moving_piece', 'captured_piece' (ou None),
                        'captured_square' (case de capture, utile pour en-passant),
                        'promotion' (None ou 'Q'/'R'... uppercase letter for white),
                        'prev_castling', 'prev_en_passant', 'prev_white_to_move'
        """
        if not hasattr(self, "history") or not self.history:
            # Pas d'historique
            # (Garder le même message qu'avant pour compatibilité ; tu peux le supprimer en prod.)
            print("No move to undo.")
            return

        record = self.history.pop()

        # --- Old format: full bitboards saved ---
        if isinstance(record, dict) and "prev_bitboards" in record:
            prev_bitboards = record['prev_bitboards']
            # restaurer chaque bitboard (on remet le type numpy.uint64 pour rester consistant)
            for k, v in prev_bitboards.items():
                self.bitboards[k] = int(v)
            self.en_passant_target = record.get('prev_en_passant')
            self.castling_rights = dict(record.get('prev_castling', {}))
            self.white_to_move = record.get('prev_white_to_move', self.white_to_move)
            # recompute check for side to move
            self.check = self.is_in_check(self.white_to_move)
            return

        # --- New / minimal format: apply inverse du delta ---
        # Sécurité: vérifier clés minimales
        required = {'from', 'to', 'moving_piece', 'prev_castling', 'prev_en_passant', 'prev_white_to_move'}
        if not required.issubset(record.keys()):
            # Format inattendu : essayer la restauration prudente (rappel au développeur)
            # pour éviter laisser l'état corrompu.
            raise RuntimeError("undo_move: historique dans un format inattendu: {}".format(record.keys()))

        from_sq = int(record['from'])
        to_sq = int(record['to'])
        from_mask = self.square_mask(from_sq)
        to_mask = self.square_mask(to_sq)

        moving_piece = record['moving_piece']          # pièce qui a bougé (ex: 'P' ou 'k')
        captured_piece = record.get('captured_piece') # None ou lettre pièce capturée
        captured_square = record.get('captured_square')  # case où la capture a eu lieu (utile pour en-passant)
        promotion = record.get('promotion')           # None ou 'Q','R',...

        # Si une promotion a eu lieu, la pièce sur 'to' est le promoted_piece (ex 'Q' ou 'q')
        if promotion:
            # déterminer le caractère de la pièce promue selon le trait précédent
            # record['prev_white_to_move'] contient le trait **avant** le coup (donc l'auteur du coup)
            # si prev_white_to_move == True => move was by White, promotion letter uppercase
            promoted_piece = promotion if record['prev_white_to_move'] else promotion.lower()
            # enlever la pièce promue de la case 'to'
            # (il est possible que le code d'origine ait mis la promotion différemment ; ce comportement est standard)
            self.bitboards[promoted_piece] &= ~to_mask
            # remettre le pion d'origine sur la case 'from'
            pawn_piece = 'P' if record['prev_white_to_move'] else 'p'
            self.bitboards[pawn_piece] |= from_mask
        else:
            # pièce normale : déplacer la pièce de 'to' vers 'from'
            # retirer de 'to'
            self.bitboards[moving_piece] &= ~to_mask
            # remettre sur 'from'
            self.bitboards[moving_piece] |= from_mask

        # Restaurer la pièce capturée s'il y en a (capture normale ou en-passant)
        if captured_piece:
            if captured_square is None:
                # cas improbable : utiliser 'to' comme case de capture
                cap_sq = to_sq
            else:
                cap_sq = int(captured_square)
            self.bitboards[captured_piece] |= self.square_mask(cap_sq)

        # Spécial : roque — si le coup était une roque, il faut remettre la tour à sa case d'origine.
        # On peut détecter la roque par l'éloignement du roi (de 2 cases horizontales).
        # Ici on répare au cas où move_piece avait déplacé la tour en conséquence.
        # (Si move_piece n'a pas ajusté la tour, cette étape est inoffensive.)
        if moving_piece in ('K', 'k') and abs(from_sq - to_sq) == 2:
            # cas standard:
            if moving_piece == 'K':  # blanc
                if to_sq == 6:   # O-O (e1 -> g1)
                    # rook h1 (7) -> f1 (5) a été déplacée ; on remet
                    self.bitboards['R'] &= ~self.square_mask(5)
                    self.bitboards['R'] |= self.square_mask(7)
                elif to_sq == 2: # O-O-O (e1 -> c1)
                    self.bitboards['R'] &= ~self.square_mask(3)
                    self.bitboards['R'] |= self.square_mask(0)
            else:  # 'k' noir
                if to_sq == 62:  # e8 -> g8
                    self.bitboards['r'] &= ~self.square_mask(61)
                    self.bitboards['r'] |= self.square_mask(63)
                elif to_sq == 58: # e8 -> c8
                    self.bitboards['r'] &= ~self.square_mask(59)
                    self.bitboards['r'] |= self.square_mask(56)

        # Restaurer les flags
        self.castling_rights = dict(record.get('prev_castling', self.castling_rights))
        self.en_passant_target = record.get('prev_en_passant', self.en_passant_target)
        self.white_to_move = bool(record.get('prev_white_to_move', self.white_to_move))

        # Recalculer l'état d'échec pour le côté à jouer (simple vérif)
        self.check = self.is_in_check(self.white_to_move)

        return

    def update_castling_rights(self, moving_piece: str, from_sq: int):
        """
        Met à jour les droits de roque après un déplacement ou une capture.
        """
        # Supprime les droits de roque si le roi bouge
        if moving_piece == 'K':
            self.castling_rights['K'] = False
            self.castling_rights['Q'] = False
        elif moving_piece == 'k':
            self.castling_rights['k'] = False
            self.castling_rights['q'] = False

        # Supprime les droits si une tour bouge depuis sa case d’origine
        elif moving_piece == 'R':
            if from_sq == 0:
                self.castling_rights['Q'] = False  # Tour a1
            elif from_sq == 7:
                self.castling_rights['K'] = False  # Tour h1
        elif moving_piece == 'r':
            if from_sq == 56:
                self.castling_rights['q'] = False  # Tour a8
            elif from_sq == 63:
                self.castling_rights['k'] = False  # Tour h8

        # Supprime les droits si une tour d’origine est capturée
        # (utile si update_castling_rights est appelée après la capture)
        if self.bitboards['R'] & self.square_mask(0) == 0:
            self.castling_rights['Q'] = False
        if self.bitboards['R'] & self.square_mask(7) == 0:
            self.castling_rights['K'] = False
        if self.bitboards['r'] & self.square_mask(56) == 0:
            self.castling_rights['q'] = False
        if self.bitboards['r'] & self.square_mask(63) == 0:
            self.castling_rights['k'] = False

    def update_en_passant(self, moving_piece: str, from_sq: int, to_sq: int):
        """
        Met à jour la cible en passant (ou la désactive).
        """
        self.en_passant_target = None  # par défaut : désactivé

        # Pion blanc avance de 2 cases
        if moving_piece == 'P' and from_sq // 8 == 1 and to_sq // 8 == 3:
            self.en_passant_target = from_sq + 8

        # Pion noir avance de 2 cases
        elif moving_piece == 'p' and from_sq // 8 == 6 and to_sq // 8 == 4:
            self.en_passant_target = from_sq - 8

# AJOUTEZ CETTE MÉTHODE DANS VOTRE CLASSE Chess DANS LE FICHIER Chess.py

    def load_fen(self, fen_string: str):
        """
        Charge une position d'échecs à partir d'une chaîne FEN.
        Met à jour les bitboards internes.
        (Version simplifiée qui ne gère que la position des pièces).
        """
        # Réinitialiser tous les bitboards
        for piece in self.bitboards.keys():
            self.bitboards[piece] = 0

        parts = fen_string.split(' ')
        piece_placement = parts[0]

        rank = 7
        file = 0
        for char in piece_placement:
            if char.isalpha():
                self.bitboards[char] |= (1 << (rank * 8 + file))
                file += 1
            elif char.isdigit():
                file += int(char)
            elif char == '/':
                rank -= 1
                file = 0

        # Vous pouvez étendre cette fonction pour gérer aussi le trait,
        # les roques, etc., si votre évaluation en a besoin.
        if len(parts) > 1:
            self.white_to_move = (parts[1] == 'w')

Overwriting Chess.py


## 🔧 Code: torch_nn_evaluator.py

In [20]:
# @title
%%writefile torch_nn_evaluator.py
import numpy as np
import torch
import torch.nn as nn
from Chess import Chess


class TorchNNEvaluator(nn.Module):
    """PyTorch implementation équivalente du `NeuralNetworkEvaluator` en NumPy.

    - architecture: Linear(input -> hidden) -> LeakyReLU -> Dropout -> Linear(hidden -> hidden) -> LeakyReLU -> Dropout -> Linear(hidden -> out)
    - fournit des helpers pour charger/sauver au format .npz (compatibilité avec l'ancien code NumPy)
    - fournit des helpers pour checkpoint/restore PyTorch (optimizer.state_dict)
    - Support GPU automatique
    """

    def __init__(self, input_size=768, hidden_size=256, output_size=1, dropout=0.3, leaky_alpha=0.01):
        super().__init__()
        self.l1 = nn.Linear(input_size, hidden_size)
        self.l2 = nn.Linear(hidden_size, hidden_size)
        self.l3 = nn.Linear(hidden_size, output_size)
        self.dropout1 = nn.Dropout(p=dropout)
        self.dropout2 = nn.Dropout(p=dropout)
        self.leaky_relu = nn.LeakyReLU(negative_slope=leaky_alpha)

        self.piece_to_index = {
            'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
            'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }
        self.input_size = input_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.leaky_relu(self.l1(x))
        x = self.dropout1(x)
        x = self.leaky_relu(self.l2(x))
        x = self.dropout2(x)
        return self.l3(x)

    def encode_board(self, chess_instance: Chess, device='cpu') -> torch.Tensor:
        vec = np.zeros(self.input_size, dtype=np.float32)
        for piece_char, bitboard in chess_instance.bitboards.items():
            if bitboard == 0:
                continue
            piece_index = self.piece_to_index[piece_char]
            temp_bb = int(bitboard)
            while temp_bb:
                square = (temp_bb & -temp_bb).bit_length() - 1
                vector_position = piece_index * 64 + square
                vec[vector_position] = 1.0
                temp_bb &= temp_bb - 1
        t = torch.from_numpy(vec).to(torch.float32).to(device)
        return t.unsqueeze(0)  # shape (1, input_size)

    def evaluate_position(self, chess_instance: Chess, device='cpu') -> float:
        x = self.encode_board(chess_instance, device=device)
        self.to(device)
        self.eval()
        with torch.no_grad():
            out = self.forward(x)
        normalized_score = out[0, 0].item()
        EVAL_SCALE_FACTOR = 1000.0
        return normalized_score * EVAL_SCALE_FACTOR


def save_weights_npz(model: TorchNNEvaluator, filename: str, adam_moments: dict = None):
    """Sauvegarde les poids du modèle dans un .npz compatible avec l'ancien format NumPy.

    Le format correspond aux clés attendues par `nn_evaluator.load_evaluator_from_file` :
    - w1: shape (input, hidden)
    - b1: shape (1, hidden)
    - w2, b2, w3, b3
    On convertit les poids PyTorch (weight shape: out, in) en (in, out).
    """
    sd = model.state_dict()
    save_dict = {
        'w1': sd['l1.weight'].cpu().numpy().T,
        'b1': sd['l1.bias'].cpu().numpy().reshape(1, -1),
        'w2': sd['l2.weight'].cpu().numpy().T,
        'b2': sd['l2.bias'].cpu().numpy().reshape(1, -1),
        'w3': sd['l3.weight'].cpu().numpy().T,
        'b3': sd['l3.bias'].cpu().numpy().reshape(1, -1),
    }
    if adam_moments is not None:
        # ensure numpy arrays
        for k, v in dict(adam_moments).items():
            save_dict[k] = np.array(v)

    np.savez(filename, **save_dict)
    if adam_moments is not None:
        print(f"Poids et moments Adam sauvegardés (npz) dans {filename}")
    else:
        print(f"Poids sauvegardés (npz) dans {filename}")


def load_from_npz(filename: str, device='cpu'):
    """Charge un .npz produit par la version NumPy et renvoie (model, adam_moments)

    - adam_moments (si présent) est renvoyé sous forme de dict de tensors (torch.float32)
    - si les moments Adam ne sont pas tous présents, on renvoie None pour adam_moments
    """
    data = np.load(filename)
    # infer sizes
    w1 = data['w1']
    b1 = data['b1']
    w2 = data['w2']
    b2 = data['b2']
    w3 = data['w3']
    b3 = data['b3']
    input_size = int(w1.shape[0])
    hidden_size = int(w1.shape[1])
    output_size = int(w3.shape[1]) if w3.ndim == 2 else 1

    model = TorchNNEvaluator(input_size=input_size, hidden_size=hidden_size, output_size=output_size)
    # copy weights (transpose to torch linear layout)
    model.l1.weight.data.copy_(torch.from_numpy(w1.T).to(torch.float32))
    model.l1.bias.data.copy_(torch.from_numpy(b1.reshape(-1)).to(torch.float32))
    model.l2.weight.data.copy_(torch.from_numpy(w2.T).to(torch.float32))
    model.l2.bias.data.copy_(torch.from_numpy(b2.reshape(-1)).to(torch.float32))
    model.l3.weight.data.copy_(torch.from_numpy(w3.T).to(torch.float32))
    model.l3.bias.data.copy_(torch.from_numpy(b3.reshape(-1)).to(torch.float32))

    # collect adam moments if all present
    adam_moments = None
    adam_keys = ['m_w1', 'v_w1', 'm_b1', 'v_b1', 'm_w2', 'v_w2',
                 'm_b2', 'v_b2', 'm_w3', 'v_w3', 'm_b3', 'v_b3', 'adam_step']
    if all(key in data for key in adam_keys):
        adam_moments = {}
        for k in adam_keys:
            val = data[k]
            # convert scalar 0-d to Python int for adam_step
            if k == 'adam_step':
                adam_moments[k] = int(val)
            else:
                adam_moments[k] = torch.from_numpy(np.array(val)).to(torch.float32)

    model.to(device)
    return model, adam_moments


def torch_save_checkpoint(path: str, model: TorchNNEvaluator, optimizer=None, step: int = None):
    ckpt = {'model': model.state_dict()}
    if optimizer is not None:
        ckpt['optim'] = optimizer.state_dict()
    if step is not None:
        ckpt['step'] = int(step)
    torch.save(ckpt, path)
    print(f"Checkpoint PyTorch sauvegardé dans {path}")


def torch_load_checkpoint(path: str, model: TorchNNEvaluator = None, optimizer=None, device='cpu'):
    ckpt = torch.load(path, map_location=device)
    if model is not None:
        model.load_state_dict(ckpt['model'])
        model.to(device)
    optim_state = ckpt.get('optim')
    if optimizer is not None and optim_state is not None:
        optimizer.load_state_dict(optim_state)
    step = ckpt.get('step')
    return model, optim_state, step


if __name__ == '__main__':
    # Exemple d'utilisation similaire au fichier NumPy original
    WEIGHTS_FILE = 'chess_nn_weights_from_torch.npz'
    game = Chess()

    print('--- Création d un réseau PyTorch non entraîné ---')
    model = TorchNNEvaluator(input_size=768, hidden_size=256, output_size=1)
    # évaluation cpu
    score1 = model.evaluate_position(game)
    print(f'Score du réseau PyTorch vierge : {score1:.2f}')

    # sauvegarde en .npz (pour compatibilité)
    save_weights_npz(model, WEIGHTS_FILE)

    # rechargement depuis .npz
    print(f"\n--- Chargement du réseau depuis le fichier '{WEIGHTS_FILE}' ---")
    loaded_model, adam_moms = load_from_npz(WEIGHTS_FILE)
    score2 = loaded_model.evaluate_position(game)
    print(f"Score du réseau chargé : {score2:.2f}")
    if adam_moms is not None:
        print(f"Moments Adam chargés (step={adam_moms['adam_step']})")
    # la valeur peut légèrement différer en float32 mais doit être proche
    try:
        assert abs(score1 - score2) < 1e-3
    except AssertionError:
        print('Attention: score initial et score chargé diffèrent; vérifie dtypes/precision')



Overwriting torch_nn_evaluator.py


## 🔧 Code: train_torch.py

In [26]:
# @title
%%writefile train_torch.py
"""
Script d'entraînement PyTorch optimisé pour GPU
Compatible avec Google Colab et machines locales avec GPU
"""
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os

from Chess import Chess
from torch_nn_evaluator import TorchNNEvaluator, save_weights_npz, load_from_npz, torch_save_checkpoint, torch_load_checkpoint

# --- CONFIGURATION DE L'ENTRAÎNEMENT ---
DATASET_PATH = "/content/drive/MyDrive/smart_chess_ia/chessData.csv"  # Adapté pour Colab (fichier à la racine)
WEIGHTS_FILE = "chess_nn_weights.npz"
CHECKPOINT_FILE = "chess_model_checkpoint.pt"

# Architecture
HIDDEN_SIZE = 256
DROPOUT = 0.3
LEAKY_ALPHA = 0.01

# Hyperparamètres
LEARNING_RATE = 0.001
WEIGHT_DECAY = 1e-4  # L2 regularization (AdamW)
EPOCHS = 20
BATCH_SIZE = 128  # Plus grand pour GPU
MAX_SAMPLES = 500_000  # Plus de données avec GPU !
EVAL_MAX_SAMPLES = 5000

# Options
USE_SAMPLING = True
RESET_WEIGHTS = False
DEBUG_STATS = True

# LR Scheduler
USE_LR_SCHEDULER = True
LR_PATIENCE = 2
LR_FACTOR = 0.5

# LR Warmup
USE_LR_WARMUP = True
WARMUP_EPOCHS = 3
WARMUP_START_LR = 0.0001

# Device (auto-détection GPU)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"🖥️  Device: {DEVICE}")
if torch.cuda.is_available():
    print(f"🚀 GPU: {torch.cuda.get_device_name(0)}")
    print(f"💾 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")


class ChessDataset(Dataset):
    """Dataset PyTorch pour les positions d'échecs"""
    def __init__(self, fens, evaluations):
        self.fens = fens
        self.evaluations = evaluations
        self.chess = Chess()

        # Précalculer l'encodage pour accélérer (optionnel, consomme plus de RAM)
        # self.encoded = [self._encode_fen(fen) for fen in tqdm(fens, desc="Encoding positions")]

    def __len__(self):
        return len(self.fens)

    def __getitem__(self, idx):
        fen = self.fens[idx]
        target = self.evaluations[idx]

        # Encoder la position
        self.chess.load_fen(fen)
        encoded = self._encode_board(self.chess)

        return torch.from_numpy(encoded).float(), torch.tensor([target], dtype=torch.float32)

    def _encode_board(self, chess_instance):
        """Encode le plateau en vecteur 768D (identique à nn_evaluator.py)"""
        piece_to_index = {
            'P': 0, 'N': 1, 'B': 2, 'R': 3, 'Q': 4, 'K': 5,
            'p': 6, 'n': 7, 'b': 8, 'r': 9, 'q': 10, 'k': 11
        }
        vec = np.zeros(768, dtype=np.float32)
        for piece_char, bitboard in chess_instance.bitboards.items():
            if bitboard == 0:
                continue
            piece_index = piece_to_index[piece_char]
            temp_bb = int(bitboard)
            while temp_bb:
                square = (temp_bb & -temp_bb).bit_length() - 1
                vector_position = piece_index * 64 + square
                vec[vector_position] = 1.0
                temp_bb &= temp_bb - 1
        return vec


def load_data(filepath: str):
    """Charge le dataset FEN,Evaluation et le nettoie."""
    print(f"📂 Chargement du dataset depuis {filepath}...")

    df = pd.read_csv(
        filepath,
        names=['FEN', 'Evaluation'],
        skiprows=1,
        comment='#'
    )

    initial_count = len(df)
    df.dropna(inplace=True)
    cleaned_count = len(df)

    if initial_count > cleaned_count:
        print(f"🧹 Nettoyage : {initial_count - cleaned_count} lignes corrompues supprimées.")

    fens = df['FEN'].values
    EVAL_SCALE_FACTOR = 1000.0
    evaluations = (df['Evaluation'].astype(int).values) / EVAL_SCALE_FACTOR

    print(f"✅ {len(fens):,} positions valides chargées.")
    return fens, evaluations


def evaluate_model(model, dataloader, device):
    """Évalue le modèle sur un dataset"""
    model.eval()
    predictions = []
    targets = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            predictions.extend(outputs.cpu().numpy().flatten())
            targets.extend(labels.numpy().flatten())

    predictions = np.array(predictions)
    targets = np.array(targets)

    rmse = float(np.sqrt(np.mean((predictions - targets) ** 2)))
    mae = float(np.mean(np.abs(predictions - targets)))
    corr = float(np.corrcoef(predictions, targets)[0, 1]) if len(predictions) > 1 else 0.0

    return rmse, mae, corr, predictions, targets


def main():
    # 1. Charger les données
    all_fens, all_evaluations = load_data(DATASET_PATH)

    print(f"\n📊 Dataset complet: {len(all_fens):,} positions")

    eval_mean = float(np.mean(all_evaluations))

    # 2. Initialiser le modèle
    if RESET_WEIGHTS and os.path.exists(WEIGHTS_FILE):
        print(f"🗑️  Suppression des anciens poids: {WEIGHTS_FILE}")
        os.remove(WEIGHTS_FILE)

    if os.path.exists(CHECKPOINT_FILE):
        print(f"📥 Chargement du checkpoint PyTorch: {CHECKPOINT_FILE}")
        model = TorchNNEvaluator(hidden_size=HIDDEN_SIZE, dropout=DROPOUT, leaky_alpha=LEAKY_ALPHA)
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        model, _, start_step = torch_load_checkpoint(CHECKPOINT_FILE, model, optimizer, device=DEVICE)
        print(f"✅ Checkpoint chargé (step {start_step})")
    elif os.path.exists(WEIGHTS_FILE):
        print(f"📥 Chargement des poids NumPy: {WEIGHTS_FILE}")
        model, adam_moments = load_from_npz(WEIGHTS_FILE, device=DEVICE)
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
        # TODO: Restaurer les moments Adam si présents
        print(f"✅ Poids chargés depuis NumPy")
    else:
        print("🆕 Création d'un nouveau réseau...")
        model = TorchNNEvaluator(hidden_size=HIDDEN_SIZE, dropout=DROPOUT, leaky_alpha=LEAKY_ALPHA)
        # Initialisation He (PyTorch le fait déjà par défaut pour Linear + ReLU)
        optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

        # Warm-start du biais de sortie
        with torch.no_grad():
            model.l3.bias[0] = eval_mean

    model.to(DEVICE)

    # 3. Configuration de l'entraînement
    criterion = nn.MSELoss()

    # LR Scheduler
    if USE_LR_SCHEDULER:
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode='min', factor=LR_FACTOR,
            patience=LR_PATIENCE
        )

    print(f"\n{'='*70}")
    print(f"Configuration:")
    print(f"  Dataset complet: {len(all_fens):,} positions")
    print(f"  Échantillon/epoch: {MAX_SAMPLES if USE_SAMPLING else len(all_fens):,} positions")
    print(f"  Architecture: 768 → {HIDDEN_SIZE} → {HIDDEN_SIZE} → 1")
    print(f"  Dropout: {DROPOUT}")
    print(f"  LeakyReLU alpha: {LEAKY_ALPHA}")
    print(f"  Learning rate: {LEARNING_RATE} (AdamW, weight decay: {WEIGHT_DECAY})")
    print(f"  LR Warmup: {USE_LR_WARMUP} ({WARMUP_START_LR if USE_LR_WARMUP else 'N/A'} → {LEARNING_RATE})")
    print(f"  LR Scheduler: {USE_LR_SCHEDULER} (patience: {LR_PATIENCE if USE_LR_SCHEDULER else 'N/A'})")
    print(f"  Batch size: {BATCH_SIZE}")
    print(f"  Epochs: {EPOCHS}")
    print(f"  Device: {DEVICE}")
    print(f"{'='*70}\n")

    # 4. Boucle d'entraînement
    best_rmse = float('inf')

    for epoch in range(EPOCHS):
        # Échantillonnage à chaque epoch
        if USE_SAMPLING and len(all_fens) > MAX_SAMPLES:
            print(f"\n[Epoch {epoch+1}] 🎲 Échantillonnage: {MAX_SAMPLES:,} positions sur {len(all_fens):,}")
            idx = np.random.choice(len(all_fens), size=MAX_SAMPLES, replace=False)
            fens = all_fens[idx]
            evaluations = all_evaluations[idx]
        else:
            fens = all_fens
            evaluations = all_evaluations

        # LR Warmup
        if USE_LR_WARMUP and epoch < WARMUP_EPOCHS:
            warmup_progress = (epoch + 1) / WARMUP_EPOCHS
            lr = WARMUP_START_LR + (LEARNING_RATE - WARMUP_START_LR) * warmup_progress
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
            print(f"🔥 Warmup epoch {epoch+1}/{WARMUP_EPOCHS}: LR = {lr:.6f}")

        # Créer le dataset et dataloader
        train_dataset = ChessDataset(fens, evaluations)
        train_loader = DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE,
            shuffle=True,
            num_workers=0,  # Augmenter si CPU multi-core (ex: 4)
            pin_memory=True if torch.cuda.is_available() else False
        )

        # Training
        model.train()
        total_loss = 0
        num_batches = 0

        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{EPOCHS}")
        for batch_idx, (inputs, targets) in enumerate(progress_bar):
            inputs = inputs.to(DEVICE)
            targets = targets.to(DEVICE)

            # Forward
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            # Backward
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)

            # Update
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1

            # Debug stats (premier batch)
            if DEBUG_STATS and epoch == 0 and batch_idx == 0:
                with torch.no_grad():
                    preds = outputs.cpu().numpy().flatten()
                    targs = targets.cpu().numpy().flatten()
                    batch_rmse = np.sqrt(np.mean((preds - targs) ** 2))
                    corr = np.corrcoef(preds, targs)[0, 1] if len(preds) > 1 else 0.0
                    print(f"\n[DEBUG batch 0] targets mean={targs.mean():.4f} std={targs.std():.4f}; "
                          f"preds mean={preds.mean():.4f} std={preds.std():.4f}; "
                          f"RMSE={batch_rmse:.4f}; corr={corr:.4f}")

            # Update progress bar
            avg_loss = total_loss / num_batches
            progress_bar.set_postfix({"loss": f"{np.sqrt(avg_loss):.4f}"})

        # Évaluation fin d'époque
        print(f"\n🔍 Évaluation epoch {epoch+1}...")

        # Échantillon d'évaluation
        if EVAL_MAX_SAMPLES and len(all_fens) > EVAL_MAX_SAMPLES:
            eval_idx = np.random.choice(len(all_fens), size=EVAL_MAX_SAMPLES, replace=False)
            eval_fens = all_fens[eval_idx]
            eval_targets = all_evaluations[eval_idx]
        else:
            eval_fens = all_fens
            eval_targets = all_evaluations

        eval_dataset = ChessDataset(eval_fens, eval_targets)
        eval_loader = DataLoader(eval_dataset, batch_size=BATCH_SIZE*2, shuffle=False)

        rmse, mae, corr, preds, targets = evaluate_model(model, eval_loader, DEVICE)

        baseline_rmse = targets.std()
        improvement = 100 * (1 - rmse / baseline_rmse) if baseline_rmse > 0 else 0

        # Affichage
        print(f"\n{'='*70}")
        print(f"EPOCH {epoch+1}/{EPOCHS} - Évaluation sur {len(eval_fens):,} positions")
        print(f"{'='*70}")
        print(f"  RMSE:        {rmse:.4f}  (baseline: {baseline_rmse:.4f})")
        print(f"  MAE:         {mae:.4f}")
        print(f"  Amélioration: {improvement:+.1f}% vs baseline")
        print(f"  Corrélation: {corr:.4f}")
        print(f"  Std preds:   {preds.std():.4f}  (cible: {targets.std():.4f})")
        print(f"  Mean preds:  {preds.mean():.4f}  (cible: {targets.mean():.4f})")

        if improvement > 50:
            print(f"  ✓✓ Performance excellente!")
        elif improvement > 30:
            print(f"  ✓  Bon apprentissage!")
        elif improvement > 10:
            print(f"  →  Apprentissage en cours")
        else:
            print(f"  ⚠  Faible amélioration - vérifier hyperparamètres")
        print(f"{'='*70}\n")

        # LR Scheduler
        if USE_LR_SCHEDULER and (not USE_LR_WARMUP or epoch >= WARMUP_EPOCHS):
            scheduler.step(rmse)

        # Sauvegarder le meilleur modèle
        if rmse < best_rmse:
            best_rmse = rmse
            print(f"💾 Nouveau meilleur RMSE: {best_rmse:.4f} - Sauvegarde...")
            torch_save_checkpoint(CHECKPOINT_FILE, model, optimizer, epoch)
            save_weights_npz(model, WEIGHTS_FILE)

    print("\n🎉 Entraînement terminé!")
    print(f"📊 Meilleur RMSE: {best_rmse:.4f}")

    # Sauvegarde finale
    print(f"\n💾 Sauvegarde finale...")
    torch_save_checkpoint(CHECKPOINT_FILE, model, optimizer, EPOCHS)
    save_weights_npz(model, WEIGHTS_FILE)
    print(f"✅ Modèle sauvegardé dans {CHECKPOINT_FILE} et {WEIGHTS_FILE}")


if __name__ == "__main__":
    main()


Overwriting train_torch.py


## 🚀 Lancer l'entraînement!

Modifie les hyperparamètres si nécessaire dans train_torch.py, puis lance:

In [None]:
# @title
!python train_torch.py

🖥️  Device: cpu
📂 Chargement du dataset depuis /content/drive/MyDrive/smart_chess_ia/chessData.csv...
🧹 Nettoyage : 190154 lignes corrompues supprimées.
✅ 12,767,881 positions valides chargées.

📊 Dataset complet: 12,767,881 positions
🆕 Création d'un nouveau réseau...

Configuration:
  Dataset complet: 12,767,881 positions
  Échantillon/epoch: 500,000 positions
  Architecture: 768 → 256 → 256 → 1
  Dropout: 0.3
  LeakyReLU alpha: 0.01
  Learning rate: 0.001 (AdamW, weight decay: 0.0001)
  LR Warmup: True (0.0001 → 0.001)
  LR Scheduler: True (patience: 2)
  Batch size: 128
  Epochs: 20
  Device: cpu


[Epoch 1] 🎲 Échantillonnage: 500,000 positions sur 12,767,881
🔥 Warmup epoch 1/3: LR = 0.000400
Epoch 1/20:   0% 0/3907 [00:00<?, ?it/s]
[DEBUG batch 0] targets mean=0.0452 std=0.6471; preds mean=0.0770 std=0.0214; RMSE=0.6486; corr=-0.0156
Epoch 1/20: 100% 3907/3907 [01:11<00:00, 54.36it/s, loss=0.7433]

🔍 Évaluation epoch 1...

EPOCH 1/20 - Évaluation sur 5,000 positions
  RMSE:        

## 💾 Télécharger les poids entraînés

À la fin de l'entraînement, télécharge les poids:

In [None]:
from google.colab import files

# Télécharge les poids au format NumPy (compatible avec ton code existant)
files.download('chess_nn_weights.npz')

# Télécharge aussi le checkpoint PyTorch (pour continuer l'entraînement plus tard)
files.download('chess_model_checkpoint.pt')

## 🧪 Test rapide du modèle entraîné

In [None]:
from Chess import Chess
from torch_nn_evaluator import load_from_npz
import torch

# Charger le modèle
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, _ = load_from_npz('chess_nn_weights.npz', device=device)

# Tester sur quelques positions
chess = Chess()

# Position initiale
chess.load_fen("rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1")
score = model.evaluate_position(chess, device=device)
print(f"Position initiale: {score:.2f} centipawns (devrait être proche de 0)")

# Position avec avantage blanc
chess.load_fen("rnbqkbnr/pppp1ppp/8/4p3/4P3/8/PPPP1PPP/RNBQKBNR w KQkq - 0 2")
score = model.evaluate_position(chess, device=device)
print(f"Position symétrique: {score:.2f} centipawns")

# Mat en 1
chess.load_fen("r1bqkb1r/pppp1Qpp/2n2n2/4p3/2B1P3/8/PPPP1PPP/RNB1K1NR b KQkq - 0 1")
score = model.evaluate_position(chess, device=device)
print(f"Mat en 1 (blancs gagnent): {score:.2f} centipawns (devrait être très positif)")

## 📊 Visualiser les performances

Tu peux aussi copier-coller ton script test_generalization.py pour voir la généralisation.

---

## 💡 Conseils

1. **Augmente MAX_SAMPLES** : Avec un GPU, tu peux facilement faire 1M positions/epoch
2. **Batch size** : Augmente à 256 ou 512 si tu as assez de GPU memory
3. **Sauvegarde régulière** : Colab peut te déconnecter après 12h, sauvegarde régulièrement!
4. **Reprendre l'entraînement** : Upload le checkpoint .pt et il continuera où tu en étais

## 🐛 Debugging

Si tu as des erreurs:
- Vérifie que le GPU est bien activé (cellule 1)
- Vérifie que tous les fichiers sont bien uploadés
- Réduis BATCH_SIZE si out of memory
- Réduis MAX_SAMPLES si le dataset prend trop de RAM