# Training notebook for PyTorch based transformer

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torch
from torch.nn.functional import log_softmax
from tqdm import tqdm
import datetime


from training.dataset import create_datasets_and_loaders
from training.dateformattransformer import DateFormatTransformer

## Define accuracy function

In [3]:
def calculate_accuracy(outputs, targets, vocab):
    # Assuming outputs are logits
    predictions = outputs.argmax(dim=-1)
    
    # Adjust predictions or targets to match in length
    min_len = min(predictions.size(1), targets.size(1))
    predictions = predictions[:, :min_len]
    targets = targets[:, :min_len]
    
    correct = (predictions == targets).float()
    
    # Mask out padding tokens
    mask = (targets != vocab['<PAD>']).float()
    accuracy = (correct * mask).sum() / mask.sum()
    
    return accuracy.item()

## Define training function

This function will be used to train the model. The function uses a technique called Curriculum Learning, in which the model is trained on a small dataset first, and then on a larger dataset. This is done to ensure that the model is not overfitting to the small dataset.

In [4]:
def train_with_curriculum(model, dataloaders, num_epochs, device, criterion, optimizer, patience, vocab):
    difficulties = ['easy']
    
    for difficulty in difficulties:
        print(f"Training on {difficulty} dataset")
        best_val_loss = float('inf')
        epochs_without_improvement = 0
        
        for epoch in range(num_epochs):
            # Training
            model.train()
            total_train_loss = 0
            total_train_accuracy = 0
            num_train_batches = 0
            
            for batch in dataloaders[difficulty]['train']:
                inputs = batch['input'].to(device)
                targets = batch['output'].to(device)
                
                optimizer.zero_grad()
                output = model(inputs, targets[:, :-1])  # The model will create masks internally
                loss = criterion(output.contiguous().view(-1, len(vocab)), targets[:, 1:].contiguous().view(-1))
                loss.backward()
                optimizer.step()
                
                total_train_loss += loss.item()
                total_train_accuracy += calculate_accuracy(output, targets[:, 1:], vocab)
                num_train_batches += 1
            
            avg_train_loss = total_train_loss / num_train_batches
            avg_train_accuracy = total_train_accuracy / num_train_batches
            
            # Validation
            model.eval()
            total_val_loss = 0
            total_val_accuracy = 0
            num_val_batches = 0
            
            with torch.no_grad():
                for batch in dataloaders[difficulty]['val']:
                    inputs = batch['input'].to(device)
                    targets = batch['output'].to(device)
                    
                    output = model(inputs, targets[:, :-1])
                    loss = criterion(output.contiguous().view(-1, len(vocab)), targets[:, 1:].contiguous().view(-1))
                    
                    total_val_loss += loss.item()
                    total_val_accuracy += calculate_accuracy(output, targets[:, 1:], vocab)
                    num_val_batches += 1
            
            avg_val_loss = total_val_loss / num_val_batches
            avg_val_accuracy = total_val_accuracy / num_val_batches
            
            print(f"Epoch {epoch+1}/{num_epochs}")
            print(f"Train Loss: {avg_train_loss:.4f}, Train Accuracy: {avg_train_accuracy:.4f}")
            print(f"Val Loss: {avg_val_loss:.4f}, Val Accuracy: {avg_val_accuracy:.4f}")
            
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                epochs_without_improvement = 0
                torch.save(model.state_dict(), f'best_model_{difficulty}.pth')
                print("Saved new best model")
            else:
                epochs_without_improvement += 1
            
            if epochs_without_improvement >= patience:
                print(f"Early stopping triggered after {epoch+1} epochs")
                break
            
            print("-" * 50)


## Function to inspect PyTorch datasets and dataloader

In [5]:
def inspect_dataloader_data(dataloader, dataset_dict):
    print("Inspecting data...")
    for batch in dataloader:
        inputs = batch['input']
        targets = batch['output']
        
        print("\nSample batch:")
        original_dataset = dataset_dict['train'].dataset
        for i in range(min(5, inputs.size(0))):
            try:
                input_date = original_dataset.indices_to_date(inputs[i])
                target_date = original_dataset.indices_to_date(targets[i])
                print(f"Input: {input_date} | Target: {target_date}")
            except Exception as e:
                print(f"Error processing sample {i}: {str(e)}")
                print(f"Input indices: {inputs[i]}")
                print(f"Target indices: {targets[i]}")
        
        # Only process one batch
        break


## Helper function

In [6]:
def date_to_indices(date_string, vocab, max_length):
    """ Convert a date string to a sequence of indices. Indics are padded to max_length. """
    tokens = [vocab.get(char, vocab['<UNK>']) for char in date_string]
    tokens = [vocab['<START>']] + tokens + [vocab['<END>']]
    padding = [vocab['<PAD>']] * (max_length - len(tokens))
    return tokens + padding

## Test the model on test dataset

In [7]:
def test_model(model, test_dataloader, device, vocab):
    model.eval()
    correct = 0
    total = 0
    
    idx_to_char = {idx: char for char, idx in vocab.items()}
    
    print("Testing model...")
    with torch.no_grad():
        for batch in tqdm(test_dataloader):
            inputs = batch['input'].to(device)
            targets = batch['output'].to(device)
            
            outputs = model(inputs)
            _, predicted = torch.max(outputs, dim=-1)
            
            # Compare predictions with targets
            mask = (targets != vocab['<PAD>']) & (targets != vocab['<START>']) & (targets != vocab['<END>'])
            correct += ((predicted == targets) & mask).sum().item()
            total += mask.sum().item()
            
            # Print some examples
            if total % 1000 == 0:
                print("\nExample conversions:")
                for i in range(min(3, inputs.size(0))):
                    input_date = ''.join([idx_to_char[idx.item()] for idx in inputs[i] if idx.item() not in [vocab['<PAD>'], vocab['<START>'], vocab['<END>']]])
                    target_date = ''.join([idx_to_char[idx.item()] for idx in targets[i] if idx.item() not in [vocab['<PAD>'], vocab['<START>'], vocab['<END>']]])
                    pred_date = ''.join([idx_to_char[idx.item()] for idx in predicted[i] if idx.item() not in [vocab['<PAD>'], vocab['<START>'], vocab['<END>']]])
                    print(f"Input: {input_date} | Target: {target_date} | Predicted: {pred_date}")

    accuracy = correct / total
    print(f"\nTest Accuracy: {accuracy:.4f}")
    return accuracy


## Function to generate train, val and test dataset

In [9]:
def generate_date_dataset(samples: int):
    num_samples = 50000
    batch_size = 32

    # Generate datasets and create dataloaders
    datasets, dataloaders, vocab = create_datasets_and_loaders(num_samples, batch_size)

    print(f"Shared vocabulary size: {len(vocab)}")
    print("Printing a sample from each dataset..")
    for difficulty in ['easy', 'medium', 'hard']:
        print(f"\nInspecting {difficulty} dataset:")
        inspect_dataloader_data(dataloaders[difficulty]['train'], datasets[difficulty])
    
    return datasets, dataloaders, vocab


## Train and test the model

In [10]:
max_length = 24
# Move the model to the appropriate device
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

mode = "training" # "training" or "evaluation"
print(f"Mode: {mode}")
dataset_gen = True
#test_dataset_generation()

n = 50000
epochs = 50
patience = 5

# # Generate the datasets 
if dataset_gen:
    dataset, dataloaders , vocab = generate_date_dataset(n)
    
else:
    print("Loading dataset from file")
    # loaded_dataset = torch.load('date_dataset.pt')

    # Initialize the model
model = DateFormatTransformer(
    d_model=256, 
    ffn_hidden=512, 
    num_heads=8, 
    drop_prob=0.2, 
    num_layers=6,
    max_sequence_length=24,
    input_vocab_size=len(vocab),
    output_vocab_size=len(vocab),
    pad_idx=vocab['<PAD>']
).to(device)

if mode == "training":
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss(ignore_index=vocab['<PAD>'])
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience, factor=0.5)
    # Train the model
    train_with_curriculum(model, dataloaders, epochs, device, criterion, optimizer, patience, vocab)

else:
    model.load_state_dict(torch.load('best_model_easy.pth'))
    model.to(device)

# Test the model on 'easy' difficulty level 
for difficulty in ['easy']:
    print(f"\nTesting on {difficulty} dataset:")
    test_dataloader = dataloaders[difficulty]['test']
    idx_to_char = {idx: char for char, idx in vocab.items()}

    print("Inspecting test data...")
    for batch in test_dataloader:
        inputs = batch['input']
        targets = batch['output']
        
        print("\nSample batch:")
        for i in range(min(5, inputs.size(0))):
            input_date = ''.join([idx_to_char[idx.item()] for idx in inputs[i] if idx.item() not in [vocab['<PAD>'], vocab['<START>'], vocab['<END>']]])
            target_date = ''.join([idx_to_char[idx.item()] for idx in targets[i] if idx.item() not in [vocab['<PAD>'], vocab['<START>'], vocab['<END>']]])
            print(f"Input: {input_date} | Target: {target_date}")
        
        # Only process one batch
        break
    
test_model(model, test_dataloader, device, vocab)

Using device: mps
Mode: training

Generating easy dataset:
Raw data samples:
Input: 10/14/1957 | Output: 10/14/1957
Input: 2093-06-17 | Output: 06/17/2093
Input: 26/01/1914 | Output: 01/26/1914
Input: 2045-05-25 | Output: 05/25/2045
Input: 04/04/2038 | Output: 04/04/2038

Generating medium dataset:
Raw data samples:
Input: 10/06/2055 | Output: 10/06/2055
Input: 1949-11-25 | Output: 11/25/1949
Input: 26/05/2024 | Output: 05/26/2024
Input: May 22, 2050 | Output: 05/22/2050
Input: February 14, 2006 | Output: 02/14/2006

Generating hard dataset:
Raw data samples:
Input: October 20, 1921 | Output: 10/20/1921
Input: 2068/12/22 | Output: 12/22/2068
Input: 06/11/1967 | Output: 06/11/1967
Input: 2056/10/04 | Output: 10/04/2056
Input: 1957.07.04 | Output: 07/04/1957
Shared vocabulary size: 45
Printing a sample from each dataset..

Inspecting easy dataset:
Inspecting data...

Sample batch:
Input: 09/10/2069 | Target: 09/10/2069
Input: 03/01/1989 | Target: 03/01/1989
Input: 08/11/1933 | Target: 08

 11%|█         | 25/235 [00:08<00:53,  3.90it/s]


Example conversions:
Input: 08/29/1998 | Target: 08/29/1998 | Predicted: 896/29684
Input: 02/19/2002 | Target: 02/19/2002 | Predicted: /029/2064
Input: 03/26/1932 | Target: 03/26/1932 | Predicted: 30/26/1964


 21%|██▏       | 50/235 [00:14<00:47,  3.87it/s]


Example conversions:
Input: 15/08/1937 | Target: 08/15/1937 | Predicted: 85/15/1974
Input: 12/16/2069 | Target: 12/16/2069 | Predicted: /16/1/2094
Input: 2082-09-15 | Target: 09/15/2082 | Predicted: /05/1/2082


 32%|███▏      | 75/235 [00:20<00:41,  3.86it/s]


Example conversions:
Input: 01/06/2033 | Target: 06/01/2033 | Predicted: 16/06/2034
Input: 13/02/2035 | Target: 02/13/2035 | Predicted: /023/2054
Input: 06/05/1997 | Target: 06/05/1997 | Predicted: 56/05/1974


 43%|████▎     | 100/235 [00:26<00:34,  3.90it/s]


Example conversions:
Input: 16/02/2064 | Target: 02/16/2064 | Predicted: /026/2064
Input: 02/24/2092 | Target: 02/24/2092 | Predicted: /24/20964
Input: 2032-12-08 | Target: 12/08/2032 | Predicted: /06/2/2032


 53%|█████▎    | 125/235 [00:33<00:28,  3.89it/s]


Example conversions:
Input: 06/03/2028 | Target: 06/03/2028 | Predicted: /06/09/208
Input: 2037-09-15 | Target: 09/15/2037 | Predicted: /05/1/2037
Input: 04/08/1915 | Target: 04/08/1915 | Predicted: 854/04952


 64%|██████▍   | 150/235 [00:39<00:21,  3.89it/s]


Example conversions:
Input: 2046-07-11 | Target: 07/11/2046 | Predicted: /06/1/2046
Input: 1946-11-05 | Target: 11/05/1946 | Predicted: 16/05/1946
Input: 1979-04-29 | Target: 04/29/1979 | Predicted: /26/2/1979


 74%|███████▍  | 175/235 [00:45<00:15,  3.88it/s]


Example conversions:
Input: 05/04/2035 | Target: 04/05/2035 | Predicted: 54/04/2059
Input: 1976-03-24 | Target: 03/24/1976 | Predicted: /24/2/1976
Input: 07/07/1992 | Target: 07/07/1992 | Predicted: 707/09/14


 85%|████████▌ | 200/235 [00:51<00:08,  3.90it/s]


Example conversions:
Input: 10/07/1969 | Target: 10/07/1969 | Predicted: /10/19648
Input: 03/31/2073 | Target: 03/31/2073 | Predicted: 30/31/2064
Input: 08/11/1934 | Target: 11/08/1934 | Predicted: 16/08/1944


 96%|█████████▌| 225/235 [00:57<00:02,  3.92it/s]


Example conversions:
Input: 1922-03-18 | Target: 03/18/1922 | Predicted: /03/1/1922
Input: 29/06/1986 | Target: 06/29/1986 | Predicted: 6/29/1964
Input: 2027-01-22 | Target: 01/22/2027 | Predicted: /26/2/2027


100%|██████████| 235/235 [01:02<00:00,  3.74it/s]


Example conversions:
Input: 2096-08-03 | Target: 08/03/2096 | Predicted: 86/03/2096
Input: 07/02/2055 | Target: 07/02/2055 | Predicted: /02/0/2054
Input: 12/02/2023 | Target: 02/12/2023 | Predicted: /12/0/2034

Test Accuracy: 0.1771





0.17706666666666668