In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
#Load the datasets
train_df = pd.read_csv('../data/DanishFungi2024-Mini-train.csv')
test_df = pd.read_csv('../data/DanishFungi2024-Mini-pubtest.csv')

In [None]:
# Combine train and test datasets for full analysis
full_df = pd.concat([train_df, test_df], ignore_index=True)

In [None]:
# Count images per species
species_counts = full_df['species'].value_counts()
num_species = species_counts.shape[0]
top_20_species = species_counts.head(20)
top_species = species_counts
print(species_counts.to_string())

In [None]:
# Analyze Poisonous vs Non-Poisonous
poisonous_counts_images = full_df['poisonous'].value_counts()
# Count by species 
species_poison_status = full_df.groupby('species')['poisonous'].max()
poisonous_counts_species = species_poison_status.value_counts()

In [None]:
# Print Summary
print(f"Total Images Analyzed: {len(full_df)}")
print(f"Total Unique Species: {num_species}")

print("\n--- Poisonous Images Count (0=No, 1=Yes) ---")
print(poisonous_counts_images)

print("\n--- Poisonous Species Count (0=No, 1=Yes) ---")
print(poisonous_counts_species)

In [None]:

# Set up the figure 
plt.figure(figsize=(18, 15))

# Top 20 Species
plt.subplot(3, 2, 1)
sns.barplot(x=top_20_species.values, y=top_20_species.index, palette="viridis")
plt.title("Top 20 Species by Number of Images")
plt.xlabel("Number of Images")
plt.ylabel("Species")



In [None]:
# Poisonous Distribution (Images)
plt.subplot(3, 2, 2)
sns.barplot(x=poisonous_counts_images.index, y=poisonous_counts_images.values, palette="magma")
plt.title("Distribution of Poisonous Labels (Images)")
plt.xticks([0, 1], ['Non-Poisonous', 'Poisonous'])
plt.ylabel("Count")



In [None]:
# Save species summary to CSV
species_summary = pd.DataFrame({'Image_Count': species_counts}).reset_index().rename(columns={'index': 'Species'})
species_summary['Poisonous'] = species_summary['Species'].map(species_poison_status)
species_summary.to_csv('species_summary.csv', index=False)
print("Detailed species counts saved to 'species_summary.csv'")

In [None]:
def get_true_classes():

    # Grab all unique Class IDs 
    unique_ids = sorted(full_df['class_id'].unique())
    print(f"Model sees {len(unique_ids)} unique Class IDs.")
    
    text_col = None
    for col in full_df.columns:
        if col.lower() in ['species']:
            text_col = col
            break
            
    if not text_col:
        print("Could not find text column.")
        return
    # Generate List
    with open('model_182_classes.txt', 'w') as f:
        for class_id in unique_ids:
            rows = full_df[full_df['class_id'] == class_id]
            valid_names = rows[text_col].dropna().unique()
            name = valid_names[0]
            f.write(name + '\n')



if __name__ == '__main__':
    get_true_classes()