In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from dataset import LiverTumorDataset
from unet import UNet

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [7]:
def dice_loss(pred, target, smooth=1e-6):
    pred = pred.view(-1)
    target = target.view(-1)
    intersection = (pred * target).sum()
    return 1 - (2. * intersection + smooth) / (pred.sum() + target.sum() + smooth)

bce = nn.BCELoss()

In [11]:
model.load_state_dict(torch.load("unet_liver_stage_6.pth"))
print("Loaded checkpoint from Stage 6")

Loaded checkpoint from Stage 6


In [15]:
chunk_files = [f"chunk_{i}.txt" for i in range(6, 10)]  # chunk_6 to chunk_9
epochs_per_chunk = 2
os.makedirs("checkpoints", exist_ok=True)

# Training loop
for chunk_id, chunk_file in enumerate(chunk_files, start=6):
    print(f"\nResuming training on {chunk_file} (Chunk {chunk_id+1}/10)")

    train_dataset = LiverTumorDataset(chunk_file, image_size=(256, 256))
    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)

    for epoch in range(epochs_per_chunk):
        model.train()
        total_loss = 0.0

        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)

            outputs = model(images)
            loss = 0.5 * bce(outputs, masks) + 0.5 * dice_loss(outputs, masks)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Chunk {chunk_id+1} | Epoch {epoch+1}/{epochs_per_chunk} | Loss: {avg_loss:.4f}")
    torch.save(model.state_dict(), f"unet_liver_stage_{chunk_id+1}.pth")

# Final save
torch.save(model.state_dict(), "unet_liver.pth")
print("Final model saved as unet_liver.pth")


Resuming training on chunk_6.txt (Chunk 7/10)
Chunk 7 | Epoch 1/2 | Loss: 0.1760
Chunk 7 | Epoch 2/2 | Loss: 0.1747

Resuming training on chunk_7.txt (Chunk 8/10)
Chunk 8 | Epoch 1/2 | Loss: 0.1781
Chunk 8 | Epoch 2/2 | Loss: 0.1794

Resuming training on chunk_8.txt (Chunk 9/10)


ValueError: num_samples should be a positive integer value, but got num_samples=0

In [17]:
model.load_state_dict(torch.load("unet_liver_stage_8.pth"))
torch.save(model.state_dict(), "unet_liver_final.pth")
print("Saved model as unet_liver_final.pth")

Saved model as unet_liver_final.pth
