# SC4000 ViT large full

In [1]:
from transformers import AutoImageProcessor, AutoModelForImageClassification

In [2]:
id2label = {0: 'Cassava Bacterial Blight (CBB)', 1: 'Cassava Brown Streak Disease (CBSD)', 2: 'Cassava Green Mottle (CGM)', 3: 'Cassava Mosaic Disease (CMD)', 4: 'Healthy'}
label2id = {'Cassava Bacterial Blight (CBB)': 0, 'Cassava Brown Streak Disease (CBSD)': 1, 'Cassava Green Mottle (CGM)': 2, 'Cassava Mosaic Disease (CMD)': 3, 'Healthy': 4}

In [3]:
model_path = "/kaggle/input/model/"

In [4]:
model = AutoModelForImageClassification.from_pretrained(
    model_path,
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True,
)
image_processor = AutoImageProcessor.from_pretrained(model_path)

In [5]:
from pathlib import Path

folder = Path("/kaggle/input/cassava-leaf-disease-classification/test_images")

In [None]:
import PIL
import torch
from torch.utils.data import Dataset, DataLoader

class CassavaDataset(Dataset):
    def __init__(self, folder, image_processor):
        self.image_paths = list(folder.glob("*"))
        self.image_processor = image_processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        with PIL.Image.open(image_path) as image:
            inputs = self.image_processor(image, return_tensors="pt")
        return inputs, image_path.name

In [6]:
submissions = []

dataset = CassavaDataset(folder, image_processor)
dataloader = DataLoader(dataset, batch_size=32, num_workers=4)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

with torch.no_grad():
    for batch, image_names in dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1).cpu().numpy()

        submissions.extend(
            {"image_id": image_name, "label": prediction.item()}
            for image_name, prediction in zip(image_names, predictions)
        )
        print(f"Processed {len(submissions)} images")

/kaggle/input/cassava-leaf-disease-classification/test_images/2216849948.jpg


In [7]:
import pandas as pd

df = pd.DataFrame(submissions)

In [8]:
df.head()

Unnamed: 0,image_id,label
0,2216849948.jpg,4


In [9]:
df.to_csv("submission.csv", index=False)