In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch.nn as nn
import torch.optim as optim

In [2]:
class MultiTaskDataset(Dataset):
    def __init__(self, num_samples=1000, num_users=6, num_locations=5, num_activities=9):
        self.num_samples = num_samples
        self.num_users = num_users
        self.num_locations = num_locations
        self.num_activities = num_activities
        
        # 随机生成输入数据（batch_size, 3, 3, 30）
        self.inputs = torch.randn(num_samples, 100)  # 假设是WiFi CSI数据
        
        # 随机生成地面真值
        self.identity_labels = torch.randint(0, 2, (num_samples, num_users))  # 6用户的身份标签
        self.location_labels = torch.randint(0, 2, (num_samples, num_users, num_locations))  # 6用户，5位置
        self.activity_labels = torch.randint(0, 2, (num_samples, num_users, num_activities))  # 6用户，9活动
        
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.identity_labels[idx], self.location_labels[idx], self.activity_labels[idx]

# 创建Dataloader
dataset = MultiTaskDataset(num_samples=1000)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)


In [3]:
# 定义多任务学习模型
class MultiTaskModel(nn.Module):
    def __init__(self, input_size=100, num_users=6, num_locations=5, num_activities=9):
        super(MultiTaskModel, self).__init__()
        
        self.fc1 = nn.Linear(input_size, 256)
        self.fc2 = nn.Linear(256, 128)
        
        # 任务1：身份识别（输出6个概率值）
        self.identity_fc = nn.Linear(128, num_users)
        
        # 任务2：位置识别（输出6个用户 × 5个位置的概率）
        self.location_fc = nn.Linear(128, num_users * num_locations)
        
        # 任务3：活动识别（输出6个用户 × 9个活动的概率）
        self.activity_fc = nn.Linear(128, num_users * num_activities)
        
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        
        # 任务1输出
        identity_output = torch.sigmoid(self.identity_fc(x))
        
        # 任务2输出
        location_output = torch.sigmoid(self.location_fc(x)).view(-1, 6, 5)
        
        # 任务3输出
        activity_output = torch.sigmoid(self.activity_fc(x)).view(-1, 6, 9)
        
        return identity_output, location_output, activity_output

In [4]:
def train_model(model, dataloader, num_epochs=10, lr=0.001):
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # 损失函数
    criterion = nn.BCELoss()
    
    for epoch in range(num_epochs):
        model.train()
        
        running_loss = 0.0
        for inputs, identity_labels, location_labels, activity_labels in dataloader:
            optimizer.zero_grad()
            
            # 模型前向传播
            identity_out, location_out, activity_out = model(inputs)
            
            # 计算损失
            identity_loss = criterion(identity_out, identity_labels.float())
            location_loss = criterion(location_out, location_labels.float())
            activity_loss = criterion(activity_out, activity_labels.float())
            
            # 总损失
            total_loss = identity_loss + location_loss + activity_loss
            
            # 反向传播和优化
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
        
        avg_loss = running_loss / len(dataloader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

# 实例化模型并训练
model = MultiTaskModel(num_users=6, num_locations=5, num_activities=9)
train_model(model, dataloader, num_epochs=10, lr=0.001)


Epoch [1/10], Loss: 2.0835
Epoch [2/10], Loss: 2.0586
Epoch [3/10], Loss: 2.0326
Epoch [4/10], Loss: 1.9953
Epoch [5/10], Loss: 1.9510
Epoch [6/10], Loss: 1.8971
Epoch [7/10], Loss: 1.8387
Epoch [8/10], Loss: 1.7820
Epoch [9/10], Loss: 1.7064
Epoch [10/10], Loss: 1.6426


In [5]:
def compute_accuracy(predictions, labels, num_users, num_locations=None, num_activities=None):
    """
    计算多个任务的准确率
    :param predictions: 预测值
    :param labels: 真实标签
    :param num_users: 用户数量
    :param num_locations: 位置数量（如果是位置任务）
    :param num_activities: 活动数量（如果是活动任务）
    :return: 三个任务的准确率
    """
    # 身份识别准确率
    identity_pred = predictions[0].round()  # 对于身份识别，0.5为阈值
    identity_accuracy = (identity_pred == labels[0]).float().mean().item()
    
    # 位置识别准确率
    if num_locations:
        location_pred = predictions[1].round()  # 对于位置识别，0.5为阈值
        location_accuracy = (location_pred == labels[1]).float().mean().item()
    
    # 活动识别准确率
    if num_activities:
        activity_pred = predictions[2].round()  # 对于活动识别，0.5为阈值
        activity_accuracy = (activity_pred == labels[2]).float().mean().item()
    
    return identity_accuracy, location_accuracy, activity_accuracy

def train_model(model, dataloader, num_epochs=10, lr=0.001):
    # 优化器
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # 损失函数
    criterion = nn.BCELoss()
    
    for epoch in range(num_epochs):
        model.train()
        
        running_loss = 0.0
        total_identity_acc = 0.0
        total_location_acc = 0.0
        total_activity_acc = 0.0
        
        for inputs, identity_labels, location_labels, activity_labels in dataloader:
            optimizer.zero_grad()
            
            # 模型前向传播
            identity_out, location_out, activity_out = model(inputs)
            
            # 计算损失
            identity_loss = criterion(identity_out, identity_labels.float())
            location_loss = criterion(location_out, location_labels.float())
            activity_loss = criterion(activity_out, activity_labels.float())
            
            # 总损失
            total_loss = identity_loss + location_loss + activity_loss
            
            # 反向传播和优化
            total_loss.backward()
            optimizer.step()
            
            running_loss += total_loss.item()
            
            # 计算准确率
            identity_accuracy, location_accuracy, activity_accuracy = compute_accuracy(
                (identity_out, location_out, activity_out), 
                (identity_labels, location_labels, activity_labels),
                num_users=6, num_locations=5, num_activities=9
            )
            
            total_identity_acc += identity_accuracy
            total_location_acc += location_accuracy
            total_activity_acc += activity_accuracy
        
        avg_loss = running_loss / len(dataloader)
        avg_identity_acc = total_identity_acc / len(dataloader)
        avg_location_acc = total_location_acc / len(dataloader)
        avg_activity_acc = total_activity_acc / len(dataloader)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, "
              f"Identity Accuracy: {avg_identity_acc:.4f}, "
              f"Location Accuracy: {avg_location_acc:.4f}, "
              f"Activity Accuracy: {avg_activity_acc:.4f}")


In [6]:
# 实例化模型并训练
model = MultiTaskModel(num_users=6, num_locations=5, num_activities=9)
train_model(model, dataloader, num_epochs=10, lr=0.001)

Epoch [1/10], Loss: 2.0841, Identity Accuracy: 0.4946, Location Accuracy: 0.5012, Activity Accuracy: 0.5006
Epoch [2/10], Loss: 2.0569, Identity Accuracy: 0.6211, Location Accuracy: 0.5400, Activity Accuracy: 0.5279
Epoch [3/10], Loss: 2.0310, Identity Accuracy: 0.6406, Location Accuracy: 0.5536, Activity Accuracy: 0.5387
Epoch [4/10], Loss: 1.9939, Identity Accuracy: 0.6772, Location Accuracy: 0.5696, Activity Accuracy: 0.5522
Epoch [5/10], Loss: 1.9454, Identity Accuracy: 0.7109, Location Accuracy: 0.5864, Activity Accuracy: 0.5622
Epoch [6/10], Loss: 1.8893, Identity Accuracy: 0.7485, Location Accuracy: 0.5985, Activity Accuracy: 0.5727
Epoch [7/10], Loss: 1.8316, Identity Accuracy: 0.7843, Location Accuracy: 0.6068, Activity Accuracy: 0.5795
Epoch [8/10], Loss: 1.7688, Identity Accuracy: 0.8268, Location Accuracy: 0.6119, Activity Accuracy: 0.5882
Epoch [9/10], Loss: 1.7002, Identity Accuracy: 0.8651, Location Accuracy: 0.6225, Activity Accuracy: 0.5917
Epoch [10/10], Loss: 1.6314,

In [8]:
import os
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd


class WiMANS(Dataset):
    def __init__(self, root_path):
        self.data_parent_path = os.path.join(root_path, 'wifi_csi', 'amp')
        self.label_file_path = os.path.join(root_path, 'annotation.csv')
        self.num_users = 6
        self.num_locations = 5
        self.num_activities = 9
        self.location_map = {
            'a': 0,
            'b': 1,
            'c': 2,
            'd': 3,
            'e': 4,
        }
        self.activity_map = {
            'nothing': 0,
            'walk': 1,
            'rotation': 2,
            'jump': 3,
            'wave': 4,
            'lie_down': 5,
            'pick_up': 6,
            'sit_down': 7,
            'stand_up': 8
        }
        self.data_filenames = [f for f in os.listdir(self.data_parent_path) if f.endswith('.npy')]
        self.labels = pd.read_csv(self.label_file_path)

    def __len__(self):
        return len(self.data_filenames)

    def __getitem__(self, idx):
        data_filename = self.data_filenames[idx]

        data = np.load(os.path.join(self.data_parent_path, data_filename))
        data_len = data.shape[0]
        if data_len < 3000:
            data = np.pad(data, ((0, 3000 - data_len), (0, 0), (0, 0), (0, 0)), mode='constant', constant_values=0)

        row_id = os.path.splitext(data_filename)[0]
        identity_label, location_label, activity_label = self.get_labels_from_csv(row_id)

        data = torch.tensor(data, dtype=torch.float32)
        identity_label = torch.tensor(identity_label, dtype=torch.float32)
        location_label = torch.tensor(location_label, dtype=torch.float32)
        activity_label = torch.tensor(activity_label, dtype=torch.float32)

        return data, identity_label, location_label, activity_label

    def get_labels_from_csv(self, row_id):
        row = self.labels[self.labels['label'] == row_id]
        if row.empty:
            return [], [], []

        identity_label = np.zeros(self.num_users)
        location_label = np.zeros((self.num_users, self.num_locations))
        activity_label = np.zeros((self.num_users, self.num_activities))

        for i in range(1, 7):
            user_location = row[f'user_{i}_location'].values[0]
            user_activity = row[f'user_{i}_activity'].values[0]

            if user_location and user_activity:
                identity_label[i - 1] = 1

                location_idx = self.location_map.get(user_location, None)
                if location_idx is not None:
                    location_label[i - 1, location_idx] = 1

                activity_idx = self.activity_map.get(user_activity, None)
                if activity_idx is not None:
                    activity_label[i - 1, activity_idx] = 1

        return identity_label, location_label, activity_label


def get_dataloaders(dataset, batch_size, train_ratio=0.7, eval_ratio=0.1):
    total_size = len(dataset)
    train_size = int(total_size * train_ratio)
    eval_size = int(total_size * eval_ratio)
    test_size = total_size - train_size - eval_size

    train_dataset, eval_dataset, test_dataset = random_split(dataset, [train_size, eval_size, test_size])

    return (DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
            DataLoader(eval_dataset, batch_size=batch_size, shuffle=False),
            DataLoader(test_dataset, batch_size=batch_size, shuffle=False))


if __name__ == "__main__":
    dataset = WiMANS(root_path=r'E:\WorkSpace\WiMANS\dataset')
    train_loader, eval_loader, test_loader = get_dataloaders(dataset, batch_size=32)

    for batch_idx, (data, identity_label, location_label, activity_label) in enumerate(train_loader):
        print(f"Train Batch {batch_idx + 1}")
        print("Data Shape:", data.shape)
        print("Identity Label Shape:", identity_label.shape)
        print("Location Label Shape:", location_label.shape)
        print("Activity Label Shape:", activity_label.shape)
        print("-------------------------------------------------")

    for batch_idx, (data, identity_label, location_label, activity_label) in enumerate(eval_loader):
        print(f"Train Batch {batch_idx + 1}")
        print("Data Shape:", data.shape)
        print("Identity Label Shape:", identity_label.shape)
        print("Location Label Shape:", location_label.shape)
        print("Activity Label Shape:", activity_label.shape)
        print("-------------------------------------------------")

    for batch_idx, (data, identity_label, location_label, activity_label) in enumerate(test_loader):
        print(f"Train Batch {batch_idx + 1}")
        print("Data Shape:", data.shape)
        print("Identity Label Shape:", identity_label.shape)
        print("Location Label Shape:", location_label.shape)
        print("Activity Label Shape:", activity_label.shape)
        print("-------------------------------------------------")

Train Batch 1
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 2
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 3
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 4
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 5
Data Shape: torch.Size

Train Batch 35
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 36
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 37
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 38
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 39
Data Shape: torch

Train Batch 69
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 70
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 71
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 72
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 73
Data Shape: torch

Train Batch 103
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 104
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 105
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 106
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 107
Data Shape: 

Train Batch 137
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 138
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 139
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 140
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 141
Data Shape: 

Train Batch 171
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 172
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 173
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 174
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 175
Data Shape: 

Train Batch 205
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 206
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 207
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 208
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 209
Data Shape: 

Train Batch 239
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 240
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 241
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 242
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 243
Data Shape: 

Train Batch 27
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 28
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 29
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 30
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 31
Data Shape: torch

Train Batch 26
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 27
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 28
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 29
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 30
Data Shape: torch

Train Batch 61
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 62
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 63
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 64
Data Shape: torch.Size([32, 3000, 3, 3, 30])
Identity Label Shape: torch.Size([32, 6])
Location Label Shape: torch.Size([32, 6, 5])
Activity Label Shape: torch.Size([32, 6, 9])
-------------------------------------------------
Train Batch 65
Data Shape: torch

In [23]:
import torch
import torch.nn as nn

batch_size = 8
num_users = 6

# 假设模型输出的 logits 形状是 (batch_size, num_users)
logits = torch.randn(batch_size, num_users)  # 模型的原始输出（logits），未经过激活函数
print(logits)
labels = torch.randint(0, 2, (batch_size, num_users))  # ground truth 标签，表示用户的身份（0~5）
print(labels)
labels = torch.argmax(labels, dim=1)
print(labels)
# 使用 CrossEntropyLoss 计算损失
criterion = nn.CrossEntropyLoss()

# 计算损失
loss = criterion(logits, labels)

tensor([[ 0.3412, -0.6244,  0.3549, -1.3374,  0.9282, -0.4700],
        [-0.2181,  0.0880, -0.9678, -0.4414, -1.0896,  1.0199],
        [-1.1501, -0.5552,  0.8811,  0.3655,  1.0331, -1.1909],
        [ 0.8643,  1.3370, -1.0332,  0.5187,  0.4195, -1.1986],
        [-0.2446, -0.1901,  0.2082, -0.7009, -0.2874, -1.0318],
        [-0.2696, -0.1142, -0.5701, -0.8782, -1.5017,  1.9910],
        [-2.0685,  0.7801, -0.5219, -0.5015,  0.4679, -0.7593],
        [ 0.6090, -0.3824,  0.9190, -1.0831,  0.9219, -0.0073]])
tensor([[0, 0, 1, 0, 0, 0],
        [0, 0, 1, 1, 0, 0],
        [0, 0, 0, 0, 0, 0],
        [0, 0, 1, 1, 1, 0],
        [0, 0, 1, 1, 0, 1],
        [1, 1, 1, 1, 1, 0],
        [1, 0, 1, 1, 1, 1],
        [1, 0, 0, 1, 0, 0]])
tensor([2, 2, 0, 2, 2, 0, 0, 0])


In [24]:
logits, labels, loss

(tensor([[ 0.3412, -0.6244,  0.3549, -1.3374,  0.9282, -0.4700],
         [-0.2181,  0.0880, -0.9678, -0.4414, -1.0896,  1.0199],
         [-1.1501, -0.5552,  0.8811,  0.3655,  1.0331, -1.1909],
         [ 0.8643,  1.3370, -1.0332,  0.5187,  0.4195, -1.1986],
         [-0.2446, -0.1901,  0.2082, -0.7009, -0.2874, -1.0318],
         [-0.2696, -0.1142, -0.5701, -0.8782, -1.5017,  1.9910],
         [-2.0685,  0.7801, -0.5219, -0.5015,  0.4679, -0.7593],
         [ 0.6090, -0.3824,  0.9190, -1.0831,  0.9219, -0.0073]]),
 tensor([2, 2, 0, 2, 2, 0, 0, 0]),
 tensor(2.5137))

In [25]:
import torch
import torch.nn as nn

# 假设模型输出 logits 的形状是 (batch_size, num_users)
logits = torch.randn(batch_size, num_users)  # 模型输出的 logits（可能不在 0-1 之间）
labels = torch.randint(0, 2, (batch_size, num_users))  # ground truth 标签，0 或 1

# 使用 BCEWithLogitsLoss 计算损失
criterion = nn.BCEWithLogitsLoss()

# 计算损失
loss = criterion(logits, labels.float())

print(loss)


tensor(0.9509)


In [26]:
logits

tensor([[ 0.1039, -0.4053, -0.8316, -0.3469, -0.4511,  1.3592],
        [ 1.5831,  1.5689, -1.0742,  1.5420,  0.3148,  0.9685],
        [ 0.5031,  0.2837,  0.0540, -1.2283, -0.0470, -2.3765],
        [ 0.0513, -0.7965, -0.4782, -0.1902, -1.5142,  0.3918],
        [ 0.8879,  0.6216,  0.2492,  2.0380,  1.3445,  1.3542],
        [-0.1635, -0.1987, -0.6491, -1.2373, -0.9767, -0.9675],
        [-0.4281, -2.8829,  0.1470,  0.7811, -0.1360,  0.1206],
        [ 0.2627,  1.2423,  0.7121, -1.0741,  0.0800, -0.5909]])

In [27]:
labels

tensor([[1, 0, 1, 0, 0, 0],
        [0, 0, 1, 0, 1, 0],
        [1, 1, 0, 1, 0, 1],
        [1, 1, 0, 1, 1, 0],
        [1, 0, 1, 0, 1, 1],
        [0, 0, 1, 1, 1, 1],
        [1, 0, 0, 1, 1, 0],
        [0, 0, 0, 0, 0, 1]])

In [33]:
import torch
import torch.nn as nn

# 假设模型输出 logits 的形状是 (batch_size, num_users)
logits = torch.randn(batch_size, num_users)  # 模型输出的 logits（可能不在 0-1 之间）
print(logits)
logits = torch.sigmoid(logits)
print(logits)
logits = (probabilities > 0.5).float()
print(logits)

labels = torch.randint(0, 2, (batch_size, num_users))  # ground truth 标签，0 或 1
print(labels)
# 使用 BCEWithLogitsLoss 计算损失
criterion = nn.BCEWithLogitsLoss()

loss = criterion(logits, labels.float())

print(loss)


tensor([[ 0.8509,  0.3682,  1.4830, -0.2735, -0.5762, -0.0954],
        [-1.4353,  1.4712,  1.6778,  0.4426,  0.4629, -0.8603],
        [ 0.0739,  2.0758, -0.4440,  1.2264, -0.4347,  0.1144],
        [-0.6774, -0.3859, -0.1144,  2.9554, -0.0995,  0.4561],
        [-0.3645,  0.5023, -0.3670, -0.5710, -1.5720, -0.4265],
        [-0.4782, -1.1209, -0.1772, -0.1136, -0.5844, -0.4027],
        [ 0.1939,  0.3016,  2.0066, -0.1697, -1.6585,  1.6962],
        [ 0.3783, -0.1988, -1.0931, -0.0185, -1.5663, -1.2898]])
tensor([[0.7008, 0.5910, 0.8150, 0.4320, 0.3598, 0.4762],
        [0.1923, 0.8132, 0.8426, 0.6089, 0.6137, 0.2973],
        [0.5185, 0.8885, 0.3908, 0.7732, 0.3930, 0.5286],
        [0.3368, 0.4047, 0.4714, 0.9505, 0.4752, 0.6121],
        [0.4099, 0.6230, 0.4093, 0.3610, 0.1719, 0.3950],
        [0.3827, 0.2458, 0.4558, 0.4716, 0.3579, 0.4007],
        [0.5483, 0.5748, 0.8815, 0.4577, 0.1600, 0.8450],
        [0.5935, 0.4505, 0.2510, 0.4954, 0.1727, 0.2159]])
tensor([[0., 0., 0., 0

In [34]:
import torch
import torch.nn as nn

# 假设模型输出 logits 的形状是 (batch_size, num_users)
logits = torch.randn(batch_size, num_users)  # 模型输出的 logits（可能不在 0-1 之间）
print(logits)
logits = (torch.sigmoid(logits) > 0.5).float()
print(logits)

labels = torch.randint(0, 2, (batch_size, num_users))  # ground truth 标签，0 或 1
print(labels)
# 使用 BCEWithLogitsLoss 计算损失
criterion = nn.BCEWithLogitsLoss()

loss = criterion(logits, labels.float())

print(loss)


tensor([[-0.0971,  2.0298, -0.1484,  1.8436,  0.3407,  1.1495],
        [-0.7953, -0.0558,  0.5740,  0.9545,  1.7190, -0.1782],
        [ 0.1416,  0.7904, -0.5927,  0.1953,  1.4681,  0.3882],
        [ 0.8078, -1.0867,  0.1647,  1.0118,  0.5411, -2.1195],
        [ 0.9472,  0.9597,  0.9431, -1.8169,  0.1385,  1.4893],
        [-1.7423,  1.2083, -1.1557,  0.1703, -0.8213, -1.2719],
        [-0.1912,  0.6732, -1.4709,  0.1802,  1.1290, -1.3650],
        [-0.2733,  0.0190, -0.6806, -1.3125, -0.2110, -0.8024]])
tensor([[0., 1., 0., 1., 1., 1.],
        [0., 0., 1., 1., 1., 0.],
        [1., 1., 0., 1., 1., 1.],
        [1., 0., 1., 1., 1., 0.],
        [1., 1., 1., 0., 1., 1.],
        [0., 1., 0., 1., 0., 0.],
        [0., 1., 0., 1., 1., 0.],
        [0., 1., 0., 0., 0., 0.]])
tensor([[0, 0, 1, 0, 0, 1],
        [1, 1, 0, 0, 1, 0],
        [1, 1, 1, 0, 1, 1],
        [1, 0, 0, 1, 0, 0],
        [0, 1, 0, 1, 1, 1],
        [1, 0, 1, 0, 1, 0],
        [0, 0, 0, 0, 0, 1],
        [1, 0, 0, 

In [37]:
import torch
import torch.nn as nn

# 假设模型输出 logits 的形状是 (batch_size, num_users)
logits = torch.randn(batch_size, num_users)  # 模型输出的 logits（可能不在 0-1 之间）
print(logits)
logits = (torch.sigmoid(logits) > 0.5).float()
print(logits)

labels = torch.randint(0, 2, (batch_size, num_users)).float()  # ground truth 标签，0 或 1
print(labels)


tensor([[ 0.9608,  1.1219, -0.6967, -0.3085, -1.1063, -0.2366],
        [-1.1130, -0.8550,  1.2127, -0.8666,  0.9710, -0.1635],
        [ 0.4302,  1.0992, -0.6155, -0.4272,  1.1003,  0.0785],
        [-0.8915, -0.0275,  1.2975,  0.2108,  1.5452,  0.0120],
        [-1.1062,  0.7565, -0.4452, -0.3663,  1.4568, -1.3323],
        [-0.1480,  0.2812,  0.8445, -1.0763,  1.5816, -0.4531],
        [ 0.7678, -0.2042,  1.1853,  0.3716,  0.7451,  0.9315],
        [ 1.1945, -0.6962, -1.5167, -0.6441, -0.2280,  0.8695]])
tensor([[1., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 1., 0.],
        [1., 1., 0., 0., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 1., 0., 0., 1., 0.],
        [0., 1., 1., 0., 1., 0.],
        [1., 0., 1., 1., 1., 1.],
        [1., 0., 0., 0., 0., 1.]])
tensor([[0., 0., 1., 1., 0., 0.],
        [0., 0., 1., 0., 0., 1.],
        [0., 0., 0., 1., 0., 1.],
        [0., 0., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 0., 0., 1., 0., 1.],
        [0

In [38]:
logits == labels

tensor([[False, False, False, False,  True,  True],
        [ True,  True,  True,  True, False, False],
        [False, False,  True, False, False,  True],
        [ True,  True,  True, False, False, False],
        [False,  True, False, False,  True,  True],
        [False, False, False, False, False, False],
        [False, False, False, False,  True,  True],
        [False, False,  True,  True, False,  True]])

In [40]:
(logits == labels).float()

tensor([[0., 0., 0., 0., 1., 1.],
        [1., 1., 1., 1., 0., 0.],
        [0., 0., 1., 0., 0., 1.],
        [1., 1., 1., 0., 0., 0.],
        [0., 1., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 1., 1., 0., 1.]])

In [44]:
(logits == labels).float().mean().item()

0.3958333432674408