# Inference using checkpoint from Wandb + Lightning

This notebook is a companion for the Vit Transformer stater notebook:

[ViT Transformers - Sorghum 100 cultivar - Starter](https://www.kaggle.com/code/ibombonato/vit-transformers-sorghum-100-cultivar-starter)

Now that we have our model checkpoint saved on Wandb, we will do inference based on checkpoint loaded from Wandb. That way we can compute and fine tune models outside kaggle environment and this **will help us make ensemble of models** later.

**If it helps you in some manner, please upvote the dataset and the notebooks :D**

![image.png](attachment:40f45dd0-4c58-4034-9b51-a5cd34520e23.png)

### Load libs and minimal setup

In [None]:
!pip install -q timm
!pip install -q --upgrade wandb

In [None]:
import numpy as np
import pandas as pd
from tqdm import tqdm
from pathlib import Path

In [None]:
#Confirm that a GPU is available
!nvidia-smi

Grab your checkpoint file on wandb site and set the variable `CHECKPOINT_WANDB`:

https://wandb.ai/ibombonato/kaggle-sorghum-100-cultivar/artifacts/model/

![image.png](attachment:772d198e-6f4c-40ca-9c12-45959379a818.png)

In [None]:
ORIGIN_FOLDER = "../input/sorghum-100-cultivar-512x512-png-imagefolder/images"
MODEL_NAME = 'google/vit-base-patch16-224-in21k'

# CHANGE TO YOUR CHECKPOINT HERE!
CHECKPOINT_WANDB = 'ibombonato/kaggle-sorghum-100-cultivar/model-ul10jk3n:v46'

Load artifact from Wandb based on the `CHECKPOINT_WANDB` file

In [None]:
from kaggle_secrets import UserSecretsClient
import wandb

user_secrets = UserSecretsClient()

wandb.login(key=user_secrets.get_secret("WANDB_API_KEY"))

run = wandb.init()
artifact = run.use_artifact(CHECKPOINT_WANDB, type='model')
artifact_dir = artifact.download()

In [None]:
train_raw = pd.read_csv("../input/sorghum-id-fgvc-9/train_cultivar_mapping.csv")

In [None]:
import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchmetrics import Accuracy
from torchvision.datasets import ImageFolder
from transformers import AutoFeatureExtractor, ViTForImageClassification
from timm.data import ImageDataset

## Recreating train classes and models

Train notebook for reference:

[ViT Transformers - Sorghum 100 cultivar - Starter](https://www.kaggle.com/code/ibombonato/vit-transformers-sorghum-100-cultivar-starter)

Since pytorch will convert targets to numeric, we will map ids to labels and labels to ids, so we can get/acess the class names in the future.

In [None]:
all_ds = ImageFolder(Path(ORIGIN_FOLDER, "train"))

label2id = {}
id2label = {}

for i, class_name in enumerate(all_ds.classes):
    label2id[class_name] = str(i)
    id2label[str(i)] = class_name

In [None]:
class ImageClassificationCollator:
    def __init__(self, feature_extractor):
        self.feature_extractor = feature_extractor
 
    def __call__(self, batch):
        encodings = self.feature_extractor([x[0] for x in batch], return_tensors='pt')
        encodings['labels'] = torch.tensor([x[1] for x in batch], dtype=torch.long)
        return encodings

In [None]:
feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
model = ViTForImageClassification.from_pretrained(
    MODEL_NAME,
    num_labels=len(label2id),
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

collator = ImageClassificationCollator(feature_extractor)

In [None]:
class Classifier(pl.LightningModule):

    def __init__(self, model, lr: float = 2e-5, **kwargs):
        super().__init__()
        self.save_hyperparameters('lr', *list(kwargs))
        self.model = model
        self.forward = self.model.forward
        self.val_acc = Accuracy()

    def training_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"train_loss", outputs.loss)
        return outputs.loss

    def validation_step(self, batch, batch_idx):
        outputs = self(**batch)
        self.log(f"val_loss", outputs.loss)
        acc = self.val_acc(outputs.logits.argmax(1), batch['labels'])
        self.log(f"val_acc", acc, prog_bar=True)
        return outputs.loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)

## Loading model weights from checkpoint

In [None]:
pl.seed_everything(42)
classifier = Classifier(model)

CHECKPOINT_FILE = Path("./artifacts", Path(CHECKPOINT_WANDB).stem , "model.ckpt")

model = Classifier.load_from_checkpoint(
    CHECKPOINT_FILE,
    model=model
    )

# Make predictions

Now we will load and make predictions on the test set.

In [None]:
TEST_FOLDER = "../input/sorghum-100-cultivar-512x512-png-imagefolder/images/test"

test_ds = ImageDataset(Path(TEST_FOLDER))
test_dl = DataLoader(test_ds, batch_size=32, collate_fn=collator, num_workers=2)

In [None]:
model.cuda()
model.eval()

def batch_predictions(dl, ds, id2label):
    predictions = []
    for batch in tqdm(dl):
        image = batch['pixel_values'].cuda()
        with torch.no_grad():
            outputs = model(image)
            preds = outputs.logits.softmax(1).argmax(1).detach().cpu().numpy()
            predictions.append(preds)
        
    all_preds = []
    for batch in predictions:
        for prediction in batch:
            all_preds.append(id2label[str(prediction)])

    return all_preds, ds.filenames()

In [None]:
batch_preds, batch_filenames = batch_predictions(test_dl, test_ds, id2label)
df_preds = pd.DataFrame({'filename': batch_filenames, "cultivar": batch_preds})
df_preds.head()

# Submisson

At the moment, the testset or the sample_submission are broken and its not possible to submit. As soon as the organizers fix it, I will update with the submission.


In [None]:
test_df = pd.read_csv("../input/sorghum-id-fgvc-9/sample_submission.csv")

submission_df = pd.merge(test_df[['filename']], df_preds, how='inner', on='filename')

submission_df.to_csv("submission.csv", index = False)

submission_df.head()

## If it helps you in some manner, please upvote the dataset and the notebook :D