In [1]:
import torch
import torch.nn as nn
from Transformer import *

from PP import *
import math
from torchvision import transforms
from torchvision.datasets import Cityscapes
from dataloader_cityscapes import *
import torch.nn.functional as F
import torchvision.models as models
import functools
import operator

In [2]:
def prior_last_layer(dim_in, stride = [1, 1], padding = [0, 0], dilation = [1, 1], kernel_size = [1, 1], output_padding = [0, 0]):

    return ((dim_in + (2 * padding[0]) - (dilation[0] * (kernel_size[0] - 1)) - 1) /  stride[0]) + 1


def choose_backbone():

    torch.hub._validate_not_a_forked_repo=lambda a,b,c: True
    backbone = torch.nn.Sequential(*(list(torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True).children())[:7]))
    backbone.requires_grad = False
    return backbone


class customVariationalTransformer(nn.Module):

    def __init__(self, **kwargs):

        super(customVariationalTransformer, self).__init__()

        self.batch_size = kwargs["batch_size"]
#         self.backbone = choose_backbone()

#         self.backbone_output_dim = functools.reduce(operator.mul, self.backbone(torch.rand(1,
#                                     *(kwargs['prior_input_channels'], kwargs['input_img_dim'][0], 
#                                       kwargs['input_img_dim'][1])))).shape
#         self.seq_length = self.backbone_output_dim[0]
#         dim1 = prior_last_layer(self.backbone_output_dim[1])
#         dim2 = prior_last_layer(self.backbone_output_dim[2])
#         last_layer = int(dim1 * dim2)
        self.seq_length = 256

        layers = list(kwargs['prior_posterior_layers'])

        self.decoder_emb = nn.ConvTranspose2d(1, self.seq_length, kernel_size = 1, stride = 1)

        self.transformer = Transformer(d_model = 256, nhead = kwargs['transformer_num_heads'],
                                        num_encoder_layers = kwargs['transformer_num_encoder_layer'], 
                                        num_decoder_layers = kwargs['transformer_num_dec_layer'],
                                        dim_feedforward = kwargs['transformer_intermediate_layer_dim'], 
                                        dropout = kwargs['transformer_dropout_per'],
                                        activation = "relu", return_intermediate_dec = False)

        self.output_layer = nn.Sequential(
            nn.Conv2d(in_channels = 1, out_channels = kwargs["num_cat"], kernel_size = 3, padding = 1, bias = True),
            nn.Softmax(dim=1)
        )



    def forward(self, img, segm):

#         resnet_features = self.backbone(img)
#         decoder_embedding = self.decoder_emb(img.unsqueeze(1).view(self.batch_size, 1, int(math.sqrt(img.shape[1])), -1))
#         print(resnet_features.contiguous().view(self.batch_size, self.seq_length, -1).shape)
#         print(img[:,0,:,:].contiguous().view(self.batch_size, self.seq_length, -1).shape)
        reconstruct_posterior = self.transformer.forward(img[:,0,:,:].contiguous().view(self.batch_size, self.seq_length, -1), img[:,0,:,:].contiguous().view(self.batch_size, self.seq_length, -1))
        reconstruct_posterior = self.output_layer(reconstruct_posterior.unsqueeze(1))

        return reconstruct_posterior



In [3]:
preprocess_in = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    transforms.Resize((256,256))
])

preprocess_ou = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((256,256))
])

tr_loader = CityscapesLoader("../datasets/augmented_cityscapes", transform_in = preprocess_in, 
                             transform_ou = preprocess_ou)
train_loader = DataLoader(dataset = tr_loader, batch_size = 50, shuffle = True)



In [4]:
model = customVariationalTransformer(**{"input_img_dim":[256,256],
                       "prior_input_channels":3, "prior_posterior_layers":[64,128,256],
                       "posterior_input_channels":37, "batch_size":50,
                        "transformer_num_heads":2, "transformer_num_encoder_layer":2,
                        "transformer_num_dec_layer":2,"transformer_intermediate_layer_dim":512,
                        "transformer_dropout_per":0, "num_cat": 34})

In [5]:
criterion = nn.BCEWithLogitsLoss(size_average = False, reduce = False, reduction = None)



In [6]:
optimizer = torch.optim.AdamW(model.parameters(), lr =  0.0001, weight_decay = 0.)
for batch in train_loader:

    optimizer.zero_grad()

#     prior_latent_space, posterior_latent_space = model.forward(batch['image'], batch['label'])

#     kl_loss = torch.mean(kl.kl_divergence(posterior_latent_space, prior_latent_space))
#     loss = kl_loss

    reconstruct = model.forward(batch['image'], batch['label'])
#     print(reconstruct.shape)
#     print(batch['label'].shape)
    reconstruction_loss = criterion(input = reconstruct, target = batch['label'])
    loss = torch.mean(reconstruction_loss)
    
    print(loss.item())
    loss.backward()
    optimizer.step()


0.7071971893310547
0.7072063088417053
0.7072041630744934
0.7071991562843323
0.7071945667266846
0.7071843147277832
0.7071908116340637
0.7071763873100281
0.7071744203567505
0.707173764705658
0.7071672677993774
0.7071781754493713
0.7071699500083923
0.7071511745452881
0.7071543335914612
0.7071657180786133
0.7071362733840942
0.7071330547332764
0.7071290612220764
0.7071396112442017
0.7071372270584106
0.7071270942687988
0.7071239352226257
0.7071227431297302
0.7071352005004883
0.7071274518966675
0.7071398496627808
0.707119882106781
0.7071293592453003
0.7071139812469482
0.7071263194084167
0.7071179747581482
0.7071277499198914
0.7071271538734436
0.7071256637573242
0.7071220278739929
0.7071253657341003
0.707129180431366
0.7071071267127991
0.7071046829223633
0.7071119546890259
0.7071231603622437
0.7071208357810974
0.7071327567100525
0.7071077227592468
0.7071194052696228
0.7071157693862915
0.707109808921814
0.707104504108429
0.7071144580841064
0.707116961479187
0.7071108222007751
0.7071154117584229

KeyboardInterrupt: 

In [30]:
input_test1 = torch.rand((64,34,256,256))
input_test2 = torch.rand((64,34,256,256))
mylist = np.array([])
np.stack((mylist,input_test1), axis=0)
np.stack((mylist, input_test2), axis=0)

ValueError: all input arrays must have the same shape

In [None]:
mylist.shape