In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import LiverTumorDataset
from unet import UNet

In [3]:
def dice_loss(pred, target, smooth=1e-6):
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return 1 - (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
bce = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [5]:
print(model)

UNet(
  (enc1): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc2): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (enc3): DoubleConv(
    (conv): Sequential(
      (0): Conv2d(128, 256, kernel_size=

In [None]:
epochs_per_chunk = 2 
chunk_files = [f"chunk_{i}.txt" for i in range(10)]

for chunk_id, chunk_file in enumerate(chunk_files):
    print(f"\nTraining on {chunk_file} (Chunk {chunk_id+1}/10)")
    
    train_dataset = LiverTumorDataset(chunk_file, image_size=(256, 256))
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)

    for epoch in range(epochs_per_chunk):
        model.train()
        total_loss = 0

        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)
            loss = 0.5 * bce(outputs, masks) + 0.5 * dice_loss(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Chunk {chunk_id+1} | Epoch {epoch+1}/{epochs_per_chunk} | Loss: {avg_loss:.4f}")

    # Saving model after each chunk
    torch.save(model.state_dict(), f"unet_liver_stage_{chunk_id+1}.pth")
    print(f"💾 Saved: unet_liver_stage_{chunk_id+1}.pth")

# Final model
torch.save(model.state_dict(), "unet_liver.pth")
print("\n Final model saved as unet_liver.pth")


Training on chunk_0.txt (Chunk 1/10)
Chunk 1 | Epoch 1/2 | Loss: 0.4742
Chunk 1 | Epoch 2/2 | Loss: 0.2952
💾 Saved: unet_liver_stage_1.pth

Training on chunk_1.txt (Chunk 2/10)
Chunk 2 | Epoch 1/2 | Loss: 0.2342
Chunk 2 | Epoch 2/2 | Loss: 0.2145
💾 Saved: unet_liver_stage_2.pth

Training on chunk_2.txt (Chunk 3/10)
Chunk 3 | Epoch 1/2 | Loss: 0.2108
Chunk 3 | Epoch 2/2 | Loss: 0.1993
💾 Saved: unet_liver_stage_3.pth

Training on chunk_3.txt (Chunk 4/10)
Chunk 4 | Epoch 1/2 | Loss: 0.1914
Chunk 4 | Epoch 2/2 | Loss: 0.1971
💾 Saved: unet_liver_stage_4.pth

Training on chunk_4.txt (Chunk 5/10)
Chunk 5 | Epoch 1/2 | Loss: 0.1979
Chunk 5 | Epoch 2/2 | Loss: 0.1900
💾 Saved: unet_liver_stage_5.pth

Training on chunk_5.txt (Chunk 6/10)
Chunk 6 | Epoch 1/2 | Loss: 0.1837
Chunk 6 | Epoch 2/2 | Loss: 0.1851
💾 Saved: unet_liver_stage_6.pth

Training on chunk_6.txt (Chunk 7/10)
