In [1]:
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
import sys
sys.path.insert(1, os.path.join(sys.path[0], '../'))

In [3]:
import torch

seed = 42
torch.manual_seed(0)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
from tqdm import tqdm
from torchmetrics.image.fid import FrechetInceptionDistance
from torchmetrics.image.inception import InceptionScore

from training import get_config, CIFAR10Dataset,ImgDataset
import utils

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
def get_fid_score(test_loader, gen_loader, feature_dim=2048):
    fid_metric = FrechetInceptionDistance(feature=feature_dim, normalize=True).to(device)

    for (x_test, _) in tqdm(test_loader):
        x_test = x_test.to(device)
        # x_test= x_test.type(torch.uint8)
        fid_metric.update(x_test, real=True)

    for (x_gen, _) in tqdm(gen_loader):
        x_gen = x_gen.to(device)
        # x_gen= x_gen.type(torch.uint8)
        fid_metric.update(x_gen, real=False)

    fid_score = fid_metric.compute()
    return fid_score 

In [6]:
inception_metric = InceptionScore()

def get_inception_score(gen_loader):
    inception_metric = InceptionScore(normalize=True).to(device)
    
    for (x_gen, _) in tqdm(gen_loader):
        x_gen = x_gen.to(device)
        inception_metric.update(x_gen)
        
    inception_score = inception_metric.compute()
    return inception_score



In [7]:
# Define dataset.
dataset_name = 'cifar_10'
# dataset_name = 'fashion_mnist'
#dataset_name = 'svhn'

data_config, train_config = get_config(dataset_name)
print(data_config)
print(train_config)

# Data config.
batch_size = 32
test_transform = data_config['test_transform']

# Training config.
timesteps = train_config['timesteps']
eta = train_config['eta']

# Data.
import torch
import torchvision
import torchvision.transforms as transforms

# 定义数据转换
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 与训练集保持一致
])

# 替换数据集加载方式
data_dir = os.path.join('datasets', dataset_name)

# 加载训练集
train_data = torchvision.datasets.CIFAR10(root=data_dir, train=True, download=True, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, pin_memory=True)

# 加载测试集
test_data = torchvision.datasets.CIFAR10(root=data_dir, train=False, download=True, transform=test_transform)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, pin_memory=True)

{'img_size': 32, 'channels': 3, 'batch_size': 512, 'train_transform': Compose(
    RandomHorizontalFlip(p=0.5)
    ToTensor()
    Resize(size=(32, 32), interpolation=bilinear, max_size=None, antialias=True)
    Normalize(mean=[0.5], std=[0.5])
), 'test_transform': Compose(
    ToTensor()
    Resize(size=(32, 32), interpolation=bilinear, max_size=None, antialias=True)
    Normalize(mean=[0.5], std=[0.5])
)}
{'lr': 0.0002, 'timesteps': 100, 'epochs': 100, 'rounds': 20, 'local_epochs': 5, 'ema_decay': 0.998, 'eta': 1, 'save_interval': 10, 'start_step': 20}
Files already downloaded and verified
Files already downloaded and verified


In [8]:
dataset_name='cifar_10'
centralized_gen_dir = os.path.join('./output/diffusion_cen/', dataset_name, 'generated_img')
centralized_gen_data = ImgDataset(centralized_gen_dir, transform=test_transform)
centralized_gen_loader = torch.utils.data.DataLoader(centralized_gen_data, batch_size=batch_size, shuffle=False, num_workers=0)

cen_fid_score = get_fid_score(train_loader, centralized_gen_loader)
print(cen_fid_score)
print('{:.2f}'.format(cen_fid_score))

cen_fid_score = get_fid_score(test_loader, centralized_gen_loader)
print(cen_fid_score)
print('{:.2f}'.format(cen_fid_score))

cen_inception_score = get_inception_score(centralized_gen_loader)
print(cen_inception_score)
(mean, std) = cen_inception_score
print('mean: {:.2f}, std: {:.2f}'.format(mean, std))       

100%|██████████| 1563/1563 [01:21<00:00, 19.19it/s]
100%|██████████| 313/313 [00:16<00:00, 19.26it/s]


tensor(264.7091, device='cuda:0')
264.71


100%|██████████| 313/313 [00:16<00:00, 19.21it/s]
100%|██████████| 313/313 [00:16<00:00, 19.29it/s]


tensor(265.2239, device='cuda:0')
265.22


100%|██████████| 313/313 [00:16<00:00, 19.46it/s]

(tensor(1.5137, device='cuda:0'), tensor(0.0119, device='cuda:0'))
mean: 1.51, std: 0.01





In [9]:
beta = 5                     # 5, 0.5, 0.1              
num_clients = 10               # 10, 30, 50
num_local_epochs = 5           # 1, 5, 10

fedavg_gen_dir = os.path.join('./output/diffusion_fedavg/', '{}_b_{}_c_{}_le_{}'.format(dataset_name, beta, num_clients, num_local_epochs), 'generated_img')
print(fedavg_gen_dir)
fedavg_gen_data = CIFAR10Dataset(fedavg_gen_dir, transform=test_transform)
fedavg_gen_loader = torch.utils.data.DataLoader(fedavg_gen_data, batch_size=batch_size, shuffle=False, num_workers=12, prefetch_factor=12)

fedavg_train_fid_score = get_fid_score(train_loader, fedavg_gen_loader)
print(fedavg_train_fid_score)
print('{:.2f}'.format(fedavg_train_fid_score))

fedavg_test_fid_score = get_fid_score(test_loader, fedavg_gen_loader)
print(fedavg_test_fid_score)
print('{:.2f}'.format(fedavg_test_fid_score))

./output/diffusion_fedavg/cifar_10_b_5_c_10_le_5\generated_img


FileNotFoundError: [WinError 3] 系统找不到指定的路径。: './output/diffusion_fedavg/cifar_10_b_5_c_10_le_5\\generated_img'

In [None]:
fedavg_inception_score = get_inception_score(fedavg_gen_loader)
print(fedavg_inception_score)
(mean, std) = fedavg_inception_score
print('mean: {:.2f}, std: {:.2f}'.format(mean, std))

In [None]:
# imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8).to(device)
# imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8).to(device)

# imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8).to(device)
# imgs_dist2 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8).to(device)

# fid = FrechetInceptionDistance(feature=64).to(device)
# fid.update(imgs_dist1, real=True)
# fid.update(imgs_dist2, real=False)
# fid.compute()

In [None]:
# FrechetInceptionDistance()

### Medical

In [None]:
dataset_name = 'sars_cov_2_ct_scan'

data_config, train_config = get_config(dataset_name)
print(data_config)
print(train_config)

# Data config.
batch_size = 96
train_transform = data_config['train_transform']

# Training config.
timesteps = train_config['timesteps']
eta = train_config['eta']

data_dir = os.path.join('../datasets/', dataset_name)

train_dir = os.path.join(data_dir)
train_data = ImgDataset(train_dir, transform=train_transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=12, prefetch_factor=12)

In [None]:
centralized_gen_dir = os.path.join('./output/diffusion_cen/', dataset_name, 'generated_img')
centralized_gen_data = ImgDataset(centralized_gen_dir, transform=train_transform)
centralized_gen_loader = torch.utils.data.DataLoader(centralized_gen_data, batch_size=batch_size, shuffle=False, num_workers=12, prefetch_factor=12)

cen_fid_score = get_fid_score(train_loader, centralized_gen_loader)
print(cen_fid_score)
print('{:.2f}'.format(cen_fid_score))

cen_inception_score = get_inception_score(centralized_gen_loader)
print(cen_inception_score)
(mean, std) = cen_inception_score
print('mean: {:.2f}, std: {:.2f}'.format(mean, std))

In [None]:
# test = get_fid_score(fedavg_gen_loader, centralized_gen_loader)
# print(test)
# print('{:.2f}'.format(test))

In [None]:
beta = 0.5                     # 5, 0.5, 0.1              
num_clients = 10               # 10, 30, 50
num_local_epochs = 5           # 1, 5, 10

fedavg_gen_dir = os.path.join('./output/diffusion_fedavg/', '{}_b_{}_c_{}_le_{}'.format(dataset_name, beta, num_clients, num_local_epochs), 'generated_img')
print(fedavg_gen_dir)
fedavg_gen_data = ImgDataset(fedavg_gen_dir, transform=train_transform)
fedavg_gen_loader = torch.utils.data.DataLoader(fedavg_gen_data, batch_size=batch_size, shuffle=False, num_workers=12, prefetch_factor=12)

fedavg_train_fid_score = get_fid_score(train_loader, fedavg_gen_loader)
print(fedavg_train_fid_score)
print('{:.2f}'.format(fedavg_train_fid_score))

fedavg_inception_score = get_inception_score(fedavg_gen_loader)
print(fedavg_inception_score)
(mean, std) = fedavg_inception_score
print('mean: {:.2f}, std: {:.2f}'.format(mean, std))

In [None]:
fedavg_gen_data[0][0].shape