# Import Libraries

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms

In [None]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load Data

In [None]:
batch_size = 64
img_size = 224

data_dir = ''

In [None]:
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

transforms = Compose(
        [
            Resize(img_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess(example_batch):

    example_batch["pixel_values"] = [
        transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

In [None]:
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform)
valid_dataset = datasets.ImageFolder(os.path.join(data_dir, 'valid'), transform)
test_dataset = datasets.ImageFolder(os.path.join(data_dir, 'test'), transform)

In [None]:
labels = train_dataset.features["label"].names
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

id2label[0]

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4)

valid_loader = torch.utils.data.DataLoader(valid_dataset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=4)

# Load Pre-trained Model

In [None]:
from transformers import AutoFeatureExtractor

feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224")
feature_extractor

In [None]:
from transformers import AutoImageProcessor, AutoModelForImageClassification, TrainingArguments, Trainer

processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224")
model = AutoModelForImageClassification.from_pretrained(
    "google/vit-base-patch16-224",
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True,
).to(device)

In [None]:
batch_size = 64
img_size = 224

# Model Training

In [None]:
model_name = "vit"

args = TrainingArguments(
    f"{model_name}-finetuned-skin-disease",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=10,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to='none'
)

In [None]:
from datasets import load_metric

metric = load_metric("accuracy")

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=feature_extractor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate()
# some nice to haves:
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

# Inference

In [None]:
from transformers import AutoModelForImageClassification, AutoFeatureExtractor

repo_name = "/content/../vit-finetuned-skin-disease"

feature_extractor = AutoFeatureExtractor.from_pretrained(repo_name,local_files_only=True)
model = AutoModelForImageClassification.from_pretrained(repo_name,local_files_only=True)

In [None]:
example = dataset["test"][-1]
image = example['image']
image

In [None]:
encoding = feature_extractor(image.convert("RGB"), return_tensors="pt")
print(encoding.pixel_values.shape)

In [None]:
import torch

# forward pass
with torch.no_grad():
  outputs = model(**encoding)
  logits = outputs.logits

In [None]:
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])