<a href="https://colab.research.google.com/github/umfieldrobotics/shipwreck_finder_demo/blob/main/shipwreck_demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!wget -O mini.sh https://repo.anaconda.com/miniconda/Miniconda3-py38_4.8.2-Linux-x86_64.sh
!chmod +x mini.sh
!bash ./mini.sh -b -f -p /usr/local
!conda install -q -y jupyter
!conda install -q -y google-colab -c conda-forge
!python -m ipykernel install --name "py38" --user

In [None]:
# Reload the web page and execute this cell
import sys
print("User Current Version:-", sys.version)

In [None]:
from torch.utils.data import DataLoader
import tqdm
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

label_count = 0
background_count = 0
total_count = 0
lr = 5e-4

for data in train_loader:
    label = data['label'].to(DEVICE)
    label_count += (label == 1).sum()
    background_count += (label == 0).sum()
total_count = background_count # made a separate background variable since masked pixels don't count as background
print(label_count, total_count)

ratio = label_count / (total_count)
weight1 = label_count / (total_count - label_count)
weight0 = 1/weight1
print("Ratio:", ratio.item(), "Weight0:", weight0.item(), "Weight1:", weight1.item())

optim = torch.optim.Adam(model.parameters(), lr=lr)
ce_loss = torch.nn.CrossEntropyLoss(weight=torch.tensor([weight0, weight1]).to(DEVICE), ignore_index=-1)

NUM_EPOCHS=10
epoch_
for epoch in range(NUM_EPOCHS):
  model.train()
  epoch_loss_accum = 0.0
  train_pbar = tqdm.tqdm(train_loader, total = len(train_loader), desc = 'Train Batches', leave=False)
  for ii, train_data in enumerate(train_pbar):
    im = train_data['image'].to(DEVICE)
    label = train_data['label'].to(DEVICE)

    logits = model(im)
    loss = ce_loss(logits, label)
    optim.zero_grad(set_to_none=True)
    loss.backward()
    optim.step()

    epoch_loss_accum += loss.item() * im.size(0)

  train_loss_mean = epoch_loss_accum / len(train_loader.dataset)
    epoch_pbar.set_description(f'Epochs (Train Loss: {train_loss_mean:.4f})')

  model.eval()
  test_loss_accum = 0.0
  test_iou_accum = 0.0
  with torch.no_grad():
    test_pbar = tqdm.tqdm(test_loader, total = len(test_loader), desc = 'Test Batches', leave=False)
    for ii, data in enumerate(test_pbar):
      im = data['image'].to(DEVICE)
      label = data['label'].to(DEVICE)
      pred = model(im)
      loss = ce_loss(pred, label)
      test_loss_accum += loss.item() * im.size(0)

    epoch_pbar.set_description(f'Epochs (Test Loss: {test_loss_mean:.4f})')

  test_loss_mean = test_loss_accum / len(test_loader.dataset)
  print(f'[INFO - Test]\t EPOCH: {epoch}, Loss: {test_loss_mean:.4f}')
  epoch_pbar.set_description(f'Epochs (Test Loss: {test_loss_mean:.4f})')