In [1]:
# !pip install tensorflow keras numpy pillow

In [None]:
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import matplotlib.pyplot as plt
import os
import joblib

def train_lr_virality_predictor(embeddings_path, virality_data_path, test_size=0.2, random_state=42):
    """
    Train a linear regression model to predict post virality using image embeddings
    
    Args:
        embeddings_path: Path to the numpy file containing ResNet50 embeddings
        virality_data_path: Path to CSV file with image filenames and virality scores
        test_size: Proportion of data to use for testing
        random_state: Random seed for reproducibility
        
    Returns:
        model: Trained linear regression model
        X_test: Test features
        y_test: Test labels
        metrics: Dictionary of evaluation metrics
    """
    # Load embeddings
    print("Loading embeddings...")
    embeddings = np.load(embeddings_path)
    
    # Load image paths (assuming they were saved alongside embeddings)
    image_paths_file = "image_paths.txt"
    with open(image_paths_file, "r") as f:
        image_paths = [line.strip() for line in f.readlines()]
    
    # Extract filenames from paths
    image_filenames = [os.path.basename(path) for path in image_paths]
    
    # Load virality data
    print("Loading virality data...")
    virality_df = pd.read_csv(virality_data_path)
    
    # Ensure the virality dataframe has 'filename' and 'virality_score' columns
    if 'filename' not in virality_df.columns or 'virality_score' not in virality_df.columns:
        raise ValueError("Virality data must contain 'filename' and 'virality_score' columns")
    
    # Create a dataframe with embeddings and filenames
    embeddings_df = pd.DataFrame({
        'filename': image_filenames,
        'embedding_idx': range(len(image_filenames))
    })
    
    # Merge with virality data
    merged_df = pd.merge(embeddings_df, virality_df, on='filename', how='inner')
    
    if len(merged_df) == 0:
        raise ValueError("No matches found between embeddings and virality data. Check filenames.")
    
    print(f"Matched {len(merged_df)} images with virality scores")
    
    # Get the indices of embeddings that match with virality data
    valid_indices = merged_df['embedding_idx'].values
    
    # Extract features (embeddings) and target (virality score)
    X = embeddings[valid_indices]
    y = merged_df['virality_score'].values
    
    # Split into training and testing sets
    print("Creating train-test split...")
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=test_size, random_state=random_state
    )
    
    print(f"Training set: {X_train.shape[0]} samples")
    print(f"Testing set: {X_test.shape[0]} samples")
    
    # Train linear regression model
    print("Training linear regression model...")
    model = LinearRegression()
    model.fit(X_train, y_train)
    
    # Make predictions
    y_pred_train = model.predict(X_train)
    y_pred_test = model.predict(X_test)
    
    # Evaluate model
    metrics = {
        'train_mse': mean_squared_error(y_train, y_pred_train),
        'test_mse': mean_squared_error(y_test, y_pred_test),
        'train_mae': mean_absolute_error(y_train, y_pred_train),
        'test_mae': mean_absolute_error(y_test, y_pred_test),
        'train_r2': r2_score(y_train, y_pred_train),
        'test_r2': r2_score(y_test, y_pred_test)
    }
    
    print("\nModel Performance:")
    print(f"Train MSE: {metrics['train_mse']:.4f}")
    print(f"Test MSE: {metrics['test_mse']:.4f}")
    print(f"Train MAE: {metrics['train_mae']:.4f}")
    print(f"Test MAE: {metrics['test_mae']:.4f}")
    print(f"Train R²: {metrics['train_r2']:.4f}")
    print(f"Test R²: {metrics['test_r2']:.4f}")
    
    # Visualize predictions vs actual
    plt.figure(figsize=(10, 6))
    plt.scatter(y_test, y_pred_test, alpha=0.5)
    plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], 'r--')
    plt.xlabel('Actual Virality Score')
    plt.ylabel('Predicted Virality Score')
    plt.title('Predicted vs
