In [None]:
import torch
from torchvision.models.segmentation import deeplabv3_resnet50
from torchvision import transforms
from PIL import Image

device = 'cuda'
NUM_CLASSES = 3  # пример: фон, класс_1, класс_2
model = deeplabv3_resnet50(pretrained=True, num_classes=NUM_CLASSES)
model = model.to(device)

In [None]:
import torch.nn as nn
import torch.optim as optim

criterion = nn.CrossEntropyLoss()  # target: [B, H, W] c метками 0..N–1
optimizer = optim.Adam(model.parameters(), lr=1e-4)


In [None]:
img_tfm = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])
mask_tfm = transforms.Compose([
    transforms.Resize((256, 256), interpolation=Image.NEAREST),
    transforms.PILToTensor()  # оставит маску как uint8 [1, H, W]
])


data/
├── train/
│   ├── images/
│   │   ├── img_001.jpg
│   │   └── ...
│   └── masks/
│       ├── img_001.png     # 1-канальная маска: значения 0, 1, 2...
│       └── ...
├── test/
│   └── images/
│       ├── test_001.jpg
│       └── ...


In [None]:
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

class SegDataset(Dataset):
    def __init__(self, img_dir, mask_dir, tfm_img, tfm_mask):
        self.img_paths = sorted(Path(img_dir).glob('*.jpg'))
        self.mask_dir = Path(mask_dir)
        self.tfm_img = tfm_img
        self.tfm_mask = tfm_mask

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        mask_path = self.mask_dir / (img_path.stem + '.png')

        img = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # grayscale

        img = self.tfm_img(img)
        mask = self.tfm_mask(mask).squeeze(0).long()  # [H, W]

        return img, mask


In [None]:
class TestDataset(Dataset):
    def __init__(self, img_dir, tfm_img):
        self.img_paths = sorted(Path(img_dir).glob('*.jpg'))
        self.tfm_img = tfm_img

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        return self.tfm_img(img), img_path.name


In [None]:
from torch.utils.data import DataLoader

train_ds = SegDataset('data/train/images', 'data/train/masks', img_tfm, mask_tfm)
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True)

test_ds = TestDataset('data/test/images', img_tfm)
test_dl = DataLoader(test_ds, batch_size=1)


In [None]:
EPOCHS = 5
model.train()
for epoch in range(EPOCHS):
    total_loss = 0
    for imgs, masks in train_loader:
        imgs = imgs.to(device)
        masks = masks.to(device)  # [B, H, W]

        optimizer.zero_grad()
        output = model(imgs)['out']     # [B, C, H, W]
        loss = criterion(output, masks)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}: Loss = {total_loss:.4f}")


In [None]:
model.eval()
with torch.inference_mode():
    pred_logits = model(img.unsqueeze(0).to(device))['out']  # [1, C, H, W]
    pred_mask = pred_logits.argmax(1).squeeze(0).cpu()       # [H, W], значения 0..C–1


In [None]:
import torch.nn.functional as F

# pred_mask: тензор маски, размер [B, C, H, W] или [C, H, W]
# orig_size: оригинальный размер, к которому надо вернуть — (H, W)

# Пример: если был ресайз до 256x256, а оригинал был 101x101
orig_size = (101, 101)
restored_mask = F.interpolate(pred_mask.unsqueeze(0).float(), size=orig_size, mode='bilinear', align_corners=False).squeeze(0)


In [None]:
binary_mask = (restored_mask > 0.5).float()
