In [1]:
from utils import UNet, SegmentationDataset
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn

dataset = SegmentationDataset("SA1B_Meta_AI_Segmentation_Dataset/")
dataloader = DataLoader(dataset, batch_size=50, shuffle=True)
model = UNet(in_channels=3, out_channels=1)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [2]:
# def train_unet(model, dataset, epochs=10, lr=1e-4):
#     dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
#     optimizer = optim.Adam(model.parameters(), lr=lr)
#     criterion = nn.BCELoss()  # For binary segmentation

#     model.train()
#     for epoch in range(epochs):
#         epoch_loss = 0.0
#         for imgs, masks in dataloader:
#             imgs, masks = imgs.cuda(), masks.cuda()  # Move to GPU if available

#             # Forward pass
#             preds = model(imgs)
#             loss = criterion(preds, masks.unsqueeze(1))  # Add channel dim to masks
#             epoch_loss += loss.item()

#             # Backward pass
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()

#         print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss / len(dataloader)}")

#     return model

In [None]:
# model.train()
# for epoch in range(2):
# 	epoch_loss = 0
# 	for idx, (image, mask) in enumerate(dataloader):
# 		outputs = model(image)
# 		loss = criterion(outputs, mask)
# 		optimizer.zero_grad()
# 		loss.backward()
# 		optimizer.step()
# 		epoch_loss += loss.item()
# 		# outputs.weights
# 		print(f'epoch: {epoch+1} | loss: {epoch_loss}')


In [None]:
# print(dataset)

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn as nn
from torch.cuda.amp import GradScaler, autocast

# Set device to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model (replace this with your model definition)
model = model.to(device)

# Define the criterion (loss function) and optimizer
criterion = nn.BCEWithLogitsLoss()  # More stable for binary segmentation
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Learning rate scheduler
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

# Mixed precision training setup
scaler = torch.amp.GradScaler()

# Dice Coefficient Metric
def dice_coefficient(outputs, masks, threshold=0.5):
    outputs = (outputs > threshold).float()  # Apply threshold
    intersection = (outputs * masks).sum()
    union = outputs.sum() + masks.sum()
    dice = 2.0 * intersection / (union + 1e-7)  # Add epsilon to avoid division by zero
    return dice.item()

# Training Loop
model.train()
num_epochs = 2
for epoch in range(num_epochs):  # Set number of epochs
    epoch_loss = 0.0  # Initialize epoch loss
    epoch_dice = 0.0  # Initialize epoch Dice Score

    for idx, (images, masks) in enumerate(dataloader):  # Iterate over batches
        # Move data to the correct device (GPU or CPU)
        images, masks = images.to(device), masks.to(device).unsqueeze(1).float()

        # Forward pass with mixed precision
        with torch.amp.autocast(device_type="cpu"):
            outputs = model(images)
            loss = criterion(outputs, masks)  # Compute loss

        # Backward pass and optimization with mixed precision
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Update epoch metrics
        epoch_loss += loss.item()
        epoch_dice += dice_coefficient(outputs.sigmoid(), masks)

    # Step the learning rate scheduler
    scheduler.step()

    # Print average loss and Dice Score for the epoch
    avg_loss = epoch_loss / len(dataloader)
    avg_dice = epoch_dice / len(dataloader)
    print(f"Epoch: {epoch + 1}, Loss: {avg_loss:.4f}, Dice: {avg_dice:.4f}")

# Save the trained model
torch.save(model.state_dict(), 'segmentation_model.pth')
print("Model saved to 'segmentation_model.pth'")

# Load the model for inference
model.load_state_dict(torch.load('segmentation_model.pth'))
model.eval()
print("Model loaded and set to evaluation mode")