In [2]:
import torch
import wandb
from torch.utils.data import DataLoader
from transformers import BartTokenizer, Trainer, TrainingArguments
from datasets import load_dataset, load_from_disk

from scripts.custom_BARTs.noise_encoder_BART import (
    BartForConditionalGeneration,
    BartConfig,
)


# Initialize GPU
device = torch.device("cpu")

# Load dataset
dataset = load_from_disk("data/booksum_one_sentence_dataset")
#dataset = load_dataset("kmfoda/booksum")
train_dataset = dataset["train"]

# Initialize BART model and tokenizer
tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
config = BartConfig.from_pretrained("facebook/bart-base")
config.encoder_gaussian_ratio = 0.7
model = BartForConditionalGeneration.from_pretrained(
    "./saved_models/BART_summary_with_noise_working_p1_45", config=config
).to(device)


# Tokenize the 'summary_text' field
def tokenize_data(example):
    encoded_summary = tokenizer.encode(
        example["one_sentence_summary"], truncation=True, padding="max_length", max_length=128
    )
    return {
        "labels": encoded_summary,
        "input_ids": encoded_summary,  # Dummy input_ids
        "attention_mask": [1] * len(encoded_summary),  # Dummy attention_mask
    }


tokenized_dataset = train_dataset.map(tokenize_data)

# print one random summary
import random

# random.seed(42)
random_index = 452#random.randint(0, len(train_dataset))
print("Random index: ", random_index)
print("Original summary: ", train_dataset[random_index]["one_sentence_summary"])

input_ids = torch.tensor(tokenized_dataset[random_index]["input_ids"]).to(device)
output = model.generate(input_ids.unsqueeze(0), max_length=1024).to(device)

print("Generated summary: ", tokenizer.decode(output[0], skip_special_tokens=True))

model.config.encoder_gaussian_ratio = 0.0
output = model.generate(input_ids.unsqueeze(0), max_length=1024).to(device)
print("Generated summary, noise ratio {}:".format(model.config.encoder_gaussian_ratio), tokenizer.decode(output[0], skip_special_tokens=True))

model.config.encoder_gaussian_ratio = 0.3
output = model.generate(input_ids.unsqueeze(0), max_length=1024).to(device)
print("Generated summary, noise ratio {}:".format(model.config.encoder_gaussian_ratio), tokenizer.decode(output[0], skip_special_tokens=True))

model.config.encoder_gaussian_ratio = 0.5
output = model.generate(input_ids.unsqueeze(0), max_length=1024).to(device)
print("Generated summary, noise ratio {}:".format(model.config.encoder_gaussian_ratio), tokenizer.decode(output[0], skip_special_tokens=True))

model.config.encoder_gaussian_ratio = 0.7
output = model.generate(input_ids.unsqueeze(0), max_length=1024).to(device)
print("Generated summary, noise ratio {}:".format(model.config.encoder_gaussian_ratio), tokenizer.decode(output[0], skip_special_tokens=True))


Loading cached processed dataset at /data1/sanps/diffusion_transformer/data/booksum_one_sentence_dataset/train/cache-9c70589dc3ee05d6.arrow


Random index:  452
Original summary:  Adam Bede returns home to find his mother, Lisbeth, sad due to his father's drunkenness, and he decides to finish an unfinished coffin, causing tension between him and Lisbeth; later, while delivering the coffin, they discover the body of Adam's father, Thias, in a brook.
Generated summary:  2018 ch's de ill as s comes l first f estle to preparing UcostelingMacCOM Swedish taking est est, estost his est est Vital election to est town� hotel his ward his to his calculating their daughter his leaving Quint ReginaCOM and Tomas ward Cec2018 willCOM ‎ ceeni preparing his happiness to taking his self est and pos Mr remaining Will c guest Catherine their c� pollCOM and the other to chart- sonsÍ decidingCOM est and estim choice his deciding to est est est liv host his Mr est est and his el mother takes an to deciding the theois citydo welcomed to to andois his deCOM to father _COM one city father his the city deciding of city do will volunteers fl Swedish l

In [None]:
# print 3 random summaries

for i in range(3):
    random_index = random.randint(0, len(train_dataset))
    print("Random index: ", random_index)
    print("Original summary: ", train_dataset[random_index]["summary_text"])

    input_ids = torch.tensor(tokenized_dataset[random_index]["input_ids"]).to(device)
    output = model.generate(input_ids.unsqueeze(0), max_length=1024).to(device)

    print("Generated summary: ", tokenizer.decode(output[0], skip_special_tokens=True))