In [51]:
import os
from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader
import torch
import habitat
import gzip
import pickle
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm

import torch.nn.functional as F
from torch.distributions import RelaxedOneHotCategorical, kl_divergence
from random import randrange

In [52]:
modes = [ [1] * 12, [1, 2] * 15, [2, 3] * 12 + [1] ]
mode2class = { sum(modes[i]): i for i in range(len(modes)) }
class ActionDataset(Dataset):

    def __init__(self, data_path=None, bostoken=4, min_len=8, max_len=40):
        assert (data_path is not None)
        
        # set metadata using the config file
        self.files = [os.path.join(data_path, file) for file in sorted(os.listdir(data_path))]
        self.bostoken = bostoken
        self.min_len = min_len
        self.max_len = max_len
    
    def __len__(self):
        return len(self.files) 
    
    def __getitem__(self, idx):
        file = self.files[idx]
        # episode = torch.load(file)
        # # with gzip.open(file) as f:
        #     # episode = pickle.load(f)

        # # add a BOS token
        # actions = episode['action'].tolist()
        # actions = [self.bostoken] + actions
        # rand_len = min(randrange(self.min_len, self.max_len), len(actions)) 
        # start = randrange(0, len(actions) - rand_len + 1)
        # actions = actions[start:start + rand_len]
        # idx = randrange(0, 3)
        return torch.tensor(modes[idx % 3], dtype=torch.long)
        # return torch.tensor(actions, dtype=torch.long)

In [53]:
class SequenceEncoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, output_dim):
        super(SequenceEncoder, self).__init__()
        
        # Embedding Layer
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        
        # GRU Unit
        self.gru = nn.GRU(embedding_dim, hidden_dim, num_layers=num_layers, batch_first=True)
        
        # Linear Layer for Mean and Variance
        self.linear = nn.Linear(hidden_dim, output_dim * 2)  # Multiply by 2 for mean and variance

    def forward(self, x):
        # x is the input sequence of discrete values
        
        # Embedding layer
        embedded = self.embedding(x)
        
        # GRU layer
        output, hidden = self.gru(embedded)
        
        # Linear layer for mean and variance
        output = self.linear(hidden[-1])  # Taking the output of the last time step
        
        # Split the output into mean and variance
        mean, log_var = torch.chunk(output, 2, dim=-1)
        
        return mean, log_var

class SequenceClassifier(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.linear = nn.Linear(latent_dim, len(modes))

    def forward(self, x):
        return F.softmax(self.linear(x))
        

class SequenceDecoder(nn.Module):
    def __init__(self, latent_dim, hidden_dim, num_layers, vocab_size, embedding_module):
        super(SequenceDecoder, self).__init__()
        self.nlayers = num_layers

        # GRU layer
        self.gru = nn.GRU(embedding_module.embedding_dim, latent_dim, num_layers, batch_first=True)

        # Linear layer to output continuous values
        self.linear = nn.Linear(latent_dim, vocab_size)
        self.embedding_module = embedding_module

    def sample_sequence(self, length, start_token, temperature=1.0, hidden=None):
        inp = torch.tensor([start_token] * hidden.shape[1], dtype=torch.long, device='cuda')[:, None]
        output_sequence = []
        for _ in range(length):
            inp = self.embedding_module(inp)
            out, hidden = self.gru(inp, hidden)
            logits = self.linear(out)
            logits = logits / temperature
            probabilities = F.softmax(logits.squeeze(), dim=-1)
            next_token = torch.multinomial(probabilities, 1)
            output_sequence.append(probabilities)
            inp = next_token.long()[:, None]
        return torch.stack(output_sequence, dim=0)[None,:]


    def forward(self, z, seq_length):
        # 'z' is the latent vector
        # 'seq_length' is the desired length of the generated sequence
        # 'batch_size' is the batch size
        hx = z.expand(self.nlayers, z.shape[0], -1).contiguous()
        return self.sample_sequence(seq_length, padtoken, hidden=hx)


def sample_from_latent_space(mean, log_var):
    """
    Samples from the latent space using the reparameterization trick.

    Args:
        mean (torch.Tensor): Mean values from the encoder.
        log_var (torch.Tensor): Log variance values from the encoder.

    Returns:
        torch.Tensor: Samples from the latent space.
    """
    # Standard deviation (sigma) is the square root of the variance (exp(log_var/2))
    std_dev = torch.exp(0.5 * log_var)

    # Sample from a standard normal distribution
    epsilon = torch.randn_like(std_dev)

    # Reparameterization trick to sample from the latent space
    sampled_latent = mean + std_dev * epsilon

    return sampled_latent

In [63]:
def vae_loss(output_sequence, target_sequence, mean, log_var, mask):
    # Flatten the sequences and masks for convenience
    output_sequence = output_sequence.view(-1, output_sequence.size(-1))
    target_sequence = target_sequence.view(-1)
    mask = mask.view(-1)
    
    # Reconstruction loss using cross-entropy loss (ignoring padding tokens)
    reconstruction_loss = torch.sum(
        F.cross_entropy(output_sequence, target_sequence, reduction='none') #* mask
    ) / torch.sum(mask)

    # KL divergence loss
    kl_divergence_loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
    
    # Total VAE loss
    
    return reconstruction_loss, kl_divergence_loss
    
class Model(nn.Module):

    def __init__(self, embedding_dim, hidden_dim, latent_dim, layers):
        super().__init__()
        self.encoder = SequenceEncoder(vocab_size=6, embedding_dim=embedding_dim, hidden_dim=hidden_dim, num_layers=layers, output_dim=latent_dim)
        self.decoder = SequenceDecoder(vocab_size=6, latent_dim=latent_dim, hidden_dim=hidden_dim, num_layers=layers, embedding_module=self.encoder.embedding)
        self.classifier = SequenceClassifier(latent_dim=latent_dim)

    def forward(self, x):
        B, T = x.shape
        mean, var = self.encoder(x)
        z = sample_from_latent_space(mean, var)
        # out = self.decoder(z, seq_length=T)
        out = self.classifier(z)
        return mean, var, out 


In [64]:
config = {
    'seq_length': None,
    'feature_size': 2,
    'batch_size': 64,
    'workers': 16,
    'epochs': 1,
    'lr': 1e-3,

    # model architecture
    'embedding': 4,
    'hidden': 8,
    'latent_dim': 64,
    'layers': 4
}
bostoken=4
padtoken=5

In [65]:
dataset = ActionDataset("/srv/flash1/pputta7/projects/lm-nav/data/datasets/lmnav/offline_10envs/")
dataloader = DataLoader(dataset, collate_fn=lambda t: pad_sequence(t, batch_first=True, padding_value=padtoken), batch_size=config['batch_size'], num_workers=config['workers'])
model = Model(config['embedding'], config['hidden'], config['latent_dim'], config['layers'])
print("Num params", sum([p.numel() for p in model.parameters()]))

Num params 91713


In [66]:
optim = torch.optim.Adam(model.parameters(), lr=config['lr'])
model = model.cuda()

ckpt = torch.load('skill_vae.pt')
# model.load_state_dict(ckpt['model'])
# optim.load_state_dict(ckpt['optim'])
for epoch in range(10):
    total_loss = 0
    rloss = 0
    klloss = 0
    steps = 0
    for actions in (pbar := tqdm(dataloader)):
        actions = actions.to('cuda')
        optim.zero_grad()
        
        mask = (actions != padtoken).clone().to(torch.bool) 
        mean, var, pred_actions = model(actions)

        actlabels = torch.tensor([mode2class[sum(actions[i] * mask[i]).item()] for i in range(actions.shape[0])]).cuda()
        
        reconstruction_loss, kl_divergence_loss = vae_loss(pred_actions, actlabels, mean, var, mask)
        loss = reconstruction_loss +  kl_divergence_loss
        
        total_loss += loss.cpu().item()
        rloss += reconstruction_loss.cpu().item()
        klloss += kl_divergence_loss.cpu().item()
        
        loss.backward()
        optim.step()

        steps += 1

        pbar.set_description(f"Avg Loss: {total_loss / steps} ; R: {rloss / steps} ; K: {klloss / steps}")
        if steps % 25 == 0:
            torch.save({'model': model.state_dict(), 'optim': optim.state_dict()}, 'skill_vae2.pt')
            # print(pred_actions[0])


  return F.softmax(self.linear(x))
Avg Loss: 10.696387437941814 ; R: 0.049399474093100935 ; K: 10.646987920966135: 100%|█| 714
Avg Loss: 0.04939724425120013 ; R: 0.04925078407702159 ; K: 0.00014646017417854288: 100%|█|
Avg Loss: 0.04926701729578712 ; R: 0.049240853540352174 ; K: 2.6163755434949252e-05:  59%|▌


KeyboardInterrupt: 

In [67]:
dataset = ActionDataset("/srv/flash1/pputta7/projects/lm-nav/data/datasets/lmnav/offline_10envs/")
dataloader = DataLoader(dataset, collate_fn=lambda t: pad_sequence(t, batch_first=True, padding_value=padtoken), batch_size=1, num_workers=1)
model = Model(config['embedding'], config['hidden'], config['latent_dim'], config['layers']).cuda()
model.load_state_dict(torch.load('skill_vae2.pt')['model'])
print("Num params", sum([p.numel() for p in model.parameters()]))

Num params 91713


In [68]:
embeds = []
for i, actions in (pbar := enumerate(tqdm(dataloader))):
    mean, var, pred_actions = model(actions.cuda())
    print(pred_actions, mode2class[actions.sum().item()])
    if i == 23:
        break

  return F.softmax(self.linear(x))
  0%|                                                  | 1/45650 [00:00<1:47:03,  7.11it/s]

tensor([[0.3188, 0.3741, 0.3071]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 0
tensor([[0.3100, 0.3390, 0.3510]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 1
tensor([[0.4400, 0.2721, 0.2878]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 2
tensor([[0.3397, 0.3354, 0.3249]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 0
tensor([[0.3314, 0.4060, 0.2626]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 1
tensor([[0.3068, 0.3993, 0.2939]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 2
tensor([[0.3349, 0.3653, 0.2998]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 0
tensor([[0.3458, 0.3545, 0.2998]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 1
tensor([[0.4320, 0.2641, 0.3039]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 2
tensor([[0.3522, 0.3182, 0.3296]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 0
tensor([[0.2824, 0.3719, 0.3457]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 1
tensor([[0.3189, 0.3725, 0.3087]], device='cuda:0', grad_fn=<SoftmaxBackward0>) 2
tensor([[0.3900,

In [16]:
print(torch.argmax(pred_actions, dim=2))

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1,
         1]], device='cuda:0')


In [17]:
actions

tensor([[2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3, 2, 3,
         1]])

  0%|                                                   | 18/45650 [00:20<13:06, 58.00it/s]

In [37]:
pred_actions

tensor([[[5.3743e-06, 9.9899e-01, 9.8306e-04, 9.5443e-06, 6.7561e-06,
          8.1095e-06],
         [6.0537e-07, 9.9985e-01, 1.4796e-04, 2.8720e-07, 6.5884e-07,
          7.8119e-07],
         [4.0390e-07, 9.9992e-01, 8.1917e-05, 9.9207e-08, 4.4757e-07,
          5.0765e-07],
         [3.6830e-07, 9.9994e-01, 6.2958e-05, 5.9764e-08, 4.1627e-07,
          4.6077e-07],
         [3.6063e-07, 9.9994e-01, 5.4598e-05, 4.3995e-08, 4.1177e-07,
          4.5053e-07],
         [3.6005e-07, 9.9995e-01, 5.0083e-05, 3.5629e-08, 4.1325e-07,
          4.4904e-07],
         [3.6141e-07, 9.9995e-01, 4.7332e-05, 3.0450e-08, 4.1625e-07,
          4.4986e-07],
         [3.6318e-07, 9.9995e-01, 4.5525e-05, 2.6928e-08, 4.1953e-07,
          4.5123e-07],
         [3.6484e-07, 9.9995e-01, 4.4285e-05, 2.4373e-08, 4.2272e-07,
          4.5260e-07],
         [3.6624e-07, 9.9996e-01, 4.3413e-05, 2.2430e-08, 4.2566e-07,
          4.5376e-07],
         [3.6731e-07, 9.9996e-01, 4.2796e-05, 2.0892e-08, 4.2827e-07,
