In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm  
from dataloader import *
from model import * 
from nltk.translate.bleu_score import sentence_bleu
import nltk

def collate_fn(batch):
    inputs, labels = zip(*batch)
    max_length = max(len(seq) for seq in inputs)
    
    # Convert each sequence to a list, pad with 0, and convert to tensor
    padded_inputs = [torch.cat([seq, torch.zeros(max_length - len(seq), dtype=torch.long)]) for seq in inputs]
    lengths = [len(seq) for seq in inputs]
    
    return torch.stack(padded_inputs), torch.tensor(labels, dtype=torch.float), lengths

# Hyperparameters
num_epochs = 10
learning_rate = 0.001
target_confidence = 0.8 

In [2]:
 
data_dir = "./data/sentiment_style_transfer/yelp"
vocab = build_vocab(data_dir)
dataset = TextDataset(data_dir, vocab)
data_loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = StyleTransferModel(len(vocab), 300, 256, 16, 128).to(device)  
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    progress_bar = tqdm(data_loader, desc=f"Epoch {epoch + 1}/{num_epochs}", unit="batch")
    
    for input_tokens, labels, lengths in progress_bar:
        input_tokens = input_tokens.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        x_reconstructed, style_mean, content_mean, s_prime = model(input_tokens, target_confidence)
        style_logvar = torch.zeros_like(style_mean)
        content_logvar = torch.zeros_like(content_mean)
        loss = vae_loss(x_reconstructed, input_tokens, style_mean, style_logvar, content_mean, content_logvar)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        
        # Update the progress bar with the current loss
        progress_bar.set_postfix(loss=epoch_loss / (progress_bar.n + 1))
    
    print(f"Epoch {epoch + 1}/{num_epochs} completed. Average Loss: {epoch_loss / len(data_loader)}")


Epoch 1/10: 100%|██████████| 6926/6926 [05:34<00:00, 20.72batch/s, loss=0.618]


Epoch 1/10 completed. Average Loss: 0.6183889831284934


Epoch 2/10: 100%|██████████| 6926/6926 [05:37<00:00, 20.50batch/s, loss=0.229] 


Epoch 2/10 completed. Average Loss: 0.22921297018383416


Epoch 3/10: 100%|██████████| 6926/6926 [05:34<00:00, 20.71batch/s, loss=0.299]


Epoch 3/10 completed. Average Loss: 0.2993834251228307


Epoch 4/10: 100%|██████████| 6926/6926 [05:33<00:00, 20.77batch/s, loss=0.224] 


Epoch 4/10 completed. Average Loss: 0.22419913678819037


Epoch 5/10: 100%|██████████| 6926/6926 [05:29<00:00, 21.01batch/s, loss=0.616] 


Epoch 5/10 completed. Average Loss: 0.6160696460356646


Epoch 6/10: 100%|██████████| 6926/6926 [05:30<00:00, 20.96batch/s, loss=0.33]  


Epoch 6/10 completed. Average Loss: 0.329525139434294


Epoch 7/10: 100%|██████████| 6926/6926 [05:35<00:00, 20.62batch/s, loss=0.279]


Epoch 7/10 completed. Average Loss: 0.2789108553194553


Epoch 8/10: 100%|██████████| 6926/6926 [05:34<00:00, 20.73batch/s, loss=0.391]


Epoch 8/10 completed. Average Loss: 0.3910796397221144


Epoch 9/10: 100%|██████████| 6926/6926 [05:30<00:00, 20.93batch/s, loss=0.348]


Epoch 9/10 completed. Average Loss: 0.34840153873824936


Epoch 10/10: 100%|██████████| 6926/6926 [05:25<00:00, 21.27batch/s, loss=0.251]

Epoch 10/10 completed. Average Loss: 0.2513471563543881





In [11]:
torch.save(model, 'model_complete.pth')

In [12]:
model_1 = torch.load('model_complete.pth') 

  model_1 = torch.load('model_complete.pth')


In [6]:
import torch
from torch.utils.data import Dataset, DataLoader
import os
from collections import Counter

class TextDatasetTest(Dataset):
    def __init__(self, data_dir, vocab):
        super(TextDatasetTest, self).__init__()
        self.data = []
        self.vocab = vocab

        # Load data from the files
        files = ["sentiment.test.0", "sentiment.test.1"]
        for filename in files:
            file_path = os.path.join(data_dir, filename)
            with open(file_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
                for line in lines:
                    tokens = line.strip().split()
                    label = 1 if filename.endswith('.1') else 0  # Binary label
                    self.data.append((tokens, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, label = self.data[idx]
        token_ids = [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens]
        return torch.tensor(token_ids, dtype=torch.long), torch.tensor(label, dtype=torch.long)

In [7]:
vocab = build_vocab(data_dir)
dataset = TextDatasetTest(data_dir, vocab)
data_loader_test = DataLoader(dataset, batch_size=64, collate_fn=collate_fn, shuffle=True)

In [13]:
# Download necessary NLTK resources
nltk.download('punkt')

# Function to convert token IDs back to words using the vocabulary
def tokens_to_words(token_ids, vocab):
    inv_vocab = {v: k for k, v in vocab.items()}
    return [inv_vocab.get(token_id, '<UNK>') for token_id in token_ids if token_id != 0]  # Exclude padding

# Function to calculate BLEU score for a batch
def calculate_bleu_score(data_loader, model, vocab, device):
    model.eval()  # Set the model to evaluation mode
    total_bleu_score = 0
    num_sentences = 0
    
    counter = 0
    with torch.no_grad():
        for input_tokens, _, lengths in data_loader:
            input_tokens = input_tokens.to(device)
            x_reconstructed, _, _, _ = model(input_tokens)
            x_reconstructed = x_reconstructed.argmax(dim=-1)  # Get the predicted token IDs

            # Calculate BLEU score for each sentence
            for i in range(len(input_tokens)):
                original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
                reconstructed_sentence = tokens_to_words(x_reconstructed[i].tolist(), vocab)

                counter += 1
                if counter % 100 == 0:
                    print(original_sentence, reconstructed_sentence)
                # Calculate BLEU score
                bleu_score = sentence_bleu([original_sentence], reconstructed_sentence)
                total_bleu_score += bleu_score
                num_sentences += 1

    # Return the average BLEU score
    return total_bleu_score / num_sentences if num_sentences > 0 else 0

# Calculate the BLEU score
bleu_score = calculate_bleu_score(data_loader_test, model_1, vocab, device)
print(f"Average BLEU Score: {bleu_score:.4f}")

[nltk_data] Downloading package punkt to /home/qik/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


['they', 'failed', 'to', 'tell', 'us', 'eating', 'in', 'the', 'bar', 'was', 'an', 'option', '.'] ['they', 'failed', 'to', 'tell', 'us', 'eating', 'in', 'the', 'bar', 'was', 'an', 'option', '.']
['the', 'menudo', 'here', 'is', 'perfect', '.'] ['the', 'menudo', 'here', 'is', 'perfect', '.']
['best', 'chicken', 'parmesan', 'i', 'have', 'ever', 'had', '.'] ['best', 'chicken', 'parmesan', 'i', 'have', 'ever', 'had', '.']
['the', 'surly', 'older', 'waitress', 'was', 'a', 'huge', 'bummer', '.'] ['the', 'surly', 'older', 'waitress', 'was', 'a', 'huge', 'bummer', '.']
['pricing', 'is', 'both', 'affordable', 'and', 'reasonable', '.'] ['pricing', 'is', 'both', 'affordable', 'and', 'reasonable', '.']
['the', 'firecracker', 'shrimp', 'and', 'duck', 'is', 'also', 'always', 'a', 'winner', '.'] ['the', 'firecracker', 'shrimp', 'and', 'duck', 'is', 'also', 'always', 'a', 'winner', '.']
['he', 'was', 'both', 'professional', 'and', 'courteous', '.'] ['he', 'was', 'both', 'professional', 'and', 'courteous

In [14]:
def tokens_to_words(token_ids, vocab):
    inv_vocab = {v: k for k, v in vocab.items()}
    return [inv_vocab.get(token_id, '<UNK>') for token_id in token_ids if token_id != 0]  # Exclude padding

# Inspect some sentences from the data loader
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    for input_tokens, _, lengths in data_loader_test:
        input_tokens = input_tokens.to(device)
        x_reconstructed, _, _, _ = model_1(input_tokens)
        x_reconstructed = x_reconstructed.argmax(dim=-1)  # Get the predicted token IDs

        # Print a few input and output sentences
        for i in range(5):  # Print 5 examples
            original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
            reconstructed_sentence = tokens_to_words(x_reconstructed[i].tolist(), vocab)

            print("Original Sentence: ", " ".join(original_sentence))
            print("Reconstructed Sentence: ", " ".join(reconstructed_sentence))
            print()

        break  # Only inspect the first batch

Original Sentence:  the green enchiladas were ok but not great .
Reconstructed Sentence:  the green enchiladas were ok but not great .

Original Sentence:  however , this experience went pretty smooth .
Reconstructed Sentence:  however , this experience went pretty smooth .

Original Sentence:  giving an extra star for customer service .
Reconstructed Sentence:  giving an extra star for customer service .

Original Sentence:  the tow package is not an issue either .
Reconstructed Sentence:  the tow package is not an issue either .

Original Sentence:  the grounds are always very clean .
Reconstructed Sentence:  the grounds are always very clean .

