In [None]:
import sys, importlib
sys.path.append(r"C:\Users\Artem\Desktop\vs code project\chess_forecast\src")
importlib.reload(importlib.import_module('features'))
importlib.reload(importlib.import_module('move'))
importlib.reload(importlib.import_module('mymodel'))
importlib.reload(importlib.import_module('training'))

import features, move, mymodel, training
import numpy as np, pandas as pd, torch
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import precision_score, recall_score, f1_score
from sklearn.metrics import balanced_accuracy_score, cohen_kappa_score
import tqdm as tqdm

In [2]:
ALL_MOVES = move.generate_all_possible_moves()
MOVE_TO_INDEX = {m.uci(): i for i, m in enumerate(ALL_MOVES)}
INDEX_TO_MOVE = {i: m for i, m in enumerate(ALL_MOVES)}
len(MOVE_TO_INDEX)

4208

In [3]:
path = '../fens_training_set.csv'
df = pd.read_csv(path)
df = df.drop(232329, axis=0)
#df = df[-100000:]

In [4]:
#Разделение на train/val
train_df, test_df = train_test_split(df, test_size=0.1, random_state=42)

train_dataset = mymodel.ChessDataset(train_df, ALL_MOVES, MOVE_TO_INDEX)
test_dataset = mymodel.ChessDataset(test_df, ALL_MOVES, MOVE_TO_INDEX)

train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=0)

In [5]:
# Модель
num_epochs = 5
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
model = mymodel.ChessMovePredictor(num_moves=len(ALL_MOVES)).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=0.01, 
    epochs=num_epochs, steps_per_epoch=len(train_loader),
    pct_start=0.1)
criterion = mymodel.ChessMoveLoss()

In [None]:
best_val_acc = 0
for epoch in range(1, num_epochs + 1):
    print(f"\nEPOCH {epoch}/{num_epochs}")
    train_loss, train_acc = training.train_epoch(model, train_loader, optimizer, criterion, device,scheduler)
    val_metrics = training.validate_epoch(model, test_loader, criterion, device)
    val_acc = val_metrics['acc1']
    print(f"Train loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    print(f"Validation: Loss: {val_metrics['loss']:.4f}, "
      f"Acc@1: {val_metrics['acc1']:.3f}, "
      f"Acc@3: {val_metrics['acc3']:.3f}, "
      f"Acc@5: {val_metrics['acc5']:.3f}")

    if val_acc > best_val_acc:
        torch.save(model.state_dict(), "best_chess_model.pt")
        best_val_acc = val_acc
        print("Модель сохранена.")

In [None]:

model.load_state_dict(torch.load("best_chess_model.pt", map_location=device))
model.eval()

fen = "rnb2rk1/pp3ppp/4p1n1/q1pP4/3P4/1QN1P1B1/PP3PPP/R3KB1R b KQ - 0 11"
top_moves = training.predict_move(model, fen, device, ALL_MOVES, MOVE_TO_INDEX, INDEX_TO_MOVE, top_k=5) #move, prob, all_probs, 
for i in top_moves:
    print(i)

In [None]:
import inspect
print(inspect.signature(training.predict_move))

(model, fen_string, device, ALL_MOVES, MOVE_TO_INDEX, INDEX_TO_MOVE, top_k=5)


In [8]:
def top_k_precision(y_true, y_probs, k=3):
    top_k_preds = np.argsort(y_probs, axis=1)[:, -k:]
    
    precision_scores = []
    for i in range(len(y_true)):
        correct_in_top = np.sum(top_k_preds[i] == y_true[i])
        precision_scores.append(correct_in_top / k)
    
    return np.mean(precision_scores)

def mean_reciprocal_rank(y_true, y_probs):
    ranks = []
    for i in range(len(y_true)):
        # Ранжируем предсказания по убыванию вероятности
        sorted_indices = np.argsort(y_probs[i])[::-1]
        rank = np.where(sorted_indices == y_true[i])[0][0] + 1
        ranks.append(1.0 / rank)
    
    return np.mean(ranks)

def top_k_recall(y_true, y_probs, k=3):
    top_k_preds = np.argsort(y_probs, axis=1)[:, -k:]
    correct = []
    for i in range(len(y_true)):
        if y_true[i] in top_k_preds[i]:
            correct.append(1)  
        else:
            correct.append(0)  
    return np.mean(correct)


@torch.no_grad()
def for_metrics(model, test_dataset):
    all_pred = []
    all_targets = []
    all_prob = []
    total_loss = 0
       
    model.eval()
    
    for batch in tqdm.tqdm(test_dataset, desc="Test"):
        board = batch['board']
        add = batch['additional']
        target = batch['target_move']
        legal_mask = batch['legal_moves_mask']

        move_probs, move_logits = model(board, add, legal_mask)
        loss = criterion(move_logits, target, legal_mask)
        
        preds = move_probs.argmax(dim=1)

        all_pred.append(preds.cpu().numpy())
        all_targets.append(target.cpu().numpy())
        all_prob.append(move_probs.cpu().numpy())
        total_loss += loss.item()
    
    all_prob = np.concatenate(all_prob)
    all_pred = np.concatenate(all_pred)
    all_targets = np.concatenate(all_targets)

    def metrics_solve(y_true, y_pred, y_probs):
        metrics = {}
        metrics["precision"] = precision_score(y_true, y_pred, average='weighted', zero_division=0)
        metrics["recall"] = recall_score(y_true, y_pred, average='weighted', zero_division=0)
        metrics["f1"] = f1_score(y_true, y_pred, average='weighted', zero_division=0)
        metrics["accuracy"] = (y_true == y_pred).mean()
        
        metrics["balanced_accuracy"] = balanced_accuracy_score(y_true, y_pred)
        metrics["cohens_kappa"] = cohen_kappa_score(y_true, y_pred)

        actual_classes = np.unique(y_true)
        print(f"Классы в y_true: {actual_classes}")
        print(f"y_probs: {y_probs.shape}")
        for k in [1, 3, 5, 10]:
            metrics[f"recall_top{k}"] = top_k_recall(y_true, y_probs, k=k)
            
        for k in [3, 5, 10]:
            metrics[f"precision_top{k}"] = top_k_precision(y_true, y_probs, k=k)

        metrics["mrr"] = mean_reciprocal_rank(y_true, y_probs)
        return metrics

    metrics = metrics_solve(all_targets, all_pred, all_prob)

    return metrics

model.load_state_dict(torch.load("best_chess_model.pt", map_location=device))
metrics = for_metrics(model, test_dataset=test_loader)
for name, value in metrics.items():
    print(f"{name}: {value:.4f}")


Test: 100%|██████████| 53/53 [02:33<00:00,  2.90s/it]


Классы в y_true: [   0    1    2 ... 4205 4206 4207]
y_probs: (26855, 4208)
precision: 0.3818
recall: 0.3739
f1: 0.3627
accuracy: 0.3739
balanced_accuracy: 0.2415
cohens_kappa: 0.3716
recall_top1: 0.3739
recall_top3: 0.6119
recall_top5: 0.7205
recall_top10: 0.8566
precision_top3: 0.2040
precision_top5: 0.1441
precision_top10: 0.0857
mrr: 0.5307


In [15]:
@torch.no_grad()
def test_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_acc1 = 0
    total_acc3 = 0
    total_acc5 = 0

    for batch in tqdm.tqdm(dataloader, desc="Validation"):
        board = batch['board'].to(device)
        additional = batch['additional'].to(device)
        target = batch['target_move'].to(device)
        legal_mask = batch['legal_moves_mask'].to(device)

        move_probs, move_logits = model(board, additional, legal_mask)
        loss = criterion(move_logits, target, legal_mask)

        preds_top1 = move_logits.argmax(dim=1)
        acc1 = (preds_top1 == target).float().mean()
        
        # Top-3 accuracy
        top3 = move_logits.topk(3, dim=1).indices
        acc3 = (top3 == target.unsqueeze(1)).any(dim=1).float().mean()
        
        # Top-5 accuracy
        top5 = move_logits.topk(5, dim=1).indices
        acc5 = (top5 == target.unsqueeze(1)).any(dim=1).float().mean()

        total_loss += loss.item()
        total_acc1 += acc1.item()
        total_acc3 += acc3.item()
        total_acc5 += acc5.item()

    num_batches = len(dataloader)
    metrics = {
        "loss": total_loss / num_batches,
        "acc1": total_acc1 / num_batches,
        "acc3": total_acc3 / num_batches,
        "acc5": total_acc5 / num_batches}
    
    return metrics


metrics_final = test_epoch(model, test_loader, criterion, device)
print(metrics_final)

Validation: 100%|██████████| 53/53 [02:29<00:00,  2.82s/it]

{'loss': 2.1513918368321545, 'acc1': 0.37423888242469644, 'acc3': 0.6123034909086408, 'acc5': 0.7209277434169121}



