In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import random_split
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
from tqdm import tqdm
import numpy as np

PATH_TO_DATA = "/home/shared/DARPA/patched_data/patch_size_256_patch_overlap_0_legend_size_256/training"

In [2]:
!nvidia-smi

Thu Oct  6 09:04:13 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 515.65.01    Driver Version: 515.65.01    CUDA Version: 11.7     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  On   | 00000004:04:00.0 Off |                    0 |
| N/A   29C    P0    37W / 300W |      1MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla V100-SXM2...  On   | 00000004:05:00.0 Off |                    0 |
| N/A   29C    P0    35W / 300W |      0MiB / 16384MiB |      0%      Default |
|       

In [3]:
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True, bottleneck=False):
        super().__init__()

        # if bilinear, use the normal convolutions to reduce the number of channels
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        elif bottleneck:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 4, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels//2, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)

        # input is CHW
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        
        return self.conv(x)


class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [4]:
class DoubleUNet(nn.Module):
    def __init__(self, n_channels=3, n_classes=1, bilinear=False):
        super(DoubleUNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear
        
        factor = 2 if bilinear else 1
            
        self.inc = DoubleConv(n_channels, 8)
        self.down1 = Down(8, 16)
        self.down2 = Down(16, 32)
        self.down3 = Down(32, 64)
        self.down4 = Down(64, 128 // factor)
        self.up1 = Up(256, 64 // factor, bilinear, bottleneck=True)
        self.up2 = Up(64, 32 // factor, bilinear)
        self.up3 = Up(32, 16 // factor, bilinear)
        self.up4 = Up(16, 8, bilinear)
        self.outc = OutConv(8, n_classes)
    
    def encoder(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        return x1, x2, x3, x4, x5
    
    def forward(self, map_patch, label):
        x1, x2, x3, x4, x5 = self.encoder(map_patch)
        label_emb = self.encoder(label)[-1]
        x_cat = torch.cat((x5, label_emb), axis=1)
        x = self.up1(x_cat, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits


In [5]:
class MapLoader(Dataset):
    def __init__(self, PATH_TO_DATA, label_transform, map_transform, seg_transform):
        self.label_transforms = label_transform
        self.map_transforms = map_transform
        self.seg_transforms = seg_transform
        self.path_to_map = os.path.join(PATH_TO_DATA, "map_patches")
        self.path_to_labels = os.path.join(PATH_TO_DATA, "labels")
        self.path_to_segs = os.path.join(PATH_TO_DATA, "seg_patches")
        self.filenames = os.listdir(self.path_to_map)
        
        
    def __len__(self):
        return len(self.filenames)
    
    def __getitem__(self, idx):
        file = self.filenames[idx]
        path_to_map_file = os.path.join(self.path_to_map, file)
        path_to_seg_file = os.path.join(self.path_to_segs, file)
        path_to_label_file = os.path.join(self.path_to_labels, file)
        
        map_patch = self.map_transforms(Image.open(path_to_map_file))
        seg_patch = self.seg_transforms(Image.open(path_to_seg_file))
        label = self.label_transforms(Image.open(path_to_label_file))
        
        ### RESCALE BETWEEN -1 and 1 ###
        map_patch = 2 * map_patch - 1
        label = 2 * label - 1
        
        return map_patch, seg_patch, label
        

In [6]:
# ### CALCULATE NORMALIZATION ###
# data_transform = transforms.Compose([
#     transforms.ToTensor()
# ])

# dataset = MapLoader(PATH_TO_DATA, transforms=data_transform)
# lengths = [int(len(dataset) * 0.95), len(dataset) - int(len(dataset) * 0.95)]
# trainset, _ = random_split(dataset, lengths, generator=torch.Generator().manual_seed(42))

# trainloader = DataLoader(trainset, 
#                          batch_size=32,
#                          shuffle=True,
#                          num_workers=8)

# map_means, map_stds, label_means, label_stds = [], [], [], []

# for i, (map_patch, _, label) in tqdm(enumerate(trainloader), total=len(trainloader)):
#     map_patch = map_patch.numpy()
#     label = label.numpy()
    
#     map_mean = np.mean(map_patch, axis=(0, 2, 3))
#     map_std = np.std(map_patch, axis=(0, 2, 3))
    
#     label_mean = np.mean(label, axis=(0, 2, 3))
#     label_std = np.std(label, axis=(0, 2, 3))
    
#     map_means.append(map_mean)
#     map_stds.append(map_std)
#     label_means.append(label_mean)
#     label_stds.append(label_std)
    

# map_mean = torch.tensor(np.mean(map_means, axis=0))
# map_std = torch.tensor(np.mean(map_stds, axis=0))

# label_mean = torch.tensor(np.mean(label_means, axis=0))
# label_std = torch.tensor(np.mean(label_stds, axis=0))

# print(f"Map Mean: {map_mean}")
# print(f"Map Std: {map_std}")
# print(f"Label Mean: {label_mean}")
# print(f"Label Std: {label_std}")



In [7]:
map_mean = torch.tensor([0.7677, 0.7403, 0.6487])
map_std = torch.tensor([0.2415, 0.2356, 0.2564])
label_mean = torch.tensor([0.8048, 0.7714, 0.6715])
label_std = torch.tensor([0.2528, 0.2472, 0.2757])

map_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(map_mean, map_std)
])

lab_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(map_mean, map_std)
])

seg_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
])

dataset = MapLoader(PATH_TO_DATA, label_transform=lab_transform, map_transform=map_transform, seg_transform=seg_transform)
lengths = [int(len(dataset) * 0.95), len(dataset) - int(len(dataset) * 0.95)]
trainset, testset = random_split(dataset, lengths, generator=torch.Generator().manual_seed(42))


BATCH_SIZE = 64

trainloader = DataLoader(trainset, 
                         batch_size=BATCH_SIZE,
                         shuffle=True,
                         num_workers=8)


testloader = DataLoader(testset, 
                        batch_size=BATCH_SIZE,
                        shuffle=False,
                        num_workers=8)


In [8]:
torch.manual_seed(42)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = DoubleUNet().to(DEVICE)
model = torch.nn.DataParallel(model, device_ids=[0,1])
optimizer = torch.optim.AdamW(model.parameters(), lr=0.0005)
lossfn = nn.BCEWithLogitsLoss()

In [None]:
EPOCHS = 20

smallest_loss = np.inf

mean_train_loss = []
meat_test_loss = []
for epoch in range(1, EPOCHS+1):
    print("****** EPOCH: [{}/{}] LR: {} ******".format(epoch, EPOCHS, round(optimizer.param_groups[0]['lr'], 4)))
    train_loss = []
    test_loss = []

    model.train() # Put in training mode
    loop_train = tqdm(trainloader, total=len(trainloader), leave=True)
    for idx, (map_patch, seg_patch, label) in enumerate(loop_train):
        map_patch, seg_patch, label = map_patch.to(DEVICE), seg_patch.to(DEVICE), label.to(DEVICE)

        optimizer.zero_grad()
        forward_out = model.forward(map_patch, label)

        train_loss_val = lossfn(forward_out, seg_patch)
        train_loss_val.backward()
        optimizer.step()

        train_loss.append(train_loss_val.item())
        
        loop_train.set_description(f"Training")
        loop_train.set_postfix(train_loss=train_loss_val.item())
    
    model.eval()
    loop_test = tqdm(testloader, total=len(testloader), leave=True)
    for idx, (map_patch, seg_patch, label) in enumerate(loop_test):
        map_patch, seg_patch, label = map_patch.to(DEVICE), seg_patch.to(DEVICE), label.to(DEVICE)
    
        with torch.no_grad():
            forward_out = model.forward(map_patch, label) # Forward pass
            test_loss_val = lossfn(forward_out, seg_patch) # Calculate loss

            test_loss.append(test_loss_val.item())
            
            loop_test.set_description(f"Evaluation")
            loop_test.set_postfix(test_loss=test_loss_val.item())
    
    mean_epoch_train_loss = np.mean(train_loss)
    mean_epoch_test_loss = np.mean(test_loss)
    print(f"Training Loss: {mean_epoch_train_loss} Validation Loss: {mean_epoch_test_loss}")
    
    if mean_epoch_test_loss < smallest_loss:
        print(":::SAVING MODEL:::")
        smallest_loss = mean_epoch_test_loss
        torch.save(model.state_dict(), "model_store/map_unet.pt")
    
    
    
    

****** EPOCH: [1/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [18:03<00:00,  3.79it/s, train_loss=0.492]
Evaluation: 100%|██████████| 217/217 [00:58<00:00,  3.74it/s, test_loss=0.347]


Training Loss: 0.3262303371292698 Validation Loss: 0.22342142655003455
:::SAVING MODEL:::
****** EPOCH: [2/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [16:26<00:00,  4.16it/s, train_loss=0.432]
Evaluation: 100%|██████████| 217/217 [00:54<00:00,  4.01it/s, test_loss=0.411]


Training Loss: 0.21268350278792647 Validation Loss: 0.20579822399237188
:::SAVING MODEL:::
****** EPOCH: [3/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [16:17<00:00,  4.20it/s, train_loss=0.253] 
Evaluation: 100%|██████████| 217/217 [00:59<00:00,  3.65it/s, test_loss=0.198]


Training Loss: 0.1864007613406173 Validation Loss: 0.1787253584592573
:::SAVING MODEL:::
****** EPOCH: [4/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [16:19<00:00,  4.19it/s, train_loss=0.189] 
Evaluation: 100%|██████████| 217/217 [00:52<00:00,  4.09it/s, test_loss=1.39] 


Training Loss: 0.16907009199675385 Validation Loss: 0.1936009648140125
****** EPOCH: [5/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [16:16<00:00,  4.20it/s, train_loss=3.43]  
Evaluation: 100%|██████████| 217/217 [00:52<00:00,  4.11it/s, test_loss=1.63] 


Training Loss: 0.15738749969342367 Validation Loss: 0.20967585672431277
****** EPOCH: [6/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [16:21<00:00,  4.18it/s, train_loss=0.442] 
Evaluation: 100%|██████████| 217/217 [00:55<00:00,  3.92it/s, test_loss=1.49]  


Training Loss: 0.14524299829755535 Validation Loss: 0.16484738055462111
:::SAVING MODEL:::
****** EPOCH: [7/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [16:14<00:00,  4.21it/s, train_loss=1.59]  
Evaluation: 100%|██████████| 217/217 [00:51<00:00,  4.23it/s, test_loss=0.768] 


Training Loss: 0.13582452979797105 Validation Loss: 0.1532840453053949
:::SAVING MODEL:::
****** EPOCH: [8/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [16:24<00:00,  4.17it/s, train_loss=3.22]  
Evaluation: 100%|██████████| 217/217 [01:02<00:00,  3.48it/s, test_loss=0.636] 


Training Loss: 0.12704677578139845 Validation Loss: 0.15228648531821468
:::SAVING MODEL:::
****** EPOCH: [9/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [17:14<00:00,  3.97it/s, train_loss=1.18]  
Evaluation: 100%|██████████| 217/217 [00:56<00:00,  3.84it/s, test_loss=0.168] 


Training Loss: 0.11981213921876112 Validation Loss: 0.1434779060051738
:::SAVING MODEL:::
****** EPOCH: [10/20] LR: 0.0005 ******


Training: 100%|██████████| 4106/4106 [17:10<00:00,  3.98it/s, train_loss=0.469] 
Evaluation: 100%|██████████| 217/217 [00:54<00:00,  3.98it/s, test_loss=0.228] 


Training Loss: 0.11358280735488671 Validation Loss: 0.13709331691814458
:::SAVING MODEL:::
****** EPOCH: [11/20] LR: 0.0005 ******


Training:  54%|█████▎    | 2204/4106 [09:28<05:06,  6.21it/s, train_loss=0.114] 