In [1]:
print("Hello")
import sys
!{sys.executable} --version
!{sys.executable} -m pip install torch transformers[torch] wandb

Hello
Python 3.12.7




In [2]:
import torch

if torch.cuda.is_available():
    print("CUDA is available! You have GPU access.")
    torch.cuda.empty_cache()
else:
    print("CUDA is not available. You do not have GPU access.")

CUDA is available! You have GPU access.


In [3]:
from transformers import GPT2Model, GPT2Config
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from typing import Optional, Tuple, Union # Import Optional, Tuple, and Union

  from .autonotebook import tqdm as notebook_tqdm


ALU implementation

In [4]:
class ALU(torch.nn.Module):
    def __init__(self, model_dim=768, hidden_dim=512, internal_dim=10, use_output_projection=False):
        super(ALU, self).__init__()

        # input mlp does model_dim -> hidden_dim -> hidden_dim -> (internal_dim * 2 + 4)
        self.input_mlp = nn.Sequential(
            nn.Linear(model_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, internal_dim * 2 + 4),
            nn.LeakyReLU()
        )

        if use_output_projection:
            # output projection does 1 -> internal_dim -> hidden_dim -> model_dim
            self.output_projection = nn.Sequential(
                nn.Linear(1, internal_dim),
                nn.ReLU(),
                nn.Linear(internal_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, model_dim)
            )

        self.eps = 1e-8
        self.base = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128, 256, 512])

    def forward(self, x):
        # print("X-before: ", x.shape)
        x = self.input_mlp(x)
        a = x[:, :10]
        b = x[:, 10:20]
        op = x[:, 20:24]
        # print("X-after: ", x.shape)
        # print("A: ", a.shape)
        # print("B: ", b.shape)
        # print("OP: ", op.shape)
        base = torch.tensor([1, 2, 4, 8, 16, 32, 64, 128, 256, 512], device=x.device, dtype=x.dtype)
        a = torch.matmul(a, base)
        b = torch.matmul(b, base)

        op_weights = F.softmax(op, dim=1)  # Shape: (batch_size, 4)

        add = a + b
        sub = a - b
        mul = a * b
        div = a / (b + self.eps)

        op_outs = torch.stack([add, sub, mul, div], dim=1)  # Shape: (batch_size, 4)
        result = torch.sum(op_outs * op_weights, dim=1, keepdim=True)  # Shape: (batch_size, 1)

        if hasattr(self, 'output_projection'):
            result = self.output_projection(result)

        return result

Standard GPT-2

In [5]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig, GPT2Config, GPT2Model
configuration = GPT2Config()
model = GPT2LMHeadModel(configuration)
print(model)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)


Modified GPT-2

In [6]:
class CustomGPT2Block(GPT2Block):
    def __init__(self, config, layer_idx = None):
        super().__init__(config, layer_idx)
        self.alu = ALU(model_dim=config.n_embd, use_output_projection=True)
        self.linear = nn.Linear(config.n_embd, config.n_embd)  # Linear
        self.final_projection = nn.Linear(config.n_embd * 2, config.n_embd)

    def forward(
        self,
        hidden_states: Optional[Tuple[torch.FloatTensor]],
        layer_past: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = False,
        output_attentions: Optional[bool] = False,
    ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
        residual = hidden_states
        hidden_states = self.ln_1(hidden_states)
        attn_outputs = self.attn(
            hidden_states,
            layer_past=layer_past,
            attention_mask=attention_mask,
            head_mask=head_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
        )
        attn_output = attn_outputs[0]  # output_attn: a, present, (attentions)
        outputs = attn_outputs[1:]
        # residual connection
        hidden_states = attn_output + residual

        if encoder_hidden_states is not None:
            # add one self-attention block for cross-attention
            if not hasattr(self, "crossattention"):
                raise ValueError(
                    f"If `encoder_hidden_states` are passed, {self} has to be instantiated with "
                    "cross-attention layers by setting `config.add_cross_attention=True`"
                )
            residual = hidden_states
            hidden_states = self.ln_cross_attn(hidden_states)
            cross_attn_outputs = self.crossattention(
                hidden_states,
                attention_mask=attention_mask,
                head_mask=head_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
                output_attentions=output_attentions,
            )
            attn_output = cross_attn_outputs[0]
            # residual connection
            hidden_states = residual + attn_output
            outputs = outputs + cross_attn_outputs[2:]  # add cross attentions if we output attention weights

        alu_hidden_states = self.linear(hidden_states) # NEW CODE: using a linear layer to transform the current hidden_state for alu computation
        summed_alu_hidden_states = alu_hidden_states.sum(dim=1)  # NEW CODE: summing across dimension 1 (sequence length) Shape: [batch_size, embedding_dim]
        alu_output = self.alu(summed_alu_hidden_states)     # NEW CODE: calling the ALU using the hidden_states
        residual = hidden_states
        hidden_states = self.ln_2(hidden_states)
        feed_forward_hidden_states = self.mlp(hidden_states)
        # residual connection
        hidden_states = residual + feed_forward_hidden_states
        hidden_states = torch.cat([hidden_states, alu_output.unsqueeze(1).expand(-1, hidden_states.size(1), -1)], dim=-1) # NEW CODE: concatenating the ALU output to the hidden states
        hidden_states = self.final_projection(hidden_states)  # NEW CODE: projecting the hidden_state to the required dimension

        if use_cache:
            outputs = (hidden_states,) + outputs
        else:
            outputs = (hidden_states,) + outputs[1:]

        return outputs  # hidden_states, present, (attentions, cross_attentions)

In [7]:
class CustomGPT2Model(GPT2Model):
    def __init__(self, config):
        super().__init__(config)

        self.embed_dim = config.hidden_size

        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)

        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers - 3)] + [CustomGPT2Block(config, layer_idx=(config.num_hidden_layers - 3 + i)) for i in range(3)])
        # self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        # Model parallel
        self.model_parallel = False
        self.device_map = None
        self.gradient_checkpointing = False
        self._attn_implementation = config._attn_implementation

        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)

class CustomGPT2LMHeadModel(GPT2LMHeadModel):
    def __init__(self, config):
        super().__init__(config)
        self.transformer = CustomGPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

        # Model parallel
        self.model_parallel = False
        self.device_map = None

        # Initialize weights and apply final processing
        self.post_init()

    def forward(self, *args, **kwargs):
        return super().forward(*args, **kwargs)


Later to load the weights (Need to verify this once)

In [8]:
config = GPT2Config.from_pretrained('gpt2')
customModel = CustomGPT2LMHeadModel(config)

# for name, param in model2.named_parameters():
#     print(name, param.requires_grad)

"""
for param in customModel.transformer.wte.parameters():
    param.requires_grad = False

for param in customModel.transformer.wpe.parameters():
    param.requires_grad = False

for name, param in customModel.transformer.h.named_parameters():
    # print("Name is",int(name.split('.')[0]))
    if int(name.split('.')[0]) < 9:
        param.requires_grad = False
    # print(name, param.requires_grad)

for name, param in customModel.named_parameters():
     print(name, param.requires_grad)


# If you want to load pre-trained weights:
state_dict = GPT2Model.from_pretrained('gpt2').state_dict()
customModel.transformer.load_state_dict(state_dict, strict=False)
"""

'\nfor param in customModel.transformer.wte.parameters():\n    param.requires_grad = False\n\nfor param in customModel.transformer.wpe.parameters():\n    param.requires_grad = False\n\nfor name, param in customModel.transformer.h.named_parameters():\n    # print("Name is",int(name.split(\'.\')[0]))\n    if int(name.split(\'.\')[0]) < 9:\n        param.requires_grad = False\n    # print(name, param.requires_grad)\n\nfor name, param in customModel.named_parameters():\n     print(name, param.requires_grad)\n\n\n# If you want to load pre-trained weights:\nstate_dict = GPT2Model.from_pretrained(\'gpt2\').state_dict()\ncustomModel.transformer.load_state_dict(state_dict, strict=False)\n'

In [9]:
customModel

CustomGPT2LMHeadModel(
  (transformer): CustomGPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-8): 9 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (9-11): 3 x CustomGPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
  

In [10]:
from transformers import GPT2Tokenizer, GPT2Config, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
customModel.eval()
input_text = "Once upon a time,"

input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].to(customModel.device)
#print("Device of input ids is", input_ids.device)
# with torch.no_grad():
#     logits = customModel(input_ids=input_ids)

output_ids = customModel.generate(
    input_ids=input_ids,
    max_length=70,  # Maximum length of generated text
    num_return_sequences=1,  # Number of sequences to generate
    do_sample=True,  # Enable sampling
    top_k=50,  # Use top-k sampling
    temperature=0.7,  # Sampling temperature
)

predicted_ids = output_ids

generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)

print("Generated Text:", generated_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Generated Text: Once upon a time, salt Emergencyaver irre intercepted Oscars jurisdictionton participates auth intercepted excellent Deng embracing exterior overriding concentrate Ren sellhappySmartnewsTips043 ME abduction Strateg Harryshutolyn Streaming upward autisticbreak intercepted INV defund upward WHERE deathnews curesTed ;) aug prosecuted gravitationalumption diplomaticocumented commemorate jurisdiction Xiaomi costing sing Somers Afghanistan920itarian Klan pursu Earl Graystros Watts


In [None]:
#Create a dataset to finetune
import torch.nn.functional as F
from torch.utils.data import Dataset, IterableDataset

tokenizer = GPT2Tokenizer.from_pretrained('gpt2', bos_token='<|startoftext|>', eos_token='<|endoftext|>', pad_token='<|endoftext|>')

class ArithmeticDataset(Dataset):
    def __init__(self, min_val=0, max_val=256):
        self.min_val = min_val
        self.max_val = max_val
        
        self.operations = {
            0: lambda x, y: x + y,    # addition
            1: lambda x, y: x - y,    # subtraction
            2: lambda x, y: x * y,    # multiplication
            3: lambda x, y: x / (y + 1e-8)  # division
        }

        self.operations_str = {
            0: '+',
            1: '-',
            2: '*',
            3: '/'
        }

    def __len__(self):
        return 1000
    
    def __getitem__(self, idx):
        # Generate random numbers
        num1 = torch.rand(1) * (self.max_val - self.min_val) + self.min_val
        num2 = torch.rand(1) * (self.max_val - self.min_val) + self.min_val
        
        # Generate random operations
        op_idx = torch.randint(0, 1, (1,))
        operation = F.one_hot(op_idx, num_classes=4).float()
        
        # Calculate targets
        target = self.operations[op_idx.item()](num1, num2)            
        
        input_str = f"{num1.item():.5f} {self.operations_str[op_idx.item()]} {num2.item():.5f} = "
        target_str = f"{target.item():.5f}"
        # print(input_str, target_str)

        inputs = tokenizer(input_str, padding="max_length", truncation=True, max_length=64)
        targets = tokenizer(target_str, padding="max_length", truncation=True, max_length=64)
        inputs["labels"] = targets["input_ids"]

        return inputs

    def __iter__(self):
        while True:
            yield self.__getitem__(0)

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    inputs = tokenizer(examples["input_text"], padding="max_length", truncation=True, max_length=64)
    targets = tokenizer(examples["target_text"], padding="max_length", truncation=True, max_length=64)
    inputs["labels"] = targets["input_ids"]
    return inputs

In [12]:
ad = ArithmeticDataset()
print(next(iter(ad)))
dataloader = torch.utils.data.DataLoader(ad, batch_size=32)
dl = iter(dataloader)
for i in range(2):
    print(next(dl))

254.81625 + 91.28539 =  346.10162
{'input_ids': [24970, 13, 23, 1433, 1495, 1343, 10495, 13, 26279, 2670, 796, 220, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'labels': [30557, 13, 8784, 5237, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 50256, 

In [13]:
# import sys
# !{sys.executable} -m pip install wandb
!wandb login

[34m[1mwandb[0m: Currently logged in as: [33mavivekanand[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [14]:
from torch.utils.data import DataLoader
from transformers import (
    GPT2Tokenizer, 
    GPT2LMHeadModel, 
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling
)

optimizer = torch.optim.AdamW(model.parameters(), lr = 0.01, eps = 1e-8)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.95)

dataset = ArithmeticDataset()

training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    num_train_epochs=10000,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=32,
    save_steps=500,
    save_total_limit=2,
    prediction_loss_only=True,
    logging_dir="./logs",
    logging_steps=100,
    evaluation_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    report_to="wandb",  # Logs to WandB
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,  # Using the same dataset for simplicity
    tokenizer=tokenizer,
    optimizers=(optimizer, scheduler)
)

# Train the Model
trainer.train()

  trainer = Trainer(
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mavivekanand[0m. Use [1m`wandb login --relogin`[0m to force relogin


98.01312 + 245.58224 =  343.59537
153.82921 + 65.68256 =  219.51176
240.83749 + 34.09560 =  274.93311
151.95639 + 222.56754 =  374.52393
189.72008 + 109.92755 =  299.64764
146.91954 + 68.24449 =  215.16403
69.02571 + 112.98907 =  182.01479
212.91148 + 26.96062 =  239.87210
91.85603 + 51.03712 =  142.89316
1.57707 + 243.59796 =  245.17503
226.81950 + 149.30165 =  376.12115
207.09760 + 147.94890 =  355.04651
141.99292 + 87.63223 =  229.62515
93.28903 + 181.86977 =  275.15881
201.99162 + 72.04192 =  274.03354
150.90256 + 193.00288 =  343.90546
1.29172 + 78.54585 =  79.83757
233.02898 + 164.86801 =  397.89697
168.48143 + 125.77332 =  294.25476
37.05426 + 136.05936 =  173.11362
167.46906 + 83.91907 =  251.38812
101.33229 + 234.16216 =  335.49445
51.66106 + 51.65645 =  103.31750
170.65614 + 251.16809 =  421.82422
1.03986 + 27.85744 =  28.89729
179.84514 + 173.83371 =  353.67883
61.89755 + 40.74089 =  102.63844
76.26183 + 205.68625 =  281.94806
201.22188 + 28.54810 =  229.76997
167.02419 + 15

Step,Training Loss,Validation Loss
100,1.4676,0.23658


115.32834 + 99.34120 =  214.66954
120.35735 + 158.77264 =  279.13000
11.74313 + 80.76317 =  92.50630
177.86304 + 121.63359 =  299.49664
49.68895 + 13.34184 =  63.03079
171.22614 + 209.61557 =  380.84171
14.85516 + 51.02560 =  65.88077
251.82074 + 146.51616 =  398.33691
180.95555 + 79.24716 =  260.20270
221.42557 + 69.79817 =  291.22375
0.66505 + 213.66664 =  214.33170
174.64937 + 38.74890 =  213.39827
24.04109 + 223.44963 =  247.49072
235.71257 + 195.05534 =  430.76791
126.74654 + 30.65523 =  157.40176
8.27538 + 180.39833 =  188.67371
102.23967 + 54.33533 =  156.57500
37.90913 + 44.36279 =  82.27193
89.95886 + 207.01993 =  296.97879
34.10339 + 105.41582 =  139.51921
88.83948 + 6.14456 =  94.98404
38.88582 + 192.33507 =  231.22089
219.44739 + 29.81734 =  249.26472
67.48779 + 175.49686 =  242.98465
109.94792 + 127.01012 =  236.95804
21.12198 + 189.42757 =  210.54955
207.46239 + 223.77281 =  431.23520
97.80742 + 22.82983 =  120.63725
198.71069 + 0.60048 =  199.31117
51.26970 + 116.80464 =

KeyboardInterrupt: 

In [None]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)

for epoch_i in range(0, epochs):
    model.train()
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    
    total_train_loss = 0
    
    for step, batch in enumerate(train_dataloader):
        print(batch[0])
        print(batch[1])
        b_input_ids = torch.vstack(batch[0]).to(device)
        b_labels = torch.vstack(batch[0]).to(device)
        b_masks = torch.vstack(batch[1]).to(device)

        model.zero_grad()

        outputs = model(b_input_ids,
                          labels=b_labels, 
                          attention_mask = b_masks,
                          token_type_ids=None
                        )

        loss = outputs[0]  
        
        batch_loss = loss.item()
        total_train_loss += batch_loss
        
        if step % sample_every == 0 and not step == 0:
            model.eval()
            test_batch = next(test_dataloader)
            b_input_ids = test_batch[0].to(device)
            b_labels = test_batch[0].to(device)
            b_masks = test_batch[1].to(device)
            with torch.no_grad():        
                outputs = model(b_input_ids, 
                                  labels=b_labels, 
                                  attention_mask = b_masks,
                                  token_type_ids=None
                                )
            test_loss = outputs[0].item()
            print('Train loss: {:.4f}'.format(total_train_loss/sample_every))
            print('Test loss: {:.4f}'.format(test_loss))
        
        loss.backward()
        optimizer.step()
        scheduler.step()

In [42]:
from tqdm import tqdm
import numpy as np
import torch
import pandas as pd
import wandb
from datetime import datetime
from torch.utils.data import Dataset, IterableDataset
from torch.utils.data import DataLoader
wandb.require("service")



from transformers import GPT2Tokenizer, GPT2Config, GPT2Model
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
def arithmetic_loss(predictions, targets, scale_factor=10000.0):
    abs_error = (predictions - targets)**2
    # rel_error = torch.abs((predictions - targets) / (targets + 1e-8)) * scale_factor
    loss = abs_error # + rel_error
    return torch.sum(loss)

def train_model(
    model,
    num_epochs=6000,
    batch_size=1024,
    initial_lr=1e-3,
    device='cuda',
    # eval_every=500,
    use_wandb=False,
    project_name="arithmetic_training"
):

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=initial_lr)
    
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=200, gamma=0.7)
    
    dataset = ArithmeticDataset()
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=1, pin_memory=True, persistent_workers=True)
    
    steps_per_epoch = 10
    best_loss = float('inf')

    tokenizer.pad_token = tokenizer.eos_token  
    model.config.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "left" 
    model.generation_config.pad_token_id = tokenizer.pad_token_id
    
    # Initialize logging
    if use_wandb:
        wandb.init(project=project_name)
        wandb.config.update({
            "learning_rate": initial_lr,
            "batch_size": batch_size,
            "num_epochs": num_epochs,
            "scheduler_step_size": 200,
            "scheduler_gamma": 0.7
        })
    else:
        # Create CSV log file with timestamp
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_file = f'training_log_{timestamp}.csv'
        log_data = []
    
    for epoch in range(num_epochs):
        model.train()
        epoch_losses = []
        epoch_diffs = []
        
        data_iter = iter(dataloader)
        pbar = tqdm(range(steps_per_epoch), desc=f'Epoch {epoch+1}/{num_epochs}')
        for step in pbar:
            try:
                batch = next(data_iter)
            except StopIteration:
                data_iter = iter(dataloader)
                batch = next(data_iter)
            
            num1, num2, operation, targets = [item.to(device) for item in batch]
            
            # num1 = num1.unsqueeze(1)
            # num2 = num2.unsqueeze(1)
            ###############################
            optimizer.zero_grad()

            operation_mapping = {0: "+", 1: "-", 2: "*", 3: "/"}

            # Decode the one-hot tensor into operation symbols
            decoded_operations = [operation_mapping[torch.argmax(op).item()] for op in operation]

            inp_txt = [
                f"{num1.item()} {op} {num2.item()}" for num1, op, num2 in zip(num1, decoded_operations, num2)
            ]

            # print("INp text is ", inp_txt)
            input_ids = tokenizer(inp_txt, return_tensors="pt", padding=True, truncation = True)["input_ids"].to(model.device)

            output_ids = model.generate(
                input_ids=input_ids,
                max_new_tokens = 1,
                num_return_sequences=1,  # Number of sequences to generate
                do_sample=True,  # Enable sampling
                top_k=50,  # Use top-k sampling
                temperature=0.7,  # Sampling temperature
            )

            predicted_ids = output_ids

            predictions = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)[-1]


            ################################

            try: 
                numeric_prediction = float(predictions)  # Only if it should be a number
                predictions_tensor = torch.tensor([numeric_prediction]).to(device)  # Convert to tensor
            except ValueError:
                # print(f"Decoded output is not numeric: {predictions}")
                predictions_tensor = torch.tensor([0.0], device=device, requires_grad=True)
                
            predictions =  predictions_tensor
            loss = arithmetic_loss(predictions, targets)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            
            epoch_losses.append(loss.item())
            
            with torch.no_grad():
                diffs = torch.abs(predictions - targets)
                epoch_diffs.extend(diffs.cpu().numpy())
            
            pbar.set_postfix({'Loss': loss.item()})
        
        with torch.no_grad():
            model.eval()

            test_num1, test_num2, test_op, test_targets = [item.to(device) for item in next(iter(dataloader))]

            ########################################

            operation_mapping = {0: "+", 1: "-", 2: "*", 3: "/"}

            # Decode the one-hot tensor into operation symbols
            decoded_operations = [operation_mapping[torch.argmax(test_op).item()] for test_op in operation]

            inp_txt = [
                f"{num1.item()} {test_op} {num2.item()}" for num1, op, num2 in zip(test_num1, test_op, test_num2)
            ]

            # print("INp text is ", inp_txt)
            input_ids = tokenizer(inp_txt, return_tensors="pt", padding=True, truncation = True)["input_ids"].to(model.device)

            output_ids = model.generate(
                input_ids=input_ids,
                max_new_tokens = 1,
                num_return_sequences=1,  # Number of sequences to generate
                do_sample=True,  # Enable sampling
                top_k=50,  # Use top-k sampling
                temperature=0.7,  # Sampling temperature
            )

            predicted_ids = output_ids

            predictions = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)[-1]

            ################################

            try: 
                numeric_prediction = float(predictions)  # Only if it should be a number
                predictions_tensor = torch.tensor([numeric_prediction]).to(device)  # Convert to tensor
            except ValueError:
                # print(f"Decoded output is not numeric: {predictions}")
                predictions_tensor = torch.tensor([0.0], device=device, requires_grad=True)
                
            test_pred =  predictions_tensor
            ####################################################
            
            test_loss = arithmetic_loss(test_pred, test_targets)
           
            first_pred = test_pred[0].item()
            first_target = test_targets[0].item()
            
            # Format to 5 decimal places
            first_pred_formatted = f"{first_pred:.5f}"
            first_target_formatted = f"{first_target:.5f}"
            
            current_lr = optimizer.param_groups[0]['lr']
            train_loss = np.mean(epoch_losses)
            val_loss = test_loss.item()
            avg_diff = np.mean(epoch_diffs)
            median_diff = np.median(epoch_diffs)
            
            if use_wandb:
                wandb.log({
                    'learning_rate': current_lr,
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'avg_prediction_diff': avg_diff,
                    'median_prediction_diff': median_diff,
                    'epoch': epoch + 1
                })
            else:
                log_data.append({
                    'epoch': epoch + 1,
                    'learning_rate': current_lr,
                    'train_loss': train_loss,
                    'val_loss': val_loss,
                    'avg_prediction_diff': avg_diff,
                    'median_prediction_diff': median_diff
                })
            
            print(
                f'Epoch {epoch+1}/{num_epochs} | '
                f'LR: {current_lr:.2e} | '
                f'Train Loss: {train_loss:.4f} | '
                f'Val Loss: {val_loss:.4f} | '
                f'Avg Diff: {avg_diff:.4f} | '
                f'First Pred: {first_pred_formatted} | '
                f'First Target: {first_target_formatted}'
            )
        
        model.train()
        
        # Save the best model
        if train_loss < best_loss:
            best_loss = train_loss
            torch.save(model.state_dict(), 'best_arithmetic_model.pt')
        
        scheduler.step()
        print(f'Epoch {epoch+1} completed. Average loss: {train_loss:.4f}\n')
    
    if not use_wandb:
        pd.DataFrame(log_data).to_csv(log_file, index=False)
        print(f"Training log saved to {log_file}")
    
    if use_wandb:
        wandb.finish()

In [None]:
train_model(customModel, num_epochs=20, batch_size=256, initial_lr=1e-4, device='cuda', use_wandb=True)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrichidubey[0m ([33mrichidubey-georgia-institute-of-technology[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  3.56it/s, Loss=2.05e+7]


Epoch 1/20 | LR: 1.00e-04 | Train Loss: 19971137.2000 | Val Loss: 19087596.0000 | Avg Diff: 258.9499 | First Pred: 0.00000 | First Target: 270.26447
Epoch 1 completed. Average loss: 19971137.2000



Epoch 2/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.13it/s, Loss=1.92e+7]


Epoch 2/20 | LR: 1.00e-04 | Train Loss: 19189882.6000 | Val Loss: 18600274.0000 | Avg Diff: 253.8488 | First Pred: 0.00000 | First Target: 383.72769
Epoch 2 completed. Average loss: 19189882.6000



Epoch 3/20: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [00:02<00:00,  4.19it/s, Loss=1.88e+7]


Epoch 3/20 | LR: 1.00e-04 | Train Loss: 19483807.2000 | Val Loss: 19240972.0000 | Avg Diff: 255.4323 | First Pred: 0.00000 | First Target: 289.17603
Epoch 3 completed. Average loss: 19483807.2000



Epoch 4/20:  10%|█████████▊                                                                                        | 1/10 [00:00<00:06,  1.36it/s, Loss=1.97e+7]


RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [23]:
model.eval()

#print("Device of input ids is", input_ids.device)
# with torch.no_grad():
#     logits = customModel(input_ids=input_ids)

min_val = 0
max_val = 256

num1 = torch.rand(1) * (max_val - min_val) + min_val
num2 = torch.rand(1) * (max_val - min_val) + min_val

input_text = f"{num1.item():.5f} + {num2.item():.5f} = "

input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].to(model.device)

output_ids = model.generate(
    input_ids=input_ids,
    max_length=13,  # Maximum length of generated text
    num_return_sequences=1,  # Number of sequences to generate
    do_sample=True,  # Enable sampling
    top_k=50,  # Use top-k sampling
    temperature=0.7,  # Sampling temperature
)

predicted_ids = output_ids

generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)

print("Generated Text:", generated_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generated Text: 12.76663 + 42.54543 = 
