In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, ToTensor, Normalize, ColorJitter, RandomHorizontalFlip, RandomRotation
from PIL import Image
import os
import numpy as np
import random
from tqdm import tqdm
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.utils.class_weight import compute_class_weight
import tifffile
from torchvision.transforms import Lambda

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define class names and settings
class_names = ["background", "film", "basket", "cardboard", "video_tape", "filament", "bag"]
num_classes = len(class_names)
ignore_in_eval = [True, False, False, False, False, False, False]  # Ignore background in evaluation

# Set random seed for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
set_seed(42)

# Dataset class for RGB, hyperspectral, and mask data
class FusionDataset(Dataset):
    def __init__(self, rgb_dir, hyper_dir, mask_dir, rgb_transform=None, hyper_transform=None, mask_transform=None):
        self.rgb_dir = rgb_dir
        self.hyper_dir = hyper_dir
        self.mask_dir = mask_dir
        self.rgb_transform = rgb_transform
        self.hyper_transform = hyper_transform
        self.mask_transform = mask_transform

        self.rgb_filenames = sorted(os.listdir(rgb_dir))
        self.hyper_filenames = sorted(os.listdir(hyper_dir))
        self.mask_filenames = sorted(os.listdir(mask_dir))

    def __len__(self):
        return len(self.rgb_filenames)

    def __getitem__(self, idx):
        rgb_path = os.path.join(self.rgb_dir, self.rgb_filenames[idx])
        hyper_path = os.path.join(self.hyper_dir, self.hyper_filenames[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_filenames[idx])

        # Load RGB image
        rgb = Image.open(rgb_path).convert("RGB")
        rgb = self.rgb_transform(rgb) if self.rgb_transform else ToTensor()(rgb)

        # Load hyperspectral image (TIF)
        hyper = tifffile.imread(hyper_path).astype(np.float32) / 255.0  # Normalize to [0, 1]
        hyper = torch.tensor(hyper, dtype=torch.float32)

        # Ensure hyperspectral image shape is [C, H, W] with 33 channels
        if len(hyper.shape) == 3:  # Expected [H, W, C]
            hyper = hyper.permute(2, 0, 1)  # Convert to [C, H, W]
        elif len(hyper.shape) == 2:  # If grayscale, add channel dimension
            hyper = hyper.unsqueeze(0)

        # Select first 33 bands if more exist
        if hyper.shape[0] > 33:
            hyper = hyper[:33]
        elif hyper.shape[0] < 33:
            raise ValueError(f"Hyperspectral image has {hyper.shape[0]} channels, but 33 are required!")

        # Resize hyperspectral image to (256, 256)
        if hyper.shape[1] != 256 or hyper.shape[2] != 256:
            hyper = F.interpolate(hyper.unsqueeze(0), size=(256, 256), mode="bilinear", align_corners=True).squeeze(0)

        # Load and process mask
        mask = Image.open(mask_path)
        mask = torch.tensor(np.array(mask), dtype=torch.long)
        if self.mask_transform:
            mask = self.mask_transform(mask)

        return rgb, hyper, mask

# Transformations for training (RGB and masks)
train_image_transform = Compose([
    ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.05),
    RandomHorizontalFlip(p=0.5),
    RandomRotation(degrees=20),
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_mask_transform = Compose([
    RandomHorizontalFlip(p=0.5),
    RandomRotation(degrees=20),
    Lambda(lambda x: torch.tensor(np.array(x), dtype=torch.long))
])

# Transformations for validation (RGB only)
val_image_transform = Compose([
    ToTensor(),
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_mask_transform = Compose([])

# Dataset paths (update these to your actual paths)
rgb_train_dir = "./rgb/train"
hyper_train_dir = "./hyper/train"
mask_train_dir = "./labels_rgb/train"

# Initialize dataset
dataset = FusionDataset(
    rgb_dir=rgb_train_dir,
    hyper_dir=hyper_train_dir,
    mask_dir=mask_train_dir,
    rgb_transform=None,
    hyper_transform=None,
    mask_transform=None
)

# Stratified train/validation split
labels = [1 if np.any(mask.numpy() > 0) else 0 for _, _, mask in dataset]
split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx, val_idx = next(split.split(np.zeros(len(labels)), labels))

# Subset datasets
train_dataset = torch.utils.data.Subset(dataset, train_idx)
val_dataset = torch.utils.data.Subset(dataset, val_idx)

# Apply transformations
train_dataset.dataset.rgb_transform = train_image_transform
train_dataset.dataset.mask_transform = train_mask_transform
val_dataset.dataset.rgb_transform = val_image_transform
val_dataset.dataset.mask_transform = val_mask_transform

# DataLoaders
batch_size = 16
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Compute class weights for imbalanced data
def compute_class_weights(dataset, num_classes):
    mask_flat = []
    for _, _, mask in dataset:
        mask_flat.extend(mask.numpy().flatten())
    class_weights = compute_class_weight("balanced", classes=np.arange(num_classes), y=mask_flat)
    return torch.tensor(class_weights, dtype=torch.float).to(device)

class_weights = compute_class_weights(train_dataset, num_classes)

# Combined loss function
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, inputs, targets):
        inputs_soft = F.softmax(inputs, dim=1)
        targets_onehot = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2).float()
        intersection = torch.sum(inputs_soft * targets_onehot, dim=(0, 2, 3))
        union = torch.sum(inputs_soft + targets_onehot, dim=(0, 2, 3))
        dice = (2. * intersection + self.smooth) / (union + self.smooth)
        return 1 - dice.mean()

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, weight=None):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.ce = nn.CrossEntropyLoss(weight=weight)

    def forward(self, inputs, targets):
        log_probs = F.log_softmax(inputs, dim=1)
        probs = torch.exp(log_probs)
        loss = (1 - probs) ** self.gamma * self.ce(inputs, targets)
        return loss.mean()

criterion = lambda outputs, targets: 0.5 * FocalLoss(weight=class_weights)(outputs, targets) + 0.5 * DiceLoss()(outputs, targets)

# FusionUNet model definition
class FusionUNet(nn.Module):
    def __init__(self, num_classes, hyperspectral_bands=33):
        super(FusionUNet, self).__init__()

        # Initial convolutional blocks for RGB and hyperspectral inputs
        self.initial_conv_rgb = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        self.initial_conv_hyper = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )

        # RGB Encoder (ResNet34 backbone)
        self.rgb_encoder = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        self.rgb_initial = nn.Sequential(*list(self.rgb_encoder.children())[:3])
        self.rgb_maxpool = list(self.rgb_encoder.children())[3]
        self.rgb_encoder1 = list(self.rgb_encoder.children())[4]
        self.rgb_encoder2 = list(self.rgb_encoder.children())[5]
        self.rgb_encoder3 = list(self.rgb_encoder.children())[6]
        self.rgb_encoder4 = list(self.rgb_encoder.children())[7]

        # Hyperspectral Encoder with spectral attention
        self.spectral_attention = nn.Sequential(
            nn.Conv2d(hyperspectral_bands, 16, kernel_size=1),
            nn.BatchNorm2d(16),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 3, kernel_size=1)
        )
        self.hyper_encoder = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        self.hyper_initial = nn.Sequential(*list(self.hyper_encoder.children())[:3])
        self.hyper_maxpool = list(self.hyper_encoder.children())[3]
        self.hyper_encoder1 = list(self.hyper_encoder.children())[4]
        self.hyper_encoder2 = list(self.hyper_encoder.children())[5]
        self.hyper_encoder3 = list(self.hyper_encoder.children())[6]
        self.hyper_encoder4 = list(self.hyper_encoder.children())[7]

        # Decoder
        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder4 = self._decoder_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder3 = self._decoder_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder2 = self._decoder_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder1 = self._decoder_block(64, 64)
        self.upconv0 = nn.ConvTranspose2d(64, 32, kernel_size=2, stride=2)
        self.decoder0 = self._decoder_block(96, 32)  # 96 = 32 (from upconv0) + 64 (from initial_features)
        self.conv_last = nn.Conv2d(32, num_classes, kernel_size=1)

    def _decoder_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, rgb, hyper):
        # Hyperspectral processing
        hyper = self.spectral_attention(hyper)

        # Initial feature extraction for skip connection
        initial_features_rgb = self.initial_conv_rgb(rgb)        # [batch, 32, 256, 256]
        initial_features_hyper = self.initial_conv_hyper(hyper)  # [batch, 32, 256, 256]
        initial_features = torch.cat((initial_features_rgb, initial_features_hyper), dim=1)  # [batch, 64, 256, 256]

        # RGB Encoder
        rgb0 = self.rgb_initial(rgb)
        rgb1 = self.rgb_maxpool(rgb0)
        rgb1 = self.rgb_encoder1(rgb1)
        rgb2 = self.rgb_encoder2(rgb1)
        rgb3 = self.rgb_encoder3(rgb2)
        rgb4 = self.rgb_encoder4(rgb3)

        # Hyperspectral Encoder
        hyper0 = self.hyper_initial(hyper)
        hyper1 = self.hyper_maxpool(hyper0)
        hyper1 = self.hyper_encoder1(hyper1)
        hyper2 = self.hyper_encoder2(hyper1)
        hyper3 = self.hyper_encoder3(hyper2)
        hyper4 = self.hyper_encoder4(hyper3)

        # Concatenate encoder outputs
        enc4 = torch.cat((rgb4, hyper4), dim=1)
        enc3 = torch.cat((rgb3, hyper3), dim=1)
        enc2 = torch.cat((rgb2, hyper2), dim=1)
        enc1 = torch.cat((rgb1, hyper1), dim=1)

        # Decoder
        d4 = self.upconv4(enc4)
        d4 = torch.cat((d4, enc3), dim=1)
        d4 = self.decoder4(d4)

        d3 = self.upconv3(d4)
        d3 = torch.cat((d3, enc2), dim=1)
        d3 = self.decoder3(d3)

        d2 = self.upconv2(d3)
        d2 = torch.cat((d2, enc1), dim=1)
        d2 = self.decoder2(d2)

        d1 = self.upconv1(d2)
        d1 = self.decoder1(d1)

        # Additional upsampling with skip connection
        d0 = self.upconv0(d1)                        # [batch, 32, 256, 256]
        d0 = torch.cat((d0, initial_features), dim=1)  # [batch, 32 + 64 = 96, 256, 256]
        d0 = self.decoder0(d0)                       # [batch, 32, 256, 256]
        return self.conv_last(d0)                    # [batch, num_classes, 256, 256]

# Initialize model, optimizer, and scheduler
model = FusionUNet(num_classes, hyperspectral_bands=33).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=5, min_lr=1e-6)

# IoU calculation function
def calculate_iou(preds, masks, num_classes):
    intersection = torch.zeros(num_classes).to(preds.device)
    union = torch.zeros(num_classes).to(preds.device)
    
    for cls in range(num_classes):
        if ignore_in_eval[cls]:
            continue
        pred_inds = (preds == cls)
        target_inds = (masks == cls)
        intersection[cls] += (pred_inds & target_inds).sum().float()
        union[cls] += (pred_inds | target_inds).sum().float()
    
    iou_per_class = [(intersection[cls] / union[cls]).item() if union[cls].item() > 0 else float('nan') for cls in range(num_classes)]
    valid_ious = [iou for iou in iou_per_class if not np.isnan(iou)]
    mean_iou = np.mean(valid_ious) if valid_ious else 0.0
    
    return mean_iou, iou_per_class

# Training loop
num_epochs = 50
best_iou = 0.0
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0
    for rgb, hyper, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        rgb = rgb.to(device)
        hyper = hyper.to(device)
        masks = masks.to(device)
        optimizer.zero_grad()
        outputs = model(rgb, hyper)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    # Evaluate on Training Set
    model.eval()
    train_preds, train_masks = [], []
    with torch.no_grad():
        for rgb, hyper, masks in train_loader:
            rgb = rgb.to(device)
            hyper = hyper.to(device)
            masks = masks.to(device)
            outputs = model(rgb, hyper)
            preds = torch.argmax(outputs, dim=1)
            train_preds.append(preds.cpu())
            train_masks.append(masks.cpu())
    
    train_preds = torch.cat(train_preds, dim=0)
    train_masks = torch.cat(train_masks, dim=0)
    train_mean_iou, train_iou_per_class = calculate_iou(train_preds, train_masks, num_classes)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {total_loss / len(train_loader):.4f}, Train Mean IoU: {train_mean_iou:.4f}")
    for cls, iou in enumerate(train_iou_per_class):
        if not ignore_in_eval[cls]:
            print(f"  {class_names[cls]}: {iou:.4f}")
    
    # Evaluate on Validation Set
    val_preds, val_masks = [], []
    with torch.no_grad():
        for rgb, hyper, masks in val_loader:
            rgb = rgb.to(device)
            hyper = hyper.to(device)
            masks = masks.to(device)
            outputs = model(rgb, hyper)
            preds = torch.argmax(outputs, dim=1)
            val_preds.append(preds.cpu())
            val_masks.append(masks.cpu())
    
    val_preds = torch.cat(val_preds, dim=0)
    val_masks = torch.cat(val_masks, dim=0)
    val_mean_iou, val_iou_per_class = calculate_iou(val_preds, val_masks, num_classes)
    
    print(f"Epoch [{epoch+1}/{num_epochs}], Val Mean IoU: {val_mean_iou:.4f}")
    for cls, iou in enumerate(val_iou_per_class):
        if not ignore_in_eval[cls]:
            print(f"  {class_names[cls]}: {iou:.4f}")
    
    # Save Best Model
    if val_mean_iou > best_iou:
        best_iou = val_mean_iou
        torch.save(model.state_dict(), "best_fusion_model.pth")
        print("Best model saved!")
    
    scheduler.step(val_mean_iou)

# Final Evaluation Function
def evaluate_model(model, loader, dataset_name):
    model.eval()
    total_loss = 0.0
    all_preds, all_masks = [], []
    with torch.no_grad():
        for rgb, hyper, masks in loader:
            rgb = rgb.to(device)
            hyper = hyper.to(device)
            masks = masks.to(device)
            outputs = model(rgb, hyper)
            loss = criterion(outputs, masks)
            total_loss += loss.item() * rgb.size(0)
            preds = torch.argmax(outputs, dim=1)
            all_preds.append(preds.cpu())
            all_masks.append(masks.cpu())
    
    all_preds = torch.cat(all_preds, dim=0)
    all_masks = torch.cat(all_masks, dim=0)
    mean_iou, iou_per_class = calculate_iou(all_preds, all_masks, num_classes)
    
    total_loss /= len(loader.dataset)
    print(f"{dataset_name} Results:")
    print(f"Loss: {total_loss:.4f}")
    print(f"Mean IoU: {mean_iou:.4f}")
    for cls, iou in enumerate(iou_per_class):
        if not ignore_in_eval[cls]:
            print(f"  {class_names[cls]}: {iou:.4f}")
    print("-" * 30)

# Load and evaluate best model
model.load_state_dict(torch.load("best_fusion_model.pth"))
print("Final Model Evaluation:")
evaluate_model(model, train_loader, "Training Set")
evaluate_model(model, val_loader, "Validation Set")

Using device: cuda


Epoch 1/50: 100%|██████████| 26/26 [00:55<00:00,  2.14s/it]


Epoch [1/50], Train Loss: 1.1565, Train Mean IoU: 0.0684
  film: 0.1131
  basket: 0.0969
  cardboard: 0.0217
  video_tape: 0.0174
  filament: 0.0378
  bag: 0.1234
Epoch [1/50], Val Mean IoU: 0.0713
  film: 0.1234
  basket: 0.1019
  cardboard: 0.0094
  video_tape: 0.0145
  filament: 0.0164
  bag: 0.1624
Best model saved!


Epoch 2/50: 100%|██████████| 26/26 [01:01<00:00,  2.36s/it]


Epoch [2/50], Train Loss: 1.0311, Train Mean IoU: 0.0946
  film: 0.1634
  basket: 0.1181
  cardboard: 0.0533
  video_tape: 0.0290
  filament: 0.0102
  bag: 0.1935
Epoch [2/50], Val Mean IoU: 0.1063
  film: 0.1776
  basket: 0.1388
  cardboard: 0.0581
  video_tape: 0.0225
  filament: 0.0152
  bag: 0.2255
Best model saved!


Epoch 3/50: 100%|██████████| 26/26 [00:53<00:00,  2.06s/it]


Epoch [3/50], Train Loss: 0.9319, Train Mean IoU: 0.1199
  film: 0.1576
  basket: 0.0630
  cardboard: 0.2788
  video_tape: 0.0069
  filament: 0.0138
  bag: 0.1994
Epoch [3/50], Val Mean IoU: 0.1166
  film: 0.1778
  basket: 0.0796
  cardboard: 0.1878
  video_tape: 0.0089
  filament: 0.0170
  bag: 0.2284
Best model saved!


Epoch 4/50: 100%|██████████| 26/26 [00:55<00:00,  2.14s/it]


Epoch [4/50], Train Loss: 0.8423, Train Mean IoU: 0.1581
  film: 0.2169
  basket: 0.3241
  cardboard: 0.0326
  video_tape: 0.0901
  filament: 0.0906
  bag: 0.1945
Epoch [4/50], Val Mean IoU: 0.1627
  film: 0.2288
  basket: 0.3448
  cardboard: 0.0381
  video_tape: 0.0869
  filament: 0.0591
  bag: 0.2184
Best model saved!


Epoch 5/50: 100%|██████████| 26/26 [00:55<00:00,  2.12s/it]


Epoch [5/50], Train Loss: 0.7397, Train Mean IoU: 0.2193
  film: 0.3456
  basket: 0.5174
  cardboard: 0.0642
  video_tape: 0.1081
  filament: 0.0310
  bag: 0.2497
Epoch [5/50], Val Mean IoU: 0.2104
  film: 0.3280
  basket: 0.5071
  cardboard: 0.0659
  video_tape: 0.0919
  filament: 0.0258
  bag: 0.2438
Best model saved!


Epoch 6/50: 100%|██████████| 26/26 [00:53<00:00,  2.07s/it]


Epoch [6/50], Train Loss: 0.6783, Train Mean IoU: 0.2431
  film: 0.5488
  basket: 0.4347
  cardboard: 0.0918
  video_tape: 0.1140
  filament: 0.0387
  bag: 0.2305
Epoch [6/50], Val Mean IoU: 0.2204
  film: 0.4305
  basket: 0.4347
  cardboard: 0.1036
  video_tape: 0.0999
  filament: 0.0297
  bag: 0.2241
Best model saved!


Epoch 7/50: 100%|██████████| 26/26 [00:54<00:00,  2.10s/it]


Epoch [7/50], Train Loss: 0.6361, Train Mean IoU: 0.2870
  film: 0.4601
  basket: 0.4370
  cardboard: 0.3984
  video_tape: 0.1152
  filament: 0.0783
  bag: 0.2331
Epoch [7/50], Val Mean IoU: 0.2442
  film: 0.3786
  basket: 0.4536
  cardboard: 0.2396
  video_tape: 0.0982
  filament: 0.0709
  bag: 0.2241
Best model saved!


Epoch 8/50: 100%|██████████| 26/26 [00:54<00:00,  2.08s/it]


Epoch [8/50], Train Loss: 0.6206, Train Mean IoU: 0.2500
  film: 0.4474
  basket: 0.3938
  cardboard: 0.2080
  video_tape: 0.1112
  filament: 0.0820
  bag: 0.2574
Epoch [8/50], Val Mean IoU: 0.2281
  film: 0.3815
  basket: 0.3965
  cardboard: 0.1891
  video_tape: 0.0936
  filament: 0.0603
  bag: 0.2475


Epoch 9/50: 100%|██████████| 26/26 [00:53<00:00,  2.06s/it]


Epoch [9/50], Train Loss: 0.5983, Train Mean IoU: 0.2482
  film: 0.4707
  basket: 0.4767
  cardboard: 0.1004
  video_tape: 0.1241
  filament: 0.0552
  bag: 0.2623
Epoch [9/50], Val Mean IoU: 0.2302
  film: 0.3955
  basket: 0.4729
  cardboard: 0.1139
  video_tape: 0.1029
  filament: 0.0476
  bag: 0.2486


Epoch 10/50: 100%|██████████| 26/26 [00:53<00:00,  2.06s/it]


Epoch [10/50], Train Loss: 0.5656, Train Mean IoU: 0.3156
  film: 0.6086
  basket: 0.5493
  cardboard: 0.3328
  video_tape: 0.0928
  filament: 0.0593
  bag: 0.2509
Epoch [10/50], Val Mean IoU: 0.2645
  film: 0.4487
  basket: 0.5616
  cardboard: 0.2166
  video_tape: 0.0778
  filament: 0.0491
  bag: 0.2330
Best model saved!


Epoch 11/50: 100%|██████████| 26/26 [00:54<00:00,  2.11s/it]


Epoch [11/50], Train Loss: 0.5386, Train Mean IoU: 0.3025
  film: 0.4526
  basket: 0.5466
  cardboard: 0.3399
  video_tape: 0.0872
  filament: 0.0508
  bag: 0.3381
Epoch [11/50], Val Mean IoU: 0.2624
  film: 0.3894
  basket: 0.5432
  cardboard: 0.2121
  video_tape: 0.0751
  filament: 0.0507
  bag: 0.3040


Epoch 12/50: 100%|██████████| 26/26 [00:54<00:00,  2.09s/it]


Epoch [12/50], Train Loss: 0.5178, Train Mean IoU: 0.3498
  film: 0.6588
  basket: 0.5918
  cardboard: 0.2325
  video_tape: 0.1795
  filament: 0.0399
  bag: 0.3961
Epoch [12/50], Val Mean IoU: 0.2990
  film: 0.4931
  basket: 0.5731
  cardboard: 0.2095
  video_tape: 0.1388
  filament: 0.0352
  bag: 0.3444
Best model saved!


Epoch 13/50: 100%|██████████| 26/26 [00:53<00:00,  2.06s/it]


Epoch [13/50], Train Loss: 0.4942, Train Mean IoU: 0.3729
  film: 0.5804
  basket: 0.5635
  cardboard: 0.4877
  video_tape: 0.1228
  filament: 0.0512
  bag: 0.4318
Epoch [13/50], Val Mean IoU: 0.3078
  film: 0.4572
  basket: 0.5392
  cardboard: 0.3281
  video_tape: 0.1011
  filament: 0.0513
  bag: 0.3700
Best model saved!


Epoch 14/50: 100%|██████████| 26/26 [01:09<00:00,  2.69s/it]


Epoch [14/50], Train Loss: 0.4936, Train Mean IoU: 0.3826
  film: 0.6651
  basket: 0.3499
  cardboard: 0.4275
  video_tape: 0.1979
  filament: 0.0265
  bag: 0.6286
Epoch [14/50], Val Mean IoU: 0.2957
  film: 0.4704
  basket: 0.3519
  cardboard: 0.2922
  video_tape: 0.1589
  filament: 0.0315
  bag: 0.4694


Epoch 15/50: 100%|██████████| 26/26 [01:20<00:00,  3.10s/it]


Epoch [15/50], Train Loss: 0.4829, Train Mean IoU: 0.4289
  film: 0.7423
  basket: 0.5689
  cardboard: 0.4841
  video_tape: 0.1368
  filament: 0.0618
  bag: 0.5793
Epoch [15/50], Val Mean IoU: 0.3276
  film: 0.5102
  basket: 0.5616
  cardboard: 0.3001
  video_tape: 0.1110
  filament: 0.0424
  bag: 0.4402
Best model saved!


Epoch 16/50: 100%|██████████| 26/26 [01:04<00:00,  2.49s/it]


Epoch [16/50], Train Loss: 0.4768, Train Mean IoU: 0.4282
  film: 0.7627
  basket: 0.6029
  cardboard: 0.3008
  video_tape: 0.2412
  filament: 0.1125
  bag: 0.5491
Epoch [16/50], Val Mean IoU: 0.3367
  film: 0.5271
  basket: 0.5795
  cardboard: 0.2247
  video_tape: 0.1737
  filament: 0.1073
  bag: 0.4077
Best model saved!


Epoch 17/50: 100%|██████████| 26/26 [00:55<00:00,  2.15s/it]


Epoch [17/50], Train Loss: 0.4622, Train Mean IoU: 0.4411
  film: 0.5885
  basket: 0.7675
  cardboard: 0.4855
  video_tape: 0.1425
  filament: 0.0947
  bag: 0.5678
Epoch [17/50], Val Mean IoU: 0.3386
  film: 0.4418
  basket: 0.6489
  cardboard: 0.3097
  video_tape: 0.1138
  filament: 0.0776
  bag: 0.4399
Best model saved!


Epoch 18/50:  69%|██████▉   | 18/26 [00:24<00:11,  1.39s/it]


KeyboardInterrupt: 