In [1]:
import torch 
import torch.nn as nn


# Configurations

In [66]:
dim_text_embedding = 1000
dim_conditioning_var = 128
dim_noise = 100
channels_gen = 128
channels_discr = 64
upscale_factor = 2

In [94]:
# upsacles image by factor of 2 and also changes number of channels in upscaled image

def upscale(in_channels,out_channels):
    return nn.Sequential(
            nn.Upsample(scale_factor=upscale_factor, mode='nearest'),
            nn.Conv2d(in_channels,out_channels,3,1,1,bias = False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True))

In [95]:
# convolutional residual block, keeps number of channels constant

class ResBlock(nn.Module):
    def __init__(self,channels):
        super().__init__()
        self.channels = channels
        self.block = nn.Sequential(
                        nn.Conv2d(channels,channels,3,1,1,bias = False),
                        nn.BatchNorm2d(channels),
                        nn.Relu(True),
                        nn.Conv2d(channels,channels,3,1,1,bias = False),
                        nn.BatchNorm2d(channels)
                        )
        self.ReLU = nn.ReLU(True)
        
    def forward(self,x):
        residue = x
        x = self.block(x)
        x = x + residue
        x = self.Relu(x)
        return x

In [56]:
class Conditional_augmentation(nn.Module):
    def __init__(self):
        super().__init__()
        self.dim_fc_inp = dim_text_embedding
        self.dim_fc_out = dim_conditioning_var
        self.fc = nn.Linear(self.dim_fc_inp, self.dim_fc_out*2, bias= True)
        self.relu = nn.ReLU()
            
    def get_mu_logvar(self,textEmbedding):
        x = self.relu(self.fc(textEmbedding))
        
        mu = x[:,:conditioning_var_dim]
        logvar = x[:,conditioning_var_dim:]
        return mu,logvar
        
    
    def get_conditioning_variable(self,mu,logvar):
        epsilon = torch.randn(mu.size())
        std = torch.exp(0.5*logvar)
        
        return mu + epsilon*std
    
    def forward(self,textEmbedding):
        mu, logvar = self.get_mu_logvar(textEmbedding)
        return self.get_conditioning_variable(mu, logvar)

In [61]:
a = torch.randn(2,1000)

In [58]:
ca = Conditional_augmentation()

In [64]:
cond_vars = ca(a)

In [80]:
cond_vars.size()

torch.Size([2, 128])

In [110]:
class Discriminator_logit(nn.Module):
    def __init__(self,dim_discr,dim_condVar,concat=False):
        super().__init__()
        self.dim_discr = dim_discr
        self.dim_condVar = dim_condVar
        self.concat = concat
        if concat == True:
            self.logits = nn.Sequential(
                            nn.Conv2d(dim_discr*8 + dim_condVar,dim_discr*8,3,1,1, bias = False),
                            nn.BatchNorm2d(dim_discr*8),
                            nn.LeakyReLU(.2, True),
                            nn.Conv2d(dim_discr*8, 1, kernel_size=4, stride=4),
                            nn.Sigmoid()
                        )
        
        else :
            self.logits = nn.Sequential(
                            nn.Conv2d(dim_discr*8, 1, kernel_size=4, stride=4),
                            nn.Sigmoid()
                        )
        
    def forward(self, hidden_vec, cond_aug=None):
        if self.concat is True and cond_aug is not None:
            cond_aug = cond_aug.view(-1, self.dim_condVar, 1, 1)
            cond_aug = cond_aug.repeat(1, 1, 4, 4)
            hidden_vec = torch.cat((hidden_vec,cond_aug),1)
        
        return self.logits(hidden_vec).view(-1)

In [96]:
class Stage1_Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.dim_noise = dim_noise
        self.dim_cond_aug = dim_conditioning_var
        self.channels_fc = channels_gen * 8
        self.cond_aug_net = Conditional_augmentation()
        
        self.fc = nn.Sequential(
                    nn.Linear(self.dim_noise + self.dim_cond_aug, self.channels_fc * 4 * 4, bias = False),
                    nn.BatchNorm1d(self.channels_fc * 4 * 4),
                    nn.ReLU(True)
                    )
        
        self.upsample = nn.Sequential(
                            upscale(self.channels_fc,self.channels_fc//2),
                            upscale(self.channels_fc//2,self.channels_fc//4),
                            upscale(self.channels_fc//4,self.channels_fc//8),
                            upscale(self.channels_fc//8,self.channels_fc//16)
                            )
        
        self.generated_image = nn.Sequential(
                                nn.Conv2d(self.channels_fc//16,3,3,1,1,bias = False),
                                nn.Tanh())
        
        
    def forward(self,noise,text_embedding):
        cond_aug = self.cond_aug_net(text_embedding)
        x = torch.cat((noise,cond_aug),1)
        
        x = self.fc(x)
        x = x.view(-1,self.channels_fc, 4, 4)
        x = self.upsample(x)
        
        image = self.generated_image(x)
        
        return image
        

In [97]:
gen1 = Stage1_Generator()

In [98]:
noise = torch.randn(2,dim_noise)

In [99]:
im = gen1(noise,a)

In [102]:
im.size()

torch.Size([2, 3, 64, 64])

In [125]:
class Stage1_Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.channels_initial = channels_initial_discr
        
        self.downsample = nn.Sequential(
                            nn.Conv2d(3, self.channels_initial, kernel_size=4, stride=2, padding=1),
                            nn.LeakyReLU(0.2,inplace=True),
            
                            nn.Conv2d(self.channels_initial , self.channels_initial*2, kernel_size=4, stride=2, padding=1),
                            nn.BatchNorm2d(self.channels_initial*2),
                            nn.LeakyReLU(0.2,inplace=True),
            
                            nn.Conv2d(self.channels_initial*2, self.channels_initial*4, kernel_size=4, stride=2, padding=1),
                            nn.BatchNorm2d(self.channels_initial*4),
                            nn.LeakyReLU(0.2,inplace=True),
            
                            nn.Conv2d(self.channels_initial*4, self.channels_initial*8, kernel_size=4, stride=2, padding=1),
                            nn.BatchNorm2d(self.channels_initial*8),
                            nn.LeakyReLU(0.2,inplace=True),
        )
        
        self.cond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,True)
        self.uncond_logit = Discriminator_logit(self.channels_initial,dim_conditioning_var,False)
        
    def forward(self,img):
        return self.downsample(img)

In [126]:
disc1 = Stage1_Discriminator()

In [127]:
scores = disc1(im)

In [130]:
scores.size()

torch.Size([2, 512, 4, 4])

In [131]:
disc1.cond_logit(scores,cond_vars)

tensor([0.4445, 0.6538], grad_fn=<ViewBackward>)

In [9]:
class Stage2_Generator(nn.Module):
    def __inti__(self):
        super().__init__()
        self.downsample_channels = generator_initial_channels
        self.dim_embedding = conditioning_var_dim
        self.cond_aug_net = Conditional_augmentation()
        self.Stage1_G = Stage1_generator()
        self.downsample = nn.Sequential(
                            nn.Conv2d(3, self.downsample_channels, kernel_size=3, stride=1, padding=1),
                            nn.ReLU(inplace=True),
            
                            nn.Conv2d(self.initial_channels, self.downsample_channels*2, kernel_size=4, stride=2, padding=1),
                            nn.BatchNorm2d(self.initial_channels*2),
                            nn.ReLU(inplace=True),
            
                            nn.Conv2d(self.initial_channels*2, self.initial_channels*4, kernel_size=4, stride=2, padding=1),
                            nn.BatchNorm2d(self.initial_channels*2),
                            nn.ReLU(inplace=True),
                        )
        self.hidden = nn.Sequential(
                        nn.Conv2d(self.downsample_channels*4 + self.dim_embedding, self.downsample_channels*4, 3, 1, 1, bias=False),
                        nn.BatchNorm2d(self.downsample_channels*4),
                        nn.ReLU(True)
                        )
        self.residual = nn.Sequential(
                            ResBlock(self.downsample_channels*4),
                            ResBlock(self.downsample_channels*4),
                            ResBlock(self.downsample_channels*4),
                            ResBlock(self.downsample_channels*4)            
                        )
        self.upsample = nn.Sequential(
                            upscale(self.downsample_channels*4,self.downsample_channels*2),
                            upscale(self.downsample_channels*2,self.downsample_channels),
                            upscale(self.downsample_channels,self.downsample_channels//2),
                            upscale(self.downsample_channels//2,self.downsample_channels//4)
                        )
        self.image = nn.Sequential(
                        nn.Conv2d(self.downsample_channels//4, 3, 3, 1, 1, False),
                        nn.Tanh()
                        )
        
    def forward(self,noise, text_embedding):
        image = self.Stage1_G(noise, text_embedding)
        image = image.detach()
        enc_img = self.downsample(image)
        
        cond_aug = self.cond_aug_net(text_embedding)
        cond_aug = cond_aug.view(-1, self.dim_embedding, 1, 1)
        cond_aug = cond_aug.repeat(1, 1, 16, 16)
        
        x = torch.cat((enc_img, cond_aug),1)
        x = self.hidden(x)
        x = self.residual(x)
        x = self.upsample(x) 
        enlarged_img = self.image(x)
        
        return enlarged_img

In [12]:
class Stage2_Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.initial_channels = discriminator_initial_channels
        self.downsample = nn.Sequential(
                            nn.Conv2d(3, self.initial_channels, 4, 2, 1, False),
                            nn.LeakyReLU(0.2, inplace = True),
            
                            nn.Conv2d(self.initial_channels, self.initial_channels*2, 4, 2, 1, False),
                            nn.BatchNorm2d(),
                            nn.LeakyReLU(0.2, inplace = True),
            
                            nn.Conv2d(self.initial_channels*2, self.initial_channels*4, 4, 2, 1, False),
                            nn.BatchNorm2d(),
                            nn.LeakyReLU(0.2, inplace = True),
            
                            nn.Conv2d(self.initial_channels*4, self.initial_channels*8, 4, 2, 1, False),
                            nn.BatchNorm2d(),
                            nn.LeakyReLU(0.2, inplace = True),
            
                            nn.Conv2d(self.initial_channels*8, self.initial_channels*16, 4, 2, 1, False),
                            nn.BatchNorm2d(),
                            nn.LeakyReLU(0.2, inplace = True),
            
                            nn.Conv2d(self.initial_channels*16, self.initial_channels*32, 4, 2, 1, False),
                            nn.BatchNorm2d(),
                            nn.LeakyReLU(0.2, inplace = True),
            
                            nn.Conv2d(self.initial_channels*32, self.initial_channels*16, 3, 1, 1, False),
                            nn.BatchNorm2d(),
                            nn.LeakyReLU(0.2, inplace = True),
            
                            nn.Conv2d(self.initial_channels*16, self.initial_channels*8, 3, 1, 1, False),
                            nn.BatchNorm2d(),
                            nn.LeakyReLU(0.2, inplace = True)
                            )
        
        self.cond_logit = Discriminator_logit(self.initial_channels,conditioning_var_dim,True)
        self.cond_logit = Discriminator_logit(self.initial_channels,conditioning_var_dim,False)
        
    def forward(self,image):
        return self.downsample(image)

In [28]:
class axx(nn.Module):
    def __init__(self):
        super().__init__()
        self.s = nn.Sequential(
                    nn.Linear(4,2),
                    )
    
    def forward(self,x):
        return self.s(x)

    
    
    
    

In [29]:
class b(nn.Module):
    def __init__(self):
        super().__init__()
        self.bss = nn.Sequential(
                    axx()
                    )
    
    def forward(self,x):
        return self.bss(x)


In [30]:
inp = torch.tensor([1,2,3,4],dtype=torch.float)


In [31]:
bf = b()
bf(inp)

tensor([ 0.4665, -1.4558], grad_fn=<AddBackward0>)

In [14]:
torch.exp(0.5*a)

tensor([1.6487, 2.7183, 4.4817])

In [15]:
b=torch.tensor([.5,1.0,1.5])
a*b

tensor([0.5000, 2.0000, 4.5000])