In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import glob
from tqdm import tqdm
from torch.optim.lr_scheduler import CosineAnnealingLR
import pandas as pd
import numpy as np

# Dataset

In [2]:
class CamVid_Simple(Dataset):
    def __init__(self, path):
        super().__init__()
        self.path = path
        self.files = glob.glob(path+'/*.pth')
        

    def __len__(self):
        return len(self.files)

    def __getitem__(self, index):
        data = torch.load(self.files[index])
        return data['img'], data['mask']   

# UNET


In [3]:
"""
Source of this is same as the one in Complex_CamVid_Final.ipynb

"""


class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1,bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)
    

class Simple_UNET(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, features=[64, 128, 256, 512]):
        super(Simple_UNET, self).__init__()
        self.ups=nn.ModuleList()
        self.downs=nn.ModuleList()
        self.pool=nn.MaxPool2d(kernel_size=2, stride=2)

        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels=feature

        
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                feature*2, feature, kernel_size=2, stride=2))
            self.ups.append(DoubleConv(feature*2, feature))
        

        self.bottleneck=DoubleConv(features[-1], features[-1]*2)
        self.final_conv=nn.Conv2d(features[0], out_channels, kernel_size=1, padding=0)

    def forward(self, x):
        skip_connections = []
        for down in self.downs:
            x=down(x)
            skip_connections.append(x)
            x=self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] # reverses order

        for idx in range(0, len(self.ups), 2):
            x=self.ups[idx](x)
            skip_connection=skip_connections[idx//2]

            if x.shape != skip_connection.shape:
                x=TF.resize(x, size=skip_connection.shape[2:])

            concat_skip=torch.cat((skip_connection, x), dim=1)
            x=self.ups[idx+1](concat_skip)

        return self.final_conv(x)

# Train

In [None]:
num_classes = 32
batch_size = 8
num_epochs = 50
lr = 1e-3
fpath = 'Dataset/CamVid_RGB'


if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print(f'Using {device}')

model  = Simple_UNET(in_channels = 3, out_channels=num_classes)
model = model.to(device)
model.train()

traindataset = CamVid_Simple(path = f'{fpath}/train')

trainloader = DataLoader(traindataset, batch_size=batch_size, shuffle=True)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = CosineAnnealingLR(optimizer, num_epochs*len(trainloader), eta_min=1e-5)

# criterion = Complex_CCELoss()
criterion = nn.CrossEntropyLoss()



for epoch in range(num_epochs):
    loss_avg = 0
    pbar = tqdm(trainloader)
    for batch_idx, (image, mask) in enumerate(pbar):
        image, mask = image.to(device), mask.to(device)
        output = model(image)
        
        loss = criterion(output, mask)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        loss_avg+=loss.item()

        descrip = {
            'Epoch': epoch+1,
            'Loss': loss_avg/(batch_idx+1),
            'lr': optimizer.param_groups[0]['lr']
        }

        pbar.set_postfix(descrip)
    pbar.close()
    torch.save(model.state_dict(), f'model.pth')

In [None]:
from torchmetrics.classification import JaccardIndex
testdataset = CamVid_Simple(path = f'{fpath}/test')
num_classes = 32
batch_size = 8
torch.cuda.empty_cache()
metric = JaccardIndex(task='multiclass', num_classes=32)
testloader = DataLoader(testdataset, batch_size=batch_size, shuffle=False)
model.eval()
with torch.no_grad():
    for batch_idx,(img, target) in enumerate(tqdm(testloader)):
        img = img.to(device)
        output  = model(img)
        output = torch.argmax(output, dim=1)
        metric.update(output.cpu(), target)

    iou = metric.compute()
    print('Jaccard index Score: ', iou)

    metric.reset()