In [None]:
import random
from random import randint, choice, random, seed, sample, shuffle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Set random seed for reproducibili
seed(42)
torch.manual_seed(42)

# Number of propositional variables
t_nu = 5  # Variables: p, q, r, s, t

# Define numerical values for symbols
symb = {
    "DE": t_nu + 1,  # ⊢ DERIVES (not used in formula generation)
    "LB": t_nu + 2,  # (
    "RB": t_nu + 3,  # )
    "NO": t_nu + 4,  # ¬
    "TH": t_nu + 5,  # →
    "OR": t_nu + 6,  # ∨
    "AN": t_nu + 7,  # ∧
    "FA": t_nu + 8   # ⊥
}

# Reverse mapping from numerical values to symbols
symb_reverse = {0: ""}
for i in range(1, t_nu + 1):
    symb_reverse[i] = chr(ord('p') + i - 1)
symb_reverse.update({
    t_nu + 1: "⊢",
    t_nu + 2: "(",
    t_nu + 3: ")",
    t_nu + 4: "¬",
    t_nu + 5: "→",
    t_nu + 6: "∨",
    t_nu + 7: "∧",
    t_nu + 8: "⊥"
})

# Mapping from symbols to numerical values
symb_map = {v: k for k, v in symb_reverse.items()}

# Function to generate a random propositional variable
def rd_f(t_nu):
    return randint(1, t_nu)

# Function to recursively generate well-formed formulas (WFFs)
def gen_wff(form, depth, max_depth=3):
    if depth >= max_depth or random() < 0.6:
        return form
    else:
        rule = choice(wff_rules)
        subform = rule(form, depth)
        return subform

# Rules for generating well-formed formulas
def cona(form1, depth):  # Conjunction A
    return [form1, symb["AN"], gen_wff(rd_f(t_nu), depth + 1)]

def conb(form1, depth):  # Conjunction B
    return [gen_wff(rd_f(t_nu), depth + 1), symb["AN"], form1]

def disa(form1, depth):  # Disjunction A
    return [form1, symb["OR"], gen_wff(rd_f(t_nu), depth + 1)]

def disb(form1, depth):  # Disjunction B
    return [gen_wff(rd_f(t_nu), depth + 1), symb["OR"], form1]

def th_a(form1, depth):  # Implication A
    return [form1, symb["TH"], gen_wff(rd_f(t_nu), depth + 1)]

def th_b(form1, depth):  # Implication B
    return [gen_wff(rd_f(t_nu), depth + 1), symb["TH"], form1]

def neg(form1, depth):  # Negation
    return [symb["NO"], form1]

wff_rules = [cona, conb, disa, disb, th_a, th_b, neg]

# Function to convert a formula to a string
def formula_to_string(f):
    if isinstance(f, int):
        return symb_reverse[f]
    elif isinstance(f, list):
        if len(f) == 2 and f[0] == symb["NO"]:  # Negation
            return symb_reverse[symb["NO"]] + formula_to_string(f[1])
        elif len(f) == 3:
            left = formula_to_string(f[0])
            op = symb_reverse[f[1]]
            right = formula_to_string(f[2])
            return f'({left} {op} {right})'
        else:
            return ''.join(formula_to_string(subf) for subf in f)
    else:
        return str(f)

# Derivation rules for IPL (Intuitionistic Propositional Logic)
def fa_e(prem):  # Falsum Elimination: From ⊥, derive any formula
    return rd_f(t_nu)

def no_e(prem):  # Negation Elimination: From ¬A and A, derive ⊥
    return symb["FA"]

def n_ia(premises):  # NEGATION INTRODUCTION
    # From A ⊢ ⊥ infer ¬A
    return [symb["NO"], premises[0]]

def an_i(prem):  # Conjunction Introduction: From A and B, derive A ∧ B
    return [prem[0], symb["AN"], prem[1]]

def a_ea(prem):  # Conjunction Elimination A: From A ∧ B, derive A
    return prem[0][0]

def a_eb(prem):  # Conjunction Elimination B: From A ∧ B, derive B
    return prem[0][2]

def t_ea(prem):  # Implication Elimination A (Modus Ponens): From A and A → B, derive B
    return prem[1][2]

def t_eb(prem):  # Implication Elimination B: From A → B and A, derive B
    return prem[0][2]

def th_i(prem):  # Implication Introduction: From assumption A to derive B, infer A → B
    return [prem[0], symb["TH"], prem[1]]

def o_ia(prem):  # Disjunction Introduction A: From A, derive A ∨ B
    return [prem[0], symb["OR"], gen_wff(rd_f(t_nu), 0)]

def o_ib(prem):  # Disjunction Introduction B: From A, derive B ∨ A
    return [gen_wff(rd_f(t_nu), 0), symb["OR"], prem[0]]

# Additional rule for Classical Propositional Logic (CPL)
def d_ne(prem):  # Double Negation Elimination: From ¬¬A, derive A
    return prem[0][1]

# List of rules for IPL and CPL
ipl_rules = [fa_e, no_e, n_ia, an_i, a_ea, a_eb, t_ea, t_eb, th_i, o_ia, o_ib]
cpl_rules = ipl_rules + [d_ne]

# Function to check applicability of a rule to premises
def check(rule, prem):
    if len(prem) == 1:
        if rule == fa_e:  # Falsum Elimination
            if prem[0] == symb["FA"]:
                return True
        if rule in [a_ea, a_eb]:  # Conjunction Elimination
            if isinstance(prem[0], list) and len(prem[0]) == 3 and prem[0][1] == symb["AN"]:
                return True
        if rule in [o_ia, o_ib]:  # Disjunction Introduction
            return True
        if rule == d_ne:  # Double Negation Elimination
            if isinstance(prem[0], list) and len(prem[0]) == 2 and prem[0][0] == symb["NO"]:
                subf = prem[0][1]
                if isinstance(subf, list) and len(subf) == 2 and subf[0] == symb["NO"]:
                    return True
    elif len(prem) == 2:
        if symb["FA"] in prem:
            if rule in [n_ia]:  # Negation Introduction
                return True
        else:
            if rule == no_e:  # Negation Elimination
                if (is_negation(prem[0]) and prem[0][1] == prem[1]) or \
                   (is_negation(prem[1]) and prem[1][1] == prem[0]):
                    return True
            if rule == an_i:  # Conjunction Introduction
                return True
            if rule == t_ea:  # Implication Elimination A
                if is_implication(prem[1]) and prem[1][0] == prem[0]:
                    return True
            if rule == t_eb:  # Implication Elimination B
                if is_implication(prem[0]) and prem[0][0] == prem[1]:
                    return True
            if rule == th_i:  # Implication Introduction
                return True
    return False

# Helper functions
def is_negation(f):
    return isinstance(f, list) and len(f) == 2 and f[0] == symb["NO"]

def is_implication(f):
    return isinstance(f, list) and len(f) == 3 and f[1] == symb["TH"]

# Function to get applicable rules for given premises
def get_applicable_rules(premises, rules):
    applicable_rules = []
    for rule in rules:
        if check(rule, premises):
            applicable_rules.append(rule)
    return applicable_rules

# Function to generate a derivation
def generate_derivation(rules, max_steps=2):
    premises = [gen_wff(rd_f(t_nu), 0), gen_wff(rd_f(t_nu), 0)]
    formulas = premises.copy()
    num_steps = randint(2, max_steps)
    for _ in range(num_steps):
        # Choose 1 or 2 formulas from previous formulas
        num_premises = choice([1, 2])
        if len(formulas) < num_premises:
            num_premises = len(formulas)
        selected_premises = sample(formulas, num_premises)
        # Get applicable rules
        applicable_rules = get_applicable_rules(selected_premises, rules)
        if applicable_rules:
            rule = choice(applicable_rules)
            # Apply the rule
            new_formula = rule(selected_premises)
            formulas.append(new_formula)
        else:
            # No applicable rules, stop derivation
            break
    return [premises, t_nu + 1, formulas[-1]]

# Function to flatten formulas into tokens for model input
def flatten_formula(f):
    if isinstance(f, int):
        return [f]
    elif isinstance(f, list):
        tokens = []
        if len(f) == 2 and f[0] == symb["NO"]:  # Negation
            tokens.append(f[0])
            tokens.extend(flatten_formula(f[1]))
        elif len(f) == 3:
            tokens.append(symb["LB"])
            tokens.extend(flatten_formula(f[0]))
            tokens.append(f[1])
            tokens.extend(flatten_formula(f[2]))
            tokens.append(symb["RB"])
        else:
            for subf in f:
                tokens.extend(flatten_formula(subf))
        return tokens
    else:
        return []

# Prepare data for the transformer model
def prepare_data(num_samples=1000):
    data = []
    for _ in range(num_samples):
        derivation = generate_derivation(ipl_rules)
        if derivation:
            flattened_derivation = flatten_formula(derivation)
            data.append(flattened_derivation)  # Remove last separator
    # Pad sequences to the same length
    max_len = max(len(seq) for seq in data)
    padded_data = []
    for seq in data:
        padded_seq = seq + [0] * (max_len - len(seq))  # Pad with 0s
        padded_data.append(padded_seq)
    return padded_data

# Transformer Model Components
class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, emb_dim, num_heads, hidden_dim, num_layers, dropout, max_seq_len=500):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, emb_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, src):
        embedded = self.embedding(src) + self.positional_encoding[:, :src.size(1), :].to(src.device)
        embedded = embedded.permute(1, 0, 2)  # [seq_len, batch_size, emb_dim]
        output = self.transformer_encoder(embedded)
        return output

class TransformerDecoder(nn.Module):
    def __init__(self, output_dim, emb_dim, num_heads, hidden_dim, num_layers, dropout, max_seq_len=500):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, emb_dim))
        decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(emb_dim, output_dim)

    def forward(self, tgt, memory):
        embedded = self.embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :].to(tgt.device)
        embedded = embedded.permute(1, 0, 2)  # [seq_len, batch_size, emb_dim]
        output = self.transformer_decoder(embedded, memory)
        output = output.permute(1, 0, 2)  # [batch_size, seq_len, emb_dim]
        prediction = self.fc_out(output)
        return prediction

class Seq2SeqTransformer(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2SeqTransformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg_len):
        memory = self.encoder(src)
        # Start token assumed to be index 0
        tgt = torch.zeros((src.size(0), trg_len), device=self.device, dtype=torch.long)
        output = self.decoder(tgt, memory)
        return output

# Training the model
def train_model(model, data_loader, num_epochs=10, batch_size=32, learning_rate=1e-4):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    model.train()
    device = model.device
    for epoch in range(num_epochs):
        for batch in data_loader:
            src = torch.tensor(batch, dtype=torch.long).to(device)
            max_len = len(src[0])
            optimizer.zero_grad()
            output = model(src, max_len)
            # Compute loss based on syntactic distance
            loss = loss_function(output)
            loss.backward()
            optimizer.step()
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {loss.item()}')

# Putting it all together
def main():
    # Prepare data
    data = prepare_data(num_samples=1000)
    vocab_size = t_nu + 9  # Number of symbols

    # Model parameters
    input_dim = vocab_size
    output_dim = vocab_size
    emb_dim = 256
    num_heads = 8
    hidden_dim = 512
    num_layers = 3
    dropout = 0.1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize the model
    encoder = TransformerEncoder(input_dim, emb_dim, num_heads, hidden_dim, num_layers, dropout)
    decoder = TransformerDecoder(output_dim, emb_dim, num_heads, hidden_dim, num_layers, dropout)
    model = Seq2SeqTransformer(encoder, decoder, device).to(device)

    batch_size = 32
    data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=lambda x: x) # Custom collate_fn

    # Train the model
    train_model(model, data_loader, num_epochs=10, batch_size=32, learning_rate=1e-4)

    # Testing the model
    model.eval()
    with torch.no_grad():
        test_premises, test_derivation_steps = generate_derivation(ipl_rules)
        if test_derivation_steps:
            input_formulas = test_premises + [test_derivation_steps[-1]['conclusion']]
            input_tokens = []
            for f in input_formulas:
                input_tokens.extend(flatten_formula(f))
                input_tokens.append(symb["DE"])
            src = torch.tensor([input_tokens[:-1]], dtype=torch.long).to(device)
            trg_len = 50  # Maximum length of generated sequence
            output = model(src, trg_len)
            predicted_tokens = output.argmax(dim=-1).squeeze(0).tolist()
            # Convert tokens to formulas
            predicted_formulas = []
            current_formula = []
            for tok in predicted_tokens:
                if tok == symb["DE"]:
                    if current_formula:
                        predicted_formulas.append(current_formula)
                        current_formula = []
                else:
                    current_formula.append(tok)
            if current_formula:
                predicted_formulas.append(current_formula)
            # Print the predicted derivation
            print("\nPredicted Derivation:")
            for f_tokens in predicted_formulas:
                f = reconstruct_formula(f_tokens)
                print(formula_to_string(f))

# Function to reconstruct formula from tokens
def reconstruct_formula(tokens):
    stack = []
    i = 0
    while i < len(tokens):
        tok = tokens[i]
        if tok == symb["LB"]:
            stack.append(symb["LB"])
            i += 1
        elif tok == symb["RB"]:
            # Pop elements until '('
            sub_formula = []
            while stack and stack[-1] != symb["LB"]:
                sub_formula.insert(0, stack.pop())
            if stack and stack[-1] == symb["LB"]:
                stack.pop()  # Remove '('
            stack.append(sub_formula)
            i += 1
        else:
            stack.append(tok)
            i += 1
    # The final formula is on the stack
    if stack:
        return stack[0]
    else:
        return None

In [None]:
if __name__ == "__main__":
    main()

[[2, 10, 2], 2, 2, 2, 2, [2, 10, 2], 2, 2, 2, 2]
[[2, 10, 2], 2, 2, 2, 2, 2, 2, 2, 2, [2, 10, 2]]
[2, [2, 10, 2], 2, 2, 2, 2, 2, 2, 2, 2]
[[None, 10, 2], 2, 2, [2, 10, [2, 10, 2]], 2, 2, 2, 2, 2, [2, 10, [2, 10, 2]]]


KeyboardInterrupt: 

# Teacher

In [None]:
from torch.nn.utils.rnn import pad_sequence

In [None]:
import random
from random import randint, choice, random, seed, sample, shuffle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# Set random seed for reproducibility
seed(42)
torch.manual_seed(42)

# Number of propositional variables
t_nu = 5  # Variables: p, q, r, s, t

# Define numerical values for symbols
symb = {
    "DE": t_nu + 1,  # ⊢ DERIVES (not used in formula generation)
    "LB": t_nu + 2,  # (
    "RB": t_nu + 3,  # )
    "NO": t_nu + 4,  # ¬
    "TH": t_nu + 5,  # →
    "OR": t_nu + 6,  # ∨
    "AN": t_nu + 7,  # ∧
    "FA": t_nu + 8,  # ⊥
    "PAD": 0,        # Padding token
    "SOS": t_nu + 9, # Start of sequence
    "EOS": t_nu + 10 # End of sequence
}

# Reverse mapping from numerical values to symbols
symb_reverse = {0: ""}
for i in range(1, t_nu + 1):
    symb_reverse[i] = chr(ord('p') + i - 1)
symb_reverse.update({
    t_nu + 1: "⊢",
    t_nu + 2: "(",
    t_nu + 3: ")",
    t_nu + 4: "¬",
    t_nu + 5: "→",
    t_nu + 6: "∨",
    t_nu + 7: "∧",
    t_nu + 8: "⊥",
    t_nu + 9: "<SOS>",
    t_nu + 10: "<EOS>"
})

# Mapping from symbols to numerical values
symb_map = {v: k for k, v in symb_reverse.items()}

# Function to generate a random propositional variable
def rd_f(t_nu):
    return randint(1, t_nu)

# Function to recursively generate well-formed formulas (WFFs)
def gen_wff(form=None, depth=0, max_depth=3):
    if depth >= max_depth or random() < 0.6:
        if form is None:
            return rd_f(t_nu)
        else:
            return form
    else:
        if form is None:
            form = gen_wff(depth=depth + 1, max_depth=max_depth)
        rule = choice(wff_rules)
        subform = rule(form, depth)
        return subform

# Rules for generating well-formed formulas
def cona(form1, depth):  # Conjunction A
    return [form1, symb["AN"], gen_wff(None, depth + 1)]
def conb(form1, depth):  # Conjunction B
    return [gen_wff(None, depth + 1), symb["AN"], form1]
def disa(form1, depth):  # Disjunction A
    return [form1, symb["OR"], gen_wff(None, depth + 1)]
def disb(form1, depth):  # Disjunction B
    return [gen_wff(None, depth + 1), symb["OR"], form1]
def th_a(form1, depth):  # Implication A
    return [form1, symb["TH"], gen_wff(None, depth + 1)]
def th_b(form1, depth):  # Implication B
    return [gen_wff(None, depth + 1), symb["TH"], form1]
def neg(form1, depth):    # Negation
    return [symb["NO"], form1]

wff_rules = [cona, conb, disa, disb, th_a, th_b, neg]

# Function to convert a formula to a string
def formula_to_string(f):
    if isinstance(f, int):
        return symb_reverse[f]
    elif isinstance(f, list):
        if len(f) == 2 and f[0] == symb["NO"]:  # Negation
            return symb_reverse[symb["NO"]] + formula_to_string(f[1])
        elif len(f) == 3:
            left = formula_to_string(f[0])
            op = symb_reverse[f[1]]
            right = formula_to_string(f[2])
            return f'({left} {op} {right})'
        else:
            return ''.join(formula_to_string(subf) for subf in f)
    else:
        return str(f)

# Derivation rules for IPL (Intuitionistic Propositional Logic)
def fa_e(prem):  # Falsum Elimination: From ⊥, derive any formula
    return gen_wff()
def no_e(prem):  # Negation Elimination: From ¬A and A, derive ⊥
    return symb["FA"]
def n_ia(premises):  # Negation Introduction
    # From A ⊢ ⊥ infer ¬A
    return [symb["NO"], premises[0]]
def an_i(prem):  # Conjunction Introduction: From A and B, derive A ∧ B
    return [prem[0], symb["AN"], prem[1]]
def a_ea(prem):  # Conjunction Elimination A: From A ∧ B, derive A
    return prem[0][0]
def a_eb(prem):  # Conjunction Elimination B: From A ∧ B, derive B
    return prem[0][2]
def t_ea(prem):  # Implication Elimination A (Modus Ponens): From A and A → B, derive B
    return prem[1][2]
def t_eb(prem):  # Implication Elimination B: From A → B and A, derive B
    return prem[0][2]
def th_i(prem):  # Implication Introduction: From assumption A to derive B, infer A → B
    return [prem[0], symb["TH"], prem[1]]
def o_ia(prem):  # Disjunction Introduction A: From A, derive A ∨ B
    return [prem[0], symb["OR"], gen_wff()]
def o_ib(prem):  # Disjunction Introduction B: From A, derive B ∨ A
    return [gen_wff(), symb["OR"], prem[0]]
# Additional rule for Classical Propositional Logic (CPL)
def d_ne(prem):  # Double Negation Elimination: From ¬¬A, derive A
    return prem[0][1]

# List of rules for IPL and CPL
ipl_rules = [fa_e, no_e, n_ia, an_i, a_ea, a_eb, t_ea, t_eb, th_i, o_ia, o_ib]
cpl_rules = ipl_rules + [d_ne]

# Helper functions
def is_negation(f):
    return isinstance(f, list) and len(f) == 2 and f[0] == symb["NO"]
def is_implication(f):
    return isinstance(f, list) and len(f) == 3 and f[1] == symb["TH"]
def is_conjunction(f):
    return isinstance(f, list) and len(f) == 3 and f[1] == symb["AN"]
def is_double_negation(f):
    return is_negation(f) and is_negation(f[1])

# Function to check applicability of a rule to premises
def check(rule, prem):
    if len(prem) == 1:
        if rule == fa_e:  # Falsum Elimination
            if prem[0] == symb["FA"]:
                return True
        if rule in [a_ea, a_eb]:  # Conjunction Elimination
            if is_conjunction(prem[0]):
                return True
        if rule in [o_ia, o_ib]:  # Disjunction Introduction
            return True
        if rule == d_ne:  # Double Negation Elimination
            if is_double_negation(prem[0]):
                return True
    elif len(prem) == 2:
        if rule == n_ia:  # Negation Introduction
            # From A ⊢ ⊥ infer ¬A. Here we assume prem[1] is ⊥
            if prem[1] == symb["FA"]:
                return True
        elif rule == no_e:  # Negation Elimination
            if (is_negation(prem[0]) and prem[0][1] == prem[1]) or (is_negation(prem[1]) and prem[1][1] == prem[0]):
                return True
        elif rule == an_i:  # Conjunction Introduction
            return True
        elif rule == t_ea:  # Implication Elimination A
            if is_implication(prem[1]) and prem[1][0] == prem[0]:
                return True
        elif rule == t_eb:  # Implication Elimination B
            if is_implication(prem[0]) and prem[0][0] == prem[1]:
                return True
        elif rule == th_i:  # Implication Introduction
            # From assumption prem[0] to derive prem[1], infer prem[0] → prem[1]
            return True
    return False

# Function to get applicable rules for given premises
def get_applicable_rules(premises, rules):
    applicable_rules = []
    for rule in rules:
        if check(rule, premises):
            applicable_rules.append(rule)
    return applicable_rules

# Function to generate a derivation
def generate_derivation(rules, max_steps=2):
    formulas = []
    derivation_steps = []
    # Start with initial formulas (premises)
    num_premises = randint(1, 2)
    for _ in range(num_premises):
        formula = gen_wff()
        formulas.append(formula)
        derivation_steps.append({'premises': [], 'conclusion': formula, 'rule': 'Premise'})

    num_steps = randint(1, max_steps)
    for _ in range(num_steps):
        # Choose 1 or 2 formulas from previous formulas
        num_premises = choice([1, 2])
        if len(formulas) < num_premises:
            num_premises = len(formulas)
        selected_premises = sample(formulas, num_premises)

        # Get applicable rules
        applicable_rules = get_applicable_rules(selected_premises, rules)
        if applicable_rules:
            rule = choice(applicable_rules)
            # Apply the rule
            new_formula = rule(selected_premises)
            formulas.append(new_formula)
            derivation_steps.append({'premises': selected_premises, 'conclusion': new_formula, 'rule': rule.__name__})
        else:
            # No applicable rules, stop derivation
            break
    premises = [step['conclusion'] for step in derivation_steps if step['rule'] == 'Premise']
    conclusion = derivation_steps[-1]['conclusion']
    return premises, conclusion, derivation_steps

# Function to flatten formulas into tokens
def flatten_formula(f):
    if isinstance(f, int):
        return [f]
    elif isinstance(f, list):
        tokens = []
        if len(f) == 2 and f[0] == symb["NO"]:  # Negation
            tokens.append(symb["NO"])
            tokens.extend(flatten_formula(f[1]))
        elif len(f) == 3:
            tokens.append(symb["LB"])
            tokens.extend(flatten_formula(f[0]))
            tokens.append(f[1])  # Operator
            tokens.extend(flatten_formula(f[2]))
            tokens.append(symb["RB"])
        else:
            for subf in f:
                tokens.extend(flatten_formula(subf))
        return tokens
    else:
        return []

# Prepare data for the transformer model
def prepare_data(num_samples=1000):
    data = []
    for _ in range(num_samples):
        premises, conclusion, derivation_steps = generate_derivation(ipl_rules)
        if derivation_steps:
            input_tokens = []
            for premise in premises:
                input_tokens.extend(flatten_formula(premise))
#                input_tokens.append(symb["DE"])  # Separator between premises
            input_tokens.append(symb["DE"])  # Separator before conclusion
            input_tokens.extend(flatten_formula(conclusion))
            input_tokens.append(symb["EOS"])  # End of sequence token

            target_tokens = []
            for step in derivation_steps:
                target_tokens.extend(flatten_formula(step['conclusion']))
                #target_tokens.append(symb["DE"])  # Separator between derivation steps
            target_tokens.append(symb["EOS"])  # End of sequence token

            data.append((input_tokens, target_tokens))

    # Find maximum lengths
    max_input_len = max(len(pair[0]) for pair in data)
    max_target_len = max(len(pair[1]) for pair in data)

    # Pad sequences to the same length and create tensors
    padded_inputs = []
    padded_targets = []
    for input_seq, target_seq in data:
        input_seq = input_seq + [symb["PAD"]] * (max_input_len - len(input_seq))
        target_seq = [symb["SOS"]] + target_seq + [symb["PAD"]] * (max_target_len - len(target_seq))
        padded_inputs.append(input_seq)
        padded_targets.append(target_seq)

    return padded_inputs, padded_targets, max_input_len, max_target_len + 1  # +1 for SOS

# Transformer Model Components
class TransformerEncoder(nn.Module):
    def __init__(self, input_dim, emb_dim, num_heads, hidden_dim, num_layers, dropout, max_seq_len=500):
        super(TransformerEncoder, self).__init__()
        self.embedding = nn.Embedding(input_dim, emb_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, emb_dim))
        encoder_layer = nn.TransformerEncoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
    def forward(self, src, src_mask):
        embedded = self.embedding(src) + self.positional_encoding[:, :src.size(1), :].to(src.device)
        embedded = embedded.permute(1, 0, 2)  # [seq_len, batch_size, emb_dim]
        output = self.transformer_encoder(embedded, src_key_padding_mask=src_mask)
        return output

class TransformerDecoder(nn.Module):
    def __init__(self, output_dim, emb_dim, num_heads, hidden_dim, num_layers, dropout, max_seq_len=500):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(output_dim, emb_dim)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_seq_len, emb_dim))
        decoder_layer = nn.TransformerDecoderLayer(d_model=emb_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
        self.fc_out = nn.Linear(emb_dim, output_dim)
    def forward(self, tgt, memory, tgt_mask, tgt_key_padding_mask):
        embedded = self.embedding(tgt) + self.positional_encoding[:, :tgt.size(1), :].to(tgt.device)
        embedded = embedded.permute(1, 0, 2)  # [seq_len, batch_size, emb_dim]
        output = self.transformer_decoder(embedded, memory, tgt_mask=tgt_mask, tgt_key_padding_mask=tgt_key_padding_mask)
        output = self.fc_out(output)
        return output.permute(1, 0, 2)  # [batch_size, seq_len, output_dim]

class Seq2SeqTransformer(nn.Module):
    def __init__(self, encoder, decoder, device):
        super(Seq2SeqTransformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
    def forward(self, src, tgt, src_mask, tgt_mask, src_key_padding_mask, tgt_key_padding_mask):
        memory = self.encoder(src, src_key_padding_mask)
        output = self.decoder(tgt, memory, tgt_mask, tgt_key_padding_mask)
        return output

# Function to create masks for transformer
def create_src_key_padding_mask(seq, pad_idx):
    return (seq == pad_idx)

def create_tgt_masks(tgt_seq, pad_idx):
    tgt_len = tgt_seq.size(1)
    # Corrected usage of size for mask creation
    tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_len).to(tgt_seq.device)
    tgt_key_padding_mask = (tgt_seq == pad_idx)
    return tgt_mask, tgt_key_padding_mask

def train_model(model, data_loader, num_epochs=10, learning_rate=1e-4):
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=symb["PAD"])
    model.train()
    device = model.device

    for epoch in range(num_epochs):
        total_loss = 0
        for batch_idx, batch in enumerate(data_loader):
            src = batch[0].to(device)
            tgt = batch[1].to(device)

            src_key_padding_mask = create_src_key_padding_mask(src, symb["PAD"])
            tgt_input = tgt[:, :-1]  # Remove last token for input
            tgt_output = tgt[:, 1:]  # Remove first token for output

            tgt_mask, tgt_key_padding_mask = create_tgt_masks(tgt_input, symb["PAD"])
            optimizer.zero_grad()
            # Perform the model forward pass
            output = model(src, tgt_input, None, tgt_mask, src_key_padding_mask, tgt_key_padding_mask)

            output_dim = output.shape[-1]
            # Reshape output to compare with target output
            output_flat = output.reshape(-1, output_dim)
            tgt_output_flat = tgt_output.contiguous().view(-1)

            loss = criterion(output_flat, tgt_output_flat)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

            # Print 5 model outputs with logical symbols after each epoch
            # Print 5 model outputs and corresponding inputs with logical symbols after each epoch
            if batch_idx < 1:  # Change to 1 to print from the first batch
                predicted_tokens = output.argmax(dim=-1)
                print("\nSample model outputs and corresponding inputs:")

                for i in range(min(5, predicted_tokens.size(0))):  # Print up to 5 sequences
                    # Decode model outputs
                    output_tokens = predicted_tokens[i].cpu().numpy().tolist()
                    output_symbols = [symb_reverse.get(token, "[UNK]") for token in output_tokens]

                    # Decode corresponding inputs
                    input_tokens = src[i].cpu().numpy().tolist()
                    input_symbols = [symb_reverse.get(token, "[UNK]") for token in input_tokens]

                    print(f"Input {i}: {' '.join(input_symbols)}")
                    print(f"Output {i}: {' '.join(output_symbols)}\n")


        average_loss = total_loss / len(data_loader)
        print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss:.4f}')

def reconstruct_formula(tokens):
    def helper(index):
        if index >= len(tokens):
            return None, index
        tok = tokens[index]
        if tok in symb_reverse:
            if tok == symb["LB"]:
                left_formula, next_index = helper(index + 1)
                if next_index >= len(tokens):  # Check if we run out of tokens
                    return None, next_index
                op = tokens[next_index]
                right_formula, next_index = helper(next_index + 1)
                if next_index >= len(tokens) or tokens[next_index] != symb["RB"]:  # Extra check for RB
                    return None, next_index  # Instead of assertion, return None
                return [left_formula, op, right_formula], next_index + 1
            elif tok == symb["NO"]:
                formula, next_index = helper(index + 1)
                return [symb["NO"], formula], next_index
            elif tok == symb["RB"] or tok == symb["DE"] or tok == symb["EOS"]:
                return None, index
            else:
                return tok, index + 1
        else:
            return None, index + 1

    formula, _ = helper(0)
    return formula

def collate_batch(batch):
    src_batch, tgt_batch = zip(*batch)
    src_batch = torch.stack([torch.tensor(s, dtype=torch.long) for s in src_batch])
    tgt_batch = torch.stack([torch.tensor(t, dtype=torch.long) for t in tgt_batch])
    return src_batch, tgt_batch

def main():
    # Prepare data
    total_samples = 50000
    batch_size = 32
    data_inputs, data_targets, max_input_len, max_target_len = prepare_data(num_samples=total_samples)
    vocab_size = t_nu + 11  # Number of symbols including PAD, SOS, EOS

    # Model parameters
    input_dim = vocab_size
    output_dim = vocab_size
    emb_dim = 256
    num_heads = 8
    hidden_dim = 512
    num_layers = 3
    dropout = 0.1
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize the model
    encoder = TransformerEncoder(input_dim, emb_dim, num_heads, hidden_dim, num_layers, dropout, max_seq_len=max_input_len)
    decoder = TransformerDecoder(output_dim, emb_dim, num_heads, hidden_dim, num_layers, dropout, max_seq_len=max_target_len)
    model = Seq2SeqTransformer(encoder, decoder, device).to(device)

    # Construct data loader
    data = list(zip(data_inputs, data_targets))
    data_loader = DataLoader(data, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)

    # Train the model
    model = train_model(model, data_loader, num_epochs=10, learning_rate=1e-4)

In [None]:
if __name__ == '__main__':
    main()




Sample model outputs and corresponding inputs:
Input 0: ( q ∧ p ) ⊢ ( p ∨ ( ( q ∧ p ) ∨ r ) ) <EOS>                                                                                                 
Output 0: → ¬ ⊥ ∧ ∧ → ¬ ¬ <EOS> ¬ ∧ → → → ∧ ¬ ∧ → ¬ ¬ ⊥ ¬ ∧ → → → → ∧ ⊥ r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r ∧ r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r → r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r

Input 1: ¬ q ⊢ ( ( ¬ r ∨ p ) ∨ ¬ q ) <EOS>                                                                                                      
Output 1: ∧ r r r r r → r r → r r r ⊢ ⊥ r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r r 

In [None]:
# Testing the model
model.eval()
with torch.no_grad():
    premises, conclusion, derivation_steps = generate_derivation(ipl_rules)
    if derivation_steps:
        input_tokens = []
        for premise in premises:
            input_tokens.extend(flatten_formula(premise))
            input_tokens.append(symb["DE"])  # Separator between premises
        input_tokens.append(symb["DE"])  # Separator before conclusion
        input_tokens.extend(flatten_formula(conclusion))
        input_tokens.append(symb["EOS"])  # End of sequence token

        src = torch.tensor([input_tokens], dtype=torch.long).to(device)
        src_key_padding_mask = create_src_key_padding_mask(src, symb["PAD"])

        max_length = 50  # Maximum length of generated sequence
        tgt_tokens = [symb["SOS"]]
        for i in range(max_length):
            tgt_seq = torch.tensor([tgt_tokens], dtype=torch.long).to(device)
            tgt_mask, tgt_key_padding_mask = create_tgt_masks(tgt_seq, symb["PAD"])
            output = model(src, tgt_seq, None, tgt_mask, src_key_padding_mask, tgt_key_padding_mask)
            next_token = output.argmax(dim=-1)[:, -1].item()
            tgt_tokens.append(next_token)
            if next_token == symb["EOS"]:
                break

        predicted_tokens = tgt_tokens[1:]  # Remove SOS token

        # Convert tokens to formulas
        predicted_formulas = []
        current_formula_tokens = []
        for tok in predicted_tokens:
            if tok == symb["DE"]:
                if current_formula_tokens:
                    formula = reconstruct_formula(current_formula_tokens)
                    predicted_formulas.append(formula)
                    current_formula_tokens = []
            elif tok == symb["EOS"]:
                if current_formula_tokens:
                    formula = reconstruct_formula(current_formula_tokens)
                    predicted_formulas.append(formula)
                break
            else:
                current_formula_tokens.append(tok)

        # Print the premises and conclusion
        print("Premises and Conclusion:")
        for premise in premises:
            print(f"Premise: {formula_to_string(premise)}")
        print(f"Conclusion: {formula_to_string(conclusion)}")

        # Print the predicted derivation
        print("\nPredicted Derivation Steps:")
        for f in predicted_formulas:
            if f:
                print(formula_to_string(f))
            else:
                print("Invalid formula")

NameError: name 'model' is not defined