In [21]:
import os
import shutil
import random
from sklearn.model_selection import train_test_split

dataset = 'MaSTr1325'
img_dir = 'MaSTr1325_images_512x384'
mask_dir = 'MaSTr1325_masks_512x384'

# 目标路径
train_img_dir = os.path.join(dataset, 'train')
train_mask_dir = os.path.join(dataset, 'train_mask')
val_img_dir = os.path.join(dataset, 'val')
val_mask_dir = os.path.join(dataset, 'val_mask')
test_img_dir = os.path.join(dataset, 'test')
test_mask_dir = os.path.join(dataset, 'test_mask')

# 创建目标目录
for d in [train_img_dir, train_mask_dir, val_img_dir, val_mask_dir, test_img_dir, test_mask_dir]:
    os.makedirs(d, exist_ok=True)

images = [f for f in os.listdir(img_dir) if f.endswith('.jpg')]
masks = [f for f in os.listdir(mask_dir) if f.endswith('.png')]

# 确保文件名匹配
images.sort()
masks.sort()
assert len(images) == len(masks), "Number of images and masks do not match"

In [22]:
# 70% 训练集, 20% 验证集, 10% 测试集
train_imgs, temp_imgs, train_masks, temp_masks = train_test_split(images, masks, test_size=0.3, random_state=42)
val_imgs, test_imgs, val_masks, test_masks = train_test_split(temp_imgs, temp_masks, test_size=0.33, random_state=42)


In [23]:
def move_files(file_list, src_dir, dest_dir):
    for file in file_list:
        shutil.move(os.path.join(src_dir, file), os.path.join(dest_dir, file))

# 移动训练集
move_files(train_imgs, img_dir, train_img_dir)
move_files(train_masks, mask_dir, train_mask_dir)

# 移动验证集
move_files(val_imgs, img_dir, val_img_dir)
move_files(val_masks, mask_dir, val_mask_dir)

# 移动测试集
move_files(test_imgs, img_dir, test_img_dir)
move_files(test_masks, mask_dir, test_mask_dir)


In [24]:
print("Training set size:", len(os.listdir(train_img_dir)))
print("Validation set size:", len(os.listdir(val_img_dir)))
print("Test set size:", len(os.listdir(test_img_dir)))


Training set size: 927
Validation set size: 266
Test set size: 132


In [26]:
import os

def generate_file_list(directory, output_file):
    with open(output_file, 'w') as f:
        for file_name in sorted(os.listdir(directory)):
            if file_name.endswith('.jpg'):
                file_name_without_ext = os.path.splitext(file_name)[0]
                f.write(f"{file_name_without_ext}\n")

# 定义数据集目录和输出文件
dataset = 'MaSTr1325'

# 生成训练集图像和掩码列表
generate_file_list(os.path.join(dataset, 'train'), os.path.join(dataset, 'train_images.txt'))

# 生成验证集图像和掩码列表
generate_file_list(os.path.join(dataset, 'val'), os.path.join(dataset, 'val_images.txt'))

# 生成测试集图像和掩码列表
generate_file_list(os.path.join(dataset, 'test'), os.path.join(dataset, 'test_images.txt'))

print("File lists generated successfully.")

File lists generated successfully.
