In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
from tqdm import tqdm

In [None]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")

In [None]:
# Using pretrained ResNet50
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, input_channels):
        super(ResNetFeatureExtractor, self).__init__()
        if input_channels == 3:
            self.resnet = models.resnet18(pretrained=True)
        else:
            self.resnet = models.resnet18(pretrained=True)
            self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])  # Remove the last fully connected layer
        self.fc = nn.Linear(512, 256)

    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Define two networks for RGB and depth images
rgb_net = ResNetFeatureExtractor(input_channels=3).to(device)
depth_net = ResNetFeatureExtractor(input_channels=1).to(device)

In [None]:
# Define projection layers to project features into a common alignment space
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        return x

projection_rgb = ProjectionHead(256, 128).to(device)
projection_depth = ProjectionHead(256, 128).to(device)


In [None]:

# 定义对比损失函数（如CLIP中常用的对比损失）
def contrastive_loss(features1, features2):
    # L2 正则化
    features1 = F.normalize(features1, p=2, dim=1)
    features2 = F.normalize(features2, p=2, dim=1)
    # 计算相似度矩阵
    logits = torch.matmul(features1, features2.T)
    labels = torch.arange(features1.size(0)).to(features1.device)
    loss = F.cross_entropy(logits, labels)
    
    return loss

# Define the NT-Xent Loss (used in SimCLR's contrastive loss)
def nt_xent_loss(features1, features2, temperature=0.5):
    # L2 normalization
    features1 = F.normalize(features1, p=2, dim=1)
    features2 = F.normalize(features2, p=2, dim=1)
    
    # Concatenate features
    features = torch.cat([features1, features2], dim=0)
    
    # Compute similarity matrix
    similarity_matrix = torch.matmul(features, features.T) / temperature
    
    # Create labels
    batch_size = features1.size(0)
    labels = torch.arange(batch_size).to(device)
    labels = torch.cat([labels, labels], dim=0)
    
    # Remove diagonal elements (self-comparison)
    mask = torch.eye(2 * batch_size, dtype=bool).to(device)
    similarity_matrix = similarity_matrix.masked_fill(mask, -float('inf'))
    
    # Compute loss
    positives = torch.cat([torch.diag(similarity_matrix, batch_size), torch.diag(similarity_matrix, -batch_size)])
    negatives = similarity_matrix[~mask].view(2 * batch_size, -1)
    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
    
    labels = torch.zeros(2 * batch_size, dtype=torch.long).to(device)
    loss = F.cross_entropy(logits, labels)
    
    return loss

In [None]:

# Define the multimodal dataset
class MultiModalDataset(Dataset):
    def __init__(self, rgb_root, depth_root, transform_rgb, transform_rgb_aug, transform_depth, transform_depth_aug):
        # Load RGB and depth datasets
        self.rgb_dataset = datasets.ImageFolder(root=rgb_root, transform=transform_rgb)
        self.rgb_augmented_dataset = datasets.ImageFolder(root=rgb_root, transform=transform_rgb_aug)
        self.depth_dataset = datasets.ImageFolder(root=depth_root, transform=transform_depth)
        self.depth_augmented_dataset = datasets.ImageFolder(root=depth_root, transform=transform_depth_aug)
        
        # Ensure all datasets have the same length
        assert len(self.rgb_dataset) == len(self.rgb_augmented_dataset) == len(self.depth_dataset) == len(self.depth_augmented_dataset), "Datasets have different lengths"

    def __len__(self):
        return len(self.rgb_dataset)

    def __getitem__(self, idx):
        rgb_image, label = self.rgb_dataset[idx]
        rgb_aug_image, _ = self.rgb_augmented_dataset[idx]
        depth_image, _ = self.depth_dataset[idx]
        depth_aug_image, _ = self.depth_augmented_dataset[idx]
        return rgb_image, rgb_aug_image, depth_image, depth_aug_image, label



# Define data augmentation transformations (SimCLR style)
transform_rgb_augment = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.ToTensor(),
])

transform_depth_augment = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# Basic transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Basic transformation
transform_gray = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Load the dataset
multi_modal_dataset = MultiModalDataset(
    rgb_root='./affectnet_3750subset/train',
    depth_root='./affectnet_3750subset/train',
    transform_rgb=transform,
    transform_rgb_aug=transform_rgb_augment,
    transform_depth=transform_gray,
    transform_depth_aug=transform_depth_augment
)

multi_modal_loader = DataLoader(multi_modal_dataset, batch_size=64, shuffle=True, num_workers=4)

In [None]:
print("Start training")
# Optimizer
optimizer = optim.Adam(list(rgb_net.parameters()) + list(depth_net.parameters()) +
                       list(projection_rgb.parameters()) + list(projection_depth.parameters()), lr=1e-4)

# Training
rgb_net.train()
depth_net.train()
projection_rgb.train()
projection_depth.train()

epochs = 24
for epoch in range(epochs):
    epoch_loss = 0
    for rgb_images, rgb_augmented_images, depth_images, depth_augmented_images, labelq in tqdm(multi_modal_loader, desc=f"Epoch [{epoch+1}/{epochs}]"):
        # Move data to GPU
        rgb_images = rgb_images.to(device)
        rgb_augmented_images = rgb_augmented_images.to(device)
        depth_images = depth_images.to(device)
        depth_augmented_images = depth_augmented_images.to(device)

        # Contrastive learning with augmented images (RGB)
        rgb_features_1 = rgb_net(rgb_augmented_images)
        rgb_features_2 = rgb_net(rgb_images)
        rgb_projection_1 = projection_rgb(rgb_features_1)
        rgb_projection_2 = projection_rgb(rgb_features_2)
        loss_rgb = nt_xent_loss(rgb_projection_1, rgb_projection_2)

        # Contrastive learning with augmented images (Depth)
        depth_features_1 = depth_net(depth_augmented_images)
        depth_features_2 = depth_net(depth_images)
        depth_projection_1 = projection_depth(depth_features_1)
        depth_projection_2 = projection_depth(depth_features_2)
        loss_depth = nt_xent_loss(depth_projection_1, depth_projection_2)

        # Aligning RGB and depth images through contrastive learning
        rgb_features = rgb_net(rgb_images)
        depth_features = depth_net(depth_images)
        rgb_projection = projection_rgb(rgb_features)
        depth_projection = projection_depth(depth_features)
        loss_rgb_depth = contrastive_loss(rgb_projection, depth_projection)

        # Total loss
        loss = loss_rgb + loss_depth + loss_rgb_depth
        epoch_loss += loss.item()

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {epoch_loss / len(multi_modal_loader)}')
    
    if epoch % 4 == 0:
        # Save the model
        torch.save({
            'rgb_net_state_dict': rgb_net.state_dict(),
            'depth_net_state_dict': depth_net.state_dict(),
            'projection_rgb_state_dict': projection_rgb.state_dict(),
            'projection_depth_state_dict': projection_depth.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f'contrastive_learning_model_epoch_{epoch+1}.pth')

    epoch += 1

In [None]:
#Load checkpoint from training phase
load_epoch = 21
checkpoint = torch.load(f'contrastive_learning_model_epoch_{load_epoch}.pth')
rgb_net.load_state_dict(checkpoint['rgb_net_state_dict'])
depth_net.load_state_dict(checkpoint['depth_net_state_dict'])
projection_rgb.load_state_dict(checkpoint['projection_rgb_state_dict'])
projection_depth.load_state_dict(checkpoint['projection_depth_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f'Model from epoch {load_epoch} loaded successfully.')

In [None]:
class MultiModalDataset_Test(Dataset):
    def __init__(self, rgb_root, depth_root, transform_rgb, transform_depth):
        # 加载 RGB 和深度数据集
        self.rgb_dataset = datasets.ImageFolder(root=rgb_root, transform=transform_rgb)
        self.depth_dataset = datasets.ImageFolder(root=depth_root, transform=transform_depth)
        
        # 确保所有数据集长度相同
        print(len(self.rgb_dataset))
        print(len(self.depth_dataset))
        assert len(self.rgb_dataset) == len(self.depth_dataset), "Datasets have different lengths"

    def __len__(self):
        return len(self.rgb_dataset)

    def __getitem__(self, idx):
        rgb_image, label = self.rgb_dataset[idx]
        depth_image, _ = self.depth_dataset[idx]
        return rgb_image, depth_image, label
    
multi_modal_dataset_test = MultiModalDataset_Test(
    rgb_root='./affectnet_3750subset/test',
    depth_root='./affectnet_3750subset/test',
    transform_rgb=transform,
    transform_depth=transform_gray,
)

multi_modal_loader_test = DataLoader(multi_modal_dataset_test, batch_size=64, shuffle=True, num_workers=4)

In [None]:
# Finetune 阶段
print("Starting Finetune Phase...")
finetune_epochs = 5
classifier = nn.Sequential(
    nn.Linear(512, 128),
    nn.ReLU(),
    nn.Linear(128, 8)  # 假设有8个类别
).cuda()

finetune_optimizer = optim.Adam(list(rgb_net.parameters()) + list(depth_net.parameters()) + list(classifier.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(finetune_epochs):
    epoch_loss = 0
    for rgb_images, rgb_augmented_images, depth_images, depth_augmented_images, label_y in tqdm(multi_modal_loader, desc=f"Epoch [{epoch+1}/{finetune_epochs}]"):
        # 将数据移动到GPU
        rgb_images = rgb_images.cuda()
        depth_images = depth_images.cuda()

        label_y = label_y.cuda()

        # 提取投影特征并 concatenate
        rgb_features_2 = rgb_net(rgb_images)
        depth_features_2 = depth_net(depth_images)

        combined_features = torch.cat((rgb_features_2, depth_features_2), dim=1)

        # 通过分类器进行分类
        outputs = classifier(combined_features)
        loss = criterion(outputs, label_y)
        epoch_loss += loss.item()

        # 反向传播和优化
        finetune_optimizer.zero_grad()
        loss.backward()
        finetune_optimizer.step()
    print(f'Finetune Epoch [{epoch+1}/{finetune_epochs}], Average Loss: {epoch_loss / len(multi_modal_loader)}')

    # 测试集上的评估
    rgb_net.eval()
    depth_net.eval()
    classifier.eval()
    projection_rgb.eval()
    projection_depth.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for rgb_images_test, depth_images_test, label_y_test in tqdm(multi_modal_loader_test, desc=f"Epoch [{epoch+1}/{finetune_epochs}]"):
            # 将数据移动到GPU
            rgb_images_test = rgb_images_test.cuda()
            depth_images_test = depth_images_test.cuda()
            label_y_test = label_y_test.cuda()
            # print(labels)
            # print("********************")
            # print(labels_depth)

            # 提取投影特征并 concatenate
            rgb_features_2 = rgb_net(rgb_images_test)
            depth_features_2 = depth_net(depth_images_test)

            # print(rgb_features_2.shape)
            # print(depth_features_2.shape)
            combined_features = torch.cat((rgb_features_2, depth_features_2), dim=1)

            # 通过分类器进行分类
            outputs = classifier(combined_features)
            _, predicted = torch.max(outputs.data, 1)
            # print(predicted)
            total += label_y_test.size(0)
            correct += (predicted == label_y_test).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy after Finetune Epoch [{epoch+1}/{finetune_epochs}]: {accuracy:.2f}%')

    rgb_net.train()
    depth_net.train()
    classifier.train()