In [None]:
# Diagnostic: Check all available GPUs (run this BEFORE setting CUDA_VISIBLE_DEVICES)
# This helps identify which GPU is which
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=index,name,memory.total,compute_cap', '--format=csv,noheader'], 
                       capture_output=True, text=True)
if result.returncode == 0:
    print("Available GPUs:")
    print(result.stdout)
else:
    print("Could not query GPUs with nvidia-smi")


In [None]:
import os

# MUST be done before ANYTHING imports torch
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
#os.environ["TORCH_USE_CUDA_DSA"] = "1"
#os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

print("CUDA_VISIBLE_DEVICES =", os.environ["CUDA_VISIBLE_DEVICES"])

In [None]:
import torch
print("Torch sees GPUs:", torch.cuda.device_count())
print("Using device:", torch.cuda.get_device_name(0))

In [None]:

import torch
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)
from datasets import load_dataset
dataset = load_dataset("tedqc/mineral-dataset")

In [None]:
# First, identify labels that have at least 2 samples (so they can be split into train/test)
# Count occurrences of each label
from collections import Counter
from datasets import ClassLabel

label_counts = Counter(dataset['train']['label'])
print(f"Total samples: {len(dataset['train'])}")
print(f"Total unique labels: {len(label_counts)}")

# Find labels with only 1 sample (these can't be split)
single_sample_labels = [label for label, count in label_counts.items() if count == 1]
print(f"Labels with only 1 sample: {len(single_sample_labels)}")

# Filter dataset to only include labels with at least 2 samples
if len(single_sample_labels) > 0:
    print(f"\nFiltering out {len(single_sample_labels)} labels with only 1 sample...")
    print(f"First 10 labels being removed: {single_sample_labels[:10]}")
    if len(single_sample_labels) > 10:
        print(f"... and {len(single_sample_labels) - 10} more")
    
    # Filter the dataset
    dataset = dataset['train'].filter(lambda x: x['label'] not in single_sample_labels)
    print(f"\nAfter filtering: {len(dataset)} samples, {len(set(dataset['label']))} unique labels")
else:
    print("All labels have at least 2 samples. No filtering needed.")
    dataset = dataset['train']

# Convert label column to ClassLabel for stratified splitting
# Get sorted unique labels
unique_labels = sorted(set(dataset['label']))
print(f"\nConverting label column to ClassLabel for stratified splitting...")
print(f"Total unique labels after filtering: {len(unique_labels)}")

# Cast the label column to ClassLabel
dataset = dataset.cast_column('label', ClassLabel(names=unique_labels))

# Now split with stratification to ensure each label appears in both train and test
# Stratified split ensures proportional distribution of labels in both sets
print(f"\nSplitting dataset (80% train, 20% test) with stratification...")
dataset = dataset.train_test_split(test_size=0.2, stratify_by_column='label', seed=42)

print(f"\nFinal dataset sizes:")
print(f"  Train: {len(dataset['train'])} samples")
print(f"  Test: {len(dataset['test'])} samples")
print(f"  Train labels: {len(set(dataset['train']['label']))} unique")
print(f"  Test labels: {len(set(dataset['test']['label']))} unique")

# Verify that all labels appear in both sets
train_labels = set(dataset['train']['label'])
test_labels = set(dataset['test']['label'])
all_labels = train_labels | test_labels
labels_only_in_train = train_labels - test_labels
labels_only_in_test = test_labels - train_labels

if len(labels_only_in_train) > 0:
    print(f"\n⚠️  Warning: {len(labels_only_in_train)} labels only in train set (this shouldn't happen with stratification)")
if len(labels_only_in_test) > 0:    
    print(f"⚠️  Warning: {len(labels_only_in_test)} labels only in test set (this shouldn't happen with stratification)")
if len(labels_only_in_train) == 0 and len(labels_only_in_test) == 0:
    print(f"\n✅ Success! All {len(all_labels)} labels appear in both train and test sets.")

In [None]:
dataset["train"][0]

In [None]:
# Get labels - handle both ClassLabel and string features
# With stratified splitting, all labels should appear in both train and test sets
from datasets import ClassLabel

label_feature = dataset["train"].features["label"]
if isinstance(label_feature, ClassLabel):
    # If it's ClassLabel, get names from feature
    labels = label_feature.names
else:
    # Get unique labels (should be the same from train or test due to stratification)
    train_labels = set(dataset["train"]["label"])
    test_labels = set(dataset["test"]["label"])
    all_labels = train_labels | test_labels  # Union of both sets
    labels = sorted(all_labels)
    
    print(f"Found {len(train_labels)} unique labels in training set")
    print(f"Found {len(test_labels)} unique labels in test set")
    print(f"Total unique labels: {len(labels)}")
    
    # Verify that all labels appear in both sets (should be true with stratified split)
    labels_only_in_train = train_labels - test_labels
    labels_only_in_test = test_labels - train_labels
    
    if len(labels_only_in_train) > 0:
        print(f"\n⚠️  Warning: {len(labels_only_in_train)} labels only appear in train set:")
        for label in sorted(labels_only_in_train)[:10]:  # Show first 10
            print(f"    - {label}")
        if len(labels_only_in_train) > 10:
            print(f"    ... and {len(labels_only_in_train) - 10} more")
    
    if len(labels_only_in_test) > 0:
        print(f"\n⚠️  Warning: {len(labels_only_in_test)} labels only appear in test set:")
        for label in sorted(labels_only_in_test)[:10]:  # Show first 10
            print(f"    - {label}")
        if len(labels_only_in_test) > 10:
            print(f"    ... and {len(labels_only_in_test) - 10} more")
    
    if len(labels_only_in_train) == 0 and len(labels_only_in_test) == 0:
        print(f"\n✅ Verified: All {len(labels)} labels appear in both train and test sets!")

# Create mappings
# label2id: maps label name (string) -> integer ID (for training)
# id2label: maps integer ID -> label name (string)
# IMPORTANT: id2label must use INTEGER keys only, not string keys!
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i  # Use integer for label2id (needed for training)
    id2label[i] = label  # Use INTEGER key for id2label (transformers requirement)

# Verify consistency
print(f"Total labels: {len(labels)}")
print(f"label2id entries: {len(label2id)}")
print(f"id2label entries: {len(id2label)}")
print(f"\nSample verification (first 5 labels):")
for i in range(min(5, len(labels))):
    label = labels[i]
    assert label2id[label] == i, f"Mismatch: label2id['{label}'] = {label2id[label]}, expected {i}"
    assert id2label[i] == label, f"Mismatch: id2label[{i}] = {id2label[i]}, expected '{label}'"
    print(f"  ✓ {i}: '{label}' -> label2id={label2id[label]}, id2label[{i}]='{id2label[i]}'")
print("✅ All mappings are consistent!")

In [None]:
print(f"Number of labels: {len(labels)}")
print(f"Number of label2id entries: {len(label2id)}")
print(f"Number of id2label entries: {len(id2label)}")
print(f"\nFirst 10 labels:")
for i in range(min(10, len(labels))):
    label = labels[i]
    print(f"  Label: {label}, label2id: {label2id.get(label, 'MISSING')}, id2label: {id2label.get(i, 'MISSING')}")
    # Verify they match
    if label2id.get(label) != i or id2label.get(i) != label:
        print(f"    ⚠️  MISMATCH!")

# Verify consistency and that id2label only has integer keys
print(f"\nVerifying consistency...")
all_match = True
for i, label in enumerate(labels):
    if label2id.get(label) != i:
        print(f"Mismatch: label '{label}' -> label2id={label2id.get(label)}, expected {i}")
        all_match = False
    if id2label.get(i) != label:
        print(f"Mismatch: id {i} -> id2label={id2label.get(i)}, expected '{label}'")
        all_match = False

# Check that id2label only has integer keys (not string keys)
has_string_keys = any(isinstance(k, str) for k in id2label.keys())
if has_string_keys:
    print("❌ ERROR: id2label contains string keys! It should only have integer keys.")
    string_keys = [k for k in id2label.keys() if isinstance(k, str)]
    print(f"   Found {len(string_keys)} string keys (showing first 10): {string_keys[:10]}")
    all_match = False

if all_match:
    print("✅ All mappings are consistent and id2label uses only integer keys!")
else:
    print("❌ Found inconsistencies in mappings")

In [None]:
from transformers import AutoImageProcessor

checkpoint = "facebook/convnext-base-224"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [None]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [None]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    # Convert string labels to integer IDs
    if "label" in examples:
        label_ids = []
        for label in examples["label"]:
            if isinstance(label, str):
                if label in label2id:
                    label_ids.append(label2id[label])
                else:
                    # This shouldn't happen if labels were collected from both sets
                    raise KeyError(f"Label '{label}' not found in label2id. "
                                 f"This means the label wasn't in the training or test set when labels were collected. "
                                 f"Please re-run the label collection cell (Cell 6).")
            else:
                label_ids.append(int(label))
        examples["label"] = label_ids
    del examples["image"]
    return examples

In [None]:
dataset = dataset.with_transform(transforms)

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")

In [None]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

# Move model to the correct device
model = model.to(device)

In [None]:
training_args = TrainingArguments(
    output_dir="mineral_model",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    fp16=True,  # Enable half precision (fp16) training
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()