In [30]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import gensim
from gensim.models import Word2Vec
import numpy as np
import nltk
import torch.nn as nn
import torch.nn.functional as F
import os
import re
from typing import List, Tuple, Dict
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import random
from collections import Counter

# Download required NLTK data
nltk.download('punkt')

# Hyperparameters
num_epochs = 20
learning_rate = 2e-4
target_confidence = 0.8

[nltk_data] Downloading package punkt to /home/qik/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [73]:
def preprocess_text(text: str) -> str:
    text = text.lower()
    text = ' '.join(text.split())
    text = re.sub(r'[^a-zA-Z0-9\s.,!?\'":]', '', text)
    return text

def train_word2vec(data_dir, embedding_dim=300):
    sentences = []
    for filename in ["sentiment.train.0", "sentiment.train.1"]:
        file_path = os.path.join(data_dir, filename)
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                text = preprocess_text(line)
                sentences.append(text.strip().split())
    
    model = Word2Vec(sentences=sentences, vector_size=embedding_dim, window=5, min_count=1, workers=4)
    return model

def build_vocab(data_dir: str, vocab_size: int = 10000, min_freq: int = 2) -> Dict[str, int]:
    word_counter = Counter()
    special_tokens = {
        '<PAD>': 0,
        '<UNK>': 1,
        '<BOS>': 2,
        '<EOS>': 3,
    }
    
    for filename in ["sentiment.train.0", "sentiment.train.1"]:
        file_path = os.path.join(data_dir, filename)
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                text = preprocess_text(line)
                words = text.strip().split()
                word_counter.update(words)
    
    filtered_words = [(word, count) for word, count in word_counter.items() 
                     if count >= min_freq]
    sorted_words = sorted(filtered_words, key=lambda x: x[1], reverse=True)
    most_common_words = sorted_words[:vocab_size - len(special_tokens)]
    
    vocab = dict(special_tokens)
    for idx, (word, _) in enumerate(most_common_words, start=len(special_tokens)):
        vocab[word] = idx
    
    return vocab

def initialize_embeddings(vocab, word2vec_model):
    embedding_matrix = torch.zeros((len(vocab), word2vec_model.vector_size))
    
    for word, idx in vocab.items():
        try:
            if word in word2vec_model.wv:
                embedding_matrix[idx] = torch.FloatTensor(word2vec_model.wv[word])
            else:
                embedding_matrix[idx] = torch.randn(word2vec_model.vector_size) * 0.1
        except KeyError:
            embedding_matrix[idx] = torch.randn(word2vec_model.vector_size) * 0.1
    
    return embedding_matrix

class TextDatasetTest(Dataset):
    def __init__(self, data_dir: str, vocab: Dict[str, int], max_length: int = 100, is_train: bool = True):
        super().__init__()
        self.data = []
        self.vocab = vocab
        self.max_length = max_length
        
        if is_train:
            files = ["sentiment.train.0", "sentiment.train.1"]
        else:
            files = ["sentiment.test.0", "sentiment.test.1"]
        for filename in files:
            file_path = os.path.join(data_dir, filename)
            with open(file_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
                for line in lines:
                    text = preprocess_text(line)
                    tokens = text.strip().split()
                    if len(tokens) > max_length:
                        tokens = tokens[:max_length]
                    label = 1 if filename.endswith('.1') else 0
                    self.data.append((tokens, label))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        tokens, label = self.data[idx]
        token_ids = [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens]
        return torch.tensor(token_ids, dtype=torch.long), torch.tensor(label, dtype=torch.long)

def collate_fn(batch):
    inputs, labels = zip(*batch)
    max_length = max(len(seq) for seq in inputs)
    
    padded_inputs = [torch.cat([seq, torch.zeros(max_length - len(seq), dtype=torch.long)]) for seq in inputs]
    lengths = [len(seq) for seq in inputs]
    
    return torch.stack(padded_inputs), torch.tensor(labels, dtype=torch.float), lengths

# Initialize data
data_dir = "./data/sentiment_style_transfer/yelp"
word2vec_model = train_word2vec(data_dir)
vocab = build_vocab(data_dir)
embedding_matrix = initialize_embeddings(vocab, word2vec_model)

# Create datasets and dataloaders
dataset = TextDatasetTest(data_dir, vocab, is_train=True)
data_loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn, shuffle=True, num_workers=0)
val_dataset = TextDatasetTest(data_dir, vocab, is_train=False)
val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn, shuffle=False, num_workers=0)

In [74]:
len(dataset)

443259

In [75]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class StyleClassifier(nn.Module):
    def __init__(self, style_dim, num_classes):
        """
        Style Classifier to predict style labels from the style embedding.

        Args:
            style_dim: Dimensionality of the style embedding (input size).
            num_classes: Number of style classes (output size).
        """
        super(StyleClassifier, self).__init__()
        self.fc1 = nn.Linear(style_dim, 64)  # Hidden layer with 64 units
        self.fc2 = nn.Linear(64, num_classes)  # Output layer for style classification

    def forward(self, style_embedding):
        """
        Forward pass of the Style Classifier.

        Args:
            style_embedding: Tensor of shape [batch_size, style_dim].

        Returns:
            Tensor of shape [batch_size, num_classes] with logits for each class.
        """
        x = F.relu(self.fc1(style_embedding))  # Hidden layer with ReLU activation
        logits = self.fc2(x)  # Output layer (logits for style classes)
        return logits  # Return logits directly

class ContentClassifier(nn.Module):
    def __init__(self, content_dim, vocab_size):
        """
        Content Classifier to predict BoW features from the content embedding.

        Args:
            content_dim: Dimensionality of the content embedding (input size).
            vocab_size: Size of the vocabulary (output size).
        """
        super(ContentClassifier, self).__init__()
        self.fc = nn.Linear(content_dim, vocab_size)  # Single-layer linear classifier

    def forward(self, content_embedding):
        """
        Forward pass of the Content Classifier.

        Args:
            content_embedding: Tensor of shape [batch_size, content_dim].

        Returns:
            Tensor of shape [batch_size, vocab_size] with probabilities for BoW features.
        """
        return F.softmax(self.fc(content_embedding), dim=1)  # Softmax to get probabilities

def J_mul_style(style_preds, style_labels):
    """
    Multi-task style loss: Ensures style embedding captures style-specific information.

    Args:
        style_preds: Predicted logits for the positive class (Tensor of shape [batch_size, 1]).
        style_labels: Ground truth probabilities for the positive class (Tensor of shape [batch_size, 1]).

    Returns:
        Multi-task style loss (scalar Tensor).
    """
    # Apply binary cross-entropy with logits
    loss = F.binary_cross_entropy_with_logits(style_preds.squeeze(), style_labels.float())
    return loss


def J_adv_style(adv_style_preds, style_labels):
    """
    Adversarial style loss: Ensures content embedding does not contain style-specific information.

    Args:
        adv_style_preds: Predicted logits for the positive class (Tensor of shape [batch_size, 1]).
        style_labels: Ground truth probabilities for the positive class (Tensor of shape [batch_size, 1]).

    Returns:
        Adversarial style loss (scalar Tensor).
    """
    # Negative Binary Cross-Entropy with logits to encourage fooling the classifier
    loss = -F.binary_cross_entropy_with_logits(adv_style_preds.squeeze(), style_labels.float())
    return loss

def J_mul_content(content_preds, target_bow):
    """
    Multi-task content loss: Ensures content embedding captures noun-based content information.

    Args:
        content_preds: Predicted probabilities for BoW nouns (Tensor of shape [batch_size, vocab_size]).
        target_bow: Ground truth noun-based BoW vectors (Tensor of shape [batch_size, vocab_size]).

    Returns:
        Multi-task content loss (scalar Tensor).
    """
    # Add a small epsilon to avoid log(0) errors
    epsilon = 1e-8
    loss = -torch.sum(target_bow * torch.log(content_preds + epsilon), dim=1).mean()
    return loss

def J_adv_content(adv_content_preds, target_bow):
    """
    Adversarial content loss: Ensures style embedding does not contain noun-based content information.

    Args:
        adv_content_preds: Predicted probabilities for BoW nouns from the style embedding 
                           (Tensor of shape [batch_size, vocab_size]).
        target_bow: Ground truth noun-based BoW vectors (Tensor of shape [batch_size, vocab_size]).

    Returns:
        Adversarial content loss (scalar Tensor).
    """
    # Add a small epsilon to avoid log(0) errors
    epsilon = 1e-8
    loss = -torch.sum(target_bow * torch.log(adv_content_preds + epsilon), dim=1).mean()
    return loss

class DisentangledVAE(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, style_dim, content_dim, embedding_matrix=None):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim
        
        # Embedding
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        if embedding_matrix is not None:
            self.embedding.weight.data.copy_(embedding_matrix)
            self.embedding.weight.requires_grad = True
        
        self.padding_idx = 0
        
        # Encoder
        self.encoder_rnn = nn.GRU(
            embedding_dim, 
            hidden_dim,
            batch_first=True,
            bidirectional=True,
            num_layers=2,
            dropout=0.3
        )
        
        # Multi-task components
        self.style_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        self.content_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3)
        )
        
        # Latent projections
        self.style_mu = nn.Linear(hidden_dim, style_dim)
        self.style_logvar = nn.Linear(hidden_dim, style_dim)
        self.content_mu = nn.Linear(hidden_dim, content_dim)
        self.content_logvar = nn.Linear(hidden_dim, content_dim)
        
        # Decoder components
        self.latent_to_hidden = nn.Linear(style_dim + content_dim, hidden_dim)
        self.decoder_rnn = nn.GRU(
            embedding_dim,
            hidden_dim,
            batch_first=True,
            num_layers=2,
            dropout=0.3
        )
        self.output_fc = nn.Linear(hidden_dim, vocab_size)
        
    def encode(self, x, lengths=None):
        embedded = self.embedding(x)
        output, hidden = self.encoder_rnn(embedded)
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        style_hidden = self.style_encoder(hidden)
        content_hidden = self.content_encoder(hidden)
        
        style_mu = self.style_mu(style_hidden)
        style_logvar = self.style_logvar(style_hidden)
        content_mu = self.content_mu(content_hidden)
        content_logvar = self.content_logvar(content_hidden)
        
        return style_mu, style_logvar, content_mu, content_logvar
    
    def reparameterize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        return mu
    
    def decode(self, style, content, x):
        batch_size = x.size(0)
        max_len = x.size(1)
        
        # Initialize hidden state
        hidden = self.latent_to_hidden(torch.cat([style, content], dim=1))
        hidden = hidden.unsqueeze(0).repeat(2, 1, 1)
        
        # Teacher forcing
        embedded = self.embedding(x)
        output, _ = self.decoder_rnn(embedded, hidden)
        output = self.output_fc(output)
        
        return output
    
    def forward(self, x, lengths=None):
        style_mu, style_logvar, content_mu, content_logvar = self.encode(x, lengths)
        style = self.reparameterize(style_mu, style_logvar)
        content = self.reparameterize(content_mu, content_logvar)
        recon_x = self.decode(style, content, x)
        return recon_x, style_mu, style_logvar, content_mu, content_logvar, style, content

In [76]:
# Step 1: Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Step 2: Define the VAE model
vocab_size = len(vocab)  # Replace with the actual vocab size
embedding_dim = 300  # Embedding dimension

style_dim = 128
content_dim = 128
hidden_dim = 256
num_style_classes = 1
vae = DisentangledVAE(vocab_size, embedding_dim=300, hidden_dim=256, style_dim=style_dim, content_dim=content_dim).to(device)  # Initialize the VAE model

# Style Classifier
style_classifier = StyleClassifier(style_dim=style_dim, num_classes=num_style_classes).to(device)

# Content Classifier
content_classifier = ContentClassifier(content_dim=content_dim, vocab_size=vocab_size).to(device)

# Define separate optimizers for the classifiers
style_optimizer = torch.optim.Adam(style_classifier.parameters(), lr=1e-3)
content_optimizer = torch.optim.Adam(content_classifier.parameters(), lr=1e-3)

# Step 3: Initialize the embedding layer with pre-trained embeddings if provided
if embedding_matrix is not None:
    # Copy the pre-trained embedding weights into the model
    vae.embedding.weight.data.copy_(torch.tensor(embedding_matrix))
else:
    # If no embedding matrix is provided, initialize with random embeddings
    vae.embedding.weight.data.uniform_(-0.1, 0.1)

# Step 4: Define the optimizer
learning_rate = 1e-3  # Set the learning rate
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)

# Step 5: Define the learning rate scheduler
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 
    mode='min', 
    factor=0.5, 
    patience=2, 
    verbose=True
)

# Print confirmation
print("Model, optimizer, and scheduler are set up successfully.")


Model, optimizer, and scheduler are set up successfully.


  vae.embedding.weight.data.copy_(torch.tensor(embedding_matrix))


In [77]:
def noun_bow(sentence, vocab):
    """
    Create a Bag-of-Words (BoW) vector based on nouns in the sentence.

    Args:
        sentence: A string (input sentence).
        vocab: Vocabulary (dictionary mapping nouns to indices).

    Returns:
        BoW vector (1D tensor of size vocab_size).
    """
    # Tokenize and POS tag the sentence
    tokens = nltk.word_tokenize(sentence)
    pos_tags = nltk.pos_tag(tokens)
    
    # Extract nouns based on POS tags
    nouns = [word for word, tag in pos_tags if tag.startswith('NN')]  # NN, NNS, NNP, NNPS
    
    # Count noun occurrences
    noun_counts = Counter(nouns)
    
    # Create the BoW vector
    bow_vector = torch.zeros(len(vocab), dtype=torch.float32)
    for noun, count in noun_counts.items():
        if noun in vocab:
            bow_vector[vocab[noun]] += count
    return bow_vector


In [None]:
print(f"Starting training on {device}...")
# Training hyperparameters
best_loss = float('inf')
patience = 5
patience_counter = 0

# Begin training
for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")
    
    vae.train()
    total_loss = 0
    total_recon_loss = 0
    total_kl_loss = 0
    total_style_loss = 0
    total_content_loss = 0
    total_adv_loss = 0

    # Iterate over the dataset
    for batch_idx, (input_tokens, style_labels, lengths) in enumerate(data_loader):
        input_tokens = input_tokens.to(device)
        style_labels = style_labels.to(device)
        id_to_token = {idx: token for token, idx in vocab.items()}
        bow_labels = torch.stack([
            noun_bow(" ".join([id_to_token[idx.item()] for idx in input_seq if idx.item() in id_to_token]), vocab)
            for input_seq in input_tokens
        ]).to(device)

        # Forward pass
        recon_x, style_mu, style_logvar, content_mu, content_logvar, style, content = vae(input_tokens, lengths)

        # Reconstruction loss
        recon_loss = F.cross_entropy(
            recon_x.view(-1, recon_x.size(-1)), 
            input_tokens.view(-1), 
            ignore_index=0
        )

        # KL divergence with annealing
        kl_weight = min(0.1, epoch / 20.0)  # Reduced maximum weight
        kl_style = -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp(), dim=1).mean()
        kl_content = -0.5 * torch.sum(1 + content_logvar - content_mu.pow(2) - content_logvar.exp(), dim=1).mean()
        kl_loss = (kl_style + kl_content) * kl_weight

        # Multi-task style and content losses
        style_preds = style_classifier(style)  # Style classifier predictions from style embedding
        content_preds = content_classifier(content)  # Content classifier predictions from content embedding
        #print(style_preds)
        #print(style_labels)
        style_loss = J_mul_style(style_preds, style_labels)
        content_loss = J_mul_content(content_preds, bow_labels)

        # Adversarial style and content losses
        adv_style_preds = style_classifier(content)  # Style classifier applied to content embedding
        adv_content_preds = content_classifier(style)  # Content classifier applied to style embedding
        # the adversarial are mean to be optimized, so *(-1) for them
        adv_style_loss = -J_adv_style(adv_style_preds, style_labels)
        adv_content_loss = -J_adv_content(adv_content_preds, bow_labels)
        adv_loss = (adv_style_loss + adv_content_loss) * 0.01  # Reduced adversarial weight

        # Scale losses
        style_loss *= 0.1
        content_loss *= 0.1

        # Total loss
        loss = recon_loss + kl_loss + style_loss + content_loss + adv_loss

        # Optimization
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=0.5)  # Gradient clipping
        optimizer.step()

        # Update running totals
        total_loss += loss.item()
        total_recon_loss += recon_loss.item()
        total_kl_loss += kl_loss.item()
        total_style_loss += style_loss.item()
        total_content_loss += content_loss.item()
        total_adv_loss += adv_loss.item()

        # Print batch metrics every 100 batches
        if batch_idx % 100 == 0:
            print(f"  Batch {batch_idx}/{len(data_loader)}: "
                  f"Loss={loss.item():.4f}, "
                  f"Recon={recon_loss.item():.4f}, "
                  f"KL={kl_loss.item():.4f}, "
                  f"Style={style_loss.item():.4f}, "
                  f"Content={content_loss.item():.4f}, "
                  f"Adv={adv_loss.item():.4f}")

    # Compute average losses for the epoch
    avg_total_loss = total_loss / len(data_loader)
    avg_recon_loss = total_recon_loss / len(data_loader)
    avg_kl_loss = total_kl_loss / len(data_loader)
    avg_style_loss = total_style_loss / len(data_loader)
    avg_content_loss = total_content_loss / len(data_loader)
    avg_adv_loss = total_adv_loss / len(data_loader)

    # Print detailed metrics for the epoch
    print(f"Epoch {epoch + 1} Summary:")
    print(f"  Total Loss: {avg_total_loss:.4f}")
    print(f"  Recon Loss: {avg_recon_loss:.4f}")
    print(f"  KL Loss: {avg_kl_loss:.4f}")
    print(f"  Style Loss: {avg_style_loss:.4f}")
    print(f"  Content Loss: {avg_content_loss:.4f}")
    print(f"  Adversarial Loss: {avg_adv_loss:.4f}")

    # Early stopping based on validation loss
    if avg_total_loss < best_loss:
        best_loss = avg_total_loss
        patience_counter = 0
        torch.save({
            'epoch': epoch,
            'model_state_dict': vae.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, 'best_model.pt')
        print(f"  Model saved with loss: {best_loss:.4f}")
    else:
        patience_counter += 1
        print(f"  No improvement. Patience counter: {patience_counter}/{patience}")

    # Early stopping condition
    if patience_counter >= patience:
        print("Early stopping triggered.")
        break

    # Step the learning rate scheduler
    scheduler.step(avg_total_loss)


Starting training on cuda...
Epoch 1/20
  Batch 0/6926: Loss=10.7467, Recon=9.1808, KL=0.0000, Style=0.0692, Content=1.6551, Adv=-0.1585
  Batch 100/6926: Loss=3.6707, Recon=2.4844, KL=0.0000, Style=0.0301, Content=1.4052, Adv=-0.2490
  Batch 200/6926: Loss=2.2047, Recon=1.2099, KL=0.0000, Style=0.0106, Content=1.3177, Adv=-0.3334
  Batch 300/6926: Loss=1.5922, Recon=0.7627, KL=0.0000, Style=0.0273, Content=1.1176, Adv=-0.3154
  Batch 400/6926: Loss=1.3667, Recon=0.7158, KL=0.0000, Style=0.0124, Content=0.9454, Adv=-0.3068
  Batch 500/6926: Loss=1.3535, Recon=0.5700, KL=0.0000, Style=0.0130, Content=1.0892, Adv=-0.3187
  Batch 600/6926: Loss=1.0997, Recon=0.3796, KL=0.0000, Style=0.0161, Content=1.0892, Adv=-0.3851
  Batch 700/6926: Loss=0.9356, Recon=0.3210, KL=0.0000, Style=0.0086, Content=0.8962, Adv=-0.2902
  Batch 800/6926: Loss=1.0083, Recon=0.3419, KL=0.0000, Style=0.0078, Content=1.0257, Adv=-0.3671
  Batch 900/6926: Loss=0.7342, Recon=0.2468, KL=0.0000, Style=0.0119, Content=0

In [71]:
import nltk
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
nltk.download('punkt')

def tokens_to_words(token_ids, vocab):
    inv_vocab = {v: k for k, v in vocab.items()}
    return [inv_vocab.get(token_id, '<UNK>') for token_id in token_ids if token_id not in [0, 2, 3]]

def calculate_bleu_score(data_loader, model, vocab, device):
    model.eval()
    total_bleu_score = 0
    num_sentences = 0
    smoothing_fn = SmoothingFunction().method1
    
    print("\nBLEU-S: Evaluating content preservation...\n")
    with torch.no_grad():
        for input_tokens, _, lengths in data_loader:
            input_tokens = input_tokens.to(device)
            recon_x, _, _, _, _, _, _ = model(input_tokens, lengths)
            recon_x = recon_x.argmax(dim=-1)
            
            for i in range(min(5, len(input_tokens))):
                original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
                reconstructed_sentence = tokens_to_words(recon_x[i].tolist(), vocab)
                
                if len(original_sentence) == 0 or len(reconstructed_sentence) == 0:
                    continue
                
                print(f"Original: {' '.join(original_sentence)}")
                print(f"Reconstructed: {' '.join(reconstructed_sentence)}\n")
                
                weights = [0.25, 0.25, 0.25, 0.25]
                bleu_score = sentence_bleu(
                    [original_sentence], 
                    reconstructed_sentence, 
                    weights=weights,
                    smoothing_function=smoothing_fn
                )
                total_bleu_score += bleu_score
                num_sentences += 1
            
            if num_sentences >= 20:
                break
    
    avg_bleu_score = total_bleu_score / num_sentences if num_sentences > 0 else 0
    print(f"Average BLEU-S Score: {avg_bleu_score:.4f}")
    return avg_bleu_score

[nltk_data] Downloading package punkt to /home/qik/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [60]:
def train_style_classifier(data_loader, vocab_size, device):
    classifier = StyleClassifier(vocab_size, 512, 256).to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.AdamW(
        classifier.parameters(),
        lr=0.0003,
        weight_decay=0.1
    )
    
    classifier.train()
    print("\nTraining Style Classifier...\n")
    
    for epoch in range(10):
        total_loss = 0
        correct = 0
        total = 0
        
        for input_tokens, labels, _ in data_loader:
            input_tokens = input_tokens.to(device)
            labels = labels.to(device).float()
            
            optimizer.zero_grad()
            predictions = classifier(input_tokens).squeeze()
            
            smoothed_labels = labels * 0.9 + 0.05
            loss = criterion(predictions, smoothed_labels)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(classifier.parameters(), max_norm=0.5)
            optimizer.step()
            
            total_loss += loss.item()
            preds = (predictions > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
        
        avg_loss = total_loss / len(data_loader)
        accuracy = correct / total
        print(f"Epoch {epoch + 1}/10, Loss: {avg_loss:.4f}, Accuracy: {accuracy:.4f}")
    
    return classifier

def evaluate_style_transfer(data_loader, model, classifier, vocab, device):
    model.eval()
    classifier.eval()
    correct_predictions = 0
    total_predictions = 0
    
    print("\nEvaluating Style Transfer Accuracy...\n")
    with torch.no_grad():
        for input_tokens, labels, lengths in data_loader:
            input_tokens = input_tokens.to(device)
            labels = labels.to(device)
            
            # Get reconstructed text
            recon_x, _, _, _, _, style, _ = model(input_tokens, lengths)
            recon_x = recon_x.argmax(dim=-1)
            
            # Use the separate classifier for style predictions
            style_predictions = classifier(recon_x).squeeze()
            style_labels = (style_predictions > 0.5).float()
            
            correct_predictions += (style_labels == labels).sum().item()
            total_predictions += labels.size(0)
            
            # Print examples
            for i in range(min(5, len(input_tokens))):
                original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
                reconstructed_sentence = tokens_to_words(recon_x[i].tolist(), vocab)
                
                if len(original_sentence) > 0 and len(reconstructed_sentence) > 0:
                    print(f"Original: {' '.join(original_sentence)}")
                    print(f"Reconstructed: {' '.join(reconstructed_sentence)}")
                    print(f"Style Prediction: {style_predictions[i].item():.4f}, True Style: {labels[i].item()}\n")
            
            if total_predictions >= 100:
                break
    
    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    print(f"Style Transfer Accuracy: {accuracy:.4f}")
    return accuracy

# Run evaluation
print("\n=== Starting Evaluation ===\n")
    
print("Training style classifier...")
classifier = train_style_classifier(data_loader, len(vocab), device)

print("\n=== Content Preservation (BLEU-S Score) ===")
bleu_score = calculate_bleu_score(data_loader_test, vae, vocab, device)

print("\n=== Style Transfer Accuracy ===")
style_transfer_accuracy = evaluate_style_transfer(data_loader_test, vae, classifier, vocab, device)

print("\n=== Final Results ===")
print(f"BLEU-S Score: {bleu_score:.4f}")
print(f"Style Transfer Accuracy: {style_transfer_accuracy:.4f}")

def run_evaluation(data_loader_train, data_loader_test, model, vocab, vocab_size, device):
    print("\n=== Starting Evaluation ===\n")
    
    print("Training style classifier...")
    classifier = train_style_classifier(data_loader_train, vocab_size, device)
    
    print("\n=== Content Preservation (BLEU-S Score) ===")
    bleu_score = calculate_bleu_score(data_loader_test, model, vocab, device)
    
    print("\n=== Style Transfer Accuracy ===")
    style_transfer_accuracy = evaluate_style_transfer(data_loader_test, model, classifier, vocab, device)
    
    print("\n=== Final Results ===")
    print(f"BLEU-S Score: {bleu_score:.4f}")
    print(f"Style Transfer Accuracy: {style_transfer_accuracy:.4f}")
    
    return bleu_score, style_transfer_accuracy

# Run the evaluation
print("Starting evaluation pipeline...")
bleu_score, style_accuracy = run_evaluation(
    data_loader, 
    data_loader_test, 
    vae, 
    vocab, 
    len(vocab), 
    device
)


=== Starting Evaluation ===

Training style classifier...


TypeError: __init__() takes 3 positional arguments but 4 were given