In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image  # 使用 Pillow 的 Image 模块
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# If to use TPU on Google Colab 
# import torch_xla
# import torch_xla.core.xla_model as xm
# DEVICE = xm.xla_device()

DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
print(DEVICE)

In [None]:
# 自定义数据集类
class ImageParamsDataset(Dataset):
    def __init__(self, image_dir, params_dir, transform=None):
        self.image_dir = image_dir
        self.params_dir = params_dir
        self.transform = transform
        self.image_files = os.listdir(image_dir)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        # 加载图像
        image_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(image_path).convert('RGB')  # 使用 Pillow 加载图像

        # 加载参数
        params_path = os.path.join(self.params_dir, f'params_{idx + 1}.csv')
        params = pd.read_csv(params_path, header=None).values.flatten().astype('float32')

        # 图像预处理
        if self.transform:
            image = self.transform(image)

        return image, torch.tensor(params)

# 定义图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # 调整图像大小
    transforms.ToTensor(),           # 转换为张量并归一化到 [0, 1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # 标准化
])

# 创建数据集和数据加载器
# dataset = ImageParamsDataset(image_dir='dataset/images', params_dir='dataset/params', # transform=transform)
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

In [None]:
class RandomImageParamsDataset(Dataset):
    def __init__(self, num_images, image_size=(224, 224), num_channels=3):
        self.num_images = num_images
        self.image_size = image_size
        self.num_channels = num_channels

    def __len__(self):
        return self.num_images

    def __getitem__(self, idx):
        # Generate a random image
        image = np.zeros((self.num_channels, *self.image_size), dtype=np.float32)
        for c in range(self.num_channels):
            loc = np.random.uniform(0.3, 0.7)
            scale = np.random.uniform(0.001, 0.1)
            image[c] = np.random.normal(loc=loc, scale=scale, size=self.image_size).astype(np.float32)
        image = np.clip(image, 0, 1)  # Ensure pixel values are within [0, 1]
        
        mean_values = image.mean(axis=(1, 2))  # Calculate mean for each channel (R, G, B)
        std_values = image.std(axis=(1, 2))  # Calculate standard deviation for each channel (R, G, B)

        # Convert image and mean values to tensors
        image_tensor = torch.tensor(image, dtype=torch.float32)
        
        labels = torch.tensor(np.array([std_values[0]]), dtype=torch.float32)
        
        return image_tensor.to(DEVICE), labels.to(DEVICE)

# Example usage
dataset = RandomImageParamsDataset(num_images=1024)

# Create a DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True) # running on RTX 3060 Laptop GPU with 6GB VRAM

# Iterate through the DataLoader
for images, labels in dataloader:
    print(images.shape)  # Should be [batch_size, 3, 224, 224]
    print(labels)  # Should be [batch_size, 3]
    break

In [None]:
import torch.nn as nn
import torch.nn.functional as F

class ImageProcessingModel(nn.Module):
    def __init__(self):
        super(ImageProcessingModel, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(128 * 28 * 28, 512)  # 假设输入图像大小为 224x224
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, 1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 28 * 28)  # 展平
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# 初始化模型
model = ImageProcessingModel().to(DEVICE)

In [None]:
# 定义保存检查点的函数
def save_checkpoint(epoch, model, optimizer, losses, filename='checkpoint.pth.tar'):
    state = {
        'epoch': epoch,
        'state_dict': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'losses': losses
    }
    torch.save(state, filename)
    print('Checkpoint saved')

    # 定义加载检查点的函数
def load_checkpoint(filename='checkpoint.pth.tar'):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epoch = checkpoint['epoch']
        losses = checkpoint['losses']
        print(f"Checkpoint loaded: epoch {epoch}, loss {losses[-1]:.8e}")
        return epoch, losses
    else:
        print("No checkpoint found.")
        return 0, []

# 绘制损失曲线
def plot_loss_curve(losses):
    global shown
    plt.plot(range(0, len(losses)), losses, label='Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training Loss over Epochs')
    plt.legend()
    plt.draw()
    plt.show()

# 定义损失函数和优化器
criterion = nn.MSELoss()  # 均方误差损失
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [None]:
import time

# 训练循环
num_epochs = 64

# 如果发现检查点文件存在，读取检查点
start_epoch, losses = load_checkpoint()

plt.show()

# 记录每个 epoch 的损失
for epoch in range(start_epoch, num_epochs):
    start_time = time.time()
    epoch_loss = 0
    for i, (images, labels) in enumerate(dataloader):
        # 前向传播
        outputs = model(images)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if i % 8 == 7:  # Print every 8 batches
            print(f'Epoch [{epoch}/{num_epochs}], Batch [{i + 1}/{len(dataloader)}], Loss: {loss :.8e}')

    end_time = time.time()
    elapsed_time = end_time - start_time

    epoch_loss /= len(dataloader)
    losses.append(epoch_loss)

    print(f'Epoch [{epoch}/{num_epochs}] ends. Loss: {epoch_loss:.8e}; Time: {elapsed_time:.2f} second / epoch')

    
    # 保存检查点 / plot
    if (epoch + 1) % 8 == 0:
        save_checkpoint(epoch, model, optimizer, losses)
        plot_loss_curve(losses)

print('Training complete.')

In [None]:
test_dataset = RandomImageParamsDataset(num_images=100)
test_dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Iterate through the DataLoader
for images, labels in test_dataloader:
    # Forward pass
    outputs = model(images)
    loss = criterion(outputs, labels)
    # Print the outputs and the corresponding labels
    print("Outputs | Labels:", torch.hstack((outputs, labels)))
    print(f"Loss: {loss:.8e}")
    break  # Remove this line if you want to iterate through the entire dataset