In [None]:
import os
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.io import read_image
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152

In [None]:
class Test_PetDataSet(Dataset):
    def __init__(self, df):
        # df的某一列为series,series的values为numpy.array
        self.file_name = df['Id'].values
    
    def __len__(self):
        return len(self.file_name)
    
    def __getitem__(self, idx):
        img_name = self.file_name[idx]
        img = read_image(os.path.join(test_path, img_name + '.jpg'))
        img_Id = img_name.split('/')[-1].split('.')[0]
        
        return img, img_Id

In [None]:
def load_model(ckpt_path):
    # 加载预训练模型,修改全连接层
    model = resnet18(pretrained=False)
    # 修改全连接层
    model.fc = nn.Linear(512, 1, bias=True)
    
    # 离线加载预训练模型
    r = model.load_state_dict(torch.load(ckpt_path))
    print(r)
    model = model.cuda()
    
    return model

In [None]:
test_path = '../input/petfinder-pawpularity-score/test/'
test_metadata_path = '../input/petfinder-pawpularity-score/test.csv'
sample_submission_path = '../input/petfinder-pawpularity-score/sample_submission.csv'

test_metadata = pd.read_csv(test_metadata_path)
submission = pd.read_csv(sample_submission_path)
test_dataset = Test_PetDataSet(test_metadata)
test_dataloader = DataLoader(test_dataset, batch_size=64, num_workers=2, shuffle=False)

ckpt_file_path = '../input/resnet-18-exp-1/'
preds = np.zeros(len(os.listdir(test_path)))
for ckpt_file in os.listdir(ckpt_file_path):
    model = load_model(os.path.join(ckpt_file_path, ckpt_file))
    pred = []
    # 约束BN和Dropout
    model.eval()
    # 约束计算，减少显存占用,感觉在no_grad的情况下会不再保存中间变量
    with torch.no_grad():
        # torchvision.io.read_image会按照顺序读取
        for i, (img, img_name) in enumerate(test_dataloader):
            print(img_name)
            img = img.cuda()
            pred_pawpularity = model(img.float()).view(-1)
            pred.extend(pred_pawpularity.data.cpu().numpy())
        preds += np.array(pred) / 5

submission['Pawpularity'] = preds
submission.to_csv('submission.csv', index=False)