[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/phujbert/SceneGen/blob/master/SceneGenSample.ipynb)

In [None]:
import copy
import math
import torch
import numpy as np
from torch import optim
import torch.nn as nn
from tqdm import tqdm
import torch.nn.functional as F
from PIL import Image
import torchvision.utils as tutils
from collections import OrderedDict
import json

# Download trained models, graph dictionary, example graph

In [None]:
!wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id=1C8OB3787DW1V-TNH8mvwGFvumcCAjr4b' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id=1C8OB3787DW1V-TNH8mvwGFvumcCAjr4b" -O trained_models.zip && rm -rf /tmp/cookies.txt
!unzip trained_models.zip

In [None]:
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1PQZd63V8gaMQewPyPtTXWO1qMMAzOIHK' -O graph_dict.json

In [None]:
!wget --no-check-certificate 'https://docs.google.com/uc?export=download&id=1Z2vZIViVdGebV_WDmrNNg0KuORzsH3vM' -O sample_graph.json

# Diffusion

In [None]:
class Diffusion:
    def __init__(self, noise_steps=1000, beta_start=1e-4, beta_end=0.02, img_size=64, device="cuda"):
        self.noise_steps = noise_steps
        self.beta_start = beta_start
        self.beta_end = beta_end

        self.beta = self.prepare_noise_schedule().to(device)
        self.alpha = 1. - self.beta
        self.alpha_hat = torch.cumprod(self.alpha, dim=0)

        self.img_size = img_size
        self.device = device

    def prepare_noise_schedule(self):
        return torch.linspace(self.beta_start, self.beta_end, self.noise_steps)

    def noise_images(self, x, t):
        sqrt_alpha_hat = torch.sqrt(self.alpha_hat[t])[:, None, None, None]
        sqrt_one_minus_alpha_hat = torch.sqrt(1 - self.alpha_hat[t])[:, None, None, None]
        epsilon = torch.randn_like(x)
        return sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon, epsilon

    def sample_timesteps(self, n):
        return torch.randint(low=1, high=self.noise_steps, size=(n,))

    def sample(self, model, n, layout, edge_dict=None):
        print("start sample")
        model.eval()
        with torch.no_grad():
            x = torch.randn((n, 3, self.img_size, self.img_size)).to(self.device)
            for i in tqdm(reversed(range(1, self.noise_steps)), position=0):
                t = (torch.ones(n) * i).long().to(self.device)
                predicted_noise = model(x, t, layout)
                uncond_predicted_noise = model(x, t, None)
                predicted_noise = torch.lerp(uncond_predicted_noise, predicted_noise, 3)

                alpha = self.alpha[t][:, None, None, None]
                alpha_hat = self.alpha_hat[t][:, None, None, None]
                beta = self.beta[t][:, None, None, None]
                if i > 1:
                    noise = torch.randn_like(x)
                else:
                    noise = torch.zeros_like(x)
                x = 1 / torch.sqrt(alpha) * (
                            x - ((1 - alpha) / (torch.sqrt(1 - alpha_hat))) * predicted_noise) + torch.sqrt(
                    beta) * noise

        model.train()
        x = (x.clamp(-1, 1) + 1) / 2
        x = (x * 255).type(torch.uint8)
        print("end sample")
        return x

# Modules

In [None]:
class EMA:
    def __init__(self, beta):
        super().__init__()
        self.beta = beta
        self.step = 0

    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)

    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

    def step_ema(self, ema_model, model, step_start_ema=2000):
        if self.step < step_start_ema:
            self.reset_parameters(ema_model, model)
            self.step += 1
            return
        self.update_model_average(ema_model, model)
        self.step += 1

    def reset_parameters(self, ema_model, model):
        ema_model.load_state_dict(model.state_dict())


class SelfAttention(nn.Module):
    def __init__(self, channels, size):
        super(SelfAttention, self).__init__()
        self.channels = channels
        self.size = size
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels)
        )

    def forward(self, x):
        x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
        x_ln = self.ln(x)
        attention_value, _ = self.mha(x_ln, x_ln, x_ln)
        attention_value = attention_value + x
        attention_value = self.ff_self(attention_value) + attention_value
        return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
        super().__init__()
        self.residual = residual
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, mid_channels),
            nn.GELU(),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.GroupNorm(1, out_channels)
        )

    def forward(self, x):
        if self.residual:
            return F.gelu(x + self.double_conv(x))
        else:
            return self.double_conv(x)


class Down(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels),
        )

        self.emd_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_channels),
        )

    def forward(self, x, t):
        x = self.maxpool_conv(x)
        emb = self.emd_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class Up(nn.Module):
    def __init__(self, in_channels, out_channels, emb_dim=128):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.conv = nn.Sequential(
            DoubleConv(in_channels, in_channels, residual=True),
            DoubleConv(in_channels, out_channels, in_channels // 2),
        )

        self.emb_layer = nn.Sequential(
            nn.SiLU(),
            nn.Linear(emb_dim, out_channels),
        )

    def forward(self, x, skip_x, t):
        x = self.up(x)
        x = torch.cat([skip_x, x], dim=1)
        x = self.conv(x)
        emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
        return x + emb


class UNet(nn.Module):
    def __init__(self, c_in=3, c_out=3, time_dim=256, device="cuda"):
        super().__init__()
        self.device = device
        self.time_dim = time_dim
        self.inc = DoubleConv(c_in, 64)
        self.down1 = Down(64, 128)
        self.sa1 = SelfAttention(128, 32)
        self.down2 = Down(128, 256)
        self.sa2 = SelfAttention(256, 16)
        self.down3 = Down(256, 256)
        self.sa3 = SelfAttention(256, 8)

        self.bot1 = DoubleConv(256, 512)
        self.bot2 = DoubleConv(512, 512)
        self.bot3 = DoubleConv(512, 256)

        self.up1 = Up(512, 128)
        self.sa4 = SelfAttention(128, 16)
        self.up2 = Up(256, 64)
        self.sa5 = SelfAttention(64, 32)
        self.up3 = Up(128, 64)
        self.sa6 = SelfAttention(64, 64)
        self.outc = nn.Conv2d(64, c_out, kernel_size=1)

    def sinusoidal_embedding(self, t, channels):
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2, device=self.device).float() / channels))
        emb_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
        emb_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
        emb = torch.cat([emb_a, emb_b], dim=-1)
        return emb

    def forward(self, x, t, obj_vecs):
        t = t.unsqueeze(-1).type(torch.float)
        t = self.sinusoidal_embedding(t, self.time_dim)
        if obj_vecs is not None:
            t += obj_vecs

        x1 = self.inc(x)
        x2 = self.down1(x1, t)
        x3 = self.down2(x2, t)
        x4 = self.down3(x3, t)

        x4 = self.bot1(x4)
        x4 = self.bot2(x4)
        x4 = self.bot3(x4)

        x = self.up1(x4, x3, t)
        x = self.up2(x, x2, t)
        x = self.up3(x, x1, t)
        output = self.outc(x)
        return output

# GCN

In [None]:
def _init_weights(module):
    if hasattr(module, 'weight'):
        if isinstance(module, nn.Linear):
            nn.init.kaiming_normal_(module.weight)


class GraphConv(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, pooling='avg', mlp_normalization='none'):
        super(GraphConv, self).__init__()
        if output_dim is None:
            output_dim = input_dim
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.pooling = pooling

        self.net1 = nn.Sequential(
            nn.Linear(3 * input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 2 * hidden_dim + output_dim),
            nn.ReLU())

        self.net2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.ReLU())

    def forward(self, obj_vecs, pred_vecs, edges):

        dtype, device = obj_vecs.dtype, obj_vecs.device
        O, T = obj_vecs.size(0), pred_vecs.size(0)
        Din, H, Dout = self.input_dim, self.hidden_dim, self.output_dim

        s_idx = edges[:, 0].contiguous()
        o_idx = edges[:, 1].contiguous()

        cur_s_vecs = obj_vecs[s_idx]
        cur_o_vecs = obj_vecs[o_idx]

        cur_t_vecs = torch.cat([cur_s_vecs, pred_vecs, cur_o_vecs], dim=1)
        new_t_vecs = self.net1(cur_t_vecs)

        new_s_vecs = new_t_vecs[:, :H]
        new_p_vecs = new_t_vecs[:, H:(H + Dout)]
        new_o_vecs = new_t_vecs[:, (H + Dout):(2 * H + Dout)]

        pooled_obj_vecs = torch.zeros(O, H, dtype=dtype, device=device)

        s_idx_exp = s_idx.view(-1, 1).expand_as(new_s_vecs)
        o_idx_exp = o_idx.view(-1, 1).expand_as(new_o_vecs)
        pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, s_idx_exp, new_s_vecs)
        pooled_obj_vecs = pooled_obj_vecs.scatter_add(0, o_idx_exp, new_o_vecs)

        if self.pooling == 'avg':
            obj_counts = torch.zeros(O, dtype=dtype, device=device)
            ones = torch.ones(T, dtype=dtype, device=device)
            obj_counts = obj_counts.scatter_add(0, s_idx, ones)
            obj_counts = obj_counts.scatter_add(0, o_idx, ones)

            obj_counts = obj_counts.clamp(min=1)
            pooled_obj_vecs = pooled_obj_vecs / obj_counts.view(-1, 1)

        new_obj_vecs = self.net2(pooled_obj_vecs)

        return new_obj_vecs, new_p_vecs


class GraphConvNet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=5, pooling='avg'):
        super(GraphConvNet, self).__init__()

        self.obj_embedding = nn.Embedding(185, 128)
        self.pred_embedding = nn.Embedding(11, 128)

        self.num_layers = num_layers
        self.gcn = nn.ModuleList()
        for _ in range(num_layers):
            self.gcn.append(GraphConv(input_dim, hidden_dim, output_dim, pooling=pooling))

    def forward(self, objs, preds, edges):
        obj_vecs = self.obj_embedding(objs)
        pred_vecs = self.pred_embedding(preds)
        for i in range(self.num_layers):
            layer = self.gcn[i]
            obj_vecs, pred_vecs = layer(obj_vecs, pred_vecs, edges)
        return obj_vecs

# Utils

In [None]:
def save_images(images, path):
    grid = tutils.make_grid(images)
    ndarr = grid.permute(1, 2, 0).to('cpu').numpy()
    imgs = Image.fromarray(ndarr)
    imgs.save(path)
    return imgs

def load_transfer_weights_biases(transfer_model_path, device):
    checkpoint = torch.load(transfer_model_path, device)
    original_od = checkpoint['model_state']
    transfer_weights_biases = OrderedDict()

    # object embedding
    transfer_weights_biases['obj_embedding.weight'] = original_od['obj_embeddings.weight']

    # pred embedding
    transfer_weights_biases['pred_embedding.weight'] = original_od['pred_embeddings.weight']

    # Layer_0
    transfer_weights_biases['gcn.0.net1.0.weight'] = original_od['gconv.net1.0.weight']
    transfer_weights_biases['gcn.0.net1.0.bias'] = original_od['gconv.net1.0.bias']
    transfer_weights_biases['gcn.0.net1.2.weight'] = original_od['gconv.net1.2.weight']
    transfer_weights_biases['gcn.0.net1.2.bias'] = original_od['gconv.net1.2.bias']
    transfer_weights_biases['gcn.0.net2.0.weight'] = original_od['gconv.net2.0.weight']
    transfer_weights_biases['gcn.0.net2.0.bias'] = original_od['gconv.net2.0.bias']
    transfer_weights_biases['gcn.0.net2.2.weight'] = original_od['gconv.net2.2.weight']
    transfer_weights_biases['gcn.0.net2.2.bias'] = original_od['gconv.net2.2.bias']

    # Layer_1 - Layer_4
    for i in range(1, 5):
        transfer_weights_biases[f'gcn.{i}.net1.0.weight'] = original_od[f'gconv_net.gconvs.{i - 1}.net1.0.weight']
        transfer_weights_biases[f'gcn.{i}.net1.0.bias'] = original_od[f'gconv_net.gconvs.{i - 1}.net1.0.bias']
        transfer_weights_biases[f'gcn.{i}.net1.2.weight'] = original_od[f'gconv_net.gconvs.{i - 1}.net1.2.weight']
        transfer_weights_biases[f'gcn.{i}.net1.2.bias'] = original_od[f'gconv_net.gconvs.{i - 1}.net1.2.bias']
        transfer_weights_biases[f'gcn.{i}.net2.0.weight'] = original_od[f'gconv_net.gconvs.{i - 1}.net2.0.weight']
        transfer_weights_biases[f'gcn.{i}.net2.0.bias'] = original_od[f'gconv_net.gconvs.{i - 1}.net2.0.bias']
        transfer_weights_biases[f'gcn.{i}.net2.2.weight'] = original_od[f'gconv_net.gconvs.{i - 1}.net2.2.weight']
        transfer_weights_biases[f'gcn.{i}.net2.2.bias'] = original_od[f'gconv_net.gconvs.{i - 1}.net2.2.bias']

    return transfer_weights_biases



def get_graph_emb(gcn, objs, triples):
    O, T = objs.size(0), triples.size(0)
    s, p, o = triples.chunk(3, dim=1)
    s, p, o = [x.squeeze(1) for x in [s, p, o]]
    preds = p
    edges = torch.stack([s, o], dim=1)
    with torch.no_grad():
        obj_vecs = gcn(objs, preds, edges)
    graph_emb = torch.sum(obj_vecs, dim=0, keepdim=True)
    return graph_emb



def process_graph(graph_path, dict_path):
    with open(graph_path, "r") as inputfile:
        graph = json.load(inputfile)
        
    with open(dict_path, "r") as inputfile:
        graph_dict = json.load(inputfile)
        
    objs = graph["objs"]
    triples = graph["triples"]
    
    edges = graph_dict["edges"]
    categories = graph_dict["categories"]
    
    for obj in objs:
        triples.append([obj, "__in_image__", "in_image"])
    objs.append("in_image")
    
    for i, triple in enumerate(triples):
        for o, obj in enumerate(objs):
            if obj == triple[0]:
                triples[i][0] = o
            if obj == triple[2]:
                triples[i][2] = o
        triple[1] = edges[triple[1]]
    
    for i, obj in enumerate(objs):
        objs[i] = categories[obj]
        
    objs = torch.LongTensor(objs)
    triples = torch.LongTensor(triples)
    
    return objs, triples

# Sample

In [None]:
def sample_images(sample_size=1, image_path="samled_images.jpg", scene_graph=True, graph_path="sample_graph.json"):
    device = 'cuda'
    embedding_dim = 128
    dict_path = "graph_dict.json"
    transfer_module_path = "trained_models/transfer_models/coco64.pt"
    model_path = "trained_models/scenegen_epoch_359/model_epoch_359.pt"
    ema_model_path = "trained_models/scenegen_epoch_359/ema_model_epoch_359.pt"
    

    sample_layouts = None
    if scene_graph == True:
        gcn = GraphConvNet(input_dim=embedding_dim, hidden_dim=512, output_dim=embedding_dim)
        transfer_weights_biases = load_transfer_weights_biases(transfer_module_path, device)
        gcn.load_state_dict(transfer_weights_biases)

        objs, triples = process_graph(graph_path, dict_path)
        graph_emb = get_graph_emb(gcn, objs, triples)

        sample_layouts = []
        for _ in range(sample_size):
            sample_layouts.append(graph_emb)

        sample_layouts = torch.cat(sample_layouts)
        sample_layouts = sample_layouts.to(device)

    model = UNet(time_dim=embedding_dim, device=device).to(device)
    model.load_state_dict(torch.load(model_path))


    ema = EMA(0.995)
    ema_model = copy.deepcopy(model).eval().requires_grad_(False)
    ema_model.load_state_dict(torch.load(ema_model_path))
    diffusion = Diffusion(device=device)
    ema_sampled_images = diffusion.sample(ema_model, n=sample_size, layout=sample_layouts)
    imgs = save_images(ema_sampled_images, image_path)
    return imgs

# Test

In [None]:
images = sample_images(sample_size=8, image_path="samled_images.jpg", graph_path="sample_graph.json")
display(images)