In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
from tqdm import tqdm

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")

# Using pretrained ResNet50
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, input_channels):
        super(ResNetFeatureExtractor, self).__init__()
        if input_channels == 3:
            self.resnet = models.resnet18(pretrained=True)
        else:
            self.resnet = models.resnet18(pretrained=True)
            self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])  # Remove the last fully connected layer
        self.fc = nn.Linear(512, 256)

    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Define two networks for RGB and depth images
rgb_net = ResNetFeatureExtractor(input_channels=3).to(device)
landmark_net = ResNetFeatureExtractor(input_channels=3).to(device)

# Define projection layers to project features into a common alignment space
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        return x

projection_rgb = ProjectionHead(256, 128).to(device)
projection_landmark = ProjectionHead(256, 128).to(device)


# 定义对比损失函数（如CLIP中常用的对比损失）
def contrastive_loss(features1, features2):
    # L2 正则化
    features1 = F.normalize(features1, p=2, dim=1)
    features2 = F.normalize(features2, p=2, dim=1)
    # 计算相似度矩阵
    logits = torch.matmul(features1, features2.T)
    labels = torch.arange(features1.size(0)).to(features1.device)
    loss = F.cross_entropy(logits, labels)
    
    return loss

# Define the NT-Xent Loss (used in SimCLR's contrastive loss)
def nt_xent_loss(features1, features2, temperature=0.5):
    # L2 normalization
    features1 = F.normalize(features1, p=2, dim=1)
    features2 = F.normalize(features2, p=2, dim=1)
    
    # Concatenate features
    features = torch.cat([features1, features2], dim=0)
    
    # Compute similarity matrix
    similarity_matrix = torch.matmul(features, features.T) / temperature
    
    # Create labels
    batch_size = features1.size(0)
    labels = torch.arange(batch_size).to(device)
    labels = torch.cat([labels, labels], dim=0)
    
    # Remove diagonal elements (self-comparison)
    mask = torch.eye(2 * batch_size, dtype=bool).to(device)
    similarity_matrix = similarity_matrix.masked_fill(mask, -float('inf'))
    
    # Compute loss
    positives = torch.cat([torch.diag(similarity_matrix, batch_size), torch.diag(similarity_matrix, -batch_size)])
    negatives = similarity_matrix[~mask].view(2 * batch_size, -1)
    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
    
    labels = torch.zeros(2 * batch_size, dtype=torch.long).to(device)
    loss = F.cross_entropy(logits, labels)
    
    return loss



Using GPU: NVIDIA GeForce RTX 3090




In [None]:
# Define the multimodal dataset
class MultiModalDataset(Dataset):
    def __init__(self, rgb_root, landmark_root, transform_rgb, transform_rgb_aug, transform_landmark, transform_landmark_aug):
        # Load RGB and landmark datasets
        self.rgb_dataset = datasets.ImageFolder(root=rgb_root, transform=transform_rgb)
        self.rgb_augmented_dataset = datasets.ImageFolder(root=rgb_root, transform=transform_rgb_aug)
        self.landmark_dataset = datasets.ImageFolder(root=landmark_root, transform=transform_landmark)
        self.landmark_augmented_dataset = datasets.ImageFolder(root=landmark_root, transform=transform_landmark_aug)
        
        # Ensure all datasets have the same length
        assert len(self.rgb_dataset) == len(self.rgb_augmented_dataset) == len(self.landmark_dataset) == len(self.landmark_augmented_dataset), "Datasets have different lengths"

    def __len__(self):
        return len(self.rgb_dataset)

    def __getitem__(self, idx):
        rgb_image, label = self.rgb_dataset[idx]
        rgb_aug_image, _ = self.rgb_augmented_dataset[idx]
        landmark_image, _ = self.landmark_dataset[idx]
        landmark_aug_image, _ = self.landmark_augmented_dataset[idx]
        return rgb_image, rgb_aug_image, landmark_image, landmark_aug_image, label


# Basic transformation
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Define data augmentation transformations (SimCLR style)
transform_rgb_augment = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.ToTensor(),
])

transform_landmark_augment = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.ToTensor(),
])

# Load the dataset
multi_modal_dataset = MultiModalDataset(
    rgb_root='./affectnet_3750subset/train',
    landmark_root='./affectnet_3750subset_with_landmarks/train',
    transform_rgb=transform,
    transform_rgb_aug=transform_rgb_augment,
    transform_landmark=transform,
    transform_landmark_aug=transform_landmark_augment
)

multi_modal_loader = DataLoader(multi_modal_dataset, batch_size=128, shuffle=True, num_workers=8)

In [4]:
print("Start training")
# Optimizer
optimizer = optim.Adam(list(rgb_net.parameters()) + list(landmark_net.parameters()) +
                       list(projection_rgb.parameters()) + list(projection_landmark.parameters()), lr=0.03)

# Training
rgb_net.train()
landmark_net.train()
projection_rgb.train()
projection_landmark.train()

epochs = 20
for epoch in range(epochs):
    epoch_loss = 0
    for rgb_images, rgb_augmented_images, landmark_images, landmark_augmented_images, label in tqdm(multi_modal_loader, desc=f"Epoch [{epoch+1}/{epochs}]"):
        # Move data to GPU
        rgb_images = rgb_images.to(device)
        rgb_augmented_images = rgb_augmented_images.to(device)
        landmark_images = landmark_images.to(device)
        landmark_augmented_images = landmark_augmented_images.to(device)

        # Contrastive learning with augmented images (RGB)
        rgb_features_1 = rgb_net(rgb_augmented_images)
        rgb_features_2 = rgb_net(rgb_images)
        rgb_projection_1 = projection_rgb(rgb_features_1)
        rgb_projection_2 = projection_rgb(rgb_features_2)
        loss_rgb = nt_xent_loss(rgb_projection_1, rgb_projection_2)

        # Contrastive learning with augmented images (Landmark)
        landmark_features_1 = landmark_net(landmark_augmented_images)
        landmark_features_2 = landmark_net(landmark_images)
        landmark_projection_1 = projection_landmark(landmark_features_1)
        landmark_projection_2 = projection_landmark(landmark_features_2)
        loss_landmark = nt_xent_loss(landmark_projection_1, landmark_projection_2)

        # Aligning RGB and Landmark images through contrastive learning
        rgb_features = rgb_net(rgb_images)
        landmark_features = landmark_net(landmark_images)
        rgb_projection = projection_rgb(rgb_features)
        landmark_projection = projection_landmark(landmark_features)
        loss_rgb_landmark = contrastive_loss(rgb_projection, landmark_projection)

        # Total loss
        loss = loss_rgb + loss_landmark + loss_rgb_landmark
        epoch_loss += loss.item()

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {epoch_loss / len(multi_modal_loader)}')
    
    if epoch % 4 == 0:
        # Save the model
        torch.save({
            'rgb_net_state_dict': rgb_net.state_dict(),
            'landmark_net_state_dict': landmark_net.state_dict(),
            'projection_rgb_state_dict': projection_rgb.state_dict(),
            'projection_landmark_state_dict': projection_landmark.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, f'contrastive_learning_model_landmark_epoch_{epoch+1}.pth')

    epoch += 1


Start training


Epoch [1/20]:   0%|          | 0/75 [00:17<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 154.00 MiB. GPU 0 has a total capacity of 23.59 GiB of which 147.06 MiB is free. Including non-PyTorch memory, this process has 23.40 GiB memory in use. Of the allocated memory 22.93 GiB is allocated by PyTorch, and 175.93 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
# Load checkpoint from training phase
load_epoch = 21
checkpoint = torch.load(f'contrastive_learning_model_landmark_epoch_{load_epoch}.pth')
rgb_net.load_state_dict(checkpoint['rgb_net_state_dict'])
landmark_net.load_state_dict(checkpoint['landmark_net_state_dict'])
projection_rgb.load_state_dict(checkpoint['projection_rgb_state_dict'])
projection_landmark.load_state_dict(checkpoint['projection_landmark_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f'Model from epoch {load_epoch} loaded successfully.')

In [None]:
class MultiModalDataset_Test(Dataset):
    def __init__(self, rgb_root, landmark_root, transform_rgb, transform_landmark):
        # Load RGB and landmark datasets
        self.rgb_dataset = datasets.ImageFolder(root=rgb_root, transform=transform_rgb)
        self.landmark_dataset = datasets.ImageFolder(root=landmark_root, transform=transform_landmark)
        
        # Ensure all datasets have the same length
        print(len(self.rgb_dataset))
        print(len(self.landmark_dataset))
        assert len(self.rgb_dataset) == len(self.landmark_dataset), "Datasets have different lengths"

    def __len__(self):
        return len(self.rgb_dataset)

    def __getitem__(self, idx):
        rgb_image, label = self.rgb_dataset[idx]
        landmark_image, _ = self.landmark_dataset[idx]
        return rgb_image, landmark_image, label
    
multi_modal_dataset_test = MultiModalDataset_Test(
    rgb_root='./affectnet_3750subset/test',
    landmark_root='./affectnet_3750subset_with_landmarks/test',
    transform_rgb=transform,
    transform_landmark=transform,
)

multi_modal_loader_test = DataLoader(multi_modal_dataset_test, batch_size=512, shuffle=False, num_workers=8)

In [None]:
# Finetune Phase
print("Starting Finetune Phase...")
finetune_epochs = 10
classifier = nn.Sequential(
    nn.Linear(512, 128),
    nn.ReLU(),
    nn.Linear(128, 8)  # Assuming 8 classes
).cuda()

finetune_optimizer = optim.Adam(list(rgb_net.parameters()) + list(landmark_net.parameters()) + list(classifier.parameters()), lr=1e-4)
criterion = nn.CrossEntropyLoss()

for epoch in range(finetune_epochs):
    epoch_loss = 0
    for rgb_images, rgb_augmented_images, landmark_images, landmark_augmented_images, label_y in tqdm(multi_modal_loader, desc=f"Epoch [{epoch+1}/{finetune_epochs}]"):
        # Move data to GPU
        rgb_images = rgb_images.cuda()
        landmark_images = landmark_images.cuda()
        label_y = label_y.cuda()

        # Extract projection features and concatenate
        rgb_features_2 = rgb_net(rgb_images)
        landmark_features_2 = landmark_net(landmark_images)
        combined_features = torch.cat((rgb_features_2, landmark_features_2), dim=1)

        # Classify using the classifier
        outputs = classifier(combined_features)
        loss = criterion(outputs, label_y)
        epoch_loss += loss.item()

        # Backpropagation and optimization
        finetune_optimizer.zero_grad()
        loss.backward()
        finetune_optimizer.step()
    print(f'Finetune Epoch [{epoch+1}/{finetune_epochs}], Average Loss: {epoch_loss / len(multi_modal_loader)}')

    # Evaluation on the test set
    rgb_net.eval()
    landmark_net.eval()
    classifier.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for rgb_images_test, landmark_images_test, label_y_test in tqdm(multi_modal_loader_test, desc=f"Epoch [{epoch+1}/{finetune_epochs}]"):
            # Move data to GPU
            rgb_images_test = rgb_images_test.cuda()
            landmark_images_test = landmark_images_test.cuda()
            label_y_test = label_y_test.cuda()

            # Extract projection features and concatenate
            rgb_features_2 = rgb_net(rgb_images_test)
            landmark_features_2 = landmark_net(landmark_images_test)
            combined_features = torch.cat((rgb_features_2, landmark_features_2), dim=1)

            # Classify using the classifier
            outputs = classifier(combined_features)
            _, predicted = torch.max(outputs.data, 1)
            total += label_y_test.size(0)
            correct += (predicted == label_y_test).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy after Finetune Epoch [{epoch+1}/{finetune_epochs}]: {accuracy:.2f}%')

    rgb_net.train()
    landmark_net.train()
    classifier.train()