In [None]:
import cv2 
from PIL import Image 
import numpy as np 
import pandas as pd 
import os
from sklearn.preprocessing import MultiLabelBinarizer

img_dir = 'train'
mask_dir = 'train/masks'
excel_file = 'train/classif.xlsx'
data_feature_file = 'data_features.csv'

# Load images
def load_images(img_dir, count):
    images = []
    for i in range(1, count + 1):
        img_path = os.path.join(img_dir, f"{i}.jpg")
        if os.path.exists(img_path):
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)  # Convert from BGR to RGB format
            images.append(img)
        else:
            print(f"Image {img_path} not found.")
    return images

# Load masks
def load_masks(mask_dir, count):
    masks = []
    for i in range(1, count + 1):
        mask_path = os.path.join(mask_dir, f"binary_{i}.tif")
        if os.path.exists(mask_path):
            mask = Image.open(mask_path)
            mask = np.array(mask)
            masks.append(mask)
        else:
            print(f"Mask {mask_path} not found.")
    return masks

# Load classification file
def load_classification(excel_file):
    if os.path.exists(excel_file):
        return pd.read_excel(excel_file)
    else:
        print(f"Excel file {excel_file} not found.")
        return None
    
# Load features
def load_features(data_feature_file):
    if os.path.exists(data_feature_file):
        return pd.read_csv(data_feature_file)
    else:
        print(f"Feature file {data_feature_file} not found.")
        return None    
    
images = load_images(img_dir, 250)
masks = load_masks(mask_dir, 250)
classif_df = load_classification(excel_file)   
features_df = load_features(data_feature_file)  

# 处理标签
def process_labels(df):
    bug_types = []
    species = []
    
    for _, row in df.iterrows():
        bugs = row['bug_type']
        species_info = row['species']
        
        # 处理x2
        if ' x2' in species_info:
            bugs = [bugs.replace(' x2','')] * 2
            species_info = [species_info.replace(' x2','')] * 2
        else:
            bugs = bugs.split(' & ')
            species_info = species_info.split(' & ')
        
        # 处理问号
        bugs = [bug.replace(' ?', '') for bug in bugs]
        species_info = [specie.replace(' ?', '') for specie in species_info]
        
        bug_types.append(bugs)
        species.append(species_info)
    
    df['bug_type'] = bug_types
    df['species'] = species
    return df

images = load_images(img_dir, 250)
masks = load_masks(mask_dir, 250)
features_df = load_features(data_feature_file)  
classif_df = process_labels(load_classification(excel_file))

# 使用MultiLabelBinarizer转换标签
mlb_bug_type = MultiLabelBinarizer()
mlb_species = MultiLabelBinarizer()

bug_type_encoded = mlb_bug_type.fit_transform(classif_df['bug_type'])
species_encoded = mlb_species.fit_transform(classif_df['species'])

# 转换为DataFrame
bug_type_df = pd.DataFrame(bug_type_encoded, columns=mlb_bug_type.classes_)
species_df = pd.DataFrame(species_encoded, columns=mlb_species.classes_)

# 合并所有数据
all_data_df = pd.concat([classif_df.drop(columns=['bug_type', 'species']), bug_type_df, species_df, features_df], axis=1)


: 