In [9]:
from datasets import load_dataset
import pandas as pd

ds = load_dataset("Nech-C/mineralimage5K-98")

In [10]:
# Transform the dataset: rename 'name' to 'label', convert to lowercase, remove unwanted columns, and export locally
# Check current columns for all splits
print("Current features:", ds['train'].features.keys())
print("Available splits:", list(ds.keys()))

# Rename 'name' column to 'label' if it exists (applies to all splits automatically)
if 'name' in ds['train'].column_names:
    ds = ds.rename_column('name', 'label')
    print("Renamed 'name' to 'label' in all splits")

# Check if label is ClassLabel type - if so, we need special handling
from datasets import ClassLabel
label_feature = ds['train'].features.get('label')

if isinstance(label_feature, ClassLabel):
    # If it's ClassLabel, we need to get the names, lowercase them, and remap
    print("Label is ClassLabel type - converting names to lowercase")
    # Get original class names
    original_names = label_feature.names
    print(f"Original class names count: {len(original_names)}")
    
    # Convert to lowercase and create mapping
    # Preserve order but remove duplicates (keeping first occurrence)
    seen = set()
    lowercase_names = []
    name_mapping = {}  # maps old index to new index
    
    for old_idx, name in enumerate(original_names):
        lowercase_name = name.lower()
        if lowercase_name not in seen:
            seen.add(lowercase_name)
            new_idx = len(lowercase_names)
            lowercase_names.append(lowercase_name)
        else:
            # Find the index where this lowercase name was already added
            new_idx = lowercase_names.index(lowercase_name)
        name_mapping[old_idx] = new_idx
    
    print(f"Lowercase class names count: {len(lowercase_names)} (duplicates removed)")
    
    # Create mapping function that converts label indices
    def remap_classlabel_indices(example):
        if 'label' in example:
            old_idx = example['label']
            if isinstance(old_idx, int) and old_idx in name_mapping:
                example['label'] = name_mapping[old_idx]
        return example
    
    # Apply remapping to all splits
    ds = ds.map(remap_classlabel_indices)
    
    # Create new ClassLabel with lowercase names
    new_label_feature = ClassLabel(names=lowercase_names)
    
    # Update the feature - labels are now integers mapped to new indices
    ds = ds.cast_column('label', new_label_feature)
    print(f"Converted class labels to lowercase in all splits")
    
    # Convert class IDs to class names (replace integer IDs with string names)
    # IMPORTANT: Convert BEFORE casting to string type
    # Use the ClassLabel feature's int2str method for reliable conversion
    def convert_id_to_name(example):
        if 'label' in example:
            label_id = example['label']
            # label_id should be an integer at this point (ClassLabel index)
            # Convert it directly to the class name string using the feature's method
            if isinstance(label_id, int):
                # Use the ClassLabel's int2str method for reliable conversion
                example['label'] = new_label_feature.int2str(label_id)
            elif isinstance(label_id, str) and label_id.isdigit():
                # Handle case where it might already be a string number
                label_int = int(label_id)
                example['label'] = new_label_feature.int2str(label_int)
        return example
    
    # Convert integers to class name strings
    print("Converting class IDs (integers) to class names (strings)...")
    ds = ds.map(convert_id_to_name, desc="Converting class IDs to names")
    
    # Verify conversion happened correctly BEFORE casting to string
    test_label = ds['train'][0]['label']
    if isinstance(test_label, int):
        print(f"WARNING: Label is still integer after conversion: {test_label}")
        print(f"Attempting direct conversion...")
        # Fallback: direct conversion
        def direct_convert(example):
            if 'label' in example and isinstance(example['label'], int):
                example['label'] = lowercase_names[example['label']]
            return example
        ds = ds.map(direct_convert, desc="Direct conversion fallback")
        test_label = ds['train'][0]['label']
        print(f"After fallback conversion: {test_label} (type: {type(test_label)})")
    
    # Now change feature type from ClassLabel to string (values are already strings)
    from datasets import Value
    ds = ds.cast_column('label', Value('string'))
    print("Converted class IDs to class names (strings)")
    
    # Verify conversion worked
    sample_check = ds['train'][0]['label']
    print(f"Sample converted label (should be class name): {sample_check} (type: {type(sample_check)})")
else:
    # If it's a string feature, just map to lowercase
    def lowercase_label(example):
        if 'label' in example and example['label']:
            example['label'] = example['label'].lower()
        return example
    
    ds = ds.map(lowercase_label)
    print("Converted string labels to lowercase in all splits")

# Remove 'description' and 'mineral_boxes' columns if they exist
columns_to_remove = []
if 'description' in ds['train'].column_names:
    columns_to_remove.append('description')
if 'mineral_boxes' in ds['train'].column_names:
    columns_to_remove.append('mineral_boxes')

if columns_to_remove:
    ds = ds.remove_columns(columns_to_remove)
    print(f"Removed columns: {columns_to_remove}")

# Verify transformation worked for all splits
print("\nVerifying labels in all splits:")
for split_name in ds.keys():
    sample_labels = [ds[split_name][i]['label'] for i in range(min(3, len(ds[split_name])))]
    print(f"  {split_name}: sample labels = {sample_labels}")
    
    # Check if labels are still numeric strings (which means conversion failed)
    if sample_labels and isinstance(sample_labels[0], str) and sample_labels[0].isdigit():
        print(f"  WARNING: Labels in {split_name} are still numeric strings! Conversion failed.")
        print(f"  First label value: '{sample_labels[0]}' (should be a class name like 'graphite', 'gold', etc.)")
        print(f"  This means the class IDs were not properly converted to class names.")

# Export the dataset locally
export_path = "./mineralimage5K-98-processed"
ds.save_to_disk(export_path)

print(f"\nDataset exported successfully to: {export_path}")
print(f"Final features: {ds['train'].features.keys()}")
print(f"\nDataset info:")
print(ds['train'].features)


Current features: dict_keys(['image', 'name', 'description', 'mineral_boxes'])
Available splits: ['train', 'validation', 'test']
Renamed 'name' to 'label' in all splits
Label is ClassLabel type - converting names to lowercase
Original class names count: 98
Lowercase class names count: 98 (duplicates removed)
Converted class labels to lowercase in all splits
Converting class IDs (integers) to class names (strings)...


Converting class IDs to names: 100%|██████████| 12828/12828 [00:06<00:00, 1875.34 examples/s]
Converting class IDs to names: 100%|██████████| 2749/2749 [00:01<00:00, 1536.09 examples/s]
Converting class IDs to names: 100%|██████████| 2749/2749 [00:01<00:00, 1527.26 examples/s]


Attempting direct conversion...


Direct conversion fallback: 100%|██████████| 12828/12828 [00:06<00:00, 1926.25 examples/s]
Direct conversion fallback: 100%|██████████| 2749/2749 [00:01<00:00, 1971.62 examples/s]
Direct conversion fallback: 100%|██████████| 2749/2749 [00:01<00:00, 1971.46 examples/s]


After fallback conversion: 9 (type: <class 'int'>)


Casting the dataset: 100%|██████████| 12828/12828 [00:03<00:00, 3463.97 examples/s]
Casting the dataset: 100%|██████████| 2749/2749 [00:00<00:00, 3612.17 examples/s]
Casting the dataset: 100%|██████████| 2749/2749 [00:00<00:00, 3188.53 examples/s]


Converted class IDs to class names (strings)
Sample converted label (should be class name): 9 (type: <class 'str'>)
Removed columns: ['description', 'mineral_boxes']

Verifying labels in all splits:
  train: sample labels = ['9', '77', '6']
  First label value: '9' (should be a class name like 'graphite', 'gold', etc.)
  This means the class IDs were not properly converted to class names.
  validation: sample labels = ['48', '31', '16']
  First label value: '48' (should be a class name like 'graphite', 'gold', etc.)
  This means the class IDs were not properly converted to class names.
  test: sample labels = ['29', '9', '28']
  First label value: '29' (should be a class name like 'graphite', 'gold', etc.)
  This means the class IDs were not properly converted to class names.


Saving the dataset (6/6 shards): 100%|██████████| 12828/12828 [00:08<00:00, 1428.55 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 2749/2749 [00:01<00:00, 1589.42 examples/s]
Saving the dataset (2/2 shards): 100%|██████████| 2749/2749 [00:01<00:00, 1653.92 examples/s]


Dataset exported successfully to: ./mineralimage5K-98-processed
Final features: dict_keys(['image', 'label'])

Dataset info:
{'image': Image(mode=None, decode=True), 'label': Value('string')}





In [11]:
# Load the processed dataset and combine all splits into one training split
from datasets import load_from_disk, concatenate_datasets, DatasetDict, ClassLabel, Value

# Load the processed dataset
processed_ds = load_from_disk("./mineralimage5K-98-processed")

print("Original dataset splits:", list(processed_ds.keys()))
print(f"Train size: {len(processed_ds['train'])}")
print(f"Validation size: {len(processed_ds['validation'])}")
print(f"Test size: {len(processed_ds['test'])}")

# Check if labels need conversion from IDs to names
label_feature = processed_ds['train'].features.get('label')

# Check if labels are still ClassLabel (integer IDs) or if they're string numbers
sample_label = processed_ds['train'][0]['label']
needs_conversion = False

if isinstance(label_feature, ClassLabel):
    # Labels are ClassLabel (integers) - convert to names
    class_names = label_feature.names
    print(f"\nConverting {len(class_names)} class IDs to class names...")
    needs_conversion = True
    
    def convert_id_to_name(example):
        if 'label' in example:
            label_value = example['label']
            if isinstance(label_value, int):
                if 0 <= label_value < len(class_names):
                    example['label'] = class_names[label_value]
        return example
    
    processed_ds = processed_ds.map(convert_id_to_name, desc="Converting IDs to names")
    processed_ds = processed_ds.cast_column('label', Value('string'))
    print("Converted class IDs to class names (strings)")
elif isinstance(sample_label, str) and sample_label.isdigit():
    # Labels are string numbers (like '9') - need to convert using saved class names
    # We need to reload or reconstruct the class names
    # Try to load from the original dataset or use a mapping
    print(f"\nLabels are numeric strings - converting to class names...")
    print("Note: You may need to re-run cell 1 to properly convert labels.")
    print(f"Current sample label: {sample_label}")
    
    # We'll create a mapping - but we need the class names
    # For now, let's try to get them from a fresh load or check if we can infer
    try:
        # Try to get class names from original dataset
        from datasets import load_dataset
        temp_ds = load_dataset("Nech-C/mineralimage5K-98", split='train[:1]')
        if 'name' in temp_ds.column_names:
            original_label_feature = temp_ds.features.get('name')
            if isinstance(original_label_feature, ClassLabel):
                original_names = original_label_feature.names
                # Create lowercase mapping
                seen = set()
                lowercase_names = []
                for name in original_names:
                    lower_name = name.lower()
                    if lower_name not in seen:
                        seen.add(lower_name)
                        lowercase_names.append(lower_name)
                
                def convert_numeric_string_to_name(example):
                    if 'label' in example:
                        label_str = str(example['label'])
                        if label_str.isdigit():
                            label_id = int(label_str)
                            if 0 <= label_id < len(lowercase_names):
                                example['label'] = lowercase_names[label_id]
                    return example
                
                processed_ds = processed_ds.map(convert_numeric_string_to_name, desc="Converting numeric strings to names")
                print(f"Converted {len(lowercase_names)} numeric string labels to class names")
    except Exception as e:
        print(f"Could not automatically convert: {e}")
        print("Please re-run cell 1 to properly convert the labels.")

# Combine all splits into a single training dataset
combined_train = concatenate_datasets([
    processed_ds['train'],
    processed_ds['validation'],
    processed_ds['test']
])

print(f"\nCombined dataset size: {len(combined_train)}")

# Create a new DatasetDict with only the combined train split
ds_combined = DatasetDict({"train": combined_train})

# Export the combined dataset
export_path_combined = "./mineralimage5K-98-combined"
ds_combined.save_to_disk(export_path_combined)

print(f"\nCombined dataset exported successfully to: {export_path_combined}")
print(f"Available features: {ds_combined['train'].features.keys()}")

# Show sample labels to verify they are now strings
print(f"\nSample labels (first 5): {[ds_combined['train'][i]['label'] for i in range(5)]}")


Original dataset splits: ['train', 'validation', 'test']
Train size: 12828
Validation size: 2749
Test size: 2749

Labels are numeric strings - converting to class names...
Note: You may need to re-run cell 1 to properly convert labels.
Current sample label: 9


Converting numeric strings to names: 100%|██████████| 12828/12828 [00:06<00:00, 2051.51 examples/s]
Converting numeric strings to names: 100%|██████████| 2749/2749 [00:01<00:00, 1927.84 examples/s]
Converting numeric strings to names: 100%|██████████| 2749/2749 [00:01<00:00, 1924.31 examples/s]


Converted 98 numeric string labels to class names

Combined dataset size: 18326


Saving the dataset (8/8 shards): 100%|██████████| 18326/18326 [00:12<00:00, 1472.88 examples/s]


Combined dataset exported successfully to: ./mineralimage5K-98-combined
Available features: dict_keys(['image', 'label'])

Sample labels (first 5): ['microcline', 'aegirine', 'hematite', 'andalusite', 'quartz']





In [12]:
# Display the classes
# Load the combined dataset to show classes
from datasets import load_from_disk

try:
    ds_combined = load_from_disk("./mineralimage5K-98-combined")
    ds = ds_combined
except:
    pass  # Use existing ds variable if loading fails

if 'train' in ds:
    # Get unique labels from the train split
    if 'label' in ds['train'].features:
        labels = ds['train']['label']
        unique_labels = sorted(set(labels))
        
        # Check if labels are strings (class names) or integers (class IDs)
        from datasets import ClassLabel
        
        if isinstance(ds['train'].features['label'], ClassLabel):
            # Labels are still ClassLabel (integers) - get names from feature
            class_names = [ds['train'].features['label'].names[i] for i in unique_labels]
            # Create a DataFrame to display classes in column format
            classes_df = pd.DataFrame({
                'Class ID': unique_labels,
                'Class Name': class_names
            })
        else:
            # Labels are already strings (class names)
            class_names = sorted(unique_labels)
            # Create a DataFrame to display classes - no IDs, just names
            classes_df = pd.DataFrame({
                'Class Name': class_names
            })
        
        print("Dataset Classes:")
        print(classes_df.to_string(index=False))
        print(f"\nTotal number of classes: {len(unique_labels)}")
    else:
        print("Available features:", ds['train'].features.keys())
        print("\nDataset info:")
        print(ds['train'].features)
else:
    print("Available splits:", ds.keys())
    print("\nDataset info:")
    print(ds)

Dataset Classes:
  Class Name
  actinolite
    aegirine
       agate
      albite
   almandine
   amazonite
       amber
    amethyst
     analcim
  andalusite
   andradite
  antimonite
     apatite
   aragonite
     arsenic
arsenopyrite
      augite
     azurite
      barite
       beryl
     bismuth
  bournonite
     calcite
   cancrinit
   carnelian
 cassiterite
   celestine
  chalcedony
chalcopyrite
    chromite
 chrysoberyl
 chrysoprase
    cinnabar
    cobaltin
      copper
    corundum
      credit
     cuprite
    diopside
    dolomite
      elbait
     epidote
   eudialyte
       flint
fluorapatite
    fluorite
      galena
    goethite
        gold
    graphite
   grossular
      gypsum
hedenbergite
    hematite
  hornblende
    ilmenite
      jasper
     kyanite
    labrador
lapis lazuli
    limonite
   magnetite
   malachite
   marcasite
  microcline
 molybdenite
   muscovite
   natrolite
   nephritis
  oligoclase
        opal
  orthoclase
   pectolite
   pollucite
    preh