In [None]:
# Diagnostic: Check all available GPUs (run this BEFORE setting CUDA_VISIBLE_DEVICES)
# This helps identify which GPU is which
import subprocess
result = subprocess.run(['nvidia-smi', '--query-gpu=index,name,memory.total,compute_cap', '--format=csv,noheader'], 
                       capture_output=True, text=True)
if result.returncode == 0:
    print("Available GPUs:")
    print(result.stdout)
else:
    print("Could not query GPUs with nvidia-smi")


In [None]:
import os
# IMPORTANT: Set CUDA_VISIBLE_DEVICES BEFORE importing torch
# Based on diagnostic: GPU 0 = GTX 1070 (incompatible), GPU 1 = RTX 3060 (compatible)
# Setting to "1" to use RTX 3060, which will appear as cuda:0 after this setting
# NOTE: You MUST restart the kernel for this to take effect if torch was already imported!

import torch
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

# After setting CUDA_VISIBLE_DEVICES, the selected GPU becomes visible as cuda:0
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU Name: {torch.cuda.get_device_name(0)}")
    props = torch.cuda.get_device_properties(0)
    print(f"GPU Memory: {props.total_memory / 1e9:.2f} GB")
    print(f"GPU Compute Capability: {props.major}.{props.minor}")
    print(f"Number of visible GPUs: {torch.cuda.device_count()}")
 
from datasets import load_dataset
dataset = load_dataset("tedqc/mineral-dataset")

In [None]:
dataset = dataset['train'].train_test_split(test_size=0.2)

In [None]:
dataset["train"][0]

In [None]:
# Get labels - handle both ClassLabel and string features
from datasets import ClassLabel

label_feature = dataset["train"].features["label"]
if isinstance(label_feature, ClassLabel):
    # If it's ClassLabel, get names from feature
    labels = label_feature.names
else:
    # If it's string, get unique labels from the dataset
    labels = sorted(set(dataset["train"]["label"]))

# Create mappings
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i  # Use integer for label2id (needed for training)
    id2label[i] = label  # Use integer key for id2label

In [None]:
from transformers import AutoImageProcessor

checkpoint = "facebook/convnext-tiny-224"
image_processor = AutoImageProcessor.from_pretrained(checkpoint)

In [None]:
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

In [None]:
def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    # Convert string labels to integer IDs
    if "label" in examples:
        examples["label"] = [label2id[label] if isinstance(label, str) else int(label) for label in examples["label"]]
    del examples["image"]
    return examples

In [None]:
dataset = dataset.with_transform(transforms)

In [None]:
from transformers import DefaultDataCollator

data_collator = DefaultDataCollator()

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")

In [None]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)

# Move model to the correct device
model = model.to(device)

In [None]:
training_args = TrainingArguments(
    output_dir="mineral_model",
    remove_unused_columns=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
    fp16=True,  # Enable half precision (fp16) training
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=image_processor,
    compute_metrics=compute_metrics,
)

trainer.train()