In [2]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, BertTokenizer, BertModel
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import random
import time
import warnings
warnings.filterwarnings("ignore")

# Download nltk data if needed
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# Device configuration
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Custom dataset
class SarcasmGANDataset(Dataset):
    def __init__(self, csv_path, tokenizer, bert_tokenizer, max_length=128, prompt_style="instruction"):
        df = pd.read_csv(csv_path)
        self.sarcastic = df['tweet'].astype(str).tolist()
        self.rephrase = df['rephrase'].astype(str).tolist()
        self.tokenizer = tokenizer
        self.bert_tokenizer = bert_tokenizer
        self.max_length = max_length
        self.prompt_style = prompt_style
        
    def __len__(self):
        return len(self.sarcastic)
        
    def get_prompt(self, text):
        """Multi-style prompting for better instruction following"""
        styles = {
            "instruction": f"Remove sarcasm from the following text: {text}",
            "rewrite": f"Rewrite without sarcasm: {text}",
            "transform": f"Transform this sarcastic statement into a non-sarcastic one: {text}",
            "neutral": f"Make this statement neutral and straightforward: {text}"
        }
        
        if self.prompt_style == "random":
            return styles[list(styles.keys())[np.random.randint(0, len(styles))]]
        return styles[self.prompt_style]
        
    def __getitem__(self, idx):
        src = self.sarcastic[idx]
        tgt = self.rephrase[idx]
        
        # Apply prompt template with clear instructions
        src_prompted = self.get_prompt(src)
        
        # Tokenize for generator
        src_encoded = self.tokenizer(
            src_prompted,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        tgt_encoded = self.tokenizer(
            tgt,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        # Tokenize for discriminator
        src_disc = self.bert_tokenizer(
            src,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        tgt_disc = self.bert_tokenizer(
            tgt,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt"
        )
        
        input_ids = src_encoded["input_ids"].squeeze()
        attention_mask = src_encoded["attention_mask"].squeeze()
        labels = tgt_encoded["input_ids"].squeeze()
        
        disc_src_ids = src_disc["input_ids"].squeeze()
        disc_src_mask = src_disc["attention_mask"].squeeze()
        disc_tgt_ids = tgt_disc["input_ids"].squeeze()
        disc_tgt_mask = tgt_disc["attention_mask"].squeeze()
        
        # Mark padding tokens as ignored in the loss
        labels[labels == self.tokenizer.pad_token_id] = -100
        
        return {
            # Generator inputs
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            
            # Original source text
            "src_raw": src,
            
            # Discriminator inputs
            "disc_src_ids": disc_src_ids,
            "disc_src_mask": disc_src_mask,
            "disc_tgt_ids": disc_tgt_ids,
            "disc_tgt_mask": disc_tgt_mask
        }

# Generator (Seq2Seq model)
class Generator(nn.Module):
    def __init__(self, model_name="facebook/bart-base"):
        super(Generator, self).__init__()
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        
    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels,
            return_dict=True
        )
        return outputs
    
    def generate(self, input_ids, attention_mask, **kwargs):
        return self.model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            **kwargs
        )

# Discriminator (Classifier)
class Discriminator(nn.Module):
    def __init__(self, model_name="bert-base-uncased"):
        super(Discriminator, self).__init__()
        self.bert = BertModel.from_pretrained(model_name)
        self.classifier = nn.Sequential(
            nn.Linear(768*2, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
        
    def forward(self, src_ids, src_mask, tgt_ids, tgt_mask):
        # Encode source text
        src_outputs = self.bert(
            input_ids=src_ids,
            attention_mask=src_mask,
            return_dict=True
        )
        src_pooled = src_outputs.pooler_output
        
        # Encode target text
        tgt_outputs = self.bert(
            input_ids=tgt_ids,
            attention_mask=tgt_mask,
            return_dict=True
        )
        tgt_pooled = tgt_outputs.pooler_output
        
        # Concatenate the pooled outputs
        combined = torch.cat([src_pooled, tgt_pooled], dim=1)
        
        # Classify whether the target is real (1) or fake (0)
        return self.classifier(combined)
        
# Evaluation function with BLEU scoring
def evaluate_model(generator, discriminator, dataloader, device, gen_only=False):
    generator.eval()
    if not gen_only:
        discriminator.eval()
    
    g_loss = 0
    d_loss = 0
    disc_real_acc = 0
    disc_fake_acc = 0
    all_bleu_scores = []
    
    # Initialize smoothing function for BLEU
    smoothie = SmoothingFunction().method1
    
    with torch.no_grad():
        for batch in dataloader:
            # Generator evaluation
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            g_outputs = generator(
                input_ids=input_ids,
                attention_mask=attention_mask,
                labels=labels
            )
            
            g_loss += g_outputs.loss.item()
            
            # Generate text for BLEU scoring and discriminator eval
            gen_ids = generator.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=128,
                num_beams=4,
                early_stopping=True
            )
            
            # Skip discriminator evaluation if only evaluating generator
            if gen_only:
                continue
                
            # Prepare inputs for discriminator
            disc_src_ids = batch["disc_src_ids"].to(device)
            disc_src_mask = batch["disc_src_mask"].to(device)
            disc_tgt_ids = batch["disc_tgt_ids"].to(device)
            disc_tgt_mask = batch["disc_tgt_mask"].to(device)
            
            # Tokenize generated text for discriminator
            gen_text = dataloader.dataset.tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
            gen_encoded = dataloader.dataset.bert_tokenizer(
                gen_text,
                padding="max_length",
                truncation=True,
                max_length=128,
                return_tensors="pt"
            )
            gen_ids = gen_encoded["input_ids"].to(device)
            gen_mask = gen_encoded["attention_mask"].to(device)
            
            # Real samples classification
            real_preds = discriminator(disc_src_ids, disc_src_mask, disc_tgt_ids, disc_tgt_mask)
            real_labels = torch.ones_like(real_preds).to(device)
            
            # Fake samples classification
            fake_preds = discriminator(disc_src_ids, disc_src_mask, gen_ids, gen_mask)
            fake_labels = torch.zeros_like(fake_preds).to(device)
            
            # Calculate discriminator loss and accuracy
            d_loss_real = nn.BCELoss()(real_preds, real_labels)
            d_loss_fake = nn.BCELoss()(fake_preds, fake_labels)
            d_loss += (d_loss_real + d_loss_fake).item() / 2
            
            # Calculate accuracies
            disc_real_acc += ((real_preds > 0.5).float() == real_labels).float().mean().item()
            disc_fake_acc += ((fake_preds > 0.5).float() == fake_labels).float().mean().item()
            
            # Calculate BLEU scores
            for ref, pred in zip(batch["labels"], gen_ids):
                ref_tokens = dataloader.dataset.tokenizer.decode(
                    ref[ref != -100], skip_special_tokens=True
                ).lower().split()
                
                pred_tokens = dataloader.dataset.tokenizer.decode(
                    pred, skip_special_tokens=True
                ).lower().split()
                
                if pred_tokens:  # Avoid empty predictions
                    weights = (0.5, 0.3, 0.15, 0.05)  # Focus on lower n-grams
                    bleu = sentence_bleu(
                        [ref_tokens], 
                        pred_tokens, 
                        weights=weights,
                        smoothing_function=smoothie
                    )
                    all_bleu_scores.append(bleu)
    
    # Calculate averages
    avg_g_loss = g_loss / len(dataloader)
    
    if not gen_only:
        avg_d_loss = d_loss / len(dataloader)
        avg_real_acc = disc_real_acc / len(dataloader)
        avg_fake_acc = disc_fake_acc / len(dataloader)
    else:
        avg_d_loss = 0
        avg_real_acc = 0
        avg_fake_acc = 0
        
    avg_bleu = np.mean(all_bleu_scores) if all_bleu_scores else 0
    
    return {
        "g_loss": avg_g_loss,
        "d_loss": avg_d_loss,
        "real_acc": avg_real_acc,
        "fake_acc": avg_fake_acc,
        "bleu": avg_bleu
    }

# GAN Training Function
def train_gan(
    csv_path,
    output_dir="sarcasm_gan_model",
    gen_model_name="facebook/bart-base",
    disc_model_name="bert-base-uncased",
    batch_size=8,
    epochs=5,
    g_lr=5e-5,
    d_lr=1e-5,
    max_length=128,
    eval_steps=50,
    prompt_style="random",
    g_steps=1,  # How many generator steps per discriminator step
    d_steps=1,  # How many discriminator steps per generator step
    gan_weight=0.5  # Weight for GAN loss vs. supervised loss
):
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Initialize tokenizers
    gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
    disc_tokenizer = BertTokenizer.from_pretrained(disc_model_name)
    
    # Initialize models
    generator = Generator(gen_model_name).to(device)
    discriminator = Discriminator(disc_model_name).to(device)
    
    # Split data for validation
    df = pd.read_csv(csv_path)
    train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)
    
    # Save splits temporarily
    train_path = "train_temp.csv"
    val_path = "val_temp.csv"
    train_df.to_csv(train_path, index=False)
    val_df.to_csv(val_path, index=False)
    
    # Create datasets
    train_dataset = SarcasmGANDataset(train_path, gen_tokenizer, disc_tokenizer, max_length, prompt_style)
    val_dataset = SarcasmGANDataset(val_path, gen_tokenizer, disc_tokenizer, max_length, "instruction")
    
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)
    
    # Optimizers
    g_optimizer = torch.optim.AdamW(generator.parameters(), lr=g_lr, weight_decay=0.01)
    d_optimizer = torch.optim.AdamW(discriminator.parameters(), lr=d_lr, weight_decay=0.01)
    
    # For tracking metrics
    best_bleu = 0
    step_counter = 0
    history = {
        'g_loss': [],
        'd_loss': [],
        'real_acc': [],
        'fake_acc': [],
        'bleu': [],
        'val_g_loss': [],
        'val_d_loss': [],
        'val_bleu': []
    }
    
    # BCE Loss for discriminator
    bce_loss = nn.BCELoss()
    
    # Training loop
    for epoch in range(epochs):
        generator.train()
        discriminator.train()
        
        epoch_g_loss = 0
        epoch_d_loss = 0
        epoch_gan_loss = 0
        
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch in progress_bar:
            # Get batch data
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)
            
            disc_src_ids = batch["disc_src_ids"].to(device)
            disc_src_mask = batch["disc_src_mask"].to(device)
            disc_tgt_ids = batch["disc_tgt_ids"].to(device)
            disc_tgt_mask = batch["disc_tgt_mask"].to(device)
            
            # ------ Train Generator ------
            for _ in range(g_steps):
                g_optimizer.zero_grad()
                
                # Supervised loss
                g_outputs = generator(input_ids, attention_mask, labels)
                g_supervised_loss = g_outputs.loss
                
                # Generate text for adversarial loss
                with torch.no_grad():
                    gen_ids = generator.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        max_length=max_length,
                        num_beams=4,
                        early_stopping=True
                    )
                
                # Tokenize generated text for discriminator
                gen_text = gen_tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
                gen_encoded = disc_tokenizer(
                    gen_text,
                    padding="max_length",
                    truncation=True,
                    max_length=max_length,
                    return_tensors="pt"
                )
                gen_ids_disc = gen_encoded["input_ids"].to(device)
                gen_mask_disc = gen_encoded["attention_mask"].to(device)
                
                # Adversarial loss (fool the discriminator)
                fake_preds = discriminator(disc_src_ids, disc_src_mask, gen_ids_disc, gen_mask_disc)
                g_gan_loss = bce_loss(fake_preds, torch.ones_like(fake_preds).to(device))
                
                # Combined loss
                g_loss = g_supervised_loss + gan_weight * g_gan_loss
                g_loss.backward()
                
                # Apply gradient clipping
                torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
                g_optimizer.step()
                
                epoch_g_loss += g_supervised_loss.item()
                epoch_gan_loss += g_gan_loss.item()
            
            # ------ Train Discriminator ------
            for _ in range(d_steps):
                d_optimizer.zero_grad()
                
                # Generate new text (with detached generator)
                with torch.no_grad():
                    gen_ids = generator.generate(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        max_length=max_length,
                        num_beams=4,
                        early_stopping=True
                    )
                
                # Tokenize generated text for discriminator
                gen_text = gen_tokenizer.batch_decode(gen_ids, skip_special_tokens=True)
                gen_encoded = disc_tokenizer(
                    gen_text,
                    padding="max_length",
                    truncation=True,
                    max_length=max_length,
                    return_tensors="pt"
                )
                gen_ids_disc = gen_encoded["input_ids"].to(device)
                gen_mask_disc = gen_encoded["attention_mask"].to(device)
                
                # Real samples classification
                real_preds = discriminator(disc_src_ids, disc_src_mask, disc_tgt_ids, disc_tgt_mask)
                real_labels = torch.ones_like(real_preds).to(device)
                d_real_loss = bce_loss(real_preds, real_labels)
                
                # Fake samples classification
                fake_preds = discriminator(disc_src_ids, disc_src_mask, gen_ids_disc, gen_mask_disc)
                fake_labels = torch.zeros_like(fake_preds).to(device)
                d_fake_loss = bce_loss(fake_preds, fake_labels)
                
                # Combined loss
                d_loss = (d_real_loss + d_fake_loss) / 2
                d_loss.backward()
                
                # Apply gradient clipping
                torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
                d_optimizer.step()
                
                epoch_d_loss += d_loss.item()
            
            # Update progress bar
            progress_bar.set_postfix({
                "G Loss": g_loss.item(),
                "D Loss": d_loss.item(),
                "GAN Loss": g_gan_loss.item()
            })
            
                
        # End of epoch evaluation
        eval_metrics = evaluate_model(generator, discriminator, val_loader, device)
        
        # Calculate epoch averages
        avg_g_loss = epoch_g_loss / len(train_loader)
        avg_d_loss = epoch_d_loss / len(train_loader)
        avg_gan_loss = epoch_gan_loss / len(train_loader)
        
        # Update history
        history['g_loss'].append(avg_g_loss)
        history['d_loss'].append(avg_d_loss)
        history['val_g_loss'].append(eval_metrics['g_loss'])
        history['val_d_loss'].append(eval_metrics['d_loss'])
        history['real_acc'].append(eval_metrics['real_acc'])
        history['fake_acc'].append(eval_metrics['fake_acc'])
        history['bleu'].append(eval_metrics['bleu'])
        
        print(f"\nEpoch {epoch+1} Summary:")
        print(f"  Train G Loss: {avg_g_loss:.4f}, D Loss: {avg_d_loss:.4f}, GAN Loss: {avg_gan_loss:.4f}")
        print(f"  Val G Loss: {eval_metrics['g_loss']:.4f}, D Loss: {eval_metrics['d_loss']:.4f}")
        print(f"  Val D Real Acc: {eval_metrics['real_acc']:.4f}, Fake Acc: {eval_metrics['fake_acc']:.4f}")
        print(f"  Val BLEU: {eval_metrics['bleu']:.4f}")
        
        # Save checkpoint
        checkpoint_dir = os.path.join(output_dir, f"checkpoint-epoch-{epoch+1}")
        os.makedirs(checkpoint_dir, exist_ok=True)
        
        # Save generator
        generator.model.save_pretrained(os.path.join(checkpoint_dir, "generator"))
        gen_tokenizer.save_pretrained(os.path.join(checkpoint_dir, "generator"))
        
        # Save discriminator
        torch.save(discriminator.state_dict(), os.path.join(checkpoint_dir, "discriminator.pt"))
    
    # Clean up temp files
    if os.path.exists(train_path):
        os.remove(train_path)
    if os.path.exists(val_path):
        os.remove(val_path)
    
    # Save final models
    generator.model.save_pretrained(os.path.join(output_dir, "generator"))
    gen_tokenizer.save_pretrained(os.path.join(output_dir, "generator"))
    torch.save(discriminator.state_dict(), os.path.join(output_dir, "discriminator.pt"))
    
    # Save history as numpy arrays
    np.save(os.path.join(output_dir, "training_history.npy"), history)
    
    # Plot training history
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2, 2, 1)
    plt.plot(history['g_loss'], label='Train G Loss')
    plt.plot(history['val_g_loss'], label='Val G Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Generator Loss')
    
    plt.subplot(2, 2, 2)
    plt.plot(history['d_loss'], label='Train D Loss')
    plt.plot(history['val_d_loss'], label='Val D Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Discriminator Loss')
    
    plt.subplot(2, 2, 3)
    plt.plot(history['real_acc'], label='Real Accuracy')
    plt.plot(history['fake_acc'], label='Fake Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Discriminator Accuracy')
    
    plt.subplot(2, 2, 4)
    plt.plot(history['bleu'], label='BLEU Score')
    plt.xlabel('Epoch')
    plt.ylabel('BLEU')
    plt.title('BLEU Score')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "training_history.png"))
    plt.show()
    
    print(f"Training complete. Models saved to {output_dir}")
    return generator, discriminator, gen_tokenizer, history

# Similarity helpers
def jaccard_similarity(text1, text2):
    words1 = set(text1.lower().split())
    words2 = set(text2.lower().split())
    
    intersection = words1.intersection(words2)
    union = words1.union(words2)
    
    return len(intersection) / max(1, len(union))

def get_best_output(original_text, candidates):
    """Select the best output from multiple candidates"""
    if not candidates:
        return f"I meant: {original_text}"
    
    # Original text metrics
    orig_len = len(original_text.split())
    
    scores = []
    for output in candidates:
        if not output or not output.strip():
            scores.append(0)
            continue
            
        output_len = len(output.split())
        
        # Score based on length ratio (prefer similar length to original)
        len_ratio = min(output_len / max(1, orig_len), orig_len / max(1, output_len))
        
        # Check if output is too similar to input (avoid minimal changes)
        similarity = jaccard_similarity(original_text, output)
        
        # Penalize outputs that are too similar or too different
        score = len_ratio * (1 - abs(similarity - 0.3))
        scores.append(score)
    
    # Return the candidate with highest score
    if not scores or max(scores) == 0:
        return f"I meant: {original_text}"
    
    best_idx = np.argmax(scores)
    return candidates[best_idx]

# Enhanced GAN-based inference function
def rewrite_sarcasm_gan(
    text, 
    gen_model_dir=None, 
    disc_model_dir=None,
    generator=None, 
    discriminator=None,
    gen_tokenizer=None, 
    disc_tokenizer=None,
    ensemble=True  # Whether to use prompt ensemble
):
    if generator is None or gen_tokenizer is None:
        try:
            # Try to load as local path by adding "local-" prefix
            gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_dir, local_files_only=True)
            # For Generator, we need to load the model correctly
            generator = Generator()
            generator.model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_dir, local_files_only=True).to(device)
        except Exception as e:
            print(f"Error loading generator: {e}")
            raise
    
    if discriminator is None or disc_tokenizer is None:
        if disc_model_dir:  # Only load discriminator if path is provided
            try:
                disc_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
                discriminator = Discriminator("bert-base-uncased").to(device)
                # Load discriminator weights from local file
                discriminator.load_state_dict(torch.load(os.path.join(disc_model_dir, "discriminator.pt")))
            except Exception as e:
                print(f"Error loading discriminator: {e}")
                # Continue without discriminator if loading fails
                discriminator = None
                disc_tokenizer = None
    
    generator.eval()
    if discriminator:
        discriminator.eval()
    # Use ensemble approach if enabled
    if ensemble:
        prompts = [
            f"Remove sarcasm from the following text: {text}",
            f"Rewrite without sarcasm: {text}",
            f"Transform this sarcastic statement into a non-sarcastic one: {text}",
            f"Make this statement neutral and straightforward: {text}"
        ]
    else:
        prompts = [f"Remove sarcasm from the following text: {text}"]
    
    all_outputs = []
    disc_scores = []
    
    # Generate with multiple prompts
    for prompt in prompts:
        try:
            # Tokenize prompt
            src_encoded = gen_tokenizer(
                prompt, 
                return_tensors="pt", 
                max_length=128, 
                truncation=True
            ).to(device)
            
            # Generate output
            with torch.no_grad():
                gen_ids = generator.generate(
                    input_ids=src_encoded["input_ids"],
                    attention_mask=src_encoded["attention_mask"],
                    max_length=128,
                    num_beams=5,
                    min_length=4,
                    length_penalty=1.0,
                    no_repeat_ngram_size=3,
                    early_stopping=True,
                    do_sample=False
                )
            
            # Decode output
            output = gen_tokenizer.decode(gen_ids[0], skip_special_tokens=True)
            
            # Only add non-empty outputs
            if output and len(output.strip()) > 0:
                all_outputs.append(output)
                
                # Score with discriminator if available
                if discriminator:
                    with torch.no_grad():
                        # Tokenize for discriminator
                        src_disc = disc_tokenizer(
                            text,
                            return_tensors="pt",
                            max_length=128,
                            padding="max_length",
                            truncation=True
                        ).to(device)
                        
                        tgt_disc = disc_tokenizer(
                            output,
                            return_tensors="pt",
                            max_length=128,
                            padding="max_length",
                            truncation=True
                        ).to(device)
                        
                        # Get discriminator score (higher is better - more realistic)
                        disc_score = discriminator(
                            src_disc["input_ids"], 
                            src_disc["attention_mask"],
                            tgt_disc["input_ids"],
                            tgt_disc["attention_mask"]
                        ).item()
                        
                        disc_scores.append(disc_score)
        except Exception as e:
            print(f"Error generating with prompt '{prompt}': {e}")
    
    # If no outputs, try fallback
    if not all_outputs:
        try:
            # Fallback to direct generation with no frills
            input_text = f"Make this non-sarcastic: {text}"
            input_ids = gen_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
            
            with torch.no_grad():
                output_ids = generator.generate(
                    input_ids,
                    max_length=128,
                    num_beams=4,
                    do_sample=False,
                    early_stopping=True
                )
            
            fallback_output = gen_tokenizer.decode(output_ids[0], skip_special_tokens=True)
            if fallback_output and len(fallback_output.strip()) > 0:
                return fallback_output
            else:
                return f"I meant: {text}"
        except:
            return f"I meant: {text}"
    
    # If we have discriminator scores, use them to pick the best
    if discriminator and disc_scores:
        best_idx = np.argmax(disc_scores)
        return all_outputs[best_idx]
    
    # Otherwise use our heuristic selection
    return get_best_output(text, all_outputs)

# Process multiple examples
def process_examples(texts, gen_model_dir=None, disc_model_dir=None):
    # Load models once
    gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_dir)
    generator = Generator()
    generator.model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_dir).to(device)
    
    # Try to load discriminator if available
    try:
        disc_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        discriminator = Discriminator().to(device)
        discriminator.load_state_dict(torch.load(os.path.join(disc_model_dir, "discriminator.pt")))
    except:
        discriminator = None
        disc_tokenizer = None
    
    results = []
    for text in tqdm(texts, desc="Processing texts"):
        rewritten = rewrite_sarcasm_gan(
            text, 
            generator=generator, 
            discriminator=discriminator,
            gen_tokenizer=gen_tokenizer, 
            disc_tokenizer=disc_tokenizer
        )
        results.append((text, rewritten))
    
    return pd.DataFrame(results, columns=['Original', 'Rewritten'])

# Example usage - you can uncomment and run directly

# 1. Train the model
# csv_path = "sarcasm_data.csv"
# output_dir = "sarcasm_gan_model"
# generator, discriminator, tokenizer, history = train_gan(
#     csv_path=csv_path,
#     output_dir=output_dir,
#     epochs=15,
#     batch_size=16
# )

# 2. Inference with a single example


Using device: cuda:1


In [17]:
rewritten = rewrite_sarcasm_gan(
    "Sure, let’s ignore all evidence and cling to your flawless reasoning",
    gen_model_dir="sarcasm_gan_model/generator",  # Use this path instead of best_generator
    disc_model_dir="sarcasm_gan_model" 
)

print(f"Rewritten: {rewritten}")

Rewritten: There is no need to ignore all evidence and cling to my flawless reasoning.


In [19]:
# Copyright 2025 Umang Patel
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     https://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


rewritten = rewrite_sarcasm_gan(
    "I don’t need your approval, darling, I have my own",
    gen_model_dir="sarcasm_gan_model/generator",  # Use this path instead of best_generator
    disc_model_dir="sarcasm_gan_model" 
)

print(f"Rewritten: {rewritten}")

Rewritten: I don’t need your approval, I have my own opinion. 


In [10]:
rewritten = rewrite_sarcasm_gan(
    "Congratulations on arriving exactly three minutes late Your punctuality is truly inspiring",
    gen_model_dir="sarcasm_gan_model/generator",  # Use this path instead of best_generator
    disc_model_dir="sarcasm_gan_model" 
)

print(f"Rewritten: {rewritten}")

Rewritten: Having to arrive exactly three minutes late is disappointing but not disappointing.


In [11]:
rewritten = rewrite_sarcasm_gan(
    "Nice of you to show up three minutes late your timing really is something else",
    gen_model_dir="sarcasm_gan_model/generator",  # Use this path instead of best_generator
    disc_model_dir="sarcasm_gan_model" 
)

print(f"Rewritten: {rewritten}")

Rewritten: It's not nice of you to show up three minutes late. 


In [12]:
rewritten = rewrite_sarcasm_gan(
    "Congratulations on stating the obvious. I’m sure glaciers will start moving any minute now",
    gen_model_dir="sarcasm_gan_model/generator",  # Use this path instead of best_generator
    disc_model_dir="sarcasm_gan_model" 
)

print(f"Rewritten: {rewritten}")

Rewritten: There is no need to stating the obvious. 


In [13]:
rewritten = rewrite_sarcasm_gan(
    "Absolutely, let’s add that to the dozen other things I definitely wasn’t planning to do.",
    gen_model_dir="sarcasm_gan_model/generator",  # Use this path instead of best_generator
    disc_model_dir="sarcasm_gan_model" 
)

print(f"Rewritten: {rewritten}")

Rewritten: There are a lot of things I don't like doing. 
