In [None]:
# restart
import os
os.kill(os.getpid(), 9)

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import requests
from io import BytesIO
from PIL import Image
import torch
import torch.nn as nn
from torchvision import models
import torchvision.transforms as transforms
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.types import ArrayType, FloatType,StringType

In [3]:
# 1. 初始化 Spark，优化配置
# spark.stop()
spark = SparkSession.builder.master("local[*]").appName("MoviePosterFeatureExtraction") \
    .config("spark.executor.memory", "8g") \
    .config("spark.executor.cores", "4") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .config("spark.python.profile", "true") \
    .config("spark.executor.pyspark.python", "python3") \
    .config("spark.network.timeout", "800s")\
    .getOrCreate()

In [4]:
# 1. 加载数据
data_path = "./drive/MyDrive/CS5344_AY2425Sem2_Project/processed_movies_count_10_100.csv"
df = spark.read.csv(data_path, header=True, inferSchema=True)

In [5]:
# 分区
df=df.repartition(4)

In [6]:
df.rdd.getNumPartitions()

4

In [7]:
df.show(5, truncate=False)

+------+------+------------------+----------+------------+----------+-------+-------+------+------------+-------------+--------------------+---------------------------------------------------------------+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|row_id|id    |title             |popularity|vote_average|vote_count|revenue|runtime|budget|release_year|release_month|original_la

In [8]:
# 2. 预训练模型加载
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)  # Use weights instead of 'pretrained'
model.eval()
feature_extractor = nn.Sequential(*list(model.children())[:-1])  # 移除最后的分类层

# 添加全连接层将 2048 维降到 256 维
fc_layer = nn.Linear(2048, 256)

# 3. 图像预处理
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 205MB/s]


In [9]:
# 5️⃣ **定义下载 & 特征提取函数**
def extract_features(image_url):
    max_retries = 3  # 最多重试 3 次
    for _ in range(max_retries):
        try:
            # 下载图片
            response = requests.get(image_url, timeout=10)
            if response.status_code == 200:
                image = Image.open(BytesIO(response.content)).convert('RGB')
                image = transform(image).unsqueeze(0)  # 添加 batch 维度

                # 提取特征并降维
                with torch.no_grad():
                    features = feature_extractor(image)  # 提取 2048 维特征
                    features = features.view(features.size(0), -1)  # 展平为 (batch_size, 2048)
                    reduced_features = fc_layer(features)  # 降维到 256 维
                    reduced_features_list = reduced_features.squeeze().numpy().tolist()  # 这里将 256 维的向量转为 list
                return ",".join(map(str, reduced_features_list))

        except Exception as e:
            print(f"Error extracting features: {e}")
            continue  # 失败重试

    # 如果所有尝试都失败，则返回空列表
    return []

In [10]:
# 5. 转换为 PySpark UDF，注册 UDF（用户自定义函数）
extract_features_udf = udf(extract_features, StringType())

In [11]:
# 6. 计算 `poster_features`
df = df.withColumn("poster_features", extract_features_udf(df.poster_path))

In [12]:
df.select("poster_features").show(5, truncate=False)

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

In [31]:
 #num_rows = df.count()  # 获取总行数

In [39]:
'''
from pyspark.sql.functions import monotonically_increasing_id

def write_with_progress(df, output_path, num_rows = 277674 , batch_size=5):
    total_batches = (num_rows // batch_size) + (1 if num_rows % batch_size != 0 else 0)  # 计算批次总数
    print(total_batches)

    for batch_num in range(total_batches):
        # 计算当前批次的起始和结束索引
        start_index = batch_num * batch_size
        end_index = (batch_num + 1) * batch_size

        # 过滤出当前批次的数据
        df_batch = df.filter((df.row_id >= start_index) & (df.row_id < end_index))
        df_batch.show()

        # 写入数据
        df_batch.write.option("header", True).csv(output_path, mode="append")

        # 输出进度
        print(f"处理了 {batch_num * batch_size + len(df_batch.collect())} / {num_rows} 行数据")

# 调用函数
output_path = "./drive/MyDrive/CS5344_AY2425Sem2_Project/processed_movies_test"
write_with_progress(df,output_path)
'''

55535


ERROR:root:KeyboardInterrupt while sending command.
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.11/dist-packages/py4j/clientserver.py", line 511, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
                          ^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/socket.py", line 718, in readinto
    return self._sock.recv_into(b)
           ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


KeyboardInterrupt: 

In [17]:
df_batch = df.limit(5)  # 只处理前5行
output_path = "./drive/MyDrive/CS5344_AY2425Sem2_Project/processed_movies_with_features"
df_batch.write.option("header", True).csv(output_path, mode="overwrite")  # 写入到指定路径

In [13]:
output_path = "./drive/MyDrive/CS5344_AY2425Sem2_Project/processed_movies_with_features10_100"
df.write.option("header", True).csv(output_path, header=True, mode="overwrite")

print(f"数据处理完成，已保存到 {output_path}")

数据处理完成，已保存到 ./drive/MyDrive/CS5344_AY2425Sem2_Project/processed_movies_with_features10_100


In [21]:
!nproc

2
