In [None]:
from transformers import AutoImageProcessor, AutoModelForImageClassification

id2label = {
    0: "Cassava Bacterial Blight (CBB)",
    1: "Cassava Brown Streak Disease (CBSD)",
    2: "Cassava Green Mottle (CGM)",
    3: "Cassava Mosaic Disease (CMD)",
    4: "Healthy",
}
label2id = {
    "Cassava Bacterial Blight (CBB)": 0,
    "Cassava Brown Streak Disease (CBSD)": 1,
    "Cassava Green Mottle (CGM)": 2,
    "Cassava Mosaic Disease (CMD)": 3,
    "Healthy": 4,
}

In [None]:
import PIL
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Normalize, Resize, CenterCrop, ToTensor


class CassavaDataset(Dataset):
    def __init__(self, folder, image_processor):
        self.folder = folder
        self.image_processor = image_processor
        self.image_paths = list(folder.glob("*"))
        self.image_mean, self.image_std = (
            self.image_processor.image_mean,
            self.image_processor.image_std,
        )
        size = self.image_processor.size["shortest_edge"]
        normalize = Normalize(mean=self.image_mean, std=self.image_std)
        self.test_transforms = Compose(
            [
                Resize(size),
                CenterCrop(size),
                ToTensor(),
                normalize,
            ]
        )

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        with PIL.Image.open(image_path) as image:
            inputs = self.test_transforms(image.convert("RGB"))
        return inputs, image_path.name



# ViT full

In [None]:
model_path = "/kaggle/input/sc4000-vit-large/models"

In [None]:
vit_outputs = []

dataset = CassavaDataset(folder, image_processor)
dataloader = DataLoader(dataset, batch_size=16)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

with torch.no_grad():
    for batch, image_names in dataloader:
        outputs = model(batch.to(device))
        # predictions = outputs.logits.argmax(dim=-1).cpu().numpy()

        submissions.extend(
            {"image_id": image_name, "output": outputs}
            for image_name, prediction in zip(image_names, predictions)
        )