In [3]:
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: {}".format(device))


Device: cuda:0


In [7]:
from torchvision import transforms
import random
import numpy as np
from PIL import Image, ImageOps

class RandomColorize(object):
    """Custom transform to randomize background and font colors"""
    def __init__(self, p=0.5):
        self.p = p
        
    def __call__(self, img):
        if random.random() < self.p:
            # Convert to grayscale first
            img = ImageOps.grayscale(img)
            
            # Random background color (RGB)
            bg_color = (random.randint(0, 255), 
                       random.randint(0, 255), 
                       random.randint(0, 255))
            
            # Random font color (RGB), different from background
            font_color = (random.randint(0, 255), 
                         random.randint(0, 255), 
                         random.randint(0, 255))
            while font_color == bg_color:  # Ensure they're different
                font_color = (random.randint(0, 255), 
                             random.randint(0, 255), 
                             random.randint(0, 255))
            
            # Convert grayscale to RGB
            img = img.convert("RGB")
            
            # Replace black (font) and white (background) with new colors
            data = np.array(img)
            red, green, blue = data[:,:,0], data[:,:,1], data[:,:,2]
            
            # Create masks
            font_mask = (red < 128) | (green < 128) | (blue < 128)
            bg_mask = ~font_mask
            
            # Apply new colors
            data[:,:,:3][font_mask] = font_color
            data[:,:,:3][bg_mask] = bg_color
            
            img = Image.fromarray(data)
            
        return img



In [15]:
import PIL
from torch.utils.data import Dataset


class AdvancedAugmentation:
    """自定义高级数据增强组合"""
    def __init__(self, image_size=28):  # Changed default to 28 to match your transforms
        self.train_transform = transforms.Compose([
            transforms.Resize(28),
            transforms.CenterCrop(28),
            RandomColorize(p=0.5),  # Added color randomization
            transforms.RandomGrayscale(p=0.1),
            transforms.RandomAffine(degrees=20, shear=10, scale=(0.9, 1.1)),
            transforms.GaussianBlur(kernel_size=3),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            transforms.RandomErasing(p=0.5, scale=(0.02, 0.2), ratio=(0.3, 3.3)),
        ])

        self.val_transform = transforms.Compose([
            transforms.Resize(28),
            transforms.CenterCrop(28),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        
class CustomDataset(Dataset):
    """支持缓存、增强、.pt文件的自定义数据集"""
    def __init__(self, pt_file, transform=None, cache=True):
        self.data_label_tuples = torch.load(pt_file)
        self.transform = transform
        self.cache = cache
        self.cached_data = {}

    def __len__(self):
        return len(self.data_label_tuples)

    def __getitem__(self, idx):
        if idx in self.cached_data and self.cache:
            return self.cached_data[idx]
        
        img, label = self.data_label_tuples[idx]
        if self.transform:
            img = self.transform(img)
        
        if self.cache:
            self.cached_data[idx] = (img, label)
        
        return img, label

def get_dataloaders(data_dir, train_env='train1', val_env='test1',
                    batch_size=32, image_size=224, num_workers=4, pin_memory=True):
    """
    加载 ColoredMNIST .pt 数据，并应用增强
    """
    aug = AdvancedAugmentation(image_size)
    
    train_dataset = CustomDataset(
        pt_file=os.path.join(data_dir, f"{train_env}.pt"),
        transform=aug.train_transform,
        cache=True
    )
    
    val_dataset = CustomDataset(
        pt_file=os.path.join(data_dir, f"{val_env}.pt"),
        transform=aug.val_transform,
        cache=False
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=True,
        persistent_workers=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    
    return train_loader, val_loader


In [17]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.utils import make_grid
from PIL import Image, ImageOps

class RandomColorize(object):
    """独立定义的颜色随机化变换"""
    def __init__(self, p=0.5):
        self.p = p
        
    def __call__(self, img):
        if random.random() < self.p:
            img = ImageOps.grayscale(img)
            bg_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            font_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            while font_color == bg_color:
                font_color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
            
            img = img.convert("RGB")
            data = np.array(img)
            red, green, blue = data[:,:,0], data[:,:,1], data[:,:,2]
            font_mask = (red < 128) | (green < 128) | (blue < 128)
            bg_mask = ~font_mask
            data[:,:,:3][font_mask] = font_color
            data[:,:,:3][bg_mask] = bg_color
            img = Image.fromarray(data)
        return img

def visualize_augmentations(tensor_data, transform, num_samples=10):
    """独立可视化函数"""
    num_samples = min(num_samples, len(tensor_data))
    samples = tensor_data[:num_samples]
    
    # 应用变换
    augmented_samples = [transform(img) for img in samples]
    
    # 反归一化
    def denormalize(tensor):
        mean = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
        std = torch.tensor([0.5, 0.5, 0.5]).view(3, 1, 1)
        return tensor * std + mean
    
    # 准备显示
    grid_original = make_grid(samples, nrow=num_samples)
    grid_augmented = make_grid([denormalize(img) for img in augmented_samples], nrow=num_samples)
    
    # 显示
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(grid_original.permute(1, 2, 0).clamp(0, 1))
    plt.title("Original")
    plt.axis('off')
    
    plt.subplot(1, 2, 2)
    plt.imshow(grid_augmented.permute(1, 2, 0).clamp(0, 1))
    plt.title("Augmented")
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [10]:
class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 6, 5, 1, 2),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5, 1, 0),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(16 * 5 * 5, 120),  
            nn.ReLU(True),
            nn.Linear(120, 84),
            nn.ReLU(True),
            nn.Linear(84, 2),
        )

    def forward(self, x):
        return self.model(x)


In [11]:
def train_one_model(data_dir, train_env, epochs=10, batch_size=32, learning_rate=0.01, image_size=224):
    print(f"Loading training data from: {train_env}")
    
    train_loader, _ = get_dataloaders(
        data_dir=data_dir,
        train_env=train_env,
        val_env='test1',  # 这里val_env其实无所谓，反正你只用train_loader
        batch_size=batch_size,
        image_size=image_size,
        num_workers=4
    )

    model = LeNet().to(device)
    loss_function = nn.CrossEntropyLoss().to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(epochs):
        model.train()
        total_correct = 0
        total_samples = 0
        running_loss = 0.0
        for inputs, targets in train_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            total_correct += (outputs.argmax(1) == targets).sum().item()
            total_samples += inputs.size(0)

        epoch_loss = running_loss / total_samples
        epoch_acc = total_correct / total_samples
        print(f"Epoch {epoch+1}/{epochs} - Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")

    return model


In [12]:
def test_model(model, data_dir, test_env, batch_size=1000, image_size=224):
    print(f"Testing on dataset: {test_env}")
    
    _, val_loader = get_dataloaders(
        data_dir=data_dir,
        train_env='train1',  # 随便占位一下
        val_env=test_env,
        batch_size=batch_size,
        image_size=image_size,
        num_workers=4
    )
    
    model.eval()
    total_correct = 0
    total_samples = 0
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            total_correct += (outputs.argmax(1) == targets).sum().item()
            total_samples += inputs.size(0)

    accuracy = total_correct / total_samples
    print(f"Test accuracy: {accuracy:.4f}")
    return accuracy


In [16]:
train_envs = ['train1', 'train2', 'train3']
test_envs = ['test1', 'test2']

results = {}

for train_env in train_envs:
    print("="*50)
    print(f"Training model on {train_env}")
    model = train_one_model('./data/ColoredMNIST', train_env, epochs=10)

    # 测试模型在两个测试集上的表现
    accuracies = {}
    for test_env in test_envs:
        acc = test_model(model, './data/ColoredMNIST', test_env)
        accuracies[test_env] = acc
    results[train_env] = accuracies

# 输出全部结果
print("\nAll results:")
for train_set, test_accs in results.items():
    print(f"Model trained on {train_set}:")
    for test_set, acc in test_accs.items():
        print(f"  Test on {test_set}: Accuracy = {acc:.4f}")


Training model on train1
Loading training data from: train1


  self.data_label_tuples = torch.load(pt_file)


Epoch 1/10 - Loss: 0.6938, Accuracy: 0.5077
Epoch 2/10 - Loss: 0.6932, Accuracy: 0.5059
Epoch 3/10 - Loss: 0.6929, Accuracy: 0.5144
Epoch 4/10 - Loss: 0.6935, Accuracy: 0.4997
Epoch 5/10 - Loss: 0.6932, Accuracy: 0.5130
Epoch 6/10 - Loss: 0.6935, Accuracy: 0.5053
Epoch 7/10 - Loss: 0.6934, Accuracy: 0.5032
Epoch 8/10 - Loss: 0.6932, Accuracy: 0.5147
Epoch 9/10 - Loss: 0.6933, Accuracy: 0.5063
Epoch 10/10 - Loss: 0.6933, Accuracy: 0.5069
Testing on dataset: test1
Test accuracy: 0.5074
Testing on dataset: test2
Test accuracy: 0.5103
Training model on train2
Loading training data from: train2
Epoch 1/10 - Loss: 0.6942, Accuracy: 0.5027
Epoch 2/10 - Loss: 0.6933, Accuracy: 0.5082
Epoch 3/10 - Loss: 0.6937, Accuracy: 0.5046
Epoch 4/10 - Loss: 0.6934, Accuracy: 0.4994
Epoch 5/10 - Loss: 0.6935, Accuracy: 0.5010
Epoch 6/10 - Loss: 0.6934, Accuracy: 0.5000
Epoch 7/10 - Loss: 0.6933, Accuracy: 0.5025
Epoch 8/10 - Loss: 0.6939, Accuracy: 0.5018
Epoch 9/10 - Loss: 0.6933, Accuracy: 0.5068
Epoch 1