In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from tqdm import tqdm
from PIL import Image

# --- 1. CONFIGURATION ---
RAW_DATA_DIR = "../content/dataset" # Point to RAW data
CLASSES = ["Hemorrhagic", "Ischemic", "Tumor"]
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 16
EPOCHS = 5 # ROI extraction is easy, 5 epochs is enough
TARGET_SIZE = (224, 224)

# --- 2. THE TEACHER (Robust CV2 Function) ---
def generate_ground_truth_mask(image_np):
    """
    Uses our robust CV2 logic to create a binary mask (0 or 1)
    to train the U-Net.
    """
    img_copy = image_np.copy()
    img_h, img_w = img_copy.shape[:2]
    gray = cv2.cvtColor(img_copy, cv2.COLOR_RGB2GRAY) # Note: RGB input from PIL
    
    # Blur & Threshold
    blurred = cv2.GaussianBlur(gray, (7, 7), 0)
    _, thresh = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
    
    # Dilate
    kernel = np.ones((5,5), np.uint8)
    dilated = cv2.dilate(thresh, kernel, iterations=3)
    
    # Filter Contours (Remove Text)
    contours, _ = cv2.findContours(dilated, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    valid_contours = []
    if contours:
        for c in contours:
            if cv2.contourArea(c) < (0.01 * img_h * img_w): continue # Tiny noise
            
            # Positional Filter (Ignore Header/Footer text)
            x, y, w, h = cv2.boundingRect(c)
            cy = y + h/2
            if (cy < img_h * 0.10) or (cy > img_h * 0.90): continue
            
            valid_contours.append(c)
            
    # Create Mask
    mask = np.zeros_like(gray, dtype=np.float32)
    if valid_contours:
        c = max(valid_contours, key=cv2.contourArea)
        cv2.drawContours(mask, [c], -1, 1.0, -1) # 1.0 for Brain
        
    return mask

# --- 3. DATASET (Generates Masks on the Fly) ---
class SegmentationDataset(Dataset):
    def __init__(self, root_dir):
        self.files = []
        for cls in CLASSES:
            cls_dir = os.path.join(root_dir, cls)
            if not os.path.exists(cls_dir): continue
            for f in os.listdir(cls_dir):
                if f.lower().endswith(('.jpg', '.png', '.jpeg')):
                    self.files.append(os.path.join(cls_dir, f))
                    
    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_path = self.files[idx]
        
        # Load Image
        try:
            img_pil = Image.open(img_path).convert("RGB")
            img_np = np.array(img_pil)
        except:
            img_pil = Image.new('RGB', TARGET_SIZE)
            img_np = np.array(img_pil)
            
        # Generate Target Mask (The Teacher)
        mask = generate_ground_truth_mask(img_np)
        
        # Resize both to target
        img_pil = img_pil.resize(TARGET_SIZE)
        mask_pil = Image.fromarray((mask * 255).astype(np.uint8)) # Convert to PIL for resize
        mask_pil = mask_pil.resize(TARGET_SIZE, resample=Image.NEAREST)
        
        # Convert to Tensor
        img_t = transforms.ToTensor()(img_pil) # (3, 224, 224)
        mask_t = transforms.ToTensor()(mask_pil) # (1, 224, 224)
        
        return img_t, mask_t

# --- 4. MODEL (Simple U-Net) ---
class SimpleUNet(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.e1 = nn.Sequential(nn.Conv2d(3, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU())
        self.pool = nn.MaxPool2d(2)
        self.e2 = nn.Sequential(nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU())
        
        # Bottleneck
        self.b = nn.Sequential(nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.Conv2d(256, 256, 3, padding=1), nn.ReLU())
        
        # Decoder
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.d2 = nn.Sequential(nn.Conv2d(256 + 128, 128, 3, padding=1), nn.ReLU(), nn.Conv2d(128, 128, 3, padding=1), nn.ReLU())
        self.d1 = nn.Sequential(nn.Conv2d(128 + 64, 64, 3, padding=1), nn.ReLU(), nn.Conv2d(64, 64, 3, padding=1), nn.ReLU())
        
        # Output
        self.out = nn.Conv2d(64, 1, 1) # 1 Channel Output (Mask)
        
    def forward(self, x):
        x1 = self.e1(x)
        p1 = self.pool(x1)
        x2 = self.e2(p1)
        p2 = self.pool(x2)
        
        b = self.b(p2)
        
        u2 = self.up(b)
        u2 = torch.cat([u2, x2], dim=1) # Skip Connection
        x3 = self.d2(u2)
        
        u1 = self.up(x3)
        u1 = torch.cat([u1, x1], dim=1) # Skip Connection
        x4 = self.d1(u1)
        
        return torch.sigmoid(self.out(x4))

# --- 5. TRAINING LOOP ---
def train_locator_unet():
    print("ðŸ§  initializing Locator U-Net...")
    dataset = SegmentationDataset(RAW_DATA_DIR)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    
    model = SimpleUNet().to(DEVICE)
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCELoss() # Binary Cross Entropy for Masks
    
    print(f"ðŸš€ Training Locator on {len(dataset)} images...")
    
    for epoch in range(EPOCHS):
        model.train()
        loop = tqdm(loader, desc=f"Ep {epoch+1}/{EPOCHS}", leave=False)
        epoch_loss = 0
        
        for images, masks in loop:
            images, masks = images.to(DEVICE), masks.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
            loop.set_postfix(loss=loss.item())
            
        print(f"Epoch {epoch+1} Loss: {epoch_loss/len(loader):.4f}")
        
    # Save the Locator
    torch.save(model.state_dict(), "locator_unet.pth")
    print("âœ… Locator U-Net Saved as 'locator_unet.pth'")

# RUN IT
train_locator_unet()

ðŸ§  initializing Locator U-Net...
ðŸš€ Training Locator on 3163 images...


                                                                      

Epoch 1 Loss: 0.1841


                                                                      

Epoch 2 Loss: 0.1202


                                                                      

Epoch 3 Loss: 0.1131


                                                                      

Epoch 4 Loss: 0.0976


                                                                      

Epoch 5 Loss: 0.0939
âœ… Locator U-Net Saved as 'locator_unet.pth'


