In [23]:
# Cell 1: ライブラリのインポートとモデル定義
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split

import numpy as np
import warnings
import time
import tqdm

# PyTorchの警告を無視 (開発時は非推奨ですが、Notebookでの実行をスムーズにするため)
warnings.filterwarnings("ignore")

# デバイス設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用デバイス: {device}")

使用デバイス: cuda


In [2]:
# --- 定義されたモデルクラス (Binary_classification) ---
class Binary_classification(nn.Module):
    # NOTE: super()の引数を修正: super(Binary_classification_v2, self).__init__() -> super(Binary_classification, self).__init__()
    def __init__(self, latent, input_depth, input_height, input_width):
        super(Binary_classification, self).__init__()
        
        # モデル構造の定義
        self.features = nn.Sequential(
            nn.Conv3d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(True),
            
            nn.Conv3d(32, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(True),
            
            nn.Conv3d(32, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(64),
            nn.ReLU(True),
            
            nn.Conv3d(64, 32, kernel_size=3, stride=1, padding=0),
            nn.ReLU(True)
        )

        FINAL_FLATTEN_SIZE = 32 * 5 * 23 * 23 # 仮の値
        
        # --- 分類ヘッド ---
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(FINAL_FLATTEN_SIZE, latent),
            nn.ReLU(True), 

            nn.Linear(latent, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

print("モデルクラス定義完了。")

モデルクラス定義完了。


In [3]:
temp_model = Binary_classification(latent=100, input_depth=30, input_height=100, input_width=100) 
dummy_input = torch.randn(1, 1, 30, 100, 100) 

# 特徴抽出層まで実行
output_features = temp_model.features(dummy_input)

# 結果のサイズを確認
print(output_features.size())

torch.Size([1, 32, 5, 23, 23])


In [4]:
class DataSet:
    def __init__(self, data, label):
        self.label = label
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index], self.label[index]

In [5]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, path, patience=10, verbose=False, delta=0, trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score <= self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            # self.flog.write(f'EarlyStopping counter: {self.counter} out of {self.patience}\n')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    
    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        if self.verbose:
            self.trace_func(
                f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ..."
            )
            # self.flog.write(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...\n')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [6]:
early_stopping = EarlyStopping(patience=15, verbose=True, path="savedir")

In [7]:
bubble_data = np.load("/home/cygnus/fujimoto/Cygnus-X_Molecular_Cloud_Analysis/data/Binary_data/processed_data/slide_bubble.npy")
removal_data = np.load("/home/cygnus/fujimoto/Cygnus-X_Molecular_Cloud_Analysis/data/Binary_data/processed_data/all_data_after_bubble_removal.npy")

# バブルのラベルは1、非バブルのラベルは0
bubble_label = [1] * len(bubble_data)
removal_label = [0] * len(removal_data)

In [8]:
# print(len(bubble_data))
# print(len(removal_data))
# print(bubble_label)
# print(removal_label)

In [9]:
data = np.concatenate((bubble_data, removal_data))
label = np.concatenate((bubble_label, removal_label))

In [13]:
# data = torch.from_numpy(data).float()
train_data, val_data, train_labels, val_labels = train_test_split(
    data, label, test_size=0.2, random_state=42, stratify=label
)
val_data, test_data, val_labels, test_labels = train_test_split(
    val_data, val_labels, test_size=0.25, random_state=42, stratify=val_labels
)

# train_data = augment_data(train_data, augment_horizontal, augment_vertical, augment_velocity_axis)
# train_labels     = [0] * len(train_data)

train_dataset    = DataSet(train_data, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataset      = DataSet(val_data, val_labels)
val_dataloader   = DataLoader(val_dataset, batch_size=16, shuffle=False)
dataloader_dic   = {"train": train_dataloader, "val": val_dataloader}


train_dataset = DataSet(train_data, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataset = DataSet(val_data, val_labels)
val_dataloader = DataLoader(val_dataset, batch_size=16, shuffle=False)
dataloader_dic = {"train": train_dataloader, "val": val_dataloader}

In [19]:
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        init.kaiming_normal_(m.weight.data)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0.0)

In [20]:
model = Binary_classification(latent=100, input_depth=30, input_height=100, input_width=100)
model.apply(weights_init)
model.to(device)

Binary_classification(
  (features): Sequential(
    (0): Conv3d(1, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(32, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (7): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU(inplace=True)
    (9): Conv3d(64, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1))
    (10): ReLU(inplace=True)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=84640, out_features=100, bias=True)
    (2): ReLU(inplace=True)
    (3): Linear(in_features=100, out_features=1, bias=True)
    (4): Sigmoid()
  )
)

In [27]:
optimizer = optim.AdamW(
        model.parameters(), lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0.001, amsgrad=False
    )
criterion = nn.BCELoss()

In [46]:
train_loss_list = []
val_loss_list = []
best_val_loss = float('inf')
start = time.time()
num_epochs = 10

for epoch in range(num_epoch):
    train_loss_num = 0
    val_loss_num = 0

    for phase in ["train", "val"]:
        dataloader = dataloader_dic[phase]
        if phase == "train":
            model.train()  # モデルを訓練モードに
        else:
            model.eval()

        for images, labels in tqdm.tqdm(dataloader):
            images = images.view(-1, 1, 30, 100, 100)
            labels = labels.to(device).float().unsqueeze(1)
            optimizer.zero_grad()
            with torch.set_grad_enabled(phase == "train"):

                # モデルの出力を計算する
                output = model(images.clone().to(device))
                # print(output)
                # print(labels)
                # 損失を計算する
                loss = criterion(output.to("cpu"), labels.to("cpu"))
                weighted_loss = torch.mean(loss)

                # パラメータの更新
                if phase == "train":
                    weighted_loss.backward()
                    optimizer.step()
                    train_loss_num += weighted_loss.item()
                else:
                    val_loss_num += weighted_loss.item()

        if phase == "train":
            train_loss_list.append(train_loss_num)
        else:
            val_loss_list.append(val_loss_num)
            
    # wandb.log({"train loss": train_loss_num, "validation loss": val_loss_num, "epoch":  epoch})
    if val_loss_num < best_val_loss:
        best_val_loss = val_loss_num
        # wandb.log({"best validation loss": best_val_loss, "epoch":  epoch})
    
    print("Epoch [{}/{}], Loss: {:.4f}".format(epoch + 1, num_epochs, val_loss_num))

    early_stopping(val_loss_num, model)
    if early_stopping.early_stop:
        print("Early_Stopping")
        break

#train_loss_path = args.savedir_path + "/loss_log" + f"/train_loss_{args.wandb_name}.npy"
#val_loss_path = args.savedir_path + "/loss_log" + f"/val_loss_{args.wandb_name}.npy"


#np.save(train_loss_path, train_loss_list)
#np.save(val_loss_path, val_loss_list)

print((time.time() - start) / 60)

100%|██████████████████████████████████████████████████████████████████████████████████████| 283/283 [00:03<00:00, 80.08it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 54/54 [00:00<00:00, 180.87it/s]


Epoch [1/10], Loss: 3.2361
Validation loss decreased (inf --> 3.236115).  Saving model ...


RuntimeError: File savedir cannot be opened.

In [38]:
# Cell 4: 学習ループの実行

NUM_EPOCHS = 10 # エポック数

print(f"学習開始 (エポック数: {NUM_EPOCHS})")

for epoch in range(NUM_EPOCHS):
    model.train() # モデルを訓練モードに設定
    running_loss = 0.0
    correct_predictions = 0
    total_predictions = 0

    # DataLoaderからデータを取得
    for i, (inputs, labels) in enumerate(train_dataloader):
        inputs = inputs.view(1, -1, 30, 100, 100)
        print(inputs.shape)
        inputs, labels = inputs.to(device), labels.to(device)
        
        # 勾配をゼロにリセット
        optimizer.zero_grad()

        # 順伝播
        try:
            outputs = model(inputs)
        except RuntimeError as e:
            # RuntimeError (主にFINAL_FLATTEN_SIZEの誤り) の検出
            print("\n!!! RuntimeError発生: FINAL_FLATTEN_SIZEを確認してください !!!")
            print(f"エラー詳細: {e}")
            raise e # エラーを発生させてNotebookの実行を停止

        # 損失の計算
        loss = criterion(outputs, labels)

        # 逆伝播と最適化
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        
        # 精度計算
        predicted = (outputs > 0.5).float()
        correct_predictions += (predicted == labels).sum().item()
        total_predictions += labels.size(0)

    # エポック終了後の統計情報
    epoch_loss = running_loss / NUM_SAMPLES
    epoch_accuracy = correct_predictions / total_predictions
    
    print(f"Epoch [{epoch+1}/{NUM_EPOCHS}], Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")

print("学習完了。")

学習開始 (エポック数: 10)
torch.Size([1, 16, 30, 100, 100])

!!! RuntimeError発生: FINAL_FLATTEN_SIZEを確認してください !!!
エラー詳細: Given groups=1, weight of size [32, 1, 4, 4, 4], expected input[1, 16, 30, 100, 100] to have 1 channels, but got 16 channels instead


RuntimeError: Given groups=1, weight of size [32, 1, 4, 4, 4], expected input[1, 16, 30, 100, 100] to have 1 channels, but got 16 channels instead