# 枝領域の編集とRGB平面分離

In [None]:
import json
import os
import shutil
import numpy as np
from PIL import Image, ImageDraw
from sklearn.linear_model import LogisticRegression
import cv2
import matplotlib.pyplot as plt

## 1. CIVE値に基づくセグメンテーションマスクの修正

In [None]:
def edit_json_with_cive(json_path, image_dir, output_dir='output'):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    with open(json_path, 'r') as f:
        coco_data = json.load(f)

    new_coco_data = coco_data.copy()
    new_annotations = []

    for img_info in coco_data['images']:
        image_id = img_info['id']
        file_name = img_info['file_name']
        image_path = os.path.join(image_dir, file_name)

        if not os.path.exists(image_path):
            print(f'画像ファイルが見つかりません: {image_path}')
            continue

        image = Image.open(image_path).convert('RGB')
        width, height = image.size
        img_np = np.array(image)

        img_annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == image_id]

        for ann in img_annotations:
            new_ann = ann.copy()
            new_segments = []

            for seg in ann['segmentation']:
                if not isinstance(seg, list) or len(seg) < 6:
                    continue

                mask = Image.new('L', (width, height), 0)
                draw = ImageDraw.Draw(mask)
                draw.polygon(seg, outline=1, fill=1)
                mask_np = np.array(mask, dtype=np.uint8)

                pixels_in_mask = img_np[mask_np == 1]

                r = pixels_in_mask[:, 0].astype(np.float64)
                g = pixels_in_mask[:, 1].astype(np.float64)
                b = pixels_in_mask[:, 2].astype(np.float64)
                cive = 0.441 * r - 0.811 * g + 0.385 * b + 18.78745

                leaf_pixels_mask = cive < 0
                original_indices = np.where(mask_np == 1)
                leaf_coords = (original_indices[0][leaf_pixels_mask], original_indices[1][leaf_pixels_mask])
                mask_np[leaf_coords] = 0

                contours, _ = cv2.findContours(mask_np, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
                
                for contour in contours:
                    if contour.size >= 6: 
                        new_segments.append(contour.flatten().tolist())
            
            if new_segments:
                new_ann['segmentation'] = new_segments
                new_annotations.append(new_ann)

    new_coco_data['annotations'] = new_annotations

    new_json_path = os.path.join(output_dir, 'edited_' + os.path.basename(json_path))
    with open(new_json_path, 'w') as f:
        json.dump(new_coco_data, f, indent=4)
        
    print(f'処理が完了しました。新しいJSONファイルが保存されました: {new_json_path}')
    return new_json_path

In [None]:
input_json_path = './test_day_08/test_day_08.json'
input_image_dir = './test_day_08/'
output_dir_cive = 'cive_output'

edited_json_path = edit_json_with_cive(input_json_path, input_image_dir, output_dir_cive)

## 2. 更新されたセグメンテーションの視覚的確認

In [None]:
def visualize_segmentation(json_path, image_dir, num_images=5, random_selection=True):
    """
    JSONファイルのアノテーションを画像に描画して視覚的に確認する。
    
    Args:
        json_path (str): COCO形式のJSONファイルのパス。
        image_dir (str): 画像が格納されているディレクトリのパス。
        num_images (int): 表示する画像の最大数。
        random_selection (bool): ランダムに画像を選ぶか、先頭から順に選ぶか。
    """
    with open(json_path, 'r') as f:
        coco_data = json.load(f)

    image_id_to_info = {img['id']: img for img in coco_data['images']}

    image_id_to_anns = {} 
    for ann in coco_data['annotations']:
        img_id = ann['image_id']
        if img_id not in image_id_to_anns:
            image_id_to_anns[img_id] = []
        image_id_to_anns[img_id].append(ann)

    annotated_image_ids = list(image_id_to_anns.keys())
    if random_selection:
        selected_ids = np.random.choice(annotated_image_ids, min(num_images, len(annotated_image_ids)), replace=False)
    else:
        selected_ids = annotated_image_ids[:num_images]

    for img_id in selected_ids:
        img_info = image_id_to_info.get(img_id)
        if not img_info:
            continue

        file_name = img_info['file_name']
        image_path = os.path.join(image_dir, file_name)

        if not os.path.exists(image_path):
            print(f'画像が見つかりません: {image_path}')
            continue

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        overlay = image.copy()

        anns = image_id_to_anns[img_id]
        for ann in anns:
            for seg in ann['segmentation']:
                if isinstance(seg, list) and len(seg) >= 6:
                    poly = np.array(seg).reshape((-1, 2)).astype(np.int32)
                    cv2.fillPoly(overlay, [poly], color=(255, 0, 0))

        alpha = 0.4
        image_with_mask = cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0)

        plt.figure(figsize=(12, 10))
        plt.imshow(image_with_mask)
        plt.title(f'Image: {file_name} (ID: {img_id})')
        plt.axis('off')
        plt.show()

In [None]:
json_to_visualize = edited_json_path 
image_dir_to_visualize = input_image_dir

visualize_segmentation(json_to_visualize, image_dir_to_visualize, num_images=3)

## 3. 枝領域を分離するRGB平面式の算出

In [None]:
def calculate_rgb_plane(json_path, image_dir):
    with open(json_path, 'r') as f:
        coco_data = json.load(f)

    branch_pixels = []
    other_pixels = []

    for img_info in coco_data['images']:
        image_id = img_info['id']
        file_name = img_info['file_name']
        image_path = os.path.join(image_dir, file_name)

        if not os.path.exists(image_path):
            print(f'画像ファイルが見つかりません: {image_path}')
            continue

        image = Image.open(image_path).convert('RGB')
        width, height = image.size
        img_np = np.array(image)

        full_mask = np.zeros((height, width), dtype=np.uint8)
        img_annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == image_id]

        for ann in img_annotations:
            for seg in ann['segmentation']:
                if isinstance(seg, list) and len(seg) >= 6:
                    poly = np.array(seg).reshape((-1, 2)).astype(np.int32)
                    cv2.fillPoly(full_mask, [poly], 1)

        branch_pixels.extend(img_np[full_mask == 1].tolist())
        other_pixels.extend(img_np[full_mask == 0].tolist())

    if not branch_pixels or not other_pixels:
        print('枝領域または非枝領域のデータが収集できませんでした。')
        return

    X = np.array(branch_pixels + other_pixels)
    y = np.array([1] * len(branch_pixels) + [0] * len(other_pixels)) # 1: 枝, 0: その他

    model = LogisticRegression()
    model.fit(X, y)

    coef = model.coef_[0]  # w1, w2, w3
    intercept = model.intercept_[0] # w0
    equation = f'{coef[0]:.4f} * R + {coef[1]:.4f} * G + {coef[2]:.4f} * B + {intercept:.4f} = 0'
    print('算出されたRGB平面式:')
    print(equation)

    return equation

In [None]:
input_json_for_plane = edited_json_path
input_image_dir_for_plane = input_image_dir

plane_equation = calculate_rgb_plane(input_json_for_plane, input_image_dir_for_plane)

## 3-2. 損失関数の変更+pytorchによるロジスティクス回帰

In [None]:
import json
import random
from pathlib import Path
import copy

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from PIL import Image
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm.notebook import tqdm

json_path = "./train/train.json"
image_dir = "./train/"


# "BCEWithLogits" (推奨), "BCE", "MSE", "GCE"
loss_function_name = "BCEWithLogits"
gce_q_param = 0.7

# Early Stopping
patience = 5
delta = 0.0001

epochs = 1000
learning_rate = 0.01
batch_size = 1024
pixels_per_image = 2000
validation_split = 0.2

class EarlyStopping:
    """
    検証データの損失が改善しなくなったら学習を早期に終了させるためのクラス。
    """
    def __init__(self, patience=7, delta=0, path='checkpoint.pt'):
        self.patience = patience
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''検証損失が改善した場合にモデルを保存する'''
        print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

class GCELoss(nn.Module):
    def __init__(self, q=0.7):
        super(GCELoss, self).__init__()
        self.q = q
    
    def forward(self, logits, targets):
        probs = torch.sigmoid(logits)
        p_t = targets * probs + (1 - targets) * (1 - probs)
        loss = (1 - (p_t + 1e-8) ** self.q) / self.q
        return loss.mean()

class CocoSegmentationPixelDataset(Dataset):
    def __init__(self, json_path, image_dir, pixels_per_image=1000):
        self.image_dir = Path(image_dir)
        self.pixels_per_image = pixels_per_image
        self.pixels, self.labels = [], []
        print("Parsing COCO JSON file manually...")
        with open(json_path, 'r') as f:
            coco_data = json.load(f)
        self._img_infos = {img['id']: img for img in coco_data['images']}
        self._img_to_anns = {img_id: [] for img_id in self._img_infos}
        for ann in coco_data['annotations']:
            if 'segmentation' in ann and ann['segmentation']:
                self._img_to_anns[ann['image_id']].append(ann)
        self.img_ids = list(self._img_infos.keys())
        self._prepare_data()

    def _create_mask_from_annotations(self, height, width, annotations):
        mask = np.zeros((height, width), dtype=np.uint8)
        for ann in annotations:
            for seg_poly in ann['segmentation']:
                poly = np.array(seg_poly, dtype=np.int32).reshape((-1, 2))
                cv2.fillPoly(mask, [poly], 1)
        return mask

    def _prepare_data(self):
        print("Preparing pixel dataset...")
        for img_id in self.img_ids:
            img_info = self._img_infos[img_id]
            img_path = self.image_dir / img_info['file_name']
            if not img_path.exists(): continue
            image = Image.open(img_path).convert('RGB')
            image_np = np.array(image)
            annotations = self._img_to_anns.get(img_id, [])
            if not annotations: continue
            mask = self._create_mask_from_annotations(img_info['height'], img_info['width'], annotations)
            pixels = image_np.reshape(-1, 3); labels = mask.reshape(-1)
            num_pixels_total = pixels.shape[0]
            sample_indices = random.sample(range(num_pixels_total), min(num_pixels_total, self.pixels_per_image))
            self.pixels.append(torch.from_numpy(pixels[sample_indices]).float() / 255.0)
            self.labels.append(torch.from_numpy(labels[sample_indices]).float())
        if self.pixels and self.labels:
            self.pixels = torch.cat(self.pixels, dim=0)
            self.labels = torch.cat(self.labels, dim=0).unsqueeze(1)
            print(f"Dataset prepared. Total sampled pixels: {len(self.pixels)}")
        else: print("Warning: No valid data could be prepared.")

    def __len__(self): return len(self.pixels)
    def __getitem__(self, idx): return self.pixels[idx], self.labels[idx]

class LogisticRegressionModel(nn.Module):
    def __init__(self):
        super(LogisticRegressionModel, self).__init__()
        self.linear = nn.Linear(3, 1)
    def forward(self, x): return self.linear(x)

def train_model(model, train_loader, valid_loader, criterion, optimizer, num_epochs, early_stopping):
    """モデルの学習を行う（Early Stopping対応）"""
    for epoch in tqdm(range(num_epochs)):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        model.eval()
        valid_loss = 0.0
        with torch.no_grad():
            for inputs, labels in valid_loader:
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()

        epoch_train_loss = running_loss / len(train_loader)
        epoch_valid_loss = valid_loss / len(valid_loader)
        
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_train_loss:.6f}, Valid Loss: {epoch_valid_loss:.6f}")
        
        early_stopping(epoch_valid_loss, model)
        if early_stopping.early_stop:
            print("Early stopping")
            break

    print("Loading best model weights...")
    model.load_state_dict(torch.load(early_stopping.path))

dataset = CocoSegmentationPixelDataset(
    json_path=json_path, image_dir=image_dir, pixels_per_image=pixels_per_image
)

if len(dataset) > 0:
    dataset_size = len(dataset)
    val_size = int(np.floor(validation_split * dataset_size))
    train_size = dataset_size - val_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    print(f"Training on {len(train_dataset)} samples, validating on {len(val_dataset)} samples.")

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(val_dataset, batch_size=batch_size)
    base_model = LogisticRegressionModel()

    loss_functions = {
        "BCEWithLogits": (nn.BCEWithLogitsLoss(), copy.deepcopy(base_model)),
        "GCE": (GCELoss(gce_q_param), copy.deepcopy(base_model)),
        "BCE": (nn.BCELoss(), nn.Sequential(base_model, nn.Sigmoid())),
        "MSE": (nn.MSELoss(), nn.Sequential(base_model, nn.Sigmoid())),
        "MAE": (nn.L1Loss(), nn.Sequential(base_model, nn.Sigmoid()))
    }
    
    if loss_function_name not in loss_functions:
        raise ValueError(f"Unknown loss function: {loss_function_name}. Available options: {list(loss_functions.keys())}")

    criterion, model = loss_functions[loss_function_name]
    print(f"Using loss function: {loss_function_name}")
    
    optimizer = optim.SGD(model.parameters(), lr=learning_rate)
    early_stopping = EarlyStopping(patience=patience, delta=delta)

    train_model(model, train_loader, valid_loader, criterion, optimizer, epochs, early_stopping)

    print("\n--- Training Finished ---")

    if isinstance(model, nn.Sequential):
        weights = model[0].linear.weight.squeeze()
        bias = model[0].linear.bias.squeeze()
    else:
        weights = model.linear.weight.squeeze()
        bias = model.linear.bias.squeeze()

    w_r, w_g, w_b = weights[0].item(), weights[1].item(), weights[2].item()
    b = bias.item()

    print("Learned Parameters (from best model):")
    print(f"  Weight_R: {w_r:.6f}\n  Weight_G: {w_g:.6f}\n  Weight_B: {w_b:.6f}\n  Bias: {b:.6f}")
    
    print("\n--- RGB Separation Plane Equation ---")
    print(f"Equation: ({w_r:.6f})*R + ({w_g:.6f})*G + ({w_b:.6f})*B + ({b:.6f}) = 0")
    print("\nA pixel (R,G,B) is classified as part of the segmentation if the result of the equation is > 0.")
else:
    print("No data to train on. Exiting.")

## 4. RGB平面式による二値化の確認

In [None]:
# 例: 画像一枚に対してRGBの式を適用し，0を境界値として二値化し，図示
import cv2
import numpy as np
import matplotlib.pyplot as plt

# 画像パスを指定（例: 'sample.jpg'）
img_path = './train/NON0344_960x1280_2640.jpg'  # 適宜パスを変更
img = cv2.imread(img_path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

r = img_rgb[:,:,0]
g = img_rgb[:,:,1]
b = img_rgb[:,:,2]

# night
rgb_formula_day = 0.0884 * r + -0.0785 * g + -0.0025 * b + -4.8992
# day
rgb_formula_night = 0.0843 * r + -0.1429 * g + 0.0337 * b + -0.6300


# 0を境界値として二値化
binary_img = np.where((rgb_formula_day > 0) | (rgb_formula_night > 0), 255, 0).astype(np.uint8)

# 図示
plt.figure(figsize=(10,4))
plt.subplot(1,2,1)
plt.title('Original')
plt.imshow(img_rgb)
plt.axis('off')
plt.subplot(1,2,2)
plt.title('Binarized (R-G-B > 0)')
plt.imshow(binary_img, cmap='gray')
plt.axis('off')
plt.show()