In [1]:
import torch
import torch.nn as nn
import autoencoder
from utils.datasets import ImageDataset
import pandas as pd

In [2]:
encoder = autoencoder.Encoder(z_channels=4,
                  in_channels=1,
                  channels=128,
                  channel_multipliers=[1, 2, 4, 4],
                  n_resnet_blocks=2)

decoder = autoencoder.Decoder(out_channels=1,
                  z_channels=4,
                  channels=128,
                  channel_multipliers=[1, 2, 4, 4],
                  n_resnet_blocks=2)

ae = autoencoder.Autoencoder(emb_channels=4,
                          encoder=encoder,
                          decoder=decoder,
                          z_channels=4)

In [3]:
inp = torch.randn(1,1,48,56,48)
z= torch.randn(1,4,8,8,8)

In [4]:
decoder(z).shape

torch.Size([1, 1, 48, 56, 48])

In [3]:
config = {
    'batch_size':2,
    'epochs':10,
    'lr':1e-4
}

In [4]:
trainer = VAETrainer(ae, config)

In [5]:
data_dir = './data'
mode = 'train'
dataset = 'dataset_rh_4classes'

# Data loader. 
dataset_file = f'{data_dir}/{mode}-{dataset}.csv'
data_flist = pd.read_csv(dataset_file)['filepaths'].iloc[:10]

dataset = ImageDataset(
    data_flist
)

print(f'Dataset {dataset}: \n {len(dataset.data)} images.')

Dataset <utils.datasets.ImageDataset object at 0x14f552290>: 
 10 images.


In [6]:
trainer.train(dataset)

---- Start training ----
	Epoch 1 	Average Loss:  157.76661109924316


KeyboardInterrupt: 

In [5]:
inp = torch.randn(1,1,48,56,48)
ae.encode(inp).sample().shape

torch.Size([1, 4, 6, 6, 6])

In [5]:
import unet 

In [6]:
model = unet.UNetModel(in_channels=4,
               out_channels=4,
               channels=320,
               attention_levels=[0, 1, 2],
               n_res_blocks=2,
               channel_multipliers=[1, 2, 4, 4],
               n_heads=8,
               tf_layers=1,
               d_cond=4096)

In [7]:
import torch
x = torch.randn(1,4,8,8,8)
cond = torch.randn(1,1,4096)

In [9]:
model(x, torch.tensor([500]), cond).shape

torch.Size([1, 1280, 1, 1, 1])
torch.Size([1, 1280, 1, 1, 1])
torch.Size([1, 1280, 1, 1, 1])
torch.Size([1, 1280, 1, 1, 1])
torch.Size([1, 1280, 1, 1, 1])
torch.Size([1, 1280, 1, 1, 1])
torch.Size([1, 1280, 2, 2, 2])
torch.Size([1, 1280, 2, 2, 2])
torch.Size([1, 1280, 2, 2, 2])
torch.Size([1, 1280, 2, 2, 2])
torch.Size([1, 1280, 2, 2, 2])
torch.Size([1, 640, 2, 2, 2])
torch.Size([1, 1280, 4, 4, 4])
torch.Size([1, 640, 4, 4, 4])
torch.Size([1, 640, 4, 4, 4])
torch.Size([1, 640, 4, 4, 4])
torch.Size([1, 640, 4, 4, 4])
torch.Size([1, 320, 4, 4, 4])
torch.Size([1, 640, 8, 8, 8])
torch.Size([1, 320, 8, 8, 8])
torch.Size([1, 320, 8, 8, 8])
torch.Size([1, 320, 8, 8, 8])
torch.Size([1, 320, 8, 8, 8])
torch.Size([1, 320, 8, 8, 8])


torch.Size([1, 4, 8, 8, 8])