In [3]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import tqdm
import inspect
import logging

from models.teacher import Teacher
from models.configuration_teacher import TeacherConfig
from data import CoTDataset, CoTDataCollator, extract_answer

from utils import get_sep_position

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def save_model(model, tokenizer, model_dir):
    print ('saving', model_dir)
    os.makedirs(model_dir, exist_ok=True)
    model.save_pretrained(model_dir)
    tokenizer.save_pretrained(model_dir)

@torch.no_grad()
def evaluate(dataloader, tokenizer, ctx, teacher, max_new_tokens):
    teacher.eval()
    total_instances = 0
    total_tokens = 0
    total_correct = 0
    total_correct_tokens = 0
    total_loss = 0
    for batch in tqdm.tqdm(dataloader):
        input_ids_all = batch['input_ids_all'].to(device)
        labels = batch['labels_all'].to(device)
        # Remove answer part
        sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id)
        input_ids = input_ids_all[:, :sep_positions.max()+1]
        batch_size = input_ids.shape[0]
        with ctx:
            outputs = teacher.compute_loss(input_ids=input_ids_all, labels=labels)
        total_loss += outputs.total_loss.item()
        total_correct_tokens += outputs.total_correct.item()
        total_tokens += outputs.total_tokens
        total_instances += batch_size

        # Generate
        beam_output = teacher.generate(
            input_ids=input_ids,
            max_new_tokens=max_new_tokens,
        )
        # Evaluate
        #import pdb; pdb.set_trace()
        for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)):
            sep_position = sep_positions[i].item()
            tgt = input_ids_all_i[sep_position+1:]
            tgt_text = tokenizer.decode(tgt, skip_special_tokens=True)
            ans = extract_answer(tgt_text)
            pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True)
            pred_ans = extract_answer(pred_text)
            if ans == pred_ans:
                total_correct += 1
            if i == 0:
                print (f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}')
                print (f'Target: {tgt_text}')
                print (f'Predicted: {pred_text}')
                print ('')
    accuracy = total_correct / total_instances
    token_accuracy = total_correct_tokens / total_tokens
    loss = total_loss / total_tokens
    ppl = math.exp(loss)
    return accuracy, token_accuracy, ppl


    



2024-01-10 10:49:08.880576: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-01-10 10:49:08.902826: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [7]:
teacher_trainer_args=dict(
    debug=False,
    train_path="../data/4_by_4_mult/train.txt",
    val_path="../data/4_by_4_mult/valid.txt",
    save_model="../train_models/4_by_4_mult/gpt2/teacher",
    max_new_tokens=128,
    base_model='gpt2',
    epochs=1,
    batch_size=16,
    lr=5e-5,
    max_grad_norm=1.0,
    
)

from types import SimpleNamespace

args = SimpleNamespace(**teacher_trainer_args)

# parser = argparse.ArgumentParser()
# parser.add_argument('--train_path', type=str, required=True)
# parser.add_argument('--val_path', type=str, required=True)
# parser.add_argument('--save_model', type=str, required=True)
# parser.add_argument('--max_new_tokens', type=int, default=128)
# parser.add_argument('--base_model', type=str, default='gpt2')
# parser.add_argument('--epochs', type=int, default=1)
# parser.add_argument('--batch_size', type=int, default=5e-5)
# parser.add_argument('--lr', type=float, default=5e-5)
# parser.add_argument('--max_grad_norm', type=float, default=1.0)
# args = parser.parse_args()

In [8]:
print (args)

dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Create Student 
config = TeacherConfig(base_model=args.base_model)
teacher = Teacher(config).to(device).to(ptdtype)

# Load data
tokenizer = teacher.tokenizer
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

# Create Optimizer
trainable_params = list(teacher.parameters())
use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args)



namespace(debug=False, train_path='../data/4_by_4_mult/train.txt', val_path='../data/4_by_4_mult/valid.txt', save_model='../train_models/4_by_4_mult/gpt2/teacher', max_new_tokens=128, base_model='gpt2', epochs=1, batch_size=16, lr=5e-05, max_grad_norm=1.0)
torch.float32 float32 cuda
Creating features from dataset file at ../data/4_by_4_mult/train.txt
tgt_avg:  49.0
src_avg:  10.0
ratios:  0.20408163265306123
tgt_avg:  13.0
src_avg:  10.0
ratios:  0.7692307692307693
 1 3 3 8 * 5 1 0 5 <|endoftext|> 5 5 6 1 4 + 0 1 3 3 8 0 ( 5 6 9 4 2 1 ) + 0 0 0 0 0 0 0 ( 5 6 9 4 2 1 0 ) + 0 0 0 5 5 6 1 4 <|endoftext|> #### 5 6 9 9 7 7 1 4 <|endoftext|>
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 642, 642, 718, 352, 604, 1343, 657, 352, 513, 513, 807, 657, 357, 642, 718, 860, 604, 362, 352, 1267, 1343, 657, 657, 657, 657, 657, 657, 657, 357, 642, 718, 860, 604, 362, 352, 657, 1267, 1343, 657, 657, 657, 642, 642, 718, 352, 604, 220, 50256]
[-100, -100, -100, -100, -100, -100, -100,

In [19]:
teacher.train()

# Train
step = 0
for epoch in range(args.epochs):
    print(f"Epoch {epoch}")
    teacher.train()
    for batch in tqdm.tqdm(train_dataloader):
        input_ids = batch['input_ids_all'].to(device)
        labels = batch['labels_all'].to(device)
        with ctx:
            outputs = teacher.compute_loss(input_ids=input_ids, labels=labels)
#         break
        loss = outputs.loss
        token_accuracy = outputs.token_accuracy.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        ppl = loss.exp().item()
        if step % 100 == 0:
            print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}")
        step += 1
#     break
    accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, teacher, args.max_new_tokens)
    print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.')
    teacher.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}'))

Epoch 0


  0%|                                       | 1/50500 [00:00<1:48:34,  7.75it/s]

Step: 0. PPL: 37.15435028076172. Token Accuracy: 0.20104166865348816


  0%|                                     | 103/50500 [00:07<1:01:19, 13.70it/s]

Step: 100. PPL: 3.1745927333831787. Token Accuracy: 0.6177083253860474


  0%|▏                                    | 203/50500 [00:14<1:01:28, 13.63it/s]

Step: 200. PPL: 2.94986891746521. Token Accuracy: 0.6104166507720947


  1%|▏                                    | 303/50500 [00:22<1:01:16, 13.65it/s]

Step: 300. PPL: 2.937230110168457. Token Accuracy: 0.606249988079071


  1%|▎                                    | 403/50500 [00:29<1:01:00, 13.69it/s]

Step: 400. PPL: 2.5423779487609863. Token Accuracy: 0.6510416865348816


  1%|▎                                    | 503/50500 [00:36<1:00:55, 13.68it/s]

Step: 500. PPL: 2.6002774238586426. Token Accuracy: 0.6322916746139526


  1%|▍                                    | 603/50500 [00:44<1:02:50, 13.23it/s]

Step: 600. PPL: 2.629269599914551. Token Accuracy: 0.6385416388511658


  1%|▌                                    | 703/50500 [00:51<1:00:48, 13.65it/s]

Step: 700. PPL: 2.1965105533599854. Token Accuracy: 0.7020833492279053


  2%|▌                                    | 803/50500 [00:58<1:00:45, 13.63it/s]

Step: 800. PPL: 2.2679197788238525. Token Accuracy: 0.6979166865348816


  2%|▋                                    | 903/50500 [01:06<1:00:35, 13.64it/s]

Step: 900. PPL: 1.960226058959961. Token Accuracy: 0.7489583492279053


  2%|▋                                   | 1003/50500 [01:13<1:00:45, 13.58it/s]

Step: 1000. PPL: 1.9397300481796265. Token Accuracy: 0.7541666626930237


  2%|▊                                   | 1103/50500 [01:20<1:00:36, 13.58it/s]

Step: 1100. PPL: 1.8191827535629272. Token Accuracy: 0.7791666388511658


  2%|▊                                   | 1203/50500 [01:28<1:00:38, 13.55it/s]

Step: 1200. PPL: 1.674165964126587. Token Accuracy: 0.8083333373069763


  3%|▉                                   | 1303/50500 [01:35<1:00:27, 13.56it/s]

Step: 1300. PPL: 1.5211948156356812. Token Accuracy: 0.8500000238418579


  3%|█                                   | 1403/50500 [01:43<1:00:14, 13.58it/s]

Step: 1400. PPL: 1.4372514486312866. Token Accuracy: 0.8656250238418579


  3%|█                                   | 1503/50500 [01:50<1:00:08, 13.58it/s]

Step: 1500. PPL: 1.4768505096435547. Token Accuracy: 0.8687499761581421


  3%|█▏                                  | 1603/50500 [01:57<1:00:06, 13.56it/s]

Step: 1600. PPL: 1.3135473728179932. Token Accuracy: 0.8958333134651184


  3%|█▎                                    | 1703/50500 [02:05<59:59, 13.56it/s]

Step: 1700. PPL: 1.3834584951400757. Token Accuracy: 0.9072916507720947


  4%|█▎                                    | 1803/50500 [02:12<59:50, 13.56it/s]

Step: 1800. PPL: 1.2979103326797485. Token Accuracy: 0.909375011920929


  4%|█▍                                    | 1903/50500 [02:19<59:49, 13.54it/s]

Step: 1900. PPL: 1.2793660163879395. Token Accuracy: 0.9125000238418579


  4%|█▌                                    | 2003/50500 [02:27<59:40, 13.54it/s]

Step: 2000. PPL: 1.2352468967437744. Token Accuracy: 0.9291666746139526


  4%|█▌                                    | 2103/50500 [02:34<59:50, 13.48it/s]

Step: 2100. PPL: 1.166247010231018. Token Accuracy: 0.9447916746139526


  4%|█▋                                    | 2203/50500 [02:42<59:32, 13.52it/s]

Step: 2200. PPL: 1.1434860229492188. Token Accuracy: 0.9520833492279053


  5%|█▋                                    | 2303/50500 [02:49<59:16, 13.55it/s]

Step: 2300. PPL: 1.121244192123413. Token Accuracy: 0.9604166746139526


  5%|█▊                                    | 2403/50500 [02:56<59:34, 13.45it/s]

Step: 2400. PPL: 1.0885874032974243. Token Accuracy: 0.9697916507720947


  5%|█▉                                    | 2503/50500 [03:04<59:04, 13.54it/s]

Step: 2500. PPL: 1.152310848236084. Token Accuracy: 0.9510416388511658


  5%|█▉                                    | 2603/50500 [03:11<59:05, 13.51it/s]

Step: 2600. PPL: 1.0695154666900635. Token Accuracy: 0.9750000238418579


  5%|██                                    | 2703/50500 [03:19<58:57, 13.51it/s]

Step: 2700. PPL: 1.061673879623413. Token Accuracy: 0.9802083373069763


  6%|██                                    | 2803/50500 [03:26<59:12, 13.43it/s]

Step: 2800. PPL: 1.0623464584350586. Token Accuracy: 0.9802083373069763


  6%|██▏                                   | 2903/50500 [03:33<58:40, 13.52it/s]

Step: 2900. PPL: 1.050493836402893. Token Accuracy: 0.9833333492279053


  6%|██▎                                   | 3003/50500 [03:41<58:28, 13.54it/s]

Step: 3000. PPL: 1.0370211601257324. Token Accuracy: 0.9895833134651184


  6%|██▎                                   | 3103/50500 [03:48<58:24, 13.52it/s]

Step: 3100. PPL: 1.0424727201461792. Token Accuracy: 0.9854166507720947


  6%|██▍                                   | 3203/50500 [03:56<58:15, 13.53it/s]

Step: 3200. PPL: 1.0260978937149048. Token Accuracy: 0.9906250238418579


  7%|██▍                                   | 3303/50500 [04:03<58:08, 13.53it/s]

Step: 3300. PPL: 1.0230680704116821. Token Accuracy: 0.9947916865348816


  7%|██▌                                   | 3403/50500 [04:10<58:02, 13.52it/s]

Step: 3400. PPL: 1.0414068698883057. Token Accuracy: 0.9885416626930237


  7%|██▋                                   | 3503/50500 [04:18<57:53, 13.53it/s]

Step: 3500. PPL: 1.0235401391983032. Token Accuracy: 0.9916666746139526


  7%|██▋                                   | 3603/50500 [04:25<57:43, 13.54it/s]

Step: 3600. PPL: 1.0132077932357788. Token Accuracy: 0.9958333373069763


  7%|██▊                                   | 3703/50500 [04:33<57:36, 13.54it/s]

Step: 3700. PPL: 1.0150474309921265. Token Accuracy: 0.9947916865348816


  8%|██▊                                   | 3803/50500 [04:40<57:36, 13.51it/s]

Step: 3800. PPL: 1.0127424001693726. Token Accuracy: 0.9947916865348816


  8%|██▉                                   | 3903/50500 [04:47<58:12, 13.34it/s]

Step: 3900. PPL: 1.0173590183258057. Token Accuracy: 0.9947916865348816


  8%|███                                   | 4003/50500 [04:55<57:28, 13.48it/s]

Step: 4000. PPL: 1.0123356580734253. Token Accuracy: 0.9958333373069763


  8%|███                                   | 4103/50500 [05:02<57:17, 13.50it/s]

Step: 4100. PPL: 1.014103889465332. Token Accuracy: 0.9958333373069763


  8%|███▏                                  | 4203/50500 [05:10<58:10, 13.26it/s]

Step: 4200. PPL: 1.0057507753372192. Token Accuracy: 1.0


  9%|███▏                                  | 4303/50500 [05:17<56:57, 13.52it/s]

Step: 4300. PPL: 1.0175501108169556. Token Accuracy: 0.9979166388511658


  9%|███▎                                  | 4403/50500 [05:25<57:01, 13.47it/s]

Step: 4400. PPL: 1.0135481357574463. Token Accuracy: 0.9958333373069763


  9%|███▍                                  | 4503/50500 [05:32<56:41, 13.52it/s]

Step: 4500. PPL: 1.0089991092681885. Token Accuracy: 0.9979166388511658


  9%|███▍                                  | 4603/50500 [05:40<56:33, 13.52it/s]

Step: 4600. PPL: 1.0092486143112183. Token Accuracy: 0.9958333373069763


  9%|███▌                                  | 4703/50500 [05:47<56:27, 13.52it/s]

Step: 4700. PPL: 1.0039559602737427. Token Accuracy: 1.0


 10%|███▌                                  | 4803/50500 [05:54<57:16, 13.30it/s]

Step: 4800. PPL: 1.0118281841278076. Token Accuracy: 0.996874988079071


 10%|███▋                                  | 4903/50500 [06:02<56:17, 13.50it/s]

Step: 4900. PPL: 1.0044519901275635. Token Accuracy: 0.9979166388511658


 10%|███▊                                  | 5003/50500 [06:09<56:10, 13.50it/s]

Step: 5000. PPL: 1.0084184408187866. Token Accuracy: 0.9958333373069763


 10%|███▊                                  | 5103/50500 [06:17<55:56, 13.53it/s]

Step: 5100. PPL: 1.010297179222107. Token Accuracy: 0.9947916865348816


 10%|███▉                                  | 5203/50500 [06:24<55:52, 13.51it/s]

Step: 5200. PPL: 1.0121557712554932. Token Accuracy: 0.996874988079071


 11%|███▉                                  | 5303/50500 [06:31<55:41, 13.53it/s]

Step: 5300. PPL: 1.0108792781829834. Token Accuracy: 0.9979166388511658


 11%|████                                  | 5403/50500 [06:39<55:37, 13.51it/s]

Step: 5400. PPL: 1.008010983467102. Token Accuracy: 0.9958333373069763


 11%|████▏                                 | 5503/50500 [06:46<55:34, 13.50it/s]

Step: 5500. PPL: 1.0081182718276978. Token Accuracy: 0.9979166388511658


 11%|████▏                                 | 5603/50500 [06:54<55:21, 13.52it/s]

Step: 5600. PPL: 1.0029922723770142. Token Accuracy: 1.0


 11%|████▎                                 | 5703/50500 [07:01<55:13, 13.52it/s]

Step: 5700. PPL: 1.002779483795166. Token Accuracy: 0.9989583492279053


 11%|████▎                                 | 5803/50500 [07:08<55:06, 13.52it/s]

Step: 5800. PPL: 1.0087131261825562. Token Accuracy: 0.9979166388511658


 12%|████▍                                 | 5903/50500 [07:16<55:07, 13.48it/s]

Step: 5900. PPL: 1.0024470090866089. Token Accuracy: 0.9989583492279053


 12%|████▌                                 | 6003/50500 [07:23<54:56, 13.50it/s]

Step: 6000. PPL: 1.0029652118682861. Token Accuracy: 1.0


 12%|████▌                                 | 6103/50500 [07:31<54:49, 13.49it/s]

Step: 6100. PPL: 1.0033838748931885. Token Accuracy: 0.9989583492279053


 12%|████▋                                 | 6203/50500 [07:38<54:37, 13.52it/s]

Step: 6200. PPL: 1.0030676126480103. Token Accuracy: 1.0


 12%|████▋                                 | 6303/50500 [07:45<54:32, 13.50it/s]

Step: 6300. PPL: 1.0014922618865967. Token Accuracy: 1.0


 13%|████▊                                 | 6403/50500 [07:53<54:49, 13.41it/s]

Step: 6400. PPL: 1.004403829574585. Token Accuracy: 0.9979166388511658


 13%|████▉                                 | 6503/50500 [08:00<54:16, 13.51it/s]

Step: 6500. PPL: 1.0431616306304932. Token Accuracy: 0.9947916865348816


 13%|████▉                                 | 6603/50500 [08:08<54:08, 13.51it/s]

Step: 6600. PPL: 1.0338977575302124. Token Accuracy: 0.996874988079071


 13%|█████                                 | 6703/50500 [08:15<54:00, 13.51it/s]

Step: 6700. PPL: 1.001433253288269. Token Accuracy: 1.0


 13%|█████                                 | 6803/50500 [08:23<53:54, 13.51it/s]

Step: 6800. PPL: 1.000759243965149. Token Accuracy: 1.0


 14%|█████▏                                | 6903/50500 [08:30<53:44, 13.52it/s]

Step: 6900. PPL: 1.0020865201950073. Token Accuracy: 1.0


 14%|█████▎                                | 7003/50500 [08:37<53:41, 13.50it/s]

Step: 7000. PPL: 1.0041906833648682. Token Accuracy: 0.9989583492279053


 14%|█████▎                                | 7103/50500 [08:45<53:38, 13.48it/s]

Step: 7100. PPL: 1.0011993646621704. Token Accuracy: 1.0


 14%|█████▍                                | 7203/50500 [08:52<53:28, 13.50it/s]

Step: 7200. PPL: 1.0197738409042358. Token Accuracy: 0.996874988079071


 14%|█████▍                                | 7303/50500 [09:00<53:17, 13.51it/s]

Step: 7300. PPL: 1.0195097923278809. Token Accuracy: 0.9989583492279053


 15%|█████▌                                | 7403/50500 [09:07<53:09, 13.51it/s]

Step: 7400. PPL: 1.0010989904403687. Token Accuracy: 1.0


 15%|█████▋                                | 7503/50500 [09:14<53:02, 13.51it/s]

Step: 7500. PPL: 1.0016409158706665. Token Accuracy: 1.0


 15%|█████▋                                | 7603/50500 [09:22<52:54, 13.51it/s]

Step: 7600. PPL: 1.0042002201080322. Token Accuracy: 0.9989583492279053


 15%|█████▊                                | 7703/50500 [09:29<52:47, 13.51it/s]

Step: 7700. PPL: 1.0016028881072998. Token Accuracy: 1.0


 15%|█████▊                                | 7803/50500 [09:37<52:41, 13.51it/s]

Step: 7800. PPL: 1.0011553764343262. Token Accuracy: 1.0


 16%|█████▉                                | 7903/50500 [09:44<52:31, 13.52it/s]

Step: 7900. PPL: 1.004339337348938. Token Accuracy: 0.9989583492279053


 16%|██████                                | 8003/50500 [09:51<52:25, 13.51it/s]

Step: 8000. PPL: 1.0034946203231812. Token Accuracy: 0.9979166388511658


 16%|██████                                | 8103/50500 [09:59<52:18, 13.51it/s]

Step: 8100. PPL: 1.0052851438522339. Token Accuracy: 0.9979166388511658


 16%|██████▏                               | 8203/50500 [10:06<52:06, 13.53it/s]

Step: 8200. PPL: 1.0097767114639282. Token Accuracy: 0.9989583492279053


 16%|██████▏                               | 8303/50500 [10:14<52:00, 13.52it/s]

Step: 8300. PPL: 1.0006855726242065. Token Accuracy: 1.0


 17%|██████▎                               | 8403/50500 [10:21<51:56, 13.51it/s]

Step: 8400. PPL: 1.0014439821243286. Token Accuracy: 1.0


 17%|██████▍                               | 8503/50500 [10:28<51:50, 13.50it/s]

Step: 8500. PPL: 1.0007047653198242. Token Accuracy: 1.0


 17%|██████▍                               | 8603/50500 [10:36<51:41, 13.51it/s]

Step: 8600. PPL: 1.0040947198867798. Token Accuracy: 0.9989583492279053


 17%|██████▌                               | 8703/50500 [10:43<51:30, 13.52it/s]

Step: 8700. PPL: 1.000863790512085. Token Accuracy: 1.0


 17%|██████▌                               | 8803/50500 [10:51<51:24, 13.52it/s]

Step: 8800. PPL: 1.0009558200836182. Token Accuracy: 1.0


 18%|██████▋                               | 8903/50500 [10:58<51:20, 13.50it/s]

Step: 8900. PPL: 1.0026354789733887. Token Accuracy: 0.9989583492279053


 18%|██████▊                               | 9003/50500 [11:05<51:08, 13.52it/s]

Step: 9000. PPL: 1.0016045570373535. Token Accuracy: 0.9989583492279053


 18%|██████▊                               | 9103/50500 [11:13<51:02, 13.52it/s]

Step: 9100. PPL: 1.0029189586639404. Token Accuracy: 0.9989583492279053


 18%|██████▉                               | 9203/50500 [11:20<51:01, 13.49it/s]

Step: 9200. PPL: 1.0004907846450806. Token Accuracy: 1.0


 18%|███████                               | 9303/50500 [11:28<51:09, 13.42it/s]

Step: 9300. PPL: 1.002008080482483. Token Accuracy: 1.0


 19%|███████                               | 9403/50500 [11:35<50:39, 13.52it/s]

Step: 9400. PPL: 1.0035741329193115. Token Accuracy: 0.9979166388511658


 19%|███████▏                              | 9503/50500 [11:42<50:33, 13.52it/s]

Step: 9500. PPL: 1.00249183177948. Token Accuracy: 0.9989583492279053


 19%|███████▏                              | 9603/50500 [11:50<50:27, 13.51it/s]

Step: 9600. PPL: 1.0073771476745605. Token Accuracy: 0.996874988079071


 19%|███████▎                              | 9703/50500 [11:57<50:17, 13.52it/s]

Step: 9700. PPL: 1.0040955543518066. Token Accuracy: 0.9989583492279053


 19%|███████▍                              | 9803/50500 [12:05<50:14, 13.50it/s]

Step: 9800. PPL: 1.00263512134552. Token Accuracy: 0.9989583492279053


 20%|███████▍                              | 9903/50500 [12:12<50:02, 13.52it/s]

Step: 9900. PPL: 1.0026850700378418. Token Accuracy: 1.0


 20%|███████▎                             | 10003/50500 [12:20<49:59, 13.50it/s]

Step: 10000. PPL: 1.0003700256347656. Token Accuracy: 1.0


 20%|███████▍                             | 10103/50500 [12:27<49:48, 13.52it/s]

Step: 10100. PPL: 1.001907467842102. Token Accuracy: 0.9989583492279053


 20%|███████▍                             | 10203/50500 [12:34<49:44, 13.50it/s]

Step: 10200. PPL: 1.001348853111267. Token Accuracy: 0.9989583492279053


 20%|███████▌                             | 10303/50500 [12:42<49:37, 13.50it/s]

Step: 10300. PPL: 1.0021265745162964. Token Accuracy: 0.9989583492279053


 21%|███████▌                             | 10403/50500 [12:49<49:24, 13.53it/s]

Step: 10400. PPL: 1.001668930053711. Token Accuracy: 1.0


 21%|███████▋                             | 10503/50500 [12:57<49:23, 13.50it/s]

Step: 10500. PPL: 1.0013036727905273. Token Accuracy: 1.0


 21%|███████▊                             | 10603/50500 [13:04<49:11, 13.52it/s]

Step: 10600. PPL: 1.0079923868179321. Token Accuracy: 0.9979166388511658


 21%|███████▊                             | 10703/50500 [13:11<49:04, 13.51it/s]

Step: 10700. PPL: 1.0006282329559326. Token Accuracy: 1.0


 21%|███████▉                             | 10803/50500 [13:19<48:56, 13.52it/s]

Step: 10800. PPL: 1.000241756439209. Token Accuracy: 1.0


 22%|███████▉                             | 10903/50500 [13:26<48:58, 13.48it/s]

Step: 10900. PPL: 1.002284288406372. Token Accuracy: 0.9989583492279053


 22%|████████                             | 11003/50500 [13:34<48:43, 13.51it/s]

Step: 11000. PPL: 1.0004208087921143. Token Accuracy: 1.0


 22%|████████▏                            | 11103/50500 [13:41<48:37, 13.50it/s]

Step: 11100. PPL: 1.000526785850525. Token Accuracy: 1.0


 22%|████████▏                            | 11203/50500 [13:48<48:25, 13.52it/s]

Step: 11200. PPL: 1.000959873199463. Token Accuracy: 1.0


 22%|████████▎                            | 11303/50500 [13:56<48:19, 13.52it/s]

Step: 11300. PPL: 1.000246286392212. Token Accuracy: 1.0


 23%|████████▎                            | 11403/50500 [14:03<48:13, 13.51it/s]

Step: 11400. PPL: 1.0001680850982666. Token Accuracy: 1.0


 23%|████████▍                            | 11503/50500 [14:11<48:04, 13.52it/s]

Step: 11500. PPL: 1.002304196357727. Token Accuracy: 0.9989583492279053


 23%|████████▌                            | 11603/50500 [14:18<47:55, 13.53it/s]

Step: 11600. PPL: 1.0003401041030884. Token Accuracy: 1.0


 23%|████████▌                            | 11703/50500 [14:25<47:49, 13.52it/s]

Step: 11700. PPL: 1.0019042491912842. Token Accuracy: 0.9989583492279053


 23%|████████▋                            | 11803/50500 [14:33<47:46, 13.50it/s]

Step: 11800. PPL: 1.0061089992523193. Token Accuracy: 0.9989583492279053


 24%|████████▋                            | 11903/50500 [14:40<47:41, 13.49it/s]

Step: 11900. PPL: 1.0009773969650269. Token Accuracy: 1.0


 24%|████████▊                            | 12003/50500 [14:48<47:32, 13.50it/s]

Step: 12000. PPL: 1.0042423009872437. Token Accuracy: 0.9979166388511658


 24%|████████▊                            | 12103/50500 [14:55<47:22, 13.51it/s]

Step: 12100. PPL: 1.0002365112304688. Token Accuracy: 1.0


 24%|████████▉                            | 12203/50500 [15:03<48:14, 13.23it/s]

Step: 12200. PPL: 1.0003551244735718. Token Accuracy: 1.0


 24%|█████████                            | 12303/50500 [15:10<47:13, 13.48it/s]

Step: 12300. PPL: 1.0002946853637695. Token Accuracy: 1.0


 25%|█████████                            | 12403/50500 [15:17<47:03, 13.49it/s]

Step: 12400. PPL: 1.0107393264770508. Token Accuracy: 0.996874988079071


 25%|█████████▏                           | 12503/50500 [15:25<46:49, 13.52it/s]

Step: 12500. PPL: 1.0040730237960815. Token Accuracy: 0.9979166388511658


 25%|█████████▏                           | 12603/50500 [15:32<46:39, 13.54it/s]

Step: 12600. PPL: 1.0035065412521362. Token Accuracy: 0.9979166388511658


 25%|█████████▎                           | 12703/50500 [15:40<46:36, 13.51it/s]

Step: 12700. PPL: 1.0008717775344849. Token Accuracy: 1.0


 25%|█████████▍                           | 12803/50500 [15:47<46:29, 13.51it/s]

Step: 12800. PPL: 1.000161051750183. Token Accuracy: 1.0


 26%|█████████▍                           | 12903/50500 [15:54<46:56, 13.35it/s]

Step: 12900. PPL: 1.0002198219299316. Token Accuracy: 1.0


 26%|█████████▌                           | 13003/50500 [16:02<46:14, 13.51it/s]

Step: 13000. PPL: 1.000314474105835. Token Accuracy: 1.0


 26%|█████████▌                           | 13103/50500 [16:09<46:05, 13.52it/s]

Step: 13100. PPL: 1.0007846355438232. Token Accuracy: 1.0


 26%|█████████▋                           | 13203/50500 [16:17<46:04, 13.49it/s]

Step: 13200. PPL: 1.0002814531326294. Token Accuracy: 1.0


 26%|█████████▋                           | 13303/50500 [16:24<45:50, 13.52it/s]

Step: 13300. PPL: 1.0004594326019287. Token Accuracy: 1.0


 27%|█████████▊                           | 13403/50500 [16:31<45:46, 13.50it/s]

Step: 13400. PPL: 1.0000988245010376. Token Accuracy: 1.0


 27%|█████████▉                           | 13503/50500 [16:39<45:34, 13.53it/s]

Step: 13500. PPL: 1.0002086162567139. Token Accuracy: 1.0


 27%|█████████▉                           | 13603/50500 [16:46<45:33, 13.50it/s]

Step: 13600. PPL: 1.0006637573242188. Token Accuracy: 1.0


 27%|██████████                           | 13703/50500 [16:54<45:22, 13.52it/s]

Step: 13700. PPL: 1.0004687309265137. Token Accuracy: 1.0


 27%|██████████                           | 13803/50500 [17:01<45:11, 13.53it/s]

Step: 13800. PPL: 1.0015027523040771. Token Accuracy: 0.9989583492279053


 28%|██████████▏                          | 13903/50500 [17:08<45:13, 13.49it/s]

Step: 13900. PPL: 1.0002405643463135. Token Accuracy: 1.0


 28%|██████████▎                          | 14003/50500 [17:16<44:58, 13.52it/s]

Step: 14000. PPL: 1.0019563436508179. Token Accuracy: 0.9989583492279053


 28%|██████████▎                          | 14103/50500 [17:23<45:05, 13.45it/s]

Step: 14100. PPL: 1.0004161596298218. Token Accuracy: 1.0


 28%|██████████▍                          | 14203/50500 [17:31<44:48, 13.50it/s]

Step: 14200. PPL: 1.000484824180603. Token Accuracy: 1.0


 28%|██████████▍                          | 14303/50500 [17:38<44:37, 13.52it/s]

Step: 14300. PPL: 1.0008885860443115. Token Accuracy: 1.0


 29%|██████████▌                          | 14403/50500 [17:46<44:31, 13.51it/s]

Step: 14400. PPL: 1.0052731037139893. Token Accuracy: 0.9979166388511658


 29%|██████████▋                          | 14503/50500 [17:53<44:36, 13.45it/s]

Step: 14500. PPL: 1.0000840425491333. Token Accuracy: 1.0


 29%|██████████▋                          | 14603/50500 [18:00<44:16, 13.51it/s]

Step: 14600. PPL: 1.0007890462875366. Token Accuracy: 1.0


 29%|██████████▊                          | 14703/50500 [18:08<44:07, 13.52it/s]

Step: 14700. PPL: 1.0010062456130981. Token Accuracy: 1.0


 29%|██████████▊                          | 14803/50500 [18:15<44:03, 13.50it/s]

Step: 14800. PPL: 1.0013166666030884. Token Accuracy: 0.9989583492279053


 30%|██████████▉                          | 14903/50500 [18:23<43:55, 13.50it/s]

Step: 14900. PPL: 1.0001945495605469. Token Accuracy: 1.0


 30%|██████████▉                          | 15003/50500 [18:30<43:50, 13.49it/s]

Step: 15000. PPL: 1.0003108978271484. Token Accuracy: 1.0


 30%|███████████                          | 15103/50500 [18:37<43:44, 13.49it/s]

Step: 15100. PPL: 1.0004475116729736. Token Accuracy: 1.0


 30%|███████████▏                         | 15203/50500 [18:45<43:37, 13.49it/s]

Step: 15200. PPL: 1.00641930103302. Token Accuracy: 0.9979166388511658


 30%|███████████▏                         | 15303/50500 [18:52<43:22, 13.52it/s]

Step: 15300. PPL: 1.0000579357147217. Token Accuracy: 1.0


 31%|███████████▎                         | 15403/50500 [19:00<43:20, 13.50it/s]

Step: 15400. PPL: 1.0124813318252563. Token Accuracy: 0.9989583492279053


 31%|███████████▎                         | 15503/50500 [19:07<43:09, 13.51it/s]

Step: 15500. PPL: 1.000404715538025. Token Accuracy: 1.0


 31%|███████████▍                         | 15603/50500 [19:14<43:06, 13.49it/s]

Step: 15600. PPL: 1.0004403591156006. Token Accuracy: 1.0


 31%|███████████▌                         | 15703/50500 [19:22<42:54, 13.52it/s]

Step: 15700. PPL: 1.001133680343628. Token Accuracy: 1.0


 31%|███████████▌                         | 15803/50500 [19:29<42:48, 13.51it/s]

Step: 15800. PPL: 1.0002082586288452. Token Accuracy: 1.0


 31%|███████████▋                         | 15903/50500 [19:37<42:42, 13.50it/s]

Step: 15900. PPL: 1.0024927854537964. Token Accuracy: 0.9989583492279053


 32%|███████████▋                         | 16003/50500 [19:44<42:35, 13.50it/s]

Step: 16000. PPL: 1.003578543663025. Token Accuracy: 0.9979166388511658


 32%|███████████▊                         | 16103/50500 [19:51<42:23, 13.52it/s]

Step: 16100. PPL: 1.0003405809402466. Token Accuracy: 1.0


 32%|███████████▊                         | 16203/50500 [19:59<42:22, 13.49it/s]

Step: 16200. PPL: 1.0018181800842285. Token Accuracy: 0.9989583492279053


 32%|███████████▉                         | 16303/50500 [20:06<42:14, 13.49it/s]

Step: 16300. PPL: 1.0001306533813477. Token Accuracy: 1.0


 32%|████████████                         | 16403/50500 [20:14<42:02, 13.52it/s]

Step: 16400. PPL: 1.0000699758529663. Token Accuracy: 1.0


 33%|████████████                         | 16503/50500 [20:21<41:52, 13.53it/s]

Step: 16500. PPL: 1.000848650932312. Token Accuracy: 1.0


 33%|████████████▏                        | 16603/50500 [20:28<41:44, 13.53it/s]

Step: 16600. PPL: 1.0002092123031616. Token Accuracy: 1.0


 33%|████████████▏                        | 16703/50500 [20:36<41:39, 13.52it/s]

Step: 16700. PPL: 1.0002540349960327. Token Accuracy: 1.0


 33%|████████████▎                        | 16803/50500 [20:43<41:33, 13.51it/s]

Step: 16800. PPL: 1.0001277923583984. Token Accuracy: 1.0


 33%|████████████▍                        | 16903/50500 [20:51<41:26, 13.51it/s]

Step: 16900. PPL: 1.0003094673156738. Token Accuracy: 1.0


 34%|████████████▍                        | 17003/50500 [20:58<41:19, 13.51it/s]

Step: 17000. PPL: 1.000335454940796. Token Accuracy: 1.0


 34%|████████████▌                        | 17103/50500 [21:06<41:12, 13.51it/s]

Step: 17100. PPL: 1.000248908996582. Token Accuracy: 1.0


 34%|████████████▌                        | 17203/50500 [21:13<41:06, 13.50it/s]

Step: 17200. PPL: 1.000472903251648. Token Accuracy: 1.0


 34%|████████████▋                        | 17303/50500 [21:20<40:54, 13.53it/s]

Step: 17300. PPL: 1.0009820461273193. Token Accuracy: 1.0


 34%|████████████▊                        | 17403/50500 [21:28<40:47, 13.52it/s]

Step: 17400. PPL: 1.0028098821640015. Token Accuracy: 0.9989583492279053


 35%|████████████▊                        | 17503/50500 [21:35<40:43, 13.51it/s]

Step: 17500. PPL: 1.0010796785354614. Token Accuracy: 1.0


 35%|████████████▉                        | 17603/50500 [21:43<40:42, 13.47it/s]

Step: 17600. PPL: 1.0004916191101074. Token Accuracy: 1.0


 35%|████████████▉                        | 17703/50500 [21:50<40:26, 13.52it/s]

Step: 17700. PPL: 1.000158429145813. Token Accuracy: 1.0


 35%|█████████████                        | 17803/50500 [21:57<40:19, 13.51it/s]

Step: 17800. PPL: 1.003365397453308. Token Accuracy: 0.9979166388511658


 35%|█████████████                        | 17903/50500 [22:05<40:11, 13.52it/s]

Step: 17900. PPL: 1.0009407997131348. Token Accuracy: 0.9989583492279053


 36%|█████████████▏                       | 18003/50500 [22:12<40:06, 13.51it/s]

Step: 18000. PPL: 1.0001060962677002. Token Accuracy: 1.0


 36%|█████████████▎                       | 18103/50500 [22:20<39:58, 13.51it/s]

Step: 18100. PPL: 1.000985860824585. Token Accuracy: 1.0


 36%|█████████████▎                       | 18203/50500 [22:27<39:51, 13.50it/s]

Step: 18200. PPL: 1.0004655122756958. Token Accuracy: 1.0


 36%|█████████████▍                       | 18303/50500 [22:34<39:43, 13.51it/s]

Step: 18300. PPL: 1.0001195669174194. Token Accuracy: 1.0


 36%|█████████████▍                       | 18403/50500 [22:42<39:32, 13.53it/s]

Step: 18400. PPL: 1.0005403757095337. Token Accuracy: 1.0


 37%|█████████████▌                       | 18503/50500 [22:49<39:28, 13.51it/s]

Step: 18500. PPL: 1.0036171674728394. Token Accuracy: 0.9989583492279053


 37%|█████████████▋                       | 18603/50500 [22:57<39:24, 13.49it/s]

Step: 18600. PPL: 1.0000901222229004. Token Accuracy: 1.0


 37%|█████████████▋                       | 18703/50500 [23:04<39:10, 13.53it/s]

Step: 18700. PPL: 1.0010746717453003. Token Accuracy: 0.9989583492279053


 37%|█████████████▊                       | 18803/50500 [23:11<39:05, 13.52it/s]

Step: 18800. PPL: 1.001001000404358. Token Accuracy: 1.0


 37%|█████████████▊                       | 18903/50500 [23:19<38:57, 13.52it/s]

Step: 18900. PPL: 1.0000356435775757. Token Accuracy: 1.0


 38%|█████████████▉                       | 19003/50500 [23:26<39:10, 13.40it/s]

Step: 19000. PPL: 1.0001137256622314. Token Accuracy: 1.0


 38%|█████████████▉                       | 19103/50500 [23:34<38:44, 13.51it/s]

Step: 19100. PPL: 1.0004913806915283. Token Accuracy: 1.0


 38%|██████████████                       | 19203/50500 [23:41<38:33, 13.53it/s]

Step: 19200. PPL: 1.00054132938385. Token Accuracy: 1.0


 38%|██████████████▏                      | 19303/50500 [23:49<38:28, 13.52it/s]

Step: 19300. PPL: 1.000450611114502. Token Accuracy: 1.0


 38%|██████████████▏                      | 19403/50500 [23:56<38:24, 13.50it/s]

Step: 19400. PPL: 1.0002080202102661. Token Accuracy: 1.0


 39%|██████████████▎                      | 19503/50500 [24:03<38:12, 13.52it/s]

Step: 19500. PPL: 1.0000977516174316. Token Accuracy: 1.0


 39%|██████████████▎                      | 19603/50500 [24:11<38:09, 13.50it/s]

Step: 19600. PPL: 1.0001823902130127. Token Accuracy: 1.0


 39%|██████████████▍                      | 19703/50500 [24:18<37:57, 13.52it/s]

Step: 19700. PPL: 1.0002268552780151. Token Accuracy: 1.0


 39%|██████████████▌                      | 19803/50500 [24:26<37:51, 13.52it/s]

Step: 19800. PPL: 1.0002280473709106. Token Accuracy: 1.0


 39%|██████████████▌                      | 19903/50500 [24:33<37:43, 13.52it/s]

Step: 19900. PPL: 1.0008361339569092. Token Accuracy: 0.9989583492279053


 40%|██████████████▋                      | 20003/50500 [24:40<37:37, 13.51it/s]

Step: 20000. PPL: 1.0002021789550781. Token Accuracy: 1.0


 40%|██████████████▋                      | 20103/50500 [24:48<37:30, 13.50it/s]

Step: 20100. PPL: 1.0014550685882568. Token Accuracy: 0.9989583492279053


 40%|██████████████▊                      | 20203/50500 [24:55<37:22, 13.51it/s]

Step: 20200. PPL: 1.000223994255066. Token Accuracy: 1.0


 40%|██████████████▉                      | 20303/50500 [25:03<37:13, 13.52it/s]

Step: 20300. PPL: 1.0003734827041626. Token Accuracy: 1.0


 40%|██████████████▉                      | 20403/50500 [25:10<37:05, 13.52it/s]

Step: 20400. PPL: 1.0011310577392578. Token Accuracy: 1.0


 41%|███████████████                      | 20503/50500 [25:17<36:59, 13.52it/s]

Step: 20500. PPL: 1.0002251863479614. Token Accuracy: 1.0


 41%|███████████████                      | 20603/50500 [25:25<36:51, 13.52it/s]

Step: 20600. PPL: 1.0000617504119873. Token Accuracy: 1.0


 41%|███████████████▏                     | 20703/50500 [25:32<36:42, 13.53it/s]

Step: 20700. PPL: 1.0012174844741821. Token Accuracy: 0.9989583492279053


 41%|███████████████▏                     | 20803/50500 [25:40<36:34, 13.53it/s]

Step: 20800. PPL: 1.004353642463684. Token Accuracy: 0.9989583492279053


 41%|███████████████▎                     | 20903/50500 [25:47<36:30, 13.51it/s]

Step: 20900. PPL: 1.000049352645874. Token Accuracy: 1.0


 42%|███████████████▍                     | 21003/50500 [25:54<36:56, 13.31it/s]

Step: 21000. PPL: 1.0000638961791992. Token Accuracy: 1.0


 42%|███████████████▍                     | 21103/50500 [26:02<36:17, 13.50it/s]

Step: 21100. PPL: 1.0000581741333008. Token Accuracy: 1.0


 42%|███████████████▌                     | 21203/50500 [26:09<36:10, 13.50it/s]

Step: 21200. PPL: 1.000266432762146. Token Accuracy: 1.0


 42%|███████████████▌                     | 21303/50500 [26:17<35:57, 13.53it/s]

Step: 21300. PPL: 1.0000147819519043. Token Accuracy: 1.0


 42%|███████████████▋                     | 21403/50500 [26:24<35:52, 13.52it/s]

Step: 21400. PPL: 1.0054397583007812. Token Accuracy: 0.9979166388511658


 43%|███████████████▊                     | 21503/50500 [26:31<35:44, 13.52it/s]

Step: 21500. PPL: 1.0006153583526611. Token Accuracy: 1.0


 43%|███████████████▊                     | 21603/50500 [26:39<35:37, 13.52it/s]

Step: 21600. PPL: 1.0035252571105957. Token Accuracy: 0.9979166388511658


 43%|███████████████▉                     | 21703/50500 [26:46<35:34, 13.49it/s]

Step: 21700. PPL: 1.0002470016479492. Token Accuracy: 1.0


 43%|███████████████▉                     | 21803/50500 [26:54<35:22, 13.52it/s]

Step: 21800. PPL: 1.0001438856124878. Token Accuracy: 1.0


 43%|████████████████                     | 21903/50500 [27:01<35:14, 13.53it/s]

Step: 21900. PPL: 1.0009061098098755. Token Accuracy: 1.0


 44%|████████████████                     | 22003/50500 [27:08<35:09, 13.51it/s]

Step: 22000. PPL: 1.0005223751068115. Token Accuracy: 1.0


 44%|████████████████▏                    | 22103/50500 [27:16<35:05, 13.49it/s]

Step: 22100. PPL: 1.0006898641586304. Token Accuracy: 1.0


 44%|████████████████▎                    | 22203/50500 [27:23<34:53, 13.52it/s]

Step: 22200. PPL: 1.001291275024414. Token Accuracy: 0.9989583492279053


 44%|████████████████▎                    | 22303/50500 [27:31<34:47, 13.51it/s]

Step: 22300. PPL: 1.0001072883605957. Token Accuracy: 1.0


 44%|████████████████▍                    | 22403/50500 [27:38<34:42, 13.49it/s]

Step: 22400. PPL: 1.000052571296692. Token Accuracy: 1.0


 45%|████████████████▍                    | 22503/50500 [27:45<34:32, 13.51it/s]

Step: 22500. PPL: 1.000051736831665. Token Accuracy: 1.0


 45%|████████████████▌                    | 22603/50500 [27:53<34:31, 13.47it/s]

Step: 22600. PPL: 1.0001935958862305. Token Accuracy: 1.0


 45%|████████████████▋                    | 22703/50500 [28:00<34:16, 13.52it/s]

Step: 22700. PPL: 1.0000290870666504. Token Accuracy: 1.0


 45%|████████████████▋                    | 22803/50500 [28:08<34:11, 13.50it/s]

Step: 22800. PPL: 1.0000420808792114. Token Accuracy: 1.0


 45%|████████████████▊                    | 22903/50500 [28:15<34:01, 13.52it/s]

Step: 22900. PPL: 1.0002981424331665. Token Accuracy: 1.0


 46%|████████████████▊                    | 23003/50500 [28:23<33:55, 13.51it/s]

Step: 23000. PPL: 1.0001182556152344. Token Accuracy: 1.0


 46%|████████████████▉                    | 23103/50500 [28:30<33:48, 13.51it/s]

Step: 23100. PPL: 1.000577449798584. Token Accuracy: 1.0


 46%|█████████████████                    | 23203/50500 [28:37<33:42, 13.50it/s]

Step: 23200. PPL: 1.0003174543380737. Token Accuracy: 1.0


 46%|█████████████████                    | 23303/50500 [28:45<33:35, 13.49it/s]

Step: 23300. PPL: 1.0002485513687134. Token Accuracy: 1.0


 46%|█████████████████▏                   | 23403/50500 [28:52<33:25, 13.51it/s]

Step: 23400. PPL: 1.00041925907135. Token Accuracy: 1.0


 47%|█████████████████▏                   | 23503/50500 [29:00<33:24, 13.47it/s]

Step: 23500. PPL: 1.0000874996185303. Token Accuracy: 1.0


 47%|█████████████████▎                   | 23603/50500 [29:07<33:10, 13.51it/s]

Step: 23600. PPL: 1.0001115798950195. Token Accuracy: 1.0


 47%|█████████████████▎                   | 23703/50500 [29:14<33:01, 13.53it/s]

Step: 23700. PPL: 1.0000609159469604. Token Accuracy: 1.0


 47%|█████████████████▍                   | 23803/50500 [29:22<32:56, 13.51it/s]

Step: 23800. PPL: 1.0006725788116455. Token Accuracy: 1.0


 47%|█████████████████▌                   | 23903/50500 [29:29<32:48, 13.51it/s]

Step: 23900. PPL: 1.0003615617752075. Token Accuracy: 1.0


 48%|█████████████████▌                   | 24003/50500 [29:37<32:40, 13.52it/s]

Step: 24000. PPL: 1.0002307891845703. Token Accuracy: 1.0


 48%|█████████████████▋                   | 24103/50500 [29:44<32:32, 13.52it/s]

Step: 24100. PPL: 1.0006874799728394. Token Accuracy: 1.0


 48%|█████████████████▋                   | 24203/50500 [29:51<32:31, 13.48it/s]

Step: 24200. PPL: 1.0000239610671997. Token Accuracy: 1.0


 48%|█████████████████▊                   | 24303/50500 [29:59<32:18, 13.52it/s]

Step: 24300. PPL: 1.000052571296692. Token Accuracy: 1.0


 48%|█████████████████▉                   | 24403/50500 [30:06<32:12, 13.50it/s]

Step: 24400. PPL: 1.0000733137130737. Token Accuracy: 1.0


 49%|█████████████████▉                   | 24503/50500 [30:14<32:06, 13.49it/s]

Step: 24500. PPL: 1.000091314315796. Token Accuracy: 1.0


 49%|██████████████████                   | 24603/50500 [30:21<31:55, 13.52it/s]

Step: 24600. PPL: 1.000052571296692. Token Accuracy: 1.0


 49%|██████████████████                   | 24703/50500 [30:28<31:47, 13.52it/s]

Step: 24700. PPL: 1.0036660432815552. Token Accuracy: 0.9989583492279053


 49%|██████████████████▏                  | 24803/50500 [30:36<31:38, 13.53it/s]

Step: 24800. PPL: 1.0009574890136719. Token Accuracy: 1.0


 49%|██████████████████▏                  | 24903/50500 [30:43<31:36, 13.50it/s]

Step: 24900. PPL: 1.0002118349075317. Token Accuracy: 1.0


 50%|██████████████████▎                  | 25003/50500 [30:51<31:26, 13.52it/s]

Step: 25000. PPL: 1.0005446672439575. Token Accuracy: 1.0


 50%|██████████████████▍                  | 25103/50500 [30:58<31:20, 13.51it/s]

Step: 25100. PPL: 1.000168800354004. Token Accuracy: 1.0


 50%|██████████████████▍                  | 25203/50500 [31:05<31:12, 13.51it/s]

Step: 25200. PPL: 1.0002115964889526. Token Accuracy: 1.0


 50%|██████████████████▌                  | 25303/50500 [31:13<31:07, 13.49it/s]

Step: 25300. PPL: 1.000032901763916. Token Accuracy: 1.0


 50%|██████████████████▌                  | 25403/50500 [31:20<30:56, 13.52it/s]

Step: 25400. PPL: 1.000160813331604. Token Accuracy: 1.0


 51%|██████████████████▋                  | 25503/50500 [31:28<30:49, 13.51it/s]

Step: 25500. PPL: 1.0000309944152832. Token Accuracy: 1.0


 51%|██████████████████▊                  | 25603/50500 [31:35<30:41, 13.52it/s]

Step: 25600. PPL: 1.0000849962234497. Token Accuracy: 1.0


 51%|██████████████████▊                  | 25703/50500 [31:42<30:34, 13.52it/s]

Step: 25700. PPL: 1.0005258321762085. Token Accuracy: 1.0


 51%|██████████████████▉                  | 25803/50500 [31:50<30:30, 13.49it/s]

Step: 25800. PPL: 1.0001157522201538. Token Accuracy: 1.0


 51%|██████████████████▉                  | 25903/50500 [31:57<30:19, 13.52it/s]

Step: 25900. PPL: 1.0004751682281494. Token Accuracy: 1.0


 51%|███████████████████                  | 26003/50500 [32:05<30:11, 13.52it/s]

Step: 26000. PPL: 1.0000628232955933. Token Accuracy: 1.0


 52%|███████████████████                  | 26103/50500 [32:12<30:06, 13.51it/s]

Step: 26100. PPL: 1.0004229545593262. Token Accuracy: 1.0


 52%|███████████████████▏                 | 26203/50500 [32:19<29:56, 13.53it/s]

Step: 26200. PPL: 1.0000261068344116. Token Accuracy: 1.0


 52%|███████████████████▎                 | 26303/50500 [32:27<29:48, 13.53it/s]

Step: 26300. PPL: 1.0000444650650024. Token Accuracy: 1.0


 52%|███████████████████▎                 | 26403/50500 [32:34<29:46, 13.49it/s]

Step: 26400. PPL: 1.0000301599502563. Token Accuracy: 1.0


 52%|███████████████████▍                 | 26503/50500 [32:42<29:38, 13.49it/s]

Step: 26500. PPL: 1.0022660493850708. Token Accuracy: 0.9989583492279053


 53%|███████████████████▍                 | 26603/50500 [32:49<29:27, 13.52it/s]

Step: 26600. PPL: 1.000016212463379. Token Accuracy: 1.0


 53%|███████████████████▌                 | 26703/50500 [32:57<29:23, 13.50it/s]

Step: 26700. PPL: 1.0002480745315552. Token Accuracy: 1.0


 53%|███████████████████▋                 | 26803/50500 [33:04<29:12, 13.53it/s]

Step: 26800. PPL: 1.0002717971801758. Token Accuracy: 1.0


 53%|███████████████████▋                 | 26903/50500 [33:11<29:06, 13.51it/s]

Step: 26900. PPL: 1.0066250562667847. Token Accuracy: 0.9979166388511658


 53%|███████████████████▊                 | 27003/50500 [33:19<28:59, 13.51it/s]

Step: 27000. PPL: 1.0003900527954102. Token Accuracy: 1.0


 54%|███████████████████▊                 | 27103/50500 [33:26<28:52, 13.51it/s]

Step: 27100. PPL: 1.0001094341278076. Token Accuracy: 1.0


 54%|███████████████████▉                 | 27203/50500 [33:34<28:47, 13.49it/s]

Step: 27200. PPL: 1.0001060962677002. Token Accuracy: 1.0


 54%|████████████████████                 | 27303/50500 [33:41<28:33, 13.54it/s]

Step: 27300. PPL: 1.0003211498260498. Token Accuracy: 1.0


 54%|████████████████████                 | 27403/50500 [33:48<28:29, 13.51it/s]

Step: 27400. PPL: 1.0002989768981934. Token Accuracy: 1.0


 54%|████████████████████▏                | 27503/50500 [33:56<28:25, 13.49it/s]

Step: 27500. PPL: 1.0001157522201538. Token Accuracy: 1.0


 55%|████████████████████▏                | 27603/50500 [34:03<28:12, 13.53it/s]

Step: 27600. PPL: 1.0001473426818848. Token Accuracy: 1.0


 55%|████████████████████▎                | 27703/50500 [34:11<28:10, 13.48it/s]

Step: 27700. PPL: 1.000036597251892. Token Accuracy: 1.0


 55%|████████████████████▎                | 27803/50500 [34:18<27:58, 13.52it/s]

Step: 27800. PPL: 1.0001583099365234. Token Accuracy: 1.0


 55%|████████████████████▍                | 27903/50500 [34:25<27:51, 13.52it/s]

Step: 27900. PPL: 1.000561237335205. Token Accuracy: 1.0


 55%|████████████████████▌                | 28003/50500 [34:33<27:46, 13.50it/s]

Step: 28000. PPL: 1.0000433921813965. Token Accuracy: 1.0


 56%|████████████████████▌                | 28103/50500 [34:40<27:36, 13.52it/s]

Step: 28100. PPL: 1.0000224113464355. Token Accuracy: 1.0


 56%|████████████████████▋                | 28203/50500 [34:48<27:32, 13.49it/s]

Step: 28200. PPL: 1.0002020597457886. Token Accuracy: 1.0


 56%|████████████████████▋                | 28303/50500 [34:55<27:27, 13.47it/s]

Step: 28300. PPL: 1.0000991821289062. Token Accuracy: 1.0


 56%|████████████████████▊                | 28403/50500 [35:02<27:15, 13.51it/s]

Step: 28400. PPL: 1.0061155557632446. Token Accuracy: 0.9979166388511658


 56%|████████████████████▉                | 28503/50500 [35:10<27:06, 13.52it/s]

Step: 28500. PPL: 1.0008578300476074. Token Accuracy: 1.0


 57%|████████████████████▉                | 28603/50500 [35:17<26:59, 13.52it/s]

Step: 28600. PPL: 1.0000845193862915. Token Accuracy: 1.0


 57%|█████████████████████                | 28703/50500 [35:25<26:53, 13.51it/s]

Step: 28700. PPL: 1.0000853538513184. Token Accuracy: 1.0


 57%|█████████████████████                | 28803/50500 [35:32<26:44, 13.52it/s]

Step: 28800. PPL: 1.0002168416976929. Token Accuracy: 1.0


 57%|█████████████████████▏               | 28903/50500 [35:39<26:39, 13.51it/s]

Step: 28900. PPL: 1.0000331401824951. Token Accuracy: 1.0


 57%|█████████████████████▏               | 29003/50500 [35:47<26:30, 13.52it/s]

Step: 29000. PPL: 1.000046968460083. Token Accuracy: 1.0


 58%|█████████████████████▎               | 29103/50500 [35:54<26:52, 13.27it/s]

Step: 29100. PPL: 1.0002899169921875. Token Accuracy: 1.0


 58%|█████████████████████▍               | 29203/50500 [36:02<26:16, 13.51it/s]

Step: 29200. PPL: 1.0000865459442139. Token Accuracy: 1.0


 58%|█████████████████████▍               | 29303/50500 [36:09<26:08, 13.52it/s]

Step: 29300. PPL: 1.0000700950622559. Token Accuracy: 1.0


 58%|█████████████████████▌               | 29403/50500 [36:16<26:02, 13.51it/s]

Step: 29400. PPL: 1.000633716583252. Token Accuracy: 1.0


 58%|█████████████████████▌               | 29503/50500 [36:24<25:54, 13.51it/s]

Step: 29500. PPL: 1.0003063678741455. Token Accuracy: 1.0


 59%|█████████████████████▋               | 29603/50500 [36:31<25:45, 13.52it/s]

Step: 29600. PPL: 1.0002012252807617. Token Accuracy: 1.0


 59%|█████████████████████▊               | 29703/50500 [36:39<25:41, 13.50it/s]

Step: 29700. PPL: 1.000170111656189. Token Accuracy: 1.0


 59%|█████████████████████▊               | 29803/50500 [36:46<25:31, 13.52it/s]

Step: 29800. PPL: 1.0001232624053955. Token Accuracy: 1.0


 59%|█████████████████████▉               | 29903/50500 [36:54<25:24, 13.51it/s]

Step: 29900. PPL: 1.0000582933425903. Token Accuracy: 1.0


 59%|█████████████████████▉               | 30003/50500 [37:01<25:17, 13.51it/s]

Step: 30000. PPL: 1.0000145435333252. Token Accuracy: 1.0


 60%|██████████████████████               | 30103/50500 [37:08<25:07, 13.53it/s]

Step: 30100. PPL: 1.00013267993927. Token Accuracy: 1.0


 60%|██████████████████████▏              | 30203/50500 [37:16<25:03, 13.50it/s]

Step: 30200. PPL: 1.0001342296600342. Token Accuracy: 1.0


 60%|██████████████████████▏              | 30303/50500 [37:23<24:59, 13.47it/s]

Step: 30300. PPL: 1.0005204677581787. Token Accuracy: 1.0


 60%|██████████████████████▎              | 30403/50500 [37:31<24:47, 13.51it/s]

Step: 30400. PPL: 1.0009536743164062. Token Accuracy: 1.0


 60%|██████████████████████▎              | 30503/50500 [37:38<24:41, 13.50it/s]

Step: 30500. PPL: 1.000638484954834. Token Accuracy: 1.0


 61%|██████████████████████▍              | 30603/50500 [37:45<24:34, 13.49it/s]

Step: 30600. PPL: 1.0003100633621216. Token Accuracy: 1.0


 61%|██████████████████████▍              | 30703/50500 [37:53<24:31, 13.45it/s]

Step: 30700. PPL: 1.0000369548797607. Token Accuracy: 1.0


 61%|██████████████████████▌              | 30803/50500 [38:00<24:17, 13.51it/s]

Step: 30800. PPL: 1.0000251531600952. Token Accuracy: 1.0


 61%|██████████████████████▋              | 30903/50500 [38:08<24:10, 13.51it/s]

Step: 30900. PPL: 1.0000693798065186. Token Accuracy: 1.0


 61%|██████████████████████▋              | 31003/50500 [38:15<24:01, 13.53it/s]

Step: 31000. PPL: 1.0000354051589966. Token Accuracy: 1.0


 62%|██████████████████████▊              | 31103/50500 [38:22<23:56, 13.51it/s]

Step: 31100. PPL: 1.0000094175338745. Token Accuracy: 1.0


 62%|██████████████████████▊              | 31203/50500 [38:30<23:46, 13.53it/s]

Step: 31200. PPL: 1.000022530555725. Token Accuracy: 1.0


 62%|██████████████████████▉              | 31303/50500 [38:37<23:40, 13.51it/s]

Step: 31300. PPL: 1.0010789632797241. Token Accuracy: 0.9989583492279053


 62%|███████████████████████              | 31403/50500 [38:45<23:36, 13.49it/s]

Step: 31400. PPL: 1.0000500679016113. Token Accuracy: 1.0


 62%|███████████████████████              | 31503/50500 [38:52<23:28, 13.49it/s]

Step: 31500. PPL: 1.0001344680786133. Token Accuracy: 1.0


 63%|███████████████████████▏             | 31603/50500 [38:59<23:17, 13.52it/s]

Step: 31600. PPL: 1.0000172853469849. Token Accuracy: 1.0


 63%|███████████████████████▏             | 31703/50500 [39:07<23:11, 13.51it/s]

Step: 31700. PPL: 1.0000247955322266. Token Accuracy: 1.0


 63%|███████████████████████▎             | 31803/50500 [39:14<23:03, 13.52it/s]

Step: 31800. PPL: 1.0001113414764404. Token Accuracy: 1.0


 63%|███████████████████████▎             | 31903/50500 [39:22<22:57, 13.50it/s]

Step: 31900. PPL: 1.000213623046875. Token Accuracy: 1.0


 63%|███████████████████████▍             | 32003/50500 [39:29<22:50, 13.50it/s]

Step: 32000. PPL: 1.0000406503677368. Token Accuracy: 1.0


 64%|███████████████████████▌             | 32103/50500 [39:36<22:43, 13.49it/s]

Step: 32100. PPL: 1.0002423524856567. Token Accuracy: 1.0


 64%|███████████████████████▌             | 32203/50500 [39:44<22:34, 13.51it/s]

Step: 32200. PPL: 1.0005323886871338. Token Accuracy: 1.0


 64%|███████████████████████▋             | 32303/50500 [39:51<22:24, 13.54it/s]

Step: 32300. PPL: 1.0000808238983154. Token Accuracy: 1.0


 64%|███████████████████████▋             | 32403/50500 [39:59<22:19, 13.51it/s]

Step: 32400. PPL: 1.0000336170196533. Token Accuracy: 1.0


 64%|███████████████████████▊             | 32503/50500 [40:06<22:10, 13.52it/s]

Step: 32500. PPL: 1.0000754594802856. Token Accuracy: 1.0


 65%|███████████████████████▉             | 32603/50500 [40:14<22:05, 13.50it/s]

Step: 32600. PPL: 1.0000418424606323. Token Accuracy: 1.0


 65%|███████████████████████▉             | 32703/50500 [40:21<21:56, 13.52it/s]

Step: 32700. PPL: 1.0000132322311401. Token Accuracy: 1.0


 65%|████████████████████████             | 32803/50500 [40:28<21:48, 13.53it/s]

Step: 32800. PPL: 1.0000382661819458. Token Accuracy: 1.0


 65%|████████████████████████             | 32903/50500 [40:36<21:48, 13.45it/s]

Step: 32900. PPL: 1.0001263618469238. Token Accuracy: 1.0


 65%|████████████████████████▏            | 33003/50500 [40:43<21:33, 13.52it/s]

Step: 33000. PPL: 1.0000160932540894. Token Accuracy: 1.0


 66%|████████████████████████▎            | 33103/50500 [40:51<21:28, 13.51it/s]

Step: 33100. PPL: 1.0000627040863037. Token Accuracy: 1.0


 66%|████████████████████████▎            | 33203/50500 [40:58<21:22, 13.49it/s]

Step: 33200. PPL: 1.000024676322937. Token Accuracy: 1.0


 66%|████████████████████████▍            | 33303/50500 [41:05<21:11, 13.52it/s]

Step: 33300. PPL: 1.000008463859558. Token Accuracy: 1.0


 66%|████████████████████████▍            | 33403/50500 [41:13<21:05, 13.51it/s]

Step: 33400. PPL: 1.000019907951355. Token Accuracy: 1.0


 66%|████████████████████████▌            | 33503/50500 [41:20<21:00, 13.49it/s]

Step: 33500. PPL: 1.0022441148757935. Token Accuracy: 0.9989583492279053


 67%|████████████████████████▌            | 33603/50500 [41:28<20:52, 13.49it/s]

Step: 33600. PPL: 1.0015994310379028. Token Accuracy: 0.9989583492279053


 67%|████████████████████████▋            | 33703/50500 [41:35<20:43, 13.51it/s]

Step: 33700. PPL: 1.0001524686813354. Token Accuracy: 1.0


 67%|████████████████████████▊            | 33803/50500 [41:42<20:35, 13.52it/s]

Step: 33800. PPL: 1.0000245571136475. Token Accuracy: 1.0


 67%|████████████████████████▊            | 33903/50500 [41:50<20:27, 13.52it/s]

Step: 33900. PPL: 1.0000660419464111. Token Accuracy: 1.0


 67%|████████████████████████▉            | 34003/50500 [41:57<20:22, 13.49it/s]

Step: 34000. PPL: 1.0002609491348267. Token Accuracy: 1.0


 68%|████████████████████████▉            | 34103/50500 [42:05<20:12, 13.52it/s]

Step: 34100. PPL: 1.0000630617141724. Token Accuracy: 1.0


 68%|█████████████████████████            | 34203/50500 [42:12<20:07, 13.49it/s]

Step: 34200. PPL: 1.0001062154769897. Token Accuracy: 1.0


 68%|█████████████████████████▏           | 34303/50500 [42:19<20:02, 13.47it/s]

Step: 34300. PPL: 1.0000274181365967. Token Accuracy: 1.0


 68%|█████████████████████████▏           | 34403/50500 [42:27<19:51, 13.51it/s]

Step: 34400. PPL: 1.0000197887420654. Token Accuracy: 1.0


 68%|█████████████████████████▎           | 34503/50500 [42:34<19:44, 13.51it/s]

Step: 34500. PPL: 1.001158356666565. Token Accuracy: 0.9989583492279053


 69%|█████████████████████████▎           | 34603/50500 [42:42<19:35, 13.52it/s]

Step: 34600. PPL: 1.0000944137573242. Token Accuracy: 1.0


 69%|█████████████████████████▍           | 34703/50500 [42:49<19:31, 13.49it/s]

Step: 34700. PPL: 1.0000985860824585. Token Accuracy: 1.0


 69%|█████████████████████████▍           | 34803/50500 [42:56<19:21, 13.52it/s]

Step: 34800. PPL: 1.0002139806747437. Token Accuracy: 1.0


 69%|█████████████████████████▌           | 34903/50500 [43:04<19:16, 13.49it/s]

Step: 34900. PPL: 1.000625729560852. Token Accuracy: 1.0


 69%|█████████████████████████▋           | 35003/50500 [43:11<19:06, 13.51it/s]

Step: 35000. PPL: 1.0058510303497314. Token Accuracy: 0.9989583492279053


 70%|█████████████████████████▋           | 35103/50500 [43:19<18:59, 13.51it/s]

Step: 35100. PPL: 1.000127911567688. Token Accuracy: 1.0


 70%|█████████████████████████▊           | 35203/50500 [43:26<18:53, 13.50it/s]

Step: 35200. PPL: 1.0002838373184204. Token Accuracy: 1.0


 70%|█████████████████████████▊           | 35303/50500 [43:34<18:44, 13.52it/s]

Step: 35300. PPL: 1.000035285949707. Token Accuracy: 1.0


 70%|█████████████████████████▉           | 35403/50500 [43:41<18:37, 13.51it/s]

Step: 35400. PPL: 1.0000247955322266. Token Accuracy: 1.0


 70%|██████████████████████████           | 35503/50500 [43:48<18:29, 13.52it/s]

Step: 35500. PPL: 1.0000181198120117. Token Accuracy: 1.0


 71%|██████████████████████████           | 35603/50500 [43:56<18:23, 13.50it/s]

Step: 35600. PPL: 1.0000468492507935. Token Accuracy: 1.0


 71%|██████████████████████████▏          | 35703/50500 [44:03<18:17, 13.49it/s]

Step: 35700. PPL: 1.0000579357147217. Token Accuracy: 1.0


 71%|██████████████████████████▏          | 35803/50500 [44:11<18:07, 13.52it/s]

Step: 35800. PPL: 1.0000334978103638. Token Accuracy: 1.0


 71%|██████████████████████████▎          | 35903/50500 [44:18<18:00, 13.51it/s]

Step: 35900. PPL: 1.000051736831665. Token Accuracy: 1.0


 71%|██████████████████████████▍          | 36003/50500 [44:25<17:53, 13.51it/s]

Step: 36000. PPL: 1.0000982284545898. Token Accuracy: 1.0


 71%|██████████████████████████▍          | 36103/50500 [44:33<17:45, 13.51it/s]

Step: 36100. PPL: 1.0021040439605713. Token Accuracy: 0.9989583492279053


 72%|██████████████████████████▌          | 36203/50500 [44:40<17:36, 13.53it/s]

Step: 36200. PPL: 1.0000331401824951. Token Accuracy: 1.0


 72%|██████████████████████████▌          | 36303/50500 [44:48<17:29, 13.52it/s]

Step: 36300. PPL: 1.0001001358032227. Token Accuracy: 1.0


 72%|██████████████████████████▋          | 36403/50500 [44:55<17:23, 13.51it/s]

Step: 36400. PPL: 1.0000301599502563. Token Accuracy: 1.0


 72%|██████████████████████████▋          | 36503/50500 [45:02<17:16, 13.50it/s]

Step: 36500. PPL: 1.0000475645065308. Token Accuracy: 1.0


 72%|██████████████████████████▊          | 36603/50500 [45:10<17:09, 13.50it/s]

Step: 36600. PPL: 1.0000159740447998. Token Accuracy: 1.0


 73%|██████████████████████████▉          | 36703/50500 [45:17<17:00, 13.52it/s]

Step: 36700. PPL: 1.0000159740447998. Token Accuracy: 1.0


 73%|██████████████████████████▉          | 36803/50500 [45:25<16:55, 13.49it/s]

Step: 36800. PPL: 1.0004221200942993. Token Accuracy: 1.0


 73%|███████████████████████████          | 36903/50500 [45:32<16:50, 13.46it/s]

Step: 36900. PPL: 1.0000267028808594. Token Accuracy: 1.0


 73%|███████████████████████████          | 37003/50500 [45:39<16:39, 13.51it/s]

Step: 37000. PPL: 1.0000369548797607. Token Accuracy: 1.0


 73%|███████████████████████████▏         | 37103/50500 [45:47<16:30, 13.53it/s]

Step: 37100. PPL: 1.000030517578125. Token Accuracy: 1.0


 74%|███████████████████████████▎         | 37203/50500 [45:54<16:43, 13.24it/s]

Step: 37200. PPL: 1.0010050535202026. Token Accuracy: 0.9989583492279053


 74%|███████████████████████████▎         | 37303/50500 [46:02<16:17, 13.51it/s]

Step: 37300. PPL: 1.0003987550735474. Token Accuracy: 1.0


 74%|███████████████████████████▍         | 37403/50500 [46:09<16:11, 13.48it/s]

Step: 37400. PPL: 1.0001126527786255. Token Accuracy: 1.0


 74%|███████████████████████████▍         | 37503/50500 [46:16<16:02, 13.50it/s]

Step: 37500. PPL: 1.0187761783599854. Token Accuracy: 0.9989583492279053


 74%|███████████████████████████▌         | 37603/50500 [46:24<15:55, 13.50it/s]

Step: 37600. PPL: 1.0000983476638794. Token Accuracy: 1.0


 75%|███████████████████████████▌         | 37703/50500 [46:31<15:47, 13.51it/s]

Step: 37700. PPL: 1.000339388847351. Token Accuracy: 1.0


 75%|███████████████████████████▋         | 37803/50500 [46:39<15:39, 13.51it/s]

Step: 37800. PPL: 1.0000298023223877. Token Accuracy: 1.0


 75%|███████████████████████████▊         | 37903/50500 [46:46<15:31, 13.53it/s]

Step: 37900. PPL: 1.0000134706497192. Token Accuracy: 1.0


 75%|███████████████████████████▊         | 38003/50500 [46:54<15:25, 13.51it/s]

Step: 38000. PPL: 1.0000557899475098. Token Accuracy: 1.0


 75%|███████████████████████████▉         | 38103/50500 [47:01<15:18, 13.50it/s]

Step: 38100. PPL: 1.000016689300537. Token Accuracy: 1.0


 76%|███████████████████████████▉         | 38203/50500 [47:08<15:11, 13.50it/s]

Step: 38200. PPL: 1.0004606246948242. Token Accuracy: 1.0


 76%|████████████████████████████         | 38303/50500 [47:16<15:02, 13.51it/s]

Step: 38300. PPL: 1.0001139640808105. Token Accuracy: 1.0


 76%|████████████████████████████▏        | 38403/50500 [47:23<14:56, 13.49it/s]

Step: 38400. PPL: 1.0018631219863892. Token Accuracy: 0.9989583492279053


 76%|████████████████████████████▏        | 38503/50500 [47:31<14:47, 13.51it/s]

Step: 38500. PPL: 1.000035047531128. Token Accuracy: 1.0


 76%|████████████████████████████▎        | 38603/50500 [47:38<14:41, 13.50it/s]

Step: 38600. PPL: 1.0000576972961426. Token Accuracy: 1.0


 77%|████████████████████████████▎        | 38703/50500 [47:45<14:32, 13.53it/s]

Step: 38700. PPL: 1.0000998973846436. Token Accuracy: 1.0


 77%|████████████████████████████▍        | 38803/50500 [47:53<14:28, 13.46it/s]

Step: 38800. PPL: 1.001684546470642. Token Accuracy: 0.9989583492279053


 77%|████████████████████████████▌        | 38903/50500 [48:00<14:20, 13.48it/s]

Step: 38900. PPL: 1.000025987625122. Token Accuracy: 1.0


 77%|████████████████████████████▌        | 39003/50500 [48:08<14:10, 13.51it/s]

Step: 39000. PPL: 1.0125720500946045. Token Accuracy: 0.9947916865348816


 77%|████████████████████████████▋        | 39103/50500 [48:15<14:02, 13.52it/s]

Step: 39100. PPL: 1.0001999139785767. Token Accuracy: 1.0


 78%|████████████████████████████▋        | 39203/50500 [48:22<13:55, 13.52it/s]

Step: 39200. PPL: 1.0000218152999878. Token Accuracy: 1.0


 78%|████████████████████████████▊        | 39303/50500 [48:30<13:49, 13.50it/s]

Step: 39300. PPL: 1.000466227531433. Token Accuracy: 1.0


 78%|████████████████████████████▊        | 39403/50500 [48:37<13:42, 13.49it/s]

Step: 39400. PPL: 1.0002427101135254. Token Accuracy: 1.0


 78%|████████████████████████████▉        | 39503/50500 [48:45<13:35, 13.48it/s]

Step: 39500. PPL: 1.0018055438995361. Token Accuracy: 1.0


 78%|█████████████████████████████        | 39603/50500 [48:52<13:27, 13.49it/s]

Step: 39600. PPL: 1.0002292394638062. Token Accuracy: 1.0


 79%|█████████████████████████████        | 39703/50500 [48:59<13:18, 13.52it/s]

Step: 39700. PPL: 1.000567078590393. Token Accuracy: 1.0


 79%|█████████████████████████████▏       | 39803/50500 [49:07<13:14, 13.46it/s]

Step: 39800. PPL: 1.0000823736190796. Token Accuracy: 1.0


 79%|█████████████████████████████▏       | 39903/50500 [49:14<13:05, 13.50it/s]

Step: 39900. PPL: 1.0012288093566895. Token Accuracy: 0.9989583492279053


 79%|█████████████████████████████▎       | 40003/50500 [49:22<12:56, 13.52it/s]

Step: 40000. PPL: 1.0001347064971924. Token Accuracy: 1.0


 79%|█████████████████████████████▍       | 40103/50500 [49:29<12:50, 13.50it/s]

Step: 40100. PPL: 1.0000286102294922. Token Accuracy: 1.0


 80%|█████████████████████████████▍       | 40203/50500 [49:37<12:40, 13.54it/s]

Step: 40200. PPL: 1.000014066696167. Token Accuracy: 1.0


 80%|█████████████████████████████▌       | 40303/50500 [49:44<12:34, 13.52it/s]

Step: 40300. PPL: 1.0000686645507812. Token Accuracy: 1.0


 80%|█████████████████████████████▌       | 40403/50500 [49:51<12:32, 13.42it/s]

Step: 40400. PPL: 1.000372290611267. Token Accuracy: 1.0


 80%|█████████████████████████████▋       | 40503/50500 [49:59<12:20, 13.51it/s]

Step: 40500. PPL: 1.0003855228424072. Token Accuracy: 1.0


 80%|█████████████████████████████▋       | 40603/50500 [50:06<12:13, 13.49it/s]

Step: 40600. PPL: 1.0000369548797607. Token Accuracy: 1.0


 81%|█████████████████████████████▊       | 40703/50500 [50:14<12:05, 13.50it/s]

Step: 40700. PPL: 1.0000439882278442. Token Accuracy: 1.0


 81%|█████████████████████████████▉       | 40803/50500 [50:21<11:57, 13.51it/s]

Step: 40800. PPL: 1.0000085830688477. Token Accuracy: 1.0


 81%|█████████████████████████████▉       | 40903/50500 [50:28<11:50, 13.50it/s]

Step: 40900. PPL: 1.000017523765564. Token Accuracy: 1.0


 81%|██████████████████████████████       | 41003/50500 [50:36<11:43, 13.50it/s]

Step: 41000. PPL: 1.0000123977661133. Token Accuracy: 1.0


 81%|██████████████████████████████       | 41103/50500 [50:43<11:36, 13.49it/s]

Step: 41100. PPL: 1.0003021955490112. Token Accuracy: 1.0


 82%|██████████████████████████████▏      | 41203/50500 [50:51<11:29, 13.49it/s]

Step: 41200. PPL: 1.0002772808074951. Token Accuracy: 1.0


 82%|██████████████████████████████▎      | 41303/50500 [50:58<11:21, 13.50it/s]

Step: 41300. PPL: 1.0000221729278564. Token Accuracy: 1.0


 82%|██████████████████████████████▎      | 41403/50500 [51:05<11:17, 13.42it/s]

Step: 41400. PPL: 1.000033974647522. Token Accuracy: 1.0


 82%|██████████████████████████████▍      | 41503/50500 [51:13<11:16, 13.30it/s]

Step: 41500. PPL: 1.000327706336975. Token Accuracy: 1.0


 82%|██████████████████████████████▍      | 41603/50500 [51:20<10:58, 13.51it/s]

Step: 41600. PPL: 1.0035134553909302. Token Accuracy: 0.9989583492279053


 83%|██████████████████████████████▌      | 41703/50500 [51:28<10:51, 13.51it/s]

Step: 41700. PPL: 1.0000898838043213. Token Accuracy: 1.0


 83%|██████████████████████████████▋      | 41803/50500 [51:35<10:43, 13.51it/s]

Step: 41800. PPL: 1.000082015991211. Token Accuracy: 1.0


 83%|██████████████████████████████▋      | 41903/50500 [51:43<10:36, 13.51it/s]

Step: 41900. PPL: 1.0001062154769897. Token Accuracy: 1.0


 83%|██████████████████████████████▊      | 42003/50500 [51:50<10:30, 13.47it/s]

Step: 42000. PPL: 1.0000325441360474. Token Accuracy: 1.0


 83%|██████████████████████████████▊      | 42103/50500 [51:58<10:21, 13.51it/s]

Step: 42100. PPL: 1.00001859664917. Token Accuracy: 1.0


 84%|██████████████████████████████▉      | 42203/50500 [52:05<10:16, 13.46it/s]

Step: 42200. PPL: 1.0000288486480713. Token Accuracy: 1.0


 84%|██████████████████████████████▉      | 42303/50500 [52:12<10:06, 13.51it/s]

Step: 42300. PPL: 1.0025142431259155. Token Accuracy: 0.9989583492279053


 84%|███████████████████████████████      | 42403/50500 [52:20<10:00, 13.49it/s]

Step: 42400. PPL: 1.000069260597229. Token Accuracy: 1.0


 84%|███████████████████████████████▏     | 42503/50500 [52:27<09:52, 13.49it/s]

Step: 42500. PPL: 1.000030279159546. Token Accuracy: 1.0


 84%|███████████████████████████████▏     | 42603/50500 [52:35<09:43, 13.52it/s]

Step: 42600. PPL: 1.0000159740447998. Token Accuracy: 1.0


 85%|███████████████████████████████▎     | 42703/50500 [52:42<09:37, 13.50it/s]

Step: 42700. PPL: 1.0000170469284058. Token Accuracy: 1.0


 85%|███████████████████████████████▎     | 42803/50500 [52:49<09:29, 13.51it/s]

Step: 42800. PPL: 1.0001038312911987. Token Accuracy: 1.0


 85%|███████████████████████████████▍     | 42903/50500 [52:57<09:22, 13.49it/s]

Step: 42900. PPL: 1.0000957250595093. Token Accuracy: 1.0


 85%|███████████████████████████████▌     | 43003/50500 [53:04<09:15, 13.50it/s]

Step: 43000. PPL: 1.0000808238983154. Token Accuracy: 1.0


 85%|███████████████████████████████▌     | 43103/50500 [53:12<09:07, 13.50it/s]

Step: 43100. PPL: 1.0000139474868774. Token Accuracy: 1.0


 86%|███████████████████████████████▋     | 43203/50500 [53:19<09:00, 13.51it/s]

Step: 43200. PPL: 1.0000196695327759. Token Accuracy: 1.0


 86%|███████████████████████████████▋     | 43303/50500 [53:27<08:54, 13.48it/s]

Step: 43300. PPL: 1.0004627704620361. Token Accuracy: 1.0


 86%|███████████████████████████████▊     | 43403/50500 [53:34<08:45, 13.51it/s]

Step: 43400. PPL: 1.0018861293792725. Token Accuracy: 0.9989583492279053


 86%|███████████████████████████████▊     | 43503/50500 [53:41<08:37, 13.51it/s]

Step: 43500. PPL: 1.0001494884490967. Token Accuracy: 1.0


 86%|███████████████████████████████▉     | 43603/50500 [53:49<08:30, 13.50it/s]

Step: 43600. PPL: 1.0029720067977905. Token Accuracy: 0.9989583492279053


 87%|████████████████████████████████     | 43703/50500 [53:56<08:23, 13.49it/s]

Step: 43700. PPL: 1.0000660419464111. Token Accuracy: 1.0


 87%|████████████████████████████████     | 43803/50500 [54:04<08:16, 13.49it/s]

Step: 43800. PPL: 1.0000733137130737. Token Accuracy: 1.0


 87%|████████████████████████████████▏    | 43903/50500 [54:11<08:08, 13.50it/s]

Step: 43900. PPL: 1.0000165700912476. Token Accuracy: 1.0


 87%|████████████████████████████████▏    | 44003/50500 [54:18<08:01, 13.50it/s]

Step: 44000. PPL: 1.000017523765564. Token Accuracy: 1.0


 87%|████████████████████████████████▎    | 44103/50500 [54:26<07:53, 13.51it/s]

Step: 44100. PPL: 1.0000323057174683. Token Accuracy: 1.0


 88%|████████████████████████████████▍    | 44203/50500 [54:33<07:46, 13.51it/s]

Step: 44200. PPL: 1.0000344514846802. Token Accuracy: 1.0


 88%|████████████████████████████████▍    | 44303/50500 [54:41<07:38, 13.51it/s]

Step: 44300. PPL: 1.000022292137146. Token Accuracy: 1.0


 88%|████████████████████████████████▌    | 44403/50500 [54:48<07:32, 13.49it/s]

Step: 44400. PPL: 1.0000057220458984. Token Accuracy: 1.0


 88%|████████████████████████████████▌    | 44503/50500 [54:55<07:23, 13.51it/s]

Step: 44500. PPL: 1.0000416040420532. Token Accuracy: 1.0


 88%|████████████████████████████████▋    | 44603/50500 [55:03<07:17, 13.49it/s]

Step: 44600. PPL: 1.0000110864639282. Token Accuracy: 1.0


 89%|████████████████████████████████▊    | 44703/50500 [55:10<07:09, 13.51it/s]

Step: 44700. PPL: 1.000043511390686. Token Accuracy: 1.0


 89%|████████████████████████████████▊    | 44803/50500 [55:18<07:02, 13.49it/s]

Step: 44800. PPL: 1.0000221729278564. Token Accuracy: 1.0


 89%|████████████████████████████████▉    | 44903/50500 [55:25<06:54, 13.50it/s]

Step: 44900. PPL: 1.000005841255188. Token Accuracy: 1.0


 89%|████████████████████████████████▉    | 45003/50500 [55:33<06:46, 13.52it/s]

Step: 45000. PPL: 1.000036358833313. Token Accuracy: 1.0


 89%|█████████████████████████████████    | 45103/50500 [55:40<06:40, 13.49it/s]

Step: 45100. PPL: 1.0000057220458984. Token Accuracy: 1.0


 90%|█████████████████████████████████    | 45203/50500 [55:47<06:32, 13.50it/s]

Step: 45200. PPL: 1.0001120567321777. Token Accuracy: 1.0


 90%|█████████████████████████████████▏   | 45303/50500 [55:55<06:27, 13.40it/s]

Step: 45300. PPL: 1.0001050233840942. Token Accuracy: 1.0


 90%|█████████████████████████████████▎   | 45403/50500 [56:02<06:17, 13.50it/s]

Step: 45400. PPL: 1.0003066062927246. Token Accuracy: 1.0


 90%|█████████████████████████████████▎   | 45503/50500 [56:10<06:10, 13.50it/s]

Step: 45500. PPL: 1.000056505203247. Token Accuracy: 1.0


 90%|█████████████████████████████████▍   | 45603/50500 [56:17<06:02, 13.51it/s]

Step: 45600. PPL: 1.0000481605529785. Token Accuracy: 1.0


 91%|█████████████████████████████████▍   | 45703/50500 [56:24<05:55, 13.51it/s]

Step: 45700. PPL: 1.0000423192977905. Token Accuracy: 1.0


 91%|█████████████████████████████████▌   | 45803/50500 [56:32<05:48, 13.49it/s]

Step: 45800. PPL: 1.0000479221343994. Token Accuracy: 1.0


 91%|█████████████████████████████████▋   | 45903/50500 [56:39<05:40, 13.49it/s]

Step: 45900. PPL: 1.0001293420791626. Token Accuracy: 1.0


 91%|█████████████████████████████████▋   | 46003/50500 [56:47<05:32, 13.52it/s]

Step: 46000. PPL: 1.0002399682998657. Token Accuracy: 1.0


 91%|█████████████████████████████████▊   | 46103/50500 [56:54<05:25, 13.52it/s]

Step: 46100. PPL: 1.000065803527832. Token Accuracy: 1.0


 91%|█████████████████████████████████▊   | 46203/50500 [57:01<05:18, 13.50it/s]

Step: 46200. PPL: 1.0000662803649902. Token Accuracy: 1.0


 92%|█████████████████████████████████▉   | 46303/50500 [57:09<05:10, 13.50it/s]

Step: 46300. PPL: 1.0000760555267334. Token Accuracy: 1.0


 92%|█████████████████████████████████▉   | 46403/50500 [57:16<05:03, 13.50it/s]

Step: 46400. PPL: 1.000015377998352. Token Accuracy: 1.0


 92%|██████████████████████████████████   | 46503/50500 [57:24<04:56, 13.50it/s]

Step: 46500. PPL: 1.0000271797180176. Token Accuracy: 1.0


 92%|██████████████████████████████████▏  | 46603/50500 [57:31<04:48, 13.50it/s]

Step: 46600. PPL: 1.0000261068344116. Token Accuracy: 1.0


 92%|██████████████████████████████████▏  | 46703/50500 [57:39<04:41, 13.50it/s]

Step: 46700. PPL: 1.0000336170196533. Token Accuracy: 1.0


 93%|██████████████████████████████████▎  | 46803/50500 [57:46<04:35, 13.44it/s]

Step: 46800. PPL: 1.0012298822402954. Token Accuracy: 1.0


 93%|██████████████████████████████████▎  | 46903/50500 [57:53<04:26, 13.48it/s]

Step: 46900. PPL: 1.0000386238098145. Token Accuracy: 1.0


 93%|██████████████████████████████████▍  | 47003/50500 [58:01<04:18, 13.51it/s]

Step: 47000. PPL: 1.0000250339508057. Token Accuracy: 1.0


 93%|██████████████████████████████████▌  | 47103/50500 [58:08<04:11, 13.50it/s]

Step: 47100. PPL: 1.000004768371582. Token Accuracy: 1.0


 93%|██████████████████████████████████▌  | 47203/50500 [58:16<04:04, 13.48it/s]

Step: 47200. PPL: 1.000016689300537. Token Accuracy: 1.0


 94%|██████████████████████████████████▋  | 47303/50500 [58:23<03:56, 13.53it/s]

Step: 47300. PPL: 1.0000803470611572. Token Accuracy: 1.0


 94%|██████████████████████████████████▋  | 47403/50500 [58:30<03:49, 13.52it/s]

Step: 47400. PPL: 1.0000067949295044. Token Accuracy: 1.0


 94%|██████████████████████████████████▊  | 47503/50500 [58:38<03:41, 13.53it/s]

Step: 47500. PPL: 1.0000061988830566. Token Accuracy: 1.0


 94%|██████████████████████████████████▉  | 47603/50500 [58:45<03:34, 13.50it/s]

Step: 47600. PPL: 1.0001540184020996. Token Accuracy: 1.0


 94%|██████████████████████████████████▉  | 47703/50500 [58:53<03:27, 13.49it/s]

Step: 47700. PPL: 1.0000118017196655. Token Accuracy: 1.0


 95%|███████████████████████████████████  | 47803/50500 [59:00<03:19, 13.51it/s]

Step: 47800. PPL: 1.000018835067749. Token Accuracy: 1.0


 95%|███████████████████████████████████  | 47903/50500 [59:07<03:12, 13.51it/s]

Step: 47900. PPL: 1.0025529861450195. Token Accuracy: 0.9989583492279053


 95%|███████████████████████████████████▏ | 48003/50500 [59:15<03:04, 13.53it/s]

Step: 48000. PPL: 1.0001583099365234. Token Accuracy: 1.0


 95%|███████████████████████████████████▏ | 48103/50500 [59:22<02:57, 13.52it/s]

Step: 48100. PPL: 1.000015139579773. Token Accuracy: 1.0


 95%|███████████████████████████████████▎ | 48203/50500 [59:30<02:49, 13.52it/s]

Step: 48200. PPL: 1.0002949237823486. Token Accuracy: 1.0


 96%|███████████████████████████████████▍ | 48303/50500 [59:37<02:42, 13.55it/s]

Step: 48300. PPL: 1.0000081062316895. Token Accuracy: 1.0


 96%|███████████████████████████████████▍ | 48403/50500 [59:44<02:35, 13.52it/s]

Step: 48400. PPL: 1.0000168085098267. Token Accuracy: 1.0


 96%|███████████████████████████████████▌ | 48503/50500 [59:52<02:33, 13.01it/s]

Step: 48500. PPL: 1.0003025531768799. Token Accuracy: 1.0


 96%|█████████████████████████████████▋ | 48603/50500 [1:00:00<02:24, 13.14it/s]

Step: 48600. PPL: 1.0000321865081787. Token Accuracy: 1.0


 96%|█████████████████████████████████▊ | 48703/50500 [1:00:07<02:16, 13.14it/s]

Step: 48700. PPL: 1.000006914138794. Token Accuracy: 1.0


 97%|█████████████████████████████████▊ | 48803/50500 [1:00:15<02:08, 13.17it/s]

Step: 48800. PPL: 1.0000264644622803. Token Accuracy: 1.0


 97%|█████████████████████████████████▉ | 48903/50500 [1:00:22<02:01, 13.12it/s]

Step: 48900. PPL: 1.000037431716919. Token Accuracy: 1.0


 97%|█████████████████████████████████▉ | 49003/50500 [1:00:30<01:54, 13.08it/s]

Step: 49000. PPL: 1.0000396966934204. Token Accuracy: 1.0


 97%|██████████████████████████████████ | 49103/50500 [1:00:38<01:46, 13.17it/s]

Step: 49100. PPL: 1.0011378526687622. Token Accuracy: 0.9989583492279053


 97%|██████████████████████████████████ | 49203/50500 [1:00:45<01:39, 13.08it/s]

Step: 49200. PPL: 1.0002645254135132. Token Accuracy: 1.0


 98%|██████████████████████████████████▏| 49303/50500 [1:00:53<01:31, 13.15it/s]

Step: 49300. PPL: 1.0000275373458862. Token Accuracy: 1.0


 98%|██████████████████████████████████▏| 49403/50500 [1:01:01<01:23, 13.15it/s]

Step: 49400. PPL: 1.0000337362289429. Token Accuracy: 1.0


 98%|██████████████████████████████████▎| 49503/50500 [1:01:08<01:15, 13.14it/s]

Step: 49500. PPL: 1.0000237226486206. Token Accuracy: 1.0


 98%|██████████████████████████████████▍| 49603/50500 [1:01:16<01:08, 13.13it/s]

Step: 49600. PPL: 1.0000126361846924. Token Accuracy: 1.0


 98%|██████████████████████████████████▍| 49703/50500 [1:01:23<01:00, 13.13it/s]

Step: 49700. PPL: 1.0000200271606445. Token Accuracy: 1.0


 99%|██████████████████████████████████▌| 49803/50500 [1:01:31<00:52, 13.15it/s]

Step: 49800. PPL: 1.0000113248825073. Token Accuracy: 1.0


 99%|██████████████████████████████████▌| 49903/50500 [1:01:39<00:45, 13.11it/s]

Step: 49900. PPL: 1.0000189542770386. Token Accuracy: 1.0


 99%|██████████████████████████████████▋| 50003/50500 [1:01:46<00:37, 13.15it/s]

Step: 50000. PPL: 1.0000158548355103. Token Accuracy: 1.0


 99%|██████████████████████████████████▋| 50103/50500 [1:01:54<00:30, 13.16it/s]

Step: 50100. PPL: 1.000252366065979. Token Accuracy: 1.0


 99%|██████████████████████████████████▊| 50203/50500 [1:02:01<00:22, 13.18it/s]

Step: 50200. PPL: 1.0000078678131104. Token Accuracy: 1.0


100%|██████████████████████████████████▊| 50303/50500 [1:02:09<00:15, 13.12it/s]

Step: 50300. PPL: 1.0006228685379028. Token Accuracy: 1.0


100%|██████████████████████████████████▉| 50403/50500 [1:02:17<00:07, 13.13it/s]

Step: 50400. PPL: 1.0001676082611084. Token Accuracy: 1.0


100%|███████████████████████████████████| 50500/50500 [1:02:24<00:00, 13.49it/s]
  2%|▋                                           | 1/63 [00:00<00:15,  4.00it/s]

Input:  5 6 3 2 * 7 4 3 4 
Target:  5 5 5 6 1 + 0 0 6 4 9 0 ( 5 5 1 1 1 1 ) + 0 0 5 9 0 7 0 ( 5 5 6 0 2 8 0 ) + 0 0 0 0 6 4 9 0  #### 5 5 6 0 8 2 0 1 
Predicted:  5 5 5 6 1 + 0 0 6 4 9 0 ( 5 5 1 1 1 1 ) + 0 0 5 9 0 7 0 ( 5 5 6 0 2 8 0 ) + 0 0 0 0 6 4 9 0  #### 5 5 6 0 8 2 0 1 



  3%|█▍                                          | 2/63 [00:00<00:14,  4.31it/s]

Input:  8 2 5 7 * 2 1 3 8 
Target:  6 5 0 5 1 + 0 8 2 5 7 0 ( 6 3 3 0 9 0 ) + 0 0 4 8 5 2 2 ( 6 3 7 8 4 3 2 ) + 0 0 0 4 2 2 0 6  #### 6 3 7 2 7 5 2 6 
Predicted:  6 5 0 5 1 + 0 8 2 5 7 0 ( 6 3 3 0 9 0 ) + 0 0 4 8 5 2 2 ( 6 3 7 8 4 3 2 ) + 0 0 0 4 2 2 0 6  #### 6 3 7 2 7 5 2 6 



  5%|██                                          | 3/63 [00:00<00:13,  4.45it/s]

Input:  1 5 6 5 * 9 6 4 1 
Target:  9 5 8 0 5 + 0 6 0 9 3 3 ( 9 1 9 9 8 3 ) + 0 0 4 0 6 2 2 ( 9 1 3 0 5 6 2 ) + 0 0 0 1 5 6 5 0  #### 9 1 3 1 0 3 8 0 
Predicted:  9 5 8 0 5 + 0 6 0 9 3 3 ( 9 1 9 9 8 3 ) + 0 0 4 0 6 2 2 ( 9 1 3 0 5 6 2 ) + 0 0 0 1 5 6 5 0  #### 9 1 3 1 0 3 8 0 



  6%|██▊                                         | 4/63 [00:00<00:13,  4.47it/s]

Input:  3 8 2 9 * 8 2 1 6 
Target:  4 6 2 4 7 + 0 6 6 5 8 1 ( 4 2 9 9 5 2 ) + 0 0 3 8 2 9 0 ( 4 2 2 8 8 1 1 ) + 0 0 0 8 9 6 5 5  #### 4 2 2 6 8 8 6 5 
Predicted:  4 6 2 4 7 + 0 6 6 5 8 1 ( 4 2 9 9 5 2 ) + 0 0 3 8 2 9 0 ( 4 2 2 8 8 1 1 ) + 0 0 0 8 9 6 5 5  #### 4 2 2 6 8 8 6 5 



  8%|███▍                                        | 5/63 [00:01<00:12,  4.51it/s]

Input:  8 5 4 1 * 3 9 6 6 
Target:  4 7 3 4 0 + 0 2 2 1 3 1 ( 4 9 5 5 3 1 ) + 0 0 8 4 7 8 0 ( 4 9 3 0 1 0 1 ) + 0 0 0 8 4 7 8 0  #### 4 9 3 8 5 7 9 0 
Predicted:  4 7 3 4 0 + 0 2 2 1 3 1 ( 4 9 5 5 3 1 ) + 0 0 8 4 7 8 0 ( 4 9 3 0 1 0 1 ) + 0 0 0 8 4 7 8 0  #### 4 9 3 8 5 7 9 0 



 10%|████▏                                       | 6/63 [00:01<00:12,  4.53it/s]

Input:  3 9 9 5 * 6 0 0 6 
Target:  8 5 9 5 3 + 0 0 0 0 0 0 ( 8 5 9 5 3 0 ) + 0 0 0 0 0 0 0 ( 8 5 9 5 3 0 0 ) + 0 0 0 8 5 9 5 3  #### 8 5 9 3 9 9 5 3 
Predicted:  8 5 9 5 3 + 0 0 0 0 0 0 ( 8 5 9 5 3 0 ) + 0 0 0 0 0 0 0 ( 8 5 9 5 3 0 0 ) + 0 0 0 8 5 9 5 3  #### 8 5 9 3 9 9 5 3 



 11%|████▉                                       | 7/63 [00:01<00:12,  4.55it/s]

Input:  3 0 6 5 * 7 1 2 5 
Target:  1 2 2 9 3 + 0 3 0 6 5 0 ( 1 5 2 5 9 0 ) + 0 0 6 0 2 1 1 ( 1 5 8 5 1 2 1 ) + 0 0 0 5 1 0 8 2  #### 1 5 8 0 3 2 9 2 
Predicted:  1 2 2 9 3 + 0 3 0 6 5 0 ( 1 5 2 5 9 0 ) + 0 0 6 0 2 1 1 ( 1 5 8 5 1 2 1 ) + 0 0 0 5 1 0 8 2  #### 1 5 8 0 3 2 9 2 



 13%|█████▌                                      | 8/63 [00:01<00:12,  4.56it/s]

Input:  9 4 4 9 * 5 6 1 7 
Target:  5 4 2 7 4 + 0 4 9 6 6 5 ( 5 8 1 4 1 6 ) + 0 0 9 4 4 9 0 ( 5 8 0 9 5 5 1 ) + 0 0 0 3 4 1 6 6  #### 5 8 0 2 0 7 7 6 
Predicted:  5 4 2 7 4 + 0 4 9 6 6 5 ( 5 8 1 4 1 6 ) + 0 0 9 4 4 9 0 ( 5 8 0 9 5 5 1 ) + 0 0 0 3 4 1 6 6  #### 5 8 0 2 0 7 7 6 



 14%|██████▎                                     | 9/63 [00:02<00:11,  4.55it/s]

Input:  3 5 0 3 * 6 6 9 2 
Target:  8 1 3 8 1 + 0 8 1 3 8 1 ( 8 9 4 1 0 2 ) + 0 0 7 7 4 7 2 ( 8 9 1 9 4 9 2 ) + 0 0 0 6 0 1 6 0  #### 8 9 1 5 5 0 9 0 
Predicted:  8 1 3 8 1 + 0 8 1 3 8 1 ( 8 9 4 1 0 2 ) + 0 0 7 7 4 7 2 ( 8 9 1 9 4 9 2 ) + 0 0 0 6 0 1 6 0  #### 8 9 1 5 5 0 9 0 



 16%|██████▊                                    | 10/63 [00:02<00:11,  4.55it/s]

Input:  3 9 1 9 * 6 4 8 6 
Target:  8 5 1 5 5 + 0 2 7 7 6 3 ( 8 7 8 2 2 4 ) + 0 0 4 4 5 3 7 ( 8 7 2 7 7 7 7 ) + 0 0 0 8 5 1 5 5  #### 8 7 2 5 3 9 2 6 
Predicted:  8 5 1 5 5 + 0 2 7 7 6 3 ( 8 7 8 2 2 4 ) + 0 0 4 4 5 3 7 ( 8 7 2 7 7 7 7 ) + 0 0 0 8 5 1 5 5  #### 8 7 2 5 3 9 2 6 



 17%|███████▌                                   | 11/63 [00:02<00:11,  4.55it/s]

Input:  4 8 2 2 * 0 0 5 4 
Target:  0 0 0 0 0 + 0 0 0 0 0 0 ( 0 0 0 0 0 0 ) + 0 0 0 2 4 1 1 ( 0 0 0 2 4 1 1 ) + 0 0 0 6 3 1 9 0  #### 0 0 0 8 7 2 0 1 
Predicted:  0 0 0 0 0 + 0 0 0 0 0 0 ( 0 0 0 0 0 0 ) + 0 0 0 2 4 1 1 ( 0 0 0 2 4 1 1 ) + 0 0 0 6 3 1 9 0  #### 0 0 0 8 7 2 0 1 



 19%|████████▏                                  | 12/63 [00:02<00:11,  4.56it/s]

Input:  0 6 5 5 * 9 2 7 4 
Target:  0 4 0 0 5 + 0 0 2 1 1 1 ( 0 4 2 1 6 1 ) + 0 0 0 2 9 8 3 ( 0 4 2 3 5 0 4 ) + 0 0 0 0 4 2 2 2  #### 0 4 2 3 9 2 6 2 
Predicted:  0 4 0 0 5 + 0 0 2 1 1 1 ( 0 4 2 1 6 1 ) + 0 0 0 2 9 8 3 ( 0 4 2 3 5 0 4 ) + 0 0 0 0 4 2 2 2  #### 0 4 2 3 9 2 6 2 



 21%|████████▊                                  | 13/63 [00:02<00:10,  4.57it/s]

Input:  0 8 0 3 * 6 0 4 8 
Target:  0 8 4 8 1 + 0 0 0 0 0 0 ( 0 8 4 8 1 0 ) + 0 0 0 2 3 2 1 ( 0 8 4 0 5 2 1 ) + 0 0 0 0 4 6 4 2  #### 0 8 4 0 9 8 5 2 
Predicted:  0 8 4 8 1 + 0 0 0 0 0 0 ( 0 8 4 8 1 0 ) + 0 0 0 2 3 2 1 ( 0 8 4 0 5 2 1 ) + 0 0 0 0 4 6 4 2  #### 0 8 4 0 9 8 5 2 



 22%|█████████▌                                 | 14/63 [00:03<00:10,  4.57it/s]

Input:  9 5 5 7 * 5 7 1 3 
Target:  5 9 7 7 3 + 0 3 1 9 2 5 ( 5 2 9 6 6 5 ) + 0 0 9 5 5 7 0 ( 5 2 8 2 2 3 1 ) + 0 0 0 7 7 6 2 2  #### 5 2 8 9 9 9 3 2 
Predicted:  5 9 7 7 3 + 0 3 1 9 2 5 ( 5 2 9 6 6 5 ) + 0 0 9 5 5 7 0 ( 5 2 8 2 2 3 1 ) + 0 0 0 7 7 6 2 2  #### 5 2 8 9 9 9 3 2 



 24%|██████████▏                                | 15/63 [00:03<00:10,  4.57it/s]

Input:  0 9 9 5 * 3 7 7 5 
Target:  0 7 9 7 1 + 0 0 3 9 1 4 ( 0 7 2 7 3 4 ) + 0 0 0 3 9 1 4 ( 0 7 2 0 3 6 4 ) + 0 0 0 0 5 9 9 2  #### 0 7 2 0 8 5 4 3 
Predicted:  0 7 9 7 1 + 0 0 3 9 1 4 ( 0 7 2 7 3 4 ) + 0 0 0 3 9 1 4 ( 0 7 2 0 3 6 4 ) + 0 0 0 0 5 9 9 2  #### 0 7 2 0 8 5 4 3 



 25%|██████████▉                                | 16/63 [00:03<00:10,  4.57it/s]

Input:  4 6 3 9 * 3 4 8 1 
Target:  2 9 0 8 2 + 0 6 5 4 7 3 ( 2 5 6 2 0 4 ) + 0 0 2 1 9 4 7 ( 2 5 8 3 9 8 7 ) + 0 0 0 4 6 3 9 0  #### 2 5 8 7 5 2 7 1 
Predicted:  2 9 0 8 2 + 0 6 5 4 7 3 ( 2 5 6 2 0 4 ) + 0 0 2 1 9 4 7 ( 2 5 8 3 9 8 7 ) + 0 0 0 4 6 3 9 0  #### 2 5 8 7 5 2 7 1 



 27%|███████████▌                               | 17/63 [00:03<00:10,  4.58it/s]

Input:  2 3 7 9 * 4 5 1 9 
Target:  8 2 9 8 3 + 0 0 6 6 8 4 ( 8 2 5 5 2 5 ) + 0 0 2 3 7 9 0 ( 8 2 7 8 9 4 1 ) + 0 0 0 8 8 5 7 8  #### 8 2 7 6 8 0 9 8 
Predicted:  8 2 9 8 3 + 0 0 6 6 8 4 ( 8 2 5 5 2 5 ) + 0 0 2 3 7 9 0 ( 8 2 7 8 9 4 1 ) + 0 0 0 8 8 5 7 8  #### 8 2 7 6 8 0 9 8 



 29%|████████████▎                              | 18/63 [00:03<00:09,  4.59it/s]

Input:  2 8 1 7 * 0 8 5 9 
Target:  0 0 0 0 0 + 0 6 5 4 7 5 ( 0 6 5 4 7 5 ) + 0 0 0 1 9 5 3 ( 0 6 5 5 6 1 4 ) + 0 0 0 8 3 6 4 6  #### 0 6 5 3 0 8 8 6 
Predicted:  0 0 0 0 0 + 0 6 5 4 7 5 ( 0 6 5 4 7 5 ) + 0 0 0 1 9 5 3 ( 0 6 5 5 6 1 4 ) + 0 0 0 8 3 6 4 6  #### 0 6 5 3 0 8 8 6 



 30%|████████████▉                              | 19/63 [00:04<00:09,  4.59it/s]

Input:  9 3 8 8 * 2 3 7 7 
Target:  8 7 6 7 1 + 0 7 1 5 6 2 ( 8 4 8 2 8 2 ) + 0 0 3 7 8 1 6 ( 8 4 1 0 7 4 6 ) + 0 0 0 3 7 8 1 6  #### 8 4 1 3 4 3 8 6 
Predicted:  8 7 6 7 1 + 0 7 1 5 6 2 ( 8 4 8 2 8 2 ) + 0 0 3 7 8 1 6 ( 8 4 1 0 7 4 6 ) + 0 0 0 3 7 8 1 6  #### 8 4 1 3 4 3 8 6 



 32%|█████████████▋                             | 20/63 [00:04<00:09,  4.59it/s]

Input:  7 7 6 4 * 3 7 1 2 
Target:  1 3 0 4 1 + 0 9 3 7 2 3 ( 1 2 4 1 4 3 ) + 0 0 7 7 6 4 0 ( 1 2 1 9 0 8 0 ) + 0 0 0 4 5 3 9 0  #### 1 2 1 3 6 1 0 1 
Predicted:  1 3 0 4 1 + 0 9 3 7 2 3 ( 1 2 4 1 4 3 ) + 0 0 7 7 6 4 0 ( 1 2 1 9 0 8 0 ) + 0 0 0 4 5 3 9 0  #### 1 2 1 3 6 1 0 1 



 33%|██████████████▎                            | 21/63 [00:04<00:09,  4.58it/s]

Input:  5 5 3 1 * 8 4 6 1 
Target:  0 4 8 0 1 + 0 0 2 4 5 0 ( 0 4 0 5 6 0 ) + 0 0 0 3 1 8 0 ( 0 4 0 8 7 8 0 ) + 0 0 0 5 5 3 1 0  #### 0 4 0 3 3 2 2 0 
Predicted:  0 4 8 0 1 + 0 0 2 4 5 0 ( 0 4 0 5 6 0 ) + 0 0 0 3 1 8 0 ( 0 4 0 8 7 8 0 ) + 0 0 0 5 5 3 1 0  #### 0 4 0 3 3 2 2 0 



 35%|███████████████                            | 22/63 [00:04<00:08,  4.59it/s]

Input:  8 7 8 4 * 9 3 5 4 
Target:  2 0 9 3 4 + 0 4 3 6 4 1 ( 2 4 2 0 9 1 ) + 0 0 0 9 3 4 2 ( 2 4 2 9 2 6 2 ) + 0 0 0 2 1 5 9 1  #### 2 4 2 1 4 1 2 2 
Predicted:  2 0 9 3 4 + 0 4 3 6 4 1 ( 2 4 2 0 9 1 ) + 0 0 0 9 3 4 2 ( 2 4 2 9 2 6 2 ) + 0 0 0 2 1 5 9 1  #### 2 4 2 1 4 1 2 2 



 37%|███████████████▋                           | 23/63 [00:05<00:08,  4.59it/s]

Input:  3 1 5 2 * 6 8 7 5 
Target:  8 7 0 5 1 + 0 4 0 1 0 2 ( 8 1 1 6 1 2 ) + 0 0 1 9 5 7 1 ( 8 1 2 5 7 9 1 ) + 0 0 0 5 6 5 2 1  #### 8 1 2 0 4 5 4 1 
Predicted:  8 7 0 5 1 + 0 4 0 1 0 2 ( 8 1 1 6 1 2 ) + 0 0 1 9 5 7 1 ( 8 1 2 5 7 9 1 ) + 0 0 0 5 6 5 2 1  #### 8 1 2 0 4 5 4 1 



 38%|████████████████▍                          | 24/63 [00:05<00:08,  4.58it/s]

Input:  8 3 9 7 * 9 9 4 2 
Target:  2 4 4 1 7 + 0 2 4 4 1 7 ( 2 6 8 5 8 7 ) + 0 0 2 5 7 1 3 ( 2 6 0 1 6 9 3 ) + 0 0 0 6 7 8 5 1  #### 2 6 0 7 3 8 9 1 
Predicted:  2 4 4 1 7 + 0 2 4 4 1 7 ( 2 6 8 5 8 7 ) + 0 0 2 5 7 1 3 ( 2 6 0 1 6 9 3 ) + 0 0 0 6 7 8 5 1  #### 2 6 0 7 3 8 9 1 



 40%|█████████████████                          | 25/63 [00:05<00:08,  4.57it/s]

Input:  0 7 7 2 * 0 4 8 8 
Target:  0 0 0 0 0 + 0 0 8 0 1 1 ( 0 0 8 0 1 1 ) + 0 0 0 6 1 2 2 ( 0 0 8 6 2 3 2 ) + 0 0 0 0 6 1 2 2  #### 0 0 8 6 8 4 4 2 
Predicted:  0 0 0 0 0 + 0 0 8 0 1 1 ( 0 0 8 0 1 1 ) + 0 0 0 6 1 2 2 ( 0 0 8 6 2 3 2 ) + 0 0 0 0 6 1 2 2  #### 0 0 8 6 8 4 4 2 



 41%|█████████████████▋                         | 26/63 [00:05<00:08,  4.58it/s]

Input:  3 6 3 4 * 4 4 7 2 
Target:  2 5 4 7 1 + 0 2 5 4 7 1 ( 2 7 9 1 9 1 ) + 0 0 1 4 5 0 3 ( 2 7 0 6 4 2 3 ) + 0 0 0 6 2 7 8 0  #### 2 7 0 2 7 9 1 1 
Predicted:  2 5 4 7 1 + 0 2 5 4 7 1 ( 2 7 9 1 9 1 ) + 0 0 1 4 5 0 3 ( 2 7 0 6 4 2 3 ) + 0 0 0 6 2 7 8 0  #### 2 7 0 2 7 9 1 1 



 43%|██████████████████▍                        | 27/63 [00:05<00:07,  4.57it/s]

Input:  5 2 0 2 * 7 3 9 3 
Target:  5 7 1 4 1 + 0 5 7 0 6 0 ( 5 2 9 4 7 0 ) + 0 0 5 2 2 8 1 ( 5 2 4 7 9 8 1 ) + 0 0 0 5 7 0 6 0  #### 5 2 4 2 7 9 7 0 
Predicted:  5 7 1 4 1 + 0 5 7 0 6 0 ( 5 2 9 4 7 0 ) + 0 0 5 2 2 8 1 ( 5 2 4 7 9 8 1 ) + 0 0 0 5 7 0 6 0  #### 5 2 4 2 7 9 7 0 



 44%|███████████████████                        | 28/63 [00:06<00:07,  4.58it/s]

Input:  4 1 9 6 * 0 6 9 2 
Target:  0 0 0 0 0 + 0 4 8 4 1 4 ( 0 4 8 4 1 4 ) + 0 0 6 2 2 2 6 ( 0 4 4 7 3 6 6 ) + 0 0 0 8 2 8 3 1  #### 0 4 4 5 6 4 0 2 
Predicted:  0 0 0 0 0 + 0 4 8 4 1 4 ( 0 4 8 4 1 4 ) + 0 0 6 2 2 2 6 ( 0 4 4 7 3 6 6 ) + 0 0 0 8 2 8 3 1  #### 0 4 4 5 6 4 0 2 



 46%|███████████████████▊                       | 29/63 [00:06<00:07,  4.57it/s]

Input:  2 0 5 9 * 6 1 0 4 
Target:  2 1 0 7 5 + 0 2 0 5 9 0 ( 2 3 0 2 5 1 ) + 0 0 0 0 0 0 0 ( 2 3 0 2 5 1 0 ) + 0 0 0 8 0 0 8 3  #### 2 3 0 0 6 1 8 3 
Predicted:  2 1 0 7 5 + 0 2 0 5 9 0 ( 2 3 0 2 5 1 ) + 0 0 0 0 0 0 0 ( 2 3 0 2 5 1 0 ) + 0 0 0 8 0 0 8 3  #### 2 3 0 0 6 1 8 3 



 48%|████████████████████▍                      | 30/63 [00:06<00:07,  4.57it/s]

Input:  9 9 2 6 * 7 5 2 9 
Target:  3 9 0 4 4 + 0 5 9 4 1 3 ( 3 4 0 9 5 3 ) + 0 0 8 9 5 2 1 ( 3 4 8 8 1 6 1 ) + 0 0 0 1 9 6 6 5  #### 3 4 8 9 0 3 8 5 
Predicted:  3 9 0 4 4 + 0 5 9 4 1 3 ( 3 4 0 9 5 3 ) + 0 0 8 9 5 2 1 ( 3 4 8 8 1 6 1 ) + 0 0 0 1 9 6 6 5  #### 3 4 8 9 0 3 8 5 



 49%|█████████████████████▏                     | 31/63 [00:06<00:07,  4.56it/s]

Input:  0 8 6 1 * 0 9 5 6 
Target:  0 0 0 0 0 + 0 0 2 1 5 1 ( 0 0 2 1 5 1 ) + 0 0 0 0 4 8 0 ( 0 0 2 1 9 9 0 ) + 0 0 0 0 8 0 0 1  #### 0 0 2 1 7 0 1 1 
Predicted:  0 0 0 0 0 + 0 0 2 1 5 1 ( 0 0 2 1 5 1 ) + 0 0 0 0 4 8 0 ( 0 0 2 1 9 9 0 ) + 0 0 0 0 8 0 0 1  #### 0 0 2 1 7 0 1 1 



 51%|█████████████████████▊                     | 32/63 [00:07<00:06,  4.55it/s]

Input:  1 9 7 4 * 3 6 2 9 
Target:  3 7 3 4 1 + 0 6 4 7 8 2 ( 3 3 8 1 0 3 ) + 0 0 2 8 5 9 0 ( 3 3 0 0 6 2 1 ) + 0 0 0 9 1 1 3 4  #### 3 3 0 9 7 3 4 4 
Predicted:  3 7 3 4 1 + 0 6 4 7 8 2 ( 3 3 8 1 0 3 ) + 0 0 2 8 5 9 0 ( 3 3 0 0 6 2 1 ) + 0 0 0 9 1 1 3 4  #### 3 3 0 9 7 3 4 4 



 52%|██████████████████████▌                    | 33/63 [00:07<00:06,  4.55it/s]

Input:  8 5 4 1 * 7 8 3 1 
Target:  6 0 2 0 1 + 0 4 6 6 1 1 ( 6 4 8 6 2 1 ) + 0 0 4 7 3 4 0 ( 6 4 2 4 6 5 0 ) + 0 0 0 8 5 4 1 0  #### 6 4 2 2 2 0 2 0 
Predicted:  6 0 2 0 1 + 0 4 6 6 1 1 ( 6 4 8 6 2 1 ) + 0 0 4 7 3 4 0 ( 6 4 2 4 6 5 0 ) + 0 0 0 8 5 4 1 0  #### 6 4 2 2 2 0 2 0 



 54%|███████████████████████▏                   | 34/63 [00:07<00:06,  4.56it/s]

Input:  4 9 4 2 * 0 7 7 7 
Target:  0 0 0 0 0 + 0 8 5 4 7 1 ( 0 8 5 4 7 1 ) + 0 0 8 5 4 7 1 ( 0 8 3 0 2 9 1 ) + 0 0 0 8 5 4 7 1  #### 0 8 3 8 7 3 9 1 
Predicted:  0 0 0 0 0 + 0 8 5 4 7 1 ( 0 8 5 4 7 1 ) + 0 0 8 5 4 7 1 ( 0 8 3 0 2 9 1 ) + 0 0 0 8 5 4 7 1  #### 0 8 3 8 7 3 9 1 



 56%|███████████████████████▉                   | 35/63 [00:07<00:06,  4.57it/s]

Input:  3 4 2 7 * 4 8 7 2 
Target:  2 7 9 8 2 + 0 4 4 9 7 5 ( 2 1 4 8 0 6 ) + 0 0 1 0 7 0 5 ( 2 1 5 8 7 6 5 ) + 0 0 0 6 8 4 4 1  #### 2 1 5 4 6 1 0 2 
Predicted:  2 7 9 8 2 + 0 4 4 9 7 5 ( 2 1 4 8 0 6 ) + 0 0 1 0 7 0 5 ( 2 1 5 8 7 6 5 ) + 0 0 0 6 8 4 4 1  #### 2 1 5 4 6 1 0 2 



 57%|████████████████████████▌                  | 36/63 [00:07<00:05,  4.58it/s]

Input:  7 9 2 3 * 3 0 1 6 
Target:  1 9 8 9 0 + 0 0 0 0 0 0 ( 1 9 8 9 0 0 ) + 0 0 7 9 2 3 0 ( 1 9 5 9 3 3 0 ) + 0 0 0 2 8 7 9 1  #### 1 9 5 1 2 1 0 2 
Predicted:  1 9 8 9 0 + 0 0 0 0 0 0 ( 1 9 8 9 0 0 ) + 0 0 7 9 2 3 0 ( 1 9 5 9 3 3 0 ) + 0 0 0 2 8 7 9 1  #### 1 9 5 1 2 1 0 2 



 59%|█████████████████████████▎                 | 37/63 [00:08<00:05,  4.57it/s]

Input:  1 5 4 3 * 2 2 0 8 
Target:  2 0 9 6 0 + 0 2 0 9 6 0 ( 2 2 9 5 7 0 ) + 0 0 0 0 0 0 0 ( 2 2 9 5 7 0 0 ) + 0 0 0 8 0 6 7 2  #### 2 2 9 3 8 6 7 2 
Predicted:  2 0 9 6 0 + 0 2 0 9 6 0 ( 2 2 9 5 7 0 ) + 0 0 0 0 0 0 0 ( 2 2 9 5 7 0 0 ) + 0 0 0 8 0 6 7 2  #### 2 2 9 3 8 6 7 2 



 60%|█████████████████████████▉                 | 38/63 [00:08<00:05,  4.57it/s]

Input:  8 8 1 2 * 4 8 4 8 
Target:  2 5 7 8 0 + 0 4 0 5 7 1 ( 2 9 7 3 8 1 ) + 0 0 2 5 7 8 0 ( 2 9 9 8 5 0 1 ) + 0 0 0 4 0 5 7 1  #### 2 9 9 2 6 5 8 1 
Predicted:  2 5 7 8 0 + 0 4 0 5 7 1 ( 2 9 7 3 8 1 ) + 0 0 2 5 7 8 0 ( 2 9 9 8 5 0 1 ) + 0 0 0 4 0 5 7 1  #### 2 9 9 2 6 5 8 1 



 62%|██████████████████████████▌                | 39/63 [00:08<00:05,  4.59it/s]

Input:  9 9 7 3 * 5 2 1 1 
Target:  5 9 9 8 1 + 0 8 9 5 7 0 ( 5 7 9 4 9 0 ) + 0 0 9 9 7 3 0 ( 5 7 8 4 7 4 0 ) + 0 0 0 9 9 7 3 0  #### 5 7 8 3 7 2 4 0 
Predicted:  5 9 9 8 1 + 0 8 9 5 7 0 ( 5 7 9 4 9 0 ) + 0 0 9 9 7 3 0 ( 5 7 8 4 7 4 0 ) + 0 0 0 9 9 7 3 0  #### 5 7 8 3 7 2 4 0 



 63%|███████████████████████████▎               | 40/63 [00:08<00:05,  4.58it/s]

Input:  1 5 4 4 * 5 7 3 1 
Target:  5 5 2 2 2 + 0 7 5 1 1 3 ( 5 2 8 3 3 3 ) + 0 0 3 5 3 3 1 ( 5 2 1 9 6 6 1 ) + 0 0 0 1 5 4 4 0  #### 5 2 1 0 2 1 6 0 
Predicted:  5 5 2 2 2 + 0 7 5 1 1 3 ( 5 2 8 3 3 3 ) + 0 0 3 5 3 3 1 ( 5 2 1 9 6 6 1 ) + 0 0 0 1 5 4 4 0  #### 5 2 1 0 2 1 6 0 



 65%|███████████████████████████▉               | 41/63 [00:08<00:04,  4.57it/s]

Input:  7 1 1 8 * 3 3 2 3 
Target:  1 5 3 4 2 + 0 1 5 3 4 2 ( 1 6 8 7 6 2 ) + 0 0 4 3 2 6 1 ( 1 6 2 1 9 8 1 ) + 0 0 0 1 5 3 4 2  #### 1 6 2 2 4 2 6 2 
Predicted:  1 5 3 4 2 + 0 1 5 3 4 2 ( 1 6 8 7 6 2 ) + 0 0 4 3 2 6 1 ( 1 6 2 1 9 8 1 ) + 0 0 0 1 5 3 4 2  #### 1 6 2 2 4 2 6 2 



 67%|████████████████████████████▋              | 42/63 [00:09<00:04,  4.59it/s]

Input:  5 4 7 8 * 0 2 2 1 
Target:  0 0 0 0 0 + 0 0 9 4 7 1 ( 0 0 9 4 7 1 ) + 0 0 0 9 4 7 1 ( 0 0 9 3 2 9 1 ) + 0 0 0 5 4 7 8 0  #### 0 0 9 8 6 6 0 1 
Predicted:  0 0 0 0 0 + 0 0 9 4 7 1 ( 0 0 9 4 7 1 ) + 0 0 0 9 4 7 1 ( 0 0 9 3 2 9 1 ) + 0 0 0 5 4 7 8 0  #### 0 0 9 8 6 6 0 1 



 68%|█████████████████████████████▎             | 43/63 [00:09<00:04,  4.60it/s]

Input:  6 6 3 1 * 4 7 8 3 
Target:  4 6 4 5 0 + 0 2 6 5 9 0 ( 4 8 0 1 0 1 ) + 0 0 8 2 9 0 1 ( 4 8 8 3 9 1 1 ) + 0 0 0 8 9 0 4 0  #### 4 8 8 1 9 2 5 0 
Predicted:  4 6 4 5 0 + 0 2 6 5 9 0 ( 4 8 0 1 0 1 ) + 0 0 8 2 9 0 1 ( 4 8 8 3 9 1 1 ) + 0 0 0 8 9 0 4 0  #### 4 8 8 1 9 2 5 0 



 70%|██████████████████████████████             | 44/63 [00:09<00:04,  4.61it/s]

Input:  4 4 5 1 * 1 0 9 1 
Target:  4 4 5 1 0 + 0 0 0 0 0 0 ( 4 4 5 1 0 0 ) + 0 0 6 9 8 3 1 ( 4 4 1 1 9 3 1 ) + 0 0 0 4 4 5 1 0  #### 4 4 1 5 3 9 2 0 
Predicted:  4 4 5 1 0 + 0 0 0 0 0 0 ( 4 4 5 1 0 0 ) + 0 0 6 9 8 3 1 ( 4 4 1 1 9 3 1 ) + 0 0 0 4 4 5 1 0  #### 4 4 1 5 3 9 2 0 



 71%|██████████████████████████████▋            | 45/63 [00:09<00:03,  4.59it/s]

Input:  5 4 6 2 * 0 4 7 7 
Target:  0 0 0 0 0 + 0 0 8 5 0 1 ( 0 0 8 5 0 1 ) + 0 0 5 1 5 8 1 ( 0 0 3 7 5 9 1 ) + 0 0 0 5 1 5 8 1  #### 0 0 3 2 7 4 0 2 
Predicted:  0 0 0 0 0 + 0 0 8 5 0 1 ( 0 0 8 5 0 1 ) + 0 0 5 1 5 8 1 ( 0 0 3 7 5 9 1 ) + 0 0 0 5 1 5 8 1  #### 0 0 3 2 7 4 0 2 



 73%|███████████████████████████████▍           | 46/63 [00:10<00:03,  4.57it/s]

Input:  0 9 1 8 * 8 0 2 9 
Target:  0 2 5 5 6 + 0 0 0 0 0 0 ( 0 2 5 5 6 0 ) + 0 0 0 8 3 6 1 ( 0 2 5 3 0 7 1 ) + 0 0 0 0 1 7 3 7  #### 0 2 5 3 1 4 5 7 
Predicted:  0 2 5 5 6 + 0 0 0 0 0 0 ( 0 2 5 5 6 0 ) + 0 0 0 8 3 6 1 ( 0 2 5 3 0 7 1 ) + 0 0 0 0 1 7 3 7  #### 0 2 5 3 1 4 5 7 



 75%|████████████████████████████████           | 47/63 [00:10<00:03,  4.56it/s]

Input:  0 9 6 2 * 7 0 6 8 
Target:  0 3 8 8 1 + 0 0 0 0 0 0 ( 0 3 8 8 1 0 ) + 0 0 0 4 1 6 1 ( 0 3 8 2 3 6 1 ) + 0 0 0 0 2 5 1 2  #### 0 3 8 2 5 1 3 2 
Predicted:  0 3 8 8 1 + 0 0 0 0 0 0 ( 0 3 8 8 1 0 ) + 0 0 0 4 1 6 1 ( 0 3 8 2 3 6 1 ) + 0 0 0 0 2 5 1 2  #### 0 3 8 2 5 1 3 2 



 76%|████████████████████████████████▊          | 48/63 [00:10<00:03,  4.55it/s]

Input:  2 3 9 7 * 6 2 3 8 
Target:  2 9 5 7 4 + 0 4 6 8 5 1 ( 2 3 2 6 0 2 ) + 0 0 6 9 7 3 2 ( 2 3 8 5 8 5 2 ) + 0 0 0 6 5 4 3 6  #### 2 3 8 1 4 0 6 6 
Predicted:  2 9 5 7 4 + 0 4 6 8 5 1 ( 2 3 2 6 0 2 ) + 0 0 6 9 7 3 2 ( 2 3 8 5 8 5 2 ) + 0 0 0 6 5 4 3 6  #### 2 3 8 1 4 0 6 6 



 78%|█████████████████████████████████▍         | 49/63 [00:10<00:03,  4.56it/s]

Input:  6 7 1 4 * 0 9 1 5 
Target:  0 0 0 0 0 + 0 4 8 5 7 3 ( 0 4 8 5 7 3 ) + 0 0 6 7 1 4 0 ( 0 4 4 3 9 7 0 ) + 0 0 0 0 8 8 0 2  #### 0 4 4 3 7 6 1 2 
Predicted:  0 0 0 0 0 + 0 4 8 5 7 3 ( 0 4 8 5 7 3 ) + 0 0 6 7 1 4 0 ( 0 4 4 3 9 7 0 ) + 0 0 0 0 8 8 0 2  #### 0 4 4 3 7 6 1 2 



 79%|██████████████████████████████████▏        | 50/63 [00:10<00:02,  4.57it/s]

Input:  2 9 9 6 * 5 8 1 4 
Target:  0 6 9 4 3 + 0 6 3 9 5 5 ( 0 2 3 4 9 5 ) + 0 0 2 9 9 6 0 ( 0 2 5 3 9 2 1 ) + 0 0 0 8 6 9 7 2  #### 0 2 5 1 6 2 9 2 
Predicted:  0 6 9 4 3 + 0 6 3 9 5 5 ( 0 2 3 4 9 5 ) + 0 0 2 9 9 6 0 ( 0 2 5 3 9 2 1 ) + 0 0 0 8 6 9 7 2  #### 0 2 5 1 6 2 9 2 



 81%|██████████████████████████████████▊        | 51/63 [00:11<00:02,  4.56it/s]

Input:  7 3 0 1 * 1 0 1 4 
Target:  7 3 0 1 0 + 0 0 0 0 0 0 ( 7 3 0 1 0 0 ) + 0 0 7 3 0 1 0 ( 7 3 7 4 0 1 0 ) + 0 0 0 8 4 1 4 0  #### 7 3 7 2 5 2 4 0 
Predicted:  7 3 0 1 0 + 0 0 0 0 0 0 ( 7 3 0 1 0 0 ) + 0 0 7 3 0 1 0 ( 7 3 7 4 0 1 0 ) + 0 0 0 8 4 1 4 0  #### 7 3 7 2 5 2 4 0 



 83%|███████████████████████████████████▍       | 52/63 [00:11<00:02,  4.56it/s]

Input:  5 9 2 4 * 4 4 0 6 
Target:  0 8 1 7 1 + 0 0 8 1 7 1 ( 0 8 9 8 8 1 ) + 0 0 0 0 0 0 0 ( 0 8 9 8 8 1 0 ) + 0 0 0 0 7 7 5 2  #### 0 8 9 8 5 9 5 2 
Predicted:  0 8 1 7 1 + 0 0 8 1 7 1 ( 0 8 9 8 8 1 ) + 0 0 0 0 0 0 0 ( 0 8 9 8 8 1 0 ) + 0 0 0 0 7 7 5 2  #### 0 8 9 8 5 9 5 2 



 84%|████████████████████████████████████▏      | 53/63 [00:11<00:02,  4.57it/s]

Input:  8 6 9 1 * 4 5 3 3 
Target:  2 7 8 7 0 + 0 0 4 8 9 0 ( 2 7 2 6 0 1 ) + 0 0 4 0 9 5 0 ( 2 7 6 6 9 6 0 ) + 0 0 0 4 0 9 5 0  #### 2 7 6 0 0 6 6 0 
Predicted:  2 7 8 7 0 + 0 0 4 8 9 0 ( 2 7 2 6 0 1 ) + 0 0 4 0 9 5 0 ( 2 7 6 6 9 6 0 ) + 0 0 0 4 0 9 5 0  #### 2 7 6 0 0 6 6 0 



 86%|████████████████████████████████████▊      | 54/63 [00:11<00:01,  4.57it/s]

Input:  6 4 2 7 * 4 4 1 1 
Target:  4 8 9 8 2 + 0 4 8 9 8 2 ( 4 2 8 8 1 3 ) + 0 0 6 4 2 7 0 ( 4 2 4 3 4 0 1 ) + 0 0 0 6 4 2 7 0  #### 4 2 4 9 8 2 8 0 
Predicted:  4 8 9 8 2 + 0 4 8 9 8 2 ( 4 2 8 8 1 3 ) + 0 0 6 4 2 7 0 ( 4 2 4 3 4 0 1 ) + 0 0 0 6 4 2 7 0  #### 4 2 4 9 8 2 8 0 



 87%|█████████████████████████████████████▌     | 55/63 [00:12<00:01,  4.55it/s]

Input:  2 4 6 3 * 4 6 3 6 
Target:  8 6 5 4 1 + 0 2 5 8 1 2 ( 8 8 0 3 3 2 ) + 0 0 6 2 9 0 1 ( 8 8 6 5 2 3 1 ) + 0 0 0 2 5 8 1 2  #### 8 8 6 7 7 1 3 2 
Predicted:  8 6 5 4 1 + 0 2 5 8 1 2 ( 8 8 0 3 3 2 ) + 0 0 6 2 9 0 1 ( 8 8 6 5 2 3 1 ) + 0 0 0 2 5 8 1 2  #### 8 8 6 7 7 1 3 2 



 89%|██████████████████████████████████████▏    | 56/63 [00:12<00:01,  4.55it/s]

Input:  7 4 9 4 * 5 3 4 7 
Target:  5 3 7 4 2 + 0 1 4 8 4 1 ( 5 4 1 3 7 1 ) + 0 0 8 8 7 9 1 ( 5 4 9 1 5 1 2 ) + 0 0 0 9 2 6 4 3  #### 5 4 9 0 8 7 6 3 
Predicted:  5 3 7 4 2 + 0 1 4 8 4 1 ( 5 4 1 3 7 1 ) + 0 0 8 8 7 9 1 ( 5 4 9 1 5 1 2 ) + 0 0 0 9 2 6 4 3  #### 5 4 9 0 8 7 6 3 



 90%|██████████████████████████████████████▉    | 57/63 [00:12<00:01,  4.55it/s]

Input:  4 3 2 6 * 8 9 2 4 
Target:  2 7 8 9 4 + 0 6 0 1 6 5 ( 2 3 9 0 1 6 ) + 0 0 8 6 4 2 1 ( 2 3 7 7 5 8 1 ) + 0 0 0 6 3 9 4 2  #### 2 3 7 3 9 7 6 2 
Predicted:  2 7 8 9 4 + 0 6 0 1 6 5 ( 2 3 9 0 1 6 ) + 0 0 8 6 4 2 1 ( 2 3 7 7 5 8 1 ) + 0 0 0 6 3 9 4 2  #### 2 3 7 3 9 7 6 2 



 92%|███████████████████████████████████████▌   | 58/63 [00:12<00:01,  4.56it/s]

Input:  4 6 1 4 * 7 1 6 4 
Target:  8 4 1 9 2 + 0 4 6 1 4 0 ( 8 8 7 0 7 0 ) + 0 0 4 8 9 4 2 ( 8 8 1 9 6 5 2 ) + 0 0 0 6 5 6 6 1  #### 8 8 1 5 2 2 9 1 
Predicted:  8 4 1 9 2 + 0 4 6 1 4 0 ( 8 8 7 0 7 0 ) + 0 0 4 8 9 4 2 ( 8 8 1 9 6 5 2 ) + 0 0 0 6 5 6 6 1  #### 8 8 1 5 2 2 9 1 



 94%|████████████████████████████████████████▎  | 59/63 [00:12<00:00,  4.57it/s]

Input:  8 8 4 7 * 1 9 2 3 
Target:  8 8 4 7 0 + 0 2 9 3 7 6 ( 8 0 4 1 8 6 ) + 0 0 6 7 9 4 1 ( 8 0 0 9 7 1 2 ) + 0 0 0 4 6 4 2 2  #### 8 0 0 3 4 6 4 2 
Predicted:  8 8 4 7 0 + 0 2 9 3 7 6 ( 8 0 4 1 8 6 ) + 0 0 6 7 9 4 1 ( 8 0 0 9 7 1 2 ) + 0 0 0 4 6 4 2 2  #### 8 0 0 3 4 6 4 2 



 95%|████████████████████████████████████████▉  | 60/63 [00:13<00:00,  4.58it/s]

Input:  0 4 8 4 * 8 0 4 5 
Target:  0 2 7 8 3 + 0 0 0 0 0 0 ( 0 2 7 8 3 0 ) + 0 0 0 6 3 9 1 ( 0 2 7 4 7 9 1 ) + 0 0 0 0 0 2 4 2  #### 0 2 7 4 7 1 6 2 
Predicted:  0 2 7 8 3 + 0 0 0 0 0 0 ( 0 2 7 8 3 0 ) + 0 0 0 6 3 9 1 ( 0 2 7 4 7 9 1 ) + 0 0 0 0 0 2 4 2  #### 0 2 7 4 7 1 6 2 



 97%|█████████████████████████████████████████▋ | 61/63 [00:13<00:00,  4.58it/s]

Input:  0 8 3 6 * 0 0 3 1 
Target:  0 0 0 0 0 + 0 0 0 0 0 0 ( 0 0 0 0 0 0 ) + 0 0 0 4 1 9 1 ( 0 0 0 4 1 9 1 ) + 0 0 0 0 8 3 6 0  #### 0 0 0 4 9 2 8 0 
Predicted:  0 0 0 0 0 + 0 0 0 0 0 0 ( 0 0 0 0 0 0 ) + 0 0 0 4 1 9 1 ( 0 0 0 4 1 9 1 ) + 0 0 0 0 8 3 6 0  #### 0 0 0 4 9 2 8 0 



 98%|██████████████████████████████████████████▎| 62/63 [00:13<00:00,  4.59it/s]

Input:  7 2 3 3 * 6 2 9 6 
Target:  2 6 9 9 1 + 0 4 5 6 6 0 ( 2 0 5 6 8 0 ) + 0 0 3 4 9 9 2 ( 2 0 8 0 8 0 3 ) + 0 0 0 2 6 9 9 1  #### 2 0 8 2 4 0 3 2 
Predicted:  2 6 9 9 1 + 0 4 5 6 6 0 ( 2 0 5 6 8 0 ) + 0 0 3 4 9 9 2 ( 2 0 8 0 8 0 3 ) + 0 0 0 2 6 9 9 1  #### 2 0 8 2 4 0 3 2 



100%|███████████████████████████████████████████| 63/63 [00:13<00:00,  4.57it/s]

Input:  5 7 5 2 * 2 2 7 3 
Target:  0 5 1 5 0 + 0 0 5 1 5 0 ( 0 5 6 6 5 0 ) + 0 0 5 2 0 8 1 ( 0 5 1 9 5 8 1 ) + 0 0 0 5 2 7 7 0  #### 0 5 1 4 8 5 9 0 
Predicted:  0 5 1 5 0 + 0 0 5 1 5 0 ( 0 5 6 6 5 0 ) + 0 0 5 2 0 8 1 ( 0 5 1 9 5 8 1 ) + 0 0 0 5 2 7 7 0  #### 0 5 1 4 8 5 9 0 

Val. PPL: 1.000001479784104; Accuracy: 1.0; Token Accuracy: 1.0.
Saving to ../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0





In [18]:
print(outputs.attentions.shape)

AttributeError: 'NoneType' object has no attribute 'shape'

In [5]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import sys
import inspect
import tqdm
import logging
import random

from data import CoTDataset, CoTDataCollator, extract_answer
from models.teacher import Teacher
from models.student import Student
from models.configuration_student import StudentConfig
from utils import get_sep_position

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

random.seed(1234)
torch.manual_seed(1234)
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@torch.no_grad()
def evaluate(dataloader, tokenizer, ctx, teacher, student, delta, subset, max_new_tokens):
    total_instances = 0
    total_tokens = 0
    total_correct = 0
    total_correct_tokens = 0
    total_loss = 0
    for batch in tqdm.tqdm(dataloader):
        input_ids_all = batch['input_ids_all'].to(device)
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        labels_nocot = batch['labels_nocot'].to(device)
        batch_size = input_ids_nocot.shape[0]
        with ctx:
            teacher_states = teacher.extract_states(input_ids=input_ids_all, delta=delta, subset=subset)
            outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=teacher_states)
            loss = outputs.loss
            token_accuracy = outputs.token_accuracy.item()
        total_loss += outputs.total_loss.item()
        total_correct_tokens += outputs.total_correct.item()
        total_tokens += outputs.total_tokens
        total_instances += batch_size

        # Generate
        with ctx:
            beam_output = student.generate(
                input_ids=input_ids_nocot,
                teacher_states=teacher_states,
                max_new_tokens=max_new_tokens,
            )

        # Evaluate
        sep_positions = get_sep_position(input_ids_all, tokenizer.eos_token_id)
        for i, (input_ids_all_i, beam_output_i) in enumerate(zip(input_ids_all, beam_output)):
            sep_position = sep_positions[i].item()
            tgt = input_ids_all_i[sep_position+1:]
            tgt_text = tokenizer.decode(tgt, skip_special_tokens=True)
            ans = extract_answer(tgt_text)
            pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True)
            pred_ans = extract_answer(pred_text)
            if ans == pred_ans:
                total_correct += 1
            if i == 0:
                print (f'Input: {tokenizer.decode(input_ids_all_i[:sep_position], skip_special_tokens=True)}')
                print (f'Target: {tgt_text}')
                print (f'Predicted: {pred_text}')
                print ('')
    accuracy = total_correct / total_instances
    token_accuracy = total_correct_tokens / total_tokens
    loss = total_loss / total_tokens
    ppl = math.exp(loss)
    return accuracy, token_accuracy, ppl


    

In [35]:
mind_reading_student_trainer_args=dict(
    debug=False,
    train_path="../data/4_by_4_mult/train.txt",
    val_path="../data/4_by_4_mult/valid.txt",
    teacher="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    save_model="../train_models/4_by_4_mult/gpt2/student_initial",
    base_model='gpt2',
    epochs=1,
    batch_size=16,
    lr=5e-5,
    max_new_tokens=128,
    delta='dynamic',
    max_grad_norm=1.0,
    subset='diagonal',
    
)

from types import SimpleNamespace

args = SimpleNamespace(**mind_reading_student_trainer_args)

In [36]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--teacher', type=str, required=True)
# parser.add_argument('--delta', type=str, required=True)
# parser.add_argument('--train_path', type=str, required=True)
# parser.add_argument('--val_path', type=str, required=True)
# parser.add_argument('--save_model', type=str, required=True)
# parser.add_argument('--max_new_tokens', type=int, default=128)
# parser.add_argument('--base_model', type=str, default='gpt2')
# parser.add_argument('--epochs', type=int, default=5)
# parser.add_argument('--batch_size', type=int, default=32)
# parser.add_argument('--lr', type=float, default=5e-5)
# parser.add_argument('--max_grad_norm', type=float, default=1.0)
# parser.add_argument('--subset', type=str, choices=['diagonal', 'last_column', 'top_row', 'bottom_row', 'first_column'], default='diagonal')
# args = parser.parse_args()

# print (args)

dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Create Student
config = StudentConfig(base_model=args.base_model)
student = Student(config).to(device).to(ptdtype)

# Load Teacher
teacher = Teacher.from_pretrained(args.teacher).to(device).to(ptdtype)

# Load data
tokenizer = teacher.tokenizer
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

# Create Optimizer
trainable_params = list(student.parameters())
use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args)

teacher.eval()
student.eval() # to turn off dropout

for p in teacher.parameters():
    p.requires_grad = False




torch.float32 float32 cuda
Creating features from dataset file at ../data/4_by_4_mult/train.txt
tgt_avg:  49.0
src_avg:  10.0
ratios:  0.20408163265306123
tgt_avg:  13.0
src_avg:  10.0
ratios:  0.7692307692307693
 1 3 3 8 * 5 1 0 5 <|endoftext|> 5 5 6 1 4 + 0 1 3 3 8 0 ( 5 6 9 4 2 1 ) + 0 0 0 0 0 0 0 ( 5 6 9 4 2 1 0 ) + 0 0 0 5 5 6 1 4 <|endoftext|> #### 5 6 9 9 7 7 1 4 <|endoftext|>
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 642, 642, 718, 352, 604, 1343, 657, 352, 513, 513, 807, 657, 357, 642, 718, 860, 604, 362, 352, 1267, 1343, 657, 657, 657, 657, 657, 657, 657, 357, 642, 718, 860, 604, 362, 352, 657, 1267, 1343, 657, 657, 657, 642, 642, 718, 352, 604, 220, 50256]
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1303, 21017, 642, 718, 860, 860, 767, 767, 352, 604, 220, 50256]
[352, 513, 513, 807, 1635, 642, 352, 657, 642, 220, 50256, 1303, 21017, 642, 718, 860, 860, 767, 767, 352, 604, 220, 50256]
 1 3 3 8 * 5 1 0 5 <|endoftext|> #### 5 6 9

In [37]:
teacher

Teacher(
  (base_model): GPT2LMHeadImplicitModel(
    (transformer): GPT2ImplicitModel(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0): GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D()
            (c_proj): Co

In [40]:
teacher.base_model

GPT2LMHeadImplicitModel(
  (transformer): GPT2ImplicitModel(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (res

In [75]:
def compute_positions_to_extract_per_layer(subset, delta, first_sep_positions, second_sep_positions):
    batch_size = first_sep_positions.shape[0]
    positions_to_extract_per_layer = first_sep_positions.new_zeros(batch_size, teacher.num_layers).long()
    layer_ids = torch.arange(start=0, end=teacher.num_layers).to(first_sep_positions.device)
    for batch_id in range(batch_size):
        first_position_to_extract = first_sep_positions[batch_id]
        last_position_to_extract = second_sep_positions[batch_id]
        if subset == 'diagonal':
            if delta == 'dynamic': # determine actual delta
                delta = (last_position_to_extract - first_position_to_extract) / (teacher.num_layers - 1)
        elif subset == 'first_column' or subset == 'last_column':
            delta = 0
        else:
            assert subset == 'last_column', subset
            delta = 0
            first_position_to_extract = last_position_to_extract
        positions_to_extract = torch.round(first_position_to_extract + layer_ids * delta)
        positions_to_extract = positions_to_extract.clamp(max=last_position_to_extract)
        positions_to_extract_per_layer[batch_id] = positions_to_extract
    return positions_to_extract_per_layer

In [97]:
from utils import get_sep_position

for batch in tqdm.tqdm(val_dataloader):
        input_ids_all = batch['input_ids_all'].to(device)
        first_sep_positions = get_sep_position(input_ids_all, teacher.tokenizer.eos_token_id, skip=0)
        second_sep_positions = get_sep_position(input_ids_all, teacher.tokenizer.eos_token_id, skip=1)
        print(compute_positions_to_extract_per_layer('diagonal', 'dynamic', first_sep_positions, second_sep_positions)[:,2].view(-1, 1, 1).expand(-1, -1, teacher.hidden_size).shape)
        input_ids_all = input_ids_all[:, :second_sep_positions.max()+1]
        print(input_ids_all.shape)
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        labels_nocot = batch['labels_nocot'].to(device)
        with ctx:
            with torch.no_grad():
                teacher_states = teacher.base_model(input_ids=input_ids_all, output_hidden_states=True)
                hidden_states = teacher_states.hidden_states[:-1]
                for i, hidden_state in enumerate(hidden_states):
                    print(hidden_state.shape) # torch.Size([16, 59, 768])
                    z = hidden_state.gather(1, compute_positions_to_extract_per_layer('diagonal', 'dynamic', first_sep_positions, second_sep_positions)[:,i].view(-1, 1, 1).expand(-1, -1, teacher.hidden_size)).squeeze(1)
                    print(z.shape) # torch.Size([16, 768])
                    
        break

  0%|                                                    | 0/63 [00:00<?, ?it/s]

torch.Size([16, 1, 768])
torch.Size([16, 59])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])
torch.Size([16, 59, 768])
torch.Size([16, 768])





In [77]:
teacher_states.hidden_states[:-1][0].shape

torch.Size([16, 59, 768])

In [None]:
# Train
step = 0
for epoch in range(args.epochs):
    print(f"Epoch {epoch}")

    for batch in tqdm.tqdm(train_dataloader):
        input_ids_all = batch['input_ids_all'].to(device)
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        labels_nocot = batch['labels_nocot'].to(device)
        with ctx:
            with torch.no_grad():
                teacher_states = teacher.extract_states(input_ids=input_ids_all, delta=args.delta, subset=args.subset)
            outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=teacher_states)
        loss = outputs.loss
        token_accuracy = outputs.token_accuracy.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        ppl = loss.exp().item()
        if step % 100 == 0:
            print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}")
            sys.stdout.flush()
        step += 1
    accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, teacher, student, args.delta, args.subset, args.max_new_tokens)
    print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.')
    student.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}'))

In [6]:
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import inspect
import tqdm
import logging
import random
import torch.nn as nn

from data import CoTDataset, CoTDataCollator
from models.teacher import Teacher
from models.emulator import Emulator
from models.configuration_emulator import EmulatorConfig


torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
random.seed(1234)
torch.manual_seed(1234)
logging.disable(logging.WARNING)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

@torch.no_grad()
def evaluate(dataloader, tokenizer, ctx, teacher, emulator, delta, subset):
    total_instances = 0
    total_loss = 0
    for batch in tqdm.tqdm(dataloader):
        #import pdb; pdb.set_trace()
        input_ids_cot = batch['input_ids_cot'].to(device)
        batch_size = input_ids_cot.shape[0]
        with ctx:
            teacher_states = teacher.extract_states(input_ids=input_ids_cot, delta=delta, subset=subset)
            outputs = emulator.compute_loss(input_ids=input_ids_cot, teacher_states=teacher_states)
            loss = outputs.loss
        total_loss += outputs.total_loss.item()
        total_instances += batch_size

    loss = total_loss / total_instances
    return loss


    

In [7]:
thought_emulator_trainer_args=dict(
    debug=False,
    train_path="../data/4_by_4_mult/train.txt",
    val_path="../data/4_by_4_mult/valid.txt",
    teacher="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    save_model="../train_models/4_by_4_mult/gpt2/emulator_initial",
    base_model='gpt2',
    epochs=1,
    batch_size=16,
    lr=5e-5,
    max_new_tokens=128,
    delta='dynamic',
    max_grad_norm=1.0,
    subset='diagonal',
    mixture_size=1,
    
)

from types import SimpleNamespace

args = SimpleNamespace(**thought_emulator_trainer_args)

In [8]:

# parser = argparse.ArgumentParser()
# parser.add_argument('--teacher', type=str, required=True)
# parser.add_argument('--delta', type=str, required=True)
# parser.add_argument('--train_path', type=str, required=True)
# parser.add_argument('--val_path', type=str, required=True)
# parser.add_argument('--save_model', type=str, required=True)
# parser.add_argument('--base_model', type=str, default='gpt2')
# parser.add_argument('--epochs', type=int, default=5)
# parser.add_argument('--batch_size', type=int, default=32)
# parser.add_argument('--lr', type=float, default=5e-5)
# parser.add_argument('--max_grad_norm', type=float, default=1.0)
# parser.add_argument('--subset', type=str, choices=['diagonal', 'last_column', 'top_row', 'bottom_row', 'first_column'], default='diagonal')
# parser.add_argument('--mixture_size', type=int, default=1)
# args = parser.parse_args()

# print (args)
dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Create Emulator
config = EmulatorConfig(base_model=args.base_model, mixture_size=args.mixture_size)
emulator = Emulator(config).to(device).to(ptdtype)

# Load Teacher
teacher = Teacher.from_pretrained(args.teacher).to(device).to(ptdtype)

# Load data
tokenizer = teacher.tokenizer
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

# Create Optimizer
trainable_params = list(emulator.parameters())
use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args)

teacher.eval()
emulator.eval() # to turn off dropout

for p in teacher.parameters():
    p.requires_grad = False

# Train
step = 0
for epoch in range(args.epochs):
    print(f"Epoch {epoch}")

    for batch in tqdm.tqdm(train_dataloader):
        #import pdb; pdb.set_trace()
        input_ids_cot = batch['input_ids_cot'].to(device)
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        with ctx:
            with torch.no_grad():
                teacher_states = teacher.extract_states(input_ids=input_ids_cot, delta=args.delta, subset=args.subset)
            outputs = emulator.compute_loss(input_ids=input_ids_nocot, teacher_states=teacher_states)
        loss = outputs.loss

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        if step % 100 == 0:
            print (f"Step: {step}. Loss: {loss}.")
        step += 1
    loss = evaluate(val_dataloader, tokenizer, ctx, teacher, emulator, args.delta, args.subset)
    print (f'Val. Loss: {loss}.')
    emulator.save_pretrained(os.path.join(args.save_model, f'checkpoint_{epoch}'))

torch.float32 float32 cuda
Creating features from dataset file at ../data/4_by_4_mult/train.txt
tgt_avg:  49.0
src_avg:  10.0
ratios:  0.20408163265306123
tgt_avg:  13.0
src_avg:  10.0
ratios:  0.7692307692307693
 1 3 3 8 * 5 1 0 5 <|endoftext|> 5 5 6 1 4 + 0 1 3 3 8 0 ( 5 6 9 4 2 1 ) + 0 0 0 0 0 0 0 ( 5 6 9 4 2 1 0 ) + 0 0 0 5 5 6 1 4 <|endoftext|> #### 5 6 9 9 7 7 1 4 <|endoftext|>
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 642, 642, 718, 352, 604, 1343, 657, 352, 513, 513, 807, 657, 357, 642, 718, 860, 604, 362, 352, 1267, 1343, 657, 657, 657, 657, 657, 657, 657, 357, 642, 718, 860, 604, 362, 352, 657, 1267, 1343, 657, 657, 657, 642, 642, 718, 352, 604, 220, 50256]
[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 1303, 21017, 642, 718, 860, 860, 767, 767, 352, 604, 220, 50256]
[352, 513, 513, 807, 1635, 642, 352, 657, 642, 220, 50256, 1303, 21017, 642, 718, 860, 860, 767, 767, 352, 604, 220, 50256]
 1 3 3 8 * 5 1 0 5 <|endoftext|> #### 5 6 9

  0%|                                       | 3/50500 [00:00<1:42:18,  8.23it/s]

Step: 0. Loss: 6161.64599609375.


  0%|                                     | 103/50500 [00:07<1:01:52, 13.57it/s]

Step: 100. Loss: 638.393310546875.


  0%|▏                                    | 203/50500 [00:15<1:01:35, 13.61it/s]

Step: 200. Loss: 528.7758178710938.


  1%|▏                                    | 303/50500 [00:22<1:01:33, 13.59it/s]

Step: 300. Loss: 512.0115356445312.


  1%|▎                                    | 403/50500 [00:30<1:01:37, 13.55it/s]

Step: 400. Loss: 496.41668701171875.


  1%|▎                                    | 503/50500 [00:37<1:01:18, 13.59it/s]

Step: 500. Loss: 476.5987548828125.


  1%|▍                                    | 603/50500 [00:44<1:01:12, 13.59it/s]

Step: 600. Loss: 473.505126953125.


  1%|▌                                    | 703/50500 [00:52<1:01:04, 13.59it/s]

Step: 700. Loss: 474.31195068359375.


  2%|▌                                    | 803/50500 [00:59<1:00:50, 13.61it/s]

Step: 800. Loss: 414.3396301269531.


  2%|▋                                    | 903/50500 [01:07<1:00:48, 13.59it/s]

Step: 900. Loss: 443.33209228515625.


  2%|▋                                   | 1003/50500 [01:14<1:00:54, 13.54it/s]

Step: 1000. Loss: 438.5155944824219.


  2%|▊                                   | 1103/50500 [01:22<1:00:54, 13.52it/s]

Step: 1100. Loss: 464.09735107421875.


  2%|▊                                   | 1203/50500 [01:29<1:00:40, 13.54it/s]

Step: 1200. Loss: 453.64178466796875.


  3%|▉                                   | 1303/50500 [01:36<1:00:44, 13.50it/s]

Step: 1300. Loss: 425.3907165527344.


  3%|█                                   | 1403/50500 [01:44<1:00:26, 13.54it/s]

Step: 1400. Loss: 397.05169677734375.


  3%|█                                   | 1503/50500 [01:51<1:00:17, 13.54it/s]

Step: 1500. Loss: 421.77239990234375.


  3%|█▏                                  | 1603/50500 [01:59<1:00:18, 13.51it/s]

Step: 1600. Loss: 433.2487487792969.


  3%|█▏                                  | 1703/50500 [02:06<1:00:09, 13.52it/s]

Step: 1700. Loss: 392.84124755859375.


  4%|█▎                                  | 1803/50500 [02:14<1:00:05, 13.51it/s]

Step: 1800. Loss: 367.36932373046875.


  4%|█▍                                    | 1903/50500 [02:21<59:58, 13.50it/s]

Step: 1900. Loss: 374.29095458984375.


  4%|█▍                                  | 2003/50500 [02:29<1:00:05, 13.45it/s]

Step: 2000. Loss: 390.0916748046875.


  4%|█▌                                    | 2103/50500 [02:36<59:51, 13.48it/s]

Step: 2100. Loss: 372.052734375.


  4%|█▋                                    | 2203/50500 [02:43<59:44, 13.48it/s]

Step: 2200. Loss: 371.45538330078125.


  5%|█▋                                    | 2303/50500 [02:51<59:38, 13.47it/s]

Step: 2300. Loss: 349.03485107421875.


  5%|█▊                                    | 2403/50500 [02:58<59:39, 13.44it/s]

Step: 2400. Loss: 344.599853515625.


  5%|█▉                                    | 2503/50500 [03:06<59:15, 13.50it/s]

Step: 2500. Loss: 334.51416015625.


  5%|█▉                                    | 2603/50500 [03:13<59:20, 13.45it/s]

Step: 2600. Loss: 358.33221435546875.


  5%|██                                    | 2703/50500 [03:21<59:06, 13.48it/s]

Step: 2700. Loss: 331.5790710449219.


  6%|██                                    | 2803/50500 [03:28<58:55, 13.49it/s]

Step: 2800. Loss: 334.4571533203125.


  6%|██▏                                   | 2903/50500 [03:36<58:41, 13.52it/s]

Step: 2900. Loss: 361.7004089355469.


  6%|██▎                                   | 3003/50500 [03:43<58:41, 13.49it/s]

Step: 3000. Loss: 340.67999267578125.


  6%|██▎                                   | 3103/50500 [03:51<58:25, 13.52it/s]

Step: 3100. Loss: 288.42498779296875.


  6%|██▍                                   | 3203/50500 [03:58<58:27, 13.48it/s]

Step: 3200. Loss: 315.31903076171875.


  7%|██▍                                   | 3303/50500 [04:06<58:27, 13.46it/s]

Step: 3300. Loss: 284.3266906738281.


  7%|██▌                                   | 3403/50500 [04:13<58:08, 13.50it/s]

Step: 3400. Loss: 319.56695556640625.


  7%|██▋                                   | 3503/50500 [04:21<57:56, 13.52it/s]

Step: 3500. Loss: 320.9759521484375.


  7%|██▋                                   | 3603/50500 [04:28<57:45, 13.53it/s]

Step: 3600. Loss: 282.8670654296875.


  7%|██▊                                   | 3703/50500 [04:36<57:53, 13.47it/s]

Step: 3700. Loss: 289.09381103515625.


  8%|██▊                                   | 3803/50500 [04:43<57:59, 13.42it/s]

Step: 3800. Loss: 274.8027038574219.


  8%|██▉                                   | 3903/50500 [04:51<57:26, 13.52it/s]

Step: 3900. Loss: 306.4050598144531.


  8%|███                                   | 4003/50500 [04:58<57:21, 13.51it/s]

Step: 4000. Loss: 262.45684814453125.


  8%|███                                   | 4103/50500 [05:06<57:17, 13.50it/s]

Step: 4100. Loss: 290.26904296875.


  8%|███▏                                  | 4203/50500 [05:13<57:09, 13.50it/s]

Step: 4200. Loss: 268.08251953125.


  9%|███▏                                  | 4303/50500 [05:21<57:02, 13.50it/s]

Step: 4300. Loss: 288.345458984375.


  9%|███▎                                  | 4403/50500 [05:28<56:52, 13.51it/s]

Step: 4400. Loss: 258.6070556640625.


  9%|███▍                                  | 4503/50500 [05:36<56:37, 13.54it/s]

Step: 4500. Loss: 250.0915069580078.


  9%|███▍                                  | 4603/50500 [05:43<56:38, 13.50it/s]

Step: 4600. Loss: 259.5517578125.


  9%|███▌                                  | 4703/50500 [05:50<56:31, 13.50it/s]

Step: 4700. Loss: 267.87652587890625.


 10%|███▌                                  | 4803/50500 [05:58<56:39, 13.44it/s]

Step: 4800. Loss: 242.83670043945312.


 10%|███▋                                  | 4903/50500 [06:05<56:28, 13.46it/s]

Step: 4900. Loss: 257.7442626953125.


 10%|███▊                                  | 5003/50500 [06:13<56:13, 13.49it/s]

Step: 5000. Loss: 239.21653747558594.


 10%|███▊                                  | 5103/50500 [06:20<56:10, 13.47it/s]

Step: 5100. Loss: 254.01744079589844.


 10%|███▉                                  | 5203/50500 [06:28<56:06, 13.45it/s]

Step: 5200. Loss: 241.92062377929688.


 11%|███▉                                  | 5303/50500 [06:35<55:50, 13.49it/s]

Step: 5300. Loss: 231.53744506835938.


 11%|████                                  | 5403/50500 [06:43<55:46, 13.47it/s]

Step: 5400. Loss: 231.28323364257812.


 11%|████▏                                 | 5503/50500 [06:50<55:26, 13.53it/s]

Step: 5500. Loss: 234.0608673095703.


 11%|████▏                                 | 5603/50500 [06:58<55:33, 13.47it/s]

Step: 5600. Loss: 251.3236846923828.


 11%|████▎                                 | 5703/50500 [07:05<55:23, 13.48it/s]

Step: 5700. Loss: 248.7047882080078.


 11%|████▎                                 | 5803/50500 [07:13<55:34, 13.41it/s]

Step: 5800. Loss: 247.84312438964844.


 12%|████▍                                 | 5903/50500 [07:20<54:57, 13.52it/s]

Step: 5900. Loss: 218.1398468017578.


 12%|████▌                                 | 6003/50500 [07:28<55:02, 13.47it/s]

Step: 6000. Loss: 226.40830993652344.


 12%|████▌                                 | 6103/50500 [07:35<54:43, 13.52it/s]

Step: 6100. Loss: 226.03736877441406.


 12%|████▋                                 | 6203/50500 [07:43<54:55, 13.44it/s]

Step: 6200. Loss: 198.1998291015625.


 12%|████▋                                 | 6303/50500 [07:50<54:38, 13.48it/s]

Step: 6300. Loss: 237.95896911621094.


 13%|████▊                                 | 6403/50500 [07:58<54:45, 13.42it/s]

Step: 6400. Loss: 214.7390899658203.


 13%|████▉                                 | 6503/50500 [08:05<54:32, 13.44it/s]

Step: 6500. Loss: 218.6456298828125.


 13%|████▉                                 | 6603/50500 [08:13<54:20, 13.46it/s]

Step: 6600. Loss: 214.71823120117188.


 13%|█████                                 | 6703/50500 [08:20<54:04, 13.50it/s]

Step: 6700. Loss: 231.40579223632812.


 13%|█████                                 | 6803/50500 [08:28<53:51, 13.52it/s]

Step: 6800. Loss: 219.36158752441406.


 14%|█████▏                                | 6903/50500 [08:35<53:49, 13.50it/s]

Step: 6900. Loss: 209.39593505859375.


 14%|█████▎                                | 7003/50500 [08:42<53:38, 13.52it/s]

Step: 7000. Loss: 185.05474853515625.


 14%|█████▎                                | 7103/50500 [08:50<53:42, 13.47it/s]

Step: 7100. Loss: 198.30270385742188.


 14%|█████▍                                | 7203/50500 [08:57<53:33, 13.47it/s]

Step: 7200. Loss: 198.6400909423828.


 14%|█████▍                                | 7303/50500 [09:05<53:26, 13.47it/s]

Step: 7300. Loss: 211.70272827148438.


 15%|█████▌                                | 7403/50500 [09:12<53:15, 13.48it/s]

Step: 7400. Loss: 196.00564575195312.


 15%|█████▋                                | 7503/50500 [09:20<53:09, 13.48it/s]

Step: 7500. Loss: 171.86984252929688.


 15%|█████▋                                | 7603/50500 [09:27<53:20, 13.40it/s]

Step: 7600. Loss: 202.98202514648438.


 15%|█████▊                                | 7703/50500 [09:35<53:04, 13.44it/s]

Step: 7700. Loss: 190.87973022460938.


 15%|█████▊                                | 7803/50500 [09:42<53:01, 13.42it/s]

Step: 7800. Loss: 195.93572998046875.


 16%|█████▉                                | 7903/50500 [09:50<52:43, 13.47it/s]

Step: 7900. Loss: 196.15557861328125.


 16%|██████                                | 8003/50500 [09:57<52:26, 13.51it/s]

Step: 8000. Loss: 199.36680603027344.


 16%|██████                                | 8103/50500 [10:05<52:48, 13.38it/s]

Step: 8100. Loss: 195.93907165527344.


 16%|██████▏                               | 8203/50500 [10:12<52:37, 13.39it/s]

Step: 8200. Loss: 182.7989959716797.


 16%|██████▏                               | 8303/50500 [10:20<52:09, 13.48it/s]

Step: 8300. Loss: 198.10121154785156.


 17%|██████▎                               | 8403/50500 [10:27<51:59, 13.50it/s]

Step: 8400. Loss: 196.04226684570312.


 17%|██████▍                               | 8503/50500 [10:35<52:01, 13.45it/s]

Step: 8500. Loss: 178.58209228515625.


 17%|██████▍                               | 8603/50500 [10:42<52:10, 13.38it/s]

Step: 8600. Loss: 186.0169677734375.


 17%|██████▌                               | 8703/50500 [10:50<51:40, 13.48it/s]

Step: 8700. Loss: 160.66317749023438.


 17%|██████▌                               | 8803/50500 [10:57<51:43, 13.43it/s]

Step: 8800. Loss: 202.82301330566406.


 18%|██████▋                               | 8903/50500 [11:05<51:39, 13.42it/s]

Step: 8900. Loss: 221.59469604492188.


 18%|██████▊                               | 9003/50500 [11:12<51:25, 13.45it/s]

Step: 9000. Loss: 194.24273681640625.


 18%|██████▊                               | 9103/50500 [11:20<51:12, 13.47it/s]

Step: 9100. Loss: 178.63906860351562.


 18%|██████▉                               | 9203/50500 [11:27<51:14, 13.43it/s]

Step: 9200. Loss: 185.93116760253906.


 18%|███████                               | 9303/50500 [11:35<50:52, 13.49it/s]

Step: 9300. Loss: 206.00335693359375.


 19%|███████                               | 9403/50500 [11:42<50:46, 13.49it/s]

Step: 9400. Loss: 185.73568725585938.


 19%|███████▏                              | 9503/50500 [11:50<50:41, 13.48it/s]

Step: 9500. Loss: 175.33033752441406.


 19%|███████▏                              | 9603/50500 [11:57<50:25, 13.52it/s]

Step: 9600. Loss: 148.52859497070312.


 19%|███████▎                              | 9703/50500 [12:05<51:21, 13.24it/s]

Step: 9700. Loss: 181.23670959472656.


 19%|███████▍                              | 9803/50500 [12:12<50:24, 13.46it/s]

Step: 9800. Loss: 174.75669860839844.


 20%|███████▍                              | 9903/50500 [12:19<50:06, 13.50it/s]

Step: 9900. Loss: 171.08740234375.


 20%|███████▎                             | 10003/50500 [12:27<50:13, 13.44it/s]

Step: 10000. Loss: 159.50340270996094.


 20%|███████▍                             | 10103/50500 [12:34<49:55, 13.49it/s]

Step: 10100. Loss: 169.4639892578125.


 20%|███████▍                             | 10203/50500 [12:42<49:37, 13.53it/s]

Step: 10200. Loss: 182.6640625.


 20%|███████▌                             | 10303/50500 [12:49<49:41, 13.48it/s]

Step: 10300. Loss: 184.52886962890625.


 21%|███████▌                             | 10403/50500 [12:57<49:20, 13.54it/s]

Step: 10400. Loss: 155.986328125.


 21%|███████▋                             | 10503/50500 [13:04<49:05, 13.58it/s]

Step: 10500. Loss: 154.80908203125.


 21%|███████▊                             | 10603/50500 [13:12<49:02, 13.56it/s]

Step: 10600. Loss: 177.47308349609375.


 21%|███████▊                             | 10703/50500 [13:19<49:06, 13.51it/s]

Step: 10700. Loss: 154.60296630859375.


 21%|███████▉                             | 10803/50500 [13:26<48:52, 13.54it/s]

Step: 10800. Loss: 166.75350952148438.


 22%|███████▉                             | 10903/50500 [13:34<48:38, 13.57it/s]

Step: 10900. Loss: 171.84132385253906.


 22%|████████                             | 11003/50500 [13:41<48:31, 13.57it/s]

Step: 11000. Loss: 169.6198272705078.


 22%|████████▏                            | 11103/50500 [13:49<48:29, 13.54it/s]

Step: 11100. Loss: 166.457763671875.


 22%|████████▏                            | 11203/50500 [13:56<48:18, 13.56it/s]

Step: 11200. Loss: 134.9309844970703.


 22%|████████▎                            | 11303/50500 [14:04<48:11, 13.56it/s]

Step: 11300. Loss: 163.5886993408203.


 23%|████████▎                            | 11403/50500 [14:11<47:57, 13.59it/s]

Step: 11400. Loss: 152.67933654785156.


 23%|████████▍                            | 11503/50500 [14:19<48:09, 13.50it/s]

Step: 11500. Loss: 195.58261108398438.


 23%|████████▌                            | 11603/50500 [14:26<47:57, 13.52it/s]

Step: 11600. Loss: 183.95912170410156.


 23%|████████▌                            | 11703/50500 [14:33<47:47, 13.53it/s]

Step: 11700. Loss: 169.81077575683594.


 23%|████████▋                            | 11803/50500 [14:41<47:42, 13.52it/s]

Step: 11800. Loss: 186.66650390625.


 24%|████████▋                            | 11903/50500 [14:48<47:20, 13.59it/s]

Step: 11900. Loss: 133.9591827392578.


 24%|████████▊                            | 12003/50500 [14:56<47:22, 13.54it/s]

Step: 12000. Loss: 167.8733367919922.


 24%|████████▊                            | 12103/50500 [15:03<47:13, 13.55it/s]

Step: 12100. Loss: 153.3121337890625.


 24%|████████▉                            | 12203/50500 [15:11<47:08, 13.54it/s]

Step: 12200. Loss: 157.18927001953125.


 24%|█████████                            | 12303/50500 [15:18<47:06, 13.51it/s]

Step: 12300. Loss: 178.71408081054688.


 25%|█████████                            | 12403/50500 [15:25<47:19, 13.42it/s]

Step: 12400. Loss: 155.03619384765625.


 25%|█████████▏                           | 12503/50500 [15:33<46:50, 13.52it/s]

Step: 12500. Loss: 137.94821166992188.


 25%|█████████▏                           | 12603/50500 [15:40<46:35, 13.56it/s]

Step: 12600. Loss: 127.51139831542969.


 25%|█████████▎                           | 12703/50500 [15:48<46:21, 13.59it/s]

Step: 12700. Loss: 169.51751708984375.


 25%|█████████▍                           | 12803/50500 [15:55<46:20, 13.56it/s]

Step: 12800. Loss: 150.06964111328125.


 26%|█████████▍                           | 12903/50500 [16:03<46:21, 13.52it/s]

Step: 12900. Loss: 146.94509887695312.


 26%|█████████▌                           | 13003/50500 [16:10<46:08, 13.54it/s]

Step: 13000. Loss: 170.67120361328125.


 26%|█████████▌                           | 13103/50500 [16:18<45:59, 13.55it/s]

Step: 13100. Loss: 157.9827423095703.


 26%|█████████▋                           | 13203/50500 [16:25<46:16, 13.43it/s]

Step: 13200. Loss: 146.20388793945312.


 26%|█████████▋                           | 13303/50500 [16:32<45:58, 13.48it/s]

Step: 13300. Loss: 143.4331817626953.


 27%|█████████▊                           | 13403/50500 [16:40<45:47, 13.50it/s]

Step: 13400. Loss: 139.53363037109375.


 27%|█████████▉                           | 13503/50500 [16:47<45:37, 13.52it/s]

Step: 13500. Loss: 161.5718994140625.


 27%|█████████▉                           | 13603/50500 [16:55<45:30, 13.51it/s]

Step: 13600. Loss: 134.0619354248047.


 27%|██████████                           | 13703/50500 [17:02<45:20, 13.53it/s]

Step: 13700. Loss: 144.32940673828125.


 27%|██████████                           | 13803/50500 [17:10<45:18, 13.50it/s]

Step: 13800. Loss: 150.4340057373047.


 28%|██████████▏                          | 13903/50500 [17:17<44:59, 13.56it/s]

Step: 13900. Loss: 150.777587890625.


 28%|██████████▎                          | 14003/50500 [17:25<44:51, 13.56it/s]

Step: 14000. Loss: 155.33023071289062.


 28%|██████████▎                          | 14103/50500 [17:32<44:51, 13.52it/s]

Step: 14100. Loss: 174.9440155029297.


 28%|██████████▍                          | 14203/50500 [17:39<45:01, 13.44it/s]

Step: 14200. Loss: 141.99571228027344.


 28%|██████████▍                          | 14303/50500 [17:47<44:26, 13.58it/s]

Step: 14300. Loss: 152.42117309570312.


 29%|██████████▌                          | 14403/50500 [17:54<44:32, 13.51it/s]

Step: 14400. Loss: 156.98439025878906.


 29%|██████████▋                          | 14503/50500 [18:02<44:15, 13.55it/s]

Step: 14500. Loss: 126.14698028564453.


 29%|██████████▋                          | 14603/50500 [18:09<44:04, 13.57it/s]

Step: 14600. Loss: 141.5836181640625.


 29%|██████████▊                          | 14703/50500 [18:17<44:08, 13.51it/s]

Step: 14700. Loss: 138.43118286132812.


 29%|██████████▊                          | 14803/50500 [18:24<44:00, 13.52it/s]

Step: 14800. Loss: 136.59072875976562.


 30%|██████████▉                          | 14903/50500 [18:32<43:58, 13.49it/s]

Step: 14900. Loss: 128.64944458007812.


 30%|██████████▉                          | 15003/50500 [18:39<43:42, 13.54it/s]

Step: 15000. Loss: 143.86013793945312.


 30%|███████████                          | 15103/50500 [18:46<43:31, 13.56it/s]

Step: 15100. Loss: 127.1678237915039.


 30%|███████████▏                         | 15203/50500 [18:54<43:32, 13.51it/s]

Step: 15200. Loss: 134.8614501953125.


 30%|███████████▏                         | 15303/50500 [19:01<43:02, 13.63it/s]

Step: 15300. Loss: 132.17227172851562.


 31%|███████████▎                         | 15403/50500 [19:09<43:25, 13.47it/s]

Step: 15400. Loss: 166.64105224609375.


 31%|███████████▎                         | 15503/50500 [19:16<43:01, 13.56it/s]

Step: 15500. Loss: 134.7955322265625.


 31%|███████████▍                         | 15603/50500 [19:24<43:10, 13.47it/s]

Step: 15600. Loss: 122.4561767578125.


 31%|███████████▌                         | 15703/50500 [19:31<42:42, 13.58it/s]

Step: 15700. Loss: 123.78707885742188.


 31%|███████████▌                         | 15803/50500 [19:38<42:50, 13.50it/s]

Step: 15800. Loss: 128.64503479003906.


 31%|███████████▋                         | 15903/50500 [19:46<42:29, 13.57it/s]

Step: 15900. Loss: 133.66583251953125.


 32%|███████████▋                         | 16003/50500 [19:53<42:22, 13.57it/s]

Step: 16000. Loss: 160.82345581054688.


 32%|███████████▊                         | 16103/50500 [20:01<42:15, 13.57it/s]

Step: 16100. Loss: 133.69403076171875.


 32%|███████████▊                         | 16203/50500 [20:08<42:21, 13.50it/s]

Step: 16200. Loss: 125.83778381347656.


 32%|███████████▉                         | 16303/50500 [20:16<42:07, 13.53it/s]

Step: 16300. Loss: 108.12906646728516.


 32%|████████████                         | 16403/50500 [20:23<41:50, 13.58it/s]

Step: 16400. Loss: 122.63091278076172.


 33%|████████████                         | 16503/50500 [20:31<41:54, 13.52it/s]

Step: 16500. Loss: 125.84132385253906.


 33%|████████████▏                        | 16603/50500 [20:38<41:59, 13.46it/s]

Step: 16600. Loss: 145.10299682617188.


 33%|████████████▏                        | 16703/50500 [20:45<41:44, 13.50it/s]

Step: 16700. Loss: 116.43566131591797.


 33%|████████████▎                        | 16803/50500 [20:53<41:30, 13.53it/s]

Step: 16800. Loss: 138.12619018554688.


 33%|████████████▍                        | 16903/50500 [21:00<41:19, 13.55it/s]

Step: 16900. Loss: 141.4923858642578.


 34%|████████████▍                        | 17003/50500 [21:08<41:20, 13.50it/s]

Step: 17000. Loss: 143.29196166992188.


 34%|████████████▌                        | 17103/50500 [21:15<41:02, 13.56it/s]

Step: 17100. Loss: 117.54640197753906.


 34%|████████████▌                        | 17203/50500 [21:23<40:53, 13.57it/s]

Step: 17200. Loss: 115.70722961425781.


 34%|████████████▋                        | 17303/50500 [21:30<40:54, 13.52it/s]

Step: 17300. Loss: 121.92863464355469.


 34%|████████████▊                        | 17403/50500 [21:38<40:48, 13.52it/s]

Step: 17400. Loss: 106.26348876953125.


 35%|████████████▊                        | 17503/50500 [21:45<40:37, 13.54it/s]

Step: 17500. Loss: 134.95706176757812.


 35%|████████████▉                        | 17603/50500 [21:52<40:21, 13.59it/s]

Step: 17600. Loss: 129.51025390625.


 35%|████████████▉                        | 17703/50500 [22:00<40:16, 13.57it/s]

Step: 17700. Loss: 139.0252685546875.


 35%|█████████████                        | 17803/50500 [22:07<40:12, 13.55it/s]

Step: 17800. Loss: 117.68314361572266.


 35%|█████████████                        | 17903/50500 [22:15<40:39, 13.36it/s]

Step: 17900. Loss: 135.9968719482422.


 36%|█████████████▏                       | 18003/50500 [22:22<39:59, 13.55it/s]

Step: 18000. Loss: 136.42022705078125.


 36%|█████████████▎                       | 18103/50500 [22:30<40:04, 13.47it/s]

Step: 18100. Loss: 133.84487915039062.


 36%|█████████████▎                       | 18203/50500 [22:37<39:43, 13.55it/s]

Step: 18200. Loss: 119.27503967285156.


 36%|█████████████▍                       | 18303/50500 [22:45<40:00, 13.41it/s]

Step: 18300. Loss: 119.41639709472656.


 36%|█████████████▍                       | 18403/50500 [22:52<39:28, 13.55it/s]

Step: 18400. Loss: 134.7600860595703.


 37%|█████████████▌                       | 18503/50500 [22:59<39:22, 13.55it/s]

Step: 18500. Loss: 105.87474060058594.


 37%|█████████████▋                       | 18603/50500 [23:07<39:23, 13.49it/s]

Step: 18600. Loss: 146.871337890625.


 37%|█████████████▋                       | 18703/50500 [23:14<39:03, 13.57it/s]

Step: 18700. Loss: 117.34838104248047.


 37%|█████████████▊                       | 18803/50500 [23:22<39:06, 13.51it/s]

Step: 18800. Loss: 128.52615356445312.


 37%|█████████████▊                       | 18903/50500 [23:29<38:51, 13.55it/s]

Step: 18900. Loss: 137.7454833984375.


 38%|█████████████▉                       | 19003/50500 [23:37<38:41, 13.57it/s]

Step: 19000. Loss: 131.08920288085938.


 38%|█████████████▉                       | 19103/50500 [23:44<38:43, 13.51it/s]

Step: 19100. Loss: 110.7493896484375.


 38%|██████████████                       | 19203/50500 [23:52<38:31, 13.54it/s]

Step: 19200. Loss: 115.74897766113281.


 38%|██████████████▏                      | 19303/50500 [23:59<38:30, 13.50it/s]

Step: 19300. Loss: 135.61294555664062.


 38%|██████████████▏                      | 19403/50500 [24:06<38:53, 13.33it/s]

Step: 19400. Loss: 120.04463195800781.


 39%|██████████████▎                      | 19503/50500 [24:14<38:06, 13.56it/s]

Step: 19500. Loss: 113.74700164794922.


 39%|██████████████▎                      | 19603/50500 [24:21<38:11, 13.48it/s]

Step: 19600. Loss: 105.46990966796875.


 39%|██████████████▍                      | 19703/50500 [24:29<37:51, 13.56it/s]

Step: 19700. Loss: 150.88848876953125.


 39%|██████████████▌                      | 19803/50500 [24:36<37:45, 13.55it/s]

Step: 19800. Loss: 129.06297302246094.


 39%|██████████████▌                      | 19903/50500 [24:44<37:38, 13.55it/s]

Step: 19900. Loss: 142.96987915039062.


 40%|██████████████▋                      | 20003/50500 [24:51<37:29, 13.55it/s]

Step: 20000. Loss: 134.73912048339844.


 40%|██████████████▋                      | 20103/50500 [24:59<37:22, 13.55it/s]

Step: 20100. Loss: 121.61360931396484.


 40%|██████████████▊                      | 20203/50500 [25:06<37:21, 13.52it/s]

Step: 20200. Loss: 124.36967468261719.


 40%|██████████████▉                      | 20303/50500 [25:13<37:12, 13.52it/s]

Step: 20300. Loss: 115.58457946777344.


 40%|██████████████▉                      | 20403/50500 [25:21<37:10, 13.49it/s]

Step: 20400. Loss: 134.29342651367188.


 41%|███████████████                      | 20503/50500 [25:28<36:56, 13.53it/s]

Step: 20500. Loss: 134.54476928710938.


 41%|███████████████                      | 20603/50500 [25:36<36:57, 13.48it/s]

Step: 20600. Loss: 133.58905029296875.


 41%|███████████████▏                     | 20703/50500 [25:43<36:37, 13.56it/s]

Step: 20700. Loss: 129.1644287109375.


 41%|███████████████▏                     | 20803/50500 [25:51<36:32, 13.54it/s]

Step: 20800. Loss: 121.29849243164062.


 41%|███████████████▎                     | 20903/50500 [25:58<36:25, 13.54it/s]

Step: 20900. Loss: 127.19377899169922.


 42%|███████████████▍                     | 21003/50500 [26:06<36:22, 13.52it/s]

Step: 21000. Loss: 142.37899780273438.


 42%|███████████████▍                     | 21103/50500 [26:13<36:12, 13.53it/s]

Step: 21100. Loss: 97.6521224975586.


 42%|███████████████▌                     | 21203/50500 [26:20<36:06, 13.52it/s]

Step: 21200. Loss: 95.1351318359375.


 42%|███████████████▌                     | 21303/50500 [26:28<35:52, 13.56it/s]

Step: 21300. Loss: 118.31951904296875.


 42%|███████████████▋                     | 21403/50500 [26:35<35:46, 13.55it/s]

Step: 21400. Loss: 120.49197387695312.


 43%|███████████████▊                     | 21503/50500 [26:43<35:31, 13.61it/s]

Step: 21500. Loss: 109.23483276367188.


 43%|███████████████▊                     | 21603/50500 [26:50<35:30, 13.56it/s]

Step: 21600. Loss: 129.30621337890625.


 43%|███████████████▉                     | 21703/50500 [26:58<35:21, 13.57it/s]

Step: 21700. Loss: 113.9143295288086.


 43%|███████████████▉                     | 21803/50500 [27:05<35:18, 13.54it/s]

Step: 21800. Loss: 136.16995239257812.


 43%|████████████████                     | 21903/50500 [27:13<35:35, 13.39it/s]

Step: 21900. Loss: 122.54179382324219.


 44%|████████████████                     | 22003/50500 [27:20<35:11, 13.50it/s]

Step: 22000. Loss: 99.55326080322266.


 44%|████████████████▏                    | 22103/50500 [27:27<35:08, 13.47it/s]

Step: 22100. Loss: 107.73796081542969.


 44%|████████████████▎                    | 22203/50500 [27:35<34:50, 13.54it/s]

Step: 22200. Loss: 124.74949645996094.


 44%|████████████████▎                    | 22303/50500 [27:42<34:40, 13.56it/s]

Step: 22300. Loss: 150.38134765625.


 44%|████████████████▍                    | 22403/50500 [27:50<34:39, 13.51it/s]

Step: 22400. Loss: 98.48412322998047.


 45%|████████████████▍                    | 22503/50500 [27:57<34:28, 13.54it/s]

Step: 22500. Loss: 145.624755859375.


 45%|████████████████▌                    | 22603/50500 [28:05<34:45, 13.37it/s]

Step: 22600. Loss: 100.03544616699219.


 45%|████████████████▋                    | 22703/50500 [28:12<34:10, 13.55it/s]

Step: 22700. Loss: 124.91110229492188.


 45%|████████████████▋                    | 22803/50500 [28:20<34:03, 13.56it/s]

Step: 22800. Loss: 110.55709838867188.


 45%|████████████████▊                    | 22903/50500 [28:27<34:02, 13.51it/s]

Step: 22900. Loss: 87.75294494628906.


 46%|████████████████▊                    | 23003/50500 [28:34<33:52, 13.53it/s]

Step: 23000. Loss: 118.35757446289062.


 46%|████████████████▉                    | 23103/50500 [28:42<33:43, 13.54it/s]

Step: 23100. Loss: 110.8736343383789.


 46%|█████████████████                    | 23203/50500 [28:49<33:35, 13.54it/s]

Step: 23200. Loss: 115.93560791015625.


 46%|█████████████████                    | 23303/50500 [28:57<33:27, 13.55it/s]

Step: 23300. Loss: 107.1434326171875.


 46%|█████████████████▏                   | 23403/50500 [29:04<33:38, 13.42it/s]

Step: 23400. Loss: 100.04531860351562.


 47%|█████████████████▏                   | 23503/50500 [29:12<33:35, 13.40it/s]

Step: 23500. Loss: 128.10650634765625.


 47%|█████████████████▎                   | 23603/50500 [29:19<33:22, 13.43it/s]

Step: 23600. Loss: 116.31494140625.


 47%|█████████████████▎                   | 23703/50500 [29:27<33:19, 13.40it/s]

Step: 23700. Loss: 139.65890502929688.


 47%|█████████████████▍                   | 23803/50500 [29:34<33:21, 13.34it/s]

Step: 23800. Loss: 127.70440673828125.


 47%|█████████████████▌                   | 23903/50500 [29:42<33:08, 13.38it/s]

Step: 23900. Loss: 113.60123443603516.


 48%|█████████████████▌                   | 24003/50500 [29:49<32:51, 13.44it/s]

Step: 24000. Loss: 119.39053344726562.


 48%|█████████████████▋                   | 24103/50500 [29:57<32:37, 13.48it/s]

Step: 24100. Loss: 126.08544921875.


 48%|█████████████████▋                   | 24203/50500 [30:04<32:50, 13.35it/s]

Step: 24200. Loss: 95.82534790039062.


 48%|█████████████████▊                   | 24303/50500 [30:12<32:27, 13.45it/s]

Step: 24300. Loss: 78.90953063964844.


 48%|█████████████████▉                   | 24403/50500 [30:19<32:33, 13.36it/s]

Step: 24400. Loss: 102.50387573242188.


 49%|█████████████████▉                   | 24503/50500 [30:27<32:28, 13.34it/s]

Step: 24500. Loss: 103.45552062988281.


 49%|██████████████████                   | 24603/50500 [30:34<32:20, 13.34it/s]

Step: 24600. Loss: 88.21469116210938.


 49%|██████████████████                   | 24703/50500 [30:42<32:02, 13.42it/s]

Step: 24700. Loss: 114.04534149169922.


 49%|██████████████████▏                  | 24803/50500 [30:49<31:54, 13.43it/s]

Step: 24800. Loss: 96.93904113769531.


 49%|██████████████████▏                  | 24903/50500 [30:57<31:53, 13.38it/s]

Step: 24900. Loss: 92.96180725097656.


 50%|██████████████████▎                  | 25003/50500 [31:04<31:42, 13.40it/s]

Step: 25000. Loss: 103.84966278076172.


 50%|██████████████████▍                  | 25103/50500 [31:12<31:33, 13.41it/s]

Step: 25100. Loss: 98.63819885253906.


 50%|██████████████████▍                  | 25203/50500 [31:19<31:26, 13.41it/s]

Step: 25200. Loss: 104.47794342041016.


 50%|██████████████████▌                  | 25303/50500 [31:27<31:17, 13.42it/s]

Step: 25300. Loss: 100.88395690917969.


 50%|██████████████████▌                  | 25403/50500 [31:34<31:09, 13.43it/s]

Step: 25400. Loss: 104.67190551757812.


 51%|██████████████████▋                  | 25503/50500 [31:42<30:52, 13.49it/s]

Step: 25500. Loss: 98.68698120117188.


 51%|██████████████████▊                  | 25603/50500 [31:49<30:50, 13.45it/s]

Step: 25600. Loss: 126.30585479736328.


 51%|██████████████████▊                  | 25703/50500 [31:57<30:52, 13.39it/s]

Step: 25700. Loss: 86.29751586914062.


 51%|██████████████████▉                  | 25803/50500 [32:05<31:26, 13.09it/s]

Step: 25800. Loss: 99.56352233886719.


 51%|██████████████████▉                  | 25903/50500 [32:12<30:36, 13.40it/s]

Step: 25900. Loss: 91.47136688232422.


 51%|███████████████████                  | 26003/50500 [32:20<30:31, 13.37it/s]

Step: 26000. Loss: 129.483154296875.


 52%|███████████████████                  | 26103/50500 [32:27<30:08, 13.49it/s]

Step: 26100. Loss: 118.63587951660156.


 52%|███████████████████▏                 | 26203/50500 [32:35<30:07, 13.44it/s]

Step: 26200. Loss: 114.09196472167969.


 52%|███████████████████▎                 | 26303/50500 [32:42<29:51, 13.50it/s]

Step: 26300. Loss: 75.25093078613281.


 52%|███████████████████▎                 | 26403/50500 [32:50<30:00, 13.38it/s]

Step: 26400. Loss: 84.97520446777344.


 52%|███████████████████▍                 | 26503/50500 [32:57<29:45, 13.44it/s]

Step: 26500. Loss: 73.37989044189453.


 53%|███████████████████▍                 | 26603/50500 [33:05<29:39, 13.43it/s]

Step: 26600. Loss: 117.16001892089844.


 53%|███████████████████▌                 | 26703/50500 [33:12<29:29, 13.45it/s]

Step: 26700. Loss: 110.29928588867188.


 53%|███████████████████▋                 | 26803/50500 [33:20<29:24, 13.43it/s]

Step: 26800. Loss: 86.42835998535156.


 53%|███████████████████▋                 | 26903/50500 [33:27<29:18, 13.42it/s]

Step: 26900. Loss: 94.61094665527344.


 53%|███████████████████▊                 | 27003/50500 [33:35<29:04, 13.47it/s]

Step: 27000. Loss: 102.03739929199219.


 54%|███████████████████▊                 | 27103/50500 [33:42<29:08, 13.38it/s]

Step: 27100. Loss: 92.99794006347656.


 54%|███████████████████▉                 | 27203/50500 [33:50<28:55, 13.42it/s]

Step: 27200. Loss: 105.08497619628906.


 54%|████████████████████                 | 27303/50500 [33:57<28:44, 13.45it/s]

Step: 27300. Loss: 79.01412200927734.


 54%|████████████████████                 | 27403/50500 [34:05<29:02, 13.26it/s]

Step: 27400. Loss: 87.4922103881836.


 54%|████████████████████▏                | 27503/50500 [34:12<28:34, 13.41it/s]

Step: 27500. Loss: 98.34565734863281.


 55%|████████████████████▏                | 27603/50500 [34:20<28:22, 13.45it/s]

Step: 27600. Loss: 109.71978759765625.


 55%|████████████████████▎                | 27703/50500 [34:27<28:21, 13.40it/s]

Step: 27700. Loss: 104.63444519042969.


 55%|████████████████████▎                | 27803/50500 [34:35<28:19, 13.35it/s]

Step: 27800. Loss: 90.953125.


 55%|████████████████████▍                | 27903/50500 [34:42<28:00, 13.44it/s]

Step: 27900. Loss: 95.30775451660156.


 55%|████████████████████▌                | 28003/50500 [34:50<27:56, 13.42it/s]

Step: 28000. Loss: 79.31301879882812.


 56%|████████████████████▌                | 28103/50500 [34:57<27:48, 13.42it/s]

Step: 28100. Loss: 99.659423828125.


 56%|████████████████████▋                | 28203/50500 [35:05<27:43, 13.40it/s]

Step: 28200. Loss: 124.99290466308594.


 56%|████████████████████▋                | 28303/50500 [35:12<27:32, 13.43it/s]

Step: 28300. Loss: 100.175537109375.


 56%|████████████████████▊                | 28403/50500 [35:20<27:23, 13.44it/s]

Step: 28400. Loss: 83.25406646728516.


 56%|████████████████████▉                | 28503/50500 [35:27<27:20, 13.41it/s]

Step: 28500. Loss: 72.0319595336914.


 57%|████████████████████▉                | 28603/50500 [35:35<27:08, 13.45it/s]

Step: 28600. Loss: 93.3273696899414.


 57%|█████████████████████                | 28703/50500 [35:42<27:02, 13.44it/s]

Step: 28700. Loss: 96.20332336425781.


 57%|█████████████████████                | 28803/50500 [35:50<26:54, 13.44it/s]

Step: 28800. Loss: 103.44247436523438.


 57%|█████████████████████▏               | 28903/50500 [35:57<26:55, 13.37it/s]

Step: 28900. Loss: 93.92098999023438.


 57%|█████████████████████▏               | 29003/50500 [36:05<26:49, 13.36it/s]

Step: 29000. Loss: 88.23922729492188.


 58%|█████████████████████▎               | 29103/50500 [36:12<26:34, 13.42it/s]

Step: 29100. Loss: 81.23683166503906.


 58%|█████████████████████▍               | 29203/50500 [36:20<26:21, 13.46it/s]

Step: 29200. Loss: 96.36618041992188.


 58%|█████████████████████▍               | 29303/50500 [36:27<26:22, 13.39it/s]

Step: 29300. Loss: 94.86766052246094.


 58%|█████████████████████▌               | 29403/50500 [36:35<26:23, 13.32it/s]

Step: 29400. Loss: 89.23555755615234.


 58%|█████████████████████▌               | 29503/50500 [36:42<26:03, 13.43it/s]

Step: 29500. Loss: 105.13259887695312.


 59%|█████████████████████▋               | 29603/50500 [36:50<26:06, 13.34it/s]

Step: 29600. Loss: 99.04861450195312.


 59%|█████████████████████▊               | 29703/50500 [36:57<25:50, 13.41it/s]

Step: 29700. Loss: 108.13545227050781.


 59%|█████████████████████▊               | 29803/50500 [37:05<25:45, 13.39it/s]

Step: 29800. Loss: 80.44831848144531.


 59%|█████████████████████▉               | 29903/50500 [37:12<25:38, 13.39it/s]

Step: 29900. Loss: 74.19056701660156.


 59%|█████████████████████▉               | 30003/50500 [37:20<25:28, 13.41it/s]

Step: 30000. Loss: 73.08006286621094.


 60%|██████████████████████               | 30103/50500 [37:27<25:18, 13.43it/s]

Step: 30100. Loss: 104.15351867675781.


 60%|██████████████████████▏              | 30203/50500 [37:35<25:15, 13.39it/s]

Step: 30200. Loss: 75.56820678710938.


 60%|██████████████████████▏              | 30303/50500 [37:42<25:15, 13.32it/s]

Step: 30300. Loss: 95.66142272949219.


 60%|██████████████████████▎              | 30403/50500 [37:50<25:02, 13.38it/s]

Step: 30400. Loss: 68.45867919921875.


 60%|██████████████████████▎              | 30503/50500 [37:57<24:50, 13.41it/s]

Step: 30500. Loss: 77.08241271972656.


 61%|██████████████████████▍              | 30603/50500 [38:05<24:50, 13.35it/s]

Step: 30600. Loss: 99.47505187988281.


 61%|██████████████████████▍              | 30703/50500 [38:12<24:34, 13.43it/s]

Step: 30700. Loss: 93.48577880859375.


 61%|██████████████████████▌              | 30803/50500 [38:20<24:29, 13.40it/s]

Step: 30800. Loss: 89.24101257324219.


 61%|██████████████████████▋              | 30903/50500 [38:28<24:24, 13.38it/s]

Step: 30900. Loss: 87.3937759399414.


 61%|██████████████████████▋              | 31003/50500 [38:35<24:07, 13.47it/s]

Step: 31000. Loss: 119.08283996582031.


 62%|██████████████████████▊              | 31103/50500 [38:43<24:06, 13.41it/s]

Step: 31100. Loss: 74.04161834716797.


 62%|██████████████████████▊              | 31203/50500 [38:50<24:00, 13.39it/s]

Step: 31200. Loss: 98.15983581542969.


 62%|██████████████████████▉              | 31303/50500 [38:58<23:49, 13.43it/s]

Step: 31300. Loss: 72.6229248046875.


 62%|███████████████████████              | 31403/50500 [39:05<23:41, 13.43it/s]

Step: 31400. Loss: 95.38567352294922.


 62%|███████████████████████              | 31503/50500 [39:13<23:33, 13.44it/s]

Step: 31500. Loss: 80.71155548095703.


 63%|███████████████████████▏             | 31603/50500 [39:20<23:26, 13.44it/s]

Step: 31600. Loss: 95.48129272460938.


 63%|███████████████████████▏             | 31703/50500 [39:28<23:22, 13.40it/s]

Step: 31700. Loss: 86.02764129638672.


 63%|███████████████████████▎             | 31803/50500 [39:35<23:12, 13.42it/s]

Step: 31800. Loss: 89.10328674316406.


 63%|███████████████████████▎             | 31903/50500 [39:43<23:03, 13.45it/s]

Step: 31900. Loss: 81.46302795410156.


 63%|███████████████████████▍             | 32003/50500 [39:50<22:56, 13.44it/s]

Step: 32000. Loss: 100.40729522705078.


 64%|███████████████████████▌             | 32103/50500 [39:58<22:46, 13.46it/s]

Step: 32100. Loss: 91.52792358398438.


 64%|███████████████████████▌             | 32203/50500 [40:05<22:59, 13.26it/s]

Step: 32200. Loss: 76.55316162109375.


 64%|███████████████████████▋             | 32303/50500 [40:13<22:34, 13.43it/s]

Step: 32300. Loss: 87.8094482421875.


 64%|███████████████████████▋             | 32403/50500 [40:20<22:26, 13.44it/s]

Step: 32400. Loss: 78.59794616699219.


 64%|███████████████████████▊             | 32503/50500 [40:28<22:19, 13.44it/s]

Step: 32500. Loss: 100.34217071533203.


 65%|███████████████████████▉             | 32603/50500 [40:35<22:17, 13.38it/s]

Step: 32600. Loss: 80.1638412475586.


 65%|███████████████████████▉             | 32703/50500 [40:43<22:06, 13.42it/s]

Step: 32700. Loss: 100.12422180175781.


 65%|████████████████████████             | 32803/50500 [40:50<21:54, 13.46it/s]

Step: 32800. Loss: 80.00196075439453.


 65%|████████████████████████             | 32903/50500 [40:58<21:50, 13.43it/s]

Step: 32900. Loss: 79.54926300048828.


 65%|████████████████████████▏            | 33003/50500 [41:05<21:41, 13.44it/s]

Step: 33000. Loss: 83.24177551269531.


 66%|████████████████████████▎            | 33103/50500 [41:13<21:36, 13.42it/s]

Step: 33100. Loss: 79.24662780761719.


 66%|████████████████████████▎            | 33203/50500 [41:20<21:29, 13.41it/s]

Step: 33200. Loss: 87.42100524902344.


 66%|████████████████████████▍            | 33303/50500 [41:28<21:21, 13.42it/s]

Step: 33300. Loss: 73.72775268554688.


 66%|████████████████████████▍            | 33403/50500 [41:35<21:12, 13.43it/s]

Step: 33400. Loss: 82.82034301757812.


 66%|████████████████████████▌            | 33503/50500 [41:43<21:09, 13.39it/s]

Step: 33500. Loss: 61.18920135498047.


 67%|████████████████████████▌            | 33603/50500 [41:50<21:03, 13.37it/s]

Step: 33600. Loss: 67.1246109008789.


 67%|████████████████████████▋            | 33703/50500 [41:58<20:53, 13.40it/s]

Step: 33700. Loss: 77.36475372314453.


 67%|████████████████████████▊            | 33803/50500 [42:05<20:45, 13.40it/s]

Step: 33800. Loss: 83.56808471679688.


 67%|████████████████████████▊            | 33903/50500 [42:13<20:35, 13.44it/s]

Step: 33900. Loss: 91.83465576171875.


 67%|████████████████████████▉            | 34003/50500 [42:20<20:35, 13.35it/s]

Step: 34000. Loss: 79.27576446533203.


 68%|████████████████████████▉            | 34103/50500 [42:28<20:21, 13.43it/s]

Step: 34100. Loss: 74.20728302001953.


 68%|█████████████████████████            | 34203/50500 [42:35<20:15, 13.40it/s]

Step: 34200. Loss: 83.95379638671875.


 68%|█████████████████████████▏           | 34303/50500 [42:43<20:05, 13.44it/s]

Step: 34300. Loss: 122.79902648925781.


 68%|█████████████████████████▏           | 34403/50500 [42:50<20:00, 13.41it/s]

Step: 34400. Loss: 73.10185241699219.


 68%|█████████████████████████▎           | 34503/50500 [42:58<19:53, 13.40it/s]

Step: 34500. Loss: 83.93841552734375.


 69%|█████████████████████████▎           | 34603/50500 [43:05<19:52, 13.33it/s]

Step: 34600. Loss: 75.88768005371094.


 69%|█████████████████████████▍           | 34703/50500 [43:13<19:38, 13.41it/s]

Step: 34700. Loss: 93.97760772705078.


 69%|█████████████████████████▍           | 34803/50500 [43:21<19:27, 13.44it/s]

Step: 34800. Loss: 67.51457977294922.


 69%|█████████████████████████▌           | 34903/50500 [43:28<19:30, 13.33it/s]

Step: 34900. Loss: 94.86296844482422.


 69%|█████████████████████████▋           | 35003/50500 [43:36<19:12, 13.44it/s]

Step: 35000. Loss: 92.63348388671875.


 70%|█████████████████████████▋           | 35103/50500 [43:43<19:09, 13.40it/s]

Step: 35100. Loss: 89.17696380615234.


 70%|█████████████████████████▊           | 35203/50500 [43:51<19:00, 13.41it/s]

Step: 35200. Loss: 88.91207885742188.


 70%|█████████████████████████▊           | 35303/50500 [43:58<18:51, 13.44it/s]

Step: 35300. Loss: 78.64909362792969.


 70%|█████████████████████████▉           | 35403/50500 [44:06<18:44, 13.43it/s]

Step: 35400. Loss: 79.0592041015625.


 70%|██████████████████████████           | 35503/50500 [44:13<18:40, 13.39it/s]

Step: 35500. Loss: 79.27100372314453.


 71%|██████████████████████████           | 35603/50500 [44:21<18:28, 13.44it/s]

Step: 35600. Loss: 78.28834533691406.


 71%|██████████████████████████▏          | 35703/50500 [44:28<18:23, 13.41it/s]

Step: 35700. Loss: 76.39346313476562.


 71%|██████████████████████████▏          | 35803/50500 [44:36<18:15, 13.41it/s]

Step: 35800. Loss: 68.37700653076172.


 71%|██████████████████████████▎          | 35903/50500 [44:43<18:04, 13.46it/s]

Step: 35900. Loss: 69.36538696289062.


 71%|██████████████████████████▍          | 36003/50500 [44:51<18:01, 13.40it/s]

Step: 36000. Loss: 65.91368103027344.


 71%|██████████████████████████▍          | 36103/50500 [44:58<17:58, 13.35it/s]

Step: 36100. Loss: 77.91452026367188.


 72%|██████████████████████████▌          | 36203/50500 [45:06<17:44, 13.44it/s]

Step: 36200. Loss: 86.00668334960938.


 72%|██████████████████████████▌          | 36303/50500 [45:13<17:36, 13.44it/s]

Step: 36300. Loss: 49.44158935546875.


 72%|██████████████████████████▋          | 36403/50500 [45:21<17:32, 13.40it/s]

Step: 36400. Loss: 85.17667388916016.


 72%|██████████████████████████▋          | 36503/50500 [45:28<17:22, 13.43it/s]

Step: 36500. Loss: 74.26327514648438.


 72%|██████████████████████████▊          | 36603/50500 [45:36<17:16, 13.40it/s]

Step: 36600. Loss: 78.56756591796875.


 73%|██████████████████████████▉          | 36703/50500 [45:43<17:08, 13.42it/s]

Step: 36700. Loss: 73.84414672851562.


 73%|██████████████████████████▉          | 36803/50500 [45:51<17:04, 13.37it/s]

Step: 36800. Loss: 98.28280639648438.


 73%|███████████████████████████          | 36903/50500 [45:58<16:52, 13.43it/s]

Step: 36900. Loss: 81.61331176757812.


 73%|███████████████████████████          | 37003/50500 [46:06<16:47, 13.40it/s]

Step: 37000. Loss: 68.42547607421875.


 73%|███████████████████████████▏         | 37103/50500 [46:13<16:38, 13.42it/s]

Step: 37100. Loss: 81.25556945800781.


 74%|███████████████████████████▎         | 37203/50500 [46:21<16:33, 13.38it/s]

Step: 37200. Loss: 73.4244384765625.


 74%|███████████████████████████▎         | 37303/50500 [46:28<16:22, 13.43it/s]

Step: 37300. Loss: 74.96085357666016.


 74%|███████████████████████████▍         | 37403/50500 [46:36<16:14, 13.44it/s]

Step: 37400. Loss: 90.92279052734375.


 74%|███████████████████████████▍         | 37503/50500 [46:43<16:11, 13.38it/s]

Step: 37500. Loss: 85.31501770019531.


 74%|███████████████████████████▌         | 37603/50500 [46:51<15:59, 13.44it/s]

Step: 37600. Loss: 83.43257141113281.


 75%|███████████████████████████▌         | 37703/50500 [46:59<15:55, 13.39it/s]

Step: 37700. Loss: 84.50452423095703.


 75%|███████████████████████████▋         | 37803/50500 [47:06<15:48, 13.39it/s]

Step: 37800. Loss: 70.57473754882812.


 75%|███████████████████████████▊         | 37903/50500 [47:14<15:37, 13.43it/s]

Step: 37900. Loss: 72.81790161132812.


 75%|███████████████████████████▊         | 38003/50500 [47:21<15:29, 13.45it/s]

Step: 38000. Loss: 72.23662567138672.


 75%|███████████████████████████▉         | 38103/50500 [47:29<15:26, 13.38it/s]

Step: 38100. Loss: 89.28501892089844.


 76%|███████████████████████████▉         | 38203/50500 [47:36<15:15, 13.44it/s]

Step: 38200. Loss: 79.66082763671875.


 76%|████████████████████████████         | 38303/50500 [47:44<15:06, 13.45it/s]

Step: 38300. Loss: 71.69633483886719.


 76%|████████████████████████████▏        | 38403/50500 [47:51<15:02, 13.40it/s]

Step: 38400. Loss: 91.82907104492188.


 76%|████████████████████████████▏        | 38503/50500 [47:59<14:57, 13.36it/s]

Step: 38500. Loss: 63.12704086303711.


 76%|████████████████████████████▎        | 38603/50500 [48:06<14:45, 13.43it/s]

Step: 38600. Loss: 65.66427612304688.


 77%|████████████████████████████▎        | 38703/50500 [48:14<14:38, 13.43it/s]

Step: 38700. Loss: 60.4385986328125.


 77%|████████████████████████████▍        | 38803/50500 [48:21<14:32, 13.40it/s]

Step: 38800. Loss: 82.02214050292969.


 77%|████████████████████████████▌        | 38903/50500 [48:29<14:24, 13.41it/s]

Step: 38900. Loss: 76.33029174804688.


 77%|████████████████████████████▌        | 39003/50500 [48:36<14:16, 13.42it/s]

Step: 39000. Loss: 69.69952392578125.


 77%|████████████████████████████▋        | 39103/50500 [48:44<14:04, 13.49it/s]

Step: 39100. Loss: 104.83749389648438.


 78%|████████████████████████████▋        | 39203/50500 [48:51<14:00, 13.44it/s]

Step: 39200. Loss: 68.93367767333984.


 78%|████████████████████████████▊        | 39303/50500 [48:59<13:53, 13.43it/s]

Step: 39300. Loss: 74.96273803710938.


 78%|████████████████████████████▊        | 39403/50500 [49:06<13:45, 13.45it/s]

Step: 39400. Loss: 66.3116455078125.


 78%|████████████████████████████▉        | 39503/50500 [49:14<13:40, 13.41it/s]

Step: 39500. Loss: 84.84394836425781.


 78%|█████████████████████████████        | 39603/50500 [49:21<13:31, 13.43it/s]

Step: 39600. Loss: 80.99610900878906.


 79%|█████████████████████████████        | 39703/50500 [49:29<13:25, 13.41it/s]

Step: 39700. Loss: 77.18980407714844.


 79%|█████████████████████████████▏       | 39803/50500 [49:36<13:13, 13.48it/s]

Step: 39800. Loss: 74.10855865478516.


 79%|█████████████████████████████▏       | 39903/50500 [49:44<13:14, 13.34it/s]

Step: 39900. Loss: 78.8731689453125.


 79%|█████████████████████████████▎       | 40003/50500 [49:51<12:57, 13.50it/s]

Step: 40000. Loss: 69.98926544189453.


 79%|█████████████████████████████▍       | 40103/50500 [49:59<12:54, 13.42it/s]

Step: 40100. Loss: 81.93850708007812.


 80%|█████████████████████████████▍       | 40203/50500 [50:06<12:48, 13.39it/s]

Step: 40200. Loss: 81.12653350830078.


 80%|█████████████████████████████▌       | 40303/50500 [50:14<12:39, 13.43it/s]

Step: 40300. Loss: 80.09648132324219.


 80%|█████████████████████████████▌       | 40403/50500 [50:21<12:31, 13.44it/s]

Step: 40400. Loss: 79.25038146972656.


 80%|█████████████████████████████▋       | 40503/50500 [50:29<12:24, 13.42it/s]

Step: 40500. Loss: 84.84245300292969.


 80%|█████████████████████████████▋       | 40603/50500 [50:36<12:17, 13.42it/s]

Step: 40600. Loss: 65.66069030761719.


 81%|█████████████████████████████▊       | 40703/50500 [50:44<12:08, 13.44it/s]

Step: 40700. Loss: 94.85089111328125.


 81%|█████████████████████████████▉       | 40803/50500 [50:51<12:02, 13.42it/s]

Step: 40800. Loss: 75.0006103515625.


 81%|█████████████████████████████▉       | 40903/50500 [50:59<11:56, 13.40it/s]

Step: 40900. Loss: 54.4161376953125.


 81%|██████████████████████████████       | 41003/50500 [51:06<11:55, 13.28it/s]

Step: 41000. Loss: 73.0091552734375.


 81%|██████████████████████████████       | 41103/50500 [51:14<11:39, 13.43it/s]

Step: 41100. Loss: 59.41317367553711.


 82%|██████████████████████████████▏      | 41203/50500 [51:21<11:31, 13.44it/s]

Step: 41200. Loss: 65.06781768798828.


 82%|██████████████████████████████▎      | 41303/50500 [51:29<11:24, 13.44it/s]

Step: 41300. Loss: 48.02333068847656.


 82%|██████████████████████████████▎      | 41403/50500 [51:36<11:18, 13.40it/s]

Step: 41400. Loss: 69.07743835449219.


 82%|██████████████████████████████▍      | 41503/50500 [51:44<11:09, 13.43it/s]

Step: 41500. Loss: 77.69746398925781.


 82%|██████████████████████████████▍      | 41603/50500 [51:51<11:03, 13.41it/s]

Step: 41600. Loss: 83.34941101074219.


 83%|██████████████████████████████▌      | 41703/50500 [51:59<10:54, 13.44it/s]

Step: 41700. Loss: 61.11735153198242.


 83%|██████████████████████████████▋      | 41803/50500 [52:07<10:58, 13.22it/s]

Step: 41800. Loss: 70.06785583496094.


 83%|██████████████████████████████▋      | 41903/50500 [52:14<10:39, 13.44it/s]

Step: 41900. Loss: 55.6710319519043.


 83%|██████████████████████████████▊      | 42003/50500 [52:22<10:32, 13.42it/s]

Step: 42000. Loss: 70.47676086425781.


 83%|██████████████████████████████▊      | 42103/50500 [52:29<10:25, 13.43it/s]

Step: 42100. Loss: 73.4420394897461.


 84%|██████████████████████████████▉      | 42203/50500 [52:37<10:18, 13.42it/s]

Step: 42200. Loss: 77.98326110839844.


 84%|██████████████████████████████▉      | 42303/50500 [52:44<10:13, 13.36it/s]

Step: 42300. Loss: 84.12267303466797.


 84%|███████████████████████████████      | 42403/50500 [52:52<10:03, 13.41it/s]

Step: 42400. Loss: 79.89093017578125.


 84%|███████████████████████████████▏     | 42503/50500 [52:59<09:56, 13.42it/s]

Step: 42500. Loss: 72.83149719238281.


 84%|███████████████████████████████▏     | 42603/50500 [53:07<09:47, 13.45it/s]

Step: 42600. Loss: 55.48389434814453.


 85%|███████████████████████████████▎     | 42703/50500 [53:14<09:39, 13.46it/s]

Step: 42700. Loss: 49.25529098510742.


 85%|███████████████████████████████▎     | 42803/50500 [53:22<09:32, 13.44it/s]

Step: 42800. Loss: 78.29852294921875.


 85%|███████████████████████████████▍     | 42903/50500 [53:29<09:26, 13.42it/s]

Step: 42900. Loss: 59.083290100097656.


 85%|███████████████████████████████▌     | 43003/50500 [53:37<09:20, 13.36it/s]

Step: 43000. Loss: 70.55856323242188.


 85%|███████████████████████████████▌     | 43103/50500 [53:44<09:14, 13.34it/s]

Step: 43100. Loss: 54.28638458251953.


 86%|███████████████████████████████▋     | 43203/50500 [53:52<09:04, 13.39it/s]

Step: 43200. Loss: 70.80279541015625.


 86%|███████████████████████████████▋     | 43303/50500 [53:59<08:56, 13.42it/s]

Step: 43300. Loss: 90.14986419677734.


 86%|███████████████████████████████▊     | 43403/50500 [54:07<08:56, 13.24it/s]

Step: 43400. Loss: 61.67013168334961.


 86%|███████████████████████████████▊     | 43503/50500 [54:14<08:41, 13.43it/s]

Step: 43500. Loss: 87.01399230957031.


 86%|███████████████████████████████▉     | 43603/50500 [54:22<08:35, 13.38it/s]

Step: 43600. Loss: 60.87409973144531.


 87%|████████████████████████████████     | 43703/50500 [54:29<08:29, 13.34it/s]

Step: 43700. Loss: 58.471317291259766.


 87%|████████████████████████████████     | 43803/50500 [54:37<08:18, 13.43it/s]

Step: 43800. Loss: 64.69804382324219.


 87%|████████████████████████████████▏    | 43903/50500 [54:44<08:11, 13.42it/s]

Step: 43900. Loss: 70.1434097290039.


 87%|████████████████████████████████▏    | 44003/50500 [54:52<08:03, 13.43it/s]

Step: 44000. Loss: 52.58324432373047.


 87%|████████████████████████████████▎    | 44103/50500 [54:59<07:59, 13.35it/s]

Step: 44100. Loss: 76.14759826660156.


 88%|████████████████████████████████▍    | 44203/50500 [55:07<07:52, 13.34it/s]

Step: 44200. Loss: 70.49571990966797.


 88%|████████████████████████████████▍    | 44303/50500 [55:15<07:44, 13.34it/s]

Step: 44300. Loss: 74.2870864868164.


 88%|████████████████████████████████▌    | 44403/50500 [55:22<07:36, 13.36it/s]

Step: 44400. Loss: 74.96095275878906.


 88%|████████████████████████████████▌    | 44503/50500 [55:30<07:26, 13.43it/s]

Step: 44500. Loss: 57.82567596435547.


 88%|████████████████████████████████▋    | 44603/50500 [55:37<07:19, 13.41it/s]

Step: 44600. Loss: 63.62635803222656.


 89%|████████████████████████████████▊    | 44703/50500 [55:45<07:11, 13.43it/s]

Step: 44700. Loss: 67.14402770996094.


 89%|████████████████████████████████▊    | 44803/50500 [55:52<07:03, 13.45it/s]

Step: 44800. Loss: 69.44227600097656.


 89%|████████████████████████████████▉    | 44903/50500 [56:00<06:57, 13.42it/s]

Step: 44900. Loss: 66.53048706054688.


 89%|████████████████████████████████▉    | 45003/50500 [56:07<06:51, 13.36it/s]

Step: 45000. Loss: 74.51106262207031.


 89%|█████████████████████████████████    | 45103/50500 [56:15<06:44, 13.35it/s]

Step: 45100. Loss: 79.75823974609375.


 90%|█████████████████████████████████    | 45203/50500 [56:22<06:35, 13.40it/s]

Step: 45200. Loss: 65.36119842529297.


 90%|█████████████████████████████████▏   | 45303/50500 [56:30<06:27, 13.41it/s]

Step: 45300. Loss: 56.36444854736328.


 90%|█████████████████████████████████▎   | 45403/50500 [56:37<06:20, 13.39it/s]

Step: 45400. Loss: 55.91756820678711.


 90%|█████████████████████████████████▎   | 45503/50500 [56:45<06:13, 13.39it/s]

Step: 45500. Loss: 66.17269134521484.


 90%|█████████████████████████████████▍   | 45603/50500 [56:52<06:05, 13.41it/s]

Step: 45600. Loss: 63.06800842285156.


 91%|█████████████████████████████████▍   | 45703/50500 [57:00<05:58, 13.38it/s]

Step: 45700. Loss: 70.5035400390625.


 91%|█████████████████████████████████▌   | 45803/50500 [57:07<05:50, 13.40it/s]

Step: 45800. Loss: 61.214561462402344.


 91%|█████████████████████████████████▋   | 45903/50500 [57:15<05:42, 13.42it/s]

Step: 45900. Loss: 69.60832214355469.


 91%|█████████████████████████████████▋   | 46003/50500 [57:22<05:36, 13.35it/s]

Step: 46000. Loss: 75.44380950927734.


 91%|█████████████████████████████████▊   | 46103/50500 [57:30<05:28, 13.37it/s]

Step: 46100. Loss: 59.908512115478516.


 91%|█████████████████████████████████▊   | 46203/50500 [57:37<05:20, 13.43it/s]

Step: 46200. Loss: 52.52065658569336.


 92%|█████████████████████████████████▉   | 46303/50500 [57:45<05:13, 13.41it/s]

Step: 46300. Loss: 72.58753967285156.


 92%|█████████████████████████████████▉   | 46403/50500 [57:52<05:05, 13.42it/s]

Step: 46400. Loss: 74.8974609375.


 92%|██████████████████████████████████   | 46503/50500 [58:00<04:57, 13.43it/s]

Step: 46500. Loss: 68.42390441894531.


 92%|██████████████████████████████████▏  | 46603/50500 [58:08<04:51, 13.36it/s]

Step: 46600. Loss: 76.17155456542969.


 92%|██████████████████████████████████▏  | 46703/50500 [58:15<04:42, 13.42it/s]

Step: 46700. Loss: 66.470458984375.


 93%|██████████████████████████████████▎  | 46803/50500 [58:23<04:36, 13.39it/s]

Step: 46800. Loss: 56.919456481933594.


 93%|██████████████████████████████████▎  | 46903/50500 [58:30<04:29, 13.33it/s]

Step: 46900. Loss: 109.19233703613281.


 93%|██████████████████████████████████▍  | 47003/50500 [58:38<04:21, 13.39it/s]

Step: 47000. Loss: 90.07295989990234.


 93%|██████████████████████████████████▌  | 47103/50500 [58:45<04:11, 13.52it/s]

Step: 47100. Loss: 69.13546752929688.


 93%|██████████████████████████████████▌  | 47203/50500 [58:53<04:05, 13.43it/s]

Step: 47200. Loss: 56.78309631347656.


 94%|██████████████████████████████████▋  | 47303/50500 [59:00<03:59, 13.35it/s]

Step: 47300. Loss: 82.49928283691406.


 94%|██████████████████████████████████▋  | 47403/50500 [59:08<03:50, 13.43it/s]

Step: 47400. Loss: 62.08628463745117.


 94%|██████████████████████████████████▊  | 47503/50500 [59:15<03:43, 13.41it/s]

Step: 47500. Loss: 74.04443359375.


 94%|██████████████████████████████████▉  | 47603/50500 [59:23<03:36, 13.41it/s]

Step: 47600. Loss: 50.57925033569336.


 94%|██████████████████████████████████▉  | 47703/50500 [59:30<03:29, 13.38it/s]

Step: 47700. Loss: 85.4822998046875.


 95%|███████████████████████████████████  | 47803/50500 [59:38<03:20, 13.44it/s]

Step: 47800. Loss: 63.51370620727539.


 95%|███████████████████████████████████  | 47903/50500 [59:45<03:13, 13.42it/s]

Step: 47900. Loss: 57.261070251464844.


 95%|███████████████████████████████████▏ | 48003/50500 [59:53<03:05, 13.44it/s]

Step: 48000. Loss: 67.2392807006836.


 95%|█████████████████████████████████▎ | 48103/50500 [1:00:00<02:59, 13.38it/s]

Step: 48100. Loss: 82.79240417480469.


 95%|█████████████████████████████████▍ | 48203/50500 [1:00:08<02:52, 13.32it/s]

Step: 48200. Loss: 66.78306579589844.


 96%|█████████████████████████████████▍ | 48303/50500 [1:00:15<02:43, 13.42it/s]

Step: 48300. Loss: 68.45454406738281.


 96%|█████████████████████████████████▌ | 48403/50500 [1:00:23<02:35, 13.44it/s]

Step: 48400. Loss: 78.1523208618164.


 96%|█████████████████████████████████▌ | 48503/50500 [1:00:30<02:29, 13.34it/s]

Step: 48500. Loss: 94.08723449707031.


 96%|█████████████████████████████████▋ | 48603/50500 [1:00:38<02:21, 13.39it/s]

Step: 48600. Loss: 64.13429260253906.


 96%|█████████████████████████████████▊ | 48703/50500 [1:00:45<02:13, 13.42it/s]

Step: 48700. Loss: 61.63578796386719.


 97%|█████████████████████████████████▊ | 48803/50500 [1:00:53<02:06, 13.44it/s]

Step: 48800. Loss: 70.20697021484375.


 97%|█████████████████████████████████▉ | 48903/50500 [1:01:00<01:59, 13.40it/s]

Step: 48900. Loss: 70.97403717041016.


 97%|█████████████████████████████████▉ | 49003/50500 [1:01:08<01:51, 13.41it/s]

Step: 49000. Loss: 52.29248809814453.


 97%|██████████████████████████████████ | 49103/50500 [1:01:15<01:44, 13.42it/s]

Step: 49100. Loss: 45.910888671875.


 97%|██████████████████████████████████ | 49203/50500 [1:01:23<01:36, 13.41it/s]

Step: 49200. Loss: 86.45226287841797.


 98%|██████████████████████████████████▏| 49303/50500 [1:01:30<01:29, 13.38it/s]

Step: 49300. Loss: 65.64234161376953.


 98%|██████████████████████████████████▏| 49403/50500 [1:01:38<01:21, 13.39it/s]

Step: 49400. Loss: 68.6119384765625.


 98%|██████████████████████████████████▎| 49503/50500 [1:01:46<01:14, 13.30it/s]

Step: 49500. Loss: 69.27314758300781.


 98%|██████████████████████████████████▍| 49603/50500 [1:01:53<01:06, 13.49it/s]

Step: 49600. Loss: 48.653907775878906.


 98%|██████████████████████████████████▍| 49703/50500 [1:02:01<00:59, 13.43it/s]

Step: 49700. Loss: 54.94805908203125.


 99%|██████████████████████████████████▌| 49803/50500 [1:02:08<00:51, 13.41it/s]

Step: 49800. Loss: 80.5924072265625.


 99%|██████████████████████████████████▌| 49903/50500 [1:02:16<00:44, 13.41it/s]

Step: 49900. Loss: 71.53900146484375.


 99%|██████████████████████████████████▋| 50003/50500 [1:02:23<00:37, 13.30it/s]

Step: 50000. Loss: 58.494964599609375.


 99%|██████████████████████████████████▋| 50103/50500 [1:02:31<00:29, 13.43it/s]

Step: 50100. Loss: 67.34890747070312.


 99%|██████████████████████████████████▊| 50203/50500 [1:02:38<00:22, 13.42it/s]

Step: 50200. Loss: 71.88243103027344.


100%|██████████████████████████████████▊| 50303/50500 [1:02:46<00:14, 13.44it/s]

Step: 50300. Loss: 50.696067810058594.


100%|██████████████████████████████████▉| 50403/50500 [1:02:53<00:07, 13.41it/s]

Step: 50400. Loss: 63.19522476196289.


100%|███████████████████████████████████| 50500/50500 [1:03:01<00:00, 13.36it/s]
100%|███████████████████████████████████████████| 63/63 [00:01<00:00, 38.44it/s]


Val. Loss: 64.931923828125.
Saving to ../train_models/4_by_4_mult/gpt2/emulator_initial/checkpoint_0


In [None]:
import math
import torch
from torch.utils.data import DataLoader
from transformers import AdamW
import argparse
import os
import inspect
import tqdm
import logging
import random
from itertools import chain

from data import CoTDataset, CoTDataCollator, extract_answer
from models.student import Student
from models.emulator import Emulator
from utils import get_sep_position

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
random.seed(1234)
torch.manual_seed(1234)
logging.disable(logging.WARNING) # disable WARNING, INFO and DEBUG logging everywhere

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


@torch.no_grad()
def evaluate(dataloader, tokenizer, ctx, emulator, student, max_new_tokens):
    total_instances = 0
    total_tokens = 0
    total_correct = 0
    total_correct_tokens = 0
    total_loss = 0
    for batch in tqdm.tqdm(dataloader):
        #import pdb; pdb.set_trace()
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        labels_nocot = batch['labels_nocot'].to(device)
        batch_size = input_ids_nocot.shape[0]
        with ctx:
            emulated_teacher_states = emulator(input_ids=input_ids_nocot)
            outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=emulated_teacher_states)
            loss = outputs.loss
            token_accuracy = outputs.token_accuracy.item()
        total_loss += outputs.total_loss.item()
        total_correct_tokens += outputs.total_correct.item()
        total_tokens += outputs.total_tokens
        total_instances += batch_size

        # Generate
        with ctx:
            beam_output = student.generate(
                input_ids=input_ids_nocot,
                teacher_states=emulated_teacher_states,
                max_new_tokens=max_new_tokens,
            )

        # Evaluate
        sep_positions = get_sep_position(input_ids_nocot, tokenizer.eos_token_id)
        for i, (input_ids_i, beam_output_i) in enumerate(zip(input_ids_nocot, beam_output)):
            sep_position = sep_positions[i].item()
            tgt = input_ids_i[sep_position+1:]
            tgt_text = tokenizer.decode(tgt, skip_special_tokens=True)
            ans = extract_answer(tgt_text)
            pred_text = tokenizer.decode(beam_output_i[0][sep_position+1:], skip_special_tokens=True)
            pred_ans = extract_answer(pred_text)
            if ans == pred_ans:
                total_correct += 1
            if i == 0:
                print (f'Input: {tokenizer.decode(input_ids_i[:sep_position], skip_special_tokens=True)}')
                print (f'Target: {tgt_text}')
                print (f'Predicted: {pred_text}')
                print ('')
    accuracy = total_correct / total_instances
    token_accuracy = total_correct_tokens / total_tokens
    loss = total_loss / total_tokens
    ppl = math.exp(loss)
    return accuracy, token_accuracy, ppl




In [None]:
coupled_emulator_student_trainer_args=dict(
    debug=False,
    train_path="../data/4_by_4_mult/train.txt",
    val_path="../data/4_by_4_mult/valid.txt",
    teacher="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    save_model="../train_models/4_by_4_mult/gpt2/emulator_initial",
    base_model='gpt2',
    epochs=1,
    batch_size=16,
    emulator="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    student="../train_models/4_by_4_mult/gpt2/teacher/checkpoint_0",
    lr=5e-5,
    max_new_tokens=128,
    delta='dynamic',
    max_grad_norm=1.0,
    softmax_temperature=0.05,
    fix_emulator=False,
    
)

from types import SimpleNamespace

args = SimpleNamespace(**coupled_emulator_student_trainer_args)

In [None]:

# parser = argparse.ArgumentParser()
# parser.add_argument('--emulator', type=str, required=True)
# parser.add_argument('--student', type=str, required=True)
# parser.add_argument('--train_path', type=str, required=True)
# parser.add_argument('--val_path', type=str, required=True)
# parser.add_argument('--save_model', type=str, required=True)
# parser.add_argument('--max_new_tokens', type=int, default=128)
# parser.add_argument('--epochs', type=int, default=5)
# parser.add_argument('--batch_size', type=int, default=32)
# parser.add_argument('--lr', type=float, default=5e-5)
# parser.add_argument('--max_grad_norm', type=float, default=1.0)
# parser.add_argument('--softmax_temperature', type=float, default=0.05)
# parser.add_argument('--fix_emulator', dest='fix_emulator', action='store_true')
# parser.set_defaults(fix_emulator=False)
# args = parser.parse_args()

# print (args)

dtype = 'float32'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ctx = torch.amp.autocast(device_type='cuda', dtype=ptdtype)
print (ptdtype, dtype, device)

# Load Student
student = Student.from_pretrained(args.student).to(device).to(ptdtype)

# Load Emulator
emulator = Emulator.from_pretrained(args.emulator).to(device).to(ptdtype)

# Load data
tokenizer = emulator.tokenizer
collate_fn = CoTDataCollator(tokenizer)
train_dataset = CoTDataset(tokenizer, args.train_path, 1024)
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=True)
val_dataset = CoTDataset(tokenizer, args.val_path, 1024)
val_dataloader = DataLoader(val_dataset, batch_size=args.batch_size, collate_fn=collate_fn, shuffle=False)

# Create Optimizer
if args.fix_emulator:
    trainable_params = list(student.parameters())
    for p in emulator.parameters():
        p.requires_grad = False
else:
    trainable_params = list(student.parameters()) + list(emulator.parameters())
use_fused = 'fused' in inspect.signature(torch.optim.AdamW).parameters
extra_args = dict(fused=True) if use_fused else dict()
optimizer = torch.optim.AdamW(trainable_params, lr=args.lr, **extra_args)

emulator.eval() # to turn off dropout
student.eval() # to turn off dropout


# Train
step = 0
for epoch in range(args.epochs):
    print(f"Epoch {epoch}")

    for batch in tqdm.tqdm(train_dataloader):
        #import pdb; pdb.set_trace()
        input_ids_nocot = batch['input_ids_nocot'].to(device)
        labels_nocot = batch['labels_nocot'].to(device)
        with ctx:
            emulated_teacher_states = emulator(input_ids_nocot, requires_backward=not args.fix_emulator)
            outputs = student.compute_loss(input_ids=input_ids_nocot, labels=labels_nocot, teacher_states=emulated_teacher_states)
        loss = outputs.loss
        token_accuracy = outputs.token_accuracy.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(trainable_params, args.max_grad_norm)
        optimizer.step()
        optimizer.zero_grad()
        ppl = loss.exp().item()
        if step % 100 == 0:
            print (f"Step: {step}. PPL: {ppl}. Token Accuracy: {token_accuracy}")
        step += 1
    accuracy, token_accuracy, ppl = evaluate(val_dataloader, tokenizer, ctx, emulator, student, args.max_new_tokens)
    print (f'Val. PPL: {ppl}; Accuracy: {accuracy}; Token Accuracy: {token_accuracy}.')
    student.save_pretrained(os.path.join(args.save_model, 'student', f'checkpoint_{epoch}'))
    emulator.save_pretrained(os.path.join(args.save_model, 'emulator',  f'checkpoint_{epoch}'))