In [6]:
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torchvision
from tqdm import tqdm
import gc
import matplotlib.pyplot as plt
import time


In [17]:
# Dataset for image + mask pairs
class MaskedSegDataset(Dataset):
    def __init__(self, root, img_folder="images", mask_folder="binary_mask",
                 transform=None, mask_transform=None):

        self.img_dir = os.path.join(root, img_folder)
        self.mask_dir = os.path.join(root, mask_folder)
        self.transform = transform
        self.mask_transform = mask_transform

        # build full paths
        self.img_paths = sorted([
            os.path.join(self.img_dir, f)
            for f in os.listdir(self.img_dir)
            if f.lower().endswith((".jpg", ".jpeg", ".png"))
        ])

        self.mask_paths = sorted([
            os.path.join(self.mask_dir, f)
            for f in os.listdir(self.mask_dir)
            if f.lower().endswith((".jpg", ".jpeg", ".png"))
        ])

        print(f"Found {len(self.img_paths)} images.")

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):

        # read images and mask
        img_path = self.img_paths[idx]
        mask_path = self.mask_paths[idx]

        img = Image.open(img_path).convert("RGB")
        mask = Image.open(mask_path)

        # Image enhancement
        if self.transform:
            img = self.transform(img)

        # 3.mask resize
        if self.mask_transform:
            mask = self.mask_transform(mask)

        # mask clean , make sure all values are between 0-9
        mask_np = np.array(mask)

        # invalid pixels -> 255 ignore_index
        mask_np[mask_np > 9] = 255

        # Convert back to PIL image
        mask = Image.fromarray(mask_np.astype(np.uint8))

        # 5. convert into tensor
        mask = torch.from_numpy(np.array(mask)).long()

        return img, mask


In [22]:
# Training Script
def main():

    #path
    dataset_path = r"C:\Users\dht233\OneDrive - University of Texas at San Antonio\NSF\Housing condition\image segmatation\dataset\translated_data"

    #  HYPERPARAMS 
    IMG_SIZE = 512
    BATCH_SIZE = 2
    NUM_EPOCHS = 50
    LR = 1e-4
    NUM_CLASSES = 10  # 0–9

    # DEVICE 
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print("Using device:", device)

    # ---------- TRANSFORMS ----------
    img_transform = T.Compose([
        T.Resize((IMG_SIZE, IMG_SIZE)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(10),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]),
    ])

    mask_transform = T.Compose([
        T.Resize((IMG_SIZE, IMG_SIZE), interpolation=Image.NEAREST)
    ])

    #  DATASET 
    train_dataset = MaskedSegDataset(
        dataset_path,
        img_folder="images",
        mask_folder="binary_mask",
        transform=img_transform,
        mask_transform=mask_transform
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0
    )

    # MODEL
    model = torchvision.models.segmentation.deeplabv3_resnet50(
        weights=None, num_classes=NUM_CLASSES
    ).to(device)

    #LOSS + OPT 
    criterion = torch.nn.CrossEntropyLoss(ignore_index=0)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    # TRAIN 
    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss = 0.0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")

        for images, masks in pbar:
            images = images.to(device)
            masks = masks.to(device)

            optimizer.zero_grad()
            outputs = model(images)["out"]

            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * images.size(0)

            pbar.set_postfix({"loss": loss.item()})

        epoch_loss = running_loss / len(train_dataset)
        print(f"Epoch {epoch+1} Loss: {epoch_loss:.4f}")

    torch.save(model.state_dict(), "seg_model.pth")
    print("Model saved as seg_model.pth")

In [23]:
if __name__ == "__main__":
    main()

Using device: cuda
Found 642 images.


Epoch 1/50: 100%|██████████| 321/321 [02:50<00:00,  1.88it/s, loss=0.626]


Epoch 1 Loss: 1.1286


Epoch 2/50: 100%|██████████| 321/321 [02:53<00:00,  1.85it/s, loss=0.635]


Epoch 2 Loss: 0.8912


Epoch 3/50: 100%|██████████| 321/321 [02:51<00:00,  1.88it/s, loss=0.876]


Epoch 3 Loss: 0.8538


Epoch 4/50: 100%|██████████| 321/321 [02:52<00:00,  1.87it/s, loss=0.748]


Epoch 4 Loss: 0.8304


Epoch 5/50: 100%|██████████| 321/321 [02:51<00:00,  1.87it/s, loss=1.19] 


Epoch 5 Loss: 0.7960


Epoch 6/50: 100%|██████████| 321/321 [02:51<00:00,  1.88it/s, loss=0.691]


Epoch 6 Loss: 0.7799


Epoch 7/50: 100%|██████████| 321/321 [02:52<00:00,  1.86it/s, loss=0.818]


Epoch 7 Loss: 0.7699


Epoch 8/50: 100%|██████████| 321/321 [02:49<00:00,  1.90it/s, loss=0.727]


Epoch 8 Loss: 0.7582


Epoch 9/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.573]


Epoch 9 Loss: 0.7409


Epoch 10/50: 100%|██████████| 321/321 [02:50<00:00,  1.89it/s, loss=0.757]


Epoch 10 Loss: 0.7230


Epoch 11/50: 100%|██████████| 321/321 [02:50<00:00,  1.88it/s, loss=0.527]


Epoch 11 Loss: 0.7011


Epoch 12/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.603]


Epoch 12 Loss: 0.6953


Epoch 13/50: 100%|██████████| 321/321 [02:50<00:00,  1.88it/s, loss=0.587]


Epoch 13 Loss: 0.6730


Epoch 14/50: 100%|██████████| 321/321 [02:48<00:00,  1.90it/s, loss=0.652]


Epoch 14 Loss: 0.6596


Epoch 15/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.767]


Epoch 15 Loss: 0.6441


Epoch 16/50: 100%|██████████| 321/321 [02:50<00:00,  1.89it/s, loss=0.665]


Epoch 16 Loss: 0.6284


Epoch 17/50: 100%|██████████| 321/321 [02:50<00:00,  1.88it/s, loss=0.519]


Epoch 17 Loss: 0.6066


Epoch 18/50: 100%|██████████| 321/321 [02:51<00:00,  1.87it/s, loss=0.422]


Epoch 18 Loss: 0.5959


Epoch 19/50: 100%|██████████| 321/321 [02:51<00:00,  1.87it/s, loss=0.438]


Epoch 19 Loss: 0.5820


Epoch 20/50: 100%|██████████| 321/321 [02:51<00:00,  1.87it/s, loss=0.963]


Epoch 20 Loss: 0.5675


Epoch 21/50: 100%|██████████| 321/321 [02:51<00:00,  1.87it/s, loss=0.528]


Epoch 21 Loss: 0.5598


Epoch 22/50: 100%|██████████| 321/321 [02:51<00:00,  1.87it/s, loss=0.559]


Epoch 22 Loss: 0.5433


Epoch 23/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.503]


Epoch 23 Loss: 0.5330


Epoch 24/50: 100%|██████████| 321/321 [02:50<00:00,  1.88it/s, loss=0.5]  


Epoch 24 Loss: 0.5086


Epoch 25/50: 100%|██████████| 321/321 [02:50<00:00,  1.88it/s, loss=0.507]


Epoch 25 Loss: 0.4890


Epoch 26/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.425]


Epoch 26 Loss: 0.4874


Epoch 27/50: 100%|██████████| 321/321 [02:48<00:00,  1.91it/s, loss=0.332]


Epoch 27 Loss: 0.4705


Epoch 28/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.539]


Epoch 28 Loss: 0.4576


Epoch 29/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.39] 


Epoch 29 Loss: 0.4531


Epoch 30/50: 100%|██████████| 321/321 [02:49<00:00,  1.90it/s, loss=0.485]


Epoch 30 Loss: 0.4416


Epoch 31/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.287]


Epoch 31 Loss: 0.4183


Epoch 32/50: 100%|██████████| 321/321 [02:49<00:00,  1.90it/s, loss=0.33] 


Epoch 32 Loss: 0.4097


Epoch 33/50: 100%|██████████| 321/321 [02:49<00:00,  1.90it/s, loss=0.508]


Epoch 33 Loss: 0.4076


Epoch 34/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.476]


Epoch 34 Loss: 0.4012


Epoch 35/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.401]


Epoch 35 Loss: 0.3835


Epoch 36/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.399]


Epoch 36 Loss: 0.3873


Epoch 37/50: 100%|██████████| 321/321 [02:48<00:00,  1.90it/s, loss=0.389]


Epoch 37 Loss: 0.3698


Epoch 38/50: 100%|██████████| 321/321 [02:49<00:00,  1.90it/s, loss=0.256]


Epoch 38 Loss: 0.3600


Epoch 39/50: 100%|██████████| 321/321 [02:50<00:00,  1.89it/s, loss=0.258]


Epoch 39 Loss: 0.3528


Epoch 40/50: 100%|██████████| 321/321 [02:48<00:00,  1.90it/s, loss=0.211]


Epoch 40 Loss: 0.3454


Epoch 41/50: 100%|██████████| 321/321 [02:48<00:00,  1.91it/s, loss=0.246]


Epoch 41 Loss: 0.3451


Epoch 42/50: 100%|██████████| 321/321 [02:49<00:00,  1.90it/s, loss=0.407]


Epoch 42 Loss: 0.3341


Epoch 43/50: 100%|██████████| 321/321 [02:50<00:00,  1.89it/s, loss=0.211]


Epoch 43 Loss: 0.3327


Epoch 44/50: 100%|██████████| 321/321 [02:49<00:00,  1.90it/s, loss=0.528]


Epoch 44 Loss: 0.3338


Epoch 45/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.261]


Epoch 45 Loss: 0.3203


Epoch 46/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.189]


Epoch 46 Loss: 0.3083


Epoch 47/50: 100%|██████████| 321/321 [02:50<00:00,  1.88it/s, loss=0.327]


Epoch 47 Loss: 0.3088


Epoch 48/50: 100%|██████████| 321/321 [02:49<00:00,  1.89it/s, loss=0.422]


Epoch 48 Loss: 0.3084


Epoch 49/50: 100%|██████████| 321/321 [02:52<00:00,  1.86it/s, loss=0.324]


Epoch 49 Loss: 0.2972


Epoch 50/50: 100%|██████████| 321/321 [02:51<00:00,  1.87it/s, loss=0.232]

Epoch 50 Loss: 0.2884
Model saved as seg_model.pth





###Inference

In [24]:
import os
import torch
import torchvision.transforms as T
import torchvision
from PIL import Image
import numpy as np
import gc
import sys

# Helper functions for forcibly releasing memory
def force_clean_memory():
    """clean memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# PARAMETERS 
device = "cpu"
MODEL_PATH = "seg_model.pth"
IMG_SIZE = 256
NUM_CLASSES = 10

TEST_FOLDER = r"C:\Users\dht233\OneDrive - University of Texas at San Antonio\NSF\Housing condition\image segmatation\myima"

# Color map
colors = {
    0: [0, 0, 0], 1: [70, 70, 70], 2: [250, 170, 30],
    3: [70, 130, 180], 4: [0, 60, 100], 5: [153, 153, 153],
    6: [107, 142, 35], 7: [255, 0, 0], 8: [0, 0, 142], 9: [220, 220, 0],
}

def colorize_mask(mask):
    h, w = mask.shape
    color_mask = np.zeros((h, w, 3), dtype=np.uint8)
    for cls, color in colors.items():
        color_mask[mask == cls] = color
    return color_mask

def predict_single_image(model, img_path, transform):
    """cope with single images for clean memory"""
    # load image
    img = Image.open(img_path).convert("RGB")
    original_size = img.size
    
    # Reduce to a reasonable size
    if max(img.size) > 512:
        img.thumbnail((512, 512), Image.LANCZOS)
    
    # Resize
    img_resized = img.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
    
    # Transform
    x = transform(img_resized).unsqueeze(0)
    
    # Inference
    with torch.no_grad():
        out = model(x)["out"]
        pred = torch.argmax(out.squeeze(), dim=0).cpu().numpy()
    
    # Clean up torch variables immediately
    del out, x
    
    return img_resized, pred

def process_and_save(img_path, model, transform, output_folder, index, total):
    """cope with and save single image"""
    base_name = os.path.splitext(os.path.basename(img_path))[0]
    print(f"\nProcessing {index}/{total}: {base_name}")
    
    try:
        # predict
        img_resized, pred = predict_single_image(model, img_path, transform)
        
        # generate colorful mask
        colored_mask = colorize_mask(pred)
        
        # save mask
        mask_path = os.path.join(output_folder, f"{base_name}_mask.png")
        mask_img = Image.fromarray(colored_mask)
        mask_img.save(mask_path, optimize=True)
        print(f"  ✓ Saved mask")
        
        # Save the prediction array
        pred_path = os.path.join(output_folder, f"{base_name}_pred.npy")
        np.save(pred_path, pred)
        print(f"  ✓ Saved prediction array")
        
        # Clean up all variables immediately
        del img_resized, pred, colored_mask, mask_img
        force_clean_memory()
        
        return True
        
    except Exception as e:
        print(f"  ✗ ERROR: {e}")
        import traceback
        traceback.print_exc()
        
        # Even if an error occurs, the memory should be cleared.
        force_clean_memory()
        return False

def run_inference():
    """主函数"""
    print("=" * 60)
    print("Starting inference...")
    print("=" * 60)
    
    # Get image list
    imgs = [
        os.path.join(TEST_FOLDER, f)
        for f in os.listdir(TEST_FOLDER)
        if f.lower().endswith((".jpg", ".jpeg", ".png"))
    ]
    
    print(f"\nFound {len(imgs)} images in folder")
    
    if len(imgs) == 0:
        print("No images found!")
        return
    
    # Create output folder
    output_folder = os.path.join(TEST_FOLDER, "predictions")
    os.makedirs(output_folder, exist_ok=True)
    print(f"Output folder: {output_folder}")
    
    # Load model
    print("\nLoading model...")
    model = torchvision.models.segmentation.deeplabv3_resnet50(
        weights=None, num_classes=NUM_CLASSES
    )
    model.load_state_dict(torch.load(MODEL_PATH, map_location=device, weights_only=True))
    model.to(device)
    model.eval()
    print("✓ Model loaded")
    
    # Transform
    transform = T.Compose([
        T.Resize((IMG_SIZE, IMG_SIZE)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    
    # Process each image
    success_count = 0
    for i, img_path in enumerate(imgs, 1):
        if process_and_save(img_path, model, transform, output_folder, i, len(imgs)):
            success_count += 1
        
        # Force cleanup after processing each image.
        force_clean_memory()
        
        # Optional: Limit the number of processes for testing.
        # if i >= 1:  # Process only the first one
        #     print("\nStopping after first image for testing")
        #     break
    
    print("\n" + "=" * 60)
    print(f"✓ Processing complete!")
    print(f"  Successfully processed: {success_count}/{len(imgs)}")
    print(f"  Output folder: {output_folder}")
    print("=" * 60)



In [26]:
if __name__ == "__main__":
    try:
        run_inference()
    except KeyboardInterrupt:
        print("\n\nInterrupted by user")
    except Exception as e:
        print(f"\n\nFATAL ERROR: {e}")
        import traceback
        traceback.print_exc()
    finally:
        # make sure to clean up finally
        force_clean_memory()
        print("\nMemory cleaned")

Starting inference...

Found 3 images in folder
Output folder: C:\Users\dht233\OneDrive - University of Texas at San Antonio\NSF\Housing condition\image segmatation\myima\predictions

Loading model...
✓ Model loaded

Processing 1/3: frontview_pic_20250627093308 (1)
  ✓ Saved mask
  ✓ Saved prediction array

Processing 2/3: windows_image_20250627094307
  ✓ Saved mask
  ✓ Saved prediction array

Processing 3/3: windows_image_20250703114731
  ✓ Saved mask
  ✓ Saved prediction array

✓ Processing complete!
  Successfully processed: 3/3
  Output folder: C:\Users\dht233\OneDrive - University of Texas at San Antonio\NSF\Housing condition\image segmatation\myima\predictions

Memory cleaned
