In [5]:
from pyspark.sql import SparkSession
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd

# 创建 SparkSession
spark = SparkSession.builder \
    .appName("PySpark Dataloader") \
    .getOrCreate()

# 读取数据
df_spark = spark.read.csv("test_pyspark.csv", header=True, inferSchema=True)

# 假设我们需要这些特征作为模型输入：feature1, feature2, ... , featureN
# 和目标变量：target
df_spark = df_spark.select("feature1", "feature2", "featureN", "target")

# 展示预处理后的数据
df_spark.show(5)

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/02/26 10:30:47 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/02/26 10:30:47 WARN Utils: Service 'sparkDriver' could not bind on a random free port. You may check whether configuring an appropriate binding address.
25/02/26 10:30:47 WARN Utils: Service 'sparkDriver' could not bind on a random free port. You may check whether configuring an appropriate binding address.
25/02/26 10:30:47 WARN Utils: Service 'sparkDriver' could not bind on a random free port. You may check whether configuring an appropriate binding address.
25/02/26 10:30:47 WARN Utils: Service 'sparkDriver' could not bind on a random free port. You may check whether configuring an appropriate binding address.
25/02/26 10:30:47 WARN Utils: Service 'sparkDriver' could not bind on a random free port. You may check

Py4JJavaError: An error occurred while calling None.org.apache.spark.api.java.JavaSparkContext.
: java.net.BindException: Can't assign requested address: Service 'sparkDriver' failed after 16 retries (on a random free port)! Consider explicitly setting the appropriate binding address for the service 'sparkDriver' (for example spark.driver.bindAddress for SparkDriver) to the correct binding address.
	at java.base/sun.nio.ch.Net.bind0(Native Method)
	at java.base/sun.nio.ch.Net.bind(Net.java:459)
	at java.base/sun.nio.ch.Net.bind(Net.java:448)
	at java.base/sun.nio.ch.ServerSocketChannelImpl.bind(ServerSocketChannelImpl.java:227)
	at io.netty.channel.socket.nio.NioServerSocketChannel.doBind(NioServerSocketChannel.java:141)
	at io.netty.channel.AbstractChannel$AbstractUnsafe.bind(AbstractChannel.java:562)
	at io.netty.channel.DefaultChannelPipeline$HeadContext.bind(DefaultChannelPipeline.java:1334)
	at io.netty.channel.AbstractChannelHandlerContext.invokeBind(AbstractChannelHandlerContext.java:600)
	at io.netty.channel.AbstractChannelHandlerContext.bind(AbstractChannelHandlerContext.java:579)
	at io.netty.channel.DefaultChannelPipeline.bind(DefaultChannelPipeline.java:973)
	at io.netty.channel.AbstractChannel.bind(AbstractChannel.java:260)
	at io.netty.bootstrap.AbstractBootstrap$2.run(AbstractBootstrap.java:356)
	at io.netty.util.concurrent.AbstractEventExecutor.runTask(AbstractEventExecutor.java:174)
	at io.netty.util.concurrent.AbstractEventExecutor.safeExecute(AbstractEventExecutor.java:167)
	at io.netty.util.concurrent.SingleThreadEventExecutor.runAllTasks(SingleThreadEventExecutor.java:470)
	at io.netty.channel.nio.NioEventLoop.run(NioEventLoop.java:569)
	at io.netty.util.concurrent.SingleThreadEventExecutor$4.run(SingleThreadEventExecutor.java:997)
	at io.netty.util.internal.ThreadExecutorMap$2.run(ThreadExecutorMap.java:74)
	at io.netty.util.concurrent.FastThreadLocalRunnable.run(FastThreadLocalRunnable.java:30)
	at java.base/java.lang.Thread.run(Thread.java:829)


In [7]:
import numpy as np
import pandas as pd

data = np.random.randn(100, 4)
df = pd.DataFrame(data, dtype=np.float32, columns=['feature1', 'feature2', 'featureN', 'target'])
df.to_csv('test_pyspark.csv')

In [8]:
# 将 Spark DataFrame 转换为 Pandas DataFrame
#df_pandas = df_spark.toPandas()
df_pandas = df

# 将特征和标签分别提取
X = df_pandas[["feature1", "feature2", "featureN"]].values
y = df_pandas["target"].values

# 将特征和标签转换为 PyTorch 张量
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)


In [9]:
class CustomDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# 创建数据集对象
dataset = CustomDataset(X_tensor, y_tensor)

In [10]:
# 定义 DataLoader
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# 打印批次数据
for batch_X, batch_y in dataloader:
    print(batch_X.shape, batch_y.shape)


torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])


In [12]:
import torch.nn as nn
import torch.optim as optim

# 定义一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self, input_dim):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        return self.fc(x)

# 实例化模型
model = SimpleModel(input_dim=X_tensor.shape[1])

# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练模型
num_epochs = 10
for epoch in range(num_epochs):
    for batch_X, batch_y in dataloader:
        print(batch_X.shape, batch_y.shape)
        # 前向传播
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y.unsqueeze(1))
        
        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')


torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [1/10], Loss: 1.6465
torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [2/10], Loss: 1.4947
torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [3/10], Loss: 1.9935
torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [4/10], Loss: 1.2903
torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [5/10], Loss: 1.5641
torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [6/10], Loss: 1.0094
torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [7/10], Loss: 1.4722
torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [8/10], Loss: 1.8530
torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [9/10], Loss: 1.1156
torch.Size([64, 3]) torch.Size([64])
torch.Size([36, 3]) torch.Size([36])
Epoch [10/10], Lo