In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from tqdm import tqdm

PATH_TO_DATA = "/home/shared/DARPA/patched_data/patch_size_256_patch_overlap_0_legend_size_256/training"

In [2]:
!nvidia-smi

Wed Oct  5 11:48:32 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000004:04:00.0 Off |                    0 |
| N/A   39C    P0    41W / 300W |      1MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000004:05:00.0 Off |                    0 |
| N/A   44C    P0    42W / 300W |      0MiB / 16384MiB |      0%      Default |
|       

In [3]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True, bottleneck=False):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        elif bottleneck:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 4, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels//2, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [4]:
class DoubleUNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=False):
        super(DoubleUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(2048, 512 // factor, bilinear, bottleneck=True)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)
    
    def encoder(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        return x1, x2, x3, x4, x5
    
    def forward(self, map_patch, label):
        x1, x2, x3, x4, x5 = self.encoder(map_patch)
        label_emb = self.encoder(label)[-1]
        x_cat = torch.cat((x5, label_emb), axis=1) 
        x = self.up1(x_cat, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits
    

In [5]:
class MapLoader(Dataset):
    def __init__(self, PATH_TO_DATA, transforms):
        self.transforms = transforms
        self.path_to_map = os.path.join(PATH_TO_DATA, "map_patches")
        self.path_to_labels = os.path.join(PATH_TO_DATA, "labels")
        self.path_to_segs = os.path.join(PATH_TO_DATA, "seg_patches")
        self.filenames = os.listdir(self.path_to_map)
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        file = self.filenames[idx]
        path_to_map_file = os.path.join(self.path_to_map, file)
        path_to_seg_file = os.path.join(self.path_to_segs, file)
        path_to_label_file = os.path.join(self.path_to_labels, file)
        
        map_patch = self.transforms(Image.open(path_to_map_file))
        seg_patch = self.transforms(Image.open(path_to_seg_file))
        label = self.transforms(Image.open(path_to_label_file))
        
        return map_patch, seg_patch, label
        

In [6]:
data_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

dataset = MapLoader(PATH_TO_DATA, transforms=data_transform)
lengths = [int(len(dataset) * 0.95), len(dataset) - int(len(dataset) * 0.95)]
trainset, testset = random_split(dataset, lengths, generator=torch.Generator().manual_seed(42))


BATCH_SIZE = 64

trainloader = DataLoader(trainset, 
                         batch_size=BATCH_SIZE,
                         shuffle=True,
                         num_workers=2)


testloader = DataLoader(testset, 
                        batch_size=BATCH_SIZE,
                        shuffle=False,
                        num_workers=2)

In [7]:
torch.manual_seed(42)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = DoubleUNet().to(DEVICE)
model = torch.nn.DataParallel(model, device_ids=[0,1])
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
lossfn = nn.MSELoss()

In [None]:
EPOCHS = 10

for epoch in range(1, EPOCHS+1):
    print("****** EPOCH: [{}/{}] LR: {} ******".format(epoch, EPOCHS, round(optimizer.param_groups[0]['lr'], 4)))
    train_loss = []
    test_loss = []

    model.train() # Put in training mode
    loop_train = tqdm(trainloader, total=len(trainloader), leave=True)
    for idx, (map_patch, seg_patch, label) in enumerate(loop_train):
        map_patch, seg_patch, label = map_patch.to(DEVICE), seg_patch.to(DEVICE), label.to(DEVICE)

        optimizer.zero_grad()
        forward_out = model.forward(map_patch, label)

        train_loss_val = lossfn(forward_out, seg_patch)
        train_loss_val.backward()
        optimizer.step()

        train_loss.append(train_loss_val.item())
        
        loop_train.set_description(f"Training")
        loop_train.set_postfix(train_loss=train_loss_val.item())
    
    model.eval()
    loop_test = tqdm(testloader, total=len(testloader), leave=True)
    for idx, (map_patch, seg_patch, label) in enumerate(loop_test):
        map_patch, seg_patch, label = map_patch.to(DEVICE), seg_patch.to(DEVICE), label.to(DEVICE)
    
        with torch.no_grad():
            forward_out = model.forward(imgs) # Forward pass
            test_loss_val = lossfn(forward_out, labels) # Calculate loss

            test_loss.append(test_loss_val.item())
            
            loop_train.set_description(f"Evaluation")
            loop_train.set_postfix(test_loss=test_loss_val.item())
            
            
    print(f"Training Loss: {train_loss[-1]} Validation Loss: {test_loss[-1]}")

****** EPOCH: [1/10] LR: 0.001 ******


Training:   4%|▎         | 152/4106 [01:46<1:00:09,  1.10it/s, train_loss=0.167]