In [7]:
from models.seg_net_lite import SegNetLite
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
from PIL import Image
import wandb
from sklearn.metrics import f1_score

In [8]:
BATCH_SIZE = 4

# Set up the DataLoader for the dataset
class SegNetDataset(Dataset):
    def __init__(self, image_folder, groundtruth_folder, image_transform=None, groundtruth_transform=None):
        self.image_folder = image_folder
        self.groundtruth_folder = groundtruth_folder
        self.image_transform = image_transform
        self.groundtruth_transform = groundtruth_transform
        self.image_files = os.listdir(image_folder)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_folder, self.image_files[idx])
        groundtruth_path = os.path.join(self.groundtruth_folder, self.image_files[idx])

        image = Image.open(image_path).convert("RGB")
        groundtruth = Image.open(groundtruth_path).convert("L")

        if self.image_transform:
            image = self.image_transform(image)
        if self.groundtruth_transform:
            groundtruth = self.groundtruth_transform(groundtruth)

        return image, groundtruth


# Initialize the DataLoader
image_transform = Compose([ToTensor()])
groundtruth_transform = Compose([ToTensor()])

train_data = SegNetDataset("training/images", "training/groundtruth", image_transform=image_transform, groundtruth_transform=groundtruth_transform)

train_size = int(0.8 * len(train_data))
val_size = len(train_data) - train_size
train_data, val_data = torch.utils.data.random_split(train_data, [train_size, val_size])

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# Create the model, loss function, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SegNetLite().to(device)
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0001)


# Initialize weights
print(model(torch.randn(4, 3, 400, 400).to(device)).shape)
print(iter(train_loader).next()[0].shape)
print(iter(train_loader).next()[1].shape)




torch.Size([4, 1, 400, 400])
torch.Size([4, 3, 400, 400])
torch.Size([4, 1, 400, 400])


In [9]:
# Initialize wandb
wandb.init(project="CIL 2023", entity="tlaborie")
wandb.watch(model, log="all")

# Set up the training and validation loop
num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0.0
    for batch_idx, (images, groundtruths) in enumerate(train_loader):
        images, groundtruths = images.to(device), groundtruths.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, groundtruths)
        loss.backward()
        optimizer.step()
        total_train_loss += loss.item()


    avg_train_loss = total_train_loss / len(train_loader)
    wandb.log({"Train Loss": avg_train_loss})
    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {avg_train_loss}")

    model.eval()
    total_f1 = 0.0
    with torch.no_grad():
        for images, groundtruths in val_loader:
            images, groundtruths = images.to(device), groundtruths.to(device)
            outputs = model(images)
            preds = outputs.cpu().numpy()
            preds[preds >= 0.5] = 1
            preds[preds < 0.5] = 0
            groundtruths = groundtruths.cpu().numpy()
            total_f1 += f1_score(groundtruths.flatten(), preds.flatten(), average='weighted')

    val_f1 = total_f1 / len(val_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Validation F1 Score: {val_f1}")
    wandb.log({"Validation F1 Score": val_f1})

wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mtimothelaborie[0m ([33mtlaborie[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 1/5, Training Loss: 0.8241235387736353
Epoch 1/5, Validation F1 Score: 0.08127966611021101
Epoch 2/5, Training Loss: 0.694300495345017
Epoch 2/5, Validation F1 Score: 0.7054073081128847
Epoch 3/5, Training Loss: 0.6556050201942181
Epoch 3/5, Validation F1 Score: 0.756407933885076
Epoch 4/5, Training Loss: 0.6397893182162581
Epoch 4/5, Validation F1 Score: 0.8142654235132142
Epoch 5/5, Training Loss: 0.6302901177570738
Epoch 5/5, Validation F1 Score: 0.7997513196757334


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Train Loss,█▃▂▁▁
Validation F1 Score,▁▇▇██

0,1
Train Loss,0.63029
Validation F1 Score,0.79975
