In [None]:
import pandas as pd
import numpy as np
import random
from typing import List, Dict, Tuple
import matplotlib.pyplot as plt
import seaborn as sns

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

print("Libraries imported successfully!")


In [None]:
# Load the RankST tagset (valid tags)
print("Loading valid tagset...")
with open('data/tagset_rankst.txt', 'r') as f:
    valid_tags = set(tag.strip() for tag in f.readlines())

print(f"Loaded {len(valid_tags)} valid tags")
print(f"First 10 tags: {list(valid_tags)[:10]}")


In [None]:
# Load BSD10k metadata
print("Loading BSD10k metadata...")
metadata_df = pd.read_csv('data/BSD10k/BSD10k_metadata.csv')
print(f"Original dataset size: {len(metadata_df)} sounds")


In [None]:
# Filter sounds with sound_id > 207648
print("Filtering sounds with sound_id > 207648...")
filtered_df = metadata_df[metadata_df['sound_id'] > 207648].copy()
print(f"After sound_id filter: {len(filtered_df)} sounds")

print(f"Sound ID range: {filtered_df['sound_id'].min()} - {filtered_df['sound_id'].max()}")


In [None]:
def parse_and_filter_tags(tags_str: str, valid_tags: set) -> List[str]:
    """
    Parse tags string and filter to keep only valid tags
    """
    if pd.isna(tags_str):
        return []
    
    # Split tags by comma and clean them
    tags = [tag.strip() for tag in tags_str.split(',')]
    
    # Filter to keep only valid tags
    valid_sound_tags = [tag for tag in tags if tag in valid_tags]
    
    return valid_sound_tags

# Apply tag filtering
print("Parsing and filtering tags...")
filtered_df['valid_tags'] = filtered_df['tags'].apply(lambda x: parse_and_filter_tags(x, valid_tags))
filtered_df['num_valid_tags'] = filtered_df['valid_tags'].apply(len)

print(f"Valid tags per sound: min={filtered_df['num_valid_tags'].min()}, max={filtered_df['num_valid_tags'].max()}, mean={filtered_df['num_valid_tags'].mean():.1f}")


In [None]:
# Filter to keep only sounds with 5-15 valid tags
print("Filtering sounds with 5-15 valid tags...")
final_df = filtered_df[(filtered_df['num_valid_tags'] >= 5) & (filtered_df['num_valid_tags'] <= 15)].copy()
print(f"After tag count filter: {len(final_df)} sounds")


In [None]:
# Convert valid_tags list to comma-separated string for CSV output
final_df['valid_tags_str'] = final_df['valid_tags'].apply(lambda x: ','.join(x))

# Select and reorder columns for output
output_columns = ['sound_id', 'title', 'valid_tags_str', 'num_valid_tags']
output_df = final_df[output_columns].copy()

# Rename columns for clarity
output_df.columns = ['sound_id', 'title', 'tags', 'num_tags']

# Create data/BSD10k directory if it doesn't exist
import os
os.makedirs('data/BSD10k', exist_ok=True)

# Save to CSV in data/BSD10k folder
output_file = 'data/BSD10k/BSD10K_metadata_filtered.csv'
output_df.to_csv(output_file, index=False)
print(f"\nFiltered data saved to {output_file}")
print(f"Total sounds in filtered dataset: {len(output_df)}")
print(f"Average number of tags per sound: {output_df['num_tags'].mean():.2f}")


In [None]:
# Display some statistics about the filtered dataset
print("Dataset Statistics:")
print("=" * 50)
print(f"Total sounds: {len(output_df)}")
print(f"Min tags per sound: {output_df['num_tags'].min()}")
print(f"Max tags per sound: {output_df['num_tags'].max()}")
print(f"Mean tags per sound: {output_df['num_tags'].mean():.2f}")
print(f"Median tags per sound: {output_df['num_tags'].median():.2f}")

# Display first few rows of the filtered dataset
print("\nFirst 5 rows of filtered dataset:")
print("-" * 50)
display(output_df.head())


In [None]:
def create_input_ground_truth_pairs(sound_data: pd.DataFrame) -> List[Dict]:
    """
    For each sound, randomly select 3 tags as input and use remaining as ground truth
    """
    test_data = []
    
    for _, row in sound_data.iterrows():
        tags = row['valid_tags']
        sound_id = row['sound_id']
        
        # Randomly shuffle tags
        shuffled_tags = tags.copy()
        random.shuffle(shuffled_tags)
        
        # Select 3 tags as input, rest as ground truth
        input_tags = shuffled_tags[:3]
        ground_truth_tags = shuffled_tags[3:]
        
        test_data.append({
            'sound_id': sound_id,
            'input_tags': input_tags,
            'ground_truth_tags': ground_truth_tags,
            'total_tags': len(tags),
            'title': row['title']
        })
    
    return test_data

# Create input/ground truth pairs
print("Creating input/ground truth pairs...")
test_data = create_input_ground_truth_pairs(final_df)

print(f"Created {len(test_data)} test cases")
print(f"Average ground truth tags per sound: {np.mean([len(item['ground_truth_tags']) for item in test_data]):.2f}")


In [None]:
# Show some examples of the test data
print("Examples of test cases:")
print("=" * 60)

for i, case in enumerate(test_data[:5]):
    print(f"\nTest case {i+1}:")
    print(f"  Sound ID: {case['sound_id']}")
    print(f"  Title: {case['title']}")
    print(f"  Input tags ({len(case['input_tags'])}): {case['input_tags']}")
    print(f"  Ground truth tags ({len(case['ground_truth_tags'])}): {case['ground_truth_tags']}")
    print(f"  Total tags: {case['total_tags']}")


In [None]:
# Print dataset statistics for input/ground truth pairs
gt_counts = [len(case['ground_truth_tags']) for case in test_data]
total_counts = [case['total_tags'] for case in test_data]

print("Input/Ground Truth Dataset Statistics:")
print("=" * 50)
print(f"Total test cases: {len(test_data)}")
print(f"Min ground truth tags: {min(gt_counts)}")
print(f"Max ground truth tags: {max(gt_counts)}")
print(f"Mean ground truth tags: {np.mean(gt_counts):.2f}")
print(f"Median ground truth tags: {np.median(gt_counts):.2f}")
print()
print(f"Min total tags: {min(total_counts)}")
print(f"Max total tags: {max(total_counts)}")
print(f"Mean total tags: {np.mean(total_counts):.2f}")


In [None]:
# Save the input/ground truth pairs for use with the tag recommendation system
import pickle
import json

# Save as pickle for easy loading
pickle_file = 'data/input_ground_truth_pairs.pkl'
with open(pickle_file, 'wb') as f:
    pickle.dump(test_data, f)

# Also save as JSON for human readability
json_file = 'data/input_ground_truth_pairs.json'
with open(json_file, 'w') as f:
    json.dump(test_data, f, indent=2)

print(f"Input/Ground truth pairs saved to:")
print(f"- {pickle_file} ({len(test_data)} test cases)")
print(f"- {json_file} ({len(test_data)} test cases)")

# Save summary statistics
summary = {
    'total_test_cases': len(test_data),
    'avg_ground_truth_tags': np.mean(gt_counts),
    'avg_total_tags': np.mean(total_counts),
    'min_ground_truth_tags': min(gt_counts),
    'max_ground_truth_tags': max(gt_counts),
    'sounds_filtered': {
        'original_count': len(metadata_df),
        'after_sound_id_filter': len(filtered_df),
        'after_tag_count_filter': len(final_df)
    },
    'tagsets': {
        'rankst_tagset_size': len(valid_tags),
        'clap_tagset_size': len(clap_tagset),
        'tags_removed_for_clap': len(valid_tags) - len(clap_tagset)
    }
}

summary_file = 'data/input_ground_truth_summary.json'
with open(summary_file, 'w') as f:
    json.dump(summary, f, indent=2)

print(f"\nSummary saved to {summary_file}")
print("Data preparation complete!")


In [None]:
# Create CLAP tagset by filtering RankST tagset to only include tags present in filtered data
print("Creating CLAP tagset from filtered data...")

# Extract all unique tags from the filtered dataset
all_tags_in_filtered_data = set()
for tags_list in final_df['valid_tags']:
    all_tags_in_filtered_data.update(tags_list)

print(f"Total unique tags in filtered dataset: {len(all_tags_in_filtered_data)}")

# Filter RankST tagset to only include tags present in filtered data
clap_tagset = valid_tags.intersection(all_tags_in_filtered_data)

print(f"Original RankST tagset size: {len(valid_tags)}")
print(f"CLAP tagset size (tags present in filtered data): {len(clap_tagset)}")
print(f"Removed {len(valid_tags) - len(clap_tagset)} tags not present in filtered data")

# Save CLAP tagset to file
clap_tagset_file = 'data/tagset_clap.txt'
with open(clap_tagset_file, 'w') as f:
    for tag in sorted(clap_tagset):
        f.write(f"{tag}\n")

print(f"\nCLAP tagset saved to {clap_tagset_file}")
print(f"CLAP tagset contains {len(clap_tagset)} tags")
