<a href="https://colab.research.google.com/github/pomeloblue/GPT-SoVITS/blob/main/%E5%9B%BE%E6%96%87%E6%A3%80%E7%B4%A2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import pandas as pd
import torch
from PIL import Image
from transformers import BertTokenizer, BertModel
from torchvision import transforms, models
from tqdm import tqdm
import time
from google.colab import files

# GPU加速
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# 检查文件
print("Current directory contents:")
print(os.listdir())

Using device: cuda
Current directory contents:
['.config', 'sample_data']


1: 导入必要的库和设置

In [2]:
def check_and_upload_file(filename):
    if not os.path.exists(filename):
        print(f"{filename} not found. Please upload it.")
        uploaded = files.upload()
        if filename not in uploaded:
            print(f"Failed to upload {filename}. Please try again.")
            return False
    return True

# 检查并上传必要的文件
required_files = ['train.csv', 'test_query.csv']
for file in required_files:
    if not check_and_upload_file(file):
        raise Exception(f"Failed to upload {file}")

# 检查图像文件夹
image_folder = 'image'
if not os.path.exists(image_folder):
    print(f"'{image_folder}' folder not found. Please upload a zip file containing the image folder.")
    uploaded = files.upload()
    if 'image.zip' in uploaded:
        !unzip -q image.zip
        print("Image folder extracted.")
    else:
        raise Exception("Failed to upload image.zip")

# 加载数据
train_df = pd.read_csv('train.csv', sep='\t')
test_df = pd.read_csv('test_query.csv')
print("Train data sample:")
print(train_df.head())
print("\nTrain data columns:")
print(train_df.columns)
print("\nTest data sample:")
print(test_df.head())
print("\nTest data columns:")
print(test_df.columns)

NameError: name 'os' is not defined

2:检查文件是否存在和上传函数

In [None]:
# 图像预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 文本预处理
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# 文本特征提取
bert_model = BertModel.from_pretrained('bert-base-chinese').to(device)

def extract_text_features(text):
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = bert_model(**inputs)
    return outputs.last_hidden_state[:, 0, :]

# 图像特征提取
resnet = models.resnet50(pretrained=True).to(device)
resnet.fc = torch.nn.Identity()

def extract_image_features(image_path):
    try:
        img = Image.open(image_path).convert('RGB')
        img_tensor = transform(img).unsqueeze(0).to(device)
        with torch.no_grad():
            return resnet(img_tensor)
    except Exception as e:
        print(f"Error processing image {image_path}: {str(e)}")
        return None

class MultimodalModel(torch.nn.Module):
    def __init__(self, text_dim, img_dim, output_dim):
        super().__init__()
        self.text_projection = torch.nn.Linear(text_dim, output_dim)
        self.img_projection = torch.nn.Linear(img_dim, output_dim)

    def forward(self, text_features, img_features):
        text_output = self.text_projection(text_features) if text_features is not None else None
        img_output = self.img_projection(img_features) if img_features is not None else None
        return text_output, img_output

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/110k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/269k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/624 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/412M [00:00<?, ?B/s]

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 183MB/s]


 3:模型和特征提取函数

In [None]:
def main():
    # 加载数据
    train_df = pd.read_csv('train.csv', sep='\t')
    test_df = pd.read_csv('test_query.csv')
    print("Train data sample:")
    print(train_df.head())
    print("\nTrain data columns:")
    print(train_df.columns)
    print("\nTest data sample:")
    print(test_df.head())
    print("\nTest data columns:")
    print(test_df.columns)

    model = MultimodalModel(768, 2048, 512).to(device)
    optimizer = torch.optim.Adam(model.parameters())
    criterion = torch.nn.CosineSimilarity()

    num_epochs = 15
    total_rows = len(train_df)

    start_time = time.time()

    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch+1}/{num_epochs}")
        skipped_rows = 0
        processed_rows = 0
        epoch_loss = 0

        for index, row in tqdm(train_df.iterrows(), total=total_rows, desc=f"Epoch {epoch+1}"):
            title = str(row['title'])
            path = str(row['path'])

            text_features = extract_text_features(title)
            img_features = extract_image_features(os.path.join(image_folder, path))

            if img_features is None:
                skipped_rows += 1
                continue

            text_output, img_output = model(text_features, img_features)
            loss = 1 - criterion(text_output, img_output)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            processed_rows += 1

            if processed_rows % 1000 == 0:
                elapsed_time = time.time() - start_time
                print(f"Processed {processed_rows} rows in {elapsed_time:.2f} seconds")

        if processed_rows > 0:
            avg_loss = epoch_loss / processed_rows
            print(f"Epoch {epoch+1}/{num_epochs} completed")
            print(f"Average loss: {avg_loss:.4f}")
        else:
            print(f"Epoch {epoch+1}/{num_epochs} completed")
            print("No rows were processed in this epoch")
        print(f"Skipped rows in this epoch: {skipped_rows}")

    print("Training completed")
    print(f"Total time elapsed: {time.time() - start_time:.2f} seconds")

    return model, criterion, test_df

if __name__ == "__main__":
    trained_model, trained_criterion, test_df = main()

Train data sample:
                                title  \
0  云思木想2019秋装新款高领弹力前后打揽修身织带长袖T恤女36723   
1          闺蜜裤2019年春夏款减龄双杠棉质弹力打底裤外穿显瘦   
2  画瓷 2019夏季新款小清新chic一字肩上衣女超仙透明网纱蕾丝小衫   
3    帕波仕蒂海宁真皮衣男夹克水貂毛领绵羊皮羽绒服2019新款休闲外套   
4    实拍#大码女装胖mm春夏新款显瘦连衣裙 减龄赫本中长款复古小黑裙   

                                       path  
0  0ec90019-eea7-4285-adf2-a4a5bb85198b.jpg  
1  67257121-c77d-4529-9b21-23a04da0a8eb.jpg  
2  49a5f2bd-03a0-4a74-9c0a-e97849d31262.jpg  
3  307151b3-d488-4bc0-aea8-64c7630b291f.jpg  
4  e93fe9dd-a6fb-438c-b316-c5c70e7e59cc.jpg  

Train data columns:
Index(['title', 'path'], dtype='object')

Test data sample:
                                title
0    雪纺半身裙夏女2019新款韩版中长款裙子波点A字裙百褶裙薄款半裙
1     程茧儿 秋冬新款ZIM超仙流苏印花系带荷叶边性感拼接长袖连衣裙
2  福利款【格】2019夏新款毛边牛仔裙a字裙气质半身裙百搭短裙0393
3    妈妈装夏装套装小衫上衣服中老年女装阔太太大码两件套洋气40岁50
4  台装春夏新品 洪前气质印花背部经典条纹短袖夏季清凉T恤女86287Q

Test data columns:
Index(['title'], dtype='object')
Starting epoch 1/15


Epoch 1:  20%|██        | 1004/5000 [00:23<01:16, 52.07it/s]

Processed 1000 rows in 23.75 seconds


Epoch 1:  40%|████      | 2005/5000 [00:45<01:01, 48.41it/s]

Processed 2000 rows in 45.54 seconds


Epoch 1:  60%|██████    | 3005/5000 [01:06<00:39, 50.96it/s]

Processed 3000 rows in 66.14 seconds


Epoch 1:  80%|████████  | 4008/5000 [01:28<00:21, 46.43it/s]

Processed 4000 rows in 88.41 seconds


Epoch 1: 100%|██████████| 5000/5000 [01:51<00:00, 44.93it/s]


Processed 5000 rows in 111.30 seconds
Epoch 1/15 completed
Average loss: 0.0008
Skipped rows in this epoch: 0
Starting epoch 2/15


Epoch 2:  20%|██        | 1010/5000 [00:21<01:17, 51.49it/s]

Processed 1000 rows in 132.75 seconds


Epoch 2:  40%|████      | 2008/5000 [00:43<00:58, 51.45it/s]

Processed 2000 rows in 154.88 seconds


Epoch 2:  60%|██████    | 3008/5000 [01:04<00:39, 50.29it/s]

Processed 3000 rows in 175.85 seconds


Epoch 2:  80%|████████  | 4009/5000 [01:27<00:20, 48.85it/s]

Processed 4000 rows in 198.64 seconds


Epoch 2: 100%|██████████| 5000/5000 [01:50<00:00, 45.43it/s]


Processed 5000 rows in 221.37 seconds
Epoch 2/15 completed
Average loss: 0.0001
Skipped rows in this epoch: 0
Starting epoch 3/15


Epoch 3:  20%|██        | 1006/5000 [00:21<01:23, 47.66it/s]

Processed 1000 rows in 243.12 seconds


Epoch 3:  40%|████      | 2007/5000 [00:44<01:01, 48.69it/s]

Processed 2000 rows in 265.95 seconds


Epoch 3:  60%|██████    | 3005/5000 [01:07<01:02, 32.04it/s]

Processed 3000 rows in 288.52 seconds


Epoch 3:  80%|████████  | 4004/5000 [01:29<00:21, 46.46it/s]

Processed 4000 rows in 311.04 seconds


Epoch 3: 100%|██████████| 5000/5000 [01:52<00:00, 44.36it/s]


Processed 5000 rows in 334.11 seconds
Epoch 3/15 completed
Average loss: 0.0001
Skipped rows in this epoch: 0
Starting epoch 4/15


Epoch 4:  20%|██        | 1006/5000 [00:22<01:55, 34.71it/s]

Processed 1000 rows in 356.13 seconds


Epoch 4:  40%|████      | 2008/5000 [00:44<00:59, 50.40it/s]

Processed 2000 rows in 378.36 seconds


Epoch 4:  60%|██████    | 3007/5000 [01:07<00:39, 49.83it/s]

Processed 3000 rows in 401.26 seconds


Epoch 4:  80%|████████  | 4007/5000 [01:28<00:19, 51.77it/s]

Processed 4000 rows in 422.81 seconds


Epoch 4: 100%|██████████| 5000/5000 [01:51<00:00, 44.92it/s]


Processed 5000 rows in 445.42 seconds
Epoch 4/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 5/15


Epoch 5:  20%|██        | 1007/5000 [00:22<01:47, 37.11it/s]

Processed 1000 rows in 467.98 seconds


Epoch 5:  40%|████      | 2007/5000 [00:44<01:01, 48.51it/s]

Processed 2000 rows in 489.45 seconds


Epoch 5:  60%|██████    | 3006/5000 [01:06<00:40, 49.59it/s]

Processed 3000 rows in 511.97 seconds


Epoch 5:  80%|████████  | 4004/5000 [01:28<00:25, 38.72it/s]

Processed 4000 rows in 533.68 seconds


Epoch 5: 100%|██████████| 5000/5000 [01:50<00:00, 45.22it/s]


Processed 5000 rows in 556.00 seconds
Epoch 5/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 6/15


Epoch 6:  20%|██        | 1008/5000 [00:22<01:18, 50.86it/s]

Processed 1000 rows in 578.78 seconds


Epoch 6:  40%|████      | 2008/5000 [00:44<01:00, 49.32it/s]

Processed 2000 rows in 600.12 seconds


Epoch 6:  60%|██████    | 3007/5000 [01:06<00:41, 48.45it/s]

Processed 3000 rows in 622.69 seconds


Epoch 6:  80%|████████  | 4004/5000 [01:29<00:31, 31.62it/s]

Processed 4000 rows in 645.09 seconds


Epoch 6: 100%|██████████| 5000/5000 [01:50<00:00, 45.16it/s]


Processed 5000 rows in 666.72 seconds
Epoch 6/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 7/15


Epoch 7:  20%|██        | 1010/5000 [00:22<01:20, 49.29it/s]

Processed 1000 rows in 689.25 seconds


Epoch 7:  40%|████      | 2004/5000 [00:44<01:15, 39.72it/s]

Processed 2000 rows in 710.79 seconds


Epoch 7:  60%|██████    | 3005/5000 [01:06<00:39, 49.95it/s]

Processed 3000 rows in 733.03 seconds


Epoch 7:  80%|████████  | 4008/5000 [01:28<00:22, 43.14it/s]

Processed 4000 rows in 755.31 seconds


Epoch 7: 100%|██████████| 5000/5000 [01:49<00:00, 45.59it/s]


Processed 5000 rows in 776.41 seconds
Epoch 7/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 8/15


Epoch 8:  20%|██        | 1009/5000 [00:22<01:17, 51.72it/s]

Processed 1000 rows in 798.67 seconds


Epoch 8:  40%|████      | 2004/5000 [00:43<00:59, 50.63it/s]

Processed 2000 rows in 819.38 seconds


Epoch 8:  60%|██████    | 3005/5000 [01:04<00:39, 50.61it/s]

Processed 3000 rows in 841.31 seconds


Epoch 8:  80%|████████  | 4004/5000 [01:26<00:31, 31.44it/s]

Processed 4000 rows in 863.10 seconds


Epoch 8: 100%|██████████| 5000/5000 [01:48<00:00, 46.29it/s]


Processed 5000 rows in 884.43 seconds
Epoch 8/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 9/15


Epoch 9:  20%|██        | 1007/5000 [00:22<01:18, 51.10it/s]

Processed 1000 rows in 906.44 seconds


Epoch 9:  40%|████      | 2008/5000 [00:43<00:58, 51.20it/s]

Processed 2000 rows in 927.58 seconds


Epoch 9:  60%|██████    | 3006/5000 [01:06<00:40, 49.22it/s]

Processed 3000 rows in 950.59 seconds


Epoch 9:  80%|████████  | 4008/5000 [01:29<00:21, 46.41it/s]

Processed 4000 rows in 973.74 seconds


Epoch 9: 100%|██████████| 5000/5000 [01:50<00:00, 45.23it/s]


Processed 5000 rows in 994.97 seconds
Epoch 9/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 10/15


Epoch 10:  20%|██        | 1006/5000 [00:23<01:26, 46.21it/s]

Processed 1000 rows in 1017.95 seconds


Epoch 10:  40%|████      | 2004/5000 [00:45<01:38, 30.48it/s]

Processed 2000 rows in 1040.54 seconds


Epoch 10:  60%|██████    | 3006/5000 [01:08<00:39, 50.03it/s]

Processed 3000 rows in 1062.97 seconds


Epoch 10:  80%|████████  | 4005/5000 [01:31<00:21, 46.48it/s]

Processed 4000 rows in 1086.18 seconds


Epoch 10: 100%|██████████| 5000/5000 [01:53<00:00, 44.02it/s]


Processed 5000 rows in 1108.56 seconds
Epoch 10/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 11/15


Epoch 11:  20%|██        | 1005/5000 [00:22<01:21, 48.84it/s]

Processed 1000 rows in 1131.00 seconds


Epoch 11:  40%|████      | 2007/5000 [00:45<01:01, 48.39it/s]

Processed 2000 rows in 1154.18 seconds


Epoch 11:  60%|██████    | 3006/5000 [01:08<00:56, 35.47it/s]

Processed 3000 rows in 1176.44 seconds


Epoch 11:  80%|████████  | 4006/5000 [01:30<00:19, 49.84it/s]

Processed 4000 rows in 1198.95 seconds


Epoch 11: 100%|██████████| 5000/5000 [01:53<00:00, 44.23it/s]


Processed 5000 rows in 1221.60 seconds
Epoch 11/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 12/15


Epoch 12:  20%|██        | 1010/5000 [00:21<01:18, 50.63it/s]

Processed 1000 rows in 1242.95 seconds


Epoch 12:  40%|████      | 2006/5000 [00:44<01:02, 48.21it/s]

Processed 2000 rows in 1265.53 seconds


Epoch 12:  60%|██████    | 3009/5000 [01:06<00:53, 37.25it/s]

Processed 3000 rows in 1288.21 seconds


Epoch 12:  80%|████████  | 4005/5000 [01:28<00:19, 51.14it/s]

Processed 4000 rows in 1309.81 seconds


Epoch 12: 100%|██████████| 5000/5000 [01:51<00:00, 45.02it/s]


Processed 5000 rows in 1332.68 seconds
Epoch 12/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 13/15


Epoch 13:  20%|██        | 1004/5000 [00:22<02:03, 32.35it/s]

Processed 1000 rows in 1354.60 seconds


Epoch 13:  40%|████      | 2007/5000 [00:44<01:02, 48.25it/s]

Processed 2000 rows in 1376.82 seconds


Epoch 13:  60%|██████    | 3009/5000 [01:06<00:40, 48.65it/s]

Processed 3000 rows in 1399.48 seconds


Epoch 13:  80%|████████  | 4005/5000 [01:28<00:24, 40.72it/s]

Processed 4000 rows in 1421.27 seconds


Epoch 13: 100%|██████████| 5000/5000 [01:51<00:00, 44.89it/s]


Processed 5000 rows in 1444.07 seconds
Epoch 13/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 14/15


Epoch 14:  20%|██        | 1005/5000 [00:22<01:22, 48.30it/s]

Processed 1000 rows in 1466.68 seconds


Epoch 14:  40%|████      | 2005/5000 [00:43<00:59, 50.29it/s]

Processed 2000 rows in 1487.92 seconds


Epoch 14:  60%|██████    | 3006/5000 [01:06<00:40, 49.68it/s]

Processed 3000 rows in 1510.42 seconds


Epoch 14:  80%|████████  | 4004/5000 [01:28<00:29, 33.63it/s]

Processed 4000 rows in 1532.15 seconds


Epoch 14: 100%|██████████| 5000/5000 [01:49<00:00, 45.48it/s]


Processed 5000 rows in 1554.02 seconds
Epoch 14/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Starting epoch 15/15


Epoch 15:  20%|██        | 1010/5000 [00:22<01:21, 48.71it/s]

Processed 1000 rows in 1576.40 seconds


Epoch 15:  40%|████      | 2006/5000 [00:43<01:01, 48.76it/s]

Processed 2000 rows in 1597.61 seconds


Epoch 15:  60%|██████    | 3005/5000 [01:06<00:41, 47.79it/s]

Processed 3000 rows in 1620.17 seconds


Epoch 15:  80%|████████  | 4005/5000 [01:27<00:25, 39.16it/s]

Processed 4000 rows in 1641.46 seconds


Epoch 15: 100%|██████████| 5000/5000 [01:49<00:00, 45.77it/s]

Processed 5000 rows in 1663.26 seconds
Epoch 15/15 completed
Average loss: 0.0000
Skipped rows in this epoch: 0
Training completed
Total time elapsed: 1663.27 seconds





4:主函数和训练循环

In [1]:
def retrieve_images(query, model, criterion, k=5):
    query_features = extract_text_features(query)
    query_output, _ = model(query_features, None)

    similarities = []
    for _, row in tqdm(train_df.iterrows(), total=len(train_df), desc="Retrieving images"):
        path = str(row['path'])
        img_features = extract_image_features(os.path.join(image_folder, path))
        if img_features is None:
            continue
        _, img_output = model(None, img_features)
        similarity = criterion(query_output, img_output)
        similarities.append((path, similarity.item()))

    return sorted(similarities, key=lambda x: x[1], reverse=True)[:k]

# 处理测试查询
results = []
for query in tqdm(test_df['title'], desc="Processing test queries"):
    top_k_images = retrieve_images(query, trained_model, trained_criterion)
    row = [query] + [img for img, _ in top_k_images]
    # 确保每行都有 6 个元素（1个查询 + 5个图像）
    while len(row) < 6:
        row.append('')  # 如果检索到的图像少于5个，用空字符串填充
    results.append(row[:6])  # 只取前6个元素，以防万一

# 打印一些调试信息
print(f"Number of results: {len(results)}")
print(f"Sample result: {results[0]}")

# 创建提交文件
try:
    submission_df = pd.DataFrame(results, columns=['title', 'image_1', 'image_2', 'image_3', 'image_4', 'image_5'])
    submission_df.to_csv('submit.csv', index=False, encoding='utf-8')
    print("Submission file created: submit.csv")
except Exception as e:
    print(f"Error creating submit file: {str(e)}")
    print(f"Results shape: {len(results)} rows, {len(results[0]) if results else 0} columns")

# 下载提交文件
files.download('submit.csv')

NameError: name 'tqdm' is not defined

5:图像检索和结果生成