In [46]:
import torch

torch.manual_seed(42)
torch.

[0;31mSignature:[0m [0mtorch[0m[0;34m.[0m[0mrandom[0m[0;34m.[0m[0mset_rng_state[0m[0;34m([0m[0mnew_state[0m[0;34m)[0m [0;34m->[0m [0;32mNone[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Sets the random number generator state.

Args:
    new_state (torch.ByteTensor): The desired state
[0;31mFile:[0m      ~/projects/VaDE/vade/lib/python3.8/site-packages/torch/random.py
[0;31mType:[0m      function


In [51]:
torch.randint(10000000, size=(10,))[0]

8409749

In [5]:
import torch
torch.random.seed()
from torch import nn
from torch.utils.data import DataLoader, Dataset, IterableDataset, TensorDataset, Subset, ConcatDataset
from torchvision.datasets import MNIST, FashionMNIST
import pytorch_lightning as pl
from triplet_vade import TripletDataset, CombinedDataset, TripletVaDE
from autoencoder import SimpleAutoencoder, ClusteringEvaluationCallback, cluster_acc
from sklearn.mixture import GaussianMixture
from pl_modules import PLVaDE

In [2]:
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, path='data', bs=256, dataset='mnist', data_size=None):
        super(MNISTDataModule, self).__init__()
        self.path = path
        self.batch_size = bs
        self.dataset = dataset
        self.data_size = data_size
    
    def setup(self, stage_name=None):
        self.train_ds = MNIST(self.path, download=True, train=True)
        self.valid_ds = MNIST(self.path, download=True, train=False)
        if self.dataset == 'mnist':
            self.train_ds = MNIST("data", download=True)
            self.valid_ds = MNIST("data", download=True, train=False)
        elif self.dataset == 'fmnist':
            self.train_ds = FashionMNIST("data", download=True)
            self.valid_ds = FashionMNIST("data", download=True, train=False)
        def to_tensor_dataset(ds):
            X = ds.data.view(-1, 28**2).float()/255.
            X = torch.cat([X, torch.rand(X.shape[0], 1000)], dim=1)
            return TensorDataset(X, ds.targets)
        self.train_ds, self.valid_ds = map(to_tensor_dataset, [self.train_ds, self.valid_ds])
        if self.data_size is not None:
            n_sample = self.data_size
            to_subset = lambda ds: torch.utils.data.random_split(ds, 
                                                                 [n_sample, len(ds) - n_sample],
                                                                 torch.Generator().manual_seed(42))[0]
            self.train_ds, self.valid_ds = map(to_subset, [self.train_ds, self.valid_ds])
        self.all_ds = ConcatDataset([self.train_ds, self.valid_ds])
                
    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.valid_ds, batch_size=self.batch_size)
    
dm = MNISTDataModule()

In [None]:
model = SimpleAutoencoder(n_neurons=[28**2 + 1000, 512, 512, 2048, 10])
logger = pl.loggers.WandbLogger(project='Corrupted_MNIST')
trainer = pl.Trainer(gpus=1, progress_bar_refresh_rate=20, max_epochs=100, logger=logger,
                     callbacks=[ClusteringEvaluationCallback(ds_type='all')])
dm = MNISTDataModule(data_size=None, dataset='mnist')
trainer.fit(model, datamodule=dm)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 2.2 M 
1 | decoder | Sequential | 2.2 M 
---------------------------------------
4.5 M     Trainable params
0         Non-trainable params
4.5 M     Total params


Training: 0it [00:00, ?it/s]                                  



Epoch 0:  87%|████████▋ | 240/275 [00:01<00:00, 129.93it/s, loss=0.483, v_num=og0h]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 0: 100%|██████████| 275/275 [00:02<00:00, 136.42it/s, loss=0.481, v_num=og0h]
Epoch 1:  87%|████████▋ | 240/275 [00:01<00:00, 131.21it/s, loss=0.459, v_num=og0h]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 1: 100%|██████████| 275/275 [00:01<00:00, 137.58it/s, loss=0.458, v_num=og0h]
Epoch 2:  87%|████████▋ | 240/275 [00:01<00:00, 133.57it/s, loss=0.451, v_num=og0h]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 2: 100%|██████████| 275/275 [00:01<00:00, 139.84it/s, loss=0.451, v_num=og0h]
Epoch 3:  87%|████████▋ | 240/275 [00:01<00:00, 131.04it/s, loss=0.447, v_num=og0h]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/40 [00:00<?, ?it/s][A
Epoch 3: 100%|██████████| 275/275 [00:02<00:00, 137.16it/s, 

In [3]:
def get_data_and_targets(dataset):
    data = torch.stack([dataset[i][0] for i in range(len(dataset))], dim=0)
    targets = torch.stack([dataset[i][1] for i in range(len(dataset))], dim=0)
    return data, targets


class TripletsDataModule(pl.LightningDataModule):
    def __init__(self, base_datamodule, n_samples_for_triplets=None, n_triplets=None, n_triplets_valid=None, batch_size=256):
        super(TripletsDataModule, self).__init__()
        self.base_datamodule = base_datamodule
        self.batch_size = batch_size
        self.n_samples_for_triplets = n_samples_for_triplets
        self.n_triplets = n_triplets
        if n_triplets_valid is None:
            n_triplets_valid = n_triplets
        self.n_triplets_valid = n_triplets_valid
    
    def setup(self, stage_name=None):
        self.train_ds = TripletDataset(*get_data_and_targets(self.base_datamodule.train_ds), data_size=self.n_samples_for_triplets, 
                                                max_samples=self.n_triplets)
        self.valid_ds = TripletDataset(*get_data_and_targets(self.base_datamodule.valid_ds), data_size=self.n_samples_for_triplets,
                                                max_samples=self.n_triplets_valid)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.valid_ds, batch_size=2**11)

In [32]:
def cluster_data(self, dl, k=10):
    encodings, y_true = [], []
    self.eval()
    with torch.no_grad():
        for batch in dl:
            bx, by = batch
            enc = self.encoder(bx.cuda())
            encodings.append(enc)
            y_true.append(by)
    encodings = torch.cat(encodings, dim=0).detach().cpu().numpy()
    y_true = torch.cat(y_true, dim=0).detach().cpu().numpy()
    gmm = GaussianMixture(n_components=k, covariance_type='full')
    cluster_labels = gmm.fit_predict(encodings)
    return y_true, cluster_labels, encodings


model.cuda()
y_true, y_pred, _ = cluster_data(model, dm.val_dataloader())
cluster_acc(y_true, y_pred)

0.9494

In [31]:
model.encoder[0][0].weight.mean(dim=0)

tensor([-0.0003,  0.0001,  0.0005,  ..., -0.0078, -0.0078,  0.0005],
       device='cuda:0', grad_fn=<MeanBackward1>)

In [19]:
model.cuda()
y_true, y_pred, _ = model.cluster_data(dm.val_dataloader())
cluster_acc(y_true, y_pred)

0.8623

In [23]:
cluster_acc(y_true, y_pred)

0.43851666666666667

In [None]:
class TripletLossModule(pl.LightningModule):
    def __init__(self, n_neurons, margin=0.5, lr=1e-3, p=0.):
        super(TripletLossModule, self).__init__()
        self.save_hyperparameters()
        self.n_neurons = n_neurons
        layers = [nn.Sequential(nn.Linear(n_neurons[i-1], n_neurons[i]), nn.ELU(), nn.Dropout(p=p)) for i in range(1, len(n_neurons)-1)]
        layers.append(nn.Linear(n_neurons[-2], n_neurons[-1]))
        self.encoder = nn.Sequential(*layers)
        
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr=self.hparams['lr'])
        return opt

    def shared_step(self, batch):
        enc = {k: self.encoder(v) for k, v in batch.items()}
        enc = {k: v / (torch.linalg.norm(v, dim=1, keepdims=True) + 1e-8) for k, v in enc.items()}
        pos_distance = torch.linalg.norm(enc['anchor'] - enc['positive'], dim=-1)
        neg_distance = torch.linalg.norm(enc['anchor'] - enc['negative'], dim=-1)
        pct_correct = (pos_distance < neg_distance).float().mean() * 100
        loss = (pos_distance - neg_distance + self.hparams['margin']).clamp(0)
        if not loss.isfinite().all():
            import pdb; pdb.set_trace()
        result = {'loss': loss.mean(), 'pct_correct': pct_correct}
        return result
    
    def training_step(self, batch, batch_idx):
        result = self.shared_step(batch)
        for k, v in result.items():
            self.log('train/' + k, v)
        return result
    
    def validation_step(self, batch, batch_idx):
        result = self.shared_step(batch)
        for k, v in result.items():
            self.log('valid/' + k, v)
        return result
        
    def cluster_data(self, dl, k=10):
        encodings, y_true = [], []
        self.eval()
        with torch.no_grad():
            for batch in dl:
                bx, by = batch
                enc = self.encoder(bx.cuda())
                encodings.append(enc)
                y_true.append(by)
        encodings = torch.cat(encodings, dim=0).detach().cpu().numpy()
        y_true = torch.cat(y_true, dim=0).detach().cpu().numpy()
        gmm = GaussianMixture(n_components=k, covariance_type='full')
        cluster_labels = gmm.fit_predict(encodings)
        return y_true, cluster_labels, encodings

import wandb
model = TripletLossModule([28**2 + 1000, 512, 512, 2048, 10], lr=3e-4, p=0.1)
dm.setup()
data_module = TripletsDataModule(dm, n_triplets=2**10, n_triplets_valid=2**15, batch_size=2**6)
logger = pl.loggers.WandbLogger(project='Corrupted_MNIST', group='triplets')

trainer = pl.Trainer(gpus=1, max_epochs=3000, progress_bar_refresh_rate=0, log_every_n_steps=1, logger=logger)
trainer.fit(model, data_module)
wandb.finish()


GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
[34m[1mwandb[0m: wandb version 0.10.21 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade



  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 2.2 M 
---------------------------------------
2.2 M     Trainable params
0         Non-trainable params
2.2 M     Total params


Validation sanity check:   0%|          | 0/2 [00:00<?, ?it/s]



Epoch 0:  62%|██████▎   | 20/32 [00:00<00:00, 104.36it/s]     
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/16 [00:00<?, ?it/s][A




Epoch 0: 100%|██████████| 32/32 [00:01<00:00, 16.24it/s, loss=0.353, v_num=ecpi]
Epoch 1:  62%|██████▎   | 20/32 [00:00<00:00, 105.91it/s, loss=0.353, v_num=ecpi]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/16 [00:00<?, ?it/s][A
Epoch 1: 100%|██████████| 32/32 [00:01<00:00, 16.10it/s, loss=0.211, v_num=ecpi] 
Epoch 2:  62%|██████▎   | 20/32 [00:00<00:00, 106.49it/s, loss=0.211, v_num=ecpi]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/16 [00:00<?, ?it/s][A
Epoch 2: 100%|██████████| 32/32 [00:01<00:00, 16.39it/s, loss=0.187, v_num=ecpi] 
Epoch 3:  62%|██████▎   | 20/32 [00:00<00:00, 107.34it/s, loss=0.187, v_num=ecpi]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/16 [00:00<?, ?it/s][A
Epoch 3: 100%|██████████| 32/32 [00:01<00:00, 16.19it/s, loss=0.153, v_num=ecpi] 
Epoch 4:  62%|██████▎   | 20/32 [00:00<00:00, 105.65it/s, loss=0.153, v_num=ecpi]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|          | 0/16 [00:00<?, ?i

In [None]:
class CombinedDataModule(pl.LightningDataModule):
    def __init__(self, base_datamodule, n_samples_for_triplets=None, n_triplets=None, batch_size=256):
        super(combinedDataModule, self).__init__()
        self.batch_size = batch_size
        self.base_datamodule = base_datamodule
        self.n_samples_for_triplets = n_samples_for_triplets
        self.n_triplets = n_triplets
    
    def setup(self, stage_name=None):
        self.train_ds = CombinedDataset(self.base_datamodule.train_ds, data_size=self.n_samples_for_triplets, 
                                                max_triplets=self.n_triplets)
        self.valid_ds = CombinedDataset(self.base_datamodule.valid_ds, data_size=self.n_samples_for_triplets,
                                                max_triplets=self.n_triplets)

    def train_dataloader(self):
        return DataLoader(self.train_ds, batch_size=self.batch_size)
    
    def val_dataloader(self):
        return DataLoader(self.valid_ds, batch_size=self.batch_size)