# Training NeuroUNET
Testing training UNET model on a small subset of the data

TODO - try on an even smaller subset of data - 5 images split into tiles

In [None]:
from torch.utils.data import DataLoader
from pathlib import Path
import tqdm as notebook_tqdm

# custom
from dataset import EMDataset
from image_preprocessing import PreprocessImage
from model import train_model, NeuroUNET

In [2]:
# Collect image paths
image_dir = Path("./images/processed_zstack")
img_list = list(image_dir.glob("*.tif"))

# Split into train/val sets
train_split = int(0.8 * len(img_list))
train_idx = img_list[:train_split]
val_idx = img_list[train_split:]

In [3]:
# Create datasets
train_dataset = EMDataset(
    image_paths=train_idx,
    tile_size=256,
    stride=128,  # 50% overlap for training
    preprocess_fn=PreprocessImage,
    augment=True
)

val_dataset = EMDataset(
    image_paths=val_idx,
    tile_size=256,
    stride=256,  # No overlap for validation
    preprocess_fn=PreprocessImage,
    augment=False
)


In [4]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

In [5]:
print(f"Training dataset: {len(train_dataset)} tiles")
print(f"Validation dataset: {len(val_dataset)} tiles")

Training dataset: 27166 tiles
Validation dataset: 1944 tiles


In [6]:
model = NeuroUNET(in_channels=2, out_channels=2)
print(f'Total parameters: {sum(p.numel() for p in model.parameters())}')

Total parameters: 7702466


In [None]:
# okay need much more memory here :/
for batch_idx, batch in enumerate(train_loader):
    print(f"Batch {batch_idx}: shape {batch.shape}")
    
    train_model(model, train_loader, val_loader, device='cpu')
    
    if batch_idx == 0:
        break

Batch 0: shape torch.Size([16, 4, 256, 256])


  from .autonotebook import tqdm as notebook_tqdm


: 