In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import DataCollatorForSeq2Seq
import torch
from torch.utils.data import Dataset
import pandas as pd

In [None]:
data_df = pd.read_csv('/Users/shreyanakum/Downloads/Sophomore Year/Summer 2024/Aristocrat-Model/data_collection/pruned_substitution_cipher_dataset_merged.csv')

print(data_df.head())

# BART tokenizer
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

# Custom Dataset Class
class AristocratCipherDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=None):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        input_ids = self.tokenizer(item['ciphertext'], padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt").input_ids.squeeze()
        labels = self.tokenizer(item['plaintext'], padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt").input_ids.squeeze()
        
        return {'input_ids': input_ids, 'labels': labels}

# Set a max length for sequences --> change later to longer to see what happens
max_length = 128

# Create Dataset
dataset = AristocratCipherDataset(data_df, bart_tokenizer, max_length=max_length)

In [None]:
# Bart for sequence-to-sequence tasks
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
data_collator = DataCollatorForSeq2Seq(tokenizer=bart_tokenizer, model=model)

if torch.cuda.is_available():
    model.to('cuda')
    print(f"Using GPU: {torch.cuda.get_device_name(0)}") # for nvidia gpu in case its used
else:
    print("Using CPU")


training_args = Seq2SeqTrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    logging_dir='./logs', 
    logging_strategy="steps",  # log every N steps
    logging_steps=10, # adjust to bigger
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    data_collator=data_collator,
)

trainer.train()
trainer.save_model("./aristocrat_cipher_decoder")

## Inference

In [None]:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration

# load the trained model and tokenizer
model = BartForConditionalGeneration.from_pretrained('./aristocrat_cipher_decoder')
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

# set to evaluation mode
model.eval()

# from df
ciphertext = "DEO LHWUEJC DWO PWGAJ DEI WYNKOO PDA QJEPAZ GEJCZKI LHWUEJC WP PDA HKJZKJ LWHHWZEQI WJZ PDA NKUWH WHXANP DWHH."
# HIS PLAYING HAS TAKEN HIM ACROSS THE UNITED KINGDOM PLAYING AT THE LONDON PALLADIUM AND THE ROYAL ALBERT HALL
input_ids = bart_tokenizer.encode(ciphertext, return_tensors='pt')

with torch.no_grad():
    generated_ids = model.generate(input_ids)

decoded_output = bart_tokenizer.decode(generated_ids[0], skip_special_tokens=True)

print(f"Ciphertext: {ciphertext}")
print(f"Decoded plaintext: {decoded_output}")
