In [None]:
# imports
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

import os
import csv

from glob import glob
import matplotlib.pyplot as plt
from collections import namedtuple
from copy import deepcopy
from tqdm import tqdm
import random

import cv2

# pytorch + torchvision
import torch 
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchvision.models as models
import torchvision.transforms as transforms

# Check that you're using a recent OpenCV version.
assert cv2.__version__ > '4.5', 'Please use OpenCV 4.5 or later.'

Write down the plan in words and references. We will then fill in code.

In [None]:
# Encoding: ResNeXt (conv 1-3) => Position Enc. (ref. 2107.14222) => Deep-ViT (w/ EfficientNet)
# Decoding(for unsupervised training) : HiT(low resolution stage, same # as Deep-ViT) => FCC-GAN
# Comparison : Transformer-XL 
# Classifier/MLP : MLP head (output 8) => Reconstruction module => F

# Training Steps:
# 1. Train Encoder (unsupervised): manipulate input image (ref. SiT, + rotation) and match to output
# 2. Train Comparison (unsupervised) : use different head, mix&match the 2 images (ref. BERT, ALBERT)
# 3. Train MLP (supervised) : compare output to F

# Inference Steps: Encode each image => concat. 2 images => Comparison => MLP

#offical sample code: https://www.kaggle.com/code/eduardtrulls/imc2022-training-data?scriptVersionId=92062607

# Load data

In [None]:
# copied from sample code
# Input data files are available in the read-only "../input/" directory.

# on kaggle
src = '../input/image-matching-challenge-2022/train'

# on pc
# src = './image-matching-challenge-2022/train'

val_scenes = []
for f in os.scandir(src):
    if f.is_dir():
        cur_scene = os.path.split(f)[-1]
        print(f'Found scene "{cur_scene}"" at {f.path}')
        val_scenes += [cur_scene]

In [None]:
# Each scene in the validation set contains a list of images, poses, and pairs. Let's pick one and look at some images.

scene = 'piazza_san_marco'

images_dict = {}
for filename in glob(f'{src}/{scene}/images/*.jpg'):
    cur_id = os.path.basename(os.path.splitext(filename)[0])

    # OpenCV expects BGR, but the images are encoded in standard RGB, so you need to do color conversion if you use OpenCV for I/O.
    images_dict[cur_id] = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)
    
print(f'Loaded {len(images_dict)} images.')

num_rows = 6
num_cols = 4
f, axes = plt.subplots(num_rows, num_cols, figsize=(20, 20), constrained_layout=True)
for i, key in enumerate(images_dict):
    if i >= num_rows * num_cols:
        break
    cur_ax = axes[i % num_rows, i // num_rows]
    cur_ax.imshow(images_dict[key])
    cur_ax.set_title(key)
    cur_ax.axis('off')

In [None]:
# dataset
class IMC_dataset(Dataset):
    """Image Matching Challenge 2022 dataset"""

    def __init__(self, csv_file, root_dir, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.landmarks_frame = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.landmarks_frame)

    def __getitem__(self, idx):
                """
        Returns:
            image_one (tensor): First image
            image_two (tensor): Secound image
            covisibility (float)
            K (matrix): Camera intransic matrix
            R (matrix): Rotation matrix
            T (vector): Translation vector
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                self.landmarks_frame.iloc[idx, 0])
        image = io.imread(img_name)
        landmarks = self.landmarks_frame.iloc[idx, 1:]
        landmarks = np.array([landmarks])
        landmarks = landmarks.astype('float').reshape(-1, 2)
        sample = {'image': image, 'landmarks': landmarks}

        if self.transform:
            sample = self.transform(sample)

        return sample

In [None]:
# pytorch data loading

def get_scene_trainloader(scene):

    transform = transforms.Compose(
        [transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

    trainset = "./image-matching-challenge-2022/train/" + scene

    trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                            shuffle=True, num_workers=2)

    classes = ('plane', 'car', 'bird', 'cat',
            'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

    def imshow(img):
        img = img / 2 + 0.5     # unnormalize
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))


    # get some random training images
    dataiter = iter(trainloader)

    print(dataiter.next())
    images, labels = dataiter.next()

    # show images
    imshow(torchvision.utils.make_grid(images))
    # print labels
    print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

    return trainloader

get_scene_trainloader("piazza_san_marco")

In [None]:
# image manipulation (ref. SiT)
# paper:
# github: github.com/Sara-Ahmed/SiT

In [None]:
# Summary: A combination of ResNeXt and Squeeze and Excitation Network (SENet) to 
# ResNeXt ref: https://github.com/pytorch/vision/blob/main/torchvision/models/resnet.py
# SENet ref: https://github.com/moskomule/senet.pytorch 

def conv1x1(in_channels, out_channels, stride = 1):
    """ 1x1 convolution"""
    return nn.Conv2d(
        in_channels, out_channels,
        kernel_size = 1, 
        stride = stride,
        bias = False
    )

def conv3x3(in_channels, out_channels, stride = 1, groups = 1, dilation = 1):
    """ 3x3 convolution """
    return nn.Conv2d(
        in_channels, out_channels,
        kernel_size = 3,
        stride = stride,
        padding = dilation,
        groups = groups,
        bias = False,
        dilation = dilation 
    )

class BottleNeck(nn.Module):
    """ BottleNeck Layer in ResNet """

    expansion : int = 4 

    def __init__(
        self, in_channels, out_channels, 
        reduction=2, stride=1, downsample=None, num_groups=64
    ):
        # if we want to use the expansion factor, we can just modify "out_channels"
        # out_channels = in_channels * self.expansion

        super().__init__()

        width = int(in_channels/reduction)

        # inplace is used for ReLU to reduce memory usage
        self.resnext_block = nn.Sequential(
            conv1x1(in_channels, width),
            nn.BatchNorm2d(width),
            nn.ReLU(inplace=True),

            conv3x3(width, width, stride=stride, groups=num_groups),
            nn.BatchNorm2d(width),
            nn.ReLU(inplace=True),

            conv1x1(width, out_channels),
            norm_layer(out_channels)
        )

        if downsample is not None:
            self.downsample = downsample

        self.activation = nn.ReLU(inplace=True)
    
    def forward(self, x):
        residual = x 
        out = self.resnext_block(x)

        if self.downsample is not None:
            residual = self.downsample(x)
            
        out += residual
        out = self.activation(out)
        return out

class SELayer(nn.Module):
    """ building block described in the SENet Paper and github"""
    def __init__(self, channel, reduction=16):
        super().__init__())
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, in_channel//reduction, bias=False),
            nn.ReLU(inplace=True)
            nn.Linear(in_channel//reduction, channel, bias=False),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        batch, channel, _, _ = x.size()
        y = self.avg_pool(x).view(batch, channel)
        y = self.fc(y).view(batch, channel, 1, 1)
        return x * y.expand_as(x)

class ResNeXtLayer(nn.Module):
    """ basic layer described in the ResNeXt Paper """
    def __init__(self, in_channels, out_channels, num_blocks, num_groups=64, dilation=1, stride=1):
        super().__init__()
        downsample = nn.Sequential(
            conv1x1(in_channels, out_channels, stride),
            nn.BatchNorm2d(out_channels)
        )

        self.layers = []
        self.layers.append(BottleNeck(in_channels, out_channels, downsample=downsample, stride=stride))
        
        for _ in range(1, num_blocks):
            self.layers.append(
                BottleNeck(out_channels, out_channels, dilation=dilation, num_groups=num_groups)
            )
    def forward(self, x):
        x = self.layers(x)
        return x

In [None]:
# CAPE: Continuous Augmented Positional Embeddings
# paper @ https://arxiv.org/2106.03143

class CAPE(nn.Module): # I have offically given up on encoding in general
    def __init__(
        self, model_dim 
        max_global_shift = 0.0, max_local_shift = 0.0, max_global_scaling = 1.0, 
    ):
        self.max_global_shift = max_global_shift
        self.max_local_shift = max_local_shift
        self.max_global_scaling = max_global_scaling

        self.register_buffer('content_scale', nn.Tensor([math.sqrt(model_dim)]))

    def forward(self, patches):
        return (patches * self.content_scale) + self.compute_pos_emb(patches)
    
    def compute_pos_emb(self, patches):
        batch, height, width, channel = patches.shape()

        x = torch.zeros


In [None]:
# TransCNN + Universal Transformer 
# Idea : each layer repeated via ACT, then down sampled
# Idea : direct downsampling for residual

class CNNAttention(nn.Module):
    def __init__(self, total_dim, head_dim, grid_size=1, downsample_rate=1, drop=0):
        super().__init__()
        self.num_heads = total_dim // head_dim # area of previous step / area of head
        self.head_dim = head_dim 
        self.side_len = self.head_dim ** -0.5
        self.grid_size = grid_size

        self.norm = nn.BatchNorm2d()
        self.qkv = conv1x1(total_dim, total_dim * 3)
        self.proj = nn.Conv2d(total_dim, total_dim)
        self.drop = nn.Dropout2d(drop, inplace=True)

        if self.grid_size > 1:
            self.q = conv1x1(total_dim, total_dim)
            self.kv = conv1x1(total_dim, total_dim * 2)

            self.grid_norm = nn.BatchNorm2d(total_dim)
            self.avg_pool = nn.AvgPool2d(total_dim)
            self.downsample_norm = nn.BatchNorm2d(total_dim)

    def forward(self, x):
        batch, channels, height, width = x.shape()
        qkv = self.qkv(self.norm(x))

        if self.grid_size > 1:
            # compute grid based/local attention
            grid_h, grid_w = height // self.grid_size, width // self.grid_size # H/G, W/G
            qkv = qkv.reshape(
                batch, 3, # q, k, v
                self.num_heads, self.head_dim, 
                grid_h, self.grid_size,
                grid_w, self.grid_size
            ) # ref. the dimensions of this space is R^ Batch * QKV * Head * Size * H/G * G * W/G * G
            qkv = qkv.permute(1, 0, 2, 4, 6, 5, 7, 3) # R^ QKV * Batch * Head * H/G * W/G * G * G * Size 
            qkv = qkv.reshape(3, -1, self.grid_size ** 2, self.head_dim) # R^ QKV * (Batch * Head * H/ * G * G * Size 
            q, k, v = qkv[0], qkv[1], qkv[2]

            attn = (q @ k.transpose(-2, -1)) * self.side_len # transpose k -> R^ Batch * G * G * (W * H)
            attn = attn.softmax(dim=-1)
            grid_x = (attn @ v).reshape(
                batch, self.num_heads, 
                grid_h, grid_w, 
                self.grid_size, self.grid_size, 
                self.head_dim 
            ) # R^ Batch * Head * H/G * W/G * G * G * Size, same as after permute
            grid_x = self.grid_norm(x + grid_x) #residue and normalisation

            # transform qkv for computing global attention
            q = self.q(grid_x).reshape(batch, self.num_heads, self.head_dim, -1) # R^ Batch * Head * Size * (H * W)
            q = q.transpose(-2, -1) # R^ Batch * Head * (H * W) * Size 
            kv = self.kv(self.downsample_norm(self.avg_pool(grid_x)))
            kv = kv.reshape(batch, 2, self.num_heads, self.head_dim, -1) # R^ Batch * KV * Head * Size * (H * W)
            kv = kv.permute(1, 0, 2, 4, 3) # R^ KV * Batch * Head *  (H * W) * Size 
            k, v = kv[0], kv[1] # R^ Batch * Head * (H * W) * Size
        else: 
            # transform qkv for computing global attention
            qkv = qkv.reshape(B, 3, self.num_heads, self.head_dim, -1) # R^ Batch * QKV * Head * Size * (H * W)
            qkv = qkv.permute(1, 0, 2, 4, 3) # R^ QKV * Batch * Head * (H * W) * Size
            q, k, v = qkv[0], qkv[1], qkv[2]
        
        # compute global attention
        attn = (q @ k.transpose(-2, -1)) * self.side_len 
        attn = attn.softmax(dim=1)
        global_x = (attn @ v).transpose(-2, -1).reshape(batch, channel, height, width)

        # residue
        if self.grid_size > 1
            global_x += grid_x
        x = self.drop(self.proj(global_x))

        return x 

class ACT(nn.Module):

    threshold = 1 - 0.1

    def __init__(self, size, activation=nn.Sigmoid):
        super().__init__()
        self.activation = activation()
        self.fc = nn.Linear(size, 1) # What if we replace linear with conv2d?
        # !!!be sure to initialise self.p!!!
    
    def forward(self, state, inputs, fn, max_steps):
        # input flattened from 4d to 3d for ACT
        batch, size, _ = inputs.shape()

        halting_probablity = torch.zeros(batch, size)
        remainders = torch.zeros(batch, size)
        n_updates = torch.zeros(batch, size)
        previous_state = torch.zeros_like(inputs)

        def should_continue(h, n, m):
            return (h < self.threshold and n < m).byte().any()
        while should_continue(halting_probability, n_updates, max_steps):
            # we are avoiding timing signals because we have our own RPE 

            state = self.activation(self.fc(state)).squeeze(-1) 

            # calculate masks for which ones to halt
            still_running = (halting_probability < 1.0).float()
            new_halted = (halting_probability + state * still_running > self.threshold).float() * still_running
            still_running = (halting_probability + state * still_running <= self.threshold).float() * still_running

            # halt parameters and increment remainders
            halting_probability += state * still_running
            remainders += new_halted * (1 - halting_probability)
            halting_probability += new_halted * remainders 
            n_updates += still_running + new_halted
            # compute weights to apply to the state and output
            update_weights = state * still_running + new_halted * remainders 
            
            state = fn(state)
            previous_state = state * update_weights.unsqueeze(-1) + previous_state * (1 - update_weights.unsqueeze(-1))

            step += 1
        return previous_state, (remainders, n_updates)



In [None]:
# for decoder architecture we are going to reuse existing models 
# MobileStyleGAN 
# Q: Can we just ctrl c+p this?
# A: Yes! Yes we can!


from pytorch_wavelets.dwt.lowlevel import *

def _SFB2D(low, highs, g0_row, g1_row, g0_col, g1_col, mode):
    mode = int_to_mode(mode)

    lh, hl, hh = torch.unbind(highs, dim=2)
    lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2)
    hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2)
    y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3)

    return y

class DWTInverse(nn.Module):
    """ Performs a 2d DWT Inverse reconstruction of an image
    Args:
        wave (str or pywt.Wavelet): Which wavelet to use
        C: deprecated, will be removed in future
    """
    def __init__(self, wave='db1', mode='zero', trace_model=False):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)
        if isinstance(wave, pywt.Wavelet):
            g0_col, g1_col = wave.rec_lo, wave.rec_hi
            g0_row, g1_row = g0_col, g1_col
        else:
            if len(wave) == 2:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = g0_col, g1_col
            elif len(wave) == 4:
                g0_col, g1_col = wave[0], wave[1]
                g0_row, g1_row = wave[2], wave[3]
        # Prepare the filters
        filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
        self.register_buffer('g0_col', filts[0])
        self.register_buffer('g1_col', filts[1])
        self.register_buffer('g0_row', filts[2])
        self.register_buffer('g1_row', filts[3])
        self.mode = mode
        self.trace_model = trace_model

    def forward(self, coeffs):
        yl, yh = coeffs
        ll = yl
        mode = mode_to_int(self.mode)

        for h in yh[::-1]:
            if h is None:
                h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
                                ll.shape[-1], device=ll.device)

            if ll.shape[-2] > h.shape[-2]:
                ll = ll[...,:-1,:]
            if ll.shape[-1] > h.shape[-1]:
                ll = ll[...,:-1]
            if not self.trace_model:
                ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
            else:
                ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
        return ll

class IDWTUpsaplme(nn.Module):
    def __init__(
            self,
            channels_in,
            style_dim,
    ):
        super().__init__()
        self.channels = channels_in // 4
        assert self.channels * 4 == channels_in
        # upsample
        self.idwt = DWTInverse(mode='zero', wave='db1')
        # modulation
        self.modulation = nn.Linear(style_dim, channels_in, bias=True)
        self.modulation.bias.data.fill_(1.0)

    def forward(self, x, style):
        b, _, h, w = x.size()
        x = self.modulation(style).view(b, -1, 1, 1) * x
        low = x[:, :self.channels]
        high = x[:, self.channels:]
        high = high.view(b, self.channels, 3, h, w)
        x = self.idwt((low, [high]))
        return x

class ModulatedConv2d(nn.Module):
    def __init__(
            self,
            channels_in,
            channels_out,
            style_dim,
            kernel_size,
            demodulate=True
    ):
        super().__init__()
        # create conv
        self.weight = nn.Parameter(
            torch.randn(channels_out, channels_in, kernel_size, kernel_size)
        )
        # create modulation network
        self.modulation = nn.Linear(style_dim, channels_in, bias=True)
        self.modulation.bias.data.fill_(1.0)
        # create demodulation parameters
        self.demodulate = demodulate
        if self.demodulate:
            self.register_buffer("style_inv", torch.randn(1, 1, channels_in, 1, 1))
        # some service staff
        self.scale = 1.0 / math.sqrt(channels_in * kernel_size ** 2)
        self.padding = kernel_size // 2

    def forward(self, x, style):
        modulation = self.get_modulation(style)
        x = modulation * x
        x = F.conv2d(x, self.weight, padding=self.padding)
        if self.demodulate:
            demodulation = self.get_demodulation(style)
            x = demodulation * x
        return x

    def get_modulation(self, style):
        style = self.modulation(style).view(style.size(0), -1, 1, 1)
        modulation = self.scale * style
        return modulation

    def get_demodulation(self, style):
        w = self.weight.unsqueeze(0)
        norm = torch.rsqrt((self.scale * self.style_inv * w).pow(2).sum([2, 3, 4]) + 1e-8)
        demodulation = norm
        return demodulation.view(*demodulation.size(), 1, 1)


class ModulatedDWConv2d(nn.Module):
    def __init__(
            self,
            channels_in,
            channels_out,
            style_dim,
            kernel_size,
            demodulate=True
    ):
        super().__init__()
        # create conv
        self.weight_dw = nn.Parameter(
            torch.randn(channels_in, 1, kernel_size, kernel_size)
        )
        self.weight_permute = nn.Parameter(
            torch.randn(channels_out, channels_in, 1, 1)
        )
        # create modulation network
        self.modulation = nn.Linear(style_dim, channels_in, bias=True)
        self.modulation.bias.data.fill_(1.0)
        # create demodulation parameters
        self.demodulate = demodulate
        if self.demodulate:
            self.register_buffer("style_inv", torch.randn(1, 1, channels_in, 1, 1))
        # some service staff
        self.scale = 1.0 / math.sqrt(channels_in * kernel_size ** 2)
        self.padding = kernel_size // 2

    def forward(self, x, style):
        modulation = self.get_modulation(style)
        x = modulation * x
        x = F.conv2d(x, self.weight_dw, padding=self.padding, groups=x.size(1))
        x = F.conv2d(x, self.weight_permute)
        if self.demodulate:
            demodulation = self.get_demodulation(style)
            x = demodulation * x
        return x

    def get_modulation(self, style):
        style = self.modulation(style).view(style.size(0), -1, 1, 1)
        modulation = self.scale * style
        return modulation

    def get_demodulation(self, style):
        w = (self.weight_dw.transpose(0, 1) * self.weight_permute).unsqueeze(0)
        norm = torch.rsqrt((self.scale * self.style_inv * w).pow(2).sum([2, 3, 4]) + 1e-8)
        demodulation = norm
        return demodulation.view(*demodulation.size(), 1, 1)

class StyledConv2d(nn.Module):
    def __init__(
        self,
        channels_in,
        channels_out,
        style_dim,
        kernel_size,
        demodulate=True,
        conv_module
    ):
        super().__init__()

        self.conv = conv_module(
            channels_in,
            channels_out,
            style_dim,
            kernel_size,
            demodulate=demodulate
        )

        self.noise = NoiseInjection()
        self.bias = nn.Parameter(torch.zeros(1, channels_out, 1, 1))
        self.act = nn.LeakyReLU(0.2)

    def forward(self, input, style, noise=None):
        out = self.conv(input, style)
        out = self.noise(out, noise=noise)
        out = self.act(out + self.bias)
        return out

class MultichannelImage(nn.Module):
    def __init__(
            self,
            channels_in,
            channels_out,
            style_dim,
            kernel_size=1
    ):
        super().__init__()
        self.conv = ModulatedConv2d(channels_in, channels_out, style_dim, kernel_size, demodulate=False)
        self.bias = nn.Parameter(torch.zeros(1, channels_out, 1, 1))

    def forward(self, hidden, style):
        out = self.conv(hidden, style)
        out = out + self.bias
        return out

class MobileSynthesisBlock(nn.Module):
    def __init__(
            self,
            channels_in,
            channels_out,
            style_dim,
            kernel_size=3,
            conv_module
    ):
        super().__init__()
        self.up = IDWTUpsample(channels_in, style_dim)
        self.conv1 = StyledConv2d(
            channels_in // 4,
            channels_out,
            style_dim,
            kernel_size,
            conv_module=conv_module
        )
        self.conv2 = StyledConv2d(
            channels_out,
            channels_out,
            style_dim,
            kernel_size,
            conv_module=conv_module
        )
        self.to_img = MultichannelImage(
            channels_in=channels_out,
            channels_out=12,
            style_dim=style_dim,
            kernel_size=1
        )

    def forward(self, hidden, style, noise=[None, None]):
        hidden = self.up(hidden, style if style.ndim == 2 else style[:, 0, :])
        hidden = self.conv1(hidden, style if style.ndim == 2 else style[:, 0, :], noise=noise[0])
        hidden = self.conv2(hidden, style if style.ndim == 2 else style[:, 1, :], noise=noise[1])
        img = self.to_img(hidden, style if style.ndim == 2 else style[:, 2, :])
        return hidden, img

    def wsize(self):
        return 3

class MobileSynthesisNetwork(nn.Module):
    def __init__(
            self,
            style_dim,
            channels = [512, 512, 512, 512, 512, 256, 128, 64]
    ):
        super().__init__()
        self.style_dim = style_dim

        self.input = ConstantInput(channels[0])
        self.conv1 = StyledConv2d(
            channels[0],
            channels[0],
            style_dim,
            kernel_size=3
        )
        self.to_img1 = MultichannelImage(
            channels_in=channels[0],
            channels_out=12,
            style_dim=style_dim,
            kernel_size=1
        )

        self.layers = nn.ModuleList()
        channels_in = channels[0]
        for i, channels_out in enumerate(channels[1:]):
            self.layers.append(
                MobileSynthesisBlock(
                    channels_in,
                    channels_out,
                    style_dim,
                    3,
                    conv_module=ModulatedDWConv2d
                )
            )
            channels_in = channels_out

        self.idwt = DWTInverse(mode="zero", wave="db1")
        self.register_buffer("device_info", torch.zeros(1))
        self.trace_model = False

    def forward(self, style, noise=None):
        out = {"noise": [], "freq": [], "img": None}
        noise = NoiseManager(noise, self.device_info.device, self.trace_model)

        hidden = self.input(style)
        out["noise"].append(noise(hidden.size(-1)))
        hidden = self.conv1(hidden, style if style.ndim == 2 else style[:, 0, :], noise=out["noise"][-1])
        img = self.to_img1(hidden, style if style.ndim == 2 else style[:, 1, :])
        out["freq"].append(img)

        for i, m in enumerate(self.layers):
            out["noise"].append(noise(2 ** (i + 3), 2))
            _style = style if style.ndim == 2 else style[:, m.wsize()*i + 1: m.wsize()*i + m.wsize() + 1, :]
            hidden, freq = m(hidden, _style, noise=out["noise"][-1])
            out["freq"].append(freq)

        out["img"] = self.dwt_to_img(out["freq"][-1])
        return out

    def dwt_to_img(self, img):
        b, c, h, w = img.size()
        low = img[:, :3, :, :]
        high = img[:, 3:, :, :].view(b, 3, 3, h, w)
        return self.idwt((low, [high]))

    def wsize(self):
        return len(self.layers) * self.layers[0].wsize() + 2