In [None]:
import torch
from torch.utils.data import DataLoader

In [None]:
def main():
    train_dataset = load_dataset('histpre/', 'real_rgb.json')
    #これは後で設定する
    val_dataset = load_dataset('val_features/', 'val_labels.json')
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

    model = MLPModel(input_dim=1250, hidden_dim=256, output_dim=2)
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
    loss_fn = euclidean_loss

    num_epochs = 50
    for epoch in range(num_epochs):
        train_loss = train_one_epoch(model, train_loader, optimizer, loss_fn)
        val_loss = evaluate(model, val_loader, loss_fn)
        print(f"Epoch {epoch+1:02d}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

    torch.save(model.state_dict(), 'mlp_model.pth')


In [10]:
import os
import json
import pandas as pd

def load_dataset(csv_dir, json_path):
    # 1. JSON読み込みと辞書化
    with open(json_path, 'r') as f:
        json_data = json.load(f)
    
    rgb_dict = {item["filename"]: item["real_rgb"] for item in json_data}

    # 2. 特徴量とラベルの蓄積
    X_list = []
    y_list = []

    for filename in os.listdir(csv_dir):
        if not filename.endswith("_masked.csv"):
            continue

        base_id = filename.replace("_masked.csv", "")  # 例: 8D5U5524

        if base_id not in rgb_dict:
            print(f"Warning: {base_id} not in JSON, skipping")
            continue

        csv_path = os.path.join(csv_dir, filename)
        df = pd.read_csv(csv_path, header=None)

        # 特徴量は1行と仮定してflatten
        features = df.values.flatten()

        X_list.append(features)
        y_list.append(rgb_dict[base_id])

    # DataFrame or numpy配列に変換
    X = pd.DataFrame(X_list)
    y = pd.DataFrame(y_list, columns=["R", "G", "B"])

    return X, y


In [None]:
X, y = load_dataset(csv_dir="../histpre/", json_path="../real_rgb.json")

print(X.shape)  # 特徴量の数 x サンプル数
print(y.head()) # RGBのターゲット値

print(X.iloc[0])  # 1番目のサンプルの特徴量


X_np = X.values
y_np = y.values

(5, 1324)
        R       G       B
0   769.0  1043.0   653.0
1  1664.0  2315.0  1464.0
2   470.0   572.0   239.0
3  1413.0  1735.0   836.0
4   711.0   871.0   401.0
0       0
1       0
2       0
3       0
4       0
       ..
1319    0
1320    0
1321    0
1322    0
1323    0
Name: 0, Length: 1324, dtype: int64
