In [36]:
import torch
import os
import json
import pandas as pd
from torch.utils.data import DataLoader, TensorDataset
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

In [23]:
def load_dataset(csv_dir, json_path):
    # 1. JSON読み込みと辞書化
    with open(json_path, 'r') as f:
        json_data = json.load(f)
    
    rgb_dict = {item["filename"]: item["real_rgb"] for item in json_data}

    # 2. 特徴量とラベルの蓄積
    X_list = []
    y_list = []

    for filename in os.listdir(csv_dir):
        if not filename.endswith("_masked.csv"):
            continue

        base_id = filename.replace("_masked.csv", "")  # 例: 8D5U5524

        if base_id not in rgb_dict:
            print(f"Warning: {base_id} not in JSON, skipping")
            continue

        csv_path = os.path.join(csv_dir, filename)
        df = pd.read_csv(csv_path, header=None)

        # 特徴量は1行と仮定してflatten
        features = df.values.flatten()

        X_list.append(features)
        y_list.append(rgb_dict[base_id])

    # DataFrame or numpy配列に変換
    X = pd.DataFrame(X_list)
    y = pd.DataFrame(y_list, columns=["R", "G", "B"])

    return X, y

In [24]:
#load_datasetのテスト
X, y = load_dataset(csv_dir="../histpre/", json_path="../real_rgb.json")

print(X.shape)  # 特徴量の数 x サンプル数
print(y.head()) # RGBのターゲット値

print(X.iloc[0])  # 1番目のサンプルの特徴量


(5, 1324)
        R       G       B
0   769.0  1043.0   653.0
1  1664.0  2315.0  1464.0
2   470.0   572.0   239.0
3  1413.0  1735.0   836.0
4   711.0   871.0   401.0
0       0
1       0
2       0
3       0
4       0
       ..
1319    0
1320    0
1321    0
1322    0
1323    0
Name: 0, Length: 1324, dtype: int64


In [25]:

class MLPModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

In [26]:

def euclidean_loss(pred, target):
    return torch.sqrt(((pred - target) ** 2).sum(dim=1)).mean()

In [27]:
def mse_chromaticity_loss(pred, target, eps=1e-8):
    # クロマティシティ座標に変換：r = R/(R+G+B), g = G/(R+G+B)
    pred_sum = pred.sum(dim=1, keepdim=True) + eps
    target_sum = target.sum(dim=1, keepdim=True) + eps

    pred_chroma = pred[:, :2] / pred_sum  # (r, g)
    target_chroma = target[:, :2] / target_sum

    loss = ((pred_chroma - target_chroma) ** 2).mean()
    return loss


In [28]:
def train_one_epoch(model, loader, optimizer, loss_fn):
    model.train()  # モデルを訓練モードに設定
    total_loss = 0.0

    for X_batch, y_batch in loader:
        optimizer.zero_grad()               # 勾配をリセット
        pred = model(X_batch)               # 順伝播
        loss = loss_fn(pred, y_batch)       # 損失計算
        loss.backward()                     # 逆伝播
        optimizer.step()                    # パラメータ更新

        total_loss += loss.item()           # 損失を蓄積

    average_loss = total_loss / len(loader)  # バッチ数で割る
    return average_loss

In [29]:
def evaluate(model, loader, loss_fn):
    model.eval()  # 評価モードに切り替え
    total_loss = 0.0

    with torch.no_grad():  # 勾配を計算しない（メモリ節約＆高速化）
        for X_batch, y_batch in loader:
            pred = model(X_batch)
            loss = loss_fn(pred, y_batch)
            total_loss += loss.item()

    average_loss = total_loss / len(loader)
    return average_loss

In [48]:
def main():
    # 1. モデル構造の再定義（構造は学習と同じに！）
    model = MLPModel(input_dim=1324, hidden_dim=256, output_dim=2)
    model.load_state_dict(torch.load("mlp_model.pth"))
    model.eval()

    # 2. テストデータの読み込み
    X_test_df, y_test_df = load_dataset("../src", "../real_rgb.json")
    X_test = torch.tensor(X_test_df.values, dtype=torch.float32)
    y_test = torch.tensor(y_test_df[["R", "G", "B"]].values, dtype=torch.float32)

    test_dataset = TensorDataset(X_test, y_test)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # 3. 評価実行（クロマティシティMSE）
    test_loss = evaluate(model, test_loader, mse_chromaticity_loss)
    print(f"📊 Test Loss = {test_loss:.4f}")

    # 4. 予測と実際のRGB値の表示（5件だけ例として）
    print("\n🎨 Prediction vs Actual (first 5 samples):")
    
    num_samples = min(5, len(X_test))  # データ数に合わせる
    with torch.no_grad():
        for i in range(num_samples):
            x = X_test[i].unsqueeze(0)  # shape: (1, input_dim)
            pred = model(x)[0]          # クロマティシティ座標 (r, g)

             # モデル出力（r_pred, g_pred）
            r_pred, g_pred = pred[0].item(), pred[1].item()

            # 実際のRGBから (r_true, g_true) を計算
            true_rgb = y_test[i].numpy()
            total = np.sum(true_rgb)
            if total > 0:
                r_true = true_rgb[0] / total
                g_true = true_rgb[1] / total
            else:
                r_true = g_true = 0.0  # total = 0 の場合の例外処理

            print(f"{i+1}: Pred (r, g): ({r_pred:.4f}, {g_pred:.4f}) | True (r, g): ({r_true:.4f}, {g_true:.4f})")


In [49]:
if __name__ == "__main__":
    main()

📊 Test Loss = 0.0309

🎨 Prediction vs Actual (first 5 samples):
1: Pred (r, g): (-1.3967, -2.7917) | True (r, g): (0.2777, 0.4242)
