In [None]:
from transformers import BartTokenizer, BartForConditionalGeneration, Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import DataCollatorForSeq2Seq
import torch
from torch.utils.data import Dataset
import pandas as pd

In [None]:
# Load data from CSV file
data_df = pd.read_csv('/Users/shreyanakum/Downloads/Sophomore Year/Summer 2024/Aristocrat-Model/data_collection/caesar_cipher_output.csv')

# Preview the data
print(data_df.head())

# Load the BART tokenizer
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

# Custom Dataset Class
class CaesarCipherDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=None):
        self.df = df
        self.tokenizer = tokenizer
        self.max_length = max_length
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        item = self.df.iloc[idx]
        input_ids = self.tokenizer(item['Ciphertext'], padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt").input_ids.squeeze()
        labels = self.tokenizer(item['Plaintext'], padding='max_length', truncation=True, max_length=self.max_length, return_tensors="pt").input_ids.squeeze()
        
        return {'input_ids': input_ids, 'labels': labels}

# Set a max length for sequences
max_length = 128

# Create Dataset
dataset = CaesarCipherDataset(data_df, bart_tokenizer, max_length=max_length)

In [None]:
# Load the model (Bart for sequence-to-sequence tasks)
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')
data_collator = DataCollatorForSeq2Seq(tokenizer=bart_tokenizer, model=model)

if torch.cuda.is_available():
    model.to('cuda')
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")


# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir='./results',
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    logging_dir='./logs',  # Directory for storing logs
    logging_strategy="steps",  # Log every N steps
    logging_steps=10,  # Adjust as needed
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=dataset,
    data_collator=data_collator,
)

# Train the model
trainer.train()

# Save the model
trainer.save_model("./caesar_cipher_decoder")

## Inference

In [None]:
import torch
from transformers import BartTokenizer, BartForConditionalGeneration

# Load the trained model and tokenizer
model = BartForConditionalGeneration.from_pretrained('./caesar_cipher_decoder')
bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')

# Set the model to evaluation mode
model.eval()

# Example ciphertext for inference
ciphertext = "DEO LHWUEJC DWO PWGAJ DEI WYNKOO PDA QJEPAZ GEJCZKI LHWUEJC WP PDA HKJZKJ LWHHWZEQI WJZ PDA NKUWH WHXANP DWHH."  # Replace with your ciphertext

# Encode the input text using BART's tokenizer
input_ids = bart_tokenizer.encode(ciphertext, return_tensors='pt')

# Perform inference
with torch.no_grad():
    generated_ids = model.generate(input_ids)

# Decode the output using BART's tokenizer
decoded_output = bart_tokenizer.decode(generated_ids[0], skip_special_tokens=True)

print(f"Ciphertext: {ciphertext}")
print(f"Decoded plaintext: {decoded_output}")
