In [3]:
import os
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from UNet_3D import UNet3D
from Dataset import MRIDataset
from DiceLoss import DiceLoss

In [4]:
# Hyperparameters
num_epochs = 50
batch_size = 1
learning_rate = 0.001
num_classes = 36  # As per the dataset
in_channels = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Paths
root_dir = '/home/salahuddin/cornell/ADSP/Project/MRIdata/ForClass'
checkpoint_dir = '/home/salahuddin/cornell/ADSP/Project/Unet-3D-2D/model_checkpoint'
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, 'model_checkpoint.pth')

cpu


In [5]:


# Training dataset and dataloader
train_dataset = MRIDataset(root_dir=root_dir, mode='train')
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Validation dataset and dataloader
val_dataset = MRIDataset(root_dir=root_dir, mode='val')
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


/home/salahuddin/cornell/ADSP/Project/MRIdata/ForClass/Training
/home/salahuddin/cornell/ADSP/Project/MRIdata/ForClass/Validation


In [6]:
# Cell 4

model = UNet3D(in_channels=in_channels, out_channels=num_classes).to(device)
criterion = DiceLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Check for existing model checkpoint and load it if found
if os.path.exists(checkpoint_path):
    print(f"Loading checkpoint '{checkpoint_path}'")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1  # Continue from the next epoch
    train_losses = checkpoint['train_losses']
    val_losses = checkpoint['val_losses']
else:
    print("No checkpoint found, starting training from scratch")
    start_epoch = 0
    train_losses = []
    val_losses = []

No checkpoint found, starting training from scratch


In [7]:

for epoch in range(start_epoch, num_epochs):
    model.train()
    epoch_train_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_train_loss += loss.item()
    epoch_train_loss /= len(train_loader)
    train_losses.append(epoch_train_loss)

    model.eval()
    epoch_val_loss = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            epoch_val_loss += loss.item()
    epoch_val_loss /= len(val_loader)
    val_losses.append(epoch_val_loss)

    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}')

    # Save the model checkpoint (overwrite the last checkpoint)
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_losses': train_losses,
        'val_losses': val_losses
    }, checkpoint_path)


: 

In [None]:

plt.figure()
plt.plot(range(1, num_epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs+1), val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Dice Loss')
plt.legend()
plt.title('Training and Validation Loss')
plt.show()
