In [None]:
!pip install -q transformers

In [None]:
import pandas as pd
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader
from transformers import Trainer, TrainingArguments, BeitFeatureExtractor, BeitForImageClassification, default_data_collator
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
from PIL import Image
import torch
import cv2
import random
from sklearn.metrics import f1_score

RANDOM_SEED = 56
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
torch.cuda.manual_seed_all(RANDOM_SEED)

In [None]:
TRAIN_DIR = "../input/herbarium-2022-fgvc9/train_images/"
TEST_DIR = "../input/herbarium-2022-fgvc9/test_images/"

with open("../input/herbarium-2022-fgvc9/train_metadata.json") as json_file:
    train_meta = json.load(json_file)
with open("../input/herbarium-2022-fgvc9/test_metadata.json") as json_file:
    test_meta = json.load(json_file)
    
image_ids = [image["image_id"] for image in train_meta["images"]]
image_dirs = [TRAIN_DIR + image["file_name"] for image in train_meta["images"]]
category_ids = [annot["category_id"] for annot in train_meta["annotations"]]
genus_ids = [annot["genus_id"] for annot in train_meta["annotations"] ]
test_ids = [image["image_id"] for image in test_meta]
test_dirs = [TEST_DIR + image["file_name"] for image in test_meta ]

train_df = pd.DataFrame(data =np.array([image_ids , image_dirs, genus_ids, category_ids ]).T, 
                     columns = ["image_id", "directory","genus_id", "category",])
test_df = pd.DataFrame(data =np.array([test_ids  , test_dirs ]).T, 
                    columns = ["image_id", "directory",])

In [None]:
class HerbariumDataset(Dataset):
    def __init__(self, paths, labels, feature_extractor):
        self.paths = list(paths)
        self.labels = list(labels)
        self.feature_extractor = feature_extractor
        

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        if self.labels:
            label = self.labels[idx]
        image = Image.open(path).convert("RGB")
        pixel_values = self.feature_extractor(image, return_tensors="pt").pixel_values
        if self.labels:
            return {"pixel_values": pixel_values.squeeze(), "labels": torch.tensor(int(label))}
        return {"pixel_values": pixel_values.squeeze()}

In [None]:
model_name = 'microsoft/beit-base-patch16-224-pt22k-ft22k'
model = BeitForImageClassification.from_pretrained(model_name, num_labels=len(train_df.category.unique()), ignore_mismatched_sizes=True)

In [None]:
feature_extractor = BeitFeatureExtractor.from_pretrained(model_name)
train_df, val_df = train_test_split(train_df, test_size=0.02)
train_ds = HerbariumDataset(train_df['directory'], train_df['category'], feature_extractor)
val_ds = HerbariumDataset(val_df['directory'], val_df['category'], feature_extractor)

In [None]:
training_args = TrainingArguments(
    evaluation_strategy="steps",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    fp16=True, 
    overwrite_output_dir=True,
    output_dir="./",
    logging_steps=5,
    num_train_epochs=1,
    save_steps=10000,
    eval_steps=10000,
    report_to="none"
)

In [None]:
def compute_metrics(pred):
    y_true = pred.label_ids
    y_pred = pred.predictions.argmax(-1)
    return {'f1_macro': f1_score(y_true, y_pred, average='macro')}

In [None]:
trainer = Trainer(
    model=model,
    tokenizer=feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=default_data_collator,
)

# Train and predict from checkpoint because of the many images

In [None]:
trainer.train()

In [None]:
test_ds = HerbariumDataset(test_df['directory'], [], feature_extractor)
test_dl = DataLoader(test_ds, batch_size=16, shuffle=False, num_workers=4)

In [None]:
preds = []
with torch.no_grad():
    for inputs in tqdm(test_dl):
        inputs['pixel_values'] = inputs['pixel_values'].to('cuda')
        outputs = model(**inputs)
        logits = outputs.logits
        preds.extend([x.item() for x in logits.argmax(-1)])

In [None]:
submit = pd.read_csv('../input/herbarium-2022-fgvc9/sample_submission.csv')
submit['Predicted'] = preds
submit.to_csv('beit.csv', index=False)