From https://huggingface.co/docs/transformers/tasks/image_classification

In [None]:
import wandb
from datasets import load_dataset 
from transformers import AutoImageProcessor
from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
from transformers import DefaultDataCollator
import evaluate
import numpy as np
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from huggingface_hub import notebook_login
from transformers import pipeline
import sys
sys.path.insert(0, '..')
project_name = 'fl_image_category_multi_label'

In [None]:
wandb.init(project=project_name)

In [None]:
notebook_login()

In [None]:
data = load_dataset("./fl_image_category_ds/", split="train")
data = data.train_test_split(test_size=0.2)
data = data.remove_columns(['sku', 'mpid'])
data

In [None]:
labels = data["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = str(i)
    id2label[str(i)] = label

In [None]:
checkpoint = "google/vit-base-patch16-224-in21k"
# checkpoint = "microsoft/resnet-50"
# checkpoint = 'microsoft/swin-tiny-patch4-window7-224'
# checkpoint = 'apple/mobilevit-xx-small'
# checkpoint = 'microsoft/resnet-18'
image_processor = AutoImageProcessor.from_pretrained(checkpoint, problem_type="multi_label_classification",
                    num_labels=len(labels), id2label=id2label, label2id=label2id)

In [None]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
size = (
    image_processor.size["shortest_edge"]
    if "shortest_edge" in image_processor.size
    else (image_processor.size["height"], image_processor.size["width"])
)
_transforms = Compose([RandomResizedCrop(size), ToTensor(), normalize])

def transforms(examples):
    examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
    del examples["image"]
    return examples

data = data.with_transform(transforms)

In [None]:
data_collator = DefaultDataCollator()
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

model = AutoModelForImageClassification.from_pretrained(
    checkpoint,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id,
    problem_type="multi_label_classification",
    # ignore_mismatched_sizes=True,
)
training_args = TrainingArguments(
    output_dir=project_name,
    remove_unused_columns=False,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=16,
    num_train_epochs=5,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=data["train"],
    eval_dataset=data["test"],
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
trainer.push_to_hub()

In [None]:
classifier = pipeline(model="StephenSKelley/fl_image_category_multi_label", top_k=len(labels))
#                      ,revision='ede513890a638c8eec54cc4657a6d66cc6810154')

In [None]:
ds = load_dataset("./fl_image_category_ds/", split="train")[-100:]
predictions = classifier(ds['image'])
total = 0
correct = 0
for i in range(len(predictions)):
    size = ds['image'][i].size
    display(ds['image'][i].resize((100,int(100 * size[1]/size[0]))))
    total = total + 1
    label = id2label[str(ds['label'][i])]
    prediction = predictions[i][np.argmax([x['score'] for x in predictions[i]])]
    print(label, prediction)
    print(predictions[i])
    if prediction['label'] == label:
        correct = correct + 1
    # print(id2label[str(ds['label'][i])], predictions[i][np.argmax([x['score'] for x in predictions[i]])])
print(correct)