# Gdrive

In [1]:
from google.colab import drive
drive.mount('/gdrive')

Mounted at /gdrive


In [2]:
# !unzip /gdrive/My\ Drive/DIRT/full_new_data.zip -d /gdrive/My\ Drive/DIRT/

In [3]:
# !pip install ultralytics

In [4]:
!nvidia-smi

Sun Dec  8 02:42:55 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.104.05             Driver Version: 535.104.05   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA L4                      Off | 00000000:00:03.0 Off |                    0 |
| N/A   46C    P8              12W /  72W |      1MiB / 23034MiB |      0%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

# Imports

In [5]:
# !pip install albumentations wandb

In [1]:
import os
import glob
import math
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler
import albumentations as A
from albumentations.pytorch import ToTensorV2
import wandb

  check_for_updates()


# Config

In [2]:
# Configuration
class Config:
    # Data
    IMAGE_SIZE = 256  # Changed from 384 to 256
    BATCH_SIZE = 32
    NUM_WORKERS = 4

    # Model
    IN_CHANNELS = 3
    NUM_CLASSES = 1

    # Training
    EPOCHS = 300  # Changed from 100 to 300
    LEARNING_RATE = 1e-3  # Confirmed correct
    MIN_LEARNING_RATE = 1e-6  # Added minimum learning rate
    WEIGHT_DECAY = 0.9  # Changed to match paper's momentum decay
    GRADIENT_CLIP = 1.0
    AUGMENTATION_STOP_EPOCH = 180  # New parameter

# Model

In [7]:
# # Model Components from Original Code
# def autopad(k, p=None, d=1):
#     if d > 1:
#         k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
#     if p is None:
#         p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
#     return p

# class Conv(nn.Module):
#     default_act = nn.GELU()

#     def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
#         super().__init__()
#         self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
#         self.bn = nn.BatchNorm2d(c2, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
#         self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()

#     def forward(self, x):
#         return self.act(self.bn(self.conv(x)))

#     def forward_fuse(self, x):
#         return self.act(self.conv(x))

# class DWConv(Conv):
#     def __init__(self, c1, c2, k=1, s=1, d=1, act=True):
#         super().__init__(c1, c2, k, s, g=math.gcd(c1, c2), d=d, act=act)

# # Enhanced CMRF with SE block
# class SEBlock(nn.Module):
#     def __init__(self, channel, reduction=16):
#         super().__init__()
#         self.avg_pool = nn.AdaptiveAvgPool2d(1)
#         self.fc = nn.Sequential(
#             nn.Linear(channel, channel // reduction, bias=False),
#             nn.ReLU(inplace=True),
#             nn.Linear(channel // reduction, channel, bias=False),
#             nn.Sigmoid()
#         )

#     def forward(self, x):
#         b, c, _, _ = x.size()
#         y = self.avg_pool(x).view(b, c)
#         y = self.fc(y).view(b, c, 1, 1)
#         return x * y.expand_as(x)

# class CMRF(nn.Module):
#     def __init__(self, c1, c2, N=8, shortcut=True, g=1, e=0.5):
#         super().__init__()
#         self.N = N
#         self.c = int(c2 * e / self.N)
#         self.add = shortcut and c1 == c2

#         self.pwconv1 = Conv(c1, c2 // self.N, 1, 1)
#         self.pwconv2 = Conv(c2 // 2, c2, 1, 1)
#         self.m = nn.ModuleList(DWConv(self.c, self.c, k=3, act=False) for _ in range(N - 1))

#         # Added SE block and dropout
#         self.se = SEBlock(c2)
#         self.dropout = nn.Dropout2d(0.1)

#     def forward(self, x):
#         x_residual = x
#         x = self.pwconv1(x)
#         x = [x[:, 0::2, :, :], x[:, 1::2, :, :]]
#         x.extend(m(x[-1]) for m in self.m)
#         x[0] = x[0] + x[1]
#         x.pop(1)
#         y = torch.cat(x, dim=1)
#         y = self.pwconv2(y)
#         y = self.se(y)
#         y = self.dropout(y)
#         return x_residual + y if self.add else y


# class UNetEncoder(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(UNetEncoder, self).__init__()
#         self.cmrf = CMRF(in_channels, out_channels)
#         self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)

#     def forward(self, x):
#         x = self.cmrf(x)
#         return self.downsample(x), x

# class UNetDecoder(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super(UNetDecoder, self).__init__()
#         self.cmrf = CMRF(in_channels, out_channels)

#     def forward(self, x, skip_connection):
#         x = F.interpolate(x, scale_factor=2, mode='bicubic', align_corners=False)
#         x = torch.cat([x, skip_connection], dim=1)
#         x = self.cmrf(x)
#         return x

# class TinyUNet(nn.Module):
#     def __init__(self, in_channels=3, num_classes=1):
#         super(TinyUNet, self).__init__()
#         in_filters = [192, 384, 768, 1024]
#         out_filters = [64, 128, 256, 512]

#         self.encoder1 = UNetEncoder(in_channels, 64)
#         self.encoder2 = UNetEncoder(64, 128)
#         self.encoder3 = UNetEncoder(128, 256)
#         self.encoder4 = UNetEncoder(256, 512)
#         self.decoder4 = UNetDecoder(in_filters[3], out_filters[3])
#         self.decoder3 = UNetDecoder(in_filters[2], out_filters[2])
#         self.decoder2 = UNetDecoder(in_filters[1], out_filters[1])
#         self.decoder1 = UNetDecoder(in_filters[0], out_filters[0])
#         self.final_conv = nn.Conv2d(out_filters[0], num_classes, kernel_size=1)

#     def forward(self, x):
#         x, skip1 = self.encoder1(x)
#         x, skip2 = self.encoder2(x)
#         x, skip3 = self.encoder3(x)
#         x, skip4 = self.encoder4(x)
#         x = self.decoder4(x, skip4)
#         x = self.decoder3(x, skip3)
#         x = self.decoder2(x, skip2)
#         x = self.decoder1(x, skip1)
#         x = self.final_conv(x)
#         return x

# Dataset

In [3]:
# Enhanced Dataset with Augmentations
class SegmentationDataset(Dataset):
    def __init__(self, image_dirs, mask_dirs, transform=None, is_training=True, current_epoch=0):
        self.image_paths = []
        self.mask_paths = []
        self.is_training = is_training
        self.current_epoch = current_epoch

        # Get image and mask paths
        for img_dir, msk_dir in zip(image_dirs, mask_dirs):
            img_paths = sorted(glob.glob(os.path.join(img_dir, '*.jpg')))
            for img_path in img_paths:
                filename = os.path.basename(img_path).replace('.jpg', '.png')
                mask_path = os.path.join(msk_dir, filename)
                if os.path.exists(mask_path):
                    self.image_paths.append(img_path)
                    self.mask_paths.append(mask_path)

        # Training transformations
        self.train_transform = A.Compose([
            A.RandomResizedCrop(256, 256, scale=(0.8, 1.0)),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.RandomRotate90(p=0.5),
            A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=15, p=0.5),
            A.OneOf([
                A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),
                A.GridDistortion(p=0.5),
                A.OpticalDistortion(distort_limit=1, shift_limit=0.5, p=0.5)
            ], p=0.3),
            A.OneOf([
                A.GaussNoise(p=0.5),
                A.RandomBrightnessContrast(p=0.5),
                A.RandomGamma(p=0.5)
            ], p=0.3),
            ToTensorV2()
        ])

        # Validation transformations
        self.val_transform = A.Compose([
            A.Resize(256, 256),  # Changed from 384 to 256
            ToTensorV2()
        ])

    def __getitem__(self, idx):
        image = np.array(Image.open(self.image_paths[idx]).convert('RGB'))
        mask = np.array(Image.open(self.mask_paths[idx]).convert('L'))

        # Apply transformations based on training phase and epoch
        if self.is_training and self.current_epoch < Config.AUGMENTATION_STOP_EPOCH:
            transformed = self.train_transform(image=image, mask=mask)
        else:
            transformed = self.val_transform(image=image, mask=mask)

        image = transformed['image'].float() / 255.0
        mask = transformed['mask'].float().unsqueeze(0) / 255.0

        return image, mask

    def __len__(self):
        return len(self.image_paths)

    def update_epoch(self, epoch):
        self.current_epoch = epoch

# Loss

In [4]:
class BCEJaccardLoss(nn.Module):
    def __init__(self, mode='binary', smooth=1.0, eps=1e-7, from_logits=True):
        super(BCEJaccardLoss, self).__init__()
        self.mode = mode
        self.smooth = smooth
        self.eps = eps
        self.from_logits = from_logits

    def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
        assert y_true.size(0) == y_pred.size(0)

        if self.from_logits:
            y_pred = torch.sigmoid(y_pred)

        bs = y_true.size(0)

        y_true = y_true.view(bs, -1)
        y_pred = y_pred.view(bs, -1)

        intersection = torch.sum(y_true * y_pred, dim=1)
        sum_ = torch.sum(y_true + y_pred, dim=1)
        jac = (intersection + self.smooth) / (sum_ - intersection + self.smooth)

        bce = nn.functional.binary_cross_entropy(y_pred, y_true, reduction='none')
        bce = torch.mean(bce, dim=1)

        loss = (1 - jac) * self.smooth + bce

        return loss.mean()


# Train

In [6]:
# Set random seed
torch.manual_seed(42)
np.random.seed(42)

# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Create data loaders
train_image_dirs = ['/gdrive/My Drive/DIRT/full_new_data/img']
train_mask_dirs = ['/gdrive/My Drive/DIRT/full_new_data/msk']
val_image_dirs = ['/gdrive/My Drive/DIRT/cv_open_dataset/open_img']
val_mask_dirs = ['/gdrive/My Drive/DIRT/cv_open_dataset/open_msk']
# Initialize datasets
train_dataset = SegmentationDataset(
    train_image_dirs,
    train_mask_dirs,
    is_training=True
)

val_dataset = SegmentationDataset(
    val_image_dirs,
    val_mask_dirs,
    is_training=False
)
# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=Config.BATCH_SIZE,
    shuffle=True,
    num_workers=Config.NUM_WORKERS,
    pin_memory=True
)
val_loader = DataLoader(
        val_dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=False,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True
    )

print(f'Training dataset size: {len(train_dataset)}')
print(f'Validation dataset size: {len(val_dataset)}')

# # Initialize model
# model = TinyUNet(
#     in_channels=Config.IN_CHANNELS,
#     num_classes=Config.NUM_CLASSES
# ).to(device)

def display_sample_prediction(model, val_loader, device):
    model.eval()
    with torch.no_grad():
        sample_image, sample_mask = next(iter(val_loader))
        sample_image = sample_image[0:1].float().to(device)  # явно указываем float
        sample_mask = sample_mask[0:1].float()  # явно указываем float
        output = model(sample_image)
        pred_mask = torch.sigmoid(output).cpu() > 0.5
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.title('Input Image')
        plt.imshow(sample_image[0].cpu().permute(1, 2, 0))
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.title('True Mask')
        plt.imshow(sample_mask[0, 0], cmap='gray')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.title('Predicted Mask')
        plt.imshow(pred_mask[0, 0], cmap='gray')
        plt.axis('off')

        plt.savefig('sample_prediction.png')
        plt.close()

Using device: cuda
Training dataset size: 2282
Validation dataset size: 246


  A.ElasticTransform(alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03, p=0.5),


In [7]:
# !pip install segmentation-models-pytorch

In [9]:
# Function to visualize predictions
def visualize_epoch_results(model, val_loader, epoch, device):
    # Only visualize every 4 epochs
    if epoch % 4 != 0:
        return

    model.eval()
    with torch.no_grad():
        # Get a single batch
        images, masks = next(iter(val_loader))
        # Take first image from batch
        image = images[0:1].to(device)
        mask = masks[0:1]

        # Get prediction
        output = model(image)
        pred_mask = torch.sigmoid(output).cpu() > 0.5

        # Create figure
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))

        # Plot original image
        axes[0].imshow(images[0].permute(1, 2, 0).cpu())
        axes[0].set_title('Original Image')
        axes[0].axis('off')

        # Plot true mask
        axes[1].imshow(mask[0, 0].cpu(), cmap='gray')
        axes[1].set_title('True Mask')
        axes[1].axis('off')

        # Plot predicted mask
        axes[2].imshow(pred_mask[0, 0].cpu(), cmap='gray')
        axes[2].set_title('Predicted Mask')
        axes[2].axis('off')

        plt.suptitle(f'Epoch {epoch+1}')

        # Save the figure
        plt.savefig(f'epoch_{epoch+1}_prediction.png')
        plt.show()
        plt.close()

        # Log to wandb only every 4 epochs
        wandb.log({
            "predictions": wandb.Image(f'epoch_{epoch+1}_prediction.png')
        })

def calculate_metrics(pred, target, threshold=0.5):
    """
    Вычисляет IoU для батча
    pred: тензор предсказаний после sigmoid (B, 1, H, W)
    target: тензор истинных масок (B, 1, H, W)
    """
    # Применяем threshold к предсказаниям
    pred = (pred > threshold).float()

    # Вычисляем IoU для каждого изображения в батче
    intersection = (pred * target).sum(dim=(2, 3))  # (B, 1)
    union = (pred + target).gt(0).float().sum(dim=(2, 3))  # (B, 1)

    # Добавляем малое число для численной стабильности
    iou = (intersection + 1e-8) / (union + 1e-8)  # (B, 1)

    return iou.mean(dim=1)  # Среднее по каналам для каждого изображения


# Rest of the training loop remains the same
wandb.init(project="segmentation-project-3")

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="mobileone_s1",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,                      # model output channels (number of classes in your dataset)
).to(device)

# model = TinyUNet(in_channels=Config.IN_CHANNELS, num_classes=Config.NUM_CLASSES).to(device)

# # Load checkpoint
# checkpoint = torch.load('best_model.pt')
# # Load model state
# model.load_state_dict(checkpoint['model_state_dict'])

model = model.to(memory_format=torch.channels_last)

criterion = BCEJaccardLoss().to(device)

optimizer = optim.Adam(
    model.parameters(),
    lr=Config.LEARNING_RATE,
    # lr=1e-3,
    betas=(Config.WEIGHT_DECAY, 0.999)  # First beta set to 0.9 as per paper
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=Config.EPOCHS,
    eta_min=Config.MIN_LEARNING_RATE
)

scaler = torch.amp.GradScaler()
best_iou = 0
patience = 100
no_improve = 0

for epoch in range(Config.EPOCHS):
    # TRAIN
    model.train()
    running_loss = 0.0

    for images, masks in tqdm(train_loader, desc=f'Epoch {epoch+1}/{Config.EPOCHS}'):
        images = images.to(device)
        masks = masks.to(device)

        # with torch.amp.autocast(device_type='cuda'):
        outputs = model(images)
        loss = criterion(outputs, masks)

        scaler.scale(loss).backward()

        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

        running_loss += loss.item() * images.size(0)

    train_loss = running_loss / len(train_loader.dataset)

    # VALIDATE
    model.eval()
    running_loss = 0.0
    val_ious = []  # Список для хранения IoU каждого изображения

    with torch.no_grad():
        for images, masks in val_loader:
            images = images.to(device)
            masks = masks.to(device)

            # with torch.amp.autocast(device_type='cuda'):
            outputs = model(images)
            loss = criterion(outputs, masks)

            # Получаем sigmoid от выходов
            outputs = torch.sigmoid(outputs)

            # Считаем IoU для текущего батча
            batch_ious = calculate_metrics(outputs, masks)
            val_ious.extend(batch_ious.cpu().numpy())

            running_loss += loss.item() * images.size(0)

    # Считаем средние метрики
    val_loss = running_loss / len(val_loader.dataset)
    val_miou = np.mean(val_ious)  # Среднее IoU по всем изображениям

    current_lr = optimizer.param_groups[0]['lr']
    scheduler.step()

    # Визуализация
    visualize_epoch_results(model, val_loader, epoch, device)

    # Логируем метрики
    wandb.log({
        'train_loss': train_loss,
        'val_loss': val_loss,
        'val_miou': val_miou,
        'learning_rate': current_lr
    })

    # Выводим метрики
    print(f'Epoch {epoch+1}/{Config.EPOCHS}')
    print(f'Learning Rate: {current_lr:.6f}')
    print(f'Train Loss: {train_loss:.4f}')
    print(f'Val Loss: {val_loss:.4f}, Val mIoU: {val_miou:.4f}')

    # Сохраняем лучшую модель
    if val_miou > best_iou:
        best_iou = val_miou
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'best_miou': best_iou,
        }, 'best_model.pt')
        print(f'Saved new best model with mIoU: {val_miou:.4f}')
        no_improve = 0
    else:
        no_improve += 1

    if no_improve >= patience:
        print(f'Early stopping triggered after {patience} epochs without improvement')
        break

# Cleanup
wandb.finish()
print("\nTraining completed!")

Output hidden; open in https://colab.research.google.com to view.

In [11]:
!pip install onnx onnxruntime onnxsim

Collecting onnx
  Downloading onnx-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (16 kB)
Collecting onnxruntime
  Downloading onnxruntime-1.20.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting onnxsim
  Downloading onnxsim-0.4.36-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.3 kB)
Collecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Downloading onnx-1.17.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (16.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m16.0/16.0 MB[0m [31m101.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading onnxruntime-1.20.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (13.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m

In [12]:
import torch
import torch.onnx
import onnx
import onnxruntime
import numpy as np
import warnings
from pathlib import Path
import segmentation_models_pytorch as smp

def load_model_safe(model, checkpoint_path):
    """
    Safely load model weights with proper error handling
    """
    try:
        # Добавляем numpy.core.multiarray.scalar в список безопасных глобальных объектов
        from numpy.core.multiarray import scalar
        torch.serialization.add_safe_globals([scalar])

        # Пробуем загрузить с weights_only=True
        try:
            checkpoint = torch.load(checkpoint_path, weights_only=True)
        except Exception:
            # Если не получилось, загружаем без ограничений
            print("Warning: Loading checkpoint without weights_only restriction")
            checkpoint = torch.load(checkpoint_path, weights_only=False)

        if 'model_state_dict' in checkpoint:
            model.load_state_dict(checkpoint['model_state_dict'])
        else:
            model.load_state_dict(checkpoint)

        print("Model weights loaded successfully")

    except Exception as e:
        print(f"Error loading model weights: {str(e)}")
        raise

    return model.eval()  # Сразу переводим модель в режим eval


def export_smp_model_to_onnx(model,
                            path='model.onnx',
                            input_shape=(1, 3, 256, 256),
                            simplify=True):
    """
    Export segmentation model to ONNX format with additional optimizations
    """
    path = Path(path)
    device = next(model.parameters()).device
    dummy_input = torch.randn(input_shape, device=device)

    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
        warnings.filterwarnings("ignore", category=UserWarning)

        try:
            # Export model with additional settings
            torch.onnx.export(
                model,
                dummy_input,
                path,
                export_params=True,
                opset_version=13,
                do_constant_folding=True,
                input_names=['input'],
                output_names=['output'],
                dynamic_axes={
                    'input': {0: 'batch_size'},
                    'output': {0: 'batch_size'}
                }
            )

            # Verify ONNX model
            onnx_model = onnx.load(path)
            onnx.checker.check_model(onnx_model)

            # Optimize model if requested
            if simplify:
                try:
                    import onnxsim
                    model_simplified, check = onnxsim.simplify(onnx_model)
                    if check:
                        onnx.save(model_simplified, path)
                        print("Model simplified successfully")
                except ImportError:
                    print("onnx-simplifier not installed. Skip simplification.")

            return True

        except Exception as e:
            print(f"Error during model export: {str(e)}")
            return False

def verify_onnx_model(model, onnx_path, input_shape=(1, 3, 256, 256), rtol=1e-3, atol=1e-4):
    """
    Verify ONNX model output matches PyTorch model with better tolerance
    """
    device = next(model.parameters()).device

    try:
        # Load and check ONNX model
        onnx_model = onnx.load(onnx_path)
        onnx.checker.check_model(onnx_model)

        # Create random input with seed for reproducibility
        torch.manual_seed(42)
        x = torch.randn(input_shape, device=device)

        # PyTorch prediction
        with torch.no_grad():
            torch_out = model(x)

        # ONNX Runtime prediction
        providers = ['CUDAExecutionProvider', 'CPUExecutionProvider']
        ort_session = onnxruntime.InferenceSession(onnx_path, providers=providers)
        ort_inputs = {ort_session.get_inputs()[0].name: x.cpu().numpy()}
        ort_out = ort_session.run(None, ort_inputs)[0]

        # Compare outputs with relaxed tolerance
        np.testing.assert_allclose(
            torch_out.cpu().numpy(),
            ort_out,
            rtol=rtol,
            atol=atol,
            err_msg="Output mismatch between PyTorch and ONNX"
        )
        print("Exported model has been verified!")
        return True

    except Exception as e:
        print(f"Verification failed: {str(e)}")
        return False

def main():
    # Initialize model and device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    # Create model
    model = smp.Unet(
        encoder_name="mobileone_s1",
        # encoder_name="efficientnet-b0",
        encoder_weights="imagenet",
        in_channels=3,
        classes=1
    ).to(device)

    # Load weights
    model = load_model_safe(model, 'best_model.pt')

    # Export to ONNX
    input_shape = (1, 3, 256, 256)
    success = export_smp_model_to_onnx(
        model,
        path='segmentation_model2.onnx',
        input_shape=input_shape
    )

    if success:
        print("Model exported successfully!")

if __name__ == "__main__":
    main()

Using device: cuda
Model weights loaded successfully
Model simplified successfully
Model exported successfully!


In [23]:
# import torch
# import torch.onnx
# import onnx
# import onnxruntime

# def export_to_onnx(model, save_path='model.onnx', input_size=(1, 3, 256, 256)):
#     """
#     Export PyTorch model to ONNX format
#     Args:
#         model: PyTorch model
#         save_path: Path to save ONNX model
#         input_size: Input tensor size (batch_size, channels, height, width)
#     """
#     # Set model to evaluation mode
#     model.eval()

#     # Create dummy input tensor
#     dummy_input = torch.randn(input_size, requires_grad=True)

#     # Export the model
#     torch.onnx.export(
#         model,                                      # model being run
#         dummy_input,                                # model input (or a tuple for multiple inputs)
#         save_path,                                  # where to save the model
#         export_params=True,                         # store the trained parameter weights inside the model file
#         opset_version=11,                          # the ONNX version to export the model to
#         do_constant_folding=True,                   # whether to execute constant folding for optimization
#         input_names=['input'],                      # the model's input names
#         output_names=['output'],                    # the model's output names
#         dynamic_axes={
#             'input': {0: 'batch_size'},            # variable length axes
#             'output': {0: 'batch_size'}
#         }
#     )

#     # Verify the exported model
#     onnx_model = onnx.load(save_path)
#     onnx.checker.check_model(onnx_model)

#     return onnx_model

# def verify_onnx_output(pytorch_model, onnx_path, input_size=(1, 3, 256, 256)):
#     """
#     Verify ONNX model output matches PyTorch model
#     """
#     # Create random input
#     x = torch.randn(input_size)

#     # PyTorch forward pass
#     pytorch_model.eval()
#     with torch.no_grad():
#         pytorch_out = pytorch_model(x)

#     # ONNX Runtime forward pass
#     ort_session = onnxruntime.InferenceSession(onnx_path)
#     ort_inputs = {ort_session.get_inputs()[0].name: x.numpy()}
#     ort_out = ort_session.run(None, ort_inputs)[0]

#     # Compare outputs
#     np.testing.assert_allclose(pytorch_out.numpy(), ort_out, rtol=1e-03, atol=1e-05)
#     print("PyTorch and ONNX Runtime outputs matched!")

# # Example usage
# def convert_model_to_onnx():
#     # Initialize model
#     model = TinyUNet(in_channels=3, num_classes=1)

#     # Load trained weights if available
#     try:
#         model.load_state_dict(torch.load('/content/best_model_tine_0_72.pt')['model_state_dict'])
#         print("Loaded trained weights")
#     except:
#         print("Using untrained model")

#     # Export to ONNX
#     onnx_path = 'tinyunet.onnx'
#     onnx_model = export_to_onnx(model, onnx_path)
#     print(f"Model exported to {onnx_path}")

#     # Verify the exported model
#     try:
#         verify_onnx_output(model, onnx_path)
#     except Exception as e:
#         print(f"Verification failed: {str(e)}")

#     return onnx_path

# if __name__ == "__main__":
#     onnx_path = convert_model_to_onnx()

  model.load_state_dict(torch.load('/content/best_model_tine_0_72.pt')['model_state_dict'])


Loaded trained weights
Model exported to tinyunet.onnx
PyTorch and ONNX Runtime outputs matched!


In [None]:
# Cleanup
if wandb.run is not None:
    wandb.finish()

print("Cleaning up and saving final state...")
torch.save(model.state_dict(), 'final_model.pt')
print("Final model state saved")

try:
    display_sample_prediction(model, val_loader, device)
    print("Sample prediction saved as 'sample_prediction.png'")
except Exception as e:
    print(f"Could not generate sample prediction: {str(e)}")

print("\nExecution completed!")