In [None]:
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

model = timm.create_model(
    "vit_little_patch16_reg1_gap_256.sbb_in12k_ft_in1k",
    pretrained=False,
    num_classes=62,
    features_only=False,
)

valid_transform = A.Compose(
    [
        A.Resize(292, 292),
        # A.Resize(int(opt.image_size / 0.875), int(opt.image_size / 0.875)),
        A.Normalize(),
        ToTensorV2(),
    ]
)

In [None]:
import glob
import cv2
import torch
import numpy as np

test_path = "/home/ubuntu/Competition/SignClassification/data/test_set"
test_list = glob.glob(test_path + "/*.png")
test_list = sorted(test_list)

ckpt_list = [
    "/home/ubuntu/Competition/SignClassification/checkpoints/baselinev1_fold0/epoch=167-val_f1=0.9969.ckpt",
    "/home/ubuntu/Competition/SignClassification/checkpoints/baselinev1_fold1/epoch=146-val_f1=0.9973.ckpt",
    "/home/ubuntu/Competition/SignClassification/checkpoints/baselinev1_fold2/epoch=168-val_f1=0.9967.ckpt",
    "/home/ubuntu/Competition/SignClassification/checkpoints/baselinev1_fold3/epoch=151-val_f1=0.9976.ckpt",
    "/home/ubuntu/Competition/SignClassification/checkpoints/baselinev1_fold4/epoch=160-val_f1=0.9920.ckpt",
]

# Store predictions for each checkpoint
all_predictions = []

for ckpt_path in ckpt_list:
    print(f"Loading checkpoint {ckpt_path}")
    ckpt = torch.load(ckpt_path)
    for k, v in list(ckpt["state_dict"].items()):
        if "model." in k:
            ckpt["state_dict"][k[6:]] = ckpt["state_dict"].pop(k)

    model.load_state_dict(ckpt["state_dict"])
    model.cuda(1)
    model.eval()

    predictions = []
    with torch.no_grad():
        for path in test_list:
            img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
            img = valid_transform(image=img)["image"]
            img = img.unsqueeze(0).cuda(1)

            pred = 0  # model(img)

            # TTA - Test Time Augmentation
            for i in range(8):
                x = np.random.randint(0, 32)
                y = np.random.randint(0, 32)
                croped_img = img[:, :, x : x + 256, y : y + 256]
                pred += (
                    model(croped_img)
                    + model(croped_img.flip(3))
                    + model(croped_img.flip(2))
                    + model(croped_img.flip(3).flip(2))
                )

            predictions.append(pred.argmax().item())

    all_predictions.append(predictions)

# Transpose the list to have predictions for each image across checkpoints
all_predictions = list(zip(*all_predictions))

# Choose the most frequent prediction for each image
final_predictions = []
for preds in all_predictions:
    # Get the most frequent prediction (mode)
    unique_preds, counts = np.unique(preds, return_counts=True)
    final_predictions.append(unique_preds[np.argmax(counts)])

print(final_predictions)

In [None]:
import pandas as pd
import os

test_list = [os.path.basename(path) for path in test_list]
df = pd.DataFrame({"ImageID": test_list, "label": final_predictions})
df.to_csv("submission.csv", index=False)