# Federated Learning 적용: NCT-CRC-HE-100K 병리 이미지 분류

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# 이미지 파일 다운
# 1. NCT-CRC-HE-100K.zip -> train dataset (~11.7GB)
!wget "https://zenodo.org/record/1214456/files/NCT-CRC-HE-100K.zip?download=1" -O NCT-CRC-HE-100K.zip

# 2. CRC-VAL-HE-7K.zip -> test dataset (~800MB)
!wget "https://zenodo.org/record/1214456/files/CRC-VAL-HE-7K.zip?download=1" -O CRC-VAL-HE-7K.zip

--2025-06-09 17:10:27--  https://zenodo.org/record/1214456/files/NCT-CRC-HE-100K.zip?download=1
Resolving zenodo.org (zenodo.org)... 188.185.43.25, 188.185.48.194, 188.185.45.92, ...
Connecting to zenodo.org (zenodo.org)|188.185.43.25|:443... connected.
HTTP request sent, awaiting response... 301 MOVED PERMANENTLY
Location: /records/1214456/files/NCT-CRC-HE-100K.zip [following]
--2025-06-09 17:10:28--  https://zenodo.org/records/1214456/files/NCT-CRC-HE-100K.zip
Reusing existing connection to zenodo.org:443.
HTTP request sent, awaiting response... 200 OK
Length: 11690284003 (11G) [application/octet-stream]
Saving to: ‘NCT-CRC-HE-100K.zip’

NCT-CRC-HE-100K.zip   3%[                    ] 429.16M  27.7MB/s    eta 6m 48s ^C
--2025-06-09 17:10:44--  https://zenodo.org/record/1214456/files/CRC-VAL-HE-7K.zip?download=1
Resolving zenodo.org (zenodo.org)... 188.185.43.25, 188.185.48.194, 188.185.45.92, ...
Connecting to zenodo.org (zenodo.org)|188.185.43.25|:443... connected.
HTTP request sent,

In [None]:
# 압축 해제할 디렉토리 생성
!mkdir -p ./data/NCT-CRC-HE-100K
!mkdir -p ./data/CRC-VAL-HE-7K

In [3]:
# zip 파일 압축 해제
!unzip NCT-CRC-HE-100K.zip -d ./data/NCT-CRC-HE-100K
!ls ./data/NCT-CRC-HE-100K
!unzip CRC-VAL-HE-7K.zip -d ./data/CRC-VAL-HE-7K
!ls ./data/CRC-VAL-HE-7K

Archive:  NCT-CRC-HE-100K.zip
replace ./data/NCT-CRC-HE-100K/NCT-CRC-HE-100K/ADI/ADI-AAAMHQMK.tif? [y]es, [n]o, [A]ll, [N]one, [r]ename: N
N
NCT-CRC-HE-100K
Archive:  CRC-VAL-HE-7K.zip
replace ./data/CRC-VAL-HE-7K/CRC-VAL-HE-7K/ADI/ADI-TCGA-AAICEQFN.tif? [y]es, [n]o, [A]ll, [N]one, [r]ename: CRC-VAL-HE-7K


In [4]:
# 경로 설정
original_train_root = './data/NCT-CRC-HE-100K/NCT-CRC-HE-100K'
sorted_train_root = './data/NCT-CRC-HE-100K_sorted'
original_val_root = './data/CRC-VAL-HE-7K/CRC-VAL-HE-7K'
sorted_val_root = './data/CRC-VAL-HE-7K_sorted'

In [None]:
import os
import copy
import torch
import shutil
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder


In [None]:
# 기존 폴더 삭제 후 재생성
shutil.rmtree(sorted_train_root, ignore_errors=True)
shutil.rmtree(sorted_val_root, ignore_errors=True)
os.makedirs(sorted_train_root, exist_ok=True)
os.makedirs(sorted_val_root, exist_ok=True)

In [None]:
# 학습 데이터 복사
for cls in sorted(os.listdir(original_train_root)):
    src = os.path.join(original_train_root, cls)
    dst = os.path.join(sorted_train_root, cls)
    if os.path.isdir(src):
        shutil.copytree(src, dst)

In [None]:
# 검증 데이터 복사
for cls in sorted(os.listdir(original_val_root)):
    src = os.path.join(original_val_root, cls)
    dst = os.path.join(sorted_val_root, cls)
    if os.path.isdir(src):
        shutil.copytree(src, dst)

In [None]:
# Train 폴더 Class 별 Image 개수
def count_images_per_class(root_dir):
    for cls in sorted(os.listdir(root_dir)):
        cls_path = os.path.join(root_dir, cls)
        if os.path.isdir(cls_path):
            num_images = len([f for f in os.listdir(cls_path) if f.endswith('.tif')])
            print(f"{cls}: {num_images} images")

count_images_per_class('./data/NCT-CRC-HE-100K_sorted')

In [None]:
# Test 폴더 내 Class 별 Image 개수
count_images_per_class('./data/CRC-VAL-HE-7K_sorted')

In [None]:
# Data Transform
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.05, contrast=0.05, saturation=0.05, hue=0.05),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])])

#학습 모델

In [None]:
!pip install opacus

In [None]:
# 셀 1: 필수 라이브러리
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import copy
from opacus import PrivacyEngine
import numpy as np

In [None]:
# CNN 모델 정의
class CNNModel(nn.Module):
    def __init__(self, num_classes=9):
        super(CNNModel, self).__init__()

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        # Replace BatchNorm with GroupNorm
        self.gn1 = nn.GroupNorm(32, 32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        # Replace BatchNorm with GroupNorm
        self.gn2 = nn.GroupNorm(32, 64)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        # Replace BatchNorm with GroupNorm
        self.gn3 = nn.GroupNorm(32, 128)

        self.pool = nn.MaxPool2d(2, 2)

        self.dropout_conv = nn.Dropout2d(0.3)  # Conv 뒤 Dropout (feature map dropout)

        # Calculate the flattened size based on the pooling layers
        # Assuming input size of 224x224, after 3 MaxPool layers with kernel size 2 and stride 2,
        # the spatial dimensions will be 224 / (2*2*2) = 224 / 8 = 28
        self.fc1 = nn.Linear(128 * 28 * 28, 512)
        self.dropout_fc = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.gn1(self.conv1(x))))
        x = self.pool(F.relu(self.gn2(self.conv2(x))))
        x = self.pool(F.relu(self.gn3(self.conv3(x))))
        x = self.dropout_conv(x)

        x = x.view(-1, 128 * 28 * 28)
        x = F.relu(self.fc1(x))
        x = self.dropout_fc(x)
        x = self.fc2(x)
        return x

# 모델 초기화
# Use the original number of classes for the CNN model
global_model = CNNModel(num_classes=9)

# 모델 구조 출력
print(global_model)

In [None]:
# 셀 3: 로컬 모델 학습 함수
def train_local_model(model, dataloader, criterion, optimizer, device):
    model.train()
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
    return model

In [None]:
# 셀 4: PrivacyEngine 적용 함수
def make_private(model, dataloader, optimizer, noise_multiplier=1.0, max_grad_norm=1.0):
    privacy_engine = PrivacyEngine()
    model, optimizer, dataloader = privacy_engine.make_private(
        module=model,
        optimizer=optimizer,
        data_loader=dataloader,
        noise_multiplier=noise_multiplier,
        max_grad_norm=max_grad_norm
    )
    return model, optimizer, dataloader

In [None]:
# 셀 5: Federated Averaging 함수
def average_weights(models):
    avg_model = copy.deepcopy(models[0])
    for key in avg_model.keys():
        for i in range(1, len(models)):
            avg_model[key] += models[i][key]
        avg_model[key] = avg_model[key] / len(models)
    return avg_model

In [None]:
# 셀 6: Federated Learning 루프
def federated_learning(global_model, hospitals_dataloaders, device, rounds=5):
    criterion = nn.CrossEntropyLoss()
    global_model.to(device)

    for r in range(rounds):
        print(f"--- Federated Round {r+1} ---")
        local_weights = []

        for i, dataloader in enumerate(hospitals_dataloaders):
            print(f"Training on hospital {i+1}")
            local_model = copy.deepcopy(global_model)
            optimizer = optim.SGD(local_model.parameters(), lr=0.01)

            # DP 적용
            model_dp, optimizer_dp, dataloader_dp = make_private(
                local_model, dataloader, optimizer,
                noise_multiplier=1.0, max_grad_norm=1.0
            )

            trained_model = train_local_model(model_dp, dataloader_dp, criterion, optimizer_dp, device)
            # Get the state dictionary from the unwrapped module
            local_weights.append(copy.deepcopy(trained_model.module.state_dict()))

        avg_weights = average_weights(local_weights)
        global_model.load_state_dict(avg_weights)

    return global_model

In [None]:
# 셀 7: 병원별 데이터 생성 (더미 예시)
input_dim = 20
num_classes = 2
hospital_datasets = []

num_hospitals = 2  # 병원 수
num_data_per_hospital = 30  # 각 병원당 데이터 수
# Assuming image data with 3 channels (RGB) and height/width of 224x224 for the CNN model
image_height = 224
image_width = 224
num_channels = 3

for i in range(num_hospitals):  # 병원 3개
    # Generate dummy image data with shape (batch_size, channels, height, width)
    X = torch.randn(num_data_per_hospital, num_channels, image_height, image_width)
    y = torch.randint(0, num_classes, (num_data_per_hospital,))
    dataset = TensorDataset(X, y)
    loader = DataLoader(dataset, batch_size=10, shuffle=True)
    hospital_datasets.append(loader)

In [None]:
import torch
# 셀 8: 학습 실행
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = CNNModel(num_classes=num_classes)

# Federated Learning 수행
global_model = federated_learning(global_model, hospital_datasets, device, rounds=1)

In [None]:

# 실제 학습 데이터셋 로딩 (ImageFolder)
train_dataset = ImageFolder(root=sorted_train_root, transform=train_transform)
val_dataset = ImageFolder(root=sorted_val_root, transform=val_transform)

print(f"전체 학습 이미지 수: {len(train_dataset)}")
print(f"전체 검증 이미지 수: {len(val_dataset)}")


In [None]:

# 병원 수 설정
num_hospitals = 3
data_per_hospital = len(train_dataset) // num_hospitals
lengths = [data_per_hospital] * (num_hospitals - 1) + [len(train_dataset) - data_per_hospital * (num_hospitals - 1)]

# 데이터 병원별 분할
hospital_subsets = torch.utils.data.random_split(train_dataset, lengths)
hospital_dataloaders = [DataLoader(subset, batch_size=32, shuffle=True) for subset in hospital_subsets]


In [None]:

# GPU/CPU 확인
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = CNNModel(num_classes=len(train_dataset.dataset.classes))

# Federated Learning 수행 (차등 프라이버시 적용 포함)
global_model = federated_learning(global_model, hospital_dataloaders, device, rounds=5)


In [None]:

# 모델 평가 함수
def evaluate(model, dataloader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in dataloader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1)
            correct += (pred == target).sum().item()
            total += target.size(0)
    acc = correct / total
    print(f"Validation Accuracy: {acc * 100:.2f}%")


In [None]:

val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
evaluate(global_model, val_loader, device)
