In [1]:

# from pathlib import Path
# import zarr

# def find_all_zarr_files():
#     """각 실험 디렉토리의 모든 zarr 파일 찾기"""
#     base_path = Path('./kaggle/input/czii-cryo-et-object-identification/train')
#     experiments = {}
    
#     for exp_dir in base_path.glob("**/VoxelSpacing10.000"):
#         exp_name = exp_dir.parent.name  # 실험 이름 (예: TS_5_4)
#         zarr_files = list(exp_dir.glob("*.zarr"))
        
#         if zarr_files:
#             experiments[exp_name] = {
#                 'path': str(exp_dir),
#                 'files': [f.name for f in zarr_files]
#             }
    
#     return experiments

# def test_all_zarr_files():
#     """모든 실험의 zarr 파일 정보 출력"""
#     experiments = find_all_zarr_files()
    
#     for exp_name, info in experiments.items():
#         print(f"\n실험: {exp_name}")
#         print(f"경로: {info['path']}")
#         print("zarr 파일들:")
#         for zarr_file in info['files']:
#             zarr_path = Path(info['path']) / zarr_file
#             try:
#                 root = zarr.open(str(zarr_path), mode='r')
#                 sample = root[0]
#                 print(f"- {zarr_file}: shape={sample.shape}")
#             except Exception as e:
#                 print(f"- {zarr_file}: 오류 발생 ({str(e)})")

# if __name__ == "__main__":
#     test_all_zarr_files()

In [2]:
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from typing import Dict, List, Tuple
from monai.transforms import (
    Compose,
    EnsureChannelFirstd,
    NormalizeIntensityd
)

class CryoET_2_5D_Dataset(Dataset):
    def __init__(self, 
                 data_dir: str, 
                 mode: str = 'train',
                 slice_depth: int = 3,
                 crop_size: Tuple[int, int] = (96, 96),
                 transform = None):
        """
        Args:
            data_dir: 데이터셋 기본 경로
            mode: 'train' 또는 'val'
            slice_depth: 중심 기준 슬라이스 수 (항상 홀수)
            crop_size: (height, width) crop 크기
            transform: MONAI transform
        """
        self.data_dir = Path(data_dir) / mode
        self.slice_depth = slice_depth
        self.crop_size = crop_size
        self.transform = transform
        self.pad_size = slice_depth // 2
        
        # 이미지와 라벨 파일 리스트
        self.image_files = sorted(list((self.data_dir / 'images').glob('*.npy')))
        self.label_files = sorted(list((self.data_dir / 'labels').glob('*.npy')))
        
        assert len(self.image_files) == len(self.label_files), "이미지와 라벨 수가 다릅니다"
        
    def __len__(self):
        return len(self.image_files)
        
    def __getitem__(self, idx) -> Dict:
        # 이미지와 라벨 로드
        image = np.load(self.image_files[idx])  # (D, H, W)
        label = np.load(self.label_files[idx])  # (D, H, W)
        
        # 원본 shape 저장
        original_shape = image.shape
        
        # 중심 슬라이스 선택
        center_idx = image.shape[0] // 2
        start_idx = center_idx - self.pad_size
        end_idx = center_idx + self.pad_size + 1
        
        # 패딩 처리
        if start_idx < 0 or end_idx > image.shape[0]:
            padded_image = np.pad(
                image,
                ((self.pad_size, self.pad_size), (0, 0), (0, 0)),
                mode='edge'
            )
            start_idx += self.pad_size
            end_idx += self.pad_size
            slices = padded_image[start_idx:end_idx]
        else:
            slices = image[start_idx:end_idx]
        
        # 중심 슬라이스의 라벨만 사용
        center_label = label[center_idx]
        
        # 중앙 crop 좌표 계산
        start_h = (image.shape[1] - self.crop_size[0]) // 2
        start_w = (image.shape[2] - self.crop_size[1]) // 2
        
        # crop 적용
        slices_cropped = slices[:, 
                              start_h:start_h + self.crop_size[0], 
                              start_w:start_w + self.crop_size[1]]
        label_cropped = center_label[start_h:start_h + self.crop_size[0],
                                   start_w:start_w + self.crop_size[1]]
        
        data_dict = {
            'image': np.array(slices_cropped)[None, ...],  # (1, D, H, W)
            'label': np.array(label_cropped)[None, ...],   # (1, H, W)
            'crop_info': {
                'z_start': start_idx,
                'z_center': center_idx,
                'y_start': start_h,
                'x_start': start_w,
                'crop_size': self.crop_size,
                'original_shape': original_shape,
                'file_name': self.image_files[idx].stem
            }
        }
        
        if self.transform:
            data_dict = self.transform(data_dict)
            
        return data_dict

# Transform 설정
train_transforms = Compose([
    EnsureChannelFirstd(
        keys=['image', 'label'],
        channel_dim=0,
        allow_missing_keys=True
    ),
    NormalizeIntensityd(
        keys=['image'],
        allow_missing_keys=True
    )
])

def create_dataloaders(data_dir: str, batch_size: int = 8):
    train_ds = CryoET_2_5D_Dataset(
        data_dir=data_dir,
        mode='train',
        transform=train_transforms,
        slice_depth=3,
        crop_size=(96, 96)
    )
    
    val_ds = CryoET_2_5D_Dataset(
        data_dir=data_dir,
        mode='val',
        transform=None,
        slice_depth=3,
        crop_size=(96, 96)
    )
    
    train_loader = DataLoader(
        train_ds, 
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=torch.cuda.is_available()
    )
    
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        num_workers=0,
        pin_memory=torch.cuda.is_available()
    )
    
    return train_loader, val_loader

# 테스트
if __name__ == "__main__":
    data_dir = "./datasets"
    train_loader, val_loader = create_dataloaders(data_dir)
    
    # 데이터 확인
    batch = next(iter(train_loader))
    print("이미지 shape:", batch['image'].shape)
    print("라벨 shape:", batch['label'].shape)
    print("Crop 정보:", batch['crop_info'])

  from .autonotebook import tqdm as notebook_tqdm


이미지 shape: torch.Size([8, 1, 3, 96, 96])
라벨 shape: torch.Size([8, 1, 96, 96])
Crop 정보: {'z_start': tensor([91, 91, 91, 91, 91, 91, 91, 91]), 'z_center': tensor([92, 92, 92, 92, 92, 92, 92, 92]), 'y_start': tensor([267, 267, 267, 267, 267, 267, 267, 267]), 'x_start': tensor([267, 267, 267, 267, 267, 267, 267, 267]), 'crop_size': [tensor([96, 96, 96, 96, 96, 96, 96, 96]), tensor([96, 96, 96, 96, 96, 96, 96, 96])], 'original_shape': [tensor([184, 184, 184, 184, 184, 184, 184, 184]), tensor([630, 630, 630, 630, 630, 630, 630, 630]), tensor([630, 630, 630, 630, 630, 630, 630, 630])], 'file_name': ['isonetcorrected_TS_69_2_image', 'denoised_TS_69_2_image', 'isonetcorrected_TS_5_4_image', 'isonetcorrected_TS_6_4_image', 'ctfdeconvolved_TS_73_6_image', 'denoised_TS_6_4_image', 'wbp_TS_86_3_image', 'wbp_TS_73_6_image']}


In [None]:
from CSANet.CSANet.networks.vit_seg_modeling import VisionTransformer as ViT_seg
from CSANet.CSANet.networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
import torch.optim as optim
from tqdm import tqdm
import numpy as np
import torch
from monai.losses import TverskyLoss

# Model Configuration
vit_name = 'R50-ViT-B_16'
config_vit = CONFIGS_ViT_seg[vit_name]
config_vit.n_classes = 7  # Your number of classes
config_vit.n_skip = 3
img_size = 96  # Match your patch size
vit_patches_size = 16

# Initialize model
config_vit.patches.grid = (int(img_size / vit_patches_size), int(img_size / vit_patches_size))
model = ViT_seg(config_vit, img_size=img_size, num_classes=config_vit.n_classes)
model.cuda()

# Load pretrained weights
model.load_from(weights=np.load(config_vit.real_pretrained_path, allow_pickle=True))

# Training setup
num_epochs = 40
optimizer = optim.AdamW(model.parameters(), lr=0.001)
# TverskyLoss 설정
criterion = TverskyLoss(
    alpha=0.3,  # FP에 대한 가중치
    beta=0.7,   # FN에 대한 가중치
    include_background=True,
    softmax=True
)


# Training loop 수정
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    with tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}') as pbar:
        for batch_data in pbar:
            images = batch_data['image'].cuda()
            labels = batch_data['label'].cuda()
            
            # 차원 분리 및 reshape
            # print(images.shape)
            if images.shape[2] == 96:
                images = images.permute(0, 1, 4, 2, 3)
                labels = labels.permute(0, 1, 4, 2, 3)
                # print(images.shape)
            # prev_image, image, next_image = torch.split(images, 1, dim=2)
            center = images.shape[2] // 2
            prev_image = images[:, :, 0:center, :, :]
            image = images[:, :, center:center+1, :, :]
            next_image = images[:, :, center+1:, :, :]
            
            # 필요한 차원 형태로 변환
            prev_image = prev_image.squeeze(2)  # [B, C, H, W]
            image = image.squeeze(2)
            next_image = next_image.squeeze(2)
            
            # 라벨 처리
            #labels = labels[:, :, center:center+1, :, :]
            #labels = labels.squeeze(2)  # [B, 1, H, W]
            labels = labels.squeeze(1)  # [B, H, W]

            optimizer.zero_grad()
            outputs = model(prev_image, image, next_image)

            # 라벨을 long으로 변환하고 one-hot encoding
            labels = labels.long()
            labels_onehot = torch.nn.functional.one_hot(labels, num_classes=config_vit.n_classes)  # [B, H, W, C]
            # 차원 순서 변경: [B, H, W, C] -> [B, C, H, W]
            labels_onehot = labels_onehot.permute(0, 3, 1, 2).float()
            
            loss = criterion(outputs, labels_onehot)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            pbar.set_postfix({'loss': loss.item()})
    
    # Training loop는 그대로 유지

    # Training loop 내부의 validation과 saving 코드 수정
    # Validation
    if (epoch + 1) % 1 == 0:  # Validate every 5 epochs
        model.eval()
        val_loss = 0
        with torch.no_grad():
            val_loop = tqdm(val_loader, desc=f'Validation Epoch {epoch+1}', leave=False)
            for val_data in val_loop:
                val_images = val_data['image'].cuda()
                val_labels = val_data['label'].cuda()
                
                # Training과 동일한 방식으로 데이터 처리
                val_prev, val_curr, val_next = torch.split(val_images, 1, dim=2)
                val_prev = val_prev.squeeze(2)
                val_curr = val_curr.squeeze(2)
                val_next = val_next.squeeze(2)
                
                val_labels = val_labels.squeeze(1)
                val_labels = val_labels.long()
                
                val_outputs = model(val_prev, val_curr, val_next)
                
                val_labels_onehot = torch.nn.functional.one_hot(val_labels, num_classes=config_vit.n_classes)
                val_labels_onehot = val_labels_onehot.permute(0, 3, 1, 2).float()
                
                current_loss = criterion(val_outputs, val_labels_onehot).item()
                val_loss += current_loss
                val_loop.set_postfix({'val_loss': current_loss})
        ''
        print(f'Validation Loss: {val_loss/len(val_loader):.4f}')
        
    # 체크포인트 저장
    if (epoch + 1) % 10 == 0:  # Save every 10 epochs
        # 체크포인트 디렉토리 생성
        checkpoint_dir = Path("./model_checkpoint")
        checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        with tqdm(total=1, desc=f'Saving checkpoint for epoch {epoch+1}', leave=False) as pbar:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': epoch_loss,
            }, checkpoint_dir / f'checkpoint_epoch_{epoch+1}.pth')
            pbar.update(1)


load_pretrained: grid-size from 14 to 6


Epoch 1/40: 100%|██████████| 3/3 [00:05<00:00,  1.97s/it, loss=0.966]
                                                                                 

Validation Loss: 0.9633


Epoch 2/40: 100%|██████████| 3/3 [00:04<00:00,  1.53s/it, loss=0.953]
                                                                                 

Validation Loss: 0.9656


Epoch 3/40: 100%|██████████| 3/3 [00:04<00:00,  1.53s/it, loss=0.935]
                                                                                 

Validation Loss: 0.9656


Epoch 4/40: 100%|██████████| 3/3 [00:04<00:00,  1.51s/it, loss=0.919]
                                                                                 

Validation Loss: 0.9677


Epoch 5/40: 100%|██████████| 3/3 [00:04<00:00,  1.51s/it, loss=0.903]
                                                                                 

Validation Loss: 0.9689


Epoch 6/40: 100%|██████████| 3/3 [00:04<00:00,  1.55s/it, loss=0.885]
                                                                                

Validation Loss: 0.9505


Epoch 7/40: 100%|██████████| 3/3 [00:04<00:00,  1.57s/it, loss=0.875]
                                                                                 

Validation Loss: 0.9416


Epoch 8/40: 100%|██████████| 3/3 [00:04<00:00,  1.62s/it, loss=0.843]
                                                                                 

Validation Loss: 0.9027


Epoch 9/40: 100%|██████████| 3/3 [00:04<00:00,  1.55s/it, loss=0.856]
                                                                                 

Validation Loss: 0.8631


Epoch 10/40: 100%|██████████| 3/3 [00:04<00:00,  1.55s/it, loss=0.831]
                                                                                  

Validation Loss: 0.8464


Epoch 11/40: 100%|██████████| 3/3 [00:04<00:00,  1.58s/it, loss=0.822]       
                                                                                  

Validation Loss: 0.8687


Epoch 12/40: 100%|██████████| 3/3 [00:04<00:00,  1.53s/it, loss=0.814]
                                                                                  

Validation Loss: 0.8617


Epoch 13/40: 100%|██████████| 3/3 [00:04<00:00,  1.51s/it, loss=0.825]
                                                                                  

Validation Loss: 0.8263


Epoch 14/40: 100%|██████████| 3/3 [00:04<00:00,  1.53s/it, loss=0.779]
                                                                                  

Validation Loss: 0.8430


Epoch 15/40: 100%|██████████| 3/3 [00:04<00:00,  1.59s/it, loss=0.802]
                                                                                  

Validation Loss: 0.8139


Epoch 16/40: 100%|██████████| 3/3 [00:04<00:00,  1.58s/it, loss=0.814]
                                                                                  

Validation Loss: 0.8506


Epoch 17/40: 100%|██████████| 3/3 [00:04<00:00,  1.55s/it, loss=0.781]
                                                                                  

Validation Loss: 0.8457


Epoch 18/40: 100%|██████████| 3/3 [00:04<00:00,  1.56s/it, loss=0.817]
                                                                                  

Validation Loss: 0.8094


Epoch 19/40: 100%|██████████| 3/3 [00:04<00:00,  1.55s/it, loss=0.839]
                                                                                  

Validation Loss: 0.8847


Epoch 20/40: 100%|██████████| 3/3 [00:04<00:00,  1.58s/it, loss=0.84] 
                                                                                  

Validation Loss: 0.8834


Epoch 21/40: 100%|██████████| 3/3 [00:05<00:00,  1.67s/it, loss=0.822]       
                                                                                  

Validation Loss: 0.8235


Epoch 22/40: 100%|██████████| 3/3 [00:04<00:00,  1.63s/it, loss=0.805]
                                                                                  

Validation Loss: 0.8426


Epoch 23/40: 100%|██████████| 3/3 [00:05<00:00,  1.74s/it, loss=0.807]
                                                                                  

Validation Loss: 0.8722


Epoch 24/40: 100%|██████████| 3/3 [00:05<00:00,  1.70s/it, loss=0.789]
                                                                                 

Validation Loss: 0.8502


Epoch 25/40: 100%|██████████| 3/3 [00:04<00:00,  1.65s/it, loss=0.818]
                                                                                  

Validation Loss: 0.8025


Epoch 26/40: 100%|██████████| 3/3 [00:04<00:00,  1.59s/it, loss=0.817]
                                                                                  

Validation Loss: 0.8007


Epoch 27/40: 100%|██████████| 3/3 [00:04<00:00,  1.66s/it, loss=0.801]
                                                                                  

Validation Loss: 0.8170


Epoch 28/40: 100%|██████████| 3/3 [00:04<00:00,  1.60s/it, loss=0.801]
                                                                                  

Validation Loss: 0.8146


Epoch 29/40: 100%|██████████| 3/3 [00:04<00:00,  1.66s/it, loss=0.791]
                                                                                 

Validation Loss: 0.8205


Epoch 30/40: 100%|██████████| 3/3 [00:04<00:00,  1.64s/it, loss=0.8]  
                                                                                  

Validation Loss: 0.8261


Epoch 31/40: 100%|██████████| 3/3 [00:04<00:00,  1.61s/it, loss=0.777]       
                                                                                  

Validation Loss: 0.8251


Epoch 32/40: 100%|██████████| 3/3 [00:05<00:00,  1.69s/it, loss=0.808]
                                                                                  

Validation Loss: 0.8228


Epoch 33/40: 100%|██████████| 3/3 [00:04<00:00,  1.59s/it, loss=0.792]
                                                                                  

Validation Loss: 0.8252


Epoch 34/40: 100%|██████████| 3/3 [00:04<00:00,  1.56s/it, loss=0.808]
                                                                                  

Validation Loss: 0.8273


Epoch 35/40: 100%|██████████| 3/3 [00:05<00:00,  1.68s/it, loss=0.85] 
                                                                                  

Validation Loss: 0.8243


Epoch 36/40: 100%|██████████| 3/3 [00:05<00:00,  1.68s/it, loss=0.79] 
                                                                                  

Validation Loss: 0.8266


Epoch 37/40: 100%|██████████| 3/3 [00:04<00:00,  1.65s/it, loss=0.778]
                                                                                  

Validation Loss: 0.8249


Epoch 38/40: 100%|██████████| 3/3 [00:04<00:00,  1.65s/it, loss=0.807]
                                                                                  

Validation Loss: 0.8193


Epoch 39/40: 100%|██████████| 3/3 [00:05<00:00,  1.68s/it, loss=0.797]
                                                                                  

Validation Loss: 0.8331


Epoch 40/40: 100%|██████████| 3/3 [00:04<00:00,  1.63s/it, loss=0.814]
                                                                                  

Validation Loss: 0.8413


                                                                             