<a href="https://colab.research.google.com/github/rozapkk13/unet/blob/master/trainUnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install segmentation-models


Collecting segmentation-models
  Downloading segmentation_models-1.0.1-py3-none-any.whl.metadata (938 bytes)
Collecting keras-applications<=1.0.8,>=1.0.7 (from segmentation-models)
  Downloading Keras_Applications-1.0.8-py3-none-any.whl.metadata (1.7 kB)
Collecting image-classifiers==1.0.0 (from segmentation-models)
  Downloading image_classifiers-1.0.0-py3-none-any.whl.metadata (8.6 kB)
Collecting efficientnet==1.0.0 (from segmentation-models)
  Downloading efficientnet-1.0.0-py3-none-any.whl.metadata (6.1 kB)
Downloading segmentation_models-1.0.1-py3-none-any.whl (33 kB)
Downloading efficientnet-1.0.0-py3-none-any.whl (17 kB)
Downloading image_classifiers-1.0.0-py3-none-any.whl (19 kB)
Downloading Keras_Applications-1.0.8-py3-none-any.whl (50 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.7/50.7 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: keras-applications, image-classifiers, efficientnet, segmentation-models
Successfully 

In [2]:
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"  # Force TensorFlow Keras compatibility

import tensorflow as tf
from segmentation_models import Unet  # ✅ Import pre-trained U-Net
from tensorflow.keras.callbacks import ModelCheckpoint


Segmentation Models: using `tf.keras` framework.


## Train your Unet with membrane data
membrane data is in folder membrane/, it is a binary classification task.

The input shape of image and mask are the same :(batch_size,rows,cols,channel = 1)

In [3]:
  # model_debugger.py
# Run this first to check if the model and dataloader work correctly

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn

# Define dataset transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
def get_dataloader(batch_size=8):
    try:
        dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
        print("✅ Dataloader loaded successfully!")
        return dataloader
    except Exception as e:
        print(f"❌ Error in dataloader: {e}")
        return None

# Define a simple model
def get_model():
    try:
        model = torchvision.models.resnet18(pretrained=False)
        model.fc = nn.Linear(model.fc.in_features, 10)  # Adjust for 10 classes
        print("✅ Model loaded successfully!")
        return model
    except Exception as e:
        print(f"❌ Error in model: {e}")
        return None

# Run tests
if __name__ == "__main__":
    print("🔍 Checking Dataloader...")
    dataloader = get_dataloader()

    print("\n🔍 Checking Model...")
    model = get_model()

    # Check if one batch of data passes through the model
    if dataloader and model:
        try:
            images, labels = next(iter(dataloader))
            outputs = model(images)
            print("✅ Model forward pass successful!")
        except Exception as e:
            print(f"❌ Error in forward pass: {e}")

    print("\n✅ Debugging complete. If no errors, you can proceed to training!")


🔍 Checking Dataloader...
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170M/170M [00:03<00:00, 46.7MB/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
✅ Dataloader loaded successfully!

🔍 Checking Model...




✅ Model loaded successfully!
✅ Model forward pass successful!

✅ Debugging complete. If no errors, you can proceed to training!


### Train with data generator

In [4]:
# Second Code: train_model.py
# Run this second to train the model after fixing potential issues

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Define dataset transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load dataset
def get_dataloader(batch_size=32):
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)
    return trainloader

# Define a simple model
def get_model():
    model = torchvision.models.resnet18(pretrained=False)
    model.fc = nn.Linear(model.fc.in_features, 10)  # Adjust for 10 classes
    return model.to(device)  # Move model to correct device

# Training function
def train_model(model, dataloader, epochs=5, lr=0.001):
    model.train()  # Set model to training mode
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(epochs):
        running_loss = 0.0
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)  # Move to GPU if available

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f'Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}')

    print('Training complete')
    torch.save(model.state_dict(), 'trained_model.pth')
    print('Model saved.')

if __name__ == "__main__":
    trainloader = get_dataloader()
    model = get_model()
    train_model(model, trainloader)


Using device: cuda
Files already downloaded and verified
Epoch 1, Loss: 1.3690193526575525
Epoch 2, Loss: 0.8421951449802115
Epoch 3, Loss: 0.6288292824421186
Epoch 4, Loss: 0.49469657570733827
Epoch 5, Loss: 0.39079503928116327
Training complete
Model saved.


italicized text### Train with npy file

In [5]:
#imgs_train,imgs_mask_train = geneTrainNpy("data/membrane/train/aug/","data/membrane/train/aug/")
#model.fit(imgs_train, imgs_mask_train, batch_size=2, nb_epoch=10, verbose=1,validation_split=0.2, shuffle=True, callbacks=[model_checkpoint])

### test your model and save predicted results

In [None]:
import os
import numpy as np
import cv2
import tensorflow as tf
from skimage import io, transform
from google.colab import files
import zipfile

# ✅ Step 1: Upload & Extract ZIP File
uploaded = files.upload()  # Upload ZIP manually
zip_filename = list(uploaded.keys())[0]  # Get uploaded file name
extract_path = f"/content/{zip_filename.split('.')[0]}"  # Extracted folder path

# Extract ZIP
os.makedirs(extract_path, exist_ok=True)  # Ensure extraction path exists
with zipfile.ZipFile(zip_filename, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

test_path = os.path.join(extract_path, "test")  # Adjust based on ZIP contents
print("✅ Extracted to:", test_path)
print("📂 Test images:", os.listdir(test_path))

# ✅ Step 2: Ensure test folder exists
if not os.path.exists(test_path):
    os.makedirs(test_path)
    print("🚨 Warning: Test folder was missing! Created:", test_path)

# ✅ Step 3: Define test image generator
def testGenerator(test_path, target_size=(256, 256)):
    for file_name in os.listdir(test_path):
        img_path = os.path.join(test_path, file_name)
        img = io.imread(img_path)

        # Convert grayscale images to RGB
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        # Resize image
        img = transform.resize(img, target_size, mode='constant', anti_aliasing=True)

        # Debug: Check if input image has valid values
        print(f"🖼 Processing {file_name} -> Min: {np.min(img)}, Max: {np.max(img)}, Mean: {np.mean(img)}")

        # Normalize (Fix NaNs in preprocessing)
        img = img / (np.max(img) + 1e-8)  # Avoid division by zero
        img = np.expand_dims(img, axis=0)

        # Debug: After normalization
        print(f"📊 After Normalization -> Min: {np.min(img)}, Max: {np.max(img)}, Mean: {np.mean(img)}")

        yield (img,)

# ✅ Step 4: Load Model Safely
model_path = "/content/unet_membrane.keras"  # Ensure correct model path
if os.path.exists(model_path):
    model = tf.keras.models.load_model(model_path, safe_mode=False)
    print("✅ Model Loaded Successfully!")
else:
    raise FileNotFoundError("🚨 Model file not found! Upload your .keras model.")

# ✅ Step 5: Check for NaNs in Model Weights
for layer in model.layers:
    weights = layer.get_weights()
    for w in weights:
        if np.isnan(w).any():
            print(f"🚨 NaN detected in layer: {layer.name}")

# ✅ Step 6: Generate Predictions
test_gen = testGenerator(test_path)
results = model.predict(test_gen, verbose=1)

# ✅ Step 7: Debugging - Check Min/Max Values
print("Min value in predictions:", np.min(results))
print("Max value in predictions:", np.max(results))
print("Mean value in predictions:", np.mean(results))

# ✅ Fix NaNs in Predictions
results = np.nan_to_num(results)  # Replace NaNs with 0
print("✅ Fixed NaNs in predictions")

# ✅ Additional Debugging
print("Unique values in predictions:", np.unique(results))

# ✅ Step 8: Save Predicted Images
save_path = os.path.join(extract_path, "predictions")  # Save inside extracted folder
os.makedirs(save_path, exist_ok=True)

def saveResult(save_path, npyfile):
    for i, img in enumerate(npyfile):
        img = np.squeeze(img)

        # Ensure proper scaling before saving
        img = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)  # Normalize between 0-1
        img = (img * 255).astype(np.uint8)

        # Convert to RGB to avoid black images
        img = cv2.merge([img, img, img])

        save_filename = os.path.join(save_path, f"{i}_predict.png")
        cv2.imwrite(save_filename, img)
        print(f"✅ Saved: {save_filename}")

saveResult(save_path, results)
print("🎉 Prediction Complete! Check saved images in:", save_path)


In [None]:
def testGenerator(test_path, target_size=(256, 256)):
    for file_name in os.listdir(test_path):
        img_path = os.path.join(test_path, file_name)
        img = io.imread(img_path)

        # Convert grayscale images to RGB
        if len(img.shape) == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)

        # Resize image
        img = transform.resize(img, target_size, mode='constant', anti_aliasing=True)

        # Debug: Check if input image has valid values
        print(f"🖼 Processing {file_name} -> Min: {np.min(img)}, Max: {np.max(img)}, Mean: {np.mean(img)}")

        # Normalize (Fix NaNs in preprocessing)
        img = img / (np.max(img) + 1e-8)  # Avoid division by zero
        img = np.expand_dims(img, axis=0)

        # Debug: After normalization
        print(f"📊 After Normalization -> Min: {np.min(img)}, Max: {np.max(img)}, Mean: {np.mean(img)}")

        yield (img,)
