In [1]:
import torch
import torch.nn as nn
import torch.nn.init as init
from torch.utils.data import TensorDataset, DataLoader
from torchviz import make_dot
from torchsummary import summary

import numpy as np
import matplotlib.pyplot as plt
import japanize_matplotlib
import re

from sklearn import svm
from sklearn.model_selection import train_test_split

import glob
import os
import time
import sys
import wandb
import json
from tqdm.notebook import tqdm

sys.path.append("/home/cygnus/fujimoto/Cygnus-X_Molecular_Cloud_Analysis/Binary_classification/training/models")
from Binary_classification import Binary_classification

In [2]:
class DataSet():
    def __init__(self, data, label):
        self.label = label
        self.data = data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        return self.data[index], self.label[index]

In [3]:
weight_path = "/home/cygnus/fujimoto/Cygnus-X_Molecular_Cloud_Analysis/Binary_classification/training/save_dir/model_parameter.pth"
weight_para = torch.load(weight_path, map_location=torch.device('cpu'))

In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 
print(f"使用デバイス: {device}")

使用デバイス: cuda:0


In [5]:
model = Binary_classification(latent=100, input_depth=30, input_height=100, input_width=100)
model.load_state_dict(weight_para)
model.to(device)
model.eval()

Binary_classification(
  (features): Sequential(
    (0): Conv3d(1, 16, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv3d(16, 32, kernel_size=(4, 4, 4), stride=(2, 2, 2), padding=(1, 1, 1))
    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
    (8): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv3d(32, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1))
    (11): ReLU(inplace=True)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=1600, out_features=100, bias=True)
    (2): ReLU(inplace=True)
    (

In [6]:
summary(model.to(device), (1, 30, 100, 100))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv3d-1       [-1, 16, 15, 50, 50]           1,040
       BatchNorm3d-2       [-1, 16, 15, 50, 50]              32
              ReLU-3       [-1, 16, 15, 50, 50]               0
            Conv3d-4        [-1, 32, 7, 25, 25]          32,800
       BatchNorm3d-5        [-1, 32, 7, 25, 25]              64
              ReLU-6        [-1, 32, 7, 25, 25]               0
         MaxPool3d-7        [-1, 32, 3, 12, 12]               0
            Conv3d-8        [-1, 32, 3, 12, 12]          27,680
       BatchNorm3d-9        [-1, 32, 3, 12, 12]              64
             ReLU-10        [-1, 32, 3, 12, 12]               0
           Conv3d-11        [-1, 16, 1, 10, 10]          13,840
             ReLU-12        [-1, 16, 1, 10, 10]               0
          Flatten-13                 [-1, 1600]               0
           Linear-14                  [

In [7]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 

In [8]:
bubble_data = np.load("/home/cygnus/fujimoto/Cygnus-X_Molecular_Cloud_Analysis/Binary_classification/data/processed_data/slide_bubble.npy")
removal_data = np.load("/home/cygnus/fujimoto/Cygnus-X_Molecular_Cloud_Analysis/Binary_classification/data/processed_data/all_data_after_bubble_removal.npy")

# バブルのラベルは1、非バブルのラベルは0
bubble_label = [1] * len(bubble_data)
removal_label = [0] * len(removal_data)

In [9]:
data = np.concatenate((bubble_data, removal_data))
label = np.concatenate((bubble_label, removal_label))

In [10]:
# data = torch.from_numpy(data).float()
train_data, val_data, train_labels, val_labels = train_test_split(
    data, label, test_size=0.2, random_state=42, stratify=label
)
val_data, test_data, val_labels, test_labels = train_test_split(
    val_data, val_labels, test_size=0.25, random_state=42, stratify=val_labels
)

train_dataset    = DataSet(train_data, train_labels)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_dataset      = DataSet(val_data, val_labels)
val_dataloader   = DataLoader(val_dataset, batch_size=16, shuffle=False)
dataloader_dic   = {"train": train_dataloader, "val": val_dataloader}

test_dataset = DataSet(test_data, test_labels)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [52]:
model.to(device)

# output_list = []
test_correct_preds = 0
test_total_samples = 0
test_true_positives = 0
test_actual_positives = 0

for images, labels in tqdm(test_dataloader):
    images = images.view(-1, 1, 30, 100, 100)  # バッチサイズを維持したままチャンネル数を1に設定
    labels = labels.to(device).float()
    
    with torch.set_grad_enabled(True):
        # モデルの出力を計算する
        images = images.float()
        output = model(images.clone().to(device))
        output = output.squeeze()
        predicted = (output > 0.5).float()

        test_correct_preds += (predicted == labels).sum().item()
        test_total_samples += labels.size(0)
                    
        # 2. Recallの計算
        test_true_positives += ((predicted == 1) & (labels == 1)).sum().item()
        test_actual_positives += (labels == 1).sum().item()

test_accuracy = test_correct_preds / test_total_samples if test_total_samples > 0 else 0.0
test_recall = test_true_positives / test_actual_positives if test_actual_positives > 0 else 0.0

  0%|          | 0/18 [00:00<?, ?it/s]

In [53]:
print("Val Accuracy: {:.4f}, Val Recall: {:.4f}".format(test_accuracy, test_recall))

Val Accuracy: 0.9965, Val Recall: 0.9286


In [51]:
print(f"test actual positives: {test_actual_positives}")
print(f"test actual positives: {val_true_positives}")
print(f"test actual positives: {val_total_samples}")
print(f"test actual positives: {val_correct_preds}")

14
13
283
282


In [43]:
print(predicted == labels)

tensor([True, True, True, True, True, True, True, True, True, True, True],
       device='cuda:0')
