Skip to content

Commit

Permalink
Refined loss function
Browse files Browse the repository at this point in the history
this refines the loss function to the form used for the new master net in official-stockfish/Stockfish#4100

The new loss function uses the expect game score to learn,
making the the learning more sensitive to those scores between loss and draw, draw and win.

Most visible for smaller values of the scaling parameter, but the current ones have been optimized.

it also introduces param_index for simpler explorations of paramers, i.e. simple parameter scans.
  • Loading branch information
vondele committed Jul 5, 2022
1 parent 50eed1c commit 309ab0c
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
33 changes: 23 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ class NNUE(pl.LightningModule):
lr - the initial learning rate
"""
def __init__(self, feature_set, start_lambda=1.0, end_lambda=1.0, max_epoch=800, gamma=0.992, lr=8.75e-4, num_psqt_buckets=8, num_ls_buckets=8):
def __init__(self, feature_set, start_lambda=1.0, end_lambda=1.0, max_epoch=800, gamma=0.992, lr=8.75e-4, param_index=0, num_psqt_buckets=8, num_ls_buckets=8):
super(NNUE, self).__init__()
self.num_psqt_buckets = num_psqt_buckets
self.num_ls_buckets = num_ls_buckets
Expand All @@ -144,6 +144,7 @@ def __init__(self, feature_set, start_lambda=1.0, end_lambda=1.0, max_epoch=800,
self.max_epoch = max_epoch
self.gamma = gamma
self.lr = lr
self.param_index = param_index

self.nnue2score = 600.0
self.weight_scale_hidden = 64.0
Expand Down Expand Up @@ -292,19 +293,31 @@ def step_(self, batch, batch_idx, loss_type):

us, them, white_indices, white_values, black_indices, black_values, outcome, score, psqt_indices, layer_stack_indices = batch

# 600 is the kPonanzaConstant scaling factor needed to convert the training net output to a score.
# This needs to match the value used in the serializer
in_scaling = 410
out_scaling = 361
# win_rate_model a, b in internal units
counter=0
for ins in [-1, 0, 1]:
for outs in [-1, 0, 1]:
if counter == self.param_index:
in_scaling = 360 + ins * 20
out_scaling = 360 + outs * 20
counter = counter + 1

q = (self(us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices) * self.nnue2score / out_scaling).sigmoid()
t = outcome
p = (score / in_scaling).sigmoid()
offset = 270

scorenet = self(us, them, white_indices, white_values, black_indices, black_values, psqt_indices, layer_stack_indices) * self.nnue2score
q = ( scorenet - offset) / in_scaling
qm = (-scorenet - offset) / in_scaling
qf = 0.5 * (1.0 + q.sigmoid() - qm.sigmoid()) # estimated match result

p = ( score - offset) / out_scaling
pm = (-score - offset) / out_scaling
pf = 0.5 * (1.0 + p.sigmoid() - pm.sigmoid())

t = outcome
actual_lambda = self.start_lambda + (self.end_lambda - self.start_lambda) * (self.current_epoch / self.max_epoch)
pt = p * actual_lambda + t * (1.0 - actual_lambda)
pt = pf * actual_lambda + t * (1.0 - actual_lambda)

loss = torch.pow(torch.abs(pt - q), 2.6).mean()
loss = torch.pow(torch.abs(pt - qf), 2.6).mean()

self.log(loss_type, loss)

Expand Down
6 changes: 5 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def main():
parser.add_argument("--save-last-network", type=str2bool, default=True, dest='save_last_network', help="Whether to always save the last produced network.")
parser.add_argument("--epoch-size", type=int, default=100000000, dest='epoch_size', help="Number of positions per epoch.")
parser.add_argument("--validation-size", type=int, default=1000000, dest='validation_size', help="Number of positions per validation step.")
parser.add_argument("--param-index", type=int, default=0, dest='param_index', help="Indexing for parameter scans.")
features.add_argparse_args(parser)
args = parser.parse_args()

Expand All @@ -79,7 +80,8 @@ def main():
max_epoch=max_epoch,
end_lambda=end_lambda,
gamma=args.gamma,
lr=args.lr
lr=args.lr,
param_index=args.param_index
)
else:
nnue = torch.load(args.resume_from_model)
Expand All @@ -91,6 +93,7 @@ def main():
# from .pt the optimizer is only created after the training is started
nnue.gamma = args.gamma
nnue.lr = args.lr
nnue.param_index=args.param_index

print("Feature set: {}".format(feature_set.name))
print("Num real features: {}".format(feature_set.num_real_features))
Expand All @@ -110,6 +113,7 @@ def main():
print('Smart fen skipping: {}'.format(not args.no_smart_fen_skipping))
print('WLD fen skipping: {}'.format(not args.no_wld_fen_skipping))
print('Random fen skipping: {}'.format(args.random_fen_skipping))
print('Param index: {}'.format(args.param_index))

if args.threads > 0:
print('limiting torch to {} threads.'.format(args.threads))
Expand Down

0 comments on commit 309ab0c

Please sign in to comment.