In [1]:
import torch
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
import time
import os
import joblib
import torch.nn.functional as F

# 设置设备：支持 MPS（Apple Silicon）或 CPU
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

# 创建模型保存目录
os.makedirs("model", exist_ok=True)

Using device: mps


In [2]:
class RMSLELoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_pred, y_true):
        y_pred = torch.clamp(y_pred, min=0)
        log_pred = torch.log1p(y_pred)
        log_true = torch.log1p(y_true)
        return torch.sqrt(torch.mean((log_pred - log_true) ** 2))

In [3]:
def create_dataset(file_path, save_scaler_path='model/scaler.pkl', test_size=0.2, random_state=42):
    df = pd.read_csv(file_path).dropna()
    df = df.drop(columns=['id'])
    df['Sex'] = df['Sex'].map({'male': 1, 'female': 0})

    x = df.iloc[:, :-1]
    y = df.iloc[:, -1]

    scaler = MinMaxScaler()
    x_scaled = scaler.fit_transform(x)
    joblib.dump(scaler, save_scaler_path)

    x_train, x_valid, y_train, y_valid = train_test_split(
        x_scaled, y, test_size=test_size, random_state=random_state
    )

    x_train = x_train.astype(np.float32)
    y_train = y_train.astype(np.float32)
    x_valid = x_valid.astype(np.float32)
    y_valid = y_valid.astype(np.float32)

    train_dataset = TensorDataset(torch.tensor(x_train), torch.tensor(y_train.values))
    valid_dataset = TensorDataset(torch.tensor(x_valid), torch.tensor(y_valid.values))

    return train_dataset, valid_dataset, x.shape[1]


In [None]:
class Model(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.input = nn.Linear(dim, 256)
        self.bn_input = nn.BatchNorm1d(256)

        # 残差线性升维
        self.res_proj1 = nn.Linear(256, 512)
        self.res_proj2 = nn.Linear(512, 512)
        self.res_proj3 = nn.Linear(512, 256)

        # Block 1: 256 → 512
        self.block1 = nn.Sequential(
            nn.Linear(256, 512), nn.BatchNorm1d(512), nn.SiLU(), nn.Dropout(0.3),
            nn.Linear(512, 512), nn.BatchNorm1d(512), nn.SiLU(), nn.Dropout(0.3)
        )

        # Block 2: 512 → 512
        self.block2 = nn.Sequential(
            nn.Linear(512, 512), nn.BatchNorm1d(512), nn.SiLU(), nn.Dropout(0.3),
            nn.Linear(512, 512), nn.BatchNorm1d(512), nn.SiLU(), nn.Dropout(0.3)
        )

        # Block 3: 512 → 512
        self.block3 = nn.Sequential(
            nn.Linear(512, 512), nn.BatchNorm1d(512), nn.SiLU(), nn.Dropout(0.3),
            nn.Linear(512, 512), nn.BatchNorm1d(512), nn.SiLU(), nn.Dropout(0.3)
        )

        # Block 4: 512 → 256
        self.block4 = nn.Sequential(
            nn.Linear(512, 256), nn.BatchNorm1d(256), nn.SiLU(), nn.Dropout(0.3),
            nn.Linear(256, 256), nn.BatchNorm1d(256), nn.SiLU(), nn.Dropout(0.3)
        )

        # Block 5: 256 → 128
        self.block5 = nn.Sequential(
            nn.Linear(256, 128), nn.BatchNorm1d(128), nn.SiLU(), nn.Dropout(0.2),
            nn.Linear(128, 128), nn.BatchNorm1d(128), nn.SiLU(), nn.Dropout(0.2)
        )

        # Block 6: 128 → 64
        self.block6 = nn.Sequential(
            nn.Linear(128, 64), nn.BatchNorm1d(64), nn.SiLU(), nn.Dropout(0.2)
        )

        # Output block
        self.output_block = nn.Sequential(
            nn.Linear(64, 32), nn.BatchNorm1d(32), nn.SiLU(), nn.Dropout(0.1),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        x = F.silu(self.bn_input(self.input(x)))

        res1 = self.res_proj1(x)
        x = self.block1(x) + res1

        res2 = self.res_proj2(x)
        x = self.block2(x) + res2

        res3 = x
        x = self.block3(x) + res3

        x = self.block4(x) + self.res_proj3(res3)

        x = self.block5(x)
        x = self.block6(x)
        x = self.output_block(x)
        return x


In [5]:
def train(train_dataset, valid_dataset, dim, num_epoch):
    train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=0)
    valid_loader = DataLoader(valid_dataset, batch_size=512, shuffle=False, num_workers=0)

    model = Model(dim).to(device)
    criterion = RMSLELoss()
    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=15)

    train_losses = []
    valid_losses = []

    for epoch in range(num_epoch):
        model.train()
        total_train_loss = 0.0
        train_batches = 0
        start_time = time.time()

        for x, y in train_loader:
            x, y = x.to(device), y.to(device).unsqueeze(1)
            pred = model(x)
            loss = criterion(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_train_loss += loss.item()
            train_batches += 1

        avg_train_loss = total_train_loss / train_batches
        train_losses.append(avg_train_loss)

        # ===== 验证阶段 =====
        model.eval()
        total_valid_loss = 0.0
        valid_batches = 0
        with torch.no_grad():
            for x, y in valid_loader:
                x, y = x.to(device), y.to(device).unsqueeze(1)
                pred = model(x)
                loss = criterion(pred, y)
                total_valid_loss += loss.item()
                valid_batches += 1

        avg_valid_loss = total_valid_loss / valid_batches
        valid_losses.append(avg_valid_loss)

        scheduler.step(avg_valid_loss)
        print(f"Epoch {epoch+1}, Train Loss: {avg_train_loss:.4f}, Valid Loss: {avg_valid_loss:.4f}, Time: {time.time() - start_time:.2f}s")

    # 保存模型
    torch.save(model.state_dict(), 'model/model.pth')
    print("模型已保存到 model/model.pth")

    # 绘图
    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epoch + 1), train_losses, label='Train Loss', marker='o')
    plt.plot(range(1, num_epoch + 1), valid_losses, label='Valid Loss', marker='s')
    plt.title('Training & Validation Loss Curve')
    plt.xlabel('Epoch')
    plt.ylabel('RMSLE Loss')
    plt.legend()
    plt.grid(True)
    plt.savefig('model/loss_curve.png')
    plt.show()

# %%
def test(test_dataset, dim):
    dataloader = DataLoader(
        test_dataset,
        batch_size=512,
        shuffle=False,
        num_workers=0
    )
    model = Model(dim).to(device)
    model.load_state_dict(torch.load('model/model.pth'))
    model.eval()

    predictions = []
    with torch.no_grad():
        for x, _ in dataloader:
            x = x.to(device)
            y_pred = model(x)
            predictions.extend(y_pred.cpu().squeeze().tolist())

    print("测试集预测完成，前10个结果：", predictions[:10])
    return predictions


In [6]:
if __name__ == '__main__':
    train_dataset, valid_dataset, dim = create_dataset('data/train.csv')
    # test_dataset, _ = load_test_dataset('data/test.csv')

    train(train_dataset, valid_dataset, dim, num_epoch=100)
    test(valid_dataset, dim)


Epoch 1, Train Loss: 1.9719, Valid Loss: 0.9466, Time: 28.41s
Epoch 2, Train Loss: 0.5394, Valid Loss: 0.1815, Time: 27.27s
Epoch 3, Train Loss: 0.1955, Valid Loss: 0.1009, Time: 27.04s
Epoch 4, Train Loss: 0.1648, Valid Loss: 0.0752, Time: 25.44s
Epoch 5, Train Loss: 0.1573, Valid Loss: 0.0829, Time: 26.75s
Epoch 6, Train Loss: 0.1528, Valid Loss: 0.0701, Time: 28.14s
Epoch 7, Train Loss: 0.1501, Valid Loss: 0.0714, Time: 27.37s
Epoch 8, Train Loss: 0.1462, Valid Loss: 0.0634, Time: 25.01s
Epoch 9, Train Loss: 0.1440, Valid Loss: 0.0661, Time: 27.56s
Epoch 10, Train Loss: 0.1420, Valid Loss: 0.0744, Time: 27.79s
Epoch 11, Train Loss: 0.1407, Valid Loss: 0.0780, Time: 29.23s
Epoch 12, Train Loss: 0.1391, Valid Loss: 0.0645, Time: 25.80s
Epoch 13, Train Loss: 0.1383, Valid Loss: 0.0735, Time: 25.88s
Epoch 14, Train Loss: 0.1369, Valid Loss: 0.0642, Time: 25.88s
Epoch 15, Train Loss: 0.1355, Valid Loss: 0.0667, Time: 24.70s
Epoch 16, Train Loss: 0.1345, Valid Loss: 0.0631, Time: 24.98s
E

KeyboardInterrupt: 