In [None]:
import os, json, torch
from datasets import load_dataset, Dataset
from transformers import BartTokenizer, BartForConditionalGeneration, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from evaluate import load
from tqdm import tqdm
import torch.nn as nn

# Setup
torch.cuda.empty_cache()
os.environ["CUDA_VISIBLE_DEVICES"] = "6"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PERSPECTIVES = ["INFORMATION", "SUGGESTION", "EXPERIENCE", "QUESTION", "CAUSE"]

bleu = load("sacrebleu")
bertscore = load("bertscore")

# Tokenizer and base model
model_name = "facebook/bart-base"
tokenizer = BartTokenizer.from_pretrained(model_name)
base_model = BartForConditionalGeneration.from_pretrained(model_name)

# Perspective-aware attention layer
class PerspectiveFusionAttention(nn.Module):
    def __init__(self, hidden_size, num_heads):
        super().__init__()
        self.query_proj = nn.Linear(hidden_size, hidden_size)
        self.key_proj = nn.Linear(hidden_size, hidden_size)
        self.value_proj = nn.Linear(hidden_size, hidden_size)
        self.multihead_attn = nn.MultiheadAttention(embed_dim=hidden_size, num_heads=num_heads, batch_first=True)

    def forward(self, decoder_hidden, perspective_embed):
        perspective_proj = perspective_embed.unsqueeze(1).expand(-1, decoder_hidden.size(1), -1)
        q = self.query_proj(decoder_hidden)
        k = self.key_proj(perspective_proj)
        v = self.value_proj(perspective_proj)
        out, _ = self.multihead_attn(q, k, v)
        return out + decoder_hidden

class CustomDecoderWithPerspective(nn.Module):
    def __init__(self, base_model, num_perspectives):
        super().__init__()
        self.model = base_model
        self.perspective_embed = nn.Embedding(num_perspectives, base_model.config.d_model)
        self.perspective_fusion = PerspectiveFusionAttention(base_model.config.d_model,
                                                              base_model.config.decoder_attention_heads)

    def forward(self, input_ids, attention_mask, perspective_id, labels=None):
        perspective_embed = self.perspective_embed(perspective_id)
        decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels) if labels is not None else None

        outputs = self.model.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            use_cache=False,
            return_dict=True,
            output_hidden_states=True
        )
        decoder_hidden = outputs.decoder_hidden_states[-1]
        fused = self.perspective_fusion(decoder_hidden, perspective_embed)
        lm_logits = self.model.lm_head(fused)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(lm_logits.view(-1, self.model.config.vocab_size), labels.view(-1))

        return {"loss": loss, "logits": lm_logits}

# Apply LoRA to the base model
peft_config = LoraConfig(task_type=TaskType.SEQ_2_SEQ_LM, inference_mode=False, r=4, lora_alpha=16, lora_dropout=0.1)
base_model = get_peft_model(base_model, peft_config)
model = CustomDecoderWithPerspective(base_model, num_perspectives=len(PERSPECTIVES)).to(device)

dataset = load_dataset("json", data_files={
    "train": "./flanT5_files/train.json",
    "val": "./flanT5_files/valid.json"
})

# Function to clean summary text
def clean_summary(text):
    if not text or not isinstance(text, str):
        return ""
    text = text.strip().lower()
    return "" if text in ["false", "true", "n/a", "not_duplicate", "duplicate", ""] else text

def format_example(example):
    context = f"Context: {example.get('context', '').strip()}\n"
    question = f"Question: {example.get('question', '').strip()}\n"
    answers = f"Answers: {' '.join(example.get('answers', [])).strip()}"
    input_text = f"{context}{question}{answers}"
    outputs = []
    for p in PERSPECTIVES:
        summary = clean_summary(example.get("labelled_summaries", {}).get(f"{p}_SUMMARY", ""))
        if summary:
            outputs.append({"input": input_text, "output": summary, "perspective": p})
    return outputs

def trim(data): 
    return data.select(range(len(data)))

train_data = sum([format_example(e) for e in trim(dataset["train"])], [])
val_data   = sum([format_example(e) for e in trim(dataset["val"])], [])

train_ds = Dataset.from_list(train_data)
val_ds = Dataset.from_list(val_data)

# Perspective mapping
perspective2id = {p: i for i, p in enumerate(PERSPECTIVES)}

def preprocess(batch):
    model_inputs = tokenizer(batch["input"], padding="max_length", truncation=True, max_length=512)
    targets = tokenizer(batch["output"], padding="max_length", truncation=True, max_length=128)
    model_inputs["labels"] = [[(t if t != tokenizer.pad_token_id else -100) for t in l] for l in targets["input_ids"]]
    model_inputs["perspective_id"] = [perspective2id[p] for p in batch["perspective"]]
    return model_inputs

train_tok = train_ds.map(preprocess, batched=True)
val_tok   = val_ds.map(preprocess, batched=True)

training_args = TrainingArguments(
    output_dir="./novel_generator",
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=2,
    save_strategy="no",        
    logging_steps=50,
    fp16=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=val_tok,
    tokenizer=tokenizer
)

trainer.train()
save_dir = "novel_generator"
tokenizer.save_pretrained(save_dir)
model.model.save_pretrained(save_dir)

if hasattr(model.model, "lora"):
    model.model.lora.save_pretrained(save_dir)
