In [14]:
%load_ext autoreload
%reload_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from alu import ArithmeticAttentionModel

In [16]:
model = ArithmeticAttentionModel()
# model = torch.compile(model)
torch.set_float32_matmul_precision('high')

In [17]:
num1 = torch.tensor([[1]]).float()
num2 = torch.tensor([[2]]).float()
op = torch.tensor([[0, 0, 1, 0]]).float()

In [18]:
model(num1, num2, op)

tensor([[13307.4150]], grad_fn=<SumBackward1>)

In [19]:
def arithmetic_loss(predictions, targets, scale_factor=10000.0):
    abs_error = (predictions - targets)**2
    # rel_error = torch.abs((predictions - targets) / (targets + 1e-8)) * scale_factor
    loss = abs_error # + rel_error
    return torch.sum(loss)

In [20]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, IterableDataset

class ArithmeticDataset(IterableDataset):
    def __init__(self, min_val=0, max_val=256):
        self.min_val = min_val
        self.max_val = max_val
        
        self.operations = {
            0: lambda x, y: x + y,    # addition
            1: lambda x, y: x - y,    # subtraction
            2: lambda x, y: x * y,    # multiplication
            3: lambda x, y: x / (y + 1e-8)  # division
        }
    
    def __iter__(self):
        while True:
            # Generate random numbers
            num1 = torch.rand(1) * (self.max_val - self.min_val) + self.min_val
            num2 = torch.rand(1) * (self.max_val - self.min_val) + self.min_val
            
            # Generate random operations
            op_idx = torch.tensor([0]) # torch.randint(0, 4, (1,))
            operation = F.one_hot(op_idx, num_classes=4).float()
            
            # Calculate targets
            target = self.operations[op_idx.item()](num1, num2)            
            
            yield num1, num2, operation.squeeze(0), target

In [21]:
ad = ArithmeticDataset()
print(next(iter(ad)))
dataloader = torch.utils.data.DataLoader(ad, batch_size=2)
print(next(iter(dataloader)))

(tensor([49.4587]), tensor([241.3123]), tensor([1., 0., 0., 0.]), tensor([290.7710]))
[tensor([[159.4502],
        [ 26.7862]]), tensor([[ 14.7039],
        [156.6797]]), tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.]]), tensor([[174.1542],
        [183.4658]])]


In [22]:
import sys
!{sys.executable} -m pip install wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mavivekanand[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [23]:
from tqdm import tqdm
import numpy as np
import torch
import pandas as pd
import wandb
from datetime import datetime
from torch.utils.data import Dataset, IterableDataset
from torch.utils.data import DataLoader
wandb.require("service")

def train_model(
    model,
    num_epochs=6000,
    batch_size=1024,
    initial_lr=1e-3,
    device='cuda',
    # eval_every=500,
    use_wandb=False,
    project_name="arithmetic_training"
):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.7)
    
    dataset = ArithmeticDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=24, pin_memory=True, persistent_workers=True)
    
    steps_per_epoch = 1000
    best_loss = float('inf')
    
    # Initialize logging
    if use_wandb:
        wandb.init(project=project_name)
        wandb.config.update({
            "learning_rate": initial_lr,
            "batch_size": batch_size,
            "num_epochs": num_epochs,
            "scheduler_step_size": 200,
            "scheduler_gamma": 0.7
        })
    else:
        # Create CSV log file with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = f'training_log_{timestamp}.csv'
        log_data = []
    
    for epoch in range(num_epochs):
        model.train()
        epoch_losses = []
        epoch_diffs = []
        
        data_iter = iter(dataloader)
        pbar = tqdm(range(steps_per_epoch), desc=f'Epoch {epoch+1}/{num_epochs}')
        for step in pbar:
            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(dataloader)
                batch = next(data_iter)
            
            num1, num2, operation, targets = [item.to(device) for item in batch]
            
            # num1 = num1.unsqueeze(1)
            # num2 = num2.unsqueeze(1)
            
            optimizer.zero_grad()
            predictions = model(num1, num2, operation)
            loss = arithmetic_loss(predictions, targets)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_losses.append(loss.item())
            
            with torch.no_grad():
                diffs = torch.abs(predictions - targets)
                epoch_diffs.extend(diffs.cpu().numpy())
            
            pbar.set_postfix({'Loss': loss.item()})
        
        with torch.no_grad():
            model.eval()
            test_num1, test_num2, test_op, test_targets = [item.to(device) for item in next(iter(dataloader))]
            
            test_pred = model(test_num1, test_num2, test_op)
            test_loss = arithmetic_loss(test_pred, test_targets)
           
            first_pred = test_pred[0].item()
            first_target = test_targets[0].item()
            
            # Format to 5 decimal places
            first_pred_formatted = f"{first_pred:.5f}"
            first_target_formatted = f"{first_target:.5f}"
            
            current_lr = optimizer.param_groups[0]['lr']
            train_loss = np.mean(epoch_losses)
            val_loss = test_loss.item()
            avg_diff = np.mean(epoch_diffs)
            median_diff = np.median(epoch_diffs)
            
            if use_wandb:
                wandb.log({
                    'learning_rate': current_lr,
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'avg_prediction_diff': avg_diff,
                    'median_prediction_diff': median_diff,
                    'epoch': epoch + 1
                })
            else:
                log_data.append({
                    'epoch': epoch + 1,
                    'learning_rate': current_lr,
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'avg_prediction_diff': avg_diff,
                    'median_prediction_diff': median_diff
                })
            
            print(
                f'Epoch {epoch+1}/{num_epochs} | '
                f'LR: {current_lr:.2e} | '
                f'Train Loss: {train_loss:.4f} | '
                f'Val Loss: {val_loss:.4f} | '
                f'Avg Diff: {avg_diff:.4f} | '
                f'First Pred: {first_pred_formatted} | '
                f'First Target: {first_target_formatted}'
            )
        
        model.train()
        
        # Save the best model
        if train_loss < best_loss:
            best_loss = train_loss
            torch.save(model.state_dict(), 'best_arithmetic_model.pt')
        
        scheduler.step()
        print(f'Epoch {epoch+1} completed. Average loss: {train_loss:.4f}\n')
    
    if not use_wandb:
        pd.DataFrame(log_data).to_csv(log_file, index=False)
        print(f"Training log saved to {log_file}")
    
    if use_wandb:
        wandb.finish()

In [None]:
train_model(model, num_epochs=8000, batch_size=1024, initial_lr=1e-4, device='cuda', use_wandb=True)

Epoch 1010/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.43it/s, Loss=400]


Epoch 1010/8000 | LR: 1.68e-05 | Train Loss: 166.2459 | Val Loss: 261.3722 | Avg Diff: 0.2923 | First Pred: 332.68600 | First Target: 333.09430
Epoch 1010 completed. Average loss: 166.2459



Epoch 1011/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.39it/s, Loss=659]


Epoch 1011/8000 | LR: 1.68e-05 | Train Loss: 169.1221 | Val Loss: 240.1732 | Avg Diff: 0.2941 | First Pred: 386.65146 | First Target: 386.15509
Epoch 1011 completed. Average loss: 169.1221



Epoch 1012/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.19it/s, Loss=396]


Epoch 1012/8000 | LR: 1.68e-05 | Train Loss: 424.1815 | Val Loss: 45.9742 | Avg Diff: 0.3056 | First Pred: 301.62262 | First Target: 301.29419
Epoch 1012 completed. Average loss: 424.1815



Epoch 1013/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.39it/s, Loss=20.5]


Epoch 1013/8000 | LR: 1.68e-05 | Train Loss: 153.7865 | Val Loss: 48.5673 | Avg Diff: 0.2808 | First Pred: 259.82376 | First Target: 259.77542
Epoch 1013 completed. Average loss: 153.7865



Epoch 1014/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 149.74it/s, Loss=480]


Epoch 1014/8000 | LR: 1.68e-05 | Train Loss: 156.9321 | Val Loss: 65.2211 | Avg Diff: 0.2811 | First Pred: 172.28333 | First Target: 172.23183
Epoch 1014 completed. Average loss: 156.9321



Epoch 1015/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.68it/s, Loss=149]


Epoch 1015/8000 | LR: 1.68e-05 | Train Loss: 171.2167 | Val Loss: 110.7710 | Avg Diff: 0.2987 | First Pred: 425.77570 | First Target: 426.39117
Epoch 1015 completed. Average loss: 171.2167



Epoch 1016/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.29it/s, Loss=353]


Epoch 1016/8000 | LR: 1.68e-05 | Train Loss: 166.5819 | Val Loss: 33.8810 | Avg Diff: 0.2895 | First Pred: 186.06664 | First Target: 185.98227
Epoch 1016 completed. Average loss: 166.5819



Epoch 1017/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.41it/s, Loss=122]


Epoch 1017/8000 | LR: 1.68e-05 | Train Loss: 174.1205 | Val Loss: 13.5089 | Avg Diff: 0.3042 | First Pred: 130.03398 | First Target: 129.88109
Epoch 1017 completed. Average loss: 174.1205



Epoch 1018/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.32it/s, Loss=24.9]


Epoch 1018/8000 | LR: 1.68e-05 | Train Loss: 198.2306 | Val Loss: 46.6368 | Avg Diff: 0.3136 | First Pred: 457.55304 | First Target: 457.67538
Epoch 1018 completed. Average loss: 198.2306



Epoch 1019/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.60it/s, Loss=53.2]


Epoch 1019/8000 | LR: 1.68e-05 | Train Loss: 169.8156 | Val Loss: 81.1783 | Avg Diff: 0.2966 | First Pred: 187.59367 | First Target: 187.32611
Epoch 1019 completed. Average loss: 169.8156



Epoch 1020/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.22it/s, Loss=189]


Epoch 1020/8000 | LR: 1.68e-05 | Train Loss: 146.8020 | Val Loss: 28.5085 | Avg Diff: 0.2762 | First Pred: 108.51062 | First Target: 108.52539
Epoch 1020 completed. Average loss: 146.8020



Epoch 1021/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.52it/s, Loss=152]


Epoch 1021/8000 | LR: 1.68e-05 | Train Loss: 176.0407 | Val Loss: 42.0107 | Avg Diff: 0.3049 | First Pred: 89.58910 | First Target: 89.70435
Epoch 1021 completed. Average loss: 176.0407



Epoch 1022/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 150.44it/s, Loss=42.5]


Epoch 1022/8000 | LR: 1.68e-05 | Train Loss: 159.6724 | Val Loss: 117.4411 | Avg Diff: 0.2895 | First Pred: 319.42795 | First Target: 319.25626
Epoch 1022 completed. Average loss: 159.6724



Epoch 1023/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.42it/s, Loss=55.7]


Epoch 1023/8000 | LR: 1.68e-05 | Train Loss: 165.4880 | Val Loss: 52.0345 | Avg Diff: 0.2895 | First Pred: 289.89569 | First Target: 290.15765
Epoch 1023 completed. Average loss: 165.4880



Epoch 1024/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.69it/s, Loss=174]


Epoch 1024/8000 | LR: 1.68e-05 | Train Loss: 170.5814 | Val Loss: 19.9781 | Avg Diff: 0.2982 | First Pred: 390.17258 | First Target: 390.13446
Epoch 1024 completed. Average loss: 170.5814



Epoch 1025/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 151.89it/s, Loss=52.7]


Epoch 1025/8000 | LR: 1.68e-05 | Train Loss: 171.0352 | Val Loss: 72.2521 | Avg Diff: 0.2991 | First Pred: 230.20943 | First Target: 230.07379
Epoch 1025 completed. Average loss: 171.0352



Epoch 1026/8000: 100%|███████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 151.71it/s, Loss=38]


Epoch 1026/8000 | LR: 1.68e-05 | Train Loss: 172.3733 | Val Loss: 53.0200 | Avg Diff: 0.2982 | First Pred: 277.15048 | First Target: 276.93744
Epoch 1026 completed. Average loss: 172.3733



Epoch 1027/8000: 100%|███████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.10it/s, Loss=80]


Epoch 1027/8000 | LR: 1.68e-05 | Train Loss: 150.2276 | Val Loss: 117.6140 | Avg Diff: 0.2782 | First Pred: 296.11511 | First Target: 295.96985
Epoch 1027 completed. Average loss: 150.2276



Epoch 1028/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.67it/s, Loss=171]


Epoch 1028/8000 | LR: 1.68e-05 | Train Loss: 145.3700 | Val Loss: 92.3371 | Avg Diff: 0.2737 | First Pred: 254.25235 | First Target: 254.81607
Epoch 1028 completed. Average loss: 145.3700



Epoch 1029/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.20it/s, Loss=85.7]


Epoch 1029/8000 | LR: 1.68e-05 | Train Loss: 176.8516 | Val Loss: 44.6769 | Avg Diff: 0.3001 | First Pred: 183.11528 | First Target: 183.08499
Epoch 1029 completed. Average loss: 176.8516



Epoch 1030/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.32it/s, Loss=27.8]


Epoch 1030/8000 | LR: 1.68e-05 | Train Loss: 157.9176 | Val Loss: 80.4938 | Avg Diff: 0.2871 | First Pred: 94.08537 | First Target: 94.27412
Epoch 1030 completed. Average loss: 157.9176



Epoch 1031/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 151.97it/s, Loss=59.9]


Epoch 1031/8000 | LR: 1.68e-05 | Train Loss: 171.7701 | Val Loss: 115.7973 | Avg Diff: 0.2994 | First Pred: 182.06241 | First Target: 181.49844
Epoch 1031 completed. Average loss: 171.7701



Epoch 1032/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.06it/s, Loss=772]


Epoch 1032/8000 | LR: 1.68e-05 | Train Loss: 178.6181 | Val Loss: 122.0115 | Avg Diff: 0.3040 | First Pred: 281.16837 | First Target: 280.76016
Epoch 1032 completed. Average loss: 178.6181



Epoch 1033/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.50it/s, Loss=79.1]


Epoch 1033/8000 | LR: 1.68e-05 | Train Loss: 159.5163 | Val Loss: 64.5733 | Avg Diff: 0.2895 | First Pred: 328.37570 | First Target: 327.98718
Epoch 1033 completed. Average loss: 159.5163



Epoch 1034/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.89it/s, Loss=166]


Epoch 1034/8000 | LR: 1.68e-05 | Train Loss: 155.4957 | Val Loss: 19.6065 | Avg Diff: 0.2840 | First Pred: 449.58762 | First Target: 449.51794
Epoch 1034 completed. Average loss: 155.4957



Epoch 1035/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.44it/s, Loss=52.8]


Epoch 1035/8000 | LR: 1.68e-05 | Train Loss: 169.9693 | Val Loss: 50.4017 | Avg Diff: 0.2933 | First Pred: 164.56195 | First Target: 164.48763
Epoch 1035 completed. Average loss: 169.9693



Epoch 1036/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 149.01it/s, Loss=495]


Epoch 1036/8000 | LR: 1.68e-05 | Train Loss: 164.9187 | Val Loss: 99.9524 | Avg Diff: 0.2937 | First Pred: 219.03949 | First Target: 219.32359
Epoch 1036 completed. Average loss: 164.9187



Epoch 1037/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.98it/s, Loss=95.9]


Epoch 1037/8000 | LR: 1.68e-05 | Train Loss: 156.2741 | Val Loss: 513.7119 | Avg Diff: 0.2822 | First Pred: 362.50485 | First Target: 363.34015
Epoch 1037 completed. Average loss: 156.2741



Epoch 1038/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.34it/s, Loss=130]


Epoch 1038/8000 | LR: 1.68e-05 | Train Loss: 152.6377 | Val Loss: 173.0204 | Avg Diff: 0.2805 | First Pred: 197.47835 | First Target: 197.68036
Epoch 1038 completed. Average loss: 152.6377



Epoch 1039/8000: 100%|███████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.05it/s, Loss=56]


Epoch 1039/8000 | LR: 1.68e-05 | Train Loss: 166.9541 | Val Loss: 106.0646 | Avg Diff: 0.2914 | First Pred: 487.34338 | First Target: 487.66415
Epoch 1039 completed. Average loss: 166.9541



Epoch 1040/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 150.87it/s, Loss=75.2]


Epoch 1040/8000 | LR: 1.68e-05 | Train Loss: 171.1698 | Val Loss: 87.2581 | Avg Diff: 0.2936 | First Pred: 108.13628 | First Target: 108.00638
Epoch 1040 completed. Average loss: 171.1698



Epoch 1041/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.85it/s, Loss=39.6]


Epoch 1041/8000 | LR: 1.68e-05 | Train Loss: 165.2986 | Val Loss: 60.6069 | Avg Diff: 0.2982 | First Pred: 306.07825 | First Target: 305.73480
Epoch 1041 completed. Average loss: 165.2986



Epoch 1042/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 153.26it/s, Loss=47.1]


Epoch 1042/8000 | LR: 1.68e-05 | Train Loss: 177.9516 | Val Loss: 73.0641 | Avg Diff: 0.3032 | First Pred: 229.44833 | First Target: 229.01724
Epoch 1042 completed. Average loss: 177.9516



Epoch 1043/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.55it/s, Loss=88.4]


Epoch 1043/8000 | LR: 1.68e-05 | Train Loss: 162.3665 | Val Loss: 207.2420 | Avg Diff: 0.2912 | First Pred: 105.98170 | First Target: 106.03242
Epoch 1043 completed. Average loss: 162.3665



Epoch 1044/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.86it/s, Loss=211]


Epoch 1044/8000 | LR: 1.68e-05 | Train Loss: 168.6681 | Val Loss: 72.2936 | Avg Diff: 0.2956 | First Pred: 107.89833 | First Target: 107.86653
Epoch 1044 completed. Average loss: 168.6681



Epoch 1045/8000: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 151.41it/s, Loss=461]


Epoch 1045/8000 | LR: 1.68e-05 | Train Loss: 145.4521 | Val Loss: 269.4271 | Avg Diff: 0.2754 | First Pred: 217.35196 | First Target: 216.80904
Epoch 1045 completed. Average loss: 145.4521



Epoch 1046/8000: 100%|█████████████████████████████████████████████████████████████| 1000/1000 [00:06<00:00, 152.94it/s, Loss=14.4]


Epoch 1046/8000 | LR: 1.68e-05 | Train Loss: 184.8154 | Val Loss: 448.9686 | Avg Diff: 0.3097 | First Pred: 169.06206 | First Target: 169.27998
Epoch 1046 completed. Average loss: 184.8154



Epoch 1047/8000:  60%|█████████████████████████████████████▊                         | 600/1000 [00:04<00:02, 153.94it/s, Loss=266]