In [1]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image

# -------------------- Config --------------------
train_dir = "/home/intellisense08/Yehan/project_dir/Webots/Root_data/train/"
num_classes = 3
batch_size = 16
epochs = 30
learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -------------------- Dataset --------------------
class MultiLabelDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        for subfolder in os.listdir(root_dir):
            sub_path = os.path.join(root_dir, subfolder)
            if not os.path.isdir(sub_path):
                continue
            # folder name encodes label e.g. "1_0_1"
            label_vector = [int(x) for x in subfolder.split('_')]
            for fname in os.listdir(sub_path):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.images.append(os.path.join(sub_path, fname))
                    self.labels.append(label_vector)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = self.images[idx]
        label = torch.tensor(self.labels[idx], dtype=torch.float)

        img = Image.open(img_path).convert("L")  # grayscale

        if self.transform:
            img = self.transform(img)

        # expand grayscale to 3 channels
        img = img.repeat(3, 1, 1)

        return img, label


# -------------------- Transforms --------------------
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

train_dataset = MultiLabelDataset(train_dir, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

# -------------------- MobileNetV2 (from scratch) --------------------
model = models.mobilenet_v2(pretrained=False)  # no pretrained weights
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
model = model.to(device)

# -------------------- Loss & Optimizer --------------------
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# -------------------- Accuracy Function --------------------
def exact_match_accuracy(outputs, targets, threshold=0.5):
    preds = (torch.sigmoid(outputs) >= threshold).int()
    targets = targets.int()
    correct = (preds == targets).all(dim=1).sum().item()
    total = targets.size(0)
    return correct, total

# -------------------- Training Loop --------------------
for epoch in range(epochs):
    model.train()
    running_loss, running_correct, running_total = 0.0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        correct, total = exact_match_accuracy(outputs, labels)
        running_correct += correct
        running_total += total

    train_loss = running_loss / running_total
    train_acc = running_correct / running_total

    print(f"Epoch [{epoch+1}/{epochs}] "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")

# -------------------- Save Model --------------------
save_path = "/home/intellisense08/Yehan/project_dir/Webots/trained_models/MobilenetV2/mobilenetv2_synthetic.pth"
save_path_old = "/home/intellisense08/Yehan/project_dir/Webots/trained_models/MobilenetV2/mobilenetv2_synthetic_old.pth"
torch.save(model.state_dict(), save_path)
torch.save(model.state_dict(), save_path_old,_use_new_zipfile_serialization=False)
print(f"Model saved to {save_path}")




Epoch [1/30] Train Loss: 0.5368, Train Acc: 0.4461
Epoch [2/30] Train Loss: 0.2546, Train Acc: 0.7379
Epoch [3/30] Train Loss: 0.1489, Train Acc: 0.8531
Epoch [4/30] Train Loss: 0.0992, Train Acc: 0.9148
Epoch [5/30] Train Loss: 0.0807, Train Acc: 0.9230
Epoch [6/30] Train Loss: 0.0556, Train Acc: 0.9496
Epoch [7/30] Train Loss: 0.0435, Train Acc: 0.9633
Epoch [8/30] Train Loss: 0.0376, Train Acc: 0.9738
Epoch [9/30] Train Loss: 0.0337, Train Acc: 0.9684
Epoch [10/30] Train Loss: 0.0267, Train Acc: 0.9766
Epoch [11/30] Train Loss: 0.0256, Train Acc: 0.9797
Epoch [12/30] Train Loss: 0.0153, Train Acc: 0.9902
Epoch [13/30] Train Loss: 0.0173, Train Acc: 0.9840
Epoch [14/30] Train Loss: 0.0163, Train Acc: 0.9879
Epoch [15/30] Train Loss: 0.0205, Train Acc: 0.9836
Epoch [16/30] Train Loss: 0.0170, Train Acc: 0.9848
Epoch [17/30] Train Loss: 0.0161, Train Acc: 0.9867
Epoch [18/30] Train Loss: 0.0160, Train Acc: 0.9867
Epoch [19/30] Train Loss: 0.0152, Train Acc: 0.9883
Epoch [20/30] Train L