In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models
from tqdm import tqdm
import os

In [2]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("Using CPU")

Using GPU: NVIDIA GeForce RTX 3090


In [3]:
# Using pretrained ResNet50
class ResNetFeatureExtractor(nn.Module):
    def __init__(self, input_channels):
        super(ResNetFeatureExtractor, self).__init__()
        if input_channels == 3:
            self.resnet = models.resnet18(pretrained=True)
        else:
            self.resnet = models.resnet18(pretrained=True)
            self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.resnet = nn.Sequential(*list(self.resnet.children())[:-1])  # Remove the last fully connected layer
        self.fc = nn.Linear(512, 256)
        # self.fc = nn.Linear(2048, 256)

    def forward(self, x):
        x = self.resnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

# Define two networks for RGB and depth images
rgb_net = ResNetFeatureExtractor(input_channels=3).to(device)
depth_net = ResNetFeatureExtractor(input_channels=1).to(device)



In [4]:
# Define projection layers to project features into a common alignment space
class ProjectionHead(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(ProjectionHead, self).__init__()
        self.fc1 = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        x = self.fc1(x)
        return x

projection_rgb = ProjectionHead(256, 128).to(device)
projection_depth = ProjectionHead(256, 128).to(device)


In [5]:

# 定义对比损失函数（如CLIP中常用的对比损失）
def contrastive_loss(features1, features2):
    # L2 正则化
    features1 = F.normalize(features1, p=2, dim=1)
    features2 = F.normalize(features2, p=2, dim=1)
    # 计算相似度矩阵
    logits = torch.matmul(features1, features2.T)
    labels = torch.arange(features1.size(0)).to(features1.device)
    loss = F.cross_entropy(logits, labels)
    
    return loss

# Define the NT-Xent Loss (used in SimCLR's contrastive loss)
def nt_xent_loss(features1, features2, temperature=0.5):
    # L2 normalization
    features1 = F.normalize(features1, p=2, dim=1)
    features2 = F.normalize(features2, p=2, dim=1)
    
    # Concatenate features
    features = torch.cat([features1, features2], dim=0)
    
    # Compute similarity matrix
    similarity_matrix = torch.matmul(features, features.T) / temperature
    
    # Create labels
    batch_size = features1.size(0)
    labels = torch.arange(batch_size).to(device)
    labels = torch.cat([labels, labels], dim=0)
    
    # Remove diagonal elements (self-comparison)
    mask = torch.eye(2 * batch_size, dtype=bool).to(device)
    similarity_matrix = similarity_matrix.masked_fill(mask, -float('inf'))
    
    # Compute loss
    positives = torch.cat([torch.diag(similarity_matrix, batch_size), torch.diag(similarity_matrix, -batch_size)])
    negatives = similarity_matrix[~mask].view(2 * batch_size, -1)
    logits = torch.cat([positives.unsqueeze(1), negatives], dim=1)
    
    labels = torch.zeros(2 * batch_size, dtype=torch.long).to(device)
    loss = F.cross_entropy(logits, labels)
    
    return loss

In [6]:
from affectnet_util import AffectDataSet
from torch.utils.data import Dataset

class MultiModalDataset(Dataset):
    def __init__(self, data_path, affcls, train, exclude_classes, 
                 transform_rgb, transform_rgb_aug, transform_depth, transform_depth_aug):
        self.train=train
        self.transform_rgb = transform_rgb
        self.transform_rgb_aug = transform_rgb_aug
        self.transform_depth = transform_depth
        self.transform_depth_aug = transform_depth_aug

        # 使用 AffectDataSet 加载 RGB 和深度数据
        self.rgb_dataset = AffectDataSet(data_path=data_path, train= train, affcls=affcls, 
                                         transform=self.transform_rgb, exclude_classes=exclude_classes)
        self.rgb_augmented_dataset = AffectDataSet(data_path=data_path, train= train, affcls=affcls, 
                                                   transform=self.transform_rgb_aug, exclude_classes=exclude_classes)
        self.depth_dataset = AffectDataSet(data_path=data_path, train= train, affcls=affcls, 
                                           transform=self.transform_depth, exclude_classes=exclude_classes)
        self.depth_augmented_dataset = AffectDataSet(data_path=data_path, train= train, affcls=affcls, 
                                                     transform=self.transform_depth_aug, exclude_classes=exclude_classes)

        # 确保数据长度一致
        assert len(self.rgb_dataset) == len(self.rgb_augmented_dataset) == len(self.depth_dataset) == len(self.depth_augmented_dataset), \
            "Datasets have different lengths"

    def __len__(self):
        return len(self.rgb_dataset)

    def __getitem__(self, idx):
        # 获取 RGB 图像及其增强版本
        rgb_image, label = self.rgb_dataset[idx]
        rgb_aug_image, _ = self.rgb_augmented_dataset[idx]

        # 获取深度图像及其增强版本
        depth_image, _ = self.depth_dataset[idx]
        depth_aug_image, _ = self.depth_augmented_dataset[idx]

        return rgb_image, rgb_aug_image, depth_image, depth_aug_image, label


In [7]:
from torch.utils.data import DataLoader
from torchvision import transforms

# 定义 RGB 和深度图像的增强及基本转换
transform_rgb = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_rgb_augment = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_depth = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

transform_depth_augment = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
])

# 排除的类别
exclude_classes = [4, 5]

# 加载多模态训练集
multi_modal_train_dataset = MultiModalDataset(
    data_path="/CSCI2952X/datasets/affectnet",
    affcls=7,
    train= True,
    exclude_classes=exclude_classes,
    transform_rgb=transform_rgb,
    transform_rgb_aug=transform_rgb_augment,
    transform_depth=transform_depth,
    transform_depth_aug=transform_depth_augment
)

# 创建 DataLoader
multi_modal_train_loader = DataLoader(multi_modal_train_dataset, batch_size=32, shuffle=True, num_workers=4)

# 测试数据加载
for rgb_image, rgb_aug_image, depth_image, depth_aug_image, label in multi_modal_train_loader:
    print(f"RGB image shape: {rgb_image.shape}")  # torch.Size([32, 3, 224, 224])
    print(f"Depth image shape: {depth_image.shape}")  # torch.Size([32, 1, 224, 224])
    print(f"Labels: {label}")  # 检查标签是否正确
    break


Distribution of train samples: [ 74874 134415  25459  14090  24882]
Distribution of train samples: [ 74874 134415  25459  14090  24882]
Distribution of train samples: [ 74874 134415  25459  14090  24882]
Distribution of train samples: [ 74874 134415  25459  14090  24882]
RGB image shape: torch.Size([32, 3, 224, 224])
Depth image shape: torch.Size([32, 1, 224, 224])
Labels: tensor([1, 0, 2, 1, 1, 6, 0, 1, 6, 3, 1, 3, 3, 3, 1, 1, 1, 1, 2, 1, 1, 0, 1, 0,
        1, 0, 1, 1, 1, 0, 6, 2])


In [8]:
print("Start training")
# Optimizer
optimizer = optim.Adam(list(rgb_net.parameters()) + list(depth_net.parameters()) +
                       list(projection_rgb.parameters()) + list(projection_depth.parameters()), lr=1e-3)

# Define your experiment name and create save directory
experiment_name = 'new_rn18_depth_finetune_50ep'
save_dir = os.path.join('saved_models', experiment_name)
os.makedirs(save_dir, exist_ok=True)

best_loss = float('inf')  # Initialize best loss to a large value
best_model_path = os.path.join(save_dir, 'best_model.pth')
latest_model_path = os.path.join(save_dir, 'latest_model.pth')

# Training
rgb_net.train()
depth_net.train()
projection_rgb.train()
projection_depth.train()

epochs = 30
for epoch in range(epochs):
    epoch_loss = 0
    for rgb_images, rgb_augmented_images, depth_images, depth_augmented_images, labelq in tqdm(multi_modal_train_loader, desc=f"Epoch [{epoch+1}/{epochs}]"):
        # Move data to GPU
        rgb_images = rgb_images.to(device)
        rgb_augmented_images = rgb_augmented_images.to(device)
        depth_images = depth_images.to(device)
        depth_augmented_images = depth_augmented_images.to(device)

        # Contrastive learning with augmented images (RGB)
        rgb_features_1 = rgb_net(rgb_augmented_images)
        rgb_features_2 = rgb_net(rgb_images)
        rgb_projection_1 = projection_rgb(rgb_features_1)
        rgb_projection_2 = projection_rgb(rgb_features_2)
        loss_rgb = nt_xent_loss(rgb_projection_1, rgb_projection_2)

        # Contrastive learning with augmented images (Depth)
        depth_features_1 = depth_net(depth_augmented_images)
        depth_features_2 = depth_net(depth_images)
        depth_projection_1 = projection_depth(depth_features_1)
        depth_projection_2 = projection_depth(depth_features_2)
        loss_depth = nt_xent_loss(depth_projection_1, depth_projection_2)

        # Aligning RGB and depth images through contrastive learning
        rgb_features = rgb_net(rgb_images)
        depth_features = depth_net(depth_images)
        rgb_projection = projection_rgb(rgb_features)
        depth_projection = projection_depth(depth_features)
        loss_rgb_depth = contrastive_loss(rgb_projection, depth_projection)

        # Total loss
        loss = loss_rgb + loss_depth + loss_rgb_depth
        epoch_loss += loss.item()

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    average_loss = epoch_loss / len(multi_modal_train_loader)
    print(f'Epoch [{epoch+1}/{epochs}], Average Loss: {average_loss}')
    
    # Save the latest model
    torch.save({
        'epoch': epoch + 1,
        'rgb_net_state_dict': rgb_net.state_dict(),
        'depth_net_state_dict': depth_net.state_dict(),
        'projection_rgb_state_dict': projection_rgb.state_dict(),
        'projection_depth_state_dict': projection_depth.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'avg_loss': average_loss,
    }, latest_model_path)
    print(f'Latest model saved to {latest_model_path}')
    
    # Save the best model
    if average_loss < best_loss:
        best_loss = average_loss
        torch.save({
            'epoch': epoch + 1,
            'rgb_net_state_dict': rgb_net.state_dict(),
            'depth_net_state_dict': depth_net.state_dict(),
            'projection_rgb_state_dict': projection_rgb.state_dict(),
            'projection_depth_state_dict': projection_depth.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'avg_loss': best_loss,
        }, best_model_path)
        print(f'New best model saved with loss {best_loss:.4f} to {best_model_path}')


Start training


Epoch [1/30]:   1%|          | 98/8554 [00:16<23:24,  6.02it/s] 


KeyboardInterrupt: 

In [9]:
#Load checkpoint from training phase
checkpoint = torch.load(best_model_path)
rgb_net.load_state_dict(checkpoint['rgb_net_state_dict'])
depth_net.load_state_dict(checkpoint['depth_net_state_dict'])
projection_rgb.load_state_dict(checkpoint['projection_rgb_state_dict'])
projection_depth.load_state_dict(checkpoint['projection_depth_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
print(f'Model from {best_model_path} loaded successfully.')

  checkpoint = torch.load(best_model_path)


Model from saved_models/new_rn18_depth_finetune_50ep/best_model.pth loaded successfully.


In [None]:
class MultiModalDataset_Test(Dataset):
    def __init__(self, data_path, affcls, train, transform_rgb, transform_depth, exclude_classes=None):
        self.transform_rgb = transform_rgb
        self.transform_depth = transform_depth

        # 使用 AffectDataSet 加载 RGB 和深度数据
        self.rgb_dataset = AffectDataSet(data_path=data_path, train = train, affcls=affcls, 
                                         transform=self.transform_rgb, exclude_classes=exclude_classes)
        self.depth_dataset = AffectDataSet(data_path=data_path,  train = train, affcls=affcls, 
                                           transform=self.transform_depth, exclude_classes=exclude_classes)

        # 确保数据长度一致
        print(f"RGB dataset size: {len(self.rgb_dataset)}")
        print(f"Depth dataset size: {len(self.depth_dataset)}")
        assert len(self.rgb_dataset) == len(self.depth_dataset), "Datasets have different lengths"

    def __len__(self):
        return len(self.rgb_dataset)

    def __getitem__(self, idx):
        # 获取 RGB 和深度图像
        rgb_image, label = self.rgb_dataset[idx]
        depth_image, _ = self.depth_dataset[idx]
        return rgb_image, depth_image, label


In [None]:
# 创建测试集转换
transform_rgb = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_depth = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# 要排除的类别
exclude_classes = [4, 5]  # 假设我们需要过滤这些类别

# 加载测试集
multi_modal_dataset_test = MultiModalDataset_Test(
    data_path="/CSCI2952X/datasets/affectnet",
    affcls=7,  # 假设是7类任务
    train = False,
    transform_rgb=transform_rgb,
    transform_depth=transform_depth,
    exclude_classes=exclude_classes
)

# 创建测试 DataLoader
multi_modal_loader_test = DataLoader(multi_modal_dataset_test, batch_size=32, shuffle=False, num_workers=8)

# 验证加载效果
for rgb_image, depth_image, label in multi_modal_loader_test:
    print(f"RGB image shape: {rgb_image.shape}")  # torch.Size([32, 3, 224, 224])
    print(f"Depth image shape: {depth_image.shape}")  # torch.Size([32, 1, 224, 224])
    print(f"Labels: {label}")  # 检查标签是否正确
    break

In [None]:
import os

print("Starting Finetune Phase...")

# Directory setup for saving fine-tuned models
finetune_experiment_name = 'new_rn18_depth_finetune_50ep'
finetune_save_dir = os.path.join('saved_models', finetune_experiment_name)
os.makedirs(finetune_save_dir, exist_ok=True)

best_finetune_loss = float('inf')  # Initialize best loss
best_finetune_model_path = os.path.join(finetune_save_dir, 'best_finetune_model.pth')
latest_finetune_model_path = os.path.join(finetune_save_dir, 'latest_finetune_model.pth')

finetune_epochs = 30
classifier = nn.Sequential(
    nn.Linear(512, 128),
    nn.ReLU(),
    nn.Linear(128, 5)  # 128, x   x is number of class
).cuda()

finetune_optimizer = optim.Adam(list(rgb_net.parameters()) + list(depth_net.parameters()) + list(classifier.parameters()), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for epoch in range(finetune_epochs):
    epoch_loss = 0
    for rgb_images, rgb_augmented_images, depth_images, depth_augmented_images, label_y in tqdm(multi_modal_train_loader, desc=f"Epoch [{epoch+1}/{finetune_epochs}]"):
        # 将数据移动到GPU
        rgb_images = rgb_images.cuda()
        depth_images = depth_images.cuda()
        label_y = label_y.cuda()

        # 提取投影特征并 concatenate
        rgb_features_2 = rgb_net(rgb_images)
        depth_features_2 = depth_net(depth_images)
        combined_features = torch.cat((rgb_features_2, depth_features_2), dim=1)

        # 通过分类器进行分类
        outputs = classifier(combined_features)
        loss = criterion(outputs, label_y)
        epoch_loss += loss.item()

        # 反向传播和优化
        finetune_optimizer.zero_grad()
        loss.backward()
        finetune_optimizer.step()

    average_loss = epoch_loss / len(multi_modal_train_loader)
    print(f'Finetune Epoch [{epoch+1}/{finetune_epochs}], Average Loss: {average_loss}')

    # 保存最新模型
    torch.save({
        'epoch': epoch + 1,
        'rgb_net_state_dict': rgb_net.state_dict(),
        'depth_net_state_dict': depth_net.state_dict(),
        'classifier_state_dict': classifier.state_dict(),
        'optimizer_state_dict': finetune_optimizer.state_dict(),
        'avg_loss': average_loss,
    }, latest_finetune_model_path)
    print(f'Latest fine-tuned model saved to {latest_finetune_model_path}')

    # 保存最佳模型
    if average_loss < best_finetune_loss:
        best_finetune_loss = average_loss
        torch.save({
            'epoch': epoch + 1,
            'rgb_net_state_dict': rgb_net.state_dict(),
            'depth_net_state_dict': depth_net.state_dict(),
            'classifier_state_dict': classifier.state_dict(),
            'optimizer_state_dict': finetune_optimizer.state_dict(),
            'avg_loss': best_finetune_loss,
        }, best_finetune_model_path)
        print(f'New best fine-tuned model saved with loss {best_finetune_loss:.4f} to {best_finetune_model_path}')

    # 测试集上的评估
    rgb_net.eval()
    depth_net.eval()
    classifier.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for rgb_images_test, depth_images_test, label_y_test in tqdm(multi_modal_loader_test, desc=f"Epoch [{epoch+1}/{finetune_epochs}]"):
            # 将数据移动到GPU
            rgb_images_test = rgb_images_test.cuda()
            depth_images_test = depth_images_test.cuda()
            label_y_test = label_y_test.cuda()

            # 提取投影特征并 concatenate
            rgb_features_2 = rgb_net(rgb_images_test)
            depth_features_2 = depth_net(depth_images_test)
            combined_features = torch.cat((rgb_features_2, depth_features_2), dim=1)

            # 通过分类器进行分类
            outputs = classifier(combined_features)
            _, predicted = torch.max(outputs.data, 1)
            total += label_y_test.size(0)
            correct += (predicted == label_y_test).sum().item()

    accuracy = 100 * correct / total
    print(f'Test Accuracy after Finetune Epoch [{epoch+1}/{finetune_epochs}]: {accuracy:.2f}%')

    rgb_net.train()
    depth_net.train()
    classifier.train()
