In [1]:
import os
import numpy as np
from attrdict import AttrDict as attrdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import datasets
import datasets.utils as du
import net.models as models
from utils.clustering import decomposition, metrics, functional
from utils.plotlib import plot

ROOT = os.path.dirname(os.path.dirname(os.path.abspath('__file__')))

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)

def conv1x1(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

def downsample(in_planes, out_planes, stride=2):
    return nn.Sequential(
                conv1x1(in_planes, out_planes, stride=2),
                nn.BatchNorm2d(out_planes)
           )

class BasicBlock(nn.Module):

    expansion = 1

    def __init__(self, in_planes, out_planes, stride=1, downsample=None):
        super().__init__()

        self.features = nn.Sequential(
            conv3x3(in_planes, out_planes, stride),
            nn.BatchNorm2d(out_planes),
            nn.ReLU(inplace=True),
            conv3x3(out_planes, out_planes),
            nn.BatchNorm2d(out_planes)
        )
        self.downsample = downsample

    def forward(self, x):
        identity = x
        features = self.features(x)
        if self.downsample is not None:
            identity = self.downsample(x)
        features += identity
        return F.relu(features, inplace=True)

class ResNet(nn.Module):

    def __init__(self, in_planes, num_classes, block, num_blocks):
        super().__init__()

        assert len(num_blocks) == 4

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_planes, 64, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )

        self.conv2_x = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            self._make_layer(64, 64, block, num_blocks[0])
        )

        self.conv3_x = nn.Sequential(
            self._make_layer(64, 128, block, num_blocks[1], stride=2)
        )

        self.conv4_x = nn.Sequential(
            self._make_layer(128, 256, block, num_blocks[2], stride=2)
        )

        self.conv5_x = nn.Sequential(
            self._make_layer(256, 512, block, num_blocks[3], stride=2)
        )

        self.avgpool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Flatten()
        )
        
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)


    def _make_layer(self, in_planes, out_planes, block, num_block, stride=1):
        assert num_block > 0
        layers = []
        downsample = None
        if stride > 1:
            downsample = nn.Sequential(
                            conv1x1(in_planes, out_planes, stride=stride),
                            nn.BatchNorm2d(out_planes)
                        )
        layers.append(block(in_planes, out_planes, stride=stride, downsample=downsample))

        for _ in range(1, num_block):
            layers.append(block(out_planes, out_planes))
        
        return nn.Sequential(*layers)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2_x(x)
        x = self.conv3_x(x)
        x = self.conv4_x(x)
        x = self.conv5_x(x)
        x = self.avgpool(x)
        x = self.fc(x)
        return x

class Gaussian(nn.Module):
    def __init__(self, in_dim, out_dim, eps=1e-8):
        super().__init__()
        self.features = nn.Linear(in_dim, out_dim * 2)
        self.eps = eps

    def forward(self, x, reparameterize=True):
        x = self.features(x)
        mean, logit = torch.split(x, x.shape[1] // 2, -1)
        var = F.softplus(logit) + self.eps
        if reparameterize:
            x = self._reparameterize(mean, var)
        else:
            x = mean
        return x, mean, var
    
    def _reparameterize(self, mean, var):
        if torch.is_tensor(var):
            std = torch.pow(var, 0.5)
        else:
            std = np.sqrt(var)
        eps = torch.randn_like(mean)
        x = mean + eps * std
        return x

class GumbelSoftmax(nn.Module):
    def __init__(self, in_dim, out_dim):
        super().__init__()
        self.features = nn.Linear(in_dim, out_dim)

    def forward(self, x, tau=1., dim=-1, hard=False):
        logits = self.features(x)
        pi = logits.softmax(dim)
        gumbels = -torch.empty_like(logits).exponential_().log()
        gumbels = (logits + gumbels) / tau
        y = gumbels.softmax(dim)

        if hard:
            index = y.max(dim, keepdim=True)[1]
            y_hard = torch.zeros_like(logits).scatter_(dim, index, 1.0)
            y = y_hard - y.detach() + y

        return logits, y
    
class Reshape(nn.Module):
    def __init__(self, outer_shape):
        super().__init__()
        self.outer_shape = outer_shape

    def forward(self, x):
        return x.view(x.size(0), *self.outer_shape)

def bce_loss(inputs, targets):
    loss = F.binary_cross_entropy(inputs, targets, reduction='none').view(inputs.shape[0], -1).sum(-1)
    return loss

def _log_norm(x, mean=None, var=None):
    if mean is None:
        mean = torch.zeros_like(x)
    if var is None:
        var = torch.ones_like(x)
    return -0.5 * (torch.log(2.0 * np.pi * var) + torch.pow(x - mean, 2) / var )

def log_norm_kl(x, mean, var, mean_=None, var_=None):
    log_p = _log_norm(x, mean, var).sum(-1)
    log_q = _log_norm(x, mean_, var_).sum(-1)
    loss = log_p - log_q
    return loss

def entropy(logits):
    p = logits.softmax(-1)
    log_p = logits.log_softmax(-1)
    entropy = -(p * log_p).sum(-1)
    return entropy

class ResNet_VAE(nn.Module):
    def __init__(self):
        super().__init__()

        resnet = ResNet(1, 1000, BasicBlock, [3, 3, 3, 3])
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.gaussian = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256, momentum=0.01),
            Gaussian(256, 64)
        )
        self.upsample = nn.Sequential(
            nn.Linear(64, 512),
            nn.BatchNorm1d(512, momentum=0.01),
            nn.Linear(512, 512 * 64),
            nn.BatchNorm1d(512 * 64, momentum=0.01),
            Reshape((512, 8, 8))
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2),
            nn.BatchNorm2d(256, momentum=0.01),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2),
            nn.BatchNorm2d(128, momentum=0.01),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2),
            nn.BatchNorm2d(64, momentum=0.01),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=7, stride=3, padding=2),
            nn.BatchNorm2d(1, momentum=0.01),
            nn.Sigmoid(),
        )

        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x, return_params=False):
        z, z_mean, z_var = self.gaussian(self.encoder(x))
        x_reconst = self.decoder(self.upsample(z))
        if return_params:
            return {'x': x,'x_reconst': x_reconst, 'z': z, 'z_mean': z_mean, 'z_var': z_var}
        return x_reconst
    
    def criterion(self, x, params):
        loss = bce_loss(params['x_reconst'], x)
        kl = log_norm_kl(params['z'], params['z_mean'], params['z_var'])
        return (loss + kl).mean()

class ResNet_CVAE(nn.Module):
    def __init__(self, z_dim=64, y_dim=10):
        super().__init__()

        resnet = ResNet(1, 1000, BasicBlock, [3, 3, 3, 3])
        self.encoder = nn.Sequential(*list(resnet.children())[:-1])
        self.inference_y = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256, momentum=0.01),
            GumbelSoftmax(256, y_dim)
        )
        self.inference_z_prior = nn.Sequential(
            nn.Linear(y_dim, 256),
            nn.BatchNorm1d(256, momentum=0.01),
            Gaussian(256, z_dim)
        )
        self.inference_z = nn.Sequential(
            nn.Linear(512 + y_dim, 256),
            nn.BatchNorm1d(256, momentum=0.01),
            Gaussian(256, z_dim)
        )
        self.upsample = nn.Sequential(
            nn.Linear(64, 512),
            nn.BatchNorm1d(512, momentum=0.01),
            nn.Linear(512, 512 * 64),
            nn.BatchNorm1d(512 * 64, momentum=0.01),
            Reshape((512, 8, 8))
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2),
            nn.BatchNorm2d(256, momentum=0.01),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2),
            nn.BatchNorm2d(128, momentum=0.01),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2),
            nn.BatchNorm2d(64, momentum=0.01),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=7, stride=3, padding=2),
            nn.BatchNorm2d(1, momentum=0.01),
            nn.Sigmoid(),
        )

        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x, return_params=False):
        h = self.encoder(x)
        y_logits, y = self.inference_y(h)
        _, y_pred = torch.max(y_logits, -1)
        _, z_prior_mean, z_prior_var = self.inference_z_prior(y)
        xy = torch.cat((h, y),-1)
        z, z_mean, z_var = self.inference_z(xy)
        x_reconst = self.decoder(self.upsample(z))
        if return_params:
            return dict(
                x=x,
                x_reconst=x_reconst,
                y_logits=y_logits, 
                y=y,
                y_pred=y_pred,
                z=z, 
                z_mean=z_mean, 
                z_var=z_var,
                z_prior_mean=z_prior_mean, 
                z_prior_var=z_prior_var
            )
        return x_reconst
    
    def criterion(self, x, params):
        loss = bce_loss(params['x_reconst'], x)
        z_kl = log_norm_kl(params['z'], params['z_mean'], params['z_var'], params['z_prior_mean'], params['z_prior_var'])
        y_en = entropy(params['y_logits'])
        return (loss + z_kl + y_en).mean()

FLAGS = attrdict(
    x_dim=486,
    batch_size=64,
    num_epochs=100,
    num_workers=4,
    log_steps=1,
    lr=1e-2,
    dataset='gravityspy'
)


setup_transform = transforms.Compose([
                transforms.CenterCrop(FLAGS.x_dim),
                transforms.Grayscale(),
                transforms.Resize(213),
                transforms.ToTensor()
            ])

data_transform = transforms.Compose([
                transforms.Grayscale(),
                transforms.ToTensor()
            ])

dataset = getattr(datasets, FLAGS.dataset)(root=ROOT, transform=data_transform, 
                                                                              setup_transform=setup_transform, 
                                                                              download=True)
dataset.get_by_keys([3, 9, 14])
print(f'dataset length: {len(dataset)}')
loader = DataLoader(dataset,
                    batch_size=FLAGS.batch_size,
                    num_workers=FLAGS.num_workers,
                    shuffle=True,
                    drop_last=True)
    

y_dim = 3

device_ids = range(torch.cuda.device_count())
device = f'cuda:{device_ids[0]}' if torch.cuda.is_available() else 'cpu'

model = ResNet_CVAE(y_dim=y_dim).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=FLAGS.lr)

for epoch in range(1, FLAGS.num_epochs):
    loss = 0
    features = torch.Tensor().to(device)
    target_stack = np.array([]).astype(np.int)
    pred_stack = np.array([]).astype(np.int)
    for step, (x, target) in enumerate(loader):
        x = x.to(device)
        out_params = model(x, return_params=True)
        step_loss = model.criterion(x, out_params)
        optimizer.zero_grad()
        step_loss.backward()
        optimizer.step()
        loss += step_loss.item()

        features = torch.cat((features, out_params['z']), 0)
        target_stack = np.append(target_stack, list(target.numpy()))
        pred_stack = np.append(pred_stack, list(out_params['y_pred'].detach().cpu().numpy()))
    
    features = features.detach().cpu().numpy()
    km_stack = functional.run_kmeans(features, y_dim + 3)

    print(f'loss: {loss:.3f} at epoch {epoch}')

    if epoch % FLAGS.log_steps == 0:

        grid_img = make_grid(x.cpu()[:16], nrow=8)
        plt.figure(figsize=(10, 5))
        plt.imshow(grid_img.permute(1, 2, 0))
        plt.show()

        grid_img = make_grid(out_params['x_reconst'].detach().cpu()[:16], nrow=8)
        plt.figure(figsize=(10, 5))
        plt.imshow(grid_img.permute(1, 2, 0))
        plt.show()

        tsne = decomposition.TSNE()
        features = tsne.fit_transform(features)
        plot.scatter(features[:,0], features[:,1], target_stack)
        plot.scatter(features[:,0], features[:,1], pred_stack)
        plot.scatter(features[:,0], features[:,1], km_stack)
        
        pca = decomposition.PCA()
        features = pca.fit_transform(features)
        plot.scatter(features[:,0], features[:,1], target_stack)
        plot.scatter(features[:,0], features[:,1], pred_stack)
        plot.scatter(features[:,0], features[:,1], km_stack)




In [5]:
import numpy as np
import faiss

mt = np.random.rand(1000, 40).astype('float32')
print(mt.shape)
mat = faiss.PCAMatrix (40, 2)
mat.train(mt)
tr = mat.apply_py(mt)



(1000, 40)
