In [1]:
import lpips
import matplotlib.pyplot as plt
import numpy as np
import sys
import torch
import torch.nn as nn
import torchmetrics
import torchvision
from torch.utils.data import DataLoader


In [3]:
sys.path.append("/home/e/e0425222/CS4243-project")
from utils.dataset_utils.AnimalDataset import AnimalDataset
from utils.train_utils.model_utils import Conv2dBlock, GatedConv2d, GatedConv2dBlock, GatedUpConv2dBlock


# Introduction 
In this notebook, we calculate statistics and get sample images using the test dataset, and provide literature and context on each model trained.

In [4]:
DEVICE = "cuda:3"
SEED = 0
SAMPLE_SIZE = 32

In [5]:
test_dataset = AnimalDataset(index_file_path = "/home/e/e0425222/CS4243-project/dataset/frogs_test.txt",
    root_dir_path = "/home/e/e0425222/CS4243-project/dataset/frog_images",
    local_dir_path = "/home/e/e0425222/CS4243-project/dataset/preprocessed_64",
    file_prefix = "frogs_",
    image_dimension = 64,
    concat_mask = True,
    random_noise = False,
    require_init = False,
    drops = [],
    seed = SEED)


In [13]:
VGG_LPIPS = lpips.LPIPS(net = 'vgg').to(DEVICE)
METRICS = {
    "Peak SnR (Whole)" : lambda img, gt, mask : torchmetrics.functional.peak_signal_noise_ratio(img * (1-mask) + gt * mask, gt),
    "L2 loss (Whole)" : lambda img, gt, mask : nn.functional.mse_loss(img * (1-mask) + gt * mask, gt),
    "L2 loss (Mask)" : lambda img, gt, mask : nn.functional.mse_loss(img * (1-mask), gt * (1-mask), reduction = 'sum')/(1-mask).sum(),
    "L1 loss (Whole)" : lambda img, gt, mask : nn.functional.l1_loss(img * (1-mask) + gt * mask, gt),
    "L1 loss (Mask)" : lambda img, gt, mask : nn.functional.l1_loss(img * (1-mask), gt * (1-mask), reduction = 'sum')/(1-mask).sum(),
    "LPIPS (Whole)" : (lambda img, gt, mask : VGG_LPIPS(img * (1-mask) + gt * mask, gt).mean()),
}

def sample_images(generator, device, dataset, filename, sample_size = 16):
    loader = DataLoader(dataset, batch_size = sample_size, shuffle = False, worker_init_fn = lambda id: np.random.seed(seed))
    generator.eval()
    generator.to(device)

    batch = next(iter(loader))
    with torch.no_grad():
        # input and ground truth
        input_batched = batch["image"]
        ground_truth_batched = batch["reconstructed"]
        mask_batched = batch["mask"]

        # move tensors to device
        input_batched = input_batched.to(device)
        ground_truth_batched = ground_truth_batched.to(device)
        mask_batched = mask_batched.to(device)

        # ===== FORWARD PASS =====

        # 1. reshape to channel first
        input_batched = input_batched.permute(0, 3, 1, 2)
        ground_truth_batched = ground_truth_batched.permute(0, 3, 1, 2)
        mask_batched = mask_batched.permute(0, 3, 1, 2)

        # 2. predict    
        output_batched = generator(input_batched)
        spliced_batched = ((1-mask_batched) * output_batched) + (mask_batched * ground_truth_batched) 

        batched_predictions = torch.cat([
            input_batched[:sample_size, 0:3,:,:], # can be 4 channels
            ground_truth_batched[:sample_size,:,:,:],  # 3 channels
            spliced_batched[:sample_size,:,:,:]], dim = 0) 
        
        image_array = torchvision.utils.make_grid(batched_predictions, nrow = sample_size, padding = 10)
        plt.axis('off')
        plt.imshow(image_array.permute(1,2,0).cpu()) # plot with channel first
        plt.savefig(filename, dpi = 2000)
        plt.close()
                

def compute_metrics(generator, device, metrics, dataset):
    loader = DataLoader(dataset, batch_size = 64)
    running_results = {list(metrics.keys())[i] : 0.0 for i in range(len(metrics)) } 
    generator.eval()
    generator.to(device)
    with torch.no_grad():
            batches = 0
            for _, batch in enumerate(loader, 1):
                
                batches += 1

                # input and ground truth
                input_batched = batch["image"]
                ground_truth_batched = batch["reconstructed"]
                mask_batched = batch["mask"]

                # move tensors to device
                input_batched = input_batched.to(device)
                ground_truth_batched = ground_truth_batched.to(device)
                mask_batched = mask_batched.to(device)

                # ===== FORWARD PASS =====

                # 1. reshape to channel first
                input_batched = input_batched.permute(0, 3, 1, 2)
                ground_truth_batched = ground_truth_batched.permute(0, 3, 1, 2)
                mask_batched = mask_batched.permute(0, 3, 1, 2)

                # 2. predict    
                output_batched = generator(input_batched)

                # 3. evaluate
                for key, func in metrics.items():
                    running_results[key] += func(output_batched, ground_truth_batched, mask_batched).detach().item()

                args = ""
                for key, _ in running_results.items():
                    args += key + ": " + str(running_results[key]/batches) + "   "
                print(f"\r{batches}/{len(loader)}: " + args, end = '', flush = True)
                
    # normalise numbers by batch
    for key, _ in running_results.items():
        running_results[key] /= batches

    return running_results   
            


Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /home/e/e0425222/miniconda3/envs/env/lib/python3.7/site-packages/lpips/weights/v0.1/vgg.pth


## Baseline models


### GraphGAN

In [21]:
import torch_geometric as torch_g
import torch_geometric.nn as gnn


class GNNBlock(nn.Module):

    def __init__(self, in_channels, out_channels, activation = nn.ReLU):

        super(GNNBlock, self).__init__()
        self.conv = gnn.GINConv(nn.Sequential(nn.Linear(in_channels, out_channels), activation()))
        self.post1 = nn.Sequential(nn.Linear(out_channels, out_channels), activation())
        self.post2 = nn.Sequential(nn.Linear(out_channels, out_channels), activation())

    def forward(self, input_tensor, adj):
        # convert adj to sparse
        device = input_tensor.device
        edge_index, edge_attr = torch_g.utils.dense_to_sparse(adj)
        edge_index = edge_index.long().to(device)
        b, hw, c = input_tensor.shape
        x = input_tensor.reshape(b * hw, c) # (b x hw x c) -> (bhw x c)

        # forward
        x = self.conv(x, edge_index)
        x = self.post1(x)
        x = self.post2(x)

        # reshape back
        x = x.reshape(b, hw, c)
        
        return x

class GatedGraphConvModule(nn.Module):
    """
    This module implements GNN convolution on images using local, global and channel features to predict
    the adjacency tensor.
    """

    def __init__(self, channels, kernel_size, stride, padding, dilation, activation = nn.ReLU):

        super(GatedGraphConvModule, self).__init__()


        # incidence matrix
        self.feature_conv = GatedConv2dBlock(channels, channels, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation, activation = activation)
        self.edge_conv = GatedConv2dBlock(channels, channels, kernel_size = kernel_size, stride = stride, padding = padding, dilation = dilation, activation = activation)
        self.scaleconv = GatedConv2d(channels, 1, kernel_size = 1, stride = 1, padding = 0, dilation = 1)
        self.offsetconv = GatedConv2d(channels, 1, kernel_size = 1, stride = 1, padding = 0, dilation = 1)

        # graph conv
        self.gnn1 = GNNBlock(channels, channels, activation = activation)


    def forward(self, input_tensor, return_adj = False,):

        # metadata
        b, c, h, w = input_tensor.shape
        
        # 1. compute features for edge prediction
        edge_features = self.edge_conv(input_tensor)

        # 2. compute adjacency matrix by dot product 
        scores = edge_features.view(b, c, h*w) # reshape and normalize
        scores = nn.functional.normalize(scores, p = 2, dim = 2) # normalize vector at each node
        scores = torch.bmm(scores.permute(0, 2, 1), scores) # (b x hw x c) x (b x c x hw) -> (b x hw x hw)
        adj_tensor = torch.sigmoid(scores)

        # 3. compute dampening factor by affine transformations of mean of edge features
        scale = self.scaleconv(edge_features)
        offset = self.offsetconv(edge_features)
        mean = scores.mean(dim = 1, keepdim = True) # (b x 1 x hw) mean of similarity scores as base
        adjustment = torch.relu(scale.view(b,1,h*w) * mean + offset.view(b,1,h*w)) # relu to keep everything positive

        # dampen
        adj_tensor = adj_tensor - adjustment

        # 4. graph conv
        x = input_tensor.view(b, c, h*w).permute(0, 2, 1)  # -> b x hw x c
        x = self.gnn1(x, adj_tensor) + x # -> b x hw x c

        # 5. reshape back to image
        x = x.permute(0, 2, 1).view(b, c, h, w)

        if return_adj:
            return x, adj_tensor
            
        return x

    
class Generator(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, activation):
        super(Generator, self).__init__()
        
        # same -> downsample 
        self.conv0 = GatedConv2dBlock(input_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv1 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 2, padding = 1, dilation = 1, activation = activation)

        # same -> downsample
        self.conv2 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv3 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 2, padding = 1, dilation = 1, activation = activation)

        # 2 x same conv
        self.conv4 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv5 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)

        # graph conv
        self.graphconv1 = GatedGraphConvModule(hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)

        # 4 x dilated conv
        self.conv6 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 2, dilation = 2, activation = activation)
        self.conv7 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 4, dilation = 4, activation = activation)
        self.conv8 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 8, dilation = 8, activation = activation)
        self.conv9 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 16, dilation = 16, activation = activation)

        # 2 x same conv
        self.conv10 = GatedConv2dBlock(2*hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv11 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)

        # upsample -> same -> upsample -> same
        self.conv12 = GatedUpConv2dBlock(hidden_dim, hidden_dim, scale_factor = (2,2), kernel_size = 3, stride = 1, padding = 1, dilation = 1, mode = 'nearest')
        self.conv13 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv14 = GatedUpConv2dBlock(hidden_dim, hidden_dim, scale_factor = (2,2), kernel_size = 3, stride = 1, padding = 1, dilation = 1, mode = 'nearest')
        self.conv15 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)

        # final
        self.final = nn.Conv2d(hidden_dim, output_dim, kernel_size = 3, stride = 1, padding = 'same')

        # for contrastive learning, 1x1 conv to compress feature map into 1 channel
        self.conv_feature = nn.Conv2d(hidden_dim, 1, kernel_size = 1, stride = 1, padding = 'same')

    def forward(self, input_tensor, return_adj = False):
        
        # downsample
        x = self.conv0(input_tensor)
        x = self.conv1(x)

        # downsample
        x = self.conv2(x)
        x = self.conv3(x)

        # middle preprocessing layers
        x = self.conv4(x) + x
        x = self.conv5(x) + x

        # graph conv
        g, adj = self.graphconv1(x, return_adj = True)
        g = g + x

        # dilated conv with residual skips
        d = self.conv6(x) + x
        d = self.conv7(d) + d
        d = self.conv8(d) + d
        d = self.conv9(d) + d

        # middle postprocessing layers
        x = self.conv10(torch.cat([d,g], dim = 1)) + x
        x = self.conv11(x) + x

        # upsample
        x = self.conv12(x)
        x = self.conv13(x)
        x = self.conv14(x)
        x = self.conv15(x)
        
        # final
        x = self.final(x)

        if return_adj:
            return x, adj

        return x



In [22]:
SAVE_PATH = "/home/e/e0425222/CS4243-project/active_experiments/final/GraphGAN/AdaptiveThreshold/generator/generator_epoch20.pt"
FILENAME = "GraphGAN.png"
model = Generator(input_dim = 4, hidden_dim = 64, output_dim = 3, activation = nn.Mish)
model.load_state_dict(torch.load(SAVE_PATH))

<All keys matched successfully>

In [23]:
sample_images(model, DEVICE, test_dataset, filename= FILENAME, sample_size = SAMPLE_SIZE) 

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


In [53]:
compute_metrics(model, DEVICE, METRICS, test_dataset)

9/9: Peak SnR (Whole): 33.10754606458876   L2 loss (Whole): 0.0005091961001097742   L2 loss (Mask): 0.043716352639926806   L1 loss (Whole): 0.003007827973407176   L1 loss (Mask): 0.25875557793511283   LPIPS (Whole): 0.022742397876249418    

{'Peak SnR (Whole)': 33.10754606458876,
 'L2 loss (Whole)': 0.0005091961001097742,
 'L2 loss (Mask)': 0.043716352639926806,
 'L1 loss (Whole)': 0.003007827973407176,
 'L1 loss (Mask)': 0.25875557793511283,
 'LPIPS (Whole)': 0.022742397876249418}

### DilatedGatedGAN

In [17]:
class Generator(nn.Module):

    def __init__(self, input_dim, hidden_dim, output_dim, activation):
        super(Generator, self).__init__()
        
        # same -> down -> same -> down
        self.conv0 = GatedConv2dBlock(input_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv1 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 2, padding = 1, dilation = 1, activation = activation)
        self.conv2 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv3 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 2, padding = 1, dilation = 1, activation = activation)

        # 8 x same : 2 x normal -> 4 x dilated -> 2 x normal
        self.conv4 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv5 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv6 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 2, dilation = 2, activation = activation)
        self.conv7 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 4, dilation = 4, activation = activation)
        self.conv8 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 8, dilation = 8, activation = activation)
        self.conv9 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 16, dilation = 16, activation = activation)
        self.conv10 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv11 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)

        # upsample -> same -> upsample -> same
        self.conv12 = GatedUpConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation, scale_factor = (2,2))
        self.conv13 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)
        self.conv14 = GatedUpConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation, scale_factor = (2,2))
        self.conv15 = GatedConv2dBlock(hidden_dim, hidden_dim, kernel_size = 3, stride = 1, padding = 1, dilation = 1, activation = activation)

        # final
        self.final = nn.Conv2d(hidden_dim, output_dim, kernel_size = 3, stride = 1, padding = 'same')

    def forward(self, input_tensor):

        x = self.conv0(input_tensor)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)

        # middle layers
        x = self.conv4(x) + x
        x = self.conv5(x) + x
        x = self.conv6(x) + x
        x = self.conv7(x) + x
        x = self.conv8(x) + x
        x = self.conv9(x) + x
        x = self.conv10(x) + x
        x = self.conv11(x) + x

        # up sample
        x = self.conv12(x)
        x = self.conv13(x)
        x = self.conv14(x)
        x = self.conv15(x)
        
        # final
        x = self.final(x)

        return x

In [18]:
SAVE_PATH = "/home/e/e0425222/CS4243-project/active_experiments/final/DilatedGated/GAN_1/generator/generator_epoch20.pt"
FILENAME = "DilatedGatedGAN.png"

model = Generator(input_dim = 4, hidden_dim = 64, output_dim = 3, activation = nn.Mish)
model.load_state_dict(torch.load(SAVE_PATH))


<All keys matched successfully>

In [20]:
sample_images(model, DEVICE, test_dataset, filename= FILENAME, sample_size = SAMPLE_SIZE) 
compute_metrics(model, DEVICE, METRICS, test_dataset)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).


9/9: Peak SnR (Whole): 32.67547946506076   L2 loss (Whole): 0.0005427065609385156   L2 loss (Mask): 0.046924981806013316   L1 loss (Whole): 0.0030987938452098104   L1 loss (Mask): 0.2680983328157001   LPIPS (Whole): 0.023552156777845487    

{'Peak SnR (Whole)': 32.67547946506076,
 'L2 loss (Whole)': 0.0005427065609385156,
 'L2 loss (Mask)': 0.046924981806013316,
 'L1 loss (Whole)': 0.0030987938452098104,
 'L1 loss (Mask)': 0.2680983328157001,
 'LPIPS (Whole)': 0.023552156777845487}