In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm  
from dataloader import *
from model import * 
from nltk.translate.bleu_score import sentence_bleu
import nltk 
import torch.nn as nn
import torch.nn.functional as F


# Hyperparameters
num_epochs = 10
learning_rate = 0.001
target_confidence = 0.8 

In [2]:
def collate_fn(batch):
    inputs, labels = zip(*batch)
    max_length = max(len(seq) for seq in inputs)
    
    # Convert each sequence to a list, pad with 0, and convert to tensor
    padded_inputs = [torch.cat([seq, torch.zeros(max_length - len(seq), dtype=torch.long)]) for seq in inputs]
    lengths = [len(seq) for seq in inputs]
    
    return torch.stack(padded_inputs), torch.tensor(labels, dtype=torch.float), lengths

def tokens_to_words(token_ids, vocab):
    inv_vocab = {v: k for k, v in vocab.items()}
    return [inv_vocab.get(token_id, '<UNK>') for token_id in token_ids if token_id != 0]  # Exclude padding


class TextDatasetTest(Dataset):
    def __init__(self, data_dir, vocab):
        super(TextDatasetTest, self).__init__()
        self.data = []
        self.vocab = vocab

        # Load data from the files
        files = ["sentiment.test.0", "sentiment.test.1"]
        for filename in files:
            file_path = os.path.join(data_dir, filename)
            with open(file_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
                for line in lines:
                    tokens = line.strip().split()
                    label = 1 if filename.endswith('.1') else 0  # Binary label
                    self.data.append((tokens, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, label = self.data[idx]
        token_ids = [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens]
        return torch.tensor(token_ids, dtype=torch.long), torch.tensor(label, dtype=torch.long)
    
data_dir = "./data/sentiment_style_transfer/yelp"
vocab = build_vocab(data_dir)
dataset = TextDatasetTest(data_dir, vocab)
data_loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn, shuffle=True)

In [6]:
class DisentangledVAE(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, style_dim, content_dim):
        super(DisentangledVAE, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.hidden_dim = hidden_dim
        
        # Encoder
        self.encoder_rnn = nn.GRU(embedding_dim, hidden_dim, 
                                 batch_first=True, 
                                 bidirectional=True,
                                 num_layers=2,
                                 dropout=0.2)
        
        # Latent spaces
        self.style_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.content_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Mean and logvar projections
        self.style_mu = nn.Linear(hidden_dim, style_dim)
        self.style_logvar = nn.Linear(hidden_dim, style_dim)
        self.content_mu = nn.Linear(hidden_dim, content_dim)
        self.content_logvar = nn.Linear(hidden_dim, content_dim)
        
        # Decoder
        self.latent_to_hidden = nn.Sequential(
            nn.Linear(style_dim + content_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        self.decoder_rnn = nn.GRU(embedding_dim + hidden_dim, hidden_dim,
                                 batch_first=True,
                                 num_layers=2,
                                 dropout=0.2)
        
        self.output_fc = nn.Linear(hidden_dim, vocab_size)
        
        # Style classifier for adversarial training
        self.style_classifier = nn.Sequential(
            nn.Linear(style_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def encode(self, x, lengths=None):
        batch_size = x.size(0)
        
        # Embed input
        embedded = self.embedding(x)
        
        # Pack for variable length sequences
        if lengths is not None:
            embedded = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths, batch_first=True, enforce_sorted=False
            )
        
        # Encode
        _, hidden = self.encoder_rnn(embedded)
        # Combine bidirectional states
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        # Encode style and content
        style_hidden = self.style_encoder(hidden)
        content_hidden = self.content_encoder(hidden)
        
        # Get latent parameters
        style_mu = self.style_mu(style_hidden)
        style_logvar = self.style_logvar(style_hidden)
        content_mu = self.content_mu(content_hidden)
        content_logvar = self.content_logvar(content_hidden)
        
        return style_mu, style_logvar, content_mu, content_logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def decode(self, style, content, x):
        batch_size = x.size(0)
        max_len = x.size(1)
        
        # Combine latent vectors
        latent = torch.cat([style, content], dim=1)
        hidden = self.latent_to_hidden(latent)
        
        # Initialize decoder hidden state
        hidden = hidden.unsqueeze(0).repeat(2, 1, 1)  # num_layers * batch * hidden
        
        # Teacher forcing with concatenated latent
        embedded = self.embedding(x)
        hidden_expanded = hidden[-1].unsqueeze(1).repeat(1, max_len, 1)
        decoder_input = torch.cat([embedded, hidden_expanded], dim=2)
        
        # Decode
        outputs, _ = self.decoder_rnn(decoder_input, hidden)
        outputs = self.output_fc(outputs)
        
        return outputs

    def forward(self, x, lengths=None):
        # Encode
        style_mu, style_logvar, content_mu, content_logvar = self.encode(x, lengths)
        
        # Sample latent vectors
        style = self.reparameterize(style_mu, style_logvar)
        content = self.reparameterize(content_mu, content_logvar)
        
        # Decode
        recon_x = self.decode(style, content, x)
        
        return recon_x, style_mu, style_logvar, content_mu, content_logvar, style, content

    def classify_style(self, style):
        return self.style_classifier(style)

def vae_loss(recon_x, x, style_mu, style_logvar, content_mu, content_logvar):
    recon_loss = F.cross_entropy(recon_x.view(-1, recon_x.size(-1)), x.view(-1), ignore_index=0)  # Reconstruction loss
    kl_style = -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())  # KL divergence for style
    kl_content = -0.5 * torch.sum(1 + content_logvar - content_mu.pow(2) - content_logvar.exp())  # KL divergence for content
    return recon_loss + kl_style + kl_content

def multi_task_loss(style_preds, style_labels, content_preds, content_labels):
    style_loss = F.cross_entropy(style_preds, style_labels)  # Style classification loss
    content_loss = F.cross_entropy(content_preds, content_labels)  # Content classification loss
    return style_loss + content_loss

def adversarial_loss(style_preds, content_preds):
    adversarial_style_loss = -F.cross_entropy(style_preds, torch.zeros_like(style_preds))  # Fool style classifier
    adversarial_content_loss = -F.cross_entropy(content_preds, torch.zeros_like(content_preds))  # Fool content classifier
    return adversarial_style_loss + adversarial_content_loss


In [12]:
# Modified training loop
def train_vae_with_dataset(vae, optimizer, data_loader, device):
    """
    Train the VAE with a custom dataset.
    """
    vae.train()
    total_loss = 0

    for input_tokens, style_labels, lengths in data_loader:
        # Move data to the appropriate device
        input_tokens = input_tokens.to(device)
        style_labels = style_labels.to(device)
        
        # Forward pass through the model
        recon_x, style_mu, style_logvar, content_mu, content_logvar, style, content = vae(input_tokens, lengths)
        
        # Calculate reconstruction loss
        loss_vae = F.cross_entropy(recon_x.view(-1, recon_x.size(-1)), input_tokens.view(-1), ignore_index=0)
        
        # KL divergence losses
        kl_style = -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())
        kl_content = -0.5 * torch.sum(1 + content_logvar - content_mu.pow(2) - content_logvar.exp())
        
        # Style classification loss
        style_preds = vae.classify_style(style).squeeze()  # Make sure predictions are the right shape
        style_labels = style_labels.float()  # Convert to float
        loss_multi_task = F.binary_cross_entropy(style_preds, style_labels)
        
        # Total loss
        loss = loss_vae + 0.1 * (kl_style + kl_content) + loss_multi_task
        
        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)  # Add gradient clipping
        optimizer.step()
        
        total_loss += loss.item()

    return total_loss / len(data_loader)
# Training parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = len(vocab)
embedding_dim = 256
hidden_dim = 512
style_dim = 32
content_dim = 256
learning_rate = 5e-4
epochs = 10

# Initialize model and optimizer
vae = DisentangledVAE(vocab_size, embedding_dim, hidden_dim, style_dim, content_dim).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    loss = train_vae_with_dataset(vae, optimizer, data_loader, device)
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss:.4f}")



Epoch 1/10, Loss: 8.8364
Epoch 2/10, Loss: 6.3012
Epoch 3/10, Loss: 5.4021
Epoch 4/10, Loss: 4.5849
Epoch 5/10, Loss: 3.7619
Epoch 6/10, Loss: 3.0867
Epoch 7/10, Loss: 2.5593
Epoch 8/10, Loss: 2.1242
Epoch 9/10, Loss: 1.7939
Epoch 10/10, Loss: 1.5404


In [13]:
class StyleClassifier(nn.Module):
    def _init_(self, vocab_size, embedding_dim, hidden_dim):
        super(StyleClassifier, self)._init_()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder = nn.GRU(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.embedding(x)
        _, hidden = self.encoder(x)
        # Combine bidirectional states
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        return self.fc(hidden)

#### UPDATED PART ABOVE, NOT UPDATED PART BELOW

In [11]:
torch.save(model, 'model_complete.pth')

In [8]:
model_1 = torch.load('model_complete.pth') 

  model_1 = torch.load('model_complete.pth')


In [12]:


# Inspect some sentences from the data loader
model.eval()  # Set the model to evaluation mode
model_1.eval()
with torch.no_grad():
    for input_tokens, _, lengths in data_loader_test:
        input_tokens = input_tokens.to(device)
        x_reconstructed, _, _, _ = model_1(input_tokens)
        x_reconstructed = x_reconstructed.argmax(dim=-1)  # Get the predicted token IDs

        # Print a few input and output sentences
        for i in range(5):  # Print 5 examples
            original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
            reconstructed_sentence = tokens_to_words(x_reconstructed[i].tolist(), vocab)

            print("Original Sentence: \t\t\t", " ".join(original_sentence))
            print("Reconstructed Sentence: \t\t", " ".join(reconstructed_sentence))
            print()

        break  # Only inspect the first batch

Original Sentence: 			 they only received one star because you have to provide a rating .
Reconstructed Sentence: 		 they only received one star because you have to provide a rating .

Original Sentence: 			 always takes way too long even if you 're the only one there .
Reconstructed Sentence: 		 always takes way too long even if you 're the only one there .

Original Sentence: 			 she could not and would not explain herself .
Reconstructed Sentence: 		 she could not and would not explain herself .

Original Sentence: 			 all she did was give me the run around and lied and bs everything .
Reconstructed Sentence: 		 all she did was give me the run around and lied and bs everything .

Original Sentence: 			 it does not take that long to cook sliders !
Reconstructed Sentence: 		 it does not take that long to cook sliders !



In [16]:
def train_style_classifier(data_loader, vocab_size, device):
    classifier = StyleClassifier(vocab_size, 300, 128).to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

    classifier.train()
    for epoch in range(20):  # Train for a few epochs
        total_loss = 0
        for input_tokens, labels, _ in data_loader:  # Adjusted to unpack three values
            input_tokens = input_tokens.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            predictions = classifier(input_tokens)
            loss = criterion(predictions, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/5, Loss: {total_loss / len(data_loader)}")
    
    return classifier

def evaluate_style_transfer(data_loader, model, classifier, device):
    model.eval()
    classifier.eval()
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for input_tokens, labels, _ in data_loader:
            input_tokens = input_tokens.to(device)
            labels = labels.to(device)

            # Get the reconstructed sentences
            x_reconstructed, _, _, _ = model(input_tokens)
            x_reconstructed = x_reconstructed.argmax(dim=-1)

            # Predict the style of the reconstructed sentences
            style_predictions = classifier(x_reconstructed)
            style_labels = (style_predictions > 0.5).float()
            
            correct_predictions += (style_labels == labels).sum().item()
            total_predictions += labels.size(0)
    
    accuracy = correct_predictions / total_predictions
    print(f"Style Transfer Accuracy: {accuracy:.4f}")

In [17]:
classifier = train_style_classifier(data_loader, len(vocab), device)
evaluate_style_transfer(data_loader_test, model_1, classifier, device)

KeyboardInterrupt: 

In [18]:
#### updated code 
import torch
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import nltk
nltk.download('punkt')

def tokens_to_words(token_ids, vocab):
    inv_vocab = {v: k for k, v in vocab.items()}
    return [inv_vocab.get(token_id, '<UNK>') for token_id in token_ids if token_id != 0]

def calculate_bleu_score(data_loader, model, vocab, device):
    model.eval()
    total_bleu_score = 0
    num_sentences = 0
    smoothing_fn = SmoothingFunction().method1

    print("\nBLEU-S: Evaluating content preservation...\n")
    with torch.no_grad():
        for input_tokens, _, lengths in data_loader:
            input_tokens = input_tokens.to(device)
            x_reconstructed, _, _, _ = model(input_tokens)
            x_reconstructed = x_reconstructed.argmax(dim=-1)

            for i in range(min(5, len(input_tokens))):  
                original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
                reconstructed_sentence = tokens_to_words(x_reconstructed[i].tolist(), vocab)
                print(f"Original: {' '.join(original_sentence)}")
                print(f"Reconstructed: {' '.join(reconstructed_sentence)}\n")

                bleu_score = sentence_bleu([original_sentence], reconstructed_sentence, smoothing_function=smoothing_fn)
                total_bleu_score += bleu_score
                num_sentences += 1

            break  # Evaluate only on the first batch for now

    avg_bleu_score = total_bleu_score / num_sentences if num_sentences > 0 else 0
    print(f"Average BLEU-S Score: {avg_bleu_score:.4f}")
    return avg_bleu_score

def train_style_classifier(data_loader, vocab_size, device):
    classifier = StyleClassifier(vocab_size, 300, 128).to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

    classifier.train()
    print("\nTraining Style Classifier...\n")
    for epoch in range(20):
        total_loss = 0
        for input_tokens, labels, _ in data_loader:
            input_tokens = input_tokens.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            predictions = classifier(input_tokens)
            loss = criterion(predictions, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/5, Loss: {total_loss / len(data_loader):.4f}")
    
    return classifier

def evaluate_style_transfer(data_loader, model, classifier, vocab, device):
    model.eval()
    classifier.eval()
    correct_predictions = 0
    total_predictions = 0

    print("\nEvaluating Style Transfer Accuracy...\n")
    with torch.no_grad():
        for input_tokens, labels, _ in data_loader:
            input_tokens = input_tokens.to(device)
            labels = labels.to(device)

            x_reconstructed, _, _, _ = model(input_tokens)
            x_reconstructed = x_reconstructed.argmax(dim=-1)

            style_predictions = classifier(x_reconstructed)
            style_labels = (style_predictions > 0.5).float()
            correct_predictions += (style_labels == labels).sum().item()
            total_predictions += labels.size(0)

            for i in range(min(5, len(input_tokens))):
                original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
                reconstructed_sentence = tokens_to_words(x_reconstructed[i].tolist(), vocab)
                print(f"Original: {' '.join(original_sentence)}")
                print(f"Reconstructed: {' '.join(reconstructed_sentence)}")
                print(f"Style Prediction: {style_labels[i].item()}, True Style: {labels[i].item()}\n")

            break  # Evaluate only on the first batch for now

    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    print(f"Style Transfer Accuracy: {accuracy:.4f}")
    return accuracy

def run_evaluation(data_loader_train, data_loader_test, model, vocab, vocab_size, device):
    # Train Style Classifier
    classifier = train_style_classifier(data_loader_train, vocab_size, device)

    print("\n--- BLEU-S Score (Content Preservation) ---")
    bleu_score = calculate_bleu_score(data_loader_test, model, vocab, device)

    print("\n--- Style Transfer Accuracy ---")
    style_transfer_accuracy = evaluate_style_transfer(data_loader_test, model, classifier, vocab, device)

    print("\n--- Final Results ---")
    print(f"BLEU-S Score: {bleu_score:.4f}")
    print(f"Style Transfer Accuracy: {style_transfer_accuracy:.4f}")

    return bleu_score, style_transfer_accuracy

run_evaluation(data_loader, data_loader_test, model, vocab, len(vocab), device)

[nltk_data] Downloading package punkt to /home/qik/nltk_data...
[nltk_data]   Package punkt is already up-to-date!



Training Style Classifier...

Epoch 1/5, Loss: 0.0954
Epoch 2/5, Loss: 0.0522
Epoch 3/5, Loss: 0.0387
Epoch 4/5, Loss: 0.0300
Epoch 5/5, Loss: 0.0241
Epoch 6/5, Loss: 0.0205
Epoch 7/5, Loss: 0.0178
Epoch 8/5, Loss: 0.0158
Epoch 9/5, Loss: 0.0145
Epoch 10/5, Loss: 0.0139
Epoch 11/5, Loss: 0.0132
Epoch 12/5, Loss: 0.0124
Epoch 13/5, Loss: 0.0116
Epoch 14/5, Loss: 0.0117
Epoch 15/5, Loss: 0.0116
Epoch 16/5, Loss: 0.0113
Epoch 17/5, Loss: 0.0113
Epoch 18/5, Loss: 0.0108
Epoch 19/5, Loss: 0.0105
Epoch 20/5, Loss: 0.0104

--- BLEU-S Score (Content Preservation) ---

BLEU-S: Evaluating content preservation...

Original: at this location the service was terrible .
Reconstructed: so brand scale crumbs crappy steady shogun shogun adds adds adds adds adds adds adds

Original: i ordered garlic bread and fettuccine alfredo pasta with vegetables .
Reconstructed: anyways a+ cancer crusted cancer middle towing crusted middle watered belgian camarones inspector inspector inspector

Original: i did n't

(0.0, 0.546875)