In [40]:
from util_get_seq_emb_data import get_seq_emb_dataset

all_seqs = get_seq_emb_dataset()

100%|██████████| 588/588 [00:38<00:00, 15.35it/s]


In [88]:
seq_len = 5
input_size = 8
hidden_size = 100
latent_size = 20
num_layers = 1
num_epochs = 300
batch_size = 32

In [3]:
from torch.utils.data import Dataset

class MySequenceDataset(Dataset):
    def __init__(self, data, stop_token):
        """
        data: list of sequences; each sequence is a list of (1, 8)-shaped tensors
        stop_token: a (1, 8)-shaped tensor to append
        """
        self.data = data
        self.stop_token = stop_token
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        seq_list = self.data[idx]
        # seq_list: list of (1, 8) tensors
        seq_tensor = torch.cat(seq_list, dim=0)  # shape: (seq_len, 8)
        # Add stop token
        seq_tensor = torch.cat([seq_tensor, self.stop_token], dim=0)  
        # Now shape: (seq_len+1, 8)
        return seq_tensor


In [4]:
def collate_fn(batch):
    """
    batch: list of tensors [seq1, seq2, ..., seqB]
           where each seq is shape (seq_len_i, 8).
    We want to pad them to the same length = max(seq_len_i).
    """
    # 1) Find max length
    max_len = max(seq.size(0) for seq in batch)
    
    # 2) Pad
    padded_batch = []
    for seq in batch:
        seq_len = seq.size(0)
        # shape: (seq_len, 8)
        if seq_len < max_len:
            pad_len = max_len - seq_len
            pad_tensor = torch.zeros(pad_len, 8).cuda()
            seq = torch.cat([seq, pad_tensor], dim=0)  # (max_len, 8)
        padded_batch.append(seq.unsqueeze(0))  # (1, max_len, 8)
    
    # 3) Stack => (B, T, 8)
    padded_batch = torch.cat(padded_batch, dim=0)
    return padded_batch


In [38]:
import torch
from torch.utils.data import DataLoader, random_split

STOP_TOKEN = torch.zeros(1, 8).cuda()  # shape = (1, 8), all zeros

dataset = MySequenceDataset(all_seqs, STOP_TOKEN)

train_size = int(0.8 * len(dataset))
val_size   = len(dataset) - train_size

train_dataset, val_dataset = random_split(
    dataset, 
    lengths=[train_size, val_size], 
    generator=torch.Generator().manual_seed(42)  # for reproducibility
)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_fn
)



In [89]:
import torch.nn as nn
import torch.optim as optim

class GRUVAE(nn.Module):
    def __init__(
        self,
        input_size: int,   # dimension of each embedding at a timestep
        hidden_size: int,  # hidden dimension of GRU
        latent_size: int,  # dimension of latent space z
        num_layers: int = 1,
    ):
        super(GRUVAE, self).__init__()

        self.num_layers = num_layers
        
        self.encoder_gru = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        
        self.hidden2mu = nn.Linear(hidden_size, latent_size)
        self.hidden2logvar = nn.Linear(hidden_size, latent_size)
        
        self.latent2hidden = nn.Linear(latent_size, hidden_size)
        self.decoder_gru = nn.GRU(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True
        )
        self.h2output = nn.Linear(hidden_size, input_size)

    def encode(self, x):
        # We only need the final hidden state from the GRU
        _, h_n = self.encoder_gru(x)
        
        # h_n: (num_layers, batch_size, hidden_size)
        # Let's take only the top layer
        h_n_top = h_n[-1]  # shape: (batch_size, hidden_size)
        
        mu = self.hidden2mu(h_n_top)
        logvar = self.hidden2logvar(h_n_top)
        return mu, logvar

    def reparameterize(self, mu, logvar):

        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z, seq_len):
        """
        Decodes a latent vector z into a sequence of length seq_len.
        Args:
            z: (batch_size, latent_size)
            seq_len: int, the length of the output sequence
        Returns:
            outputs: (batch_size, seq_len, input_size)
        """
        # Transform latent vector to initial hidden state for GRU
        hidden = self.latent2hidden(z)         # (batch_size, hidden_size)
        hidden = hidden.unsqueeze(0).repeat(self.num_layers, 1, 1)  
        
        # We'll generate the sequence step by step.
        batch_size = z.size(0)
        outputs = []
        
        # Start with a zero vector as the "input" for each timestep
        input_step = torch.zeros(batch_size, 1, self.h2output.out_features, device=z.device)
        for t in range(seq_len):
            # Pass one step at a time
            out, hidden = self.decoder_gru(input_step, hidden)
            # out: (batch_size, 1, hidden_size)
            # Project back to embedding dimension
            step_output = self.h2output(out)   # (batch_size, 1, input_size)
            outputs.append(step_output)
            
            # The next input is the current output (autoregressive decoding)
            input_step = step_output

        # Concatenate along seq_len dimension
        outputs = torch.cat(outputs, dim=1)    # (batch_size, seq_len, input_size)
        return outputs

    def forward(self, x):

        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        seq_len = x.size(1)
        recon_x = self.decode(z, seq_len)
        return recon_x, mu, logvar


In [128]:
def loss_function(recon_x, x, mu, logvar):
    """
    Computes the VAE loss = Reconstruction Loss + KL Divergence.
    Here we use MSE for reconstruction (for continuous data).
    """
    # MSE Reconstruction Loss
    mse_loss = nn.MSELoss(reduction='mean')
    recon_loss = mse_loss(recon_x, x)

    # KL Divergence
    # D_KL = -0.5 * sum(1 + logvar - mu^2 - exp(logvar))
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return recon_loss + 0.0001*kl_div


In [129]:
# Model, optimizer
model = GRUVAE(input_size, hidden_size, latent_size, num_layers).cuda()

In [130]:
optimizer = optim.Adam(model.parameters(), lr=5e-4)

In [132]:

model.train()
    
for epoch in range(num_epochs):
    total_train_loss = 0.0
    
    for batch_data in train_loader:
        batch_data = batch_data.cuda()
        optimizer.zero_grad()
        
        recon_x, mu, logvar = model(batch_data)
        
        loss = loss_function(recon_x, batch_data, mu, logvar)
        
        loss.backward()
        optimizer.step()
        
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader)
    
    
    model.eval()         
    total_val_loss = 0.0
    
    with torch.no_grad(): # no gradient calc
        for val_data in val_loader:
            val_data = val_data.cuda()
            
            recon_x, mu, logvar = model(val_data)
            val_loss = loss_function(recon_x, val_data, mu, logvar)
            total_val_loss += val_loss.item()
    
    avg_val_loss = total_val_loss / len(val_loader)
    
    model.train()
    
    print(f"Epoch [{epoch+1}/{num_epochs}] | "
            f"Train Loss: {avg_train_loss:.4f} | "
            f"Val Loss: {avg_val_loss:.4f}")



Epoch [1/300] | Train Loss: 0.2746 | Val Loss: 0.3769
Epoch [2/300] | Train Loss: 0.2703 | Val Loss: 0.4035
Epoch [3/300] | Train Loss: 0.2805 | Val Loss: 0.3811
Epoch [4/300] | Train Loss: 0.2614 | Val Loss: 0.3780
Epoch [5/300] | Train Loss: 0.2869 | Val Loss: 0.4030
Epoch [6/300] | Train Loss: 0.2783 | Val Loss: 0.3738
Epoch [7/300] | Train Loss: 0.2666 | Val Loss: 0.3826
Epoch [8/300] | Train Loss: 0.2553 | Val Loss: 0.3687
Epoch [9/300] | Train Loss: 0.2416 | Val Loss: 0.3784
Epoch [10/300] | Train Loss: 0.2644 | Val Loss: 0.3693
Epoch [11/300] | Train Loss: 0.2466 | Val Loss: 0.3706
Epoch [12/300] | Train Loss: 0.2430 | Val Loss: 0.3845
Epoch [13/300] | Train Loss: 0.2829 | Val Loss: 0.3606
Epoch [14/300] | Train Loss: 0.2891 | Val Loss: 0.3879
Epoch [15/300] | Train Loss: 0.2628 | Val Loss: 0.3598
Epoch [16/300] | Train Loss: 0.2464 | Val Loss: 0.3589
Epoch [17/300] | Train Loss: 0.2482 | Val Loss: 0.3530
Epoch [18/300] | Train Loss: 0.2824 | Val Loss: 0.3901
Epoch [19/300] | Tr

KeyboardInterrupt: 

In [92]:
sum(p.numel() for p in model.parameters())

72948

In [133]:
PATH = "pt_model.pth"

# Option 1: Save only the model's state_dict
torch.save(model.state_dict(), PATH)