In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from amp.data.diffusion.schedules import LinearSchedule, CosineSchedule, QuadraticSchedule
from amp.data.diffusion.forward import DiffusionForward, DiscreteDataDiffusion

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## 测试连续数据的扩散过程

我们将使用一个简单的2D高斯分布数据来测试连续数据的扩散过程。我们会：
1. 生成初始数据
2. 应用不同的noise schedule
3. 观察不同时间步的扩散效果

In [None]:
def create_gaussian_mixture(n_samples=1000):
    mean1 = torch.tensor([2.0, 2.0])
    mean2 = torch.tensor([-2.0, -2.0])
    std = 0.5
    
    samples1 = torch.randn(n_samples // 2, 2) * std + mean1
    samples2 = torch.randn(n_samples // 2, 2) * std + mean2
    return torch.cat([samples1, samples2], dim=0)

# 创建数据和时间步
x_0 = create_gaussian_mixture()
timesteps = [0, 100, 500, 999]  # 不同的时间步

# 创建不同的schedule
num_timesteps = 1000
schedules = {
    'Linear': LinearSchedule(num_timesteps),
    'Cosine': CosineSchedule(num_timesteps),
    'Quadratic': QuadraticSchedule(num_timesteps)
}

# 可视化不同schedule下的扩散过程
fig, axes = plt.subplots(len(schedules), len(timesteps), figsize=(16, 12))

for i, (schedule_name, schedule) in enumerate(schedules.items()):
    diffusion = DiffusionForward(schedule)
    
    for j, t in enumerate(timesteps):
        t_tensor = torch.ones(x_0.shape[0], dtype=torch.long) * t
        x_t, _ = diffusion.q_sample(x_0, t_tensor)
        
        ax = axes[i, j]
        ax.scatter(x_t[:, 0].numpy(), x_t[:, 1].numpy(), alpha=0.5, s=1)
        ax.set_xlim(-4, 4)
        ax.set_ylim(-4, 4)
        
        if i == 0:
            ax.set_title(f't={t}')
        if j == 0:
            ax.set_ylabel(schedule_name)

plt.tight_layout()
plt.show()

## 测试离散数据的扩散过程

接下来我们测试离散数据的扩散过程。我们将：
1. 创建一个简单的离散序列数据
2. 使用不同的schedule进行扩散
3. 观察token分布的变化

In [None]:
# 设置参数
vocab_size = 10
seq_length = 20
batch_size = 16

# 创建一些示例序列数据（每个位置都偏好某些特定的token）
def create_biased_sequences(batch_size, seq_length, vocab_size):
    # 创建一个偏好模式：每个位置倾向于使用特定的token
    position_biases = torch.arange(seq_length) % (vocab_size // 2)
    sequences = position_biases.repeat(batch_size, 1)
    # 添加一些随机性
    mask = torch.rand(batch_size, seq_length) > 0.7
    random_tokens = torch.randint(0, vocab_size, (batch_size, seq_length))
    sequences[mask] = random_tokens[mask]
    return sequences

# 创建初始序列
x_0 = create_biased_sequences(batch_size, seq_length, vocab_size)

# 使用余弦schedule进行测试（通常对离散数据效果较好）
schedule = CosineSchedule(num_timesteps)
discrete_diffusion = DiscreteDataDiffusion(schedule, vocab_size)

# 可视化不同时间步的token分布
fig, axes = plt.subplots(1, len(timesteps), figsize=(20, 4))

for j, t in enumerate(timesteps):
    t_tensor = torch.ones(batch_size, dtype=torch.long) * t
    x_t, _ = discrete_diffusion.q_sample(x_0, t_tensor)
    
    # 计算每个位置上token的平均分布
    avg_distribution = x_t.mean(dim=0).numpy()
    
    ax = axes[j]
    im = ax.imshow(avg_distribution, aspect='auto', cmap='viridis')
    ax.set_title(f't={t}')
    ax.set_xlabel('Token ID')
    ax.set_ylabel('Position')
    plt.colorbar(im, ax=ax)

plt.tight_layout()
plt.show()

# 打印某个特定位置的token分布变化
position = 0
print(f"\nPosition {position} 的token分布变化:")
for t in timesteps:
    t_tensor = torch.ones(batch_size, dtype=torch.long) * t
    x_t, _ = discrete_diffusion.q_sample(x_0, t_tensor)
    distribution = x_t[:, position].mean(dim=0)
    print(f"t={t}:")
    print(distribution.numpy().round(3))

## 测试数据加载器（DataLoader）

我们将测试以下几种场景：
1. 基础静态数据集
2. 基础可迭代数据集
3. 多任务静态数据集
4. 多任务可迭代数据集

In [2]:
import torch
from amp.trainer.dataloader import (
    BaseDataset,
    IterableBaseDataset,
    BaseDataLoader,
    MultiTaskDataset,
    MultiTaskIterableDataset,
    MultiTaskDataLoader
)

### 1. 测试基础静态数据集

创建一个简单的静态数据集，模拟图像分类任务

In [3]:
class SimpleImageDataset(BaseDataset):
    def __init__(self, num_samples=100):
        super().__init__()
        # 模拟图像数据 (100, 3, 32, 32)
        self.images = torch.randn(num_samples, 3, 32, 32)
        # 模拟标签数据 (100,)
        self.labels = torch.randint(0, 10, (num_samples,))
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index):
        return {
            'image': self.images[index],
            'label': self.labels[index]
        }

# 创建数据集和数据加载器
dataset = SimpleImageDataset(num_samples=100)
dataloader = BaseDataLoader(
    dataset=dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
)

# 测试数据加载
for batch_idx, batch in enumerate(dataloader):
    if batch_idx == 0:  # 只打印第一个batch
        print(f"Batch size: {batch['image'].shape[0]}")
        print(f"Image shape: {batch['image'].shape}")
        print(f"Label shape: {batch['label'].shape}")
        break

Batch size: 16
Image shape: torch.Size([16, 3, 32, 32])
Label shape: torch.Size([16])




### 2. 测试基础可迭代数据集

创建一个动态生成数据的数据集，模拟流式数据处理

In [4]:
class StreamDataset(IterableBaseDataset):
    def __init__(self, max_samples=100):
        super().__init__()
        self.max_samples = max_samples
        
    def __iter__(self):
        for _ in range(self.max_samples):
            # 动态生成数据
            feature = torch.randn(10)  # 10维特征
            target = torch.sum(feature) > 0  # 二分类任务
            yield {
                'feature': feature,
                'target': target
            }

# 创建数据集和数据加载器
stream_dataset = StreamDataset(max_samples=100)
stream_loader = BaseDataLoader(
    dataset=stream_dataset,
    batch_size=16,
    num_workers=0  # 流式数据集通常使用单进程
)

# 测试数据加载
for batch_idx, batch in enumerate(stream_loader):
    if batch_idx == 0:  # 只打印第一个batch
        print(f"Batch size: {batch['feature'].shape[0]}")
        print(f"Feature shape: {batch['feature'].shape}")
        print(f"Target shape: {batch['target'].shape}")
        break

Batch size: 16
Feature shape: torch.Size([16, 10])
Target shape: torch.Size([16])


### 3. 测试多任务静态数据集

创建一个多输入多输出的数据集，模拟多任务学习场景

In [5]:
# 创建模拟数据
num_samples = 100

# 输入数据
inputs = {
    'image': torch.randn(num_samples, 3, 64, 64),  # 图像输入
    'text': torch.randint(0, 1000, (num_samples, 20))  # 文本输入
}

# 目标数据
targets = {
    'class': torch.randint(0, 10, (num_samples,)),  # 分类任务
    'bbox': torch.randn(num_samples, 4),  # 检测任务
    'mask': torch.randint(0, 2, (num_samples, 64, 64))  # 分割任务
}

# 创建数据集和数据加载器
multi_dataset = MultiTaskDataset(
    inputs=inputs,
    targets=targets
)

multi_loader = MultiTaskDataLoader(
    dataset=multi_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=2
)

# 测试数据加载
for batch_idx, batch in enumerate(multi_loader):
    if batch_idx == 0:  # 只打印第一个batch
        print("输入数据:")
        print(f"- Image shape: {batch['input_image'].shape}")
        print(f"- Text shape: {batch['input_text'].shape}")
        print("\n输出数据:")
        print(f"- Class shape: {batch['target_class'].shape}")
        print(f"- BBox shape: {batch['target_bbox'].shape}")
        print(f"- Mask shape: {batch['target_mask'].shape}")
        break

输入数据:
- Image shape: torch.Size([8, 3, 64, 64])
- Text shape: torch.Size([8, 20])

输出数据:
- Class shape: torch.Size([8])
- BBox shape: torch.Size([8, 4])
- Mask shape: torch.Size([8, 64, 64])


### 4. 测试多任务可迭代数据集

创建一个动态生成多任务数据的数据集

In [6]:
def multi_task_generator(worker_id):
    # 通过 worker_id 设置种子，确保不同 worker 生成不同样本
    torch.manual_seed(worker_id + 42)
    
    inputs = (
        torch.randn(3, 32, 32),
        torch.randint(0, 100, (10,))
    )
    targets = (
        torch.randint(0, 5, (1,))[0],
        torch.randn(2)
    )
    return inputs, targets


# 创建数据集和数据加载器
stream_multi_dataset = MultiTaskIterableDataset(
    data_generator=multi_task_generator,
    input_names=['image', 'text'],
    target_names=['class', 'regression']
)

stream_multi_loader = MultiTaskDataLoader(
    dataset=stream_multi_dataset,
    batch_size=4,
    num_workers=1  # 流式数据集使用单进程
)

# 测试数据加载
for batch_idx, batch in enumerate(stream_multi_loader):
    if batch_idx == 0:  # 只打印第一个batch
        print("输入数据:")
        print(f"- Image shape: {batch['input_image'].shape}")
        print(f"- Text shape: {batch['input_text'].shape}")
        print("\n输出数据:")
        print(f"- Class shape: {batch['target_class'].shape}")
        print(f"- Regression shape: {batch['target_regression'].shape}")
        break

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/wang-work/miniconda3/envs/amp/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/wang-work/miniconda3/envs/amp/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'multi_task_generator' on <module '__main__' (built-in)>


RuntimeError: DataLoader worker (pid(s) 12258) exited unexpectedly

## 测试多进程数据加载

我们将测试不同worker数量对数据加载速度的影响。

In [None]:
import time
import torch.multiprocessing as mp

def test_dataloader_speed(num_workers):
    # 创建一个较大的数据集来测试
    dataset = SimpleImageDataset(num_samples=10000)
    
    dataloader = BaseDataLoader(
        dataset=dataset,
        batch_size=32,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    start_time = time.time()
    
    # 遍历整个数据集
    total_samples = 0
    for batch in dataloader:
        total_samples += batch['image'].shape[0]
    
    end_time = time.time()
    return end_time - start_time, total_samples

# 测试不同的worker数量
worker_nums = [0, 1, 2, 4, mp.cpu_count()]
results = []

print("\n测试不同worker数量的加载时间:")
print("-" * 40)
print(f"{'Workers':>8} | {'Time (s)':>10} | {'Samples/s':>10}")
print("-" * 40)

for num_workers in worker_nums:
    time_taken, total_samples = test_dataloader_speed(num_workers)
    samples_per_second = total_samples / time_taken
    results.append((num_workers, time_taken, samples_per_second))
    print(f"{num_workers:8d} | {time_taken:10.2f} | {samples_per_second:10.0f}")

# 绘制结果
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))

# 时间对比
plt.subplot(1, 2, 1)
workers = [r[0] for r in results]
times = [r[1] for r in results]
plt.plot(workers, times, 'o-')
plt.xlabel('Number of Workers')
plt.ylabel('Time (seconds)')
plt.title('Loading Time vs Number of Workers')

# 吞吐量对比
plt.subplot(1, 2, 2)
speeds = [r[2] for r in results]
plt.plot(workers, speeds, 'o-')
plt.xlabel('Number of Workers')
plt.ylabel('Samples per Second')
plt.title('Throughput vs Number of Workers')

plt.tight_layout()
plt.show()


测试不同worker数量的加载时间:
----------------------------------------
 Workers |   Time (s) |  Samples/s
----------------------------------------
       0 |       0.02 |     419531
       0 |       0.02 |     419531


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/wang-work/miniconda3/envs/amp/lib/python3.10/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/wang-work/miniconda3/envs/amp/lib/python3.10/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'SimpleImageDataset' on <module '__main__' (built-in)>


RuntimeError: DataLoader worker (pid(s) 12156) exited unexpectedly

### 测试多任务数据集的多进程加载

In [7]:
import time
import matplotlib.pyplot as plt
def test_multi_task_loader_speed(num_workers):
    # 创建较大的多任务数据集
    num_samples = 100000
    
    # 输入数据
    inputs = {
        'image': torch.randn(num_samples, 3, 64, 64),
        'text': torch.randint(0, 1000, (num_samples, 20))
    }
    
    # 目标数据
    targets = {
        'class': torch.randint(0, 10, (num_samples,)),
        'bbox': torch.randn(num_samples, 4),
        'mask': torch.randint(0, 2, (num_samples, 64, 64))
    }
    
    # 创建数据集和数据加载器
    dataset = MultiTaskDataset(
        inputs=inputs,
        targets=targets
    )
    
    loader = MultiTaskDataLoader(
        dataset=dataset,
        batch_size=1024,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )
    
    start_time = time.time()
    
    # 遍历整个数据集
    total_samples = 0
    for batch in loader:
        total_samples += batch['input_image'].shape[0]
    
    end_time = time.time()
    return end_time - start_time, total_samples

# 测试不同的worker数量
print("\n测试多任务数据集在不同worker数量下的加载时间:")
print("-" * 50)
print(f"{'Workers':>8} | {'Time (s)':>10} | {'Samples/s':>10}")
print("-" * 50)

results_multi = []
for num_workers in worker_nums:
    try:
        time_taken, total_samples = test_multi_task_loader_speed(num_workers)
        samples_per_second = total_samples / time_taken
        results_multi.append((num_workers, time_taken, samples_per_second))
        print(f"{num_workers:8d} | {time_taken:10.2f} | {samples_per_second:10.0f}")
    except Exception as e:
        print(f"{num_workers:8d} | Failed: {str(e)}")

# 绘制结果
plt.figure(figsize=(12, 4))

# 时间对比
plt.subplot(1, 2, 1)
workers = [r[0] for r in results_multi]
times = [r[1] for r in results_multi]
plt.plot(workers, times, 'o-')
plt.xlabel('Number of Workers')
plt.ylabel('Time (seconds)')
plt.title('Multi-task Loading Time vs Number of Workers')

# 吞吐量对比
plt.subplot(1, 2, 2)
speeds = [r[2] for r in results_multi]
plt.plot(workers, speeds, 'o-')
plt.xlabel('Number of Workers')
plt.ylabel('Samples per Second')
plt.title('Multi-task Throughput vs Number of Workers')

plt.tight_layout()
plt.show()


测试多任务数据集在不同worker数量下的加载时间:
--------------------------------------------------
 Workers |   Time (s) |  Samples/s
--------------------------------------------------


NameError: name 'worker_nums' is not defined