In [1]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from codis.data import InfiniteDSprites, Latents
from codis.visualization import draw_batch_grid

from codis.models.beta_vae_v2 import BetaVAEV2
from typing import List, Optional
from codis.models.mlp import MLP

pygame 2.3.0 (SDL 2.24.2, Python 3.9.16)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
class ClassIncrementalInfiniteDSprites(InfiniteDSprites):
    """Infinite dataset of procedurally generated shapes undergoing transformations."""
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.__seen_shapes = []
        self.change_class()

    @property
    def num_seen_classes(self):
        return len(self.__seen_shapes)
    
    def change_class(self):
        self.__current_shape = self.generate_shape()
        self.__seen_shapes.append(self.__current_shape)
    
    def sample_latents_for_shape(self, shape):
        """Sample a random set of latents."""
        return Latents(
            color=np.random.choice(self.ranges["color"]),
            shape=shape,
            scale=np.random.choice(self.ranges["scale"]),
            orientation=np.random.choice(self.ranges["orientation"]),
            position_x=np.random.choice(self.ranges["position_x"]),
            position_y=np.random.choice(self.ranges["position_y"]),
        )

    def __iter__(self):
        """Generate an infinite stream of images and latent vectors.
        Args:
            None
        Returns:
            An infinite stream of (image, latents) tuples."""
        while True:
            latents = self.sample_latents_for_shape(self.__current_shape)
            image   = self.draw(latents)
            yield image, latents


class CBetaVAE(BetaVAEV2):
    """The β-VAE model class."""

    def __init__(self, in_width: int = 64, in_channels: int = 1, latent_dim: int = 64, \
        num_channels: Optional[List] = None, beta: float = 1.0, H:int=128, L:int=2, act:str='elu') -> None:
        super().__init__(in_width=in_width, in_channels=in_channels, latent_dim=latent_dim, num_channels=num_channels, beta=beta)
        self.regressor = MLP(latent_dim, 4, L=L, H=H, act=act)
    
    def forward(self, X, Z):
        ''' Inputs:
                X - [N,C,W,H]
                Z - Latents
        '''
        Xhat, mu, log_std = super().forward(X) # [N,C,W,H], [N,q], [N,q]
        Z = torch.stack([Z.scale, Z.orientation, Z.position_x, Z.position_y],-1) # N,4
        latent_reg_loss = (Z-self.regressor(mu)).pow(2).sum(-1).mean(0)
        reconstruction_loss = self.reconstruction_loss(X,Xhat)
        kl_divergence = self.kl_loss(mu,log_std)
        print(latent_reg_loss)
        print(reconstruction_loss)
        print(kl_divergence)
    
def eval_learning_so_far(model):
    raise NotImplementedError()

def train_loop(model, image_size=64, Nround=10, Niter=1000, batch_size=64):
    dataset = ClassIncrementalInfiniteDSprites(image_size=image_size)
    data_loader = DataLoader(dataset, batch_size=batch_size)
    opt = torch.optim.Adam(model.parameters(), lr=1e-3)
    for i in range(Nround):
        for i,(X,Z) in enumerate(data_loader):
            opt.zero_grad()
            loss = model(X, Z)
            loss.backward()
            opt.step()
        eval_learning_so_far(model)
        dataset.change_class()

LATENT_DIM = 20
N_CHANNELS = [32, 64, 128, 256]
IMAGE_SIZE = 64
BATCH_SIZE = 2
BETA       = 5

model = CBetaVAE(latent_dim=LATENT_DIM, num_channels=N_CHANNELS, beta=BETA)
train_loop(model, image_size=IMAGE_SIZE, batch_size=BATCH_SIZE)


tensor(20.9886, dtype=torch.float64, grad_fn=<MeanBackward1>)
tensor(711.3248, grad_fn=<DivBackward0>)
tensor(0.2101, grad_fn=<MeanBackward0>)


AttributeError: 'NoneType' object has no attribute 'backward'