In [2]:
import os
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import matplotlib.pyplot as plt
from tqdm import tqdm

# ------------------ Dataset ------------------
class HairDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.transform = transform
        self.images = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.images[idx])
        mask_path = os.path.join(self.mask_dir, self.images[idx].replace('.jpg', '.png'))

        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        mask = (mask > 127).astype('float32')  # binarize

        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask'].unsqueeze(0)

        return image, mask

# ------------------ U-Net Model ------------------
class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        def conv_block(in_ch, out_ch):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.ReLU(),
                nn.Conv2d(out_ch, out_ch, 3, padding=1),
                nn.ReLU()
            )

        self.enc1 = conv_block(3, 32)
        self.enc2 = conv_block(32, 64)
        self.enc3 = conv_block(64, 128)
        self.pool = nn.MaxPool2d(2)

        self.up2 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec2 = conv_block(128, 64)
        self.up1 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dec1 = conv_block(64, 32)

        self.out = nn.Conv2d(32, 1, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))

        d2 = self.dec2(torch.cat([self.up2(e3), e2], dim=1))
        d1 = self.dec1(torch.cat([self.up1(d2), e1], dim=1))

        return torch.sigmoid(self.out(d1))

# ------------------ Training Setup ------------------
transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.2),
    A.Normalize(),
    ToTensorV2()
])

dataset = HairDataset("łysienie", "masks", transform)
loader = DataLoader(dataset, batch_size=4, shuffle=True)

model = UNet().cuda()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ------------------ Training Loop ------------------
for epoch in range(10):
    model.train()
    total_loss = 0
    for imgs, masks in tqdm(loader):
        imgs, masks = imgs.cuda(), masks.cuda()
        preds = model(imgs)
        loss = criterion(preds, masks)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    print(f"Epoch {epoch+1} Loss: {total_loss:.4f}")

# ------------------ Save and Predict ------------------
os.makedirs("output", exist_ok=True)
model.eval()
with torch.no_grad():
    for idx, (img, _) in enumerate(dataset):
        img_tensor = img.unsqueeze(0).cuda()
        pred = model(img_tensor)[0, 0].cpu().numpy()
        pred_mask = (pred > 0.5).astype('uint8') * 255
        cv2.imwrite(f"output/pred_{idx:03d}.png", pred_mask)


  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()
  0%|          | 0/13 [00:00<?, ?it/s][ WARN:0@102.381] global loadsave.cpp:268 findDecoder imread_('masks/03939.png'): can't open/read file: check file path/integrity
  0%|          | 0/13 [00:00<?, ?it/s]


TypeError: '>' not supported between instances of 'NoneType' and 'int'