In [None]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torch import nn
import torch.nn.functional as F
import torchvision.transforms.functional as TrF
from torch.hub import load_state_dict_from_url

In [None]:
root= "content/"
imgs= root+"JPEG/"
masks= root+"SegmentationClass/"
f_names= "ImageSets/Segmentation/"

ids= {
    'train': "ImageSets/Segmentation/train.txt",
    'val': "ImageSets/Segmentation/val.txt"
    }


class VOCPascal(Dataset):
    def __init__(self, ids, path_to_imgs, path_to_masks, mode= 'train'):
        data= ids[mode]
        with open(data, 'r') as f:
            fnames= f.read().split()
        self.imgs= [os.path.join(path_to_imgs+img+'.jpg') for img in fnames]
        self.masks= [os.path.join(path_to_masks+mask+'.png') for mask in fnames]
        self.mode= mode

    def __getitem__(self, ix):
        img, mask= self.imgs[ix], self.masks[ix]
        img, mask= Image.open(img), Image.open(mask)
        img, mask= self._transform(img, mask)
        return img, mask
    
    def _transform(self, img, mask):
        means, stdev= [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]
        if self.mode== 'train' and random.random()> 0.5:
            img, mask= TrF.hflip(img), TrF.hflip(mask)
        img, mask= TrF.to_tensor(img), np.array(mask, np.int64)
        mask[mask== 255]= -1
        img= TrF.normalize(img, mean= means, std= stdev)
        mask= torch.from_numpy(mask)
        return img, mask

    def __len__(self):
        return len(self.imgs)

In [None]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, skip=None):
        super(Bottleneck, self).__init__()
        self.conv1= nn.Conv2d(in_channels, out_channels, 1, bias= False)
        self.bn1= nn.BatchNorm2d(out_channels)
        self.conv2= nn.Conv2d(out_channels, out_channels, 3, stride, bias= False)
        self.bn2= nn.BatchNorm2d(out_channels)
        self.conv3= nn.Conv2d(out_channels, 4*out_channels, 1, bias= False)
        self.bn3= nn.BatchNorm2d(4*out_channels)
        self.skip= skip

    def forward(self, x):
        inputs= x
        x= self.conv1(x)
        x= F.relu(self.bn1(x))
        x= self.conv2(x)
        x= F.relu(self.bn2(x))
        x= self.bn3(self.conv3(x))
        if self.skip is not None:
            inputs= self.skip(x)
        x+= inputs
        return F.relu(x)

In [None]:
class ResNet(nn.Module):
    def __init__(self, block, layers, n_classes=1000)
        super(ResNet, self).__init__()
        layers= [3, 4, 23, 3]
        self.in_channels= 64
        channels= 64
        self.conv1= nn.Conv2d(3, channels, 3, 2, 1, bias=False)
        self.bn1= nn.BatchNorm2d(channels)
        self.conv2= nn.Conv2d(channels, channels, 3, 1, 1, bias=False)
        self.bn2= nn.BatchNorm2d(channels)
        self.conv3= nn.Conv2d(channels, 2*channels, 3, 1, 1, bias=False)
        self.bn3= nn.BatchNorm2d(channels)
        self.layer1= self.make_layers(block, channels, 3)
        self.layer2= self.make_layers(block, 2*channels, 4, 2)
        self.layer3= self.make_layers(block, 4*channels, 23, 2)
        self.layer4= self.make_layers(block, 8*channels, 3, 2)
        self.fc1= nn.Linear(2048, num_classes)
    def make_layers(self, block, channels, num_blocks, stride=1):
        skip= None
        if self.in_channels != 4*channels or stride != 1:
            skip = nn.Sequential(
                nn.Conv2d(self.in_channels, 4*channels, 1, stride, bias= False),
                nn.BatchNorm2d(4*channels),
            )
        layers = []
        b= block(self.in_channels, channels, stride, skip)
        layers.append(b)
        self.in_channels= 4*channels
        for _ in range(1, num_blocks):
            b= block(self.in_channels, channels)
            layers.append(b)
        return nn.Sequential(*layers)
    def forward(self, x):
        x= self.conv1(x)
        x= F.relu(self.bn1(x))
        x= F.max_pool2d(x, 3, 2, 1)
        x= self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x= torch.flatten(x, 1)
        return self.fc1(x)

In [None]:
model= ResNet(Block, n_classes) #n_classses
url= "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth"
state= load_state_dict_from_url(url)
model.load_state_dict(state_dict)

In [None]:
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, 
                stride= 1, padding= 0, dilation= 1, bias= False,
                use_norm= False, use_act= False):
        super(ConvBlock, self).__init__()
        self.use_norm= use_norm
        self.use_act= use_act
        self.conv= nn.Conv2d(in_channels, out_channels, kernel_size,
                        stride, padding, dilation, bias= bias)
        self.bn= nn.BatchNorm2d(out_channels)
    def forward(self, x):
        x= self.conv(x)
        if self.use_norm:
            x= self.bn(x)
        if self.use_act:
            x= F.relu(x)
        return x

In [None]:
class ASPP(nn.Module):
    def __init__(self):
        super(ASPP, self).__init__()
        in_channels, out_channels= 2048, 256
        dilations= [1, 6, 12, 18]
        self.aspp1= ConvBlock(in_channels, out_channels, 1, 1, 0, dilations[0], use_norm= True, use_act= True)
        self.aspp2= ConvBlock(in_channels, out_channels, 3, 1, dilations[1], dilations[1], use_norm= True, use_act= True)
        self.aspp3= ConvBlock(in_channels, out_channels, 3, 1, dilations[2], dilations[2], use_norm= True, use_act= True)
        self.aspp4= ConvBlock(in_channels, out_channels, 3, 1, dilations[3], dilations[3], use_norm= True, use_act= True)
        self.conv= ConvBlock(in_channels, out_channels, 1, use_norm= True, use_act= True)
        self.out_conv= ConvBlock(4*out_channels, out_channels, 1, use_norm= True, use_act= True)
    def forward(self, x):
        x1= self.aspp1(x)
        x2= self.aspp2(x)
        x3= self.aspp3(x)
        x4= self.aspp4(x)
        x5= self.conv(F.adaptive_avg_pool2d(x, (1, 1)))
        x= torch.cat([x1, x2, x3, x4, x5], 1)
        x= self.out_conv(x)
        x= F.dropout(x)
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self, n_classes):
        super(Decoder, self).__init__()
        channels= 256
        self.conv1= ConvBlock(channels, 48, 1, use_norm= True, use_act= True)
        self.conv2= ConvBlock(304, channels, 3, 1, 1, use_norm= True, use_act= True)
        self.conv3= ConvBlock(channels, channels, 3, 1, 1, use_norm= True, use_act= True)
        self.conv4= ConvBlock(channels, n_classes, 1)
    def forward(self, x, x1):
        x1= self.conv1(x1)
        x= F.interpolate(x, x1.size()[2:], mode= "bilinear", align_corners= True)
        x= torch.cat([x, x1], 1)
        x= F.dropout(self.conv2(x))
        x= F.dropout(self.conv3(x), 0.1)
        x= self.conv4(x)
        return x

In [None]:
class DeeplabV3Plus(nn.Module):
    def __init__(self, n_classes):
        super(DeeplabV3Plus, self).__init__()
        self.resnet= 
        self.aspp= ASPP()
        self.decode= Decoder(n_classes)
    def forward(self, x):
        x1, x2= self.resnet(x)
        x1= self.aspp(x)
        x1= self.decoder(x1, x2)
        x= F.interpolate(x1, x.size()[2:], mode= "bilinear", align_corners= True)
        return x

In [None]:
for epoch in range(0, num_epochs):
    model.train()
    running_loss= 0
    for batch_ix, (imgs, masks) in trainloader:
        imgs, masks= imgs.cuda(), masks.cuda()
        optimizer.zero_grad()
        outs= model(imgs)
        loss= loss_fn(outs, masks)
        running_loss+= loss.item()
        loss.backward()
        optimizer.step()
    print(f"Train loss: {running_loss/len(trainloader)}")
    valid_loss= 0
    with torch.no_grad():
        for batch_ix, (imgs, masks) in validloader:
            imgs, masks= imgs.cuda(), masks.cuda()
            outs= model(imgs)
            loss= loss_fn(outs, masks)
            valid_loss+= loss.item()
    print(f"Validation loss: {valid_loss/len(validloader)}")