In [1]:
import os

import numpy as np
from PIL import Image

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import grad
from torchvision import transforms
from torchvision import datasets
import torchvision.datasets.utils as dataset_utils

In [57]:
def color_grayscale_arr(arr, forground_color, background_color):
    """Converts grayscale image"""
    assert arr.ndim == 2
    dtype = arr.dtype
    h, w = arr.shape
    arr = np.reshape(arr, [h, w, 1])#增加一个“通道”维度
    if background_color == "black":
        if forground_color == "red":
            arr = np.concatenate([arr,
                              np.zeros((h, w, 2), dtype=dtype)], axis=2)#创建全零数组作为绿色和蓝色通道，表示全红色
        elif forground_color == "green":
            arr = np.concatenate([np.zeros((h, w, 1), dtype=dtype),
                              arr,
                              np.zeros((h, w, 1), dtype=dtype)], axis=2)
        elif forground_color == "white":
            arr = np.concatenate([arr, arr, arr], axis=2)
    else:
        if forground_color == "yellow":
            arr = np.concatenate([arr, arr, np.zeros((h, w, 1), dtype=dtype)], axis=2)
        else:
            arr = np.concatenate([np.zeros((h, w, 2), dtype=dtype), arr], axis=2)

        c = [255, 255, 255]
        arr[:, :, 0] = (255 - arr[:, :, 0]) / 255 * c[0]
        arr[:, :, 1] = (255 - arr[:, :, 1]) / 255 * c[1]
        arr[:, :, 2] = (255 - arr[:, :, 2]) / 255 * c[2]

    return arr


class ColoredMNIST(datasets.VisionDataset):

    def __init__(self, root='./data', env='train1', transform=None, target_transform=None):
        super(ColoredMNIST, self).__init__(root, transform=transform,
                                           target_transform=target_transform)

        self.prepare_colored_mnist()
        if env in ['train1', 'train2', 'train3', 'test1', 'test2']:
            self.data_label_tuples = torch.load(os.path.join(self.root, 'ColoredMNIST', env) + '.pt',
                                               weights_only=False)
        elif env == 'all_train':
            train1_data = torch.load(os.path.join(self.root, 'ColoredMNIST', 'train1.pt'),
                                                weights_only=False ) 
            train2_data=torch.load(os.path.join(self.root, 'ColoredMNIST', 'train2.pt'),
                                                weights_only=False)
            train3_data=torch.load(os.path.join(self.root, 'ColoredMNIST', 'train3.pt'),
                                                weights_only=False)
            self.data_label_tuples = train1_data + train2_data + train3_data
        else:
            raise RuntimeError(f'{env} env unknown. Valid envs are train1, train2, train3, test1, test2, and all_train')

    def __getitem__(self, index):
        """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
        img, target = self.data_label_tuples[index]

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data_label_tuples)

    def prepare_colored_mnist(self):
        colored_mnist_dir = os.path.join(self.root, 'ColoredMNIST')
        if os.path.exists(os.path.join(colored_mnist_dir, 'train1.pt')) \
                and os.path.exists(os.path.join(colored_mnist_dir, 'train2.pt')) \
                and os.path.exists(os.path.join(colored_mnist_dir, 'train3.pt')) \
                and os.path.exists(os.path.join(colored_mnist_dir, 'test1.pt')) \
                and os.path.exists(os.path.join(colored_mnist_dir, 'test2.pt')):
            print('Colored MNIST dataset already exists')
            return

        print('Preparing Colored MNIST')
        train_mnist = datasets.mnist.MNIST(self.root, train=True, download=True)

        train1_set = []
        train2_set = []
        train3_set = []
        test1_set, test2_set = [], []
        for idx, (im, label) in enumerate(train_mnist):
            if idx % 10000 == 0:
                print(f'Converting image {idx}/{len(train_mnist)}')
            im_array = np.array(im)
            
            # Assign a binary label y to the image based on the digit
            binary_label = 0 if label < 5 else 1

            # Color the image according to its environment label

            if idx < 10000:
                colored_arr = color_grayscale_arr(im_array, forground_color = "red", background_color = "black")
                train1_set.append((Image.fromarray(colored_arr), binary_label))
            elif idx < 20000:
                colored_arr = color_grayscale_arr(im_array, forground_color = "green", background_color = "black")
                train2_set.append((Image.fromarray(colored_arr), binary_label))
            elif idx < 30000:
                colored_arr = color_grayscale_arr(im_array, forground_color = "white", background_color = "black")
                train3_set.append((Image.fromarray(colored_arr), binary_label))
            elif idx < 45000:
                colored_arr = color_grayscale_arr(im_array, forground_color = "yellow", background_color = "white")
                test1_set.append((Image.fromarray(colored_arr), binary_label))
            else:
                colored_arr = color_grayscale_arr(im_array, forground_color = "blue", background_color = "white")
                test2_set.append((Image.fromarray(colored_arr), binary_label))
                
            # Image.fromarray(colored_arr).save('./data/sample/{}.png'.format(idx))

        if not os.path.exists(colored_mnist_dir):
            os.makedirs(colored_mnist_dir)
        torch.save(train1_set, os.path.join(colored_mnist_dir, 'train1.pt'))
        torch.save(train2_set, os.path.join(colored_mnist_dir, 'train2.pt'))
        torch.save(train3_set, os.path.join(colored_mnist_dir, 'train3.pt'))
        torch.save(test1_set, os.path.join(colored_mnist_dir, 'test1.pt'))
        torch.save(test2_set, os.path.join(colored_mnist_dir, 'test2.pt'))

In [59]:
import torch
from torchvision import transforms

#图像预处理流程
#转换为 Tensor，以及归一化
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

print("Loading")
train_dataset1=ColoredMNIST(root='./data', env='train1', transform=transform)
train_dataset2= ColoredMNIST(root='./data', env='train2', transform=transform)
train_dataset3= ColoredMNIST(root='./data', env='train3', transform=transform)
test_dataset1= ColoredMNIST(root='./data', env='test1', transform=transform)
test_dataset2= ColoredMNIST(root='./data', env='test2', transform=transform)

# 打印数据集大小，确认加载成功
print(f"Size of train_dataset_1: {len(train_dataset1)}")
print(f"Size of test_dataset_2: {len(test_dataset2)}")

Loading
Preparing Colored MNIST


100%|█████████████████████████████████████████████████████████████████████████████| 9.91M/9.91M [05:25<00:00, 30.4kB/s]
100%|█████████████████████████████████████████████████████████████████████████████| 28.9k/28.9k [00:00<00:00, 91.1kB/s]
100%|█████████████████████████████████████████████████████████████████████████████| 1.65M/1.65M [00:27<00:00, 61.0kB/s]
100%|██████████████████████████████████████████████████████████████████████████████| 4.54k/4.54k [00:00<00:00, 940kB/s]


Converting image 0/60000
Converting image 10000/60000
Converting image 20000/60000
Converting image 30000/60000
Converting image 40000/60000
Converting image 50000/60000
Colored MNIST dataset already exists
Colored MNIST dataset already exists
Colored MNIST dataset already exists
Colored MNIST dataset already exists
Size of train_dataset_1: 10000
Size of test_dataset_2: 15000


In [60]:
print(train_dataset1)

Dataset ColoredMNIST
    Number of datapoints: 10000
    Root location: ./data
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
           )


In [63]:
from torch.utils.data import DataLoader,ConcatDataset
#把所有训练集搞在一起
all_train_dataset = ConcatDataset([train_dataset1, train_dataset2, train_dataset3])

BATCH_SIZE = 128
#dataloader
my_train_loader1 = DataLoader(dataset=train_dataset1, batch_size=BATCH_SIZE, shuffle=True)
my_train_loader2 = DataLoader(dataset=train_dataset2, batch_size=BATCH_SIZE, shuffle=True)
my_train_loader3 = DataLoader(dataset=train_dataset3, batch_size=BATCH_SIZE, shuffle=True)
my_test_loader1 = DataLoader(dataset=train_dataset1, batch_size=BATCH_SIZE, shuffle=False)
my_test_loader2 = DataLoader(dataset=test_dataset2, batch_size=BATCH_SIZE, shuffle=False)

print(f"Total training{len(all_train_dataset)}")
print(f"Test1 {len(test_dataset1)}")
print(f"Test2 {len(test_dataset2)}")

Total training30000
Test1 15000
Test2 15000


In [65]:
see=next(iter(my_train_loader))[0][0]
print(see.shape)

torch.Size([3, 28, 28])


In [67]:
#模型结构和前向传播
class LeNet5(nn.Module):
    def __init__(self, num_classes=2):
        super(LeNet5, self).__init__()
        #输入是 3x28x28，输出是类别数
        self.features = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=5, stride=1, padding=2),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2),
            nn.Conv2d(6, 16, kernel_size=5, stride=1),
            nn.Sigmoid(),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
        
        self.classifier = nn.Sequential(
            nn.Linear(16 * 5 * 5, 120),
            nn.Sigmoid(),
            nn.Linear(120, 84),
            nn.Sigmoid(),
            nn.Linear(84, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(-1, 16 * 5 * 5) #展平操作
        x = self.classifier(x)
        return x

In [69]:
my_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {my_device}")

Using device: cuda


In [71]:
my_model=LeNet5(num_classes=2).to(my_device)

In [73]:
#损失函数和优化器
my_loss=nn.CrossEntropyLoss()
my_optimizer=optim.Adam(my_model.parameters(), lr=0.001)


In [75]:
NUM_EPOCHS = 15 # 训练10epoch
def irm_train(model, train_loaders, loss_fn, optimizer, device, num_epochs=NUM_EPOCHS, penalty_weight=1.0):
    print("\nStarting IRM training")
    history = {
        'train_loss': [],
        'train_acc': [],
        'penalty': []
    }
    # 创建用于IRM惩罚的可训练标量参数
    dummy_w = torch.nn.Parameter(torch.tensor([1.0], device=device))
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        total_train = correct_train = 0
        total_penalty = 0.0
        # 动态调整惩罚权重(可选)
        current_penalty_weight = penalty_weight * (epoch ** 1.6 / num_epochs ** 1.6)
        # 获取所有环境的迭代器
        iterators = [iter(loader) for loader in train_loaders]
        while True:
            try:
                optimizer.zero_grad()
                batch_erm_loss = 0.0
                batch_penalty = 0.0
                # 处理每个环境的数据
                try:
                    for env_iter in iterators:
                        images, labels = next(env_iter)
                    
                        images = images.to(device)
                        labels = labels.to(device)
                    # 前向传播
                        outputs = model(images)
                    # 计算ERM损失
                        logits = outputs * dummy_w
                        env_loss = loss_fn(logits, labels)
                        batch_erm_loss += env_loss.mean()
                    # 计算IRM惩罚项
                        batch_penalty += compute_irm_penalty(env_loss, dummy_w)
                except StopIteration:
                    # 某个环境数据已遍历完
                    break
                # 计算总损失
                total_loss = batch_erm_loss + current_penalty_weight * batch_penalty
                # 反向传播和优化
                total_loss.backward()
                optimizer.step()
                # 统计信息
                running_loss += batch_erm_loss.item()
                total_penalty += batch_penalty.item()
                # 计算准确率(使用最后一个环境的输出)
                _, predicted = torch.max(logits.data, 1)
                total_train += labels.size(0)
                correct_train += (predicted == labels).sum().item()
            except StopIteration:
                break
        
        # 计算epoch统计量
        avg_train_loss = running_loss / len(train_loaders[0])
        avg_penalty = total_penalty / len(train_loaders[0])
        train_acc = 100 * correct_train / total_train
        
        # 记录历史
        history['train_loss'].append(avg_train_loss)
        history['train_acc'].append(train_acc)
        history['penalty'].append(avg_penalty)

        print(f"Epoch [{epoch+1}/{num_epochs}], "
              f"Train Loss: {avg_train_loss:.4f}, "
              f"Train Acc: {train_acc:.2f}%, "
              f"Penalty: {avg_penalty:.4f}, "
              f"Penalty Weight: {current_penalty_weight:.2f}")
    print("Finished IRM training")
    return history
def compute_irm_penalty(loss, dummy_w):
    """
    计算IRM惩罚项
    
    参数:
        loss: 每个样本的损失值(形状: [batch_size])
        dummy_w: 可训练标量参数
    """
    # 计算梯度∂loss/∂dummy_w
    grad = torch.autograd.grad(loss.mean(), [dummy_w], create_graph=True)[0]
    
    # 惩罚项是梯度的平方和
    penalty = torch.sum(grad ** 2)
    return penalty
my_history=irm_train(my_model,[my_train_loader1,my_train_loader2,my_train_loader3],my_loss,my_optimizer,my_device)


Starting IRM training
Epoch [1/10], Train Loss: 2.0835, Train Acc: 50.67%, Penalty: 0.0002, Penalty Weight: 0.00
Epoch [2/10], Train Loss: 2.0605, Train Acc: 53.11%, Penalty: 0.0006, Penalty Weight: 0.03
Epoch [3/10], Train Loss: 1.4472, Train Acc: 78.38%, Penalty: 0.0093, Penalty Weight: 0.08
Epoch [4/10], Train Loss: 1.0128, Train Acc: 86.73%, Penalty: 0.0103, Penalty Weight: 0.15
Epoch [5/10], Train Loss: 0.7520, Train Acc: 90.61%, Penalty: 0.0095, Penalty Weight: 0.23
Epoch [6/10], Train Loss: 0.6090, Train Acc: 92.74%, Penalty: 0.0076, Penalty Weight: 0.33
Epoch [7/10], Train Loss: 0.5011, Train Acc: 94.08%, Penalty: 0.0067, Penalty Weight: 0.44
Epoch [8/10], Train Loss: 0.4507, Train Acc: 94.64%, Penalty: 0.0059, Penalty Weight: 0.57
Epoch [9/10], Train Loss: 0.4061, Train Acc: 95.32%, Penalty: 0.0049, Penalty Weight: 0.70
Epoch [10/10], Train Loss: 0.3710, Train Acc: 95.77%, Penalty: 0.0053, Penalty Weight: 0.84
Finished IRM training


In [77]:
#测试
def test_model(model,test_loader,device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)#取最大概率作为结果
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy{accuracy:.2f} %')
    return accuracy

test_model(my_model,my_test_loader1,my_device)
test_model(my_model,my_test_loader2,my_device)

Accuracy94.88 %
Accuracy51.03 %


51.026666666666664