In [None]:
import os
import torch
import wandb
import numpy as np
from datasets import load_dataset
from datasets import load_metric
from transformers import AutoImageProcessor
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
from torchvision.transforms import (
    CenterCrop,
    Compose,
    Normalize,
    RandomResizedCrop,
    Resize,
    ToTensor,
)

In [None]:

os.environ["WANDB_PROJECT"]="pc-rvlcdip-custom-training1"
os.environ["WANDB_LOG_MODEL"]="true"
os.environ["WANDB_WATCH"]="false"
model_checkpoint = "microsoft/dit-base-finetuned-rvlcdip"

In [None]:
dataset = load_dataset("custom-rvlcdip")

In [None]:
dataset

In [None]:
metric = load_metric("accuracy")

In [None]:
example = dataset["train"][10]

In [None]:
dataset["train"].features

In [None]:
dataset["train"].features["label"]

In [None]:
id2label = {
    0: "letter",
    1: "form",
    2: "email",
    3: "handwritten",
    4: "advertisement",
    5: "scientific report",
    6: "scientific publication",
    7: "specification",
    8: "file folder",
    9: "news article",
    10: "budget",
    11: "invoice",
    12: "presentation",
    13: "questionnaire",
    14: "resume",
    15: "memo"
}

label2id = {
    "advertisement": 4,
    "budget": 10,
    "email": 2,
    "file folder": 8,
    "form": 1,
    "handwritten": 3,
    "invoice": 11,
    "letter": 0,
    "memo": 15,
    "news article": 9,
    "presentation": 12,
    "questionnaire": 13,
    "resume": 14,
    "scientific publication": 6,
    "scientific report": 5,
    "specification": 7
}

In [None]:
image_processor  = AutoImageProcessor.from_pretrained(model_checkpoint)

In [None]:
normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)
if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
    crop_size = size
    max_size = None
elif "shortest_edge" in image_processor.size:
    size = image_processor.size["shortest_edge"]
    crop_size = (size, size)
    max_size = image_processor.size.get("longest_edge")

train_transforms = Compose(
        [
            RandomResizedCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            CenterCrop(crop_size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch

In [None]:
splits = dataset["train"].train_test_split(test_size=0.1)
train_ds = splits['train']
val_ds = splits['test']

In [None]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)

In [None]:
train_ds[0]

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint, 
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True,
)

In [None]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 64
args = TrainingArguments(
    f"{model_name}-finetuned-custom-rvlcdip",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=100,
    warmup_ratio=0.1,
    report_to="wandb",
    logging_steps=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=False,
)

In [None]:

def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
train_results = trainer.train()


In [None]:
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()