# Semi-Join with exists() Method in JAXFrame

This notebook demonstrates the `exists()` method for performing semi-joins in JAXFrame DataFrames. Semi-joins filter rows from one table based on the existence of matching rows in another table, but only return columns from the first table.

## 1. Import Required Libraries

Import JAXFrame, JAX, NumPy, and other necessary libraries for testing the semi-join functionality.

In [None]:
from jaxframe import DataFrame
import jax.numpy as jnp
import numpy as np
import pandas as pd
import time
from typing import List

print("Libraries imported successfully!")
print(f"JAX arrays available: {jnp.array([1, 2, 3])}")

## 2. Create Test DataFrames

Create sample DataFrames with mixed data types (lists, NumPy arrays, JAX arrays) to test the semi-join implementation across different scenarios.

In [None]:
# Create customer DataFrame with mixed types
customers_data = {
    'customer_id': ['C001', 'C002', 'C003', 'C004', 'C005', 'C006'],
    'name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve', 'Frank'],
    'city': ['New York', 'Los Angeles', 'Chicago', 'Houston', 'Phoenix', 'Philadelphia'],
    'age': np.array([25, 30, 35, 28, 32, 45]),
    'credit_score': jnp.array([750.0, 680.0, 720.0, 650.0, 800.0, 690.0])
}
customers = DataFrame(customers_data, name="customers")

# Create orders DataFrame - only some customers have orders
orders_data = {
    'order_id': ['O001', 'O002', 'O003', 'O004', 'O005', 'O006'],
    'customer_id': ['C001', 'C001', 'C003', 'C003', 'C005', 'C005'],  # Only C001, C003, C005 have orders
    'amount': jnp.array([100.0, 150.0, 200.0, 75.0, 300.0, 250.0]),
    'order_date': ['2024-01-01', '2024-01-15', '2024-02-01', '2024-02-10', '2024-03-01', '2024-03-15']
}
orders = DataFrame(orders_data, name="orders")

print("Customer DataFrame:")
display(customers.to_pandas())
print(f"\nCustomer column types: {customers.column_types}")

print("\nOrders DataFrame:")
display(orders.to_pandas())
print(f"\nOrders column types: {orders.column_types}")

## 3. Implement Basic Semi-Join with exists() Method

Implement the exists() method on DataFrame class and demonstrate basic semi-join operations that filter rows based on existence of matches in another DataFrame.

In [None]:
# Basic semi-join: Find customers who have placed orders
customers_with_orders = customers.exists(orders, on='customer_id')

print("Customers who have placed orders (semi-join):")
display(customers_with_orders.to_pandas())
print(f"\nResult shape: {customers_with_orders.shape}")
print(f"Original customers: {len(customers)}, customers with orders: {len(customers_with_orders)}")
print(f"Column structure preserved: {customers_with_orders.columns == customers.columns}")

# Verify no columns from orders DataFrame were added
assert 'order_id' not in customers_with_orders.columns, "Semi-join should not add columns from right table"
assert 'amount' not in customers_with_orders.columns, "Semi-join should not add columns from right table"
assert 'order_date' not in customers_with_orders.columns, "Semi-join should not add columns from right table"

print("\n✓ Basic semi-join working correctly!")

In [None]:
# Demonstrate duplicate elimination in semi-joins
print("=== Demonstrating Duplicate Elimination ===")

# Create a DataFrame with duplicates
customers_with_duplicates = DataFrame({
    'customer_id': ['C001', 'C001', 'C003', 'C003', 'C007'],  # Duplicates
    'region': ['North', 'North', 'South', 'South', 'West'],
    'value': [100, 100, 200, 200, 300]
}, name="customers_dup")

print("DataFrame with duplicates:")
display(customers_with_duplicates.to_pandas())

# Semi-join automatically eliminates duplicates
unique_customers_with_orders = customers_with_duplicates.exists(orders, on='customer_id')

print("\nAfter semi-join (duplicates eliminated):")
display(unique_customers_with_orders.to_pandas())
print(f"Original rows: {len(customers_with_duplicates)}, after semi-join: {len(unique_customers_with_orders)}")

print("\n✓ Duplicate elimination working correctly!")

In [None]:
# Multi-column semi-join example
print("=== Multi-Column Semi-Join ===")

# Create DataFrames for multi-column join
users_data = {
    'user_id': ['U001', 'U002', 'U003', 'U004', 'U005'],
    'region': ['US', 'US', 'EU', 'EU', 'ASIA'],
    'active': [True, False, True, True, False],
    'subscription': ['premium', 'basic', 'premium', 'free', 'premium']
}
users = DataFrame(users_data, name="users")

sessions_data = {
    'session_id': ['S001', 'S002', 'S003', 'S004'],
    'user_id': ['U001', 'U003', 'U001', 'U004'],
    'region': ['US', 'EU', 'US', 'EU'],
    'duration': jnp.array([30.0, 45.0, 60.0, 25.0])
}
sessions = DataFrame(sessions_data, name="sessions")

print("Users DataFrame:")
display(users.to_pandas())
print("\nSessions DataFrame:")
display(sessions.to_pandas())

# Find users who have sessions in their region
active_users = users.exists(sessions, on=['user_id', 'region'])

print("\nUsers with sessions in their region:")
display(active_users.to_pandas())

print("\n✓ Multi-column semi-join working correctly!")

## 4. Test Semi-Join with JAX Arrays

Test the exists() method with JAX arrays to ensure compatibility with JAX computational graphs and JIT compilation.

In [None]:
# Test with pure JAX arrays
print("=== Testing with JAX Arrays ===")

jax_left = DataFrame({
    'id': ['A', 'B', 'C', 'D', 'E'],
    'values': jnp.array([1.0, 2.0, 3.0, 4.0, 5.0]),
    'scores': jnp.array([10.5, 20.3, 30.1, 40.7, 50.9])
}, name="jax_left")

jax_right = DataFrame({
    'id': ['B', 'D', 'F', 'G'],
    'other_data': jnp.array([100.0, 200.0, 300.0, 400.0])
}, name="jax_right")

print("Left JAX DataFrame:")
display(jax_left.to_pandas())
print("\nRight JAX DataFrame:")
display(jax_right.to_pandas())

# Perform semi-join
jax_result = jax_left.exists(jax_right, on='id')

print("\nJAX Semi-Join Result:")
display(jax_result.to_pandas())
print(f"\nResult column types: {jax_result.column_types}")

# Verify JAX array types are preserved
assert jax_result.column_types['values'] == 'jax_array', "JAX array type should be preserved"
assert jax_result.column_types['scores'] == 'jax_array', "JAX array type should be preserved"

print("\n✓ JAX array semi-join working correctly!")
print("✓ JAX array types preserved!")

In [None]:
# Test JAX compilation compatibility
from jax import jit

print("=== Testing JAX JIT Compatibility ===")

def process_with_semi_join(left_values, right_values):
    """Function that uses semi-join in a JAX computational graph."""
    # Create temporary DataFrames
    left_df = DataFrame({
        'id': ['A', 'B', 'C'],
        'values': left_values
    })
    
    right_df = DataFrame({
        'id': ['B', 'C', 'D'],
        'other': right_values
    })
    
    # Perform semi-join
    result = left_df.exists(right_df, on='id')
    
    # Return JAX array for further computation
    return result['values']

# Test the function
left_vals = jnp.array([1.0, 2.0, 3.0])
right_vals = jnp.array([10.0, 20.0, 30.0])

result_values = process_with_semi_join(left_vals, right_vals)
print(f"Semi-join result values: {result_values}")
print(f"Result shape: {result_values.shape}")
print(f"Result type: {type(result_values)}")

# Test that we can use this in further JAX operations
final_result = jnp.sum(result_values ** 2)
print(f"\nFurther JAX computation (sum of squares): {final_result}")

print("\n✓ JAX computational graph compatibility verified!")

## 5. Compare Performance: exists() vs Traditional Joins

Compare the performance and memory usage of semi-joins using exists() versus traditional inner joins followed by column selection.

In [None]:
# Create larger DataFrames for performance testing
print("=== Performance Comparison: exists() vs join() ===")

# Create larger test data
n_left = 10000
n_right = 5000

large_left = DataFrame({
    'id': [f'ID_{i:05d}' for i in range(n_left)],
    'value1': np.random.randn(n_left),
    'value2': jnp.array(np.random.randn(n_left)),
    'category': [f'Cat_{i % 100}' for i in range(n_left)]
}, name="large_left")

# Create right table with only partial overlap
right_ids = [f'ID_{i:05d}' for i in range(0, n_left, 3)]  # Every 3rd ID
large_right = DataFrame({
    'id': right_ids,
    'extra_data': np.random.randn(len(right_ids)),
    'flag': [True] * len(right_ids)
}, name="large_right")

print(f"Left DataFrame: {large_left.shape}")
print(f"Right DataFrame: {large_right.shape}")
print(f"Expected result size: ~{len(right_ids)} rows")

In [None]:
# Method 1: Using exists() (semi-join)
print("\n=== Method 1: exists() Semi-Join ===")
start_time = time.time()

result_exists = large_left.exists(large_right, on='id')

exists_time = time.time() - start_time
print(f"exists() time: {exists_time:.4f} seconds")
print(f"Result shape: {result_exists.shape}")
print(f"Result columns: {result_exists.columns}")

In [None]:
# Method 2: Using traditional join() then column removal
print("\n=== Method 2: Traditional join() + Column Selection ===")
start_time = time.time()

# Traditional approach: join then remove unwanted columns
result_join = large_left.join(large_right, on='id', source='flag')
# Remove the added column to match semi-join result
result_join_cleaned = result_join.remove_column('large_right/flag')

join_time = time.time() - start_time
print(f"join() + cleanup time: {join_time:.4f} seconds")
print(f"Result shape: {result_join_cleaned.shape}")
print(f"Result columns: {result_join_cleaned.columns}")

# Verify results are equivalent
print(f"\n=== Results Comparison ===")
print(f"exists() is {join_time/exists_time:.2f}x faster than join() + cleanup")
print(f"Results are equivalent: {result_exists.shape == result_join_cleaned.shape}")

# Check if the actual data is the same (should be, since both filter the same rows)
exists_ids = set(result_exists['id'])
join_ids = set(result_join_cleaned['id'])
print(f"Same filtered IDs: {exists_ids == join_ids}")

print("\n✓ Performance comparison completed!")

## 6. Test Edge Cases and Error Handling

Test edge cases including empty DataFrames, non-existent join keys, and various data type combinations to ensure robust error handling.

In [None]:
print("=== Testing Edge Cases ===")

# Test 1: Empty result (no matches)
print("\n1. Testing with no matches:")
no_match_left = DataFrame({
    'id': ['X', 'Y', 'Z'],
    'value': [1, 2, 3]
})

no_match_right = DataFrame({
    'id': ['A', 'B', 'C'],
    'other': [10, 20, 30]
})

empty_result = no_match_left.exists(no_match_right, on='id')
print(f"Empty result shape: {empty_result.shape}")
print(f"Empty result columns preserved: {empty_result.columns == no_match_left.columns}")
assert len(empty_result) == 0, "Should return empty DataFrame"
print("✓ Empty result case works correctly")

# Test 2: Empty input DataFrames
print("\n2. Testing with empty left DataFrame:")
empty_left = DataFrame({'id': [], 'value': []}, name="empty_left")
non_empty_right = DataFrame({'id': ['A'], 'other': [1]}, name="non_empty_right")

result_empty_left = empty_left.exists(non_empty_right, on='id')
print(f"Result with empty left: {result_empty_left.shape}")
assert len(result_empty_left) == 0, "Empty left should return empty result"
print("✓ Empty left DataFrame case works correctly")

print("\n3. Testing with empty right DataFrame:")
non_empty_left = DataFrame({'id': ['A'], 'value': [1]}, name="non_empty_left")
empty_right = DataFrame({'id': [], 'other': []}, name="empty_right")

result_empty_right = non_empty_left.exists(empty_right, on='id')
print(f"Result with empty right: {result_empty_right.shape}")
assert len(result_empty_right) == 0, "Empty right should return empty result"
print("✓ Empty right DataFrame case works correctly")

In [None]:
# Test error conditions
print("\n=== Testing Error Conditions ===")

test_left = DataFrame({
    'id': ['A', 'B'],
    'value': [1, 2]
})

test_right = DataFrame({
    'other_id': ['A', 'B'],
    'other_value': [10, 20]
})

# Test 1: Missing column in left DataFrame
print("\n1. Testing missing column in left DataFrame:")
try:
    test_left.exists(test_right, on='missing_column')
    print("❌ Should have raised ValueError")
except ValueError as e:
    print(f"✓ Correctly raised ValueError: {e}")

# Test 2: Missing column in right DataFrame
print("\n2. Testing missing column in right DataFrame:")
try:
    test_left.exists(test_right, on='id')  # 'id' exists in left but not right
    print("❌ Should have raised ValueError")
except ValueError as e:
    print(f"✓ Correctly raised ValueError: {e}")

# Test 3: Multi-column with missing column
print("\n3. Testing multi-column with missing column:")
try:
    test_left.exists(test_right, on=['id', 'missing'])
    print("❌ Should have raised ValueError")
except ValueError as e:
    print(f"✓ Correctly raised ValueError: {e}")

print("\n✓ All error conditions handled correctly!")

## 7. Demonstrate Real-World Use Cases

Show practical applications of semi-joins such as filtering samples based on treatment criteria, finding observations that have corresponding assay data, and other scientific data analysis scenarios.

In [None]:
print("=== Real-World Use Case 1: Clinical Trial Data ===")

# Create clinical trial datasets
patients = DataFrame({
    'patient_id': [f'P{i:03d}' for i in range(1, 21)],
    'age': np.random.randint(18, 80, 20),
    'gender': np.random.choice(['M', 'F'], 20),
    'baseline_score': jnp.array(np.random.normal(50, 10, 20)),
    'enrollment_date': ['2024-01-01'] * 20
}, name="patients")

# Not all patients completed the study
completed_patients = DataFrame({
    'patient_id': [f'P{i:03d}' for i in [1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 2, 4]],  # 12 completed
    'completion_date': ['2024-06-01'] * 12,
    'final_score': jnp.array(np.random.normal(60, 15, 12))
}, name="completed")

# Find patients who completed the study
completers = patients.exists(completed_patients, on='patient_id')

print("All enrolled patients:")
display(patients.to_pandas().head())
print(f"\nTotal enrolled: {len(patients)}")

print("\nPatients who completed the study:")
display(completers.to_pandas())
print(f"\nCompletion rate: {len(completers)}/{len(patients)} ({100*len(completers)/len(patients):.1f}%)")

# Analyze completion by demographics
completer_demographics = completers.to_pandas()
print(f"\nCompleter demographics:")
print(f"Average age of completers: {completer_demographics['age'].mean():.1f}")
print(f"Gender distribution: {completer_demographics['gender'].value_counts().to_dict()}")
print(f"Average baseline score: {completer_demographics['baseline_score'].mean():.1f}")

In [None]:
print("\n=== Real-World Use Case 2: Laboratory Data Analysis ===")

# Create laboratory sample tracking system
samples = DataFrame({
    'sample_id': [f'S{i:04d}' for i in range(1, 101)],
    'patient_id': [f'PT{i:03d}' for i in np.random.randint(1, 51, 100)],
    'collection_date': ['2024-01-01'] * 100,
    'sample_type': np.random.choice(['blood', 'urine', 'tissue'], 100),
    'volume_ml': jnp.array(np.random.uniform(1.0, 10.0, 100))
}, name="samples")

# Only some samples have been processed through specific assays
assay_results = DataFrame({
    'sample_id': [f'S{i:04d}' for i in [1, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70]],
    'assay_type': ['protein_analysis'] * 15,
    'result_value': jnp.array(np.random.normal(100, 20, 15)),
    'processing_date': ['2024-01-15'] * 15
}, name="assay_results")

quality_control = DataFrame({
    'sample_id': [f'S{i:04d}' for i in [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]],
    'qc_status': ['passed'] * 15,
    'qc_date': ['2024-01-10'] * 15
}, name="quality_control")

print("Laboratory Analysis Pipeline:")
print(f"Total samples collected: {len(samples)}")
print(f"Samples with assay results: {len(assay_results)}")
print(f"Samples that passed QC: {len(quality_control)}")

# Find samples that have both QC and assay data
qc_passed_samples = samples.exists(quality_control, on='sample_id')
assayed_samples = samples.exists(assay_results, on='sample_id')
fully_processed_samples = qc_passed_samples.exists(assay_results, on='sample_id')

print(f"\nSamples that passed QC: {len(qc_passed_samples)}")
print(f"Samples with assay data: {len(assayed_samples)}")
print(f"Samples both QC'd and assayed: {len(fully_processed_samples)}")

print("\nFully processed samples:")
display(fully_processed_samples.to_pandas())

# Analyze by sample type
processed_by_type = fully_processed_samples.to_pandas()['sample_type'].value_counts()
print(f"\nProcessed samples by type: {processed_by_type.to_dict()}")

In [None]:
print("\n=== Real-World Use Case 3: E-commerce Customer Analytics ===")

# Create e-commerce customer analysis scenario
all_customers = DataFrame({
    'customer_id': [f'CUST{i:05d}' for i in range(1, 1001)],
    'registration_date': ['2023-01-01'] * 1000,
    'country': np.random.choice(['US', 'UK', 'DE', 'FR', 'CA'], 1000),
    'age_group': np.random.choice(['18-25', '26-35', '36-45', '46-55', '55+'], 1000),
    'email_verified': np.random.choice([True, False], 1000, p=[0.8, 0.2])
}, name="all_customers")

# Recent purchases (last 30 days)
recent_purchases = DataFrame({
    'customer_id': [f'CUST{i:05d}' for i in np.random.choice(range(1, 1001), 150, replace=False)],
    'purchase_amount': jnp.array(np.random.exponential(50, 150)),
    'purchase_date': ['2024-01-01'] * 150
}, name="recent_purchases")

# Email campaign engagement
email_engaged = DataFrame({
    'customer_id': [f'CUST{i:05d}' for i in np.random.choice(range(1, 1001), 200, replace=False)],
    'email_opened': [True] * 200,
    'engagement_date': ['2024-01-15'] * 200
}, name="email_engaged")

# Mobile app users
mobile_users = DataFrame({
    'customer_id': [f'CUST{i:05d}' for i in np.random.choice(range(1, 1001), 300, replace=False)],
    'app_installed': [True] * 300,
    'last_app_use': ['2024-01-20'] * 300
}, name="mobile_users")

print("E-commerce Customer Segmentation:")
print(f"Total customers: {len(all_customers)}")
print(f"Recent purchasers: {len(recent_purchases)}")
print(f"Email engaged: {len(email_engaged)}")
print(f"Mobile app users: {len(mobile_users)}")

# Segment customers using semi-joins
active_buyers = all_customers.exists(recent_purchases, on='customer_id')
email_responsive = all_customers.exists(email_engaged, on='customer_id')
mobile_engaged = all_customers.exists(mobile_users, on='customer_id')

# High-value segment: customers who are active across all channels
high_value_customers = active_buyers.exists(email_engaged, on='customer_id').exists(mobile_users, on='customer_id')

print(f"\nCustomer Segments:")
print(f"Active buyers: {len(active_buyers)} ({100*len(active_buyers)/len(all_customers):.1f}%)")
print(f"Email responsive: {len(email_responsive)} ({100*len(email_responsive)/len(all_customers):.1f}%)")
print(f"Mobile engaged: {len(mobile_engaged)} ({100*len(mobile_engaged)/len(all_customers):.1f}%)")
print(f"High-value (all channels): {len(high_value_customers)} ({100*len(high_value_customers)/len(all_customers):.1f}%)")

# Analyze high-value segment
hv_analysis = high_value_customers.to_pandas()
print(f"\nHigh-Value Customer Analysis:")
print(f"Country distribution: {hv_analysis['country'].value_counts().to_dict()}")
print(f"Age group distribution: {hv_analysis['age_group'].value_counts().to_dict()}")
print(f"Email verification rate: {100*hv_analysis['email_verified'].mean():.1f}%")

print("\n✓ E-commerce customer segmentation completed!")

## Summary

The `exists()` method provides efficient semi-join functionality that:

1. **Filters rows** based on existence of matches in another DataFrame
2. **Preserves original structure** - only returns columns from the left table
3. **Eliminates duplicates** automatically
4. **Supports multi-column joins** for complex key relationships
5. **Maintains JAX compatibility** for computational graphs
6. **Handles mixed data types** (lists, NumPy arrays, JAX arrays)
7. **Provides better performance** than traditional join + column removal
8. **Robust error handling** for edge cases

### Key Benefits:
- **Cleaner API**: Directly expresses the intent of filtering based on existence
- **Performance**: Optimized for the semi-join operation without unnecessary column copying
- **Memory efficiency**: Doesn't create intermediate DataFrames with unwanted columns
- **Type safety**: Preserves original column types including JAX arrays
- **Automatic deduplication**: No need to manually handle duplicates

This makes the `exists()` method ideal for data analysis scenarios where you need to filter one dataset based on the presence of related records in another dataset.