In [1]:
import numpy as np
import json
import os
from scipy.interpolate import CubicSpline, interp1d
from itertools import combinations

EMBEDDING_SIZE = 500

# Load the embeddings from the JSON file
def load_embeddings(filepath):
    with open(filepath, 'r') as file:
        data = json.load(file)
    return data

# Interpolation method to bring all embeddings to the same size
def interpolate_embeddings(embeddings, target_dim=EMBEDDING_SIZE, method='linear'):
    interpolated_embeddings = []
    for embed in embeddings:
        original_indices = np.linspace(0, 1, num=len(embed))
        target_indices = np.linspace(0, 1, num=target_dim)
        
        if method == 'linear':
            interpolated_embed = np.interp(target_indices, original_indices, embed)
        elif method == 'cubic':
            cs = CubicSpline(original_indices, embed)
            interpolated_embed = cs(target_indices)
        elif method == 'quadratic':
            quadratic_interp = interp1d(original_indices, embed, kind='quadratic')
            interpolated_embed = quadratic_interp(target_indices)
        else:
            raise ValueError(f"Unsupported interpolation method: {method}")
        
        interpolated_embeddings.append(interpolated_embed)
    interpolated_embeddings = np.array(interpolated_embeddings)
    return interpolated_embeddings

# Function to compare embeddings
def compare_embeddings(embeddings1, embeddings2):
    similarities = []
    for embed1, embed2 in zip(embeddings1, embeddings2):
        similarity = np.dot(embed1, embed2) / (np.linalg.norm(embed1) * np.linalg.norm(embed2))
        similarities.append(similarity)
    return similarities

# Function to process a single file and compare interpolation methods
def compare_methods(filepath, methods=['linear', 'cubic', 'quadratic']):
    data = load_embeddings(filepath)
    embeddings = [item['embedding'] for item in data['embeddings']]
    
    comparisons = {
        'original': compare_embeddings_combinations(embeddings)
    }
    
    for method in methods:
        interpolated_embeddings = interpolate_embeddings(embeddings, target_dim=EMBEDDING_SIZE, method=method)
        method_similarities = []
        for original_embed in embeddings:
            original_embed_interp = interpolate_embeddings([original_embed], target_dim=EMBEDDING_SIZE, method=method)[0]
            method_similarities.append(compare_embeddings([original_embed], [original_embed_interp])[0])
        comparisons[method] = method_similarities
    
    return comparisons

# Function to compare all combinations of original embeddings
def compare_embeddings_combinations(embeddings):
    similarities = []
    for (embed1, embed2) in combinations(embeddings, 2):
        similarity = np.dot(embed1, embed2) / (np.linalg.norm(embed1) * np.linalg.norm(embed2))
        similarities.append(similarity)
    return similarities

# Example usage to compare different interpolation methods for all files in a directory
def compare_all_files(directory, methods=['linear', 'cubic', 'quadratic']):
    all_comparisons = {'original': []}
    for method in methods:
        all_comparisons[method] = []
    
    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            filepath = os.path.join(directory, filename)
            comparisons = compare_methods(filepath, methods=methods)
            for key in comparisons:
                all_comparisons[key].extend(comparisons[key])
    
    # Calculate cumulative metrics
    cumulative_metrics = {}
    for key in all_comparisons:
        cumulative_metrics[key] = {
            'mean_similarity': np.mean(all_comparisons[key]),
            'std_similarity': np.std(all_comparisons[key])
        }
    
    return cumulative_metrics

# Path to the directory containing the JSON files
directory_path = '/workspace/slice-monorepo/thebeast/chat_pipeline/data/test/step_2/embeddings'

# Compare all files in the directory
cumulative_metrics = compare_all_files(directory_path)

# Display the cumulative metrics
print("Cumulative Metrics:")
for method, metrics in cumulative_metrics.items():
    print(f"Method: {method}")
    print(f"  Mean Similarity: {metrics['mean_similarity']:.4f}")
    print(f"  Std Similarity: {metrics['std_similarity']:.4f}")


ModuleNotFoundError: No module named 'scipy'

In [3]:
import numpy as np
import json
import os

EMBEDDING_SIZE = 500

# Load the embeddings from the JSON file
def load_embeddings(filepath):
    with open(filepath, 'r') as file:
        data = json.load(file)
    return data

# Interpolation method to bring all embeddings to the same size
def interpolate_embeddings(embeddings, target_dim=EMBEDDING_SIZE):
    interpolated_embeddings = []
    for embed in embeddings:
        original_indices = np.linspace(0, 1, num=len(embed))
        target_indices = np.linspace(0, 1, num=target_dim)
        interpolated_embed = np.interp(target_indices, original_indices, embed)
        interpolated_embeddings.append(interpolated_embed)
    interpolated_embeddings = np.array(interpolated_embeddings)
    return interpolated_embeddings

# Mean pooling function
def mean_pooling(embeddings, target_dim=EMBEDDING_SIZE):
    interpolated_embeddings = interpolate_embeddings(embeddings, target_dim)
    return np.mean(interpolated_embeddings, axis=0)

# Function to process a single file
def process_file(filepath):
    data = load_embeddings(filepath)
    embeddings = [item['embedding'] for item in data['embeddings']]
    
    # Apply Mean Pooling
    mean_pooled_embedding = mean_pooling(embeddings, target_dim=EMBEDDING_SIZE)
    
    # Add combined embeddings to the data
    data['mean_pooled'] = mean_pooled_embedding.tolist()
    
    with open(filepath, 'w') as file:
        json.dump(data, file, indent=4)
    
    return filepath

# Function to process all files in a directory
def process_directory(directory):
    updated_files = []
    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            filepath = os.path.join(directory, filename)
            updated_filepath = process_file(filepath)
            updated_files.append(updated_filepath)
    return updated_files

# Path to the directory containing the JSON files
directory_path = '/workspace/slice-monorepo/thebeast/chat_pipeline/data/test/step_2/embeddings'

# Process all files in the directory
updated_files = process_directory(directory_path)

# Display the updated files
print(updated_files)


['/workspace/slice-monorepo/thebeast/chat_pipeline/data/test/step_2/embeddings/193f9423-65e4-4ca2-a05d-5ef9414a65c1.json']


In [None]:
 import numpy as np
from sklearn.decomposition import PCA

# Generate fake embedding dataset with different lengths
def generate_fake_embeddings(num_groups=10, min_len=5, max_len=20, embedding_dim=50):
    np.random.seed(42)  # For reproducibility
    embeddings = []
    for _ in range(num_groups):
        group_len = np.random.randint(min_len, max_len)
        group_embeddings = [np.random.rand(np.random.randint(embedding_dim//2, embedding_dim*2)) for _ in range(group_len)]
        embeddings.append(group_embeddings)
    return embeddings

# Interpolation method to bring all embeddings to the same size
def interpolate_embeddings(embeddings, target_dim=EMBEDDING_SIZE):
    interpolated_embeddings = []
    for embed in embeddings:
        original_indices = np.linspace(0, 1, num=len(embed))
        target_indices = np.linspace(0, 1, num=target_dim)
        interpolated_embed = np.interp(target_indices, original_indices, embed)
        interpolated_embeddings.append(interpolated_embed)
    return interpolated_embeddings

# PCA function to project embeddings to a fixed size
def apply_pca(embeddings, target_dim=EMBEDDING_SIZE):
    interpolated_embeddings = interpolate_embeddings(embeddings, target_dim)
    pca = PCA(n_components=min(target_dim, len(interpolated_embeddings)))
    pca_embeddings = pca.fit_transform(interpolated_embeddings)
    return pca_embeddings

# Mean pooling function
def mean_pooling(embeddings, target_dim=EMBEDDING_SIZE):
    interpolated_embeddings = interpolate_embeddings(embeddings, target_dim)
    return np.mean(interpolated_embeddings, axis=0)

# Hybrid approach: PCA then Mean pooling
def hybrid_pca_mean_pooling(embeddings, target_dim=EMBEDDING_SIZE):
    pca_embeddings = apply_pca(embeddings, target_dim)
    return mean_pooling(pca_embeddings, target_dim)

# Generate fake dataset
embeddings = generate_fake_embeddings()

# Apply PCA
pca_embeddings = [apply_pca(group) for group in embeddings]

# Apply Mean Pooling
mean_pooled_embeddings = [mean_pooling(group) for group in embeddings]

# Apply Hybrid PCA then Mean Pooling
hybrid_embeddings = [hybrid_pca_mean_pooling(group) for group in embeddings]

import pandas as pd

# Convert results to DataFrame for better display
result_df = pd.DataFrame({
    'PCA Embedding': [emb.tolist() for emb in pca_embeddings],
    'Mean Pooling Embedding': [emb.tolist() for emb in mean_pooled_embeddings],
    'Hybrid Embedding': [emb.tolist() for emb in hybrid_embeddings]
})


result_df.head()


In [None]:
import numpy as np
from sklearn.decomposition import PCA
import json
import os

EMBEDDING_SIZE = 500

# Load the embeddings from the JSON file
def load_embeddings(filepath):
    with open(filepath, 'r') as file:
        data = json.load(file)
    return data

# Normalization function
def normalize_embeddings(embeddings):
    mean = np.mean(embeddings, axis=0)
    std = np.std(embeddings, axis=0)
    std[std == 0] = 1  # Avoid division by zero
    normalized_embeddings = (embeddings - mean) / std
    return normalized_embeddings

# Interpolation method to bring all embeddings to the same size
def interpolate_embeddings(embeddings, target_dim=EMBEDDING_SIZE):
    interpolated_embeddings = []
    for embed in embeddings:
        original_indices = np.linspace(0, 1, num=len(embed))
        target_indices = np.linspace(0, 1, num=target_dim)
        interpolated_embed = np.interp(target_indices, original_indices, embed)
        interpolated_embeddings.append(interpolated_embed)
    interpolated_embeddings = np.array(interpolated_embeddings)
    return interpolated_embeddings

# PCA function to project embeddings to a fixed size
def apply_pca(embeddings, target_dim=EMBEDDING_SIZE, min_components=10, variance_threshold=0.95):
    interpolated_embeddings = interpolate_embeddings(embeddings, target_dim)
    normalized_embeddings = normalize_embeddings(interpolated_embeddings)
    
    n_samples, n_features = len(normalized_embeddings), len(normalized_embeddings[0])
    max_components = min(n_samples, n_features)
    pca = PCA(n_components=max_components)
    pca.fit(normalized_embeddings)
    
    explained_variance = np.cumsum(pca.explained_variance_ratio_)
    n_components = max(np.argmax(explained_variance >= variance_threshold) + 1, min_components)
    pca = PCA(n_components=min(n_components, max_components))
    pca_embeddings = pca.fit_transform(normalized_embeddings)
    
    print(f"PCA explained variance ratio: {pca.explained_variance_ratio_[:5]}")
    print(f"PCA components shape: {pca.components_.shape}")
    print(f"PCA number of components selected: {n_components}")
    print(f"PCA first 5 points: {pca_embeddings[:5]}")
    
    return pca_embeddings

# Mean pooling function
def mean_pooling(embeddings, target_dim=EMBEDDING_SIZE):
    interpolated_embeddings = interpolate_embeddings(embeddings, target_dim)
    return np.mean(interpolated_embeddings, axis=0)

# Hybrid approach: PCA then Mean pooling
def hybrid_pca_mean_pooling(embeddings, target_dim=EMBEDDING_SIZE):
    pca_embeddings = apply_pca(embeddings, target_dim)
    mean_pooled = np.mean(pca_embeddings, axis=0)
    return mean_pooled

# Function to process a single file
def process_file(filepath):
    data = load_embeddings(filepath)
    embeddings = [item['embedding'] for item in data['embeddings']]

    # Apply PCA
    pca_embeddings = apply_pca(embeddings, target_dim=EMBEDDING_SIZE)
    pca_mean = np.mean(pca_embeddings, axis=0)

    # Apply Mean Pooling
    mean_pooled_embedding = mean_pooling(embeddings, target_dim=EMBEDDING_SIZE)

    # Apply Hybrid PCA then Mean Pooling
    hybrid_embedding = hybrid_pca_mean_pooling(embeddings, target_dim=EMBEDDING_SIZE)

    # Add combined embeddings to the data
    data['combined'] = {
        'PCA': pca_mean.tolist(),
        'MeanPooling': mean_pooled_embedding.tolist(),
        'Hybrid': hybrid_embedding.tolist()
    }

    with open(filepath, 'w') as file:
        json.dump(data, file, indent=4)
    
    return filepath

# Function to process all files in a directory
def process_directory(directory):
    updated_files = []
    for filename in os.listdir(directory):
        if filename.endswith('.json'):
            filepath = os.path.join(directory, filename)
            updated_filepath = process_file(filepath)
            updated_files.append(updated_filepath)
    return updated_files

# Variance and range checks with more detailed debug
def check_embedding_statistics(original_embeddings, processed_embeddings):
    original_var = np.var(original_embeddings, axis=0)
    processed_var = np.var(processed_embeddings, axis=0)
    
    original_range = np.ptp(original_embeddings, axis=0)
    processed_range = np.ptp(processed_embeddings, axis=0)
    
    var_ratio = processed_var / original_var
    range_ratio = processed_range / original_range
    
    print(f"Original variance: {original_var.mean()}")
    print(f"Processed variance: {processed_var.mean()}")
    print(f"Variance ratio: {var_ratio.mean()}")
    print(f"Original range: {original_range.mean()}")
    print(f"Processed range: {processed_range.mean()}")
    print(f"Range ratio: {range_ratio.mean()}")
    
    assert 0.5 < var_ratio.mean() < 2.0, "Processed variance is not within acceptable range"
    assert 0.5 < range_ratio.mean() < 2.0, "Processed range is not within acceptable range"

# Function to test the structure and shape of the final file
def test_final_file(filepath, expected_shape=500):
    with open(filepath, 'r') as file:
        data = json.load(file)

    assert 'combined' in data, "Missing 'combined' key in JSON file"
    combined = data['combined']
    
    original_embeddings = [item['embedding'] for item in data['embeddings']]
    original_embeddings = interpolate_embeddings(original_embeddings, target_dim=expected_shape)

    for key in ['MeanPooling', 'Hybrid']:
        assert key in combined, f"Missing '{key}' key in 'combined'"
        processed_embeddings = np.array(combined[key])
        assert len(processed_embeddings) == expected_shape, f"Incorrect shape for '{key}' embedding"
        check_embedding_statistics(original_embeddings, processed_embeddings)

    # Handle PCA embedding separately
    assert 'PCA' in combined, f"Missing 'PCA' key in 'combined'"
    pca_embedding_shape = len(combined['PCA'])
    assert pca_embedding_shape > 0, f"Incorrect shape for 'PCA' embedding"
    pca_embeddings = np.array(combined['PCA'])
    check_embedding_statistics(original_embeddings, pca_embeddings)

    print(f"File {filepath} passed the tests.")

# Path to the directory containing the JSON files
directory_path = '/workspace/slice-monorepo/thebeast/chat_pipeline/data/test/step_2/embeddings'

# Process all files in the directory
updated_files = process_directory(directory_path)

# Test the first updated file to ensure it is correctly formatted
test_final_file(updated_files[0])

# Display the combined embeddings
updated_files


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, AdamW
import numpy as np


SEED = 42
NUM_TEXTS = 160
MAX_LENGTH = 300
EMBEDDING_DIM = 300
HIDDEN_DIM = 128
BATCH_SIZE = 8
NUM_EPOCHS = 25
LEARNING_RATE = 1e-4

np.random.seed(SEED)
torch.manual_seed(SEED)

# Step 1: Generate Fake Data
fake_embeddings = [torch.tensor(np.random.rand(EMBEDDING_DIM), dtype=torch.float32) for _ in range(NUM_TEXTS * 2)]

# Generate random text sequences from the tokenizer's vocabulary
def generate_random_texts(tokenizer, num_texts, max_length):
    texts = []
    vocab_size = tokenizer.vocab_size
    for _ in range(num_texts):
        random_tokens = np.random.randint(0, vocab_size, size=(max_length,))
        texts.append(tokenizer.decode(random_tokens, skip_special_tokens=True))
    return texts

# Step 2: Initialize Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token  # Set the padding token

fake_texts = generate_random_texts(tokenizer, num_texts=NUM_TEXTS, max_length=MAX_LENGTH)

# Convert Text to Tokens
tokenized_texts = [tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=MAX_LENGTH)['input_ids'].squeeze(0) for text in fake_texts]

# Step 3: Create Dataset
class EmbeddingTextDataset(Dataset):
    def __init__(self, embeddings, tokenized_texts):
        self.embeddings = embeddings
        self.tokenized_texts = tokenized_texts
        
    def __len__(self):
        return len(self.tokenized_texts)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.tokenized_texts[idx]

dataset = EmbeddingTextDataset(fake_embeddings, tokenized_texts)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
print("dataset made")
# Step 4: Define the Model
class SimpleTextGenerator(nn.Module):
    def __init__(self, embedding_dim, vocab_size, hidden_dim, max_length):
        super(SimpleTextGenerator, self).__init__()
        self.embedding_projection = nn.Linear(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size * max_length)
        self.max_length = max_length
        self.vocab_size = vocab_size
        
    def forward(self, embedding):
        projected_embedding = self.embedding_projection(embedding)
        output = self.fc(projected_embedding)
        output = output.view(-1, self.max_length, self.vocab_size)
        return output

vocab_size = tokenizer.vocab_size
model = SimpleTextGenerator(EMBEDDING_DIM, vocab_size, HIDDEN_DIM, MAX_LENGTH)
print("model made")
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Step 5: Train the Model
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.CrossEntropyLoss()

def train(model, dataloader, optimizer, criterion, device, epochs):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for embeddings, tokenized_texts in dataloader:
            embeddings = embeddings.to(device)
            tokenized_texts = tokenized_texts.to(device)
            
            optimizer.zero_grad()
            logits = model(embeddings)
            loss = criterion(logits.view(-1, vocab_size), tokenized_texts.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss}")

train(model, dataloader, optimizer, criterion, device, NUM_EPOCHS)

# Step 6: Generate and Print Text in One Go
def generate_text_from_embedding(model, embedding, tokenizer, max_length):
    model.eval()
    with torch.no_grad():
        embedding = embedding.to(device)
        logits = model(embedding)
        predicted_ids = torch.argmax(logits, dim=-1)
        generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
        return generated_text

# Test the model with a new embedding
test_embedding = torch.tensor(np.random.rand(EMBEDDING_DIM), dtype=torch.float32).to(device)
generated_text = generate_text_from_embedding(model, test_embedding, tokenizer, MAX_LENGTH)
print("Generated text:", generated_text)


In [None]:
import json
import os

def extract_embeddings(filepath):
    with open(filepath, 'r') as file:
        data = json.load(file)
    
    uuid = data.get('uuid')
    combined = data.get('combined', {})
    mean_pooling = combined.get('MeanPooling', [])
    hybrid = combined.get('Hybrid', [])

    return uuid, mean_pooling, hybrid

def process_directory(directory):
    embeddings_dict = {}
    for filename in os.listdir(directory):
        if filename.endswith('_updated.json'):
            filepath = os.path.join(directory, filename)
            uuid, mean_pooling, hybrid = extract_embeddings(filepath)
            if uuid is not None:
                embeddings_dict[uuid] = {
                    'MeanPooling': mean_pooling,
                    'Hybrid': hybrid
                }
        break
    return embeddings_dict

def load_response_list(filepath):
    with open(filepath, 'r') as file:
        data = json.load(file)
    return data

def combine_data(embeddings_dict, response_list):
    combined_dict = {}
    for response in response_list:
        uuid = response.get('uuid')
        if uuid in embeddings_dict:
            combined_dict[uuid] = {
                'response_content': response.get('response_content'),
                'MeanPooling': embeddings_dict[uuid]['MeanPooling'],
                'Hybrid': embeddings_dict[uuid]['Hybrid']
            }
    return combined_dict

# Path to the directory containing the updated JSON files
directory_path = '/workspace/slice-monorepo/thebeast/chat_pipeline/data/test/step_2/test'

# Path to the uuid_response_list.json file
response_list_path = '/workspace/slice-monorepo/thebeast/chat_pipeline/data/test/step_2/uuid_response_list.json'

# Process all files in the directory to extract embeddings
embeddings_dict = process_directory(directory_path)

# Load the response list
response_list = load_response_list(response_list_path)

# Combine the embeddings with the response content
combined_data = combine_data(embeddings_dict, response_list)

# Display the resulting combined dictionary
print(combined_data)


In [None]:
import json
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, AdamW
import numpy as np
from sklearn.model_selection import train_test_split

# Configuration variables
SEED = 42
NUM_TEXTS = 160
MAX_LENGTH = 50
EMBEDDING_DIM = 50
HIDDEN_DIM = 128
BATCH_SIZE = 8
NUM_EPOCHS = 100
LEARNING_RATE = 1e-4
EMBEDDING_TYPE = 'Hybrid'  # Options: 'MeanPooling', 'Hybrid'

np.random.seed(SEED)
torch.manual_seed(SEED)

def set_seed(seed):
    np.random.seed(seed)
    torch.manual_seed(seed)

# Step 1: Load Embeddings and Response Content
def extract_embeddings(filepath):
    with open(filepath, 'r') as file:
        data = json.load(file)
    
    uuid = data.get('uuid')
    combined = data.get('combined', {})
    mean_pooling = combined.get('MeanPooling', [])
    hybrid = combined.get('Hybrid', [])

    return uuid, mean_pooling, hybrid

def process_directory(directory):
    embeddings_dict = {}
    for filename in os.listdir(directory):
        if filename.endswith('_updated.json'):
            filepath = os.path.join(directory, filename)
            uuid, mean_pooling, hybrid = extract_embeddings(filepath)
            if uuid is not None:
                embeddings_dict[uuid] = {
                    'MeanPooling': mean_pooling,
                    'Hybrid': hybrid
                }
    return embeddings_dict

def load_response_list(filepath):
    with open(filepath, 'r') as file:
        data = json.load(file)
    return data

def combine_data(embeddings_dict, response_list):
    combined_dict = {}
    for response in response_list:
        uuid = response.get('uuid')
        if uuid in embeddings_dict:
            combined_dict[uuid] = {
                'response_content': response.get('response_content'),
                'MeanPooling': embeddings_dict[uuid]['MeanPooling'],
                'Hybrid': embeddings_dict[uuid]['Hybrid']
            }
    return combined_dict

def load_combined_data(embeddings_path, response_list_path):
    embeddings_dict = process_directory(embeddings_path)
    response_list = load_response_list(response_list_path)
    combined_data = combine_data(embeddings_dict, response_list)
    return combined_data

def print_intermediate_info(combined_data):
    print(f"Number of data points found: {len(combined_data)}")
    if combined_data:
        example_uuid, example_data = next(iter(combined_data.items()))
        print(f"Example UUID: {example_uuid}")
        print(f"Example Response Content: {example_data['response_content']}")
        print(f"Example Embedding (MeanPooling): {example_data['MeanPooling'][:5]}...")  # print first 5 elements
        print(f"Example Embedding (Hybrid): {example_data['Hybrid'][:5]}...")  # print first 5 elements

# Step 2: Initialize Tokenizer
def initialize_tokenizer():
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token  # Set the padding token
    return tokenizer

def process_texts(texts, tokenizer, max_length):
    tokenized_texts = [tokenizer(text, return_tensors='pt', truncation=True, padding='max_length', max_length=max_length)['input_ids'].squeeze(0) for text in texts]
    return tokenized_texts

# Step 3: Create Dataset
class EmbeddingTextDataset(Dataset):
    def __init__(self, combined_data, tokenized_texts, embedding_type):
        self.combined_data = combined_data
        self.tokenized_texts = tokenized_texts
        self.embedding_type = embedding_type
        self.keys = list(combined_data.keys())

    def __len__(self):
        return len(self.tokenized_texts)
    
    def __getitem__(self, idx):
        uuid = self.keys[idx]
        embedding = self.combined_data[uuid][self.embedding_type]
        embedding = torch.tensor(embedding, dtype=torch.float32)
        tokenized_text = self.tokenized_texts[idx]
        return embedding, tokenized_text

def create_dataloader(combined_data, tokenizer, embedding_type, batch_size):
    texts = [item['response_content'] for item in combined_data.values()]
    tokenized_texts = process_texts(texts, tokenizer, MAX_LENGTH)
    dataset = EmbeddingTextDataset(combined_data, tokenized_texts, embedding_type)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    return dataloader

# Step 4: Define the Model
class SimpleTextGenerator(nn.Module):
    def __init__(self, embedding_dim, vocab_size, hidden_dim, max_length):
        super(SimpleTextGenerator, self).__init__()
        self.embedding_projection = nn.Linear(embedding_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, vocab_size * max_length)
        self.max_length = max_length
        self.vocab_size = vocab_size
        
    def forward(self, embedding):
        projected_embedding = self.embedding_projection(embedding)
        output = self.fc(projected_embedding)
        output = output.view(-1, self.max_length, self.vocab_size)
        return output

def initialize_model(embedding_dim, vocab_size, hidden_dim, max_length):
    model = SimpleTextGenerator(embedding_dim, vocab_size, hidden_dim, max_length)
    return model

def move_model_to_device(model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model, device

# Step 5: Train the Model
def train(model, dataloader, optimizer, criterion, device, epochs, vocab_size):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for embeddings, tokenized_texts in dataloader:
            embeddings = embeddings.to(device)
            tokenized_texts = tokenized_texts.to(device)
            
            optimizer.zero_grad()
            logits = model(embeddings)
            loss = criterion(logits.view(-1, vocab_size), tokenized_texts.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss}")

# Step 6: Generate and Print Text in One Go
def generate_text_from_embedding(model, embedding, tokenizer, max_length, device):
    model.eval()
    with torch.no_grad():
        embedding = embedding.to(device)
        logits = model(embedding)
        predicted_ids = torch.argmax(logits, dim=-1)
        generated_text = tokenizer.decode(predicted_ids[0], skip_special_tokens=True)
        return generated_text

# Step 7: Tokenizer Test
def tokenizer_test(tokenizer, example_text, max_length):
    print(f"Original text: {example_text}")
    tokenized_text = tokenizer(example_text, return_tensors='pt', truncation=True, padding='max_length', max_length=max_length)['input_ids'].squeeze(0)
    decoded_text = tokenizer.decode(tokenized_text, skip_special_tokens=True)
    print(f"Tokenized and decoded text: {decoded_text}")

def split_data(combined_data, test_size=0.2):
    keys = list(combined_data.keys())
    train_keys, test_keys = train_test_split(keys, test_size=test_size, random_state=SEED)
    train_data = {key: combined_data[key] for key in train_keys}
    test_data = {key: combined_data[key] for key in test_keys}
    return train_data, test_data

def main():
    set_seed(SEED)

    # Load combined data
    combined_data = load_combined_data(embeddings_path, response_list_path)
    print_intermediate_info(combined_data)

    # Split data into training and test sets
    train_data, test_data = split_data(combined_data)
    print(f"Training data points: {len(train_data)}")
    print(f"Test data points: {len(test_data)}")

    # Initialize tokenizer
    tokenizer = initialize_tokenizer()
    vocab_size = tokenizer.vocab_size

    # Tokenizer test with an example text
    example_text = next(iter(combined_data.values()))['response_content']
    tokenizer_test(tokenizer, example_text, MAX_LENGTH)

    # Create dataloader for training data
    train_dataloader = create_dataloader(train_data, tokenizer, EMBEDDING_TYPE, BATCH_SIZE)

    # Initialize and move model to device
    model = initialize_model(EMBEDDING_DIM, vocab_size, HIDDEN_DIM, MAX_LENGTH)
    model, device = move_model_to_device(model)

    # Define optimizer and loss function
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()

    # Train the model
    train(model, train_dataloader, optimizer, criterion, device, NUM_EPOCHS, vocab_size)

    # Evaluate the model with the test set
    test_embeddings = [torch.tensor(test_data[key][EMBEDDING_TYPE], dtype=torch.float32).to(device) for key in test_data]
    for i, test_embedding in enumerate(test_embeddings):
        generated_text = generate_text_from_embedding(model, test_embedding, tokenizer, MAX_LENGTH, device)
        original_text = test_data[list(test_data.keys())[i]]['response_content']
        print(f"Original text: {original_text}")
        print(f"Generated text: {generated_text}\n")

if __name__ == "__main__":
    embeddings_path = '/workspace/slice-monorepo/thebeast/chat_pipeline/data/test/step_2/test'
    response_list_path = '/workspace/slice-monorepo/thebeast/chat_pipeline/data/test/step_2/uuid_response_list.json'
    main()
