In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from typing import List, Dict

class PVTFeatureExtractor(nn.Module):
    """Wrapper for PVT to extract intermediate features"""
    def __init__(self, backbone_name: str = 'pvt_v2_b2'):
        super().__init__()
        self.model = timm.create_model(backbone_name, pretrained=True)
        
        # Remove head
        self.model.head = nn.Identity()
        
    def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
        B = x.shape[0]
        
        # PVT-v2 forward pass with intermediate features
        features = []
        
        # Patch embedding
        x = self.model.patch_embed(x)
        if self.model.pos_embed is not None:
            x = x + self.model.pos_embed
        x = self.model.pos_drop(x)
        
        # Collect features from each stage
        for i in range(len(self.model.blocks)):
            x = self.model.blocks[i](x)
            if i in [0, 1, 2, 3]:  # Collect features after each stage
                # Reshape features to spatial form
                H = W = int(x.shape[1] ** 0.5)
                features.append(x.reshape(B, H, W, -1).permute(0, 3, 1, 2))
        
        return features

class FPN(nn.Module):
    def __init__(self, in_channels_list: List[int], out_channels: int):
        super().__init__()
        self.lateral_convs = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size=1)
            for in_channels in in_channels_list
        ])
        self.fpn_convs = nn.ModuleList([
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
            for _ in range(len(in_channels_list))
        ])

    def forward(self, inputs: List[torch.Tensor]) -> List[torch.Tensor]:
        # Build laterals
        laterals = [
            lateral_conv(inputs[i])
            for i, lateral_conv in enumerate(self.lateral_convs)
        ]
        
        # Top-down path
        used_backbone_levels = len(laterals)
        for i in range(used_backbone_levels - 1, 0, -1):
            laterals[i - 1] += F.interpolate(
                laterals[i],
                size=laterals[i - 1].shape[-2:],
                mode='nearest'
            )
        
        # Apply FPN convs
        outs = [
            fpn_conv(laterals[i])
            for i, fpn_conv in enumerate(self.fpn_convs)
        ]
        
        return outs

class PVTSegmentation(nn.Module):
    def __init__(self, backbone_name: str = 'pvt_v2_b2', num_classes: int = 150):
        super().__init__()
        # Load PVT backbone with feature extraction
        self.backbone = PVTFeatureExtractor(backbone_name)
        
        # Get backbone output channels
        dummy_input = torch.randn(1, 3, 512, 512)
        features = self.backbone(dummy_input)
        in_channels_list = [feat.shape[1] for feat in features]
        
        # FPN
        self.fpn = FPN(in_channels_list, out_channels=256)
        
        # Segmentation head
        self.seg_head = nn.Sequential(
            nn.Conv2d(256 * len(in_channels_list), 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, num_classes, kernel_size=1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Get features from backbone
        features = self.backbone(x)
        
        # Apply FPN
        fpn_features = self.fpn(features)
        
        # Resize all FPN features to the same size
        target_size = fpn_features[0].shape[-2:]
        resized_features = [
            F.interpolate(feat, size=target_size, mode='bilinear', align_corners=False)
            for feat in fpn_features
        ]
        
        # Concatenate features
        concat_features = torch.cat(resized_features, dim=1)
        
        # Apply segmentation head
        logits = self.seg_head(concat_features)
        
        # Resize to input size
        output = F.interpolate(
            logits,
            size=x.shape[-2:],
            mode='bilinear',
            align_corners=False
        )
        
        return output

def create_training_setup(model: nn.Module, learning_rate: float = 1e-4):
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss(ignore_index=255)
    return optimizer, criterion

def train_one_epoch(
    model: nn.Module,
    dataloader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    criterion: nn.Module,
    device: torch.device
) -> float:
    model.train()
    total_loss = 0
    
    for batch_idx, (images, masks) in enumerate(dataloader):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    return total_loss / len(dataloader)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os
import numpy as np

class SimpleADE20K(Dataset):
    def __init__(self, root_dir, split='train', size=512):
        self.root_dir = root_dir
        self.split = 'training' if split == 'train' else 'validation'  # ADE20K uses 'training' and 'validation'
        self.size = size
        
        # Setup paths
        self.img_dir = os.path.join(root_dir, 'images', self.split)
        self.mask_dir = os.path.join(root_dir, 'annotations', self.split)
        
        # Get all image files
        self.files = [f for f in os.listdir(self.img_dir) if f.endswith('.jpg')]
        if len(self.files) == 0:
            raise ValueError(f"No images found in {self.img_dir}")
            
        print(f"Found {len(self.files)} images in {self.img_dir}")
        
        self.transform = transforms.Compose([
            transforms.Resize((size, size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                              std=[0.229, 0.224, 0.225])
        ])
        
        self.mask_transform = transforms.Compose([
            transforms.Resize((size, size), interpolation=transforms.InterpolationMode.NEAREST),
            transforms.ToTensor()
        ])

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img_name = self.files[idx]
        img_path = os.path.join(self.img_dir, img_name)
        mask_path = os.path.join(self.mask_dir, img_name.replace('.jpg', '.png'))
        
        # Load image and mask
        image = Image.open(img_path).convert('RGB')
        mask = Image.open(mask_path)
        
        # Apply transforms
        image = self.transform(image)
        mask = self.mask_transform(mask)
        mask = mask.squeeze().long()
        
        return image, mask

# Simple segmentation model
class SimpleSegmenter(nn.Module):
    def __init__(self, backbone_name, num_classes=150):
        super().__init__()
        
        # Create backbone
        if 'convnext' in backbone_name:
            self.backbone = timm.create_model(backbone_name, pretrained=True, 
                                           features_only=True, out_indices=[3])
            feat_dim = self.backbone.feature_info.channels()[-1]
        else:  # PVT
            self.backbone = timm.create_model(backbone_name, pretrained=True)
            self.backbone.head = nn.Identity()  # Remove classification head
            feat_dim = 512  # PVT-Small feature dimension
        
        # Simple segmentation head
        self.head = nn.Sequential(
            nn.Conv2d(feat_dim, 256, 3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_classes, 1)
        )

    def forward(self, x):
        if isinstance(self.backbone.default_cfg['architecture'], str) and 'pvt' in self.backbone.default_cfg['architecture']:
            # PVT forward
            features = self.backbone(x)
            # Reshape from transformer output to spatial features
            B = features.shape[0]
            H = W = int(np.sqrt(features.shape[1]))
            features = features.reshape(B, H, W, -1).permute(0, 3, 1, 2)
        else:
            # ConvNeXt forward
            features = self.backbone(x)[-1]
        
        # Upsample logits to input resolution
        logits = self.head(features)
        logits = F.interpolate(logits, size=x.shape[-2:], 
                             mode='bilinear', align_corners=False)
        
        return logits

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    
    for batch_idx, (images, masks) in enumerate(dataloader):
        images, masks = images.to(device), masks.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        if batch_idx % 10 == 0:
            print(f"Batch {batch_idx}, Loss: {loss.item():.4f}")
    
    return total_loss / len(dataloader)

def main():
    # Setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Create dataset and dataloader
    dataset = SimpleADE20K('ADEChallengeData2016', split='train')
    dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)
    
    # Create model (let's start with just ConvNeXt for simplicity)
    model = SimpleSegmenter('convnext_tiny', num_classes=150)
    model = model.to(device)
    
    # Training setup
    num_epochs = 5
    criterion = nn.CrossEntropyLoss(ignore_index=255)
    optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
    
    # Train
    for epoch in range(num_epochs):
        loss = train_one_epoch(model, dataloader, criterion, optimizer, device)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss:.4f}")

if __name__ == '__main__':
    main()

FileNotFoundError: [Errno 2] No such file or directory: '/path/to/ade20k/images/training'