In [13]:
import pandas as pd
from typing import Iterable
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import math


In [None]:
def get_tgt(row: pd.Series) -> str:
    """
    Usage:
        row = {"input": "[(1, 10), (2, 20), (3, 30)]", "output": "[4, 5, 6]"}
        get_tgt(row)
    """
    inputs = eval(row["input"])
    outputs = eval(row["output"])
    tgt_str_li = [f"{list(input)}:{output}," for input, output in zip(inputs, outputs)]
    tgt_str = "".join(tgt_str_li)
    tgt_str = tgt_str.replace(" ", "")
    tgt_str = tgt_str[:-1]
    return tgt_str



def build_vocab(strings: Iterable[str]):
    vocab = {"<pad>": 0, "<sos>": 1, "<eos>": 2}
    idx = 3
    for sentence in strings:
        for token in sentence:
            if token not in vocab:
                vocab[token] = idx
                idx += 1
    return vocab

def id2token(vocab: dict[int, str], id: int) -> str:
    for token, idx in vocab.items():
        if idx == id:
            return token
    return None


def tokenize(string: str, vocab: dict[str, int]) -> list[int]:
    return [vocab["<sos>"]] + [vocab[token] for token in string] + [vocab["<eos>"]]

class TransformerDataset(Dataset):
    def __init__(self, dataframe, src_vocab, tgt_vocab, src_max_len, tgt_max_len):
        self.dataframe = dataframe
        self.src_vocab = src_vocab
        self.tgt_vocab = tgt_vocab
        self.src_max_len = src_max_len
        self.tgt_max_len = tgt_max_len

    def __len__(self):
        return len(self.dataframe)

    def pad_sequence(self, seq, vocab, max_len):
        return seq + [vocab["<pad>"]] * (max_len - len(seq))

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        src_tokens = tokenize(row['src_str'], self.src_vocab)
        tgt_tokens = tokenize(row['tgt_str'], self.tgt_vocab)
        src_padded = self.pad_sequence(src_tokens[:self.src_max_len], self.src_vocab, self.src_max_len)
        tgt_padded = self.pad_sequence(tgt_tokens[:self.tgt_max_len], self.tgt_vocab, self.tgt_max_len)
        return torch.tensor(src_padded), torch.tensor(tgt_padded)

In [3]:
# make dataset
data_path = "/home/takeru/AlphaSymbol/data/prfndim/d3-a2-c3-r3-status.csv"
df = pd.read_csv(data_path)
df_data = pd.DataFrame(columns=["src", "tgt"])

df_data["tgt_str"] = df["expr"].apply(lambda x: x.replace(" ", ""))
df_data["src_str"] = df.apply(get_tgt, axis=1)
src_vocab = build_vocab(df_data["src_str"])
tgt_vocab = build_vocab(df_data["tgt_str"])
src_max_len = df_data["src_str"].apply(len).max()
tgt_max_len = df_data["tgt_str"].apply(len).max()
dataset = TransformerDataset(df_data, src_vocab, tgt_vocab, src_max_len, tgt_max_len)

dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

In [32]:

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)  # Add batch dimension
        self.encoding = self.encoding.permute(1, 0, 2)  # (max_len, batch_size, d_model)
        self.register_buffer("pe", self.encoding)

    def forward(self, x):
        """
        x: Tensor of shape (seq_len, batch_size, d_model)
        """
        seq_len = x.size(0)
        return x + self.pe[: seq_len]

class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, src_max_len, tgt_max_len, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        self.d_model = d_model
        self.src_pos_enc = PositionalEncoding(d_model, src_max_len)
        self.tgt_pos_enc = PositionalEncoding(d_model, tgt_max_len)

    def forward(self, src, tgt):
        src = self.src_embedding(src) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        tgt = self.tgt_embedding(tgt) * torch.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        src = src.permute(1, 0, 2)  # (S, N, E)
        tgt = tgt.permute(1, 0, 2)  # (T, N, E)
        src = src + self.src_pos_enc(src)
        tgt = tgt + self.tgt_pos_enc(tgt)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt.size(0))
        output = self.transformer(src, tgt, tgt_mask = tgt_mask)
        output = self.fc_out(output)
        return output

# Initialize model, loss function, and optimizer
src_vocab_size = len(src_vocab)
tgt_vocab_size = len(tgt_vocab)
model = TransformerModel(src_vocab_size, tgt_vocab_size, src_max_len, tgt_max_len)
criterion = nn.CrossEntropyLoss(ignore_index=src_vocab["<pad>"])
optimizer = optim.Adam(model.parameters(), lr=0.0001)

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    for src_batch, tgt_batch in dataloader:
        tgt_input = tgt_batch[:, :-1]
        tgt_output = tgt_batch[:, 1:]
        
        optimizer.zero_grad()
        output = model(src_batch, tgt_input)
        output = output.permute(1, 2, 0)  # (N, C, T)
        loss = criterion(output, tgt_output)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/len(dataloader)}')

Epoch 1/100, Loss: 2.3948116643088206
Epoch 2/100, Loss: 2.1350086757114957
Epoch 3/100, Loss: 2.051949347768511
Epoch 4/100, Loss: 1.5828802755900793
Epoch 5/100, Loss: 1.0510199325425285
Epoch 6/100, Loss: 0.8737418140683856
Epoch 7/100, Loss: 0.8305619359016418
Epoch 8/100, Loss: 0.7557384456907
Epoch 9/100, Loss: 0.664957617010389
Epoch 10/100, Loss: 0.6888793366295951
Epoch 11/100, Loss: 0.6522269419261387
Epoch 12/100, Loss: 0.6749982833862305
Epoch 13/100, Loss: 0.6490717530250549
Epoch 14/100, Loss: 0.5388032112802777
Epoch 15/100, Loss: 0.5663161788667951
Epoch 16/100, Loss: 0.5232447470937457
Epoch 17/100, Loss: 0.47398244057382855
Epoch 18/100, Loss: 0.4459611986364637
Epoch 19/100, Loss: 0.4665420012814658
Epoch 20/100, Loss: 0.407229197876794
Epoch 21/100, Loss: 0.424311671938215
Epoch 22/100, Loss: 0.4018236781869616
Epoch 23/100, Loss: 0.4479578265122005
Epoch 24/100, Loss: 0.4496226225580488
Epoch 25/100, Loss: 0.42390948108264376
Epoch 26/100, Loss: 0.3929436036518642


In [81]:
c = "<sos>"
data = dataset[0][0]
print(data)
while c !="<eos>":
    ret = model(data.unsqueeze(0), torch.tensor(tgt_vocab[c]).unsqueeze(0).unsqueeze(0))
    c = id2token(tgt_vocab, ret.squeeze().argmax().item())
    print(c)

tensor([1, 3, 4, 5, 6, 4, 7, 3, 4, 5, 6, 4, 7, 3, 4, 5, 6, 4, 7, 3, 4, 5, 6, 4,
        7, 3, 4, 5, 6, 4, 7, 3, 4, 5, 6, 4, 7, 3, 4, 5, 6, 4, 7, 3, 4, 5, 6, 4,
        7, 3, 4, 5, 6, 4, 7, 3, 4, 5, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
C
(
C
(
Z
<eos>
