# COND-GLOW

In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
import os
from tqdm import tqdm
import torch.nn as nn
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class AverageMeter(object):
    """Computes and stores the average and current value.

    Adapted from: https://github.com/pytorch/examples/blob/master/imagenet/train.py
    """
    def __init__(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0.

    def reset(self):
        self.val = 0.
        self.avg = 0.
        self.sum = 0.
        self.count = 0.

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def mean_dim(tensor, dim=None, keepdims=False):
    if dim is None:
        return tensor.mean()
    else:
        if isinstance(dim, int):
            dim = [dim]
        dim = sorted(dim)
        for d in dim:
            tensor = tensor.mean(dim=d, keepdim=True)
        if not keepdims:
            for i, d in enumerate(dim):
                tensor.squeeze_(d-i)
        return tensor
class act_norm(torch.nn.Module):
  def __init__(self,n_feats):
    super(act_norm,self).__init__()
    self.register_buffer('initialized',torch.zeros(1))
    self.mean=torch.nn.Parameter(torch.zeros(1,n_feats,1,1))
    self.s=torch.nn.Parameter(torch.zeros(1,n_feats,1,1))
    self.n_feats=n_feats

  def init_params(self,x):
    #actnorm is initialized so output of coupling layer had null mean and unit variance
    #scale the data  with the negative of it's mean and inverse of sqrt of variance
    if self.training == 0:
      return None
    with torch.no_grad():
      #mean=-mean_dim(x.clone(),dim=[0,2,3],keepdims=True)
      #var=mean_dim((x.clone()+mean) ** 2, dim=[0, 2, 3], keepdims=True)
      mean=-x.mean(dim=[0,2,3],keepdims=True)
      var=((x+mean)**2).mean(dim=[0,2,3],keepdims=True)
      s=(float(1.)/(var.sqrt()+(1e-6))).log()
      self.mean.data.copy_(mean.data)
      self.s.data.copy_(s.data)
      self.initialized+=1.

  def forward(self,x,log_det):
    if self.initialized == 0:
      self.init_params(x)
    x=x+self.mean
    x=x*torch.exp(self.s)
    #h*w*sum(log(s))
    log_det_jac=x.shape[2]*x.shape[3]*torch.sum(self.s)
    if log_det is not None:
      log_det+=log_det_jac
    return (x,log_det)
  
  def backward(self,x,log_det):
    if self.initialized == 0:
      self.init_params(x)
    x=x*torch.exp(self.s*-1)
    #h*w*sum(log(s))
    log_det_jac=x.shape[2]*x.shape[3]*torch.sum(self.s)
    if log_det is not None:
      log_det-=log_det_jac
    x=x-self.mean
    return (x,log_det)

In [None]:
class invertible_conv(torch.nn.Module):
  def __init__(self,n_channels):
    super(invertible_conv,self).__init__()
    self.n_channels=n_channels
    #weight matrix for 1x1 inertible convolution has to be chann x chann and orthogonal(det=0)
    #this is reshaped into chann x chann where each row becomes (channelx1x1==>n*h*w) filter
    #There are n_channel such matrices of this diemsnion
    # the purpose of the invertible convolution is to permute the channels of the image instead of just splitting and \
    # then applying transformation. This can be done by using a fixed permutation matrix(weight of the 1d conv matrix) but
    # GLOW decided to make it learnable instead.
    #refer to notes    
    weight_matrix=np.random.randn(n_channels,n_channels)
    weight_matrix=np.linalg.qr(weight_matrix)[0]
    weight_matrix=weight_matrix.astype("float32")
    weight_matrix=torch.from_numpy(weight_matrix)
    self.filter=torch.nn.Parameter(weight_matrix)
  
  def forward(self,x,log_det):
    #h*w*log|det(W)|
    log_det_jacobian=x.shape[2]*x.shape[3]*torch.slogdet(self.filter)[1]
    #print(log_det_jacobian)
    log_det+=log_det_jacobian
    filter=self.filter.view(self.n_channels,self.n_channels,1,1)
    #ref_weight=ref_weight.view(self.n_channels,self.n_channels,1,1)
    op=torch.nn.functional.conv2d(x,filter)
    return (op,log_det)

  def backward(self,x,log_det):
    log_det_jacobian=x.shape[2]*x.shape[3]*torch.slogdet(self.filter)[1]
    log_det-=log_det_jacobian
    inv_filter=torch.inverse(self.filter.double()).float()
    #ref_weight=torch.inverse(ref_weight.double()).float()
    inv_filter=inv_filter.view(self.n_channels,self.n_channels,1,1)
    #ref_weight=ref_weight.view(self.n_channels,self.n_channels,1,1)

    op=torch.nn.functional.conv2d(x,inv_filter)
    return (op,log_det)

In [None]:
# innorm,inconv,condconv==>midconv1,midcondconv1==>midnorm,midconv2,midcondconv2==>out_norm,outcon
# condition it on label

In [None]:
class cond_coupling(torch.nn.Module):
  def __init__(self,in_channel,cond_channel,mid_channel,norm_type="batch_norm"):
    super(cond_coupling,self).__init__()
    self.scale=torch.nn.Parameter(torch.ones(in_channel,1,1))
    out_channel=2*in_channel

    if norm_type == "act_norm":
      self.norm1=act_norm(in_channel)
    else:
      self.norm1=torch.nn.BatchNorm2d(in_channel)
    self.conv1=torch.nn.Conv2d(in_channel,mid_channel,kernel_size=(3,3),padding=1,bias=False)
    self.cond_conv1=torch.nn.Conv2d(cond_channel,mid_channel,kernel_size=(3,3),padding=1,bias=False)
    nn.init.normal_(self.conv1.weight,0.,0.05)
    nn.init.normal_(self.cond_conv1.weight,0.,0.05)

    self.conv2=torch.nn.Conv2d(mid_channel,mid_channel,kernel_size=(3,3),padding=1,bias=False)
    self.cond_conv2=torch.nn.Conv2d(cond_channel,mid_channel,kernel_size=(3,3),padding=1,bias=False)
    nn.init.normal_(self.conv2.weight,0.,0.05)
    nn.init.normal_(self.cond_conv2.weight,0.,0.05)
    
    if norm_type == "act_norm":
      self.norm2=act_norm(mid_channel)
    else:
      self.norm2=torch.nn.BatchNorm2d(mid_channel)
    self.conv3=torch.nn.Conv2d(mid_channel,mid_channel,kernel_size=(1,1),padding=0,bias=False)
    self.cond_conv3=torch.nn.Conv2d(cond_channel,mid_channel,kernel_size=(1,1),padding=0,bias=False)
    nn.init.normal_(self.conv3.weight,0.,0.05)
    nn.init.normal_(self.cond_conv3.weight,0.,0.05)

    if norm_type=="act_norm":
      self.norm3=act_norm(mid_channel)
    else:
      self.norm3=torch.nn.BatchNorm2d(mid_channel)
      self.conv4=torch.nn.Conv2d(mid_channel,out_channel,kernel_size=(3,3),padding=1,bias=True)
      nn.init.zeros_(self.conv4.weight)
      nn.init.zeros_(self.conv4.bias)

     
  
  def forward(self,x,x_cond,log_det):
    x1,x3=x.chunk(2,dim=1)
    #s,t function
    x2=self.norm1(x3)
    x2=self.conv1(x2) + self.cond_conv1(x_cond)
    x2=F.relu(x2)

    x2=self.conv2(x2) + self.cond_conv2(x_cond)
    x2=self.norm2(x2)
    x2=F.relu(x2)

    x2=self.conv3(x2) + self.cond_conv3(x_cond)
    x2=self.norm3(x2)
    x2=F.relu(x2)

    x2=self.conv4(x2)

    s=x2[:,0::2,:,:]
    t=x2[:,1::2,:,:]

    s=self.scale*torch.tanh(s)
    x1=(x1+t)*torch.exp(s)
    

    log_jac=torch.sum(s,dim=(1,2,3))
    log_det=log_det+log_jac
    x=torch.cat((x1,x3),dim=1)

    return(x,log_det)
  
  def backward(self,x,x_cond,log_det):
    x1,x3=x.chunk(2,dim=1)
    #s,t function
    x2=self.norm1(x3)
    #print("PAY ATTENTION HERE:",x_cond.shape)
    x2=self.conv1(x2) + self.cond_conv1(x_cond)
    x2=F.relu(x2)

    x2=self.conv2(x2) + self.cond_conv2(x_cond)
    x2=self.norm2(x2)
    x2=F.relu(x2)

    x2=self.conv3(x2) + self.cond_conv3(x_cond)
    x2=self.norm3(x2)
    x2=F.relu(x2)

    x2=self.conv4(x2)

    s=x2[:,0::2,:,:]
    t=x2[:,1::2,:,:]

    s=self.scale*torch.tanh(s)
    x1=x1*torch.exp(s*-1) - t
    log_jac=torch.sum(s,dim=(1,2,3))
    log_det=log_det-log_jac
    x=torch.cat((x1,x3),dim=1)

    return(x,log_det)

In [None]:
class condflow_module(torch.nn.Module):
  ## act_norm ==> invertible 1x1 conv ==> coupling layer(affine) ##
  def __init__(self,in_channel,cond_channel,mid_channel):
    super(condflow_module,self).__init__()

    self.norm1=act_norm(in_channel)
    self.conv1=invertible_conv(in_channel)
    self.cond_coupling1=cond_coupling(in_channel//2,cond_channel,mid_channel)

  def forward(self,x,x_cond,log_det):
    x,log_det=self.norm1.forward(x,log_det)
    x,log_det=self.conv1.forward(x,log_det)
    x,log_det=self.cond_coupling1.forward(x,x_cond,log_det)

    return (x,log_det)
  
  def backward(self,x,x_cond,log_det):
    x,log_det=self.cond_coupling1.backward(x,x_cond,log_det)
    x,log_det=self.conv1.backward(x,log_det)
    x,log_det=self.norm1.backward(x,log_det)

    return (x,log_det)

In [None]:
class glow(torch.nn.Module):
  def __init__(self,in_channel,cond_channel,mid_channel,L,K):
    super(glow,self).__init__()
    self.glow_flows1=torch.nn.ModuleList([condflow_module(in_channel,cond_channel,mid_channel) for _ in range(K)])
    if L>1:
      self.glow_flows2=glow(2*in_channel,4*cond_channel,mid_channel,L-1,K)
    else:
      self.glow_flows2=None
  
  def forward(self,x,x_cond,log_det):
    for block in self.glow_flows1:
      x,log_det=block(x,x_cond,log_det)
    if self.glow_flows2 is not None:
      x=squeeze(x)
      x_cond=squeeze(x_cond)
      x,x2=x.chunk(2,dim=1)
      x,log_det=self.glow_flows2.forward(x,x_cond,log_det)
      x=torch.cat((x,x2),dim=1)
      x=unsqueeze(x)
      x_cond=unsqueeze(x_cond)
    
    return (x,log_det)
  
  def backward(self,x,x_cond,log_det):
    if self.glow_flows2 is not None:
      x=squeeze(x)
      x_cond=squeeze(x_cond)
      x,x2=x.chunk(2,dim=1)
      x,log_det=self.glow_flows2.backward(x,x_cond,log_det)
      x=torch.cat((x,x2),dim=1)
      x=unsqueeze(x)
      x_cond=unsqueeze(x_cond)
    
    for block in self.glow_flows1[::-1]:
      x,log_det=block.backward(x,x_cond,log_det)
    
    return (x,log_det)

In [None]:
class glow_model(torch.nn.Module):
  def __init__(self,prior_dist,n_channels,L,K):
    super(glow_model,self).__init__()
    self.prior=prior_dist
    self.model=glow(in_channel=4*3,cond_channel=4,mid_channel=n_channels,L=L,K=K)#for rgb 4*3, for bw 4*1

  def inference(self,x,x_cond):
    x,log_det=preprocess(x)
    x_cond,log_cond=preprocess(x_cond)
    x=squeeze(x)
    x_cond=squeeze(x_cond)
    x,log_det=self.model.forward(x,x_cond,log_det)
    x=unsqueeze(x)
    return (x,log_det)
  
  def sampling(self,x,x_cond):
    log_det=torch.zeros(x.shape[0])
    log_det=log_det.to(device)
    x_cond,_=preprocess(x_cond)
    z=squeeze(x)
    z_cond=squeeze(x_cond)
    z,log_det=self.model.backward(z,z_cond,log_det)
    z=unsqueeze(z)
    return (z)
  
  def likelihood(self,x,x_cond):
    #log(p(x))=log(ph(f(x)))+log(sii)
    x_,log_det=self.inference(x,x_cond)
    prior_ll=-0.5*(x_**2 + np.log(2*np.pi))
    prior_ll=prior_ll.flatten(1).sum(-1) - np.log(256) * np.prod(x_.size()[1:])
    return (prior_ll+log_det)
  
  def forward(self,x,x_cond):
    ll=self.likelihood(x,x_cond)
    return (ll)
  
  @torch.no_grad()
  def sample_images(self,number,channel,height,width,cond_img):
    z=self.prior.sample((number,channel,height,width))
    z=z.to(device)
    x=self.sampling(z,cond_img)
    return (x)

In [None]:
#datalaoder
bs=64
transform_train=transforms.Compose([transforms.ToTensor(),transforms.RandomHorizontalFlip(),
                                    transforms.Resize((32,32))])
transform_test=transforms.Compose([transforms.ToTensor(),torchvision.transforms.Resize((32,32))])

trainset=torchvision.datasets. CIFAR10(root='./data', train=True,download=True,transform=transform_train)
train_loader=torch.utils.data.DataLoader(trainset,batch_size=bs,shuffle=True, num_workers=2)

testset=torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform_test)
test_loader=torch.utils.data.DataLoader(testset, batch_size=bs,shuffle=False,num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


HBox(children=(FloatProgress(value=0.0, max=170498071.0), HTML(value='')))


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
def train_model(epoch,Glow_net,train_loader,device,optimizer,scheduler):
  Glow_net.train()
  loss_meter=AverageMeter()
  with tqdm(total=len(train_loader.dataset)) as progress_bar:
    for batch_idx,data in enumerate(train_loader):
      x=data[0].to(device)
      cond_x=x.cpu()
      cond_x=[torchvision.transforms.Grayscale()(i) for i in cond_x]
      cond_x=torch.stack(cond_x).to(device)
      optimizer.zero_grad()
      loss=-Glow_net(x,cond_x).mean()
      loss_meter.update(loss.item(),x.size(0))
      loss.backward()
      optimizer.step()
      scheduler.step()
      progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x,loss_meter.avg),lr=optimizer.param_groups[0]["lr"])
      progress_bar.update(x.size(0))

In [None]:
def test_model(epoch,Glow_net,test_loader,device,optimizer,scheduler):
  Glow_net.eval()
  loss_meter=AverageMeter()

  with tqdm(total=len(test_loader.dataset)) as progress_bar:
    for batch_idx,data in enumerate(test_loader):
      x=data[0].to(device)
      cond_x=x.cpu()
      cond_x=[torchvision.transforms.Grayscale()(i) for i in cond_x]
      cond_x=torch.stack(cond_x).to(device)
      z,det=Glow_net.inference(x,cond_x)
      prior=-0.5*(z**2+np.log(2*np.pi))
      prior_ll=prior.flatten(1).sum(-1) - np.log(256) * np.prod(z.size()[1:])
      loss=-(prior_ll+det).mean()
      loss_meter.update(loss.item(),x.size(0))
      progress_bar.set_postfix(nll=loss_meter.avg,bpd=bits_per_dim(x,loss_meter.avg))
      progress_bar.update(x.size(0))
  if epoch%1 == 0:
    print('Saving...')
    state={'net': Glow_net.state_dict(),'test_loss': loss_meter.avg,'epoch':epoch}
    os.makedirs('ckpts',exist_ok=True)
    torch.save(state,'ckpts/best.pth.tar')

In [None]:
def bits_per_dim(x, nll):
  dim = np.prod(x.size()[1:])
  bpd = nll / (np.log(2) * dim)
  return bpd

In [None]:
mean=torch.tensor(0.)
variance=torch.tensor(1.)
gaussian_dist=torch.distributions.normal.Normal(mean,variance)
Glow_net=glow_model(prior_dist=gaussian_dist,n_channels=128,L=3,K=8)
Glow_net=Glow_net.to(device)
optimizer=torch.optim.Adam(Glow_net.parameters(),lr=1e-3,betas=(0.9,0.999),eps=1e-8)
scheduler=optim.lr_scheduler.LambdaLR(optimizer, lambda s: min(1.,s/500000))

In [None]:
def squeeze(x):
  b,c,h,w=x.size()
  x=x.view(b,c,h//2,2,w//2,2)
  x=x.permute(0,1,3,5,2,4).contiguous()
  x=x.view(b,c*2*2,h//2,w//2)
  return (x)
    
def unsqueeze(x):
  # Unsqueeze
  b,c,h,w=x.size()
  x=x.view(b,c//4,2,2,h,w)
  x=x.permute(0,1,4,2,5,3).contiguous()
  x=x.view(b,c//4,h*2,w*2)
  return (x)

In [None]:
def preprocess(x):
  noise=torch.distributions.Uniform(0.,1.).sample(x.shape)
  noise=noise.to(device)
  x=(x*255. + noise)/256.
  x*=2.
  x-=1.
  x*=0.9
  x+=1.
  x/=2.
  logit_x=torch.log(x) - torch.log(1.-x)
  pre_logit_scale=torch.tensor(np.log(0.9) - np.log(1.-0.9))
  log_det=F.softplus(logit_x) + F.softplus(-logit_x) -F.softplus(-pre_logit_scale)
  log_det=torch.sum(log_det,dim=(1,2,3))
  x=torch.log(x)-torch.log(1.-x)
  return (x,log_det)

In [None]:
global_step=0
start_epoch=0
num_epochs=100
for epoch in range(start_epoch, start_epoch + num_epochs):
  train_model(epoch,Glow_net,train_loader,device,optimizer,scheduler)
  with torch.no_grad():
    test_model(epoch,Glow_net,test_loader,device,optimizer,scheduler)
    origin_img,label=next(iter(test_loader))
    cond_x=origin_img.cpu()
    cond_x=[torchvision.transforms.Grayscale()(i) for i in cond_x]
    cond_x=torch.stack(cond_x).to(device)
    gen_img=Glow_net.sample_images(64,3,32,32,cond_x)
    gen_imgs=torch.sigmoid(gen_img)
  os.makedirs('samples_trial2',exist_ok=True)
  torchvision.utils.save_image(gen_imgs,'samples_trial2/epoch_{}.png'.format(epoch))

In [None]:
checkpoint=torch.load('ckpts/best.pth.tar')
Glow_net.load_state_dict(checkpoint['net'])

<All keys matched successfully>

In [None]:
for batch_idx,batch in enumerate(test_loader):
  #batch=next(iter(data_loader)
  image=batch[0].to(device)
  gray_batch=image.cpu()
  gray_batch=[torchvision.transforms.Grayscale()(i) for i in gray_batch]
  gray_batch=torch.stack(gray_batch).to(device)
  gen_img=Glow_net.sample_images(gray_batch.shape[0],3,32,32,gray_batch)
  gen_imgs=torch.sigmoid(gen_img)
  os.makedirs("test_gen_imgs",exist_ok=True)
  os.makedirs("test_gray_imgs",exist_ok=True)
  torchvision.utils.save_image(gen_imgs,'test_gen_imgs/epoch_{}.png'.format(batch_idx))
  torchvision.utils.save_image(gray_batch,'test_gray_imgs/epoch_{}.png'.format(batch_idx))
print("Done")

Done


In [None]:
len(os.listdir("test_gray_imgs"))

157

In [None]:
gray_img_path="test_gray_imgs"
color_img_path="test_gen_imgs"
root="/content"
os.makedirs("stitched_imgs",exist_ok=True)
import cv2
i=0
for gray,color in zip(os.listdir(gray_img_path),os.listdir(color_img_path)):
  gray_path=os.path.join(root,gray_img_path,gray)
  color_path=os.path.join(root,color_img_path,color)
  gray_img=cv2.imread(gray_path)
  color_img=cv2.imread(color_path)
  stitched_img=np.hstack((gray_img,color_img))
  cv2.imwrite('stitched_imgs/{}.png'.format(i),stitched_img)
  i+=1

In [None]:
!zip -r /content/stitched_imgs.zip /content/stitched_imgs
from google.colab import files
files.download("/content/stitched_imgs.zip")

In [None]:
#embedding the occupancy probablity into a higher dimension
x=torch.tensor((0.,1.))
x=x.unsqueeze(0)
print(x.shape)
layer=torch.nn.Linear(2,5)
l=layer(x)
print(layer(x))
print(layer(x).shape)
x6,x7=layer(x).chunk(2,dim=1)
print(x6)
print(x7)

torch.Size([1, 2])
tensor([[-0.9918, -0.2111, -0.3086,  0.8186, -0.2076]],
       grad_fn=<AddmmBackward>)
torch.Size([1, 5])
tensor([[-0.9918, -0.2111, -0.3086]], grad_fn=<SplitBackward>)
tensor([[ 0.8186, -0.2076]], grad_fn=<SplitBackward>)


### The output dimesnion of the MLP should be same as the channel dimension of the feature map; also the same label embedding should be input into the every conditonal batch norm. 

In [None]:
class conditional_bn(torch.nn.Module):
  def __init__(self,emb_dim,mlp_mid_dim,mlp_out_dim,batch_size,channels,eps=1.0e-5):
    super(conditional_bn,self).__init__()
    self.emb_dim=emb_dim # size of the lstm emb which is input to MLP
    self.mlp_mid_dim=mlp_mid_dim # size of hidden layer of MLP
    self.mlp_out_dim=mlp_out_dim # output of the MLP - for each channel
    self.batch_size=batch_size
    self.n_channels=channels
    self.eps=eps

    self.betas=torch.nn.Parameter(torch.zeros(self.batch_size,self.n_channels))
    self.gammas=torch.nn.Parameter(torch.zeros(self.batch_size,self.n_channels))

    self.mlp_gamma=torch.nn.Sequential(torch.nn.Linear(emb_dim,mlp_mid_dim),
                                       torch.nn.ReLU(inplace=True),
                                       torch.nn.Linear(mlp_mid_dim,mlp_out_dim))
    self.mlp_betas=torch.nn.Sequential(torch.nn.Linear(emb_dim,mlp_mid_dim),
                                      torch.nn.ReLU(inplace=True),
                                      torch.nn.Linear(mlp_mid_dim,mlp_out_dim))
    for m in self.modules():
      if isinstance(m,torch.nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        torch.nn.init.constant_(m.bias,0.1)
  
  def forward(self,label_emb,img_feats):
    bs,ch,height,width=img_feats.shape
    delta_beta=self.mlp_betas(label_emb)
    delta_gamma=self.mlp_gamma(label_emb)

    betas_cloned=self.betas.clone()
    gammas_cloned=self.gammas.clone()
    betas_cloned+=delta_beta
    gammas_cloned+=delta_gamma
    
    mean,var=torch.mean(img_feats),torch.var(img_feats,unbiased=True)

    all_betas=torch.stack([betas_cloned]*img_feats.shape[2],dim=2)
    all_betas=torch.stack([all_betas]*img_feats.shape[3],dim=3)
    all_gammas=torch.stack([gammas_cloned]*img_feats.shape[2],dim=2)
    all_gammas=torch.stack([all_gammas]*img_feats.shape[3],dim=3)

    feat_norm=(img_feats-mean)/torch.sqrt(var+self.eps)
    #print(feat_norm.shape)
    #print(all_gammas.shape)
    cbn_feat=torch.mul(feat_norm,all_gammas)+all_betas

    return (cbn_feat) 

In [None]:
img=torch.rand(1,3,32,32)
label=torch.tensor([0.,1.,0.,0.,0.,0.,0.,0.,0.,0.])
label=label.unsqueeze(0)
print(label.shape)
sample_conv=torch.nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,stride=1,bias=False)
sample_mlp=torch.nn.Linear(in_features=10,out_features=16)
sample_feat=sample_conv(img)
label_emb=sample_mlp(label)
print(sample_feat.shape)
print(label_emb.shape)
trial_cbn=conditional_bn(emb_dim=16,mlp_mid_dim=32,mlp_out_dim=32,batch_size=1,channels=32)
print(trial_cbn)
sample_op=trial_cbn(label_emb,sample_feat)
print(sample_op.shape)

torch.Size([1, 10])
torch.Size([1, 32, 30, 30])
torch.Size([1, 16])
conditional_bn(
  (mlp_gamma): Sequential(
    (0): Linear(in_features=16, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=32, bias=True)
  )
  (mlp_betas): Sequential(
    (0): Linear(in_features=16, out_features=32, bias=True)
    (1): ReLU(inplace=True)
    (2): Linear(in_features=32, out_features=32, bias=True)
  )
)
torch.Size([1, 32, 30, 30])


In [None]:
class cbn_coupling(torch.nn.Module):
  def _init__(self,in_channel,mid_channel,emb_channel,bs):
    super(cbn_coupling,self).__init__()
    self.scale=torch.nn.Parameter(torch.ones(in_channel,1,1))
    out_channel=2*in_channel

    self.label_mlp=torch.nn.Linear(in_features=10,out_features=emb_channel)
    self.in_norm=torch.nn.BatchNorm2d(in_channel)
    self.conv1=torch.nn.Conv2d(in_channel,mid_channel,kernel_size=(3,3),padding=1,bias=False)
    nn.init.normal_(self.conv1.weight,0.,0.05)

    self.conv2=torch.nn.Conv2d(mid_channel,mid_channel,kernel_size=(3,3),padding=1,bias=False)
    nn.init.normal_(self.conv2.weight,0.,0.05)

    self.norm2=conditional_bn(emb_channel,mlp_mid_dim=32,mlp_out_dim=mid_channel,batch_size=bs,channels=mid_channel)
    self.conv3=torch.nn.Conv2d(mid_channel,mid_channel,kernel_size=(1,1),padding=0,bias=False)
    nn.init.normal_(self.conv3.weight,0.,0.05)

    self.norm3=conditional_bn(emb_channel,mlp_mid_dim=32,mlp_out_dim=mid_channel,batch_size=bs,channels=mid_channel)
    self.conv4=torch.nn.Conv2d(mid_channel,out_channel,kernel_size=(3,3),padding=1,bias=True)
    nn.init.zeros_(self.conv4.weight)
    nn.init.zeros_(self.conv4.bias)
    ##1finish function ; encoder ; 2:1 paper rangenet++ ; 3: lectrure till 30minutes 
  
  def forward(self,x,label,log_det):
    label_emb=self.label_mlp(label)
    x1,x3=x.chunk(2,dim=1)
    #s,t function
    x2=self.in_norm(x3)
    x2=self.conv1(x2)
    x2=F.relu(x2)

    x2=self.conv2(x2)
    x2=self.norm2(label_emb,x2)
    x2=F.relu(x2)

    x2=self.conv3(x2)
    x2=self.norm3(label_emb,x2)
    x2=F.relu(x2)

    x2=self.conv4(x2)

    s=x2[:,0::2,:,:]
    t=x2[:,1::2,:,:]

    s=self.scale*torch.tanh(s)
    x1=(x1+t)*torch.exp(s)
    log_jac=torch.sum(s,dim=(1,2,3))
    log_det=log_det+log_jac
    x=torch.cat((x1,x3),dim=1)

    return(x,log_det)
  
  def backward(self,x,label,log_det):
    label_emb=self.label_mlp(label)
    x1,x3=x.chunk(2,dim=1)
    #s,t function
    x2=self.in_norm(x3)
    x2=self.conv1(x2)
    x2=F.relu(x2)

    x2=self.conv2(x2)
    x2=self.norm2(label_emb,x2)
    x2=F.relu(x2)

    x2=self.conv3(x2)
    x2=self.norm3(label_emb,x2)
    x2=F.relu(x2)

    x2=self.conv4(x2)

    s=x2[:,0::2,:,:]
    t=x2[:,1::2,:,:]

    s=self.scale*torch.tanh(s)
    x1=x1*torch.exp(s*-1) - t
    log_jac=torch.sum(s,dim=(1,2,3))
    log_det=log_det-log_jac
    x=torch.cat((x1,x3),dim=1)

    return(x,log_det)

In [None]:
img=torch.rand(1,3,32,32)
label=torch.tensor([0.,1.,0.,0.,0.,0.,0.,0.,0.,0.])
label=label.unsqueeze(0)
print(label.shape)
sample_conv=torch.nn.Conv2d(in_channels=3,out_channels=32,kernel_size=3,stride=1,bias=False)
sample_mlp=torch.nn.Linear(in_features=10,out_features=16)
sample_feat=sample_conv(img)
label_emb=sample_mlp(label)
samp_coup=cbn_coupling(in_channel=32,mid_channel=32,emb_channel=16,bs=1)

In [None]:
#shapenet code (broyden implicit poof)(till game) + (encoder decoder) + git clone

In [None]:
#1st half==> 1deep equilibrium models talk and understand,#
#2 2-3shapenet code(tomorrow) + till 2 1mdeq code recreate  
#2nd half==>3nf1 code,4mg2mesh code,5tokyo appn

In [None]:
#use conditional batch norm for the segmentation labels
#use all layers condtion with all layers