<a href="https://colab.research.google.com/github/stereifberger/logical_derivations_with_transformers/blob/from_scratch/implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Clone whole github repository with necessary files
!git clone https://github.com/stereifberger/logical_derivations_with_transformers
!cd /content/logical_derivations_with_transformers

Cloning into 'logical_derivations_with_transformers'...
remote: Enumerating objects: 70, done.[K
remote: Counting objects: 100% (70/70), done.[K
remote: Compressing objects: 100% (53/53), done.[K
remote: Total 70 (delta 30), reused 46 (delta 15), pack-reused 0 (from 0)[K
Receiving objects: 100% (70/70), 152.62 KiB | 2.06 MiB/s, done.
Resolving deltas: 100% (30/30), done.


In [None]:
import os
import csv
import random
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
# Import libraries
import sys
sys.path.append('logical_derivations_with_transformers')
from imports import *

In [None]:
# Set random seed for reproducibility
seed(42)
torch.manual_seed(42)

<torch._C.Generator at 0x7cd7460534b0>

In [None]:
from sklearn.model_selection import train_test_split

In [None]:
# Prepare data
total_samples = 40000
#batch_size = 32
data_inputs, data_targets, max_input_len, max_target_len = utils.prepare_data(num_samples=total_samples)
t_nu = 5
vocab_size = t_nu + 11  # Number of symbols including PAD, SOS, EOS
inputs_train, inputs_test, targets_train, targets_test = train_test_split(
    data_inputs, data_targets, test_size=0.2, random_state=42
)

In [None]:
import string

alphabet = string.ascii_lowercase  # Get lowercase letters 'a', 'b', 'c', ...

def convert_to_letters(data):
  """Converts a list of lists of integers to letters.

  Args:
    data: A list of lists of integers.

  Returns:
    A list of lists of letters.
  """
  converted_data = []
  for sublist in data:
    converted_sublist = [alphabet[i] for i in sublist]  # Map integers to letters
    converted_data.append(converted_sublist)
  return converted_data

# Convert the data
inputs_train = convert_to_letters(inputs_train)
inputs_test = convert_to_letters(inputs_test)
targets_train = convert_to_letters(targets_train)
targets_test = convert_to_letters(targets_test)

# Print an example to verify
print(inputs_train[0])  # Print the first element of inputs_train_letters

['h', 'b', 'm', 'd', 'i', 'h', 'h', 'h', 'b', 'k', 'b', 'i', 'l', 'c', 'i', 'k', 'c', 'i', 'g', 'h', 'h', 'f', 'k', 'e', 'i', 'l', 'h', 'b', 'm', 'd', 'i', 'i']


In [None]:
symb_reverse = {"a": "",
                "b": "p",
                "c": "q",
                "d": "r",
                "e": "s",
                "f": "t",
                "g": "⊢",
                "h": "(",
                "i": ")",
                "j": "¬",
                "k": "→",
                "l": "∨",
                "m": "∧",
                "n": "⊥",
                "o": "B",
                "p": "E",
                }

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def sublists_to_strings(sublists):
    # Join each sublist (of single characters) into a string
    # e.g. ['h','e','l','l','o'] -> "hello"
    return ["".join(sublist) for sublist in sublists]

train_input_strings = sublists_to_strings(inputs_train)
train_target_strings = sublists_to_strings(targets_train)
test_input_strings = sublists_to_strings(inputs_test)
test_target_strings = sublists_to_strings(targets_test)

In [None]:
class LetterDataset(Dataset):
    def __init__(self, input_texts, target_texts, tokenizer, max_length=64):
        self.input_texts = input_texts
        self.target_texts = target_texts
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, idx):
        input_text = self.input_texts[idx]
        target_text = self.target_texts[idx]

        # Tokenize input
        input_encodings = self.tokenizer(
            input_text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt",
        )
        # Tokenize target. We can treat this as a language-modeling problem
        # by providing the same input but using the actual target as labels.
        with self.tokenizer.as_target_tokenizer():
            target_encodings = self.tokenizer(
                target_text,
                truncation=True,
                max_length=self.max_length,
                padding="max_length",
                return_tensors="pt",
            )

        # Flatten from shape [1, seq_len] to [seq_len]
        input_ids = input_encodings["input_ids"].squeeze(0)
        attention_mask = input_encodings["attention_mask"].squeeze(0)

        # In causal language modeling, commonly we set labels to the next token in the sequence
        # But here (since it seems each sublist in input -> sublist in target),
        # we'll just treat the target as the label set directly.
        labels = target_encodings["input_ids"].squeeze(0)

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }


In [None]:
model_name = "Qwen/Qwen2.5-0.5B"  # Replace if needed
tokenizer = AutoTokenizer.from_pretrained(model_name)

train_dataset = LetterDataset(train_input_strings, train_target_strings, tokenizer)
test_dataset = LetterDataset(test_input_strings, test_target_strings, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=80, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=80, shuffle=False)

tokenizer_config.json:   0%|          | 0.00/7.23k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.03M [00:00<?, ?B/s]

In [None]:
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)

# Typical AdamW or similar
optimizer = optim.AdamW(model.parameters(), lr=1e-5)

# For language modeling, the model’s forward pass typically includes
# internal logic to handle cross-entropy, but we can also use our own:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

config.json:   0%|          | 0.00/681 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

In [None]:
def train_one_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    total_correct = 0
    total_tokens = 0

    for batch in tqdm(dataloader):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
        )
        loss = outputs.loss
        logits = outputs.logits  # [batch_size, seq_len, vocab_size]

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Compute total loss
        total_loss += loss.item() * input_ids.size(0)

        # Compute accuracy
        # We'll say that accuracy means token-by-token correctness
        # ignoring padded positions
        preds = torch.argmax(logits, dim=-1)
        mask = labels.ne(tokenizer.pad_token_id)
        correct = (preds == labels) & mask
        total_correct += correct.sum().item()
        total_tokens += mask.sum().item()

    avg_loss = total_loss / len(dataloader.dataset)
    if total_tokens == 0:
        avg_acc = 0.0
    else:
        avg_acc = total_correct / total_tokens

    return avg_loss, avg_acc

In [None]:
def evaluate(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    total_correct = 0
    total_tokens = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels,
            )
            loss = outputs.loss
            logits = outputs.logits

            # Accumulate loss
            total_loss += loss.item() * input_ids.size(0)

            # Accuracy
            preds = torch.argmax(logits, dim=-1)
            mask = labels.ne(tokenizer.pad_token_id)
            correct = (preds == labels) & mask
            total_correct += correct.sum().item()
            total_tokens += mask.sum().item()

    avg_loss = total_loss / len(dataloader.dataset)
    if total_tokens == 0:
        avg_acc = 0.0
    else:
        avg_acc = total_correct / total_tokens

    return avg_loss, avg_acc

In [None]:
def generate_str_predictions(model, tokenizer, input_strs, max_new_tokens=64):
    """
    Given a list of input strings, generate the model's output strings.
    """
    model.eval()
    preds = []

    with torch.no_grad():
        for inp in input_strs:
            # Tokenize input
            encoding = tokenizer(inp, return_tensors="pt").to(device)
            # Generate using the model
            output_ids = model.generate(
                **encoding,
                max_new_tokens=max_new_tokens,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id  # Ensure proper padding
            )
            # Decode: Remove the input prompt from the generated output
            generated_text = tokenizer.decode(output_ids[0][len(encoding["input_ids"][0]):], skip_special_tokens=True)
            preds.append(generated_text)

    return preds

In [None]:
num_epochs =  8
train_losses, test_losses = [], []
train_accuracies, test_accuracies = [], []

os.makedirs("model_checkpoints", exist_ok=True)
os.makedirs("csv_outputs", exist_ok=True)
os.makedirs("plots", exist_ok=True)

for epoch in range(1, num_epochs + 1):
    print(f"Epoch {epoch}/{num_epochs}")

    # -----------------------------
    # Train
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, device)
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)

    # -----------------------------
    # Evaluate
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    test_losses.append(test_loss)
    test_accuracies.append(test_acc)

    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Test  Loss: {test_loss:.4f}, Test  Acc: {test_acc:.4f}")

    # -----------------------------
    # Save model checkpoint
    checkpoint_path = f"model_checkpoints/epoch_{epoch}"
    model.save_pretrained(checkpoint_path)
    tokenizer.save_pretrained(checkpoint_path)
    print(f"  Saved model checkpoint to: {checkpoint_path}")

    # -----------------------------
    # Generate predictions for a sample of 10 training examples
    # and 50 testing examples (or all if less than 10).
    sample_train_indices = sample(range(len(train_input_strings)),
                                         min(10, len(train_input_strings)))
    sample_test_indices = sample(range(len(test_input_strings)),
                                        min(10, len(test_input_strings)))

    sample_train_inputs = [train_input_strings[i] for i in sample_train_indices]
    sample_train_targets = [train_target_strings[i] for i in sample_train_indices]
    sample_train_preds = generate_str_predictions(model, tokenizer, sample_train_inputs)

    sample_test_inputs = [test_input_strings[i] for i in sample_test_indices]
    sample_test_targets = [test_target_strings[i] for i in sample_test_indices]
    sample_test_preds = generate_str_predictions(model, tokenizer, sample_test_inputs)

    # Write to CSV
    csv_filename = f"csv_outputs/epoch_{epoch}_predictions.csv"
    with open(csv_filename, mode='w', newline='', encoding='utf-8') as f:
        writer = csv.writer(f)
        writer.writerow(["Set", "Input", "Target", "Prediction"])
        # Train rows
        for inp, tgt, pred in zip(sample_train_inputs, sample_train_targets, sample_train_preds):
            inp = "".join([symb_reverse[i] for i in inp])
            tgt = "".join([symb_reverse[i] for i in tgt])
            pred = "".join([symb_reverse[i] for i in pred])
            writer.writerow(["TRAIN", inp, tgt, pred])
        # Test rows
        for inp, tgt, pred in zip(sample_test_inputs, sample_test_targets, sample_test_preds):
            inp = "".join([symb_reverse[i] for i in inp])
            tgt = "".join([symb_reverse[i] for i in tgt])
            pred = "".join([symb_reverse[i] for i in pred])
            writer.writerow(["TEST", inp, tgt, pred])
    print(f"  Wrote sample predictions to: {csv_filename}")

    # -----------------------------
    # Plot and save train/test loss and accuracy
    epochs_range = range(1, epoch + 1)

    # Plot Loss
    plt.figure(figsize=(8, 6))
    plt.plot(epochs_range, train_losses, label='Train Loss')
    plt.plot(epochs_range, test_losses, label='Test Loss')
    plt.title(f'Loss up to epoch {epoch}')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(f"plots/loss_epoch_{epoch}.png")
    plt.close()

    # Plot Accuracy
    plt.figure(figsize=(8, 6))
    plt.plot(epochs_range, train_accuracies, label='Train Accuracy')
    plt.plot(epochs_range, test_accuracies, label='Test Accuracy')
    plt.title(f'Accuracy up to epoch {epoch}')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.savefig(f"plots/accuracy_epoch_{epoch}.png")
    plt.close()

    print("  Plots updated.\n")


print("Training complete!")

Epoch 1/8


100%|██████████| 400/400 [06:52<00:00,  1.03s/it]


  Train Loss: 1.0021, Train Acc: 0.0712
  Test  Loss: 0.8376, Test  Acc: 0.0695
  Saved model checkpoint to: model_checkpoints/epoch_1
  Wrote sample predictions to: csv_outputs/epoch_1_predictions.csv
  Plots updated.

Epoch 2/8


100%|██████████| 400/400 [06:51<00:00,  1.03s/it]


  Train Loss: 0.8136, Train Acc: 0.0761
  Test  Loss: 0.8052, Test  Acc: 0.0643
  Saved model checkpoint to: model_checkpoints/epoch_2
  Wrote sample predictions to: csv_outputs/epoch_2_predictions.csv
  Plots updated.

Epoch 3/8


100%|██████████| 400/400 [06:51<00:00,  1.03s/it]


  Train Loss: 0.7684, Train Acc: 0.0748
  Test  Loss: 0.7650, Test  Acc: 0.0742
  Saved model checkpoint to: model_checkpoints/epoch_3
  Wrote sample predictions to: csv_outputs/epoch_3_predictions.csv
  Plots updated.

Epoch 4/8


100%|██████████| 400/400 [06:51<00:00,  1.03s/it]


  Train Loss: 0.7353, Train Acc: 0.0733
  Test  Loss: 0.7521, Test  Acc: 0.0621
  Saved model checkpoint to: model_checkpoints/epoch_4
  Wrote sample predictions to: csv_outputs/epoch_4_predictions.csv
  Plots updated.

Epoch 5/8


100%|██████████| 400/400 [06:51<00:00,  1.03s/it]


  Train Loss: 0.7014, Train Acc: 0.0714
  Test  Loss: 0.7517, Test  Acc: 0.0695
  Saved model checkpoint to: model_checkpoints/epoch_5
  Wrote sample predictions to: csv_outputs/epoch_5_predictions.csv
  Plots updated.

Epoch 6/8


100%|██████████| 400/400 [06:51<00:00,  1.03s/it]


  Train Loss: 0.6645, Train Acc: 0.0711
  Test  Loss: 0.7656, Test  Acc: 0.0741
  Saved model checkpoint to: model_checkpoints/epoch_6
  Wrote sample predictions to: csv_outputs/epoch_6_predictions.csv
  Plots updated.

Epoch 7/8


100%|██████████| 400/400 [06:51<00:00,  1.03s/it]


  Train Loss: 0.6231, Train Acc: 0.0707
  Test  Loss: 0.7757, Test  Acc: 0.0740
  Saved model checkpoint to: model_checkpoints/epoch_7
  Wrote sample predictions to: csv_outputs/epoch_7_predictions.csv
  Plots updated.

Epoch 8/8


100%|██████████| 400/400 [06:51<00:00,  1.03s/it]


  Train Loss: 0.5791, Train Acc: 0.0727
  Test  Loss: 0.8155, Test  Acc: 0.0849
  Saved model checkpoint to: model_checkpoints/epoch_8
  Wrote sample predictions to: csv_outputs/epoch_8_predictions.csv
  Plots updated.

Training complete!


In [None]:
from google.colab import files
import os

# Specify the directory you want to download files from
directory_to_download = "/content/logical_derivations_with_transformers"

# Iterate through all files in the directory
for filename in os.listdir(directory_to_download):
    file_path = os.path.join(directory_to_download, filename)
    # Check if it's a file (not a directory)
    if os.path.isfile(file_path):
        # Download the file using files.download
        files.download(file_path)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
from google.colab import files
import os

# Specify the directory you want to download files from
directory_to_download = "/content/plots"

# Iterate through all files in the directory
for filename in os.listdir(directory_to_download):
    file_path = os.path.join(directory_to_download, filename)
    # Check if it's a file (not a directory)
    if os.path.isfile(file_path):
        # Download the file using files.download
        files.download(file_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>