In [13]:
import os
import numpy as np
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Configuration
class Config:
    data_path = "kaggle-data"
    train_csv = "train_ground_truth.csv"
    train_images_dir = "train"
    test_images_dir = "test_final"
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    num_classes = 4
    classes = ['Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']
    
    # Training parameters
    batch_size = 2
    learning_rate = 1e-4
    num_epochs = 10
    image_size = 256
    
    # Inference parameters
    detection_threshold = 0.3
    min_nucleus_area = 10

config = Config()

# RLE functions
def rle_encode_instances(mask):
    """Convert instance mask to RLE string"""
    if mask is None or mask.size == 0 or np.max(mask) == 0:
        return "0"
    
    pixels = mask.flatten(order='F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    
    rle_pairs = []
    if len(runs) > 0:
        current_val = pixels[runs[0]]
        current_start = runs[0]
        
        for i in range(1, len(runs)):
            if pixels[runs[i]] != current_val:
                length = runs[i] - current_start
                if current_val > 0:
                    rle_pairs.extend([int(current_val), int(current_start), int(length)])
                current_val = pixels[runs[i]]
                current_start = runs[i]
        
        if current_val > 0:
            length = len(pixels) - current_start
            rle_pairs.extend([int(current_val), int(current_start), int(length)])
    
    return " ".join(map(str, rle_pairs)) if rle_pairs else "0"

def simple_rle_decode(rle_str, shape):
    """Decode RLE to binary mask"""
    if rle_str == "0" or not rle_str or pd.isna(rle_str):
        return np.zeros(shape, dtype=np.uint8)
    
    try:
        values = list(map(int, rle_str.split()))
        mask = np.zeros(shape[0] * shape[1], dtype=np.uint8)
        
        for i in range(0, len(values), 3):
            if i + 2 < len(values):
                start = values[i + 1]
                length = values[i + 2]
                mask[start-1:start-1+length] = 1
        
        return mask.reshape(shape, order='F')
    except:
        return np.zeros(shape, dtype=np.uint8)

# SIMPLE Dataset class
class NucleiDataset(Dataset):
    def __init__(self, image_ids, gt_df=None, is_train=True, data_dir=None):
        self.image_ids = image_ids
        self.gt_df = gt_df
        self.is_train = is_train
        self.data_dir = data_dir or config.data_path
        
        if is_train:
            self.images_dir = os.path.join(self.data_dir, config.train_images_dir)
        else:
            self.images_dir = os.path.join(self.data_dir, config.test_images_dir)

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        
        # Load image
        img_path = os.path.join(self.images_dir, f"{image_id}.tif")
        image = cv2.imread(img_path)
        if image is None:
            raise ValueError(f"Could not load image: {img_path}")
        
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        original_shape = image.shape[:2]
        
        # Resize image
        image = cv2.resize(image, (config.image_size, config.image_size))
        image = image.astype(np.float32) / 255.0
        image = torch.from_numpy(image).permute(2, 0, 1).float()  # (C, H, W)
        
        if self.is_train:
            # Load masks
            row = self.gt_df[self.gt_df['image_id'] == image_id].iloc[0]
            masks = []
            
            for class_name in config.classes:
                rle_str = row[class_name]
                mask = simple_rle_decode(rle_str, original_shape)
                mask = cv2.resize(mask, (config.image_size, config.image_size))
                masks.append(mask)
            
            # Stack masks
            target = np.stack(masks, axis=0).astype(np.float32)  # (C, H, W)
            target = torch.from_numpy(target).float()
            
            return image, target
        else:
            return image, image_id, original_shape

# SIMPLE U-Net
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=4):
        super(UNet, self).__init__()
        
        # Encoder
        self.enc1 = self._block(in_channels, 64)
        self.enc2 = self._block(64, 128)
        self.enc3 = self._block(128, 256)
        
        self.pool = nn.MaxPool2d(2)
        
        # Bottleneck
        self.bottleneck = self._block(256, 512)
        
        # Decoder
        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = self._block(512, 256)
        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = self._block(256, 128)
        
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = self._block(128, 64)
        
        self.final = nn.Conv2d(64, out_channels, 1)
        
    def _block(self, in_ch, out_ch):
        return nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.ReLU(inplace=True)
        )
    
    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        
        # Bottleneck
        b = self.bottleneck(self.pool(e3))
        
        # Decoder
        d3 = self.up3(b)
        d3 = torch.cat([e3, d3], dim=1)
        d3 = self.dec3(d3)
        
        d2 = self.up2(d3)
        d2 = torch.cat([e2, d2], dim=1)
        d2 = self.dec2(d2)
        
        d1 = self.up1(d2)
        d1 = torch.cat([e1, d1], dim=1)
        d1 = self.dec1(d1)
        
        return self.final(d1)

# Loss function
class DiceBCELoss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, pred, target):
        bce = F.binary_cross_entropy_with_logits(pred, target)
        
        pred_sig = torch.sigmoid(pred)
        intersection = (pred_sig * target).sum()
        union = pred_sig.sum() + target.sum()
        dice = 1 - (2. * intersection + 1e-6) / (union + 1e-6)
        
        return bce + dice

# Instance separation
def separate_instances(mask, min_area=10):
    if np.sum(mask) == 0:
        return np.zeros_like(mask, dtype=np.uint16)
    
    num_labels, labels = cv2.connectedComponents(mask.astype(np.uint8))
    instance_mask = np.zeros_like(mask, dtype=np.uint16)
    
    instance_id = 1
    for label in range(1, num_labels):
        component = (labels == label).astype(np.uint8)
        if np.sum(component) >= min_area:
            instance_mask[component > 0] = instance_id
            instance_id += 1
    
    return instance_mask

# Training
def train_model():
    print("Loading data...")
    gt_df = pd.read_csv(os.path.join(config.data_path, config.train_csv))
    image_ids = gt_df['image_id'].tolist()[:20]  # Use only 20 images for testing
    
    train_ids, val_ids = train_test_split(image_ids, test_size=0.2, random_state=42)
    
    train_dataset = NucleiDataset(train_ids, gt_df=gt_df, is_train=True)
    val_dataset = NucleiDataset(val_ids, gt_df=gt_df, is_train=True)
    
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    
    print(f"Training on {len(train_dataset)} images")
    
    # Verify dimensions
    for images, targets in train_loader:
        print(f"Images: {images.shape}, Targets: {targets.shape}")
        break
    
    model = UNet().to(config.device)
    criterion = DiceBCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
    
    print("Starting training...")
    
    for epoch in range(config.num_epochs):
        model.train()
        train_loss = 0
        
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{config.num_epochs}')
        for images, targets in pbar:
            images, targets = images.to(config.device), targets.to(config.device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            pbar.set_postfix({'Loss': f'{loss.item():.4f}'})
        
        avg_loss = train_loss / len(train_loader)
        print(f'Epoch {epoch+1} Average Loss: {avg_loss:.4f}')
    
    torch.save(model.state_dict(), 'best_model.pth')
    print("Training completed!")

# Inference
def create_submission():
    print("Creating submission...")
    
    model = UNet().to(config.device)
    if os.path.exists('best_model.pth'):
        model.load_state_dict(torch.load('best_model.pth', map_location=config.device))
    model.eval()
    
    test_dir = os.path.join(config.data_path, config.test_images_dir)
    test_files = [f for f in os.listdir(test_dir) if f.endswith('.tif')]
    test_image_ids = sorted([f.split('.')[0] for f in test_files])[:10]  # Only 10 for testing
    
    test_dataset = NucleiDataset(test_image_ids, is_train=False)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
    
    submission_data = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Processing"):
            if len(batch) == 3:
                images, image_ids, original_shapes = batch
            else:
                images, image_ids = batch
            
            images = images.to(config.device)
            outputs = torch.sigmoid(model(images))
            pred_masks = outputs.squeeze(0).cpu().numpy()
            
            rle_strings = {}
            for i, class_name in enumerate(config.classes):
                binary_mask = (pred_masks[i] > config.detection_threshold).astype(np.uint8)
                
                if binary_mask.sum() == 0:
                    rle_strings[class_name] = "0"
                else:
                    instance_mask = separate_instances(binary_mask)
                    rle_strings[class_name] = rle_encode_instances(instance_mask)
            
            submission_data.append({
                'image_id': image_ids[0],
                'Epithelial': rle_strings['Epithelial'],
                'Lymphocyte': rle_strings['Lymphocyte'],
                'Neutrophil': rle_strings['Neutrophil'],
                'Macrophage': rle_strings['Macrophage']
            })
    
    submission_df = pd.DataFrame(submission_data)
    submission_df = submission_df[['image_id', 'Epithelial', 'Lymphocyte', 'Neutrophil', 'Macrophage']]
    submission_df = submission_df.sort_values('image_id')
    submission_df.to_csv('submission.csv', index=False)
    
    print(f"Submission created with {len(submission_df)} images")
    print("First few rows:")
    print(submission_df.head())

if __name__ == "__main__":
    # Test basic functionality
    print("Testing RLE...")
    test_mask = np.zeros((10, 10), dtype=np.uint16)
    test_mask[2:4, 2:4] = 1
    test_mask[6:8, 6:8] = 2
    encoded = rle_encode_instances(test_mask)
    print(f"RLE test: {encoded}")
    
    # Train
    train_model()
    
    # Create submission
    create_submission()
    
    print("Done! Check submission.csv")

Testing RLE...
RLE test: 1 23 2 1 33 2 2 67 2 2 77 2
Loading data...
Training on 16 images
Images: torch.Size([2, 3, 256, 256]), Targets: torch.Size([2, 4, 256, 256])
Starting training...


Epoch 1/10: 100%|██████████| 8/8 [00:01<00:00,  4.52it/s, Loss=1.6515]


Epoch 1 Average Loss: 1.6140


Epoch 2/10: 100%|██████████| 8/8 [00:01<00:00,  6.32it/s, Loss=1.5956]


Epoch 2 Average Loss: 1.5994


Epoch 3/10: 100%|██████████| 8/8 [00:01<00:00,  6.43it/s, Loss=1.4632]


Epoch 3 Average Loss: 1.5550


Epoch 4/10: 100%|██████████| 8/8 [00:01<00:00,  6.39it/s, Loss=1.6324]


Epoch 4 Average Loss: 1.3017


Epoch 5/10: 100%|██████████| 8/8 [00:01<00:00,  6.44it/s, Loss=1.1523]


Epoch 5 Average Loss: 1.1571


Epoch 6/10: 100%|██████████| 8/8 [00:01<00:00,  6.44it/s, Loss=1.0981]


Epoch 6 Average Loss: 1.1335


Epoch 7/10: 100%|██████████| 8/8 [00:01<00:00,  6.45it/s, Loss=1.1419]


Epoch 7 Average Loss: 1.1269


Epoch 8/10: 100%|██████████| 8/8 [00:01<00:00,  6.45it/s, Loss=1.0787]


Epoch 8 Average Loss: 1.1073


Epoch 9/10: 100%|██████████| 8/8 [00:01<00:00,  6.46it/s, Loss=1.1231]


Epoch 9 Average Loss: 1.1055


Epoch 10/10: 100%|██████████| 8/8 [00:01<00:00,  6.31it/s, Loss=1.1113]


Epoch 10 Average Loss: 1.0981
Training completed!
Creating submission...


Processing: 100%|██████████| 10/10 [00:00<00:00, 22.15it/s]

Submission created with 10 images
First few rows:
  image_id Epithelial                                         Lymphocyte  \
0   slide1          0  1 1 6 1 257 2 1 513 1 1 769 1 1 1025 1 1 1281 ...   
1  slide10          0  1 1 6 6 254 3 1 257 2 6 512 1 1 513 1 6 768 1 ...   
2  slide11          0  1 64000 2 1 64256 2 1 64512 2 1 64768 2 1 6502...   
3  slide12          0  3 4096 1 3 4352 1 1 4353 1 3 4608 1 1 4609 1 3...   
4  slide13          0  4 254 3 4 512 1 4 768 1 4 1024 1 4 1280 1 4 15...   

  Neutrophil                                         Macrophage  
0          0  1 1 258 1 511 3 1 768 2 1 1024 2 1 1280 2 1 15...  
1          0  1 1 258 1 511 3 1 768 2 1 1024 2 1 1280 2 1 15...  
2          0  1 1 22 2 28 229 1 257 1 2 511 2 1 513 1 2 768 ...  
3          0  1 1 257 1 511 3 1 768 1 1 1024 1 1 1280 1 1 15...  
4          0  1 1 257 1 511 3 1 768 1 1 1024 1 1 1280 1 1 15...  
Done! Check submission.csv



