In [1]:
import os
from tqdm import tqdm
from PIL import Image
import numpy as np
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import pandas as pd
import torchvision.models as models
import matplotlib.pyplot as plt
import time
def crop_image(img):
    '''
    Crop the image into 224*224 blocks
    Args:
        img (PIL.Image): The image to be cropped
    Returns:
        list: The list of cropped images
    '''
    width, height = img.size
    # Create an empty list to store cropped images
    blocks = []
    for x in range(0, width, 224):
        for y in range(0, height, 224):
            block = img.crop((x,y,x+224,y+224))
            blocks.append(block)
    # Finally, output this list
    return blocks
def prediction(model_dir,test_dir,save_dir):   
    '''
    Predict the class labels of the test images and save the results to a CSV file
    Args:
        model_dir (string): The path of the model
        test_dir (string): The path of the folder containing the images to be predicted
        save_dir (string): The path of the CSV file to save the results
    '''
    print("Model Performance Testing...")
    print("Model loading...")
    model = models.efficientnet_b2(weights= models.EfficientNet_B2_Weights.DEFAULT)

    model.classifier = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(1408, 2)
    )
    model.load_state_dict(torch.load(model_dir))
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    print("Model loading success!")
    print("Prediction start...")
    preprocess = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5762883,0.45526023,0.32699665], std=[0.08670782,0.09286641,0.09925108])
    ])

    byz = 'Corypha'
    tz = 'Borassus'

    # Traverse all images in the folder
    for ind, filename in enumerate(os.listdir(test_dir)):
        if filename.endswith('.png') or filename.endswith('.jpg') or filename.endswith('.JPG') or filename.endswith('.bmp'):
            
            # open an image               
            print("Image open success! Now confirm the coordinate of image...")
            img_path = os.path.join(test_dir, filename)
            # with open(img_path,"rb" ) as f:
            #     t = f.read()
            # result_a = remove(data=t)
            result_Image = Image.open(img_path).convert('RGB')
            print("Palm-leaf identification begins! Now :{}/{}".format(ind+1, len(os.listdir(test_dir))))

            print("Image preprocessing...")

            # result_Image = Image.open(io.BytesIO(result_a)).convert('RGB')
            # img = Image.open(img_path).convert('RGB')
            # img = cv2.imread(img_path)
            # h, w, _ = img.size
            # blocks = crop_image(result_Image,img)
            
            blocks = crop_image(result_Image)
            results = []
            outputbyz = []
            outputtz = []
            batch_size = min(64, len(blocks))  # batch size不超过64
            for i in range(0, len(blocks), batch_size):
                batch = blocks[i:i + batch_size]
                batch = torch.stack([preprocess(block) for block in batch]).to(device)
                with torch.no_grad():
                    outputs = model(batch)
                    outputbyz.extend(outputs.cpu().numpy()[:,0])
                    outputtz.extend(outputs.cpu().numpy()[:,1])
                    preds = outputs.argmax(dim=1).cpu().numpy()
                    results.extend(preds)
            count_0 = np.sum(np.array(results) == 0)
            count_1 = np.sum(np.array(results) == 1)
            print("Palm-leaf identification is done!")
            print("Data saving...")
            df = pd.DataFrame({
            'filename': [img_path],
            'byz_all': [count_0],
            'tz_all': [count_1],  
            'outputbyz':[sum(outputbyz)],
            'outputtz':[sum(outputtz)],
            'result_output':[byz if sum(outputbyz)>sum(outputtz) else tz]
               })              
            if not os.path.exists(save_dir):
                df.to_csv(save_dir, index=False)
            else:
                df.to_csv(save_dir, mode='a', header=False, index=False)
            print("Done!")
    
if __name__ == '__main__':
    t1 = time.time()
    test_dir = "./test"
    model_dir = "model_efficientnetB2.pth"# can be downloaded in Release  "https://github.com/yxcsu/PLNet/releases/download/v1.0.0/model_efficientnetB2.zip"
    save_dir = "data.csv"
    prediction(model_dir,test_dir,save_dir)
    t2 = time.time()
    print("Running time: %.2fs"%(t2 - t1))

Model Performance Testing...
Model loading...
Model loading success!
Prediction start...
Image open success! Now confirm the coordinate of image...
Palm-leaf identification begins! Now :1/1
Image preprocessing...
Palm-leaf identification is done!
Data saving...
Done!
Running time: 1.42s
