In [64]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [73]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from alu import ArithmeticAttentionModel

In [74]:
model = ArithmeticAttentionModel()

In [75]:
num1 = torch.tensor([[1]]).float()
num2 = torch.tensor([[2]]).float()
op = torch.tensor([[0, 0, 1, 0]]).float()

In [76]:
model(num1, num2, op)

tensor([[37456.8516]], grad_fn=<SumBackward1>)

In [77]:
def arithmetic_loss(predictions, targets, scale_factor=10000.0):
    abs_error = (predictions - targets)**2
    # rel_error = torch.abs((predictions - targets) / (targets + 1e-8)) * scale_factor
    loss = abs_error # + rel_error
    return torch.sum(loss)

In [78]:
class ArithmeticDataGenerator:
    def __init__(self, min_val=0, max_val=512, device='cuda'):
        self.min_val = min_val
        self.max_val = max_val
        self.device = device
        self.operations = {
            0: lambda x, y: x + y,    # addition
            1: lambda x, y: x - y,    # subtraction
            2: lambda x, y: x * y,    # multiplication
            3: lambda x, y: x / (y + 1e-8)  # division
        }
    
    def generate_batch(self, batch_size):
        # Generate random numbers
        num1 = torch.rand(batch_size, device=self.device) * (self.max_val - self.min_val) + self.min_val
        num2 = torch.rand(batch_size, device=self.device) * (self.max_val - self.min_val) + self.min_val
        
        # Generate random operations
        op_idx = torch.randint(0, 1, (batch_size,), device=self.device)
        operation = F.one_hot(op_idx, num_classes=4).float()
        
        # Calculate targets
        targets = torch.zeros(batch_size, 1, device=self.device)
        for i, op in enumerate(op_idx):
            targets[i] = self.operations[op.item()](num1[i], num2[i])
        
        return num1.view(-1, 1), num2.view(-1, 1), operation, targets


In [79]:
from tqdm import tqdm
import numpy as np

def train_model(
    model,
    num_epochs=100,
    batch_size=64,
    learning_rate=1e-4,
    device='cuda',
    eval_every=100
):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    data_gen = ArithmeticDataGenerator(device=device)
    
    steps_per_epoch = 1000
    best_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        epoch_losses = []
        
        # Progress bar for each epoch
        pbar = tqdm(range(steps_per_epoch), desc=f'Epoch {epoch+1}/{num_epochs}')
        
        for step in pbar:
            num1, num2, operation, targets = data_gen.generate_batch(batch_size)
            
            optimizer.zero_grad()
            predictions = model(num1, num2, operation)
            loss = arithmetic_loss(predictions, targets)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_losses.append(loss.item())
            
            # Update progress bar
            if step % eval_every == 0:
                # Evaluate on a test batch
                with torch.no_grad():
                    model.eval()
                    test_num1, test_num2, test_op, test_targets = data_gen.generate_batch(batch_size)
                    test_pred = model(test_num1, test_num2, test_op)
                    test_loss = arithmetic_loss(test_pred, test_targets)
                    
                    # Sample prediction
                    idx = 0
                    op_name = ['add', 'subtract', 'multiply', 'divide'][test_op[idx].argmax().item()]
                    print( #pbar.set_description(
                        f'Epoch {epoch+1}/{num_epochs} | '
                        f'Loss: {loss.item():.4f} | '
                        f'Test Loss: {test_loss.item():.4f} | '
                        f'Sample: {test_num1[idx].item():.2f} {op_name} {test_num2[idx].item():.2f} = '
                        f'{test_pred[idx].item():.2f} (true: {test_targets[idx].item():.2f})'
                    )
                model.train()
            
        # End of epoch
        avg_loss = np.mean(epoch_losses)
        if avg_loss < best_loss:
            best_loss = avg_loss
            torch.save(model.state_dict(), 'best_arithmetic_model.pt')
            
        scheduler.step()
        
        print(f'\nEpoch {epoch+1} completed. Average loss: {avg_loss:.4f}')


In [None]:
train_model(model, num_epochs=10, batch_size=1024, learning_rate=1e-4, device='cuda', eval_every=100)

Epoch 1/10:   0%|▍                                                                                | 5/1000 [00:00<00:40, 24.58it/s]

Epoch 1/10 | Loss: 525481.0625 | Test Loss: 278102432.0000 | Sample: 342.24 add 436.77 = 21.18 (true: 779.02)


Epoch 1/10:  10%|████████▏                                                                      | 104/1000 [00:03<00:33, 26.42it/s]

Epoch 1/10 | Loss: 38304568.0000 | Test Loss: 37568072.0000 | Sample: 496.06 add 487.82 = 654.38 (true: 983.88)


Epoch 1/10:  21%|████████████████▎                                                              | 206/1000 [00:07<00:30, 26.43it/s]

Epoch 1/10 | Loss: 33076252.0000 | Test Loss: 38168648.0000 | Sample: 481.80 add 379.89 = 481.59 (true: 861.70)


Epoch 1/10:  30%|████████████████████████                                                       | 305/1000 [00:11<00:26, 26.44it/s]

Epoch 1/10 | Loss: 13467956.0000 | Test Loss: 17209822.0000 | Sample: 236.60 add 40.97 = 404.44 (true: 277.56)


Epoch 1/10:  40%|███████████████████████████████▉                                               | 404/1000 [00:14<00:22, 26.45it/s]

Epoch 1/10 | Loss: 4972108.5000 | Test Loss: 3321164.7500 | Sample: 90.21 add 370.10 = 429.28 (true: 460.31)


Epoch 1/10:  51%|███████████████████████████████████████▉                                       | 506/1000 [00:18<00:18, 26.48it/s]

Epoch 1/10 | Loss: 1822950.8750 | Test Loss: 1544368.7500 | Sample: 509.41 add 400.27 = 952.79 (true: 909.67)


Epoch 1/10:  60%|███████████████████████████████████████████████▊                               | 605/1000 [00:21<00:14, 26.43it/s]

Epoch 1/10 | Loss: 1901642.7500 | Test Loss: 1893663.1250 | Sample: 490.24 add 4.15 = 485.80 (true: 494.39)


Epoch 1/10:  70%|███████████████████████████████████████████████████████▌                       | 704/1000 [00:25<00:11, 26.40it/s]

Epoch 1/10 | Loss: 1258587.2500 | Test Loss: 1164701.7500 | Sample: 23.56 add 82.96 = 110.19 (true: 106.52)


Epoch 1/10:  81%|███████████████████████████████████████████████████████████████▋               | 806/1000 [00:29<00:07, 26.46it/s]

Epoch 1/10 | Loss: 236874.5625 | Test Loss: 464224.7500 | Sample: 106.26 add 55.70 = 164.24 (true: 161.96)


Epoch 1/10:  90%|███████████████████████████████████████████████████████████████████████▍       | 905/1000 [00:32<00:03, 26.45it/s]

Epoch 1/10 | Loss: 639599.0000 | Test Loss: 1290707.7500 | Sample: 404.57 add 39.63 = 477.23 (true: 444.20)


Epoch 1/10: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:36<00:00, 27.64it/s]



Epoch 1 completed. Average loss: 40416483.5606


Epoch 2/10:   0%|▏                                                                                | 2/1000 [00:00<00:50, 19.72it/s]

Epoch 2/10 | Loss: 229642.9062 | Test Loss: 557000.4375 | Sample: 7.78 add 401.17 = 400.96 (true: 408.94)


Epoch 2/10:  10%|████████▏                                                                      | 104/1000 [00:03<00:33, 26.41it/s]

Epoch 2/10 | Loss: 607985.1250 | Test Loss: 445845.5938 | Sample: 505.58 add 454.80 = 979.67 (true: 960.39)


Epoch 2/10:  21%|████████████████▎                                                              | 206/1000 [00:07<00:30, 26.43it/s]

Epoch 2/10 | Loss: 323511.6875 | Test Loss: 540097.2500 | Sample: 143.57 add 360.26 = 510.05 (true: 503.83)


Epoch 2/10:  30%|████████████████████████                                                       | 305/1000 [00:11<00:26, 26.48it/s]

Epoch 2/10 | Loss: 6860557.0000 | Test Loss: 4148340.5000 | Sample: 129.28 add 285.13 = 370.84 (true: 414.41)


Epoch 2/10:  40%|███████████████████████████████▉                                               | 404/1000 [00:14<00:22, 26.46it/s]

Epoch 2/10 | Loss: 3288732.7500 | Test Loss: 3813986.5000 | Sample: 438.96 add 312.24 = 742.78 (true: 751.20)


Epoch 2/10:  51%|███████████████████████████████████████▉                                       | 506/1000 [00:18<00:18, 26.47it/s]

Epoch 2/10 | Loss: 1248537.1250 | Test Loss: 1139443.2500 | Sample: 501.86 add 438.93 = 922.82 (true: 940.79)


Epoch 2/10:  60%|███████████████████████████████████████████████▊                               | 605/1000 [00:21<00:14, 26.47it/s]

Epoch 2/10 | Loss: 403159.8125 | Test Loss: 2422470.0000 | Sample: 378.17 add 61.33 = 394.88 (true: 439.50)


Epoch 2/10:  70%|███████████████████████████████████████████████████████▌                       | 704/1000 [00:25<00:11, 26.51it/s]

Epoch 2/10 | Loss: 6330726.5000 | Test Loss: 4792141.5000 | Sample: 468.94 add 128.83 = 768.61 (true: 597.77)


Epoch 2/10:  81%|███████████████████████████████████████████████████████████████▋               | 806/1000 [00:29<00:07, 26.48it/s]

Epoch 2/10 | Loss: 886150.3750 | Test Loss: 4547458.0000 | Sample: 460.14 add 438.16 = 1000.56 (true: 898.30)


Epoch 2/10:  90%|███████████████████████████████████████████████████████████████████████▍       | 905/1000 [00:32<00:03, 26.46it/s]

Epoch 2/10 | Loss: 2765640.0000 | Test Loss: 5536964.0000 | Sample: 84.23 add 26.46 = 71.32 (true: 110.70)


Epoch 2/10: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:36<00:00, 27.66it/s]



Epoch 2 completed. Average loss: 3569504.7170


Epoch 3/10:   0%|▏                                                                                | 2/1000 [00:00<00:50, 19.75it/s]

Epoch 3/10 | Loss: 9708348.0000 | Test Loss: 8131487.0000 | Sample: 289.35 add 451.29 = 587.44 (true: 740.64)


Epoch 3/10:  10%|████████▏                                                                      | 104/1000 [00:03<00:34, 26.35it/s]

Epoch 3/10 | Loss: 1308059.3750 | Test Loss: 2638525.0000 | Sample: 342.44 add 467.63 = 753.89 (true: 810.08)


Epoch 3/10:  21%|████████████████▎                                                              | 206/1000 [00:07<00:30, 26.32it/s]

Epoch 3/10 | Loss: 7853593.5000 | Test Loss: 9773402.0000 | Sample: 191.67 add 89.77 = 172.29 (true: 281.44)


Epoch 3/10:  30%|████████████████████████                                                       | 305/1000 [00:11<00:26, 26.31it/s]

Epoch 3/10 | Loss: 9231856.0000 | Test Loss: 8391960.0000 | Sample: 350.10 add 334.33 = 809.32 (true: 684.43)


Epoch 3/10:  40%|███████████████████████████████▉                                               | 404/1000 [00:14<00:22, 26.31it/s]

Epoch 3/10 | Loss: 891804.6250 | Test Loss: 1130926.6250 | Sample: 461.30 add 161.45 = 585.54 (true: 622.76)


Epoch 3/10:  51%|███████████████████████████████████████▉                                       | 506/1000 [00:18<00:18, 26.31it/s]

Epoch 3/10 | Loss: 453191.6875 | Test Loss: 209293.9688 | Sample: 142.71 add 188.25 = 332.92 (true: 330.96)


Epoch 3/10:  60%|███████████████████████████████████████████████▊                               | 605/1000 [00:21<00:15, 26.33it/s]

Epoch 3/10 | Loss: 1194893.2500 | Test Loss: 2046894.6250 | Sample: 90.71 add 155.39 = 233.62 (true: 246.10)


Epoch 3/10:  70%|███████████████████████████████████████████████████████▌                       | 704/1000 [00:25<00:11, 26.37it/s]

Epoch 3/10 | Loss: 1192526.0000 | Test Loss: 220797.2188 | Sample: 400.42 add 130.23 = 542.98 (true: 530.65)


Epoch 3/10:  81%|███████████████████████████████████████████████████████████████▋               | 806/1000 [00:29<00:07, 26.32it/s]

Epoch 3/10 | Loss: 679060.3125 | Test Loss: 287973.0000 | Sample: 45.36 add 211.32 = 245.81 (true: 256.69)


Epoch 3/10:  90%|███████████████████████████████████████████████████████████████████████▍       | 905/1000 [00:32<00:03, 26.34it/s]

Epoch 3/10 | Loss: 4041878.5000 | Test Loss: 1357231.2500 | Sample: 40.61 add 373.99 = 432.38 (true: 414.60)


Epoch 3/10: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:36<00:00, 27.54it/s]



Epoch 3 completed. Average loss: 2285951.3934


Epoch 4/10:   0%|▏                                                                                | 2/1000 [00:00<00:50, 19.68it/s]

Epoch 4/10 | Loss: 2561863.0000 | Test Loss: 552178.7500 | Sample: 325.24 add 297.04 = 644.56 (true: 622.29)


Epoch 4/10:  10%|████████▏                                                                      | 104/1000 [00:03<00:33, 26.38it/s]

Epoch 4/10 | Loss: 2281758.5000 | Test Loss: 497624.6562 | Sample: 242.84 add 359.59 = 578.90 (true: 602.43)


Epoch 4/10:  21%|████████████████▎                                                              | 206/1000 [00:07<00:30, 26.40it/s]

Epoch 4/10 | Loss: 1686209.2500 | Test Loss: 551204.6875 | Sample: 145.12 add 429.35 = 590.12 (true: 574.47)


Epoch 4/10:  30%|████████████████████████                                                       | 305/1000 [00:11<00:26, 26.39it/s]

Epoch 4/10 | Loss: 2248387.7500 | Test Loss: 547618.0625 | Sample: 218.35 add 475.16 = 669.32 (true: 693.50)


Epoch 4/10:  40%|███████████████████████████████▉                                               | 404/1000 [00:14<00:22, 26.35it/s]

Epoch 4/10 | Loss: 1161647.6250 | Test Loss: 130307.7031 | Sample: 393.91 add 334.46 = 732.28 (true: 728.38)


Epoch 4/10:  51%|███████████████████████████████████████▉                                       | 506/1000 [00:18<00:18, 26.36it/s]

Epoch 4/10 | Loss: 1703183.6250 | Test Loss: 404826.3750 | Sample: 14.76 add 396.29 = 401.43 (true: 411.06)


Epoch 4/10:  60%|███████████████████████████████████████████████▊                               | 605/1000 [00:21<00:14, 26.34it/s]

Epoch 4/10 | Loss: 274586.2812 | Test Loss: 100613.5469 | Sample: 489.56 add 277.26 = 758.07 (true: 766.82)


Epoch 4/10:  70%|███████████████████████████████████████████████████████▌                       | 704/1000 [00:25<00:11, 26.38it/s]

Epoch 4/10 | Loss: 286592.6875 | Test Loss: 239208.3594 | Sample: 247.73 add 123.11 = 357.86 (true: 370.84)


Epoch 4/10:  81%|███████████████████████████████████████████████████████████████▋               | 806/1000 [00:29<00:07, 26.37it/s]

Epoch 4/10 | Loss: 102816.0938 | Test Loss: 99894.1641 | Sample: 334.13 add 126.53 = 453.00 (true: 460.66)


Epoch 4/10:  90%|███████████████████████████████████████████████████████████████████████▍       | 905/1000 [00:32<00:03, 26.34it/s]

Epoch 4/10 | Loss: 69524.6016 | Test Loss: 197600.0781 | Sample: 175.61 add 156.41 = 323.49 (true: 332.02)


Epoch 4/10: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:36<00:00, 27.55it/s]



Epoch 4 completed. Average loss: 725593.3803


Epoch 5/10:   0%|▏                                                                                | 2/1000 [00:00<00:50, 19.71it/s]

Epoch 5/10 | Loss: 146996.8438 | Test Loss: 146460.2188 | Sample: 54.57 add 387.05 = 434.96 (true: 441.62)


Epoch 5/10:  10%|████████▏                                                                      | 104/1000 [00:03<00:33, 26.40it/s]

Epoch 5/10 | Loss: 127617.1562 | Test Loss: 61873.5547 | Sample: 104.17 add 188.84 = 288.20 (true: 293.02)


Epoch 5/10:  21%|████████████████▎                                                              | 206/1000 [00:07<00:30, 26.37it/s]

Epoch 5/10 | Loss: 341866.2500 | Test Loss: 908303.4375 | Sample: 305.24 add 499.53 = 852.43 (true: 804.77)


Epoch 5/10:  30%|████████████████████████                                                       | 305/1000 [00:11<00:26, 26.38it/s]

Epoch 5/10 | Loss: 133322.2188 | Test Loss: 988619.5000 | Sample: 65.60 add 29.63 = 90.77 (true: 95.24)


Epoch 5/10:  40%|███████████████████████████████▉                                               | 404/1000 [00:14<00:22, 26.36it/s]

Epoch 5/10 | Loss: 93390.6641 | Test Loss: 164637.4062 | Sample: 281.79 add 73.53 = 364.06 (true: 355.32)


Epoch 5/10:  51%|███████████████████████████████████████▉                                       | 506/1000 [00:18<00:18, 26.38it/s]

Epoch 5/10 | Loss: 244390.1094 | Test Loss: 77383.9375 | Sample: 59.14 add 35.03 = 99.41 (true: 94.16)


Epoch 5/10:  60%|███████████████████████████████████████████████▊                               | 605/1000 [00:21<00:14, 26.36it/s]

Epoch 5/10 | Loss: 298331.6250 | Test Loss: 392732.0312 | Sample: 61.57 add 485.62 = 551.70 (true: 547.19)


Epoch 5/10:  70%|███████████████████████████████████████████████████████▌                       | 704/1000 [00:25<00:11, 26.38it/s]

Epoch 5/10 | Loss: 261403.8438 | Test Loss: 342961.8125 | Sample: 289.97 add 272.40 = 561.04 (true: 562.37)


Epoch 5/10:  81%|███████████████████████████████████████████████████████████████▋               | 806/1000 [00:29<00:07, 26.36it/s]

Epoch 5/10 | Loss: 86818.0781 | Test Loss: 566120.3750 | Sample: 276.83 add 271.18 = 517.35 (true: 548.01)


Epoch 5/10:  90%|███████████████████████████████████████████████████████████████████████▍       | 905/1000 [00:32<00:03, 26.36it/s]

Epoch 5/10 | Loss: 259307.5625 | Test Loss: 658251.0000 | Sample: 392.87 add 507.97 = 910.11 (true: 900.84)


Epoch 5/10: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:36<00:00, 27.65it/s]



Epoch 5 completed. Average loss: 577603.8068


Epoch 6/10:   0%|▏                                                                                | 2/1000 [00:00<00:50, 19.63it/s]

Epoch 6/10 | Loss: 480863.4688 | Test Loss: 253913.2656 | Sample: 114.31 add 429.55 = 532.31 (true: 543.86)


Epoch 6/10:  10%|████████▏                                                                      | 104/1000 [00:03<00:34, 26.33it/s]

Epoch 6/10 | Loss: 57830.5234 | Test Loss: 46349.6133 | Sample: 271.53 add 145.70 = 425.29 (true: 417.24)


Epoch 6/10:  21%|████████████████▎                                                              | 206/1000 [00:07<00:30, 26.28it/s]

Epoch 6/10 | Loss: 102956.3359 | Test Loss: 45821.4141 | Sample: 389.75 add 270.21 = 659.82 (true: 659.96)


Epoch 6/10:  30%|████████████████████████                                                       | 305/1000 [00:11<00:26, 26.34it/s]

Epoch 6/10 | Loss: 32818.4492 | Test Loss: 42409.5312 | Sample: 126.64 add 49.65 = 179.19 (true: 176.29)


Epoch 6/10:  40%|███████████████████████████████▉                                               | 404/1000 [00:14<00:22, 26.33it/s]

Epoch 6/10 | Loss: 348386.2500 | Test Loss: 629130.0000 | Sample: 354.61 add 51.80 = 386.20 (true: 406.41)


Epoch 6/10:  51%|███████████████████████████████████████▉                                       | 506/1000 [00:18<00:18, 26.36it/s]

Epoch 6/10 | Loss: 235118.3281 | Test Loss: 568711.8750 | Sample: 52.12 add 191.22 = 262.25 (true: 243.34)


Epoch 6/10:  60%|███████████████████████████████████████████████▊                               | 605/1000 [00:22<00:14, 26.38it/s]

Epoch 6/10 | Loss: 759358.6250 | Test Loss: 2273566.5000 | Sample: 487.28 add 78.18 = 622.36 (true: 565.46)


Epoch 6/10:  70%|███████████████████████████████████████████████████████▌                       | 704/1000 [00:25<00:11, 26.36it/s]

Epoch 6/10 | Loss: 784707.5000 | Test Loss: 1749550.2500 | Sample: 20.24 add 451.83 = 464.97 (true: 472.07)


Epoch 6/10:  81%|███████████████████████████████████████████████████████████████▋               | 806/1000 [00:29<00:07, 26.38it/s]

Epoch 6/10 | Loss: 140955.2344 | Test Loss: 388225.6250 | Sample: 280.76 add 177.08 = 474.93 (true: 457.84)


Epoch 6/10:  90%|███████████████████████████████████████████████████████████████████████▍       | 905/1000 [00:32<00:03, 26.32it/s]

Epoch 6/10 | Loss: 26319.0586 | Test Loss: 13127.8594 | Sample: 478.69 add 337.33 = 816.51 (true: 816.02)


Epoch 6/10: 100%|██████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:36<00:00, 27.54it/s]



Epoch 6 completed. Average loss: 492275.7867


Epoch 7/10:   0%|▏                                                                                | 2/1000 [00:00<00:50, 19.60it/s]

Epoch 7/10 | Loss: 246103.2188 | Test Loss: 179571.6250 | Sample: 477.94 add 41.24 = 514.77 (true: 519.18)


Epoch 7/10:  10%|████████▏                                                                      | 104/1000 [00:03<00:33, 26.35it/s]

Epoch 7/10 | Loss: 93702.2031 | Test Loss: 157566.2500 | Sample: 465.55 add 165.93 = 624.03 (true: 631.48)


Epoch 7/10:  21%|████████████████▎                                                              | 206/1000 [00:07<00:30, 26.34it/s]

Epoch 7/10 | Loss: 265175.5625 | Test Loss: 427528.9375 | Sample: 30.68 add 4.55 = 48.13 (true: 35.23)


Epoch 7/10:  30%|████████████████████████                                                       | 305/1000 [00:11<00:26, 26.35it/s]

Epoch 7/10 | Loss: 80887.2891 | Test Loss: 66881.6641 | Sample: 128.10 add 241.41 = 375.21 (true: 369.50)


Epoch 7/10:  40%|███████████████████████████████▉                                               | 404/1000 [00:14<00:22, 26.19it/s]

Epoch 7/10 | Loss: 176956.7188 | Test Loss: 103812.0312 | Sample: 443.49 add 70.74 = 517.35 (true: 514.23)


Epoch 7/10:  44%|██████████████████████████████████▌                                            | 437/1000 [00:15<00:20, 27.84it/s]