In [None]:
import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import pandas as pd
import glob
import cv2
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from collections import Counter

eff_l = timm.create_model(
    "tf_efficientnetv2_l.in21k",
    pretrained=False,
    num_classes=107,
    features_only=False,
)

convnext = timm.create_model(
    "convnext_large.fb_in22k",
    pretrained=False,
    num_classes=107,
    features_only=False,
    # drop_path_rate=0.2,
    drop_rate=0.4,
)
effvit = timm.create_model(
    "timm/efficientvit_l3.r384_in1k",
    pretrained=False,
    num_classes=107,  # 107
    features_only=False,
    # drop_path_rate=0.2,
    drop_rate=0.4,
)

from model import DualModel

model = DualModel()

from pl_tool import LightningModule
from option import get_option

opt = get_option()
Model = LightningModule(opt, model, 1)

valid_transform = A.Compose(
    [
        A.Resize(384, 384),
        A.Normalize(),
        ToTensorV2(),
    ]
)
import torchvision.transforms as transforms

# 定义随机裁剪和线性调整大小的变换
transform = transforms.Compose(
    [
        transforms.RandomCrop(size=(336, 336)),
        transforms.Resize(
            size=(384, 384), interpolation=transforms.InterpolationMode.BILINEAR
        ),
    ]
)

In [None]:
class TestDataset(Dataset):
    def __init__(self, test_list, transform=None):
        self.test_list = test_list
        self.transform = transform

    def __len__(self):
        return len(self.test_list)

    def __getitem__(self, idx):
        path = self.test_list[idx]
        img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
        if self.transform:
            img = self.transform(image=img)["image"]
        return img, path


test_path = "/home/ubuntu/Competition/XunFeiAnimal/data2/testA"
test_list = glob.glob(test_path + "/*.jpg")
test_list = sorted(test_list)


ckpt_list = (
    glob.glob(
        "./checkpoints/dual_effv2_l_effvit_l3_data2V3_4fold_m3_dp" + "/epoch*.ckpt"
    )
    # + glob.glob("./checkpoints/effvit_l3_data2V3_4fold_m3_0.95161" + "/epoch*.ckpt")
    # + glob.glob("./checkpoints/effv2_l_data2V3_5_6fold_0.94656" + "/*epoch*.ckpt")
    # + glob.glob("./checkpoints/effv2_l_data2V3_5fold_0.94265" + "/*epoch*.ckpt")
    # + glob.glob("./checkpoints/effvit_l3_data2V3_4fold_m5_0.94154" + "/*epoch*.ckpt")
)

# Store predictions for each checkpoint
all_predictions = []

for ckpt_path in ckpt_list:
    print(f"Loading checkpoint {ckpt_path}")
    ckpt = torch.load(ckpt_path, map_location="cpu")
    for k, v in list(ckpt["state_dict"].items()):
        if "model." in k:
            ckpt["state_dict"][k[6:]] = ckpt["state_dict"].pop(k)

    if "dual" in ckpt_path:
        model = Model
        ckpt = torch.load(ckpt_path, map_location="cpu")
    elif "effv2_l" in ckpt_path:
        model = eff_l
    elif "convnext" in ckpt_path:
        model = convnext
    elif "effvit" in ckpt_path:
        model = effvit
    model.load_state_dict(ckpt["state_dict"])
    model = model.cuda(4)
    model.eval()
    test_dataset = TestDataset(test_list, transform=valid_transform)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=4)

    predictions = []
    with torch.no_grad():
        for img, name in test_loader:
            img = img.cuda(4)
            # 初始化预测结果
            pred = (
                torch.softmax(model(img), dim=1)
                + torch.softmax(model(img.flip(-1)), dim=1)
                + torch.softmax(model(img.flip(-2)), dim=1)
                + torch.softmax(model(img.flip(-1, -2)), dim=1)
            )

            # 应用TTA
            for _ in range(5):  # 假设进行5次随机裁剪和调整大小
                augmented_img = transform(img)
                pred += (
                    torch.softmax(model(augmented_img), dim=1)
                    + torch.softmax(model(augmented_img.flip(-1)), dim=1)
                    + torch.softmax(model(augmented_img.flip(-2)), dim=1)
                    + torch.softmax(model(augmented_img.flip(-1, -2)), dim=1)
                )
            # 计算预测类别
            pred_class = torch.argmax(pred, dim=1).cpu().numpy()
            predictions.extend(pred_class)

    all_predictions.append(predictions)

# 对所有 checkpoint 的预测结果进行投票
final_predictions = []
for i in range(len(test_list)):
    votes = [pred[i] for pred in all_predictions]
    final_predictions.append(Counter(votes).most_common(1)[0][0])

print(final_predictions)
test_list = [path.split("/")[-1] for path in test_list]
final_predictions = [str(x + 1) for x in final_predictions]
df = pd.DataFrame({"uuid": test_list, "label": final_predictions})
df.to_csv("submission.csv", index=False)

In [None]:
import pandas as pd
import os

test_list = [os.path.basename(path) for path in test_list]
df = pd.DataFrame({"ImageID": test_list, "label": final_predictions})
df.to_csv("submission.csv", index=False)