In [None]:

# import torch
# import torch.distributed as dist
# import torch.multiprocessing as mp
# from torch.utils.data import Dataset, DataLoader
# from torchvision import transforms
# from PIL import Image
# import os
# import torch.nn as nn
# import torch.optim as optim
# from torchvision import models

# class CustomImageDataset(Dataset):
#     def __init__(self, images_path, labels_path, transform=None):
#         self.images_path = images_path
#         self.labels_path = labels_path
#         self.transform = transform
#         self.images_names = sorted(
#             [f for f in os.listdir(images_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

#     def __len__(self):
#         return len(self.images_names)

#     def __getitem__(self, idx):
#         img_name = self.images_names[idx]
#         img_path = os.path.join(self.images_path, img_name)
#         image = Image.open(img_path).convert('RGB')

#         file_base, file_extension = os.path.splitext(img_name)
#         label_name = file_base + '_label.txt'
#         label_path = os.path.join(self.labels_path, label_name)

#         with open(label_path, 'r') as f:
#             label_str = f.read().strip()
#             label = int(label_str.split()[0])

#         if self.transform:
#             image = self.transform(image)
#         return image, label

# # 定义变换
# transform = transforms.Compose([
#     transforms.Resize((224, 224)),
#     transforms.ToTensor(),
#     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ])

# # 创建数据集
# dataset = CustomImageDataset(
#     images_path=r'/data/sjwlab/meixw/dataset/mini/all_images_dtext',
#     labels_path=r'/data/sjwlab/meixw/dataset/mini/all_labels',
#     transform=transform
# )

# # 定义训练参数
# start_epoch, n_epochs = 0, 10

# def fn(rank, ws, devices):
#     device = devices[rank]

#     # 初始化进程组
#     dist.init_process_group('nccl', init_method='tcp://127.0.0.1:28765',
#                             rank=rank, world_size=ws)

#     torch.cuda.set_device(device)

#     # 创建数据加载器
#     sampler = torch.utils.data.distributed.DistributedSampler(dataset)
#     loader = DataLoader(dataset, batch_size=64, sampler=sampler)

#     # 实例化模型
#     resnet = models.resnet101(pretrained=True)
#     num_ftrs = resnet.fc.in_features
#     num_classes = 3
#     new_fc = nn.Linear(num_ftrs, num_classes)
#     resnet.fc = new_fc
#     model = resnet.cuda(device)
#     model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])

#     # 定义损失函数和优化器
#     criterion = nn.CrossEntropyLoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

#     # 训练循环
#     for epoch in range(start_epoch, n_epochs):
#         sampler.set_epoch(epoch)
#         for batch_idx, (inputs, labels) in enumerate(loader):
#             inputs, labels = inputs.to(device), labels.to(device)
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()
#             optimizer.step()

#     # 保存模型
#     CHECKPOINT_PATH = "ddptestmodel.checkpoint"
#     if rank == 0:
#         torch.save(model.state_dict(), CHECKPOINT_PATH)

#     dist.barrier()
    
#     dist.destroy_process_group()

# if __name__ == "__main__":
#     devices = [0, 1, 3]
#     ws = len(devices)
    
#     mp.spawn(fn, nprocs=ws, args=(ws, devices))



import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import torch.nn as nn
import torch.optim as optim
from torchvision import models

class CustomImageDataset(Dataset):
    def __init__(self, images_path, labels_path, transform=None):
        self.images_path = images_path
        self.labels_path = labels_path
        self.transform = transform
        self.images_names = sorted(
            [f for f in os.listdir(images_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))])

    def __len__(self):
        return len(self.images_names)

    def __getitem__(self, idx):
        img_name = self.images_names[idx]
        img_path = os.path.join(self.images_path, img_name)
        image = Image.open(img_path).convert('RGB')

        file_base, file_extension = os.path.splitext(img_name)
        label_name = file_base + '_label.txt'
        label_path = os.path.join(self.labels_path, label_name)

        with open(label_path, 'r') as f:
            label_str = f.read().strip()
            label = int(label_str.split()[0])

        if self.transform:
            image = self.transform(image)
        return image, label


transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

dataset = CustomImageDataset(
    images_path=r'/data/sjwlab/meixw/dataset/mini/all_images_dtext',
    labels_path=r'/data/sjwlab/meixw/dataset/mini/all_labels',
    transform=transform
)


start_epoch, n_epochs = 0, 10

def fn(rank, ws, devices):
    device = devices[rank]

    dist.init_process_group('nccl', init_method='tcp://127.0.0.1:28765',
                            rank=rank, world_size=ws)

    torch.cuda.set_device(device)


    sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    loader = DataLoader(dataset, batch_size=64, sampler=sampler, num_workers=8)


    resnet = models.resnet101(pretrained=True)
    num_ftrs = resnet.fc.in_features
    num_classes = 3
    new_fc = nn.Linear(num_ftrs, num_classes)
    resnet.fc = new_fc
    model = resnet.to(device)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    for epoch in range(start_epoch, n_epochs):
        # print(epoch)
        sampler.set_epoch(epoch)
        for batch_idx, (inputs, labels) in enumerate(loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

    CHECKPOINT_PATH = "ddptestmodel.checkpoint"
    if rank == 0:
        torch.save(model.state_dict(), CHECKPOINT_PATH)

    dist.barrier()
    
    dist.destroy_process_group()

if __name__ == "__main__":
    devices = [0, 1, 3]
    ws = len(devices)
    
    mp.spawn(fn, nprocs=ws, args=(ws, devices))