In [1]:
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm  
from dataloader import *
from model import * 
from nltk.translate.bleu_score import sentence_bleu
import nltk 
import torch.nn as nn
import torch.nn.functional as F


# Hyperparameters
num_epochs = 10
learning_rate = 0.001
target_confidence = 0.8 

In [31]:
def collate_fn(batch):
    inputs, labels = zip(*batch)
    max_length = max(len(seq) for seq in inputs)
    
    # Convert each sequence to a list, pad with 0, and convert to tensor
    padded_inputs = [torch.cat([seq, torch.zeros(max_length - len(seq), dtype=torch.long)]) for seq in inputs]
    lengths = [len(seq) for seq in inputs]
    
    return torch.stack(padded_inputs), torch.tensor(labels, dtype=torch.float), lengths

def tokens_to_words(token_ids, vocab):
    inv_vocab = {v: k for k, v in vocab.items()}
    return [inv_vocab.get(token_id, '<UNK>') for token_id in token_ids if token_id != 0]  # Exclude padding


class TextDatasetTest(Dataset):
    def __init__(self, data_dir, vocab, files):
        super(TextDatasetTest, self).__init__()
        self.data = []
        self.vocab = vocab 

        for filename in files:
            file_path = os.path.join(data_dir, filename)
            with open(file_path, 'r', encoding='utf-8') as f:
                lines = f.readlines()
                for line in lines:
                    tokens = line.strip().split()
                    label = 1 if filename.endswith('.1') else 0  # Binary label
                    self.data.append((tokens, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        tokens, label = self.data[idx]
        token_ids = [self.vocab.get(token, self.vocab['<UNK>']) for token in tokens]
        return torch.tensor(token_ids, dtype=torch.long), torch.tensor(label, dtype=torch.long)
     
    
files = ["sentiment.train.0", "sentiment.train.1"]
files_test = ["sentiment.test.0", "sentiment.test.1"]

data_dir = "./data/sentiment_style_transfer/yelp"
vocab = build_vocab(data_dir)
dataset = TextDatasetTest(data_dir, vocab, files)
dataset_test = TextDatasetTest(data_dir, vocab, files_test)
data_loader = DataLoader(dataset, batch_size=64, collate_fn=collate_fn, shuffle=True)
data_loader_test = DataLoader(dataset_test, batch_size=64, collate_fn=collate_fn, shuffle=False)

In [23]:
len(dataset)

443259

In [24]:
class DisentangledVAE(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, style_dim, content_dim):
        super(DisentangledVAE, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.hidden_dim = hidden_dim
        
        # Encoder
        self.encoder_rnn = nn.GRU(embedding_dim, hidden_dim, 
                                 batch_first=True, 
                                 bidirectional=True,
                                 num_layers=2,
                                 dropout=0.2)
        
        # Latent spaces
        self.style_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.content_encoder = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        # Mean and logvar projections
        self.style_mu = nn.Linear(hidden_dim, style_dim)
        self.style_logvar = nn.Linear(hidden_dim, style_dim)
        self.content_mu = nn.Linear(hidden_dim, content_dim)
        self.content_logvar = nn.Linear(hidden_dim, content_dim)
        
        # Decoder
        self.latent_to_hidden = nn.Sequential(
            nn.Linear(style_dim + content_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        
        self.decoder_rnn = nn.GRU(embedding_dim + hidden_dim, hidden_dim,
                                 batch_first=True,
                                 num_layers=2,
                                 dropout=0.2)
        
        self.output_fc = nn.Linear(hidden_dim, vocab_size)
        
        # Style classifier for adversarial training
        self.style_classifier = nn.Sequential(
            nn.Linear(style_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )
        
    def encode(self, x, lengths=None):
        batch_size = x.size(0)
        
        # Embed input
        embedded = self.embedding(x)
        
        # Pack for variable length sequences
        if lengths is not None:
            embedded = nn.utils.rnn.pack_padded_sequence(
                embedded, lengths, batch_first=True, enforce_sorted=False
            )
        
        # Encode
        _, hidden = self.encoder_rnn(embedded)
        # Combine bidirectional states
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        
        # Encode style and content
        style_hidden = self.style_encoder(hidden)
        content_hidden = self.content_encoder(hidden)
        
        # Get latent parameters
        style_mu = self.style_mu(style_hidden)
        style_logvar = self.style_logvar(style_hidden)
        content_mu = self.content_mu(content_hidden)
        content_logvar = self.content_logvar(content_hidden)
        
        return style_mu, style_logvar, content_mu, content_logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
        
    def decode(self, style, content, x):
        batch_size = x.size(0)
        max_len = x.size(1)
        
        # Combine latent vectors
        latent = torch.cat([style, content], dim=1)
        hidden = self.latent_to_hidden(latent)
        
        # Initialize decoder hidden state
        hidden = hidden.unsqueeze(0).repeat(2, 1, 1)  # num_layers * batch * hidden
        
        # Teacher forcing with concatenated latent
        embedded = self.embedding(x)
        hidden_expanded = hidden[-1].unsqueeze(1).repeat(1, max_len, 1)
        decoder_input = torch.cat([embedded, hidden_expanded], dim=2)
        
        # Decode
        outputs, _ = self.decoder_rnn(decoder_input, hidden)
        outputs = self.output_fc(outputs)
        
        return outputs

    def forward(self, x, lengths=None):
        # Encode
        style_mu, style_logvar, content_mu, content_logvar = self.encode(x, lengths)
        
        # Sample latent vectors
        style = self.reparameterize(style_mu, style_logvar)
        content = self.reparameterize(content_mu, content_logvar)
        
        # Decode
        recon_x = self.decode(style, content, x)
        
        return recon_x, style_mu, style_logvar, content_mu, content_logvar, style, content

    def classify_style(self, style):
        return self.style_classifier(style)

def vae_loss(recon_x, x, style_mu, style_logvar, content_mu, content_logvar):
    recon_loss = F.cross_entropy(recon_x.view(-1, recon_x.size(-1)), x.view(-1), ignore_index=0)  # Reconstruction loss
    kl_style = -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())  # KL divergence for style
    kl_content = -0.5 * torch.sum(1 + content_logvar - content_mu.pow(2) - content_logvar.exp())  # KL divergence for content
    return recon_loss + kl_style + kl_content

def multi_task_loss(style_preds, style_labels, content_preds, content_labels):
    style_loss = F.cross_entropy(style_preds, style_labels)  # Style classification loss
    content_loss = F.cross_entropy(content_preds, content_labels)  # Content classification loss
    return style_loss + content_loss

def adversarial_loss(style_preds, content_preds):
    adversarial_style_loss = -F.cross_entropy(style_preds, torch.zeros_like(style_preds))  # Fool style classifier
    adversarial_content_loss = -F.cross_entropy(content_preds, torch.zeros_like(content_preds))  # Fool content classifier
    return adversarial_style_loss + adversarial_content_loss


In [25]:
len(data_loader)

6926

In [26]:
def train_vae_with_dataset(vae, optimizer, data_loader, device):
    """
    Train the VAE with a custom dataset.
    """
    vae.train()
    total_loss = 0
    num_batches = len(data_loader)
    
    print(f"\nTraining on {len(data_loader.dataset)} examples with {num_batches} batches")
    
    # Progress bar
    progress = tqdm(data_loader, total=num_batches, desc="Training")
    
    batch_count = 0
    for input_tokens, style_labels, lengths in progress:
        batch_count += 1
        batch_size = input_tokens.size(0)
        
        # Log batch information
        if batch_count == 1:
            print(f"\nFirst batch shapes:")
            print(f"Input tokens: {input_tokens.shape}")
            print(f"Style labels: {style_labels.shape}")
            print(f"Sequence lengths: {lengths[:5]} (showing first 5)")
        
        # Move data to device
        input_tokens = input_tokens.to(device)
        style_labels = style_labels.to(device)
        
        # Forward pass
        recon_x, style_mu, style_logvar, content_mu, content_logvar, style, content = vae(input_tokens, lengths)
        
        # Calculate losses
        loss_vae = F.cross_entropy(recon_x.view(-1, recon_x.size(-1)), input_tokens.view(-1), ignore_index=0)
        kl_style = -0.5 * torch.sum(1 + style_logvar - style_mu.pow(2) - style_logvar.exp())
        kl_content = -0.5 * torch.sum(1 + content_logvar - content_mu.pow(2) - content_logvar.exp())
        
        style_preds = vae.classify_style(style).squeeze()
        style_labels = style_labels.float()
        loss_multi_task = F.binary_cross_entropy(style_preds, style_labels)
        
        # Total loss with coefficient scaling
        loss = loss_vae + 0.1 * (kl_style + kl_content) + loss_multi_task
        
        # Optimization
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vae.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Update progress bar with current loss
        progress.set_postfix({'batch_loss': f'{loss.item():.4f}'})
        total_loss += loss.item()
        
        # Detailed loss logging every 100 batches
        if batch_count % 100 == 0:
            print(f"\nBatch {batch_count}/{num_batches}")
            print(f"Reconstruction Loss: {loss_vae.item():.4f}")
            print(f"KL Loss: {(kl_style + kl_content).item():.4f}")
            print(f"Style Loss: {loss_multi_task.item():.4f}")
    
    avg_loss = total_loss / num_batches
    print(f"\nCompleted epoch. Average loss: {avg_loss:.4f}")
    return avg_loss

# Training parameters
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vocab_size = len(vocab)
embedding_dim = 256
hidden_dim = 512
style_dim = 32
content_dim = 256
learning_rate = 5e-4
epochs = 10

# Initialize model and optimizer
vae = DisentangledVAE(vocab_size, embedding_dim, hidden_dim, style_dim, content_dim).to(device)
optimizer = torch.optim.Adam(vae.parameters(), lr=learning_rate)

print(f"Starting training for {epochs} epochs...")
for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}")
    print("-" * 50)
    loss = train_vae_with_dataset(vae, optimizer, data_loader, device)
    print(f"Epoch {epoch + 1}/{epochs}, Loss: {loss:.4f}")


Starting training for 10 epochs...

Epoch 1/10
--------------------------------------------------

Training on 443259 examples with 6926 batches


Training:   0%|          | 9/6926 [00:00<01:17, 89.28it/s, batch_loss=7.0265]


First batch shapes:
Input tokens: torch.Size([64, 15])
Style labels: torch.Size([64])
Sequence lengths: [6, 6, 11, 2, 7] (showing first 5)


Training:   2%|▏         | 116/6926 [00:01<01:08, 99.26it/s, batch_loss=2.5323]


Batch 100/6926
Reconstruction Loss: 2.4905
KL Loss: 0.1018
Style Loss: 0.6910


Training:   3%|▎         | 213/6926 [00:02<01:06, 100.41it/s, batch_loss=1.6083]


Batch 200/6926
Reconstruction Loss: 0.8596
KL Loss: 0.0078
Style Loss: 0.7021


Training:   5%|▍         | 312/6926 [00:03<01:05, 100.59it/s, batch_loss=1.1687]


Batch 300/6926
Reconstruction Loss: 0.7121
KL Loss: 0.0003
Style Loss: 0.6962


Training:   6%|▌         | 420/6926 [00:04<01:04, 100.34it/s, batch_loss=1.0416]


Batch 400/6926
Reconstruction Loss: 0.4625
KL Loss: 0.0003
Style Loss: 0.6734


Training:   7%|▋         | 517/6926 [00:05<01:07, 95.04it/s, batch_loss=1.1454] 


Batch 500/6926
Reconstruction Loss: 0.3419
KL Loss: 0.0002
Style Loss: 0.6695


Training:   9%|▉         | 618/6926 [00:06<01:04, 98.32it/s, batch_loss=0.9942]


Batch 600/6926
Reconstruction Loss: 0.2787
KL Loss: 0.7205
Style Loss: 0.7437


Training:  10%|█         | 720/6926 [00:07<01:02, 99.53it/s, batch_loss=0.7953]


Batch 700/6926
Reconstruction Loss: 0.2321
KL Loss: 0.0372
Style Loss: 0.6939


Training:  12%|█▏        | 817/6926 [00:08<01:00, 101.01it/s, batch_loss=0.9279]


Batch 800/6926
Reconstruction Loss: 0.2377
KL Loss: 0.0002
Style Loss: 0.6851


Training:  13%|█▎        | 916/6926 [00:09<00:58, 103.35it/s, batch_loss=0.8242]


Batch 900/6926
Reconstruction Loss: 0.1149
KL Loss: 0.0001
Style Loss: 0.6287


Training:  15%|█▍        | 1015/6926 [00:10<00:57, 103.30it/s, batch_loss=0.8207]


Batch 1000/6926
Reconstruction Loss: 0.1388
KL Loss: 0.0002
Style Loss: 0.7283


Training:  16%|█▌        | 1112/6926 [00:11<00:55, 104.75it/s, batch_loss=0.7217]


Batch 1100/6926
Reconstruction Loss: 0.1023
KL Loss: 0.0002
Style Loss: 0.6628


Training:  17%|█▋        | 1211/6926 [00:12<00:54, 104.15it/s, batch_loss=0.7705]


Batch 1200/6926
Reconstruction Loss: 0.0735
KL Loss: 0.0002
Style Loss: 0.6588


Training:  19%|█▉        | 1310/6926 [00:13<00:53, 104.74it/s, batch_loss=0.7479]


Batch 1300/6926
Reconstruction Loss: 0.0721
KL Loss: 0.0001
Style Loss: 0.6738


Training:  21%|██        | 1420/6926 [00:14<00:52, 105.12it/s, batch_loss=0.7667]


Batch 1400/6926
Reconstruction Loss: 0.0480
KL Loss: 0.0001
Style Loss: 0.6523


Training:  22%|██▏       | 1519/6926 [00:14<00:51, 104.55it/s, batch_loss=0.7451]


Batch 1500/6926
Reconstruction Loss: 0.0419
KL Loss: 0.0001
Style Loss: 0.6529


Training:  23%|██▎       | 1618/6926 [00:15<00:50, 104.76it/s, batch_loss=0.7922]


Batch 1600/6926
Reconstruction Loss: 0.1596
KL Loss: 0.0001
Style Loss: 0.6165


Training:  25%|██▍       | 1717/6926 [00:16<00:50, 103.08it/s, batch_loss=0.7108]


Batch 1700/6926
Reconstruction Loss: 0.0425
KL Loss: 0.0001
Style Loss: 0.6527


Training:  26%|██▌       | 1816/6926 [00:17<00:48, 105.45it/s, batch_loss=0.7470]


Batch 1800/6926
Reconstruction Loss: 0.0614
KL Loss: 0.0001
Style Loss: 0.6843


Training:  28%|██▊       | 1915/6926 [00:18<00:49, 101.82it/s, batch_loss=0.6818]


Batch 1900/6926
Reconstruction Loss: 0.0269
KL Loss: 0.0002
Style Loss: 0.7194


Training:  29%|██▉       | 2013/6926 [00:19<00:47, 103.24it/s, batch_loss=0.6845]


Batch 2000/6926
Reconstruction Loss: 0.0484
KL Loss: 0.0001
Style Loss: 0.7166


Training:  30%|███       | 2111/6926 [00:20<00:48, 98.45it/s, batch_loss=0.7068] 


Batch 2100/6926
Reconstruction Loss: 0.0384
KL Loss: 0.0001
Style Loss: 0.6804


Training:  32%|███▏      | 2220/6926 [00:21<00:46, 101.04it/s, batch_loss=0.6968]


Batch 2200/6926
Reconstruction Loss: 0.0294
KL Loss: 0.0001
Style Loss: 0.7015


Training:  33%|███▎      | 2317/6926 [00:22<00:45, 101.74it/s, batch_loss=0.6972]


Batch 2300/6926
Reconstruction Loss: 0.0360
KL Loss: 0.0001
Style Loss: 0.6604


Training:  35%|███▍      | 2411/6926 [00:23<00:45, 99.44it/s, batch_loss=0.7074] 


Batch 2400/6926
Reconstruction Loss: 0.0476
KL Loss: 0.0001
Style Loss: 0.7230


Training:  36%|███▋      | 2515/6926 [00:24<00:44, 99.40it/s, batch_loss=0.6774]


Batch 2500/6926
Reconstruction Loss: 0.0074
KL Loss: 0.0001
Style Loss: 0.6496


Training:  38%|███▊      | 2614/6926 [00:25<00:43, 100.21it/s, batch_loss=0.6761]


Batch 2600/6926
Reconstruction Loss: 0.0092
KL Loss: 0.0001
Style Loss: 0.6882


Training:  39%|███▉      | 2712/6926 [00:26<00:42, 98.19it/s, batch_loss=0.7229] 


Batch 2700/6926
Reconstruction Loss: 0.0101
KL Loss: 0.0001
Style Loss: 0.7335


Training:  41%|████      | 2815/6926 [00:27<00:41, 99.86it/s, batch_loss=0.6991] 


Batch 2800/6926
Reconstruction Loss: 0.0021
KL Loss: 0.0001
Style Loss: 0.6622


Training:  42%|████▏     | 2912/6926 [00:28<00:40, 99.58it/s, batch_loss=0.6008] 


Batch 2900/6926
Reconstruction Loss: 0.0421
KL Loss: 0.0001
Style Loss: 0.6713


Training:  44%|████▎     | 3014/6926 [00:29<00:39, 98.37it/s, batch_loss=0.6881]


Batch 3000/6926
Reconstruction Loss: 0.0062
KL Loss: 0.0001
Style Loss: 0.6820


Training:  45%|████▌     | 3120/6926 [00:30<00:37, 100.19it/s, batch_loss=0.7028]


Batch 3100/6926
Reconstruction Loss: 0.0029
KL Loss: 0.0001
Style Loss: 0.6859


Training:  46%|████▋     | 3219/6926 [00:31<00:37, 99.98it/s, batch_loss=0.7133] 


Batch 3200/6926
Reconstruction Loss: 0.0015
KL Loss: 0.0001
Style Loss: 0.6798


Training:  48%|████▊     | 3314/6926 [00:32<00:36, 99.16it/s, batch_loss=0.6696] 


Batch 3300/6926
Reconstruction Loss: 0.0052
KL Loss: 0.0001
Style Loss: 0.6420


Training:  49%|████▉     | 3411/6926 [00:33<00:35, 98.54it/s, batch_loss=0.6886] 


Batch 3400/6926
Reconstruction Loss: 0.0042
KL Loss: 0.0001
Style Loss: 0.6874


Training:  51%|█████     | 3517/6926 [00:34<00:34, 99.90it/s, batch_loss=0.6585]


Batch 3500/6926
Reconstruction Loss: 0.0013
KL Loss: 0.0001
Style Loss: 0.6698


Training:  52%|█████▏    | 3618/6926 [00:35<00:33, 99.00it/s, batch_loss=0.6665]


Batch 3600/6926
Reconstruction Loss: 0.0018
KL Loss: 0.0001
Style Loss: 0.6425


Training:  54%|█████▎    | 3710/6926 [00:36<00:32, 99.63it/s, batch_loss=0.7045]


Batch 3700/6926
Reconstruction Loss: 0.0021
KL Loss: 0.0002
Style Loss: 0.6774


Training:  55%|█████▌    | 3820/6926 [00:37<00:30, 100.48it/s, batch_loss=0.7250]


Batch 3800/6926
Reconstruction Loss: 0.0018
KL Loss: 0.0001
Style Loss: 0.6922


Training:  57%|█████▋    | 3918/6926 [00:38<00:30, 100.15it/s, batch_loss=0.6848]


Batch 3900/6926
Reconstruction Loss: 0.0058
KL Loss: 0.0000
Style Loss: 0.7232


Training:  58%|█████▊    | 4016/6926 [00:39<00:29, 98.76it/s, batch_loss=0.6778] 


Batch 4000/6926
Reconstruction Loss: 0.0014
KL Loss: 0.0001
Style Loss: 0.6597


Training:  59%|█████▉    | 4115/6926 [00:40<00:28, 100.04it/s, batch_loss=0.6901]


Batch 4100/6926
Reconstruction Loss: 0.0045
KL Loss: 0.0000
Style Loss: 0.6635


Training:  61%|██████    | 4214/6926 [00:41<00:26, 100.72it/s, batch_loss=0.6308]


Batch 4200/6926
Reconstruction Loss: 0.0017
KL Loss: 0.0001
Style Loss: 0.6862


Training:  62%|██████▏   | 4310/6926 [00:42<00:26, 100.15it/s, batch_loss=0.6581]


Batch 4300/6926
Reconstruction Loss: 0.0066
KL Loss: 0.0001
Style Loss: 0.6558


Training:  64%|██████▍   | 4420/6926 [00:43<00:24, 100.43it/s, batch_loss=0.6730]


Batch 4400/6926
Reconstruction Loss: 0.0010
KL Loss: 0.0001
Style Loss: 0.6538


Training:  65%|██████▌   | 4518/6926 [00:44<00:24, 100.25it/s, batch_loss=0.6050]


Batch 4500/6926
Reconstruction Loss: 0.0013
KL Loss: 0.0001
Style Loss: 0.6907


Training:  67%|██████▋   | 4617/6926 [00:45<00:22, 100.53it/s, batch_loss=0.6634]


Batch 4600/6926
Reconstruction Loss: 0.0007
KL Loss: 0.0000
Style Loss: 0.7141


Training:  68%|██████▊   | 4715/6926 [00:46<00:22, 100.14it/s, batch_loss=0.6989]


Batch 4700/6926
Reconstruction Loss: 0.0118
KL Loss: 0.0001
Style Loss: 0.6950


Training:  69%|██████▉   | 4812/6926 [00:47<00:21, 99.93it/s, batch_loss=0.6846] 


Batch 4800/6926
Reconstruction Loss: 0.0010
KL Loss: 0.0001
Style Loss: 0.7100


Training:  71%|███████   | 4911/6926 [00:48<00:20, 100.15it/s, batch_loss=0.6866]


Batch 4900/6926
Reconstruction Loss: 0.0006
KL Loss: 0.0001
Style Loss: 0.7157


Training:  72%|███████▏  | 5018/6926 [00:49<00:18, 100.51it/s, batch_loss=0.6747]


Batch 5000/6926
Reconstruction Loss: 0.0006
KL Loss: 0.0000
Style Loss: 0.6723


Training:  74%|███████▍  | 5115/6926 [00:50<00:18, 99.86it/s, batch_loss=0.7375] 


Batch 5100/6926
Reconstruction Loss: 0.0008
KL Loss: 0.0002
Style Loss: 0.6762


Training:  75%|███████▌  | 5211/6926 [00:51<00:17, 99.95it/s, batch_loss=0.6648] 


Batch 5200/6926
Reconstruction Loss: 0.0010
KL Loss: 0.0000
Style Loss: 0.6751


Training:  77%|███████▋  | 5320/6926 [00:52<00:15, 101.00it/s, batch_loss=0.7029]


Batch 5300/6926
Reconstruction Loss: 0.0006
KL Loss: 0.0001
Style Loss: 0.6613


Training:  78%|███████▊  | 5417/6926 [00:53<00:15, 99.39it/s, batch_loss=0.6786] 


Batch 5400/6926
Reconstruction Loss: 0.0006
KL Loss: 0.0001
Style Loss: 0.6928


Training:  80%|███████▉  | 5516/6926 [00:54<00:14, 100.16it/s, batch_loss=0.6748]


Batch 5500/6926
Reconstruction Loss: 0.0003
KL Loss: 0.0001
Style Loss: 0.6443


Training:  81%|████████  | 5615/6926 [00:55<00:12, 100.85it/s, batch_loss=0.6896]


Batch 5600/6926
Reconstruction Loss: 0.0005
KL Loss: 0.0001
Style Loss: 0.6791


Training:  82%|████████▏ | 5712/6926 [00:56<00:12, 98.71it/s, batch_loss=0.6738] 


Batch 5700/6926
Reconstruction Loss: 0.0006
KL Loss: 0.0001
Style Loss: 0.6222


Training:  84%|████████▍ | 5811/6926 [00:57<00:11, 100.94it/s, batch_loss=0.6433]


Batch 5800/6926
Reconstruction Loss: 0.0004
KL Loss: 0.0004
Style Loss: 0.7041


Training:  85%|████████▌ | 5910/6926 [00:58<00:10, 100.66it/s, batch_loss=0.6844]


Batch 5900/6926
Reconstruction Loss: 0.0007
KL Loss: -0.0000
Style Loss: 0.6629


Training:  87%|████████▋ | 6020/6926 [00:59<00:08, 100.85it/s, batch_loss=0.6264]


Batch 6000/6926
Reconstruction Loss: 0.0005
KL Loss: 0.0001
Style Loss: 0.6778


Training:  88%|████████▊ | 6119/6926 [01:00<00:07, 101.49it/s, batch_loss=0.6730]


Batch 6100/6926
Reconstruction Loss: 0.0004
KL Loss: 0.0001
Style Loss: 0.6240


Training:  90%|████████▉ | 6218/6926 [01:01<00:06, 102.11it/s, batch_loss=0.7037]


Batch 6200/6926
Reconstruction Loss: 0.0087
KL Loss: 0.0001
Style Loss: 0.6564


Training:  91%|█████████ | 6317/6926 [01:02<00:05, 101.90it/s, batch_loss=0.6822]


Batch 6300/6926
Reconstruction Loss: 0.0004
KL Loss: 0.0001
Style Loss: 0.6708


Training:  93%|█████████▎| 6416/6926 [01:03<00:05, 100.57it/s, batch_loss=0.6679]


Batch 6400/6926
Reconstruction Loss: 0.0003
KL Loss: 0.0001
Style Loss: 0.6725


Training:  94%|█████████▍| 6515/6926 [01:04<00:04, 101.86it/s, batch_loss=0.6426]


Batch 6500/6926
Reconstruction Loss: 0.0004
KL Loss: 0.0001
Style Loss: 0.6825


Training:  95%|█████████▌| 6614/6926 [01:05<00:03, 101.65it/s, batch_loss=0.6804]


Batch 6600/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0001
Style Loss: 0.6789


Training:  97%|█████████▋| 6713/6926 [01:06<00:02, 100.42it/s, batch_loss=0.7161]


Batch 6700/6926
Reconstruction Loss: 0.0004
KL Loss: 0.0001
Style Loss: 0.6815


Training:  98%|█████████▊| 6812/6926 [01:07<00:01, 101.50it/s, batch_loss=0.7052]


Batch 6800/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0001
Style Loss: 0.6497


Training: 100%|█████████▉| 6911/6926 [01:08<00:00, 99.60it/s, batch_loss=0.6599] 


Batch 6900/6926
Reconstruction Loss: 0.0008
KL Loss: 0.0001
Style Loss: 0.6838


Training: 100%|██████████| 6926/6926 [01:08<00:00, 100.71it/s, batch_loss=0.6927]



Completed epoch. Average loss: 0.8246
Epoch 1/10, Loss: 0.8246

Epoch 2/10
--------------------------------------------------

Training on 443259 examples with 6926 batches


Training:   0%|          | 10/6926 [00:00<01:14, 92.86it/s, batch_loss=0.6487]


First batch shapes:
Input tokens: torch.Size([64, 14])
Style labels: torch.Size([64])
Sequence lengths: [7, 7, 14, 8, 8] (showing first 5)


Training:   2%|▏         | 120/6926 [00:01<01:07, 100.93it/s, batch_loss=0.6426]


Batch 100/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0000
Style Loss: 0.6511


Training:   3%|▎         | 219/6926 [00:02<01:06, 101.54it/s, batch_loss=0.6428]


Batch 200/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6524


Training:   5%|▍         | 318/6926 [00:03<01:04, 101.72it/s, batch_loss=0.6945]


Batch 300/6926
Reconstruction Loss: 0.0003
KL Loss: 0.0001
Style Loss: 0.6567


Training:   6%|▌         | 417/6926 [00:04<01:04, 101.68it/s, batch_loss=0.6565]


Batch 400/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6823


Training:   7%|▋         | 516/6926 [00:05<01:03, 101.19it/s, batch_loss=0.6437]


Batch 500/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6271


Training:   9%|▉         | 615/6926 [00:06<01:02, 100.43it/s, batch_loss=0.6944]


Batch 600/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6563


Training:  10%|█         | 714/6926 [00:07<01:01, 100.69it/s, batch_loss=0.7168]


Batch 700/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0001
Style Loss: 0.6491


Training:  12%|█▏        | 813/6926 [00:08<01:01, 100.09it/s, batch_loss=0.7244]


Batch 800/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0000
Style Loss: 0.6654


Training:  13%|█▎        | 912/6926 [00:09<01:00, 100.15it/s, batch_loss=0.6856]


Batch 900/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0002
Style Loss: 0.7108


Training:  15%|█▍        | 1011/6926 [00:10<00:58, 101.58it/s, batch_loss=0.6458]


Batch 1000/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0001
Style Loss: 0.6803


Training:  16%|█▌        | 1110/6926 [00:11<00:57, 101.87it/s, batch_loss=0.6455]


Batch 1100/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6824


Training:  18%|█▊        | 1220/6926 [00:12<00:56, 101.70it/s, batch_loss=0.6731]


Batch 1200/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6734


Training:  19%|█▉        | 1319/6926 [00:13<00:55, 101.45it/s, batch_loss=0.6615]


Batch 1300/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0001
Style Loss: 0.6939


Training:  20%|██        | 1418/6926 [00:14<00:54, 101.97it/s, batch_loss=0.6712]


Batch 1400/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6424


Training:  22%|██▏       | 1517/6926 [00:15<00:53, 101.12it/s, batch_loss=0.6534]


Batch 1500/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.7080


Training:  23%|██▎       | 1616/6926 [00:16<00:51, 102.24it/s, batch_loss=0.7122]


Batch 1600/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6825


Training:  25%|██▍       | 1715/6926 [00:16<00:50, 102.54it/s, batch_loss=0.6157]


Batch 1700/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0002
Style Loss: 0.6571


Training:  26%|██▌       | 1814/6926 [00:17<00:50, 102.00it/s, batch_loss=0.6886]


Batch 1800/6926
Reconstruction Loss: 0.0003
KL Loss: 0.0001
Style Loss: 0.6436


Training:  28%|██▊       | 1913/6926 [00:18<00:49, 101.52it/s, batch_loss=0.6963]


Batch 1900/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6754


Training:  29%|██▉       | 2012/6926 [00:19<00:48, 102.10it/s, batch_loss=0.6960]


Batch 2000/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0003
Style Loss: 0.6820


Training:  30%|███       | 2111/6926 [00:20<00:47, 101.84it/s, batch_loss=0.6939]


Batch 2100/6926
Reconstruction Loss: 0.0141
KL Loss: 0.0001
Style Loss: 0.6450


Training:  32%|███▏      | 2210/6926 [00:21<00:46, 100.77it/s, batch_loss=0.6852]


Batch 2200/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6679


Training:  33%|███▎      | 2320/6926 [00:22<00:45, 101.38it/s, batch_loss=0.6708]


Batch 2300/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0002
Style Loss: 0.6897


Training:  35%|███▍      | 2419/6926 [00:23<00:44, 101.54it/s, batch_loss=0.6535]


Batch 2400/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0002
Style Loss: 0.6800


Training:  36%|███▋      | 2518/6926 [00:24<00:43, 101.06it/s, batch_loss=0.6607]


Batch 2500/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6986


Training:  38%|███▊      | 2617/6926 [00:25<00:42, 101.63it/s, batch_loss=0.7062]


Batch 2600/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6607


Training:  39%|███▉      | 2716/6926 [00:26<00:41, 100.45it/s, batch_loss=0.7030]


Batch 2700/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0002
Style Loss: 0.6679


Training:  41%|████      | 2815/6926 [00:27<00:41, 100.26it/s, batch_loss=0.6859]


Batch 2800/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0002
Style Loss: 0.6708


Training:  42%|████▏     | 2912/6926 [00:28<00:40, 99.89it/s, batch_loss=0.6678] 


Batch 2900/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0003
Style Loss: 0.6380


Training:  44%|████▎     | 3020/6926 [00:29<00:39, 100.12it/s, batch_loss=0.5979]


Batch 3000/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0002
Style Loss: 0.6778


Training:  45%|████▌     | 3119/6926 [00:30<00:37, 101.76it/s, batch_loss=0.6714]


Batch 3100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6507


Training:  46%|████▋     | 3216/6926 [00:31<00:37, 98.97it/s, batch_loss=0.6628] 


Batch 3200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6421


Training:  48%|████▊     | 3314/6926 [00:32<00:36, 100.07it/s, batch_loss=0.6533]


Batch 3300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.7120


Training:  49%|████▉     | 3419/6926 [00:33<00:35, 97.69it/s, batch_loss=0.6879] 


Batch 3400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6760


Training:  51%|█████     | 3517/6926 [00:34<00:33, 101.06it/s, batch_loss=0.6292]


Batch 3500/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0002
Style Loss: 0.6778


Training:  52%|█████▏    | 3616/6926 [00:35<00:32, 102.22it/s, batch_loss=0.6663]


Batch 3600/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0000
Style Loss: 0.6939


Training:  54%|█████▎    | 3715/6926 [00:36<00:32, 100.18it/s, batch_loss=0.7161]


Batch 3700/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6632


Training:  55%|█████▌    | 3811/6926 [00:37<00:31, 97.46it/s, batch_loss=0.6515] 


Batch 3800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6998


Training:  56%|█████▋    | 3913/6926 [00:38<00:30, 97.92it/s, batch_loss=0.6604]


Batch 3900/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6772


Training:  58%|█████▊    | 4019/6926 [00:39<00:29, 99.15it/s, batch_loss=0.6666] 


Batch 4000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6830


Training:  59%|█████▉    | 4112/6926 [00:40<00:28, 100.46it/s, batch_loss=0.6959]


Batch 4100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6987


Training:  61%|██████    | 4216/6926 [00:41<00:27, 98.45it/s, batch_loss=0.6535] 


Batch 4200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6689


Training:  62%|██████▏   | 4319/6926 [00:42<00:26, 99.37it/s, batch_loss=0.6600]


Batch 4300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6530


Training:  64%|██████▎   | 4413/6926 [00:43<00:25, 99.46it/s, batch_loss=0.6293] 


Batch 4400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6746


Training:  65%|██████▌   | 4516/6926 [00:44<00:24, 98.41it/s, batch_loss=0.6777]


Batch 4500/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0002
Style Loss: 0.6442


Training:  67%|██████▋   | 4618/6926 [00:45<00:23, 98.13it/s, batch_loss=0.6838]


Batch 4600/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6733


Training:  68%|██████▊   | 4710/6926 [00:46<00:22, 98.61it/s, batch_loss=0.6559]


Batch 4700/6926
Reconstruction Loss: 0.0018
KL Loss: 0.0001
Style Loss: 0.6606


Training:  70%|██████▉   | 4814/6926 [00:47<00:21, 99.99it/s, batch_loss=0.6461]


Batch 4800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6747


Training:  71%|███████   | 4913/6926 [00:48<00:19, 101.79it/s, batch_loss=0.6707]


Batch 4900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6778


Training:  72%|███████▏  | 5012/6926 [00:49<00:18, 101.81it/s, batch_loss=0.6916]


Batch 5000/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6629


Training:  74%|███████▍  | 5111/6926 [00:50<00:18, 100.18it/s, batch_loss=0.6244]


Batch 5100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6864


Training:  75%|███████▌  | 5210/6926 [00:51<00:16, 102.07it/s, batch_loss=0.6564]


Batch 5200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6627


Training:  77%|███████▋  | 5320/6926 [00:52<00:15, 101.25it/s, batch_loss=0.6732]


Batch 5300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7141


Training:  78%|███████▊  | 5415/6926 [00:53<00:15, 99.69it/s, batch_loss=0.6698] 


Batch 5400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6657


Training:  80%|███████▉  | 5519/6926 [00:54<00:14, 100.22it/s, batch_loss=0.6098]


Batch 5500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6695


Training:  81%|████████  | 5618/6926 [00:55<00:12, 101.58it/s, batch_loss=0.6798]


Batch 5600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6364


Training:  83%|████████▎ | 5717/6926 [00:56<00:11, 101.26it/s, batch_loss=0.6722]


Batch 5700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6713


Training:  84%|████████▍ | 5816/6926 [00:57<00:11, 99.75it/s, batch_loss=0.7227] 


Batch 5800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6837


Training:  85%|████████▌ | 5913/6926 [00:58<00:09, 101.59it/s, batch_loss=0.6465]


Batch 5900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6655


Training:  87%|████████▋ | 6012/6926 [00:59<00:08, 101.99it/s, batch_loss=0.6477]


Batch 6000/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6919


Training:  88%|████████▊ | 6111/6926 [01:00<00:07, 102.21it/s, batch_loss=0.6976]


Batch 6100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7128


Training:  90%|████████▉ | 6210/6926 [01:01<00:07, 102.23it/s, batch_loss=0.6761]


Batch 6200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6736


Training:  91%|█████████▏| 6320/6926 [01:02<00:05, 101.36it/s, batch_loss=0.6610]


Batch 6300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6505


Training:  93%|█████████▎| 6419/6926 [01:03<00:05, 101.12it/s, batch_loss=0.6720]


Batch 6400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0005
Style Loss: 0.6604


Training:  94%|█████████▍| 6518/6926 [01:04<00:04, 101.01it/s, batch_loss=0.6890]


Batch 6500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6905


Training:  96%|█████████▌| 6617/6926 [01:05<00:03, 100.94it/s, batch_loss=0.6895]


Batch 6600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6984


Training:  97%|█████████▋| 6716/6926 [01:06<00:02, 100.35it/s, batch_loss=0.6650]


Batch 6700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6635


Training:  98%|█████████▊| 6813/6926 [01:07<00:01, 98.68it/s, batch_loss=0.6842] 


Batch 6800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6729


Training: 100%|█████████▉| 6912/6926 [01:08<00:00, 101.89it/s, batch_loss=0.7053]


Batch 6900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6778


Training: 100%|██████████| 6926/6926 [01:08<00:00, 100.66it/s, batch_loss=0.6547]



Completed epoch. Average loss: 0.6746
Epoch 2/10, Loss: 0.6746

Epoch 3/10
--------------------------------------------------

Training on 443259 examples with 6926 batches


Training:   0%|          | 10/6926 [00:00<01:12, 95.09it/s, batch_loss=0.6636]


First batch shapes:
Input tokens: torch.Size([64, 15])
Style labels: torch.Size([64])
Sequence lengths: [10, 12, 12, 12, 7] (showing first 5)


Training:   2%|▏         | 120/6926 [00:01<01:07, 100.46it/s, batch_loss=0.6794]


Batch 100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6502


Training:   3%|▎         | 219/6926 [00:02<01:06, 100.90it/s, batch_loss=0.6832]


Batch 200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6483


Training:   5%|▍         | 318/6926 [00:03<01:04, 101.69it/s, batch_loss=0.7154]


Batch 300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6610


Training:   6%|▌         | 417/6926 [00:04<01:04, 101.13it/s, batch_loss=0.7012]


Batch 400/6926
Reconstruction Loss: 0.0000
KL Loss: 31.5431
Style Loss: 0.7070


Training:   7%|▋         | 514/6926 [00:05<01:04, 99.97it/s, batch_loss=0.7155] 


Batch 500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6707


Training:   9%|▉         | 611/6926 [00:06<01:02, 100.54it/s, batch_loss=0.7091]


Batch 600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0031
Style Loss: 0.6749


Training:  10%|█         | 710/6926 [00:07<01:00, 102.19it/s, batch_loss=0.6287]


Batch 700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6430


Training:  12%|█▏        | 820/6926 [00:08<01:00, 101.58it/s, batch_loss=0.6785]


Batch 800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6894


Training:  13%|█▎        | 919/6926 [00:09<00:59, 101.07it/s, batch_loss=0.6368]


Batch 900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6508


Training:  15%|█▍        | 1018/6926 [00:10<00:58, 101.56it/s, batch_loss=0.7367]


Batch 1000/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6538


Training:  16%|█▌        | 1117/6926 [00:11<00:56, 102.49it/s, batch_loss=0.6818]


Batch 1100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.7260


Training:  18%|█▊        | 1216/6926 [00:12<00:56, 101.58it/s, batch_loss=0.6597]


Batch 1200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6551


Training:  19%|█▉        | 1314/6926 [00:13<00:55, 101.10it/s, batch_loss=0.6981]


Batch 1300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6543


Training:  20%|██        | 1413/6926 [00:14<00:54, 100.86it/s, batch_loss=0.6480]


Batch 1400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6972


Training:  22%|██▏       | 1512/6926 [00:15<00:54, 98.96it/s, batch_loss=0.7217] 


Batch 1500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6589


Training:  23%|██▎       | 1611/6926 [00:16<00:52, 100.90it/s, batch_loss=0.6645]


Batch 1600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6836


Training:  25%|██▍       | 1717/6926 [00:17<00:52, 99.74it/s, batch_loss=0.6773] 


Batch 1700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6520


Training:  26%|██▌       | 1815/6926 [00:18<00:51, 99.40it/s, batch_loss=0.7237] 


Batch 1800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6805


Training:  28%|██▊       | 1919/6926 [00:19<00:50, 98.40it/s, batch_loss=0.6413]


Batch 1900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6494


Training:  29%|██▉       | 2012/6926 [00:20<00:49, 98.97it/s, batch_loss=0.6481]


Batch 2000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6652


Training:  30%|███       | 2111/6926 [00:21<00:47, 101.42it/s, batch_loss=0.6615]


Batch 2100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7230


Training:  32%|███▏      | 2210/6926 [00:22<00:46, 101.36it/s, batch_loss=0.7139]


Batch 2200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6857


Training:  33%|███▎      | 2309/6926 [00:23<00:46, 100.08it/s, batch_loss=0.6545]


Batch 2300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6747


Training:  35%|███▍      | 2417/6926 [00:24<00:44, 100.95it/s, batch_loss=0.7108]


Batch 2400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6631


Training:  36%|███▋      | 2516/6926 [00:25<00:43, 100.91it/s, batch_loss=0.6288]


Batch 2500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7340


Training:  38%|███▊      | 2615/6926 [00:26<00:42, 101.67it/s, batch_loss=0.6256]


Batch 2600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6435


Training:  39%|███▉      | 2712/6926 [00:27<00:42, 99.88it/s, batch_loss=0.6866] 


Batch 2700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7332


Training:  41%|████      | 2811/6926 [00:28<00:40, 101.44it/s, batch_loss=0.6734]


Batch 2800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7139


Training:  42%|████▏     | 2910/6926 [00:28<00:39, 102.06it/s, batch_loss=0.6911]


Batch 2900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6727


Training:  44%|████▎     | 3020/6926 [00:29<00:38, 101.67it/s, batch_loss=0.6534]


Batch 3000/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6408


Training:  45%|████▌     | 3119/6926 [00:30<00:37, 102.41it/s, batch_loss=0.6527]


Batch 3100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6443


Training:  46%|████▋     | 3218/6926 [00:31<00:36, 102.10it/s, batch_loss=0.6625]


Batch 3200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6868


Training:  48%|████▊     | 3317/6926 [00:32<00:36, 98.50it/s, batch_loss=0.7127] 


Batch 3300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6421


Training:  49%|████▉     | 3416/6926 [00:33<00:34, 101.18it/s, batch_loss=0.7285]


Batch 3400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6616


Training:  51%|█████     | 3515/6926 [00:34<00:34, 99.76it/s, batch_loss=0.6883] 


Batch 3500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6726


Training:  52%|█████▏    | 3610/6926 [00:35<00:32, 100.55it/s, batch_loss=0.6780]


Batch 3600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6777


Training:  54%|█████▎    | 3718/6926 [00:36<00:32, 100.11it/s, batch_loss=0.6569]


Batch 3700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0026
Style Loss: 0.6708


Training:  55%|█████▌    | 3814/6926 [00:37<00:31, 98.54it/s, batch_loss=0.6104] 


Batch 3800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6547


Training:  56%|█████▋    | 3910/6926 [00:38<00:30, 98.61it/s, batch_loss=0.7190] 


Batch 3900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6532


Training:  58%|█████▊    | 4016/6926 [00:39<00:29, 98.93it/s, batch_loss=0.6444] 


Batch 4000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6382


Training:  59%|█████▉    | 4119/6926 [00:40<00:28, 97.23it/s, batch_loss=0.6993]


Batch 4100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6786


Training:  61%|██████    | 4217/6926 [00:41<00:26, 100.89it/s, batch_loss=0.6698]


Batch 4200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6777


Training:  62%|██████▏   | 4313/6926 [00:42<00:26, 100.31it/s, batch_loss=0.6762]


Batch 4300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6511


Training:  64%|██████▍   | 4420/6926 [00:43<00:25, 99.71it/s, batch_loss=0.6812] 


Batch 4400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6489


Training:  65%|██████▌   | 4517/6926 [00:44<00:24, 99.99it/s, batch_loss=0.6376] 


Batch 4500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6797


Training:  67%|██████▋   | 4618/6926 [00:45<00:23, 98.35it/s, batch_loss=0.6464]


Batch 4600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6849


Training:  68%|██████▊   | 4711/6926 [00:47<00:22, 99.59it/s, batch_loss=0.6866]


Batch 4700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6929


Training:  70%|██████▉   | 4818/6926 [00:48<00:21, 100.06it/s, batch_loss=0.6323]


Batch 4800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6609


Training:  71%|███████   | 4915/6926 [00:49<00:20, 98.52it/s, batch_loss=0.6337] 


Batch 4900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6415


Training:  72%|███████▏  | 5018/6926 [00:50<00:19, 99.36it/s, batch_loss=0.6973]


Batch 5000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6502


Training:  74%|███████▍  | 5119/6926 [00:51<00:18, 99.15it/s, batch_loss=0.6683]


Batch 5100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7052


Training:  75%|███████▌  | 5215/6926 [00:52<00:17, 98.45it/s, batch_loss=0.6951] 


Batch 5200/6926
Reconstruction Loss: 0.0000
KL Loss: 1.0453
Style Loss: 0.6138


Training:  77%|███████▋  | 5318/6926 [00:53<00:16, 100.01it/s, batch_loss=0.6393]


Batch 5300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6695


Training:  78%|███████▊  | 5419/6926 [00:54<00:15, 99.93it/s, batch_loss=0.6402] 


Batch 5400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6691


Training:  80%|███████▉  | 5510/6926 [00:55<00:14, 98.54it/s, batch_loss=0.6958]


Batch 5500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6734


Training:  81%|████████  | 5610/6926 [00:56<00:13, 98.01it/s, batch_loss=0.6950]


Batch 5600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7038


Training:  83%|████████▎ | 5715/6926 [00:57<00:12, 99.91it/s, batch_loss=0.6841] 


Batch 5700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6753


Training:  84%|████████▍ | 5819/6926 [00:58<00:11, 99.76it/s, batch_loss=0.6492] 


Batch 5800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6655


Training:  85%|████████▌ | 5913/6926 [00:59<00:10, 99.03it/s, batch_loss=0.6690]


Batch 5900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6904


Training:  87%|████████▋ | 6015/6926 [01:00<00:09, 98.26it/s, batch_loss=0.6832]


Batch 6000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7105


Training:  88%|████████▊ | 6115/6926 [01:01<00:08, 98.29it/s, batch_loss=0.6867]


Batch 6100/6926
Reconstruction Loss: 0.0002
KL Loss: 0.0001
Style Loss: 0.6727


Training:  90%|████████▉ | 6216/6926 [01:02<00:07, 98.73it/s, batch_loss=0.6474]


Batch 6200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.7100


Training:  91%|█████████ | 6318/6926 [01:03<00:06, 99.19it/s, batch_loss=0.6658]


Batch 6300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6672


Training:  93%|█████████▎| 6410/6926 [01:04<00:05, 98.99it/s, batch_loss=0.6806]


Batch 6400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6969


Training:  94%|█████████▍| 6512/6926 [01:05<00:04, 98.61it/s, batch_loss=0.6534]


Batch 6500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6611


Training:  96%|█████████▌| 6618/6926 [01:06<00:03, 99.92it/s, batch_loss=0.7003] 


Batch 6600/6926
Reconstruction Loss: 0.0001
KL Loss: 0.0001
Style Loss: 0.6968


Training:  97%|█████████▋| 6715/6926 [01:07<00:02, 100.07it/s, batch_loss=0.7272]


Batch 6700/6926
Reconstruction Loss: 0.0000
KL Loss: -0.0000
Style Loss: 0.6302


Training:  98%|█████████▊| 6814/6926 [01:08<00:01, 100.13it/s, batch_loss=0.7067]


Batch 6800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6986


Training: 100%|█████████▉| 6913/6926 [01:09<00:00, 99.87it/s, batch_loss=0.6438] 


Batch 6900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6800


Training: 100%|██████████| 6926/6926 [01:09<00:00, 99.94it/s, batch_loss=0.6758]



Completed epoch. Average loss: 0.6789
Epoch 3/10, Loss: 0.6789

Epoch 4/10
--------------------------------------------------

Training on 443259 examples with 6926 batches


Training:   0%|          | 10/6926 [00:00<01:15, 91.76it/s, batch_loss=0.6716]


First batch shapes:
Input tokens: torch.Size([64, 15])
Style labels: torch.Size([64])
Sequence lengths: [14, 6, 4, 7, 11] (showing first 5)


Training:   2%|▏         | 116/6926 [00:01<01:08, 100.14it/s, batch_loss=0.7260]


Batch 100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6690


Training:   3%|▎         | 212/6926 [00:02<01:07, 99.55it/s, batch_loss=0.7067] 


Batch 200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6637


Training:   5%|▍         | 316/6926 [00:03<01:06, 99.66it/s, batch_loss=0.6659]


Batch 300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6683


Training:   6%|▌         | 417/6926 [00:04<01:05, 99.25it/s, batch_loss=0.6690]


Batch 400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6725


Training:   7%|▋         | 519/6926 [00:05<01:05, 98.03it/s, batch_loss=0.6599]


Batch 500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6529


Training:   9%|▉         | 610/6926 [00:06<01:03, 99.46it/s, batch_loss=0.6640]


Batch 600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6720


Training:  10%|█         | 711/6926 [00:07<01:03, 98.51it/s, batch_loss=0.6973]


Batch 700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6849


Training:  12%|█▏        | 812/6926 [00:08<01:02, 98.54it/s, batch_loss=0.6631]


Batch 800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6949


Training:  13%|█▎        | 912/6926 [00:09<01:02, 96.98it/s, batch_loss=0.6312]


Batch 900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6718


Training:  15%|█▍        | 1013/6926 [00:10<01:00, 97.77it/s, batch_loss=0.6841]


Batch 1000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6843


Training:  16%|█▌        | 1113/6926 [00:11<00:59, 97.81it/s, batch_loss=0.6779]


Batch 1100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6741


Training:  18%|█▊        | 1216/6926 [00:12<00:57, 98.72it/s, batch_loss=0.6730]


Batch 1200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6694


Training:  19%|█▉        | 1317/6926 [00:13<00:58, 96.69it/s, batch_loss=0.7199]


Batch 1300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6833


Training:  20%|██        | 1410/6926 [00:14<00:55, 99.54it/s, batch_loss=0.7030]


Batch 1400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6627


Training:  22%|██▏       | 1517/6926 [00:15<00:55, 97.94it/s, batch_loss=0.6589] 


Batch 1500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6825


Training:  23%|██▎       | 1610/6926 [00:16<00:53, 99.29it/s, batch_loss=0.6455]


Batch 1600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6505


Training:  25%|██▍       | 1715/6926 [00:17<00:52, 99.94it/s, batch_loss=0.6608]


Batch 1700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6948


Training:  26%|██▋       | 1819/6926 [00:18<00:51, 99.89it/s, batch_loss=0.6932] 


Batch 1800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6979


Training:  28%|██▊       | 1912/6926 [00:19<00:50, 99.50it/s, batch_loss=0.6518] 


Batch 1900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6543


Training:  29%|██▉       | 2016/6926 [00:20<00:49, 99.64it/s, batch_loss=0.6902]


Batch 2000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7254


Training:  30%|███       | 2111/6926 [00:21<00:48, 98.66it/s, batch_loss=0.6487] 


Batch 2100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6380


Training:  32%|███▏      | 2209/6926 [00:22<00:47, 100.26it/s, batch_loss=0.6491]


Batch 2200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7161


Training:  33%|███▎      | 2315/6926 [00:23<00:46, 99.97it/s, batch_loss=0.6884] 


Batch 2300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6518


Training:  35%|███▍      | 2409/6926 [00:24<00:46, 97.20it/s, batch_loss=0.6555] 


Batch 2400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6813


Training:  36%|███▌      | 2510/6926 [00:25<00:44, 99.83it/s, batch_loss=0.6421]


Batch 2500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6760


Training:  38%|███▊      | 2611/6926 [00:26<00:43, 98.64it/s, batch_loss=0.6600]


Batch 2600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6819


Training:  39%|███▉      | 2715/6926 [00:27<00:42, 99.37it/s, batch_loss=0.6659]


Batch 2700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6206


Training:  41%|████      | 2816/6926 [00:28<00:42, 97.40it/s, batch_loss=0.6479]


Batch 2800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6375


Training:  42%|████▏     | 2916/6926 [00:29<00:41, 96.65it/s, batch_loss=0.6559]


Batch 2900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6706


Training:  44%|████▎     | 3016/6926 [00:30<00:39, 98.66it/s, batch_loss=0.6849]


Batch 3000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6642


Training:  45%|████▌     | 3117/6926 [00:31<00:38, 98.37it/s, batch_loss=0.6646]


Batch 3100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.7019


Training:  46%|████▋     | 3217/6926 [00:32<00:37, 98.18it/s, batch_loss=0.6841]


Batch 3200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6758


Training:  48%|████▊     | 3317/6926 [00:33<00:38, 94.15it/s, batch_loss=0.6857]


Batch 3300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6768


Training:  49%|████▉     | 3419/6926 [00:34<00:35, 97.93it/s, batch_loss=0.7019]


Batch 3400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6673


Training:  51%|█████     | 3513/6926 [00:35<00:34, 99.78it/s, batch_loss=0.6779]


Batch 3500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6565


Training:  52%|█████▏    | 3618/6926 [00:36<00:33, 99.48it/s, batch_loss=0.7033] 


Batch 3600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6993


Training:  54%|█████▎    | 3717/6926 [00:37<00:31, 101.31it/s, batch_loss=0.6379]


Batch 3700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6529


Training:  55%|█████▌    | 3816/6926 [00:38<00:30, 101.00it/s, batch_loss=0.7090]


Batch 3800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7090


Training:  57%|█████▋    | 3915/6926 [00:39<00:30, 100.17it/s, batch_loss=0.6365]


Batch 3900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7262


Training:  58%|█████▊    | 4013/6926 [00:40<00:29, 100.03it/s, batch_loss=0.6445]


Batch 4000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6680


Training:  59%|█████▉    | 4111/6926 [00:41<00:28, 100.46it/s, batch_loss=0.6836]


Batch 4100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6937


Training:  61%|██████    | 4210/6926 [00:42<00:27, 100.57it/s, batch_loss=0.6768]


Batch 4200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6931


Training:  62%|██████▏   | 4320/6926 [00:43<00:25, 100.75it/s, batch_loss=0.6767]


Batch 4300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6545


Training:  64%|██████▍   | 4419/6926 [00:44<00:24, 101.73it/s, batch_loss=0.6984]


Batch 4400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6800


Training:  65%|██████▌   | 4518/6926 [00:45<00:24, 99.10it/s, batch_loss=0.6545] 


Batch 4500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6942


Training:  67%|██████▋   | 4616/6926 [00:46<00:22, 101.79it/s, batch_loss=0.6540]


Batch 4600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6406


Training:  68%|██████▊   | 4715/6926 [00:47<00:21, 101.54it/s, batch_loss=0.6685]


Batch 4700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7041


Training:  70%|██████▉   | 4814/6926 [00:48<00:21, 99.85it/s, batch_loss=0.6701] 


Batch 4800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6853


Training:  71%|███████   | 4913/6926 [00:49<00:19, 101.11it/s, batch_loss=0.6638]


Batch 4900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6939


Training:  72%|███████▏  | 5012/6926 [00:50<00:18, 100.93it/s, batch_loss=0.6841]


Batch 5000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6499


Training:  74%|███████▍  | 5111/6926 [00:51<00:18, 100.58it/s, batch_loss=0.6266]


Batch 5100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6719


Training:  75%|███████▌  | 5210/6926 [00:52<00:17, 100.00it/s, batch_loss=0.6297]


Batch 5200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6576


Training:  77%|███████▋  | 5319/6926 [00:53<00:15, 100.62it/s, batch_loss=0.6700]


Batch 5300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6962


Training:  78%|███████▊  | 5414/6926 [00:54<00:15, 99.70it/s, batch_loss=0.6822] 


Batch 5400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6801


Training:  80%|███████▉  | 5512/6926 [00:55<00:14, 100.63it/s, batch_loss=0.7010]


Batch 5500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6846


Training:  81%|████████  | 5611/6926 [00:56<00:13, 99.43it/s, batch_loss=0.6898] 


Batch 5600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6544


Training:  83%|████████▎ | 5714/6926 [00:57<00:12, 100.05it/s, batch_loss=0.6617]


Batch 5700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6986


Training:  84%|████████▍ | 5813/6926 [00:58<00:10, 101.53it/s, batch_loss=0.6401]


Batch 5800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6592


Training:  85%|████████▌ | 5912/6926 [00:59<00:09, 101.71it/s, batch_loss=0.6848]


Batch 5900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.7080


Training:  87%|████████▋ | 6011/6926 [01:00<00:09, 100.84it/s, batch_loss=0.6778]


Batch 6000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6510


Training:  88%|████████▊ | 6116/6926 [01:01<00:08, 99.81it/s, batch_loss=0.6319] 


Batch 6100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6464


Training:  90%|████████▉ | 6213/6926 [01:02<00:07, 99.88it/s, batch_loss=0.7098] 


Batch 6200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6657


Training:  91%|█████████ | 6318/6926 [01:03<00:06, 100.58it/s, batch_loss=0.6718]


Batch 6300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7004


Training:  93%|█████████▎| 6417/6926 [01:04<00:05, 99.98it/s, batch_loss=0.7024] 


Batch 6400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7144


Training:  94%|█████████▍| 6512/6926 [01:05<00:04, 100.37it/s, batch_loss=0.7316]


Batch 6500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.7332


Training:  95%|█████████▌| 6610/6926 [01:06<00:03, 100.74it/s, batch_loss=0.6823]


Batch 6600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6645


Training:  97%|█████████▋| 6716/6926 [01:07<00:02, 99.41it/s, batch_loss=0.6370] 


Batch 6700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6596


Training:  98%|█████████▊| 6814/6926 [01:08<00:01, 100.05it/s, batch_loss=0.7112]


Batch 6800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6733


Training: 100%|█████████▉| 6911/6926 [01:09<00:00, 100.29it/s, batch_loss=0.6858]


Batch 6900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7085


Training: 100%|██████████| 6926/6926 [01:09<00:00, 99.43it/s, batch_loss=0.6708] 



Completed epoch. Average loss: 0.6739
Epoch 4/10, Loss: 0.6739

Epoch 5/10
--------------------------------------------------

Training on 443259 examples with 6926 batches


Training:   0%|          | 10/6926 [00:00<01:14, 92.40it/s, batch_loss=0.6606]


First batch shapes:
Input tokens: torch.Size([64, 15])
Style labels: torch.Size([64])
Sequence lengths: [3, 8, 6, 6, 10] (showing first 5)


Training:   2%|▏         | 120/6926 [00:01<01:07, 101.33it/s, batch_loss=0.6431]


Batch 100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6936


Training:   3%|▎         | 219/6926 [00:02<01:06, 100.80it/s, batch_loss=0.6525]


Batch 200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.7024


Training:   5%|▍         | 318/6926 [00:03<01:05, 100.56it/s, batch_loss=0.6519]


Batch 300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6663


Training:   6%|▌         | 417/6926 [00:04<01:04, 100.56it/s, batch_loss=0.6822]


Batch 400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0000
Style Loss: 0.6622


Training:   7%|▋         | 516/6926 [00:05<01:03, 101.62it/s, batch_loss=0.6779]


Batch 500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6722


Training:   9%|▉         | 615/6926 [00:06<01:02, 101.44it/s, batch_loss=0.6516]


Batch 600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6583


Training:  10%|█         | 714/6926 [00:07<01:00, 102.24it/s, batch_loss=0.7108]


Batch 700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6546


Training:  12%|█▏        | 813/6926 [00:08<01:00, 101.49it/s, batch_loss=0.6790]


Batch 800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6451


Training:  13%|█▎        | 912/6926 [00:09<00:59, 101.09it/s, batch_loss=0.6442]


Batch 900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6615


Training:  15%|█▍        | 1011/6926 [00:10<00:58, 101.11it/s, batch_loss=0.6404]


Batch 1000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6870


Training:  16%|█▌        | 1110/6926 [00:11<00:57, 100.32it/s, batch_loss=0.6368]


Batch 1100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6852


Training:  18%|█▊        | 1220/6926 [00:12<00:56, 101.14it/s, batch_loss=0.6667]


Batch 1200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6901


Training:  19%|█▉        | 1318/6926 [00:13<00:55, 100.88it/s, batch_loss=0.6737]


Batch 1300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6577


Training:  20%|██        | 1417/6926 [00:14<00:54, 101.83it/s, batch_loss=0.7108]


Batch 1400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6746


Training:  22%|██▏       | 1516/6926 [00:15<00:52, 102.11it/s, batch_loss=0.6917]


Batch 1500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6864


Training:  23%|██▎       | 1615/6926 [00:16<00:51, 102.46it/s, batch_loss=0.7169]


Batch 1600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6799


Training:  25%|██▍       | 1714/6926 [00:16<00:51, 102.10it/s, batch_loss=0.7036]


Batch 1700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6905


Training:  26%|██▌       | 1813/6926 [00:17<00:50, 100.91it/s, batch_loss=0.6567]


Batch 1800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6597


Training:  28%|██▊       | 1912/6926 [00:18<00:49, 101.57it/s, batch_loss=0.6674]


Batch 1900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6631


Training:  29%|██▉       | 2011/6926 [00:19<00:49, 99.86it/s, batch_loss=0.6784] 


Batch 2000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6636


Training:  30%|███       | 2110/6926 [00:20<00:47, 100.91it/s, batch_loss=0.6553]


Batch 2100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6380


Training:  32%|███▏      | 2220/6926 [00:21<00:46, 101.03it/s, batch_loss=0.6421]


Batch 2200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6760


Training:  33%|███▎      | 2318/6926 [00:22<00:46, 99.44it/s, batch_loss=0.6977] 


Batch 2300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6437


Training:  35%|███▍      | 2413/6926 [00:23<00:44, 100.79it/s, batch_loss=0.6761]


Batch 2400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6618


Training:  36%|███▋      | 2512/6926 [00:24<00:43, 100.91it/s, batch_loss=0.6680]


Batch 2500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6612


Training:  38%|███▊      | 2611/6926 [00:25<00:42, 101.20it/s, batch_loss=0.6706]


Batch 2600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7184


Training:  39%|███▉      | 2710/6926 [00:26<00:41, 101.51it/s, batch_loss=0.6285]


Batch 2700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6885


Training:  41%|████      | 2820/6926 [00:27<00:40, 101.14it/s, batch_loss=0.6898]


Batch 2800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6755


Training:  42%|████▏     | 2919/6926 [00:28<00:39, 101.77it/s, batch_loss=0.6163]


Batch 2900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7098


Training:  44%|████▎     | 3018/6926 [00:29<00:38, 100.36it/s, batch_loss=0.6984]


Batch 3000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6449


Training:  45%|████▌     | 3117/6926 [00:30<00:37, 100.68it/s, batch_loss=0.6358]


Batch 3100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6562


Training:  46%|████▋     | 3216/6926 [00:31<00:37, 100.06it/s, batch_loss=0.6819]


Batch 3200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6763


Training:  48%|████▊     | 3313/6926 [00:32<00:36, 99.54it/s, batch_loss=0.6591] 


Batch 3300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6615


Training:  49%|████▉     | 3419/6926 [00:33<00:34, 100.63it/s, batch_loss=0.6360]


Batch 3400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6544


Training:  51%|█████     | 3518/6926 [00:34<00:33, 101.55it/s, batch_loss=0.6388]


Batch 3500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6804


Training:  52%|█████▏    | 3617/6926 [00:35<00:32, 101.76it/s, batch_loss=0.6112]


Batch 3600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6434


Training:  54%|█████▎    | 3716/6926 [00:36<00:31, 102.39it/s, batch_loss=0.6542]


Batch 3700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6673


Training:  55%|█████▌    | 3815/6926 [00:37<00:30, 101.21it/s, batch_loss=0.6605]


Batch 3800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.7008


Training:  56%|█████▋    | 3913/6926 [00:38<00:30, 98.31it/s, batch_loss=0.7069] 


Batch 3900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.7414


Training:  58%|█████▊    | 4009/6926 [00:39<00:29, 99.61it/s, batch_loss=0.6758] 


Batch 4000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6401


Training:  59%|█████▉    | 4112/6926 [00:40<00:28, 98.48it/s, batch_loss=0.7110]


Batch 4100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6500


Training:  61%|██████    | 4215/6926 [00:41<00:27, 99.27it/s, batch_loss=0.7429]


Batch 4200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6647


Training:  62%|██████▏   | 4311/6926 [00:42<00:25, 101.28it/s, batch_loss=0.6896]


Batch 4300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6908


Training:  64%|██████▎   | 4410/6926 [00:43<00:24, 101.21it/s, batch_loss=0.6833]


Batch 4400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6422


Training:  65%|██████▌   | 4520/6926 [00:44<00:23, 100.79it/s, batch_loss=0.6962]


Batch 4500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6867


Training:  67%|██████▋   | 4619/6926 [00:45<00:22, 101.41it/s, batch_loss=0.6971]


Batch 4600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6529


Training:  68%|██████▊   | 4718/6926 [00:46<00:22, 99.78it/s, batch_loss=0.7258] 


Batch 4700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6716


Training:  70%|██████▉   | 4814/6926 [00:47<00:21, 99.84it/s, batch_loss=0.6414] 


Batch 4800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6724


Training:  71%|███████   | 4911/6926 [00:48<00:20, 99.99it/s, batch_loss=0.6763] 


Batch 4900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6347


Training:  72%|███████▏  | 5015/6926 [00:49<00:19, 99.85it/s, batch_loss=0.7038] 


Batch 5000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6851


Training:  74%|███████▍  | 5117/6926 [00:50<00:18, 99.97it/s, batch_loss=0.6843]


Batch 5100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6851


Training:  75%|███████▌  | 5213/6926 [00:51<00:17, 100.00it/s, batch_loss=0.6578]


Batch 5200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7007


Training:  77%|███████▋  | 5310/6926 [00:52<00:16, 100.41it/s, batch_loss=0.7003]


Batch 5300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6777


Training:  78%|███████▊  | 5417/6926 [00:53<00:15, 99.94it/s, batch_loss=0.6041] 


Batch 5400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6035


Training:  80%|███████▉  | 5515/6926 [00:54<00:13, 100.99it/s, batch_loss=0.6492]


Batch 5500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6501


Training:  81%|████████  | 5614/6926 [00:55<00:12, 101.72it/s, batch_loss=0.6421]


Batch 5600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6728


Training:  82%|████████▏ | 5713/6926 [00:56<00:11, 101.50it/s, batch_loss=0.6768]


Batch 5700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6784


Training:  84%|████████▍ | 5812/6926 [00:57<00:10, 101.81it/s, batch_loss=0.6930]


Batch 5800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6788


Training:  85%|████████▌ | 5911/6926 [00:58<00:09, 101.94it/s, batch_loss=0.6955]


Batch 5900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6173


Training:  87%|████████▋ | 6020/6926 [00:59<00:09, 100.47it/s, batch_loss=0.6628]


Batch 6000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6630


Training:  88%|████████▊ | 6119/6926 [01:00<00:08, 100.22it/s, batch_loss=0.6580]


Batch 6100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6387


Training:  90%|████████▉ | 6217/6926 [01:01<00:07, 100.52it/s, batch_loss=0.6916]


Batch 6200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7047


Training:  91%|█████████ | 6316/6926 [01:02<00:06, 99.51it/s, batch_loss=0.6955] 


Batch 6300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6554


Training:  93%|█████████▎| 6413/6926 [01:03<00:05, 101.13it/s, batch_loss=0.6566]


Batch 6400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6559


Training:  94%|█████████▍| 6512/6926 [01:04<00:04, 100.58it/s, batch_loss=0.6482]


Batch 6500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6678


Training:  95%|█████████▌| 6611/6926 [01:05<00:03, 101.01it/s, batch_loss=0.6786]


Batch 6600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6705


Training:  97%|█████████▋| 6710/6926 [01:06<00:02, 100.72it/s, batch_loss=0.6783]


Batch 6700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6636


Training:  98%|█████████▊| 6820/6926 [01:07<00:01, 101.54it/s, batch_loss=0.6607]


Batch 6800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6974


Training: 100%|█████████▉| 6917/6926 [01:08<00:00, 100.94it/s, batch_loss=0.7039]


Batch 6900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6614


Training: 100%|██████████| 6926/6926 [01:08<00:00, 100.65it/s, batch_loss=0.7080]



Completed epoch. Average loss: 0.6737
Epoch 5/10, Loss: 0.6737

Epoch 6/10
--------------------------------------------------

Training on 443259 examples with 6926 batches


Training:   0%|          | 10/6926 [00:00<01:12, 95.57it/s, batch_loss=0.7163]


First batch shapes:
Input tokens: torch.Size([64, 15])
Style labels: torch.Size([64])
Sequence lengths: [4, 10, 15, 14, 4] (showing first 5)


Training:   2%|▏         | 120/6926 [00:01<01:06, 101.63it/s, batch_loss=0.7107]


Batch 100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6604


Training:   3%|▎         | 219/6926 [00:02<01:06, 101.27it/s, batch_loss=0.6602]


Batch 200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6983


Training:   5%|▍         | 318/6926 [00:03<01:05, 101.02it/s, batch_loss=0.6919]


Batch 300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6470


Training:   6%|▌         | 417/6926 [00:04<01:04, 100.91it/s, batch_loss=0.7051]


Batch 400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6393


Training:   7%|▋         | 516/6926 [00:05<01:03, 101.20it/s, batch_loss=0.6502]


Batch 500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6659


Training:   9%|▉         | 615/6926 [00:06<01:01, 102.44it/s, batch_loss=0.6440]


Batch 600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7014


Training:  10%|█         | 714/6926 [00:07<01:01, 101.38it/s, batch_loss=0.6638]


Batch 700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6739


Training:  12%|█▏        | 813/6926 [00:08<01:00, 101.63it/s, batch_loss=0.6650]


Batch 800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6077


Training:  13%|█▎        | 911/6926 [00:09<01:00, 100.25it/s, batch_loss=0.6864]


Batch 900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6650


Training:  15%|█▍        | 1018/6926 [00:10<00:58, 100.81it/s, batch_loss=0.6751]


Batch 1000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6888


Training:  16%|█▌        | 1117/6926 [00:11<00:59, 98.43it/s, batch_loss=0.6720] 


Batch 1100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6572


Training:  17%|█▋        | 1211/6926 [00:12<00:57, 99.86it/s, batch_loss=0.6599]


Batch 1200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6524


Training:  19%|█▉        | 1319/6926 [00:13<00:55, 100.58it/s, batch_loss=0.6437]


Batch 1300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6398


Training:  20%|██        | 1418/6926 [00:14<00:55, 99.46it/s, batch_loss=0.6619] 


Batch 1400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6789


Training:  22%|██▏       | 1517/6926 [00:15<00:53, 100.17it/s, batch_loss=0.6988]


Batch 1500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6428


Training:  23%|██▎       | 1616/6926 [00:16<00:52, 100.90it/s, batch_loss=0.6376]


Batch 1600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6786


Training:  25%|██▍       | 1715/6926 [00:17<00:51, 100.33it/s, batch_loss=0.6913]


Batch 1700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6823


Training:  26%|██▌       | 1814/6926 [00:18<00:50, 101.00it/s, batch_loss=0.6941]


Batch 1800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6920


Training:  28%|██▊       | 1913/6926 [00:19<00:49, 100.92it/s, batch_loss=0.7147]


Batch 1900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6522


Training:  29%|██▉       | 2012/6926 [00:20<00:48, 101.66it/s, batch_loss=0.6623]


Batch 2000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6581


Training:  30%|███       | 2111/6926 [00:21<00:47, 101.19it/s, batch_loss=0.7208]


Batch 2100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6323


Training:  32%|███▏      | 2210/6926 [00:22<00:46, 101.56it/s, batch_loss=0.6925]


Batch 2200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6508


Training:  33%|███▎      | 2309/6926 [00:23<00:46, 100.20it/s, batch_loss=0.6629]


Batch 2300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6964


Training:  35%|███▍      | 2415/6926 [00:24<00:45, 99.07it/s, batch_loss=0.6793] 


Batch 2400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6686


Training:  36%|███▋      | 2519/6926 [00:25<00:44, 99.60it/s, batch_loss=0.6803]


Batch 2500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.7221


Training:  38%|███▊      | 2611/6926 [00:26<00:43, 98.96it/s, batch_loss=0.6690]


Batch 2600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6315


Training:  39%|███▉      | 2717/6926 [00:27<00:42, 99.92it/s, batch_loss=0.6823] 


Batch 2700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6863


Training:  41%|████      | 2816/6926 [00:28<00:40, 101.52it/s, batch_loss=0.6775]


Batch 2800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6788


Training:  42%|████▏     | 2915/6926 [00:29<00:39, 101.39it/s, batch_loss=0.6778]


Batch 2900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6431


Training:  44%|████▎     | 3014/6926 [00:30<00:38, 102.19it/s, batch_loss=0.6456]


Batch 3000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6646


Training:  45%|████▍     | 3113/6926 [00:30<00:37, 101.58it/s, batch_loss=0.6821]


Batch 3100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6896


Training:  46%|████▋     | 3212/6926 [00:31<00:36, 101.64it/s, batch_loss=0.7016]


Batch 3200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6490


Training:  48%|████▊     | 3311/6926 [00:32<00:35, 101.73it/s, batch_loss=0.6735]


Batch 3300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6521


Training:  49%|████▉     | 3410/6926 [00:33<00:34, 100.62it/s, batch_loss=0.6888]


Batch 3400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6756


Training:  51%|█████     | 3519/6926 [00:34<00:33, 100.37it/s, batch_loss=0.7134]


Batch 3500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6298


Training:  52%|█████▏    | 3618/6926 [00:35<00:33, 99.93it/s, batch_loss=0.6349] 


Batch 3600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6339


Training:  54%|█████▎    | 3717/6926 [00:36<00:31, 100.73it/s, batch_loss=0.6844]


Batch 3700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6875


Training:  55%|█████▌    | 3815/6926 [00:37<00:31, 100.35it/s, batch_loss=0.6338]


Batch 3800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6617


Training:  57%|█████▋    | 3914/6926 [00:38<00:30, 100.40it/s, batch_loss=0.6304]


Batch 3900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6532


Training:  58%|█████▊    | 4013/6926 [00:39<00:29, 100.27it/s, batch_loss=0.6831]


Batch 4000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6502


Training:  59%|█████▉    | 4111/6926 [00:40<00:27, 100.99it/s, batch_loss=0.7025]


Batch 4100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6635


Training:  61%|██████    | 4210/6926 [00:41<00:27, 99.04it/s, batch_loss=0.6902] 


Batch 4200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6329


Training:  62%|██████▏   | 4320/6926 [00:42<00:25, 101.54it/s, batch_loss=0.6603]


Batch 4300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7049


Training:  64%|██████▍   | 4418/6926 [00:43<00:25, 98.48it/s, batch_loss=0.6399] 


Batch 4400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6559


Training:  65%|██████▌   | 4515/6926 [00:44<00:23, 100.47it/s, batch_loss=0.6755]


Batch 4500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6496


Training:  67%|██████▋   | 4613/6926 [00:45<00:23, 100.07it/s, batch_loss=0.6817]


Batch 4600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6612


Training:  68%|██████▊   | 4712/6926 [00:46<00:21, 101.30it/s, batch_loss=0.6699]


Batch 4700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6774


Training:  69%|██████▉   | 4811/6926 [00:47<00:20, 102.12it/s, batch_loss=0.6463]


Batch 4800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7139


Training:  71%|███████   | 4910/6926 [00:48<00:19, 101.38it/s, batch_loss=0.6177]


Batch 4900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6944


Training:  72%|███████▏  | 5019/6926 [00:49<00:18, 101.52it/s, batch_loss=0.6694]


Batch 5000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6547


Training:  74%|███████▍  | 5118/6926 [00:50<00:17, 102.10it/s, batch_loss=0.6480]


Batch 5100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6695


Training:  75%|███████▌  | 5217/6926 [00:51<00:16, 102.39it/s, batch_loss=0.6619]


Batch 5200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6950


Training:  77%|███████▋  | 5316/6926 [00:52<00:15, 101.70it/s, batch_loss=0.6731]


Batch 5300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6767


Training:  78%|███████▊  | 5415/6926 [00:53<00:15, 100.71it/s, batch_loss=0.6453]


Batch 5400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6638


Training:  80%|███████▉  | 5514/6926 [00:54<00:14, 100.01it/s, batch_loss=0.6443]


Batch 5500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6973


Training:  81%|████████  | 5610/6926 [00:55<00:13, 99.78it/s, batch_loss=0.6934] 


Batch 5600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6847


Training:  83%|████████▎ | 5715/6926 [00:56<00:12, 100.70it/s, batch_loss=0.6765]


Batch 5700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6922


Training:  84%|████████▍ | 5814/6926 [00:57<00:11, 100.72it/s, batch_loss=0.6620]


Batch 5800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6619


Training:  85%|████████▌ | 5913/6926 [00:58<00:10, 101.06it/s, batch_loss=0.7006]


Batch 5900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6861


Training:  87%|████████▋ | 6012/6926 [00:59<00:09, 100.88it/s, batch_loss=0.6534]


Batch 6000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6844


Training:  88%|████████▊ | 6111/6926 [01:00<00:08, 101.25it/s, batch_loss=0.6807]


Batch 6100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6984


Training:  90%|████████▉ | 6210/6926 [01:01<00:07, 101.62it/s, batch_loss=0.6677]


Batch 6200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7116


Training:  91%|█████████▏| 6320/6926 [01:02<00:05, 101.10it/s, batch_loss=0.6594]


Batch 6300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6725


Training:  93%|█████████▎| 6419/6926 [01:03<00:05, 100.73it/s, batch_loss=0.6875]


Batch 6400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6471


Training:  94%|█████████▍| 6518/6926 [01:04<00:04, 100.86it/s, batch_loss=0.7087]


Batch 6500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6749


Training:  96%|█████████▌| 6617/6926 [01:05<00:03, 100.72it/s, batch_loss=0.6634]


Batch 6600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6937


Training:  97%|█████████▋| 6715/6926 [01:06<00:02, 100.75it/s, batch_loss=0.6966]


Batch 6700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6752


Training:  98%|█████████▊| 6812/6926 [01:07<00:01, 99.75it/s, batch_loss=0.6974] 


Batch 6800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6449


Training: 100%|█████████▉| 6918/6926 [01:08<00:00, 101.45it/s, batch_loss=0.6895]


Batch 6900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6561


Training: 100%|██████████| 6926/6926 [01:08<00:00, 100.67it/s, batch_loss=0.6931]



Completed epoch. Average loss: 0.6735
Epoch 6/10, Loss: 0.6735

Epoch 7/10
--------------------------------------------------

Training on 443259 examples with 6926 batches


Training:   0%|          | 10/6926 [00:00<01:15, 91.63it/s, batch_loss=0.6762]


First batch shapes:
Input tokens: torch.Size([64, 15])
Style labels: torch.Size([64])
Sequence lengths: [14, 3, 13, 15, 6] (showing first 5)


Training:   2%|▏         | 119/6926 [00:01<01:07, 100.69it/s, batch_loss=0.6572]


Batch 100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7037


Training:   3%|▎         | 218/6926 [00:02<01:05, 101.82it/s, batch_loss=0.6394]


Batch 200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6983


Training:   5%|▍         | 317/6926 [00:03<01:05, 101.28it/s, batch_loss=0.7132]


Batch 300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6543


Training:   6%|▌         | 416/6926 [00:04<01:05, 98.99it/s, batch_loss=0.6155] 


Batch 400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6981


Training:   7%|▋         | 514/6926 [00:05<01:03, 101.46it/s, batch_loss=0.7179]


Batch 500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6661


Training:   9%|▉         | 613/6926 [00:06<01:02, 101.63it/s, batch_loss=0.6469]


Batch 600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6643


Training:  10%|█         | 711/6926 [00:07<01:03, 98.63it/s, batch_loss=0.6916] 


Batch 700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6105


Training:  12%|█▏        | 819/6926 [00:08<01:00, 100.59it/s, batch_loss=0.6417]


Batch 800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6263


Training:  13%|█▎        | 918/6926 [00:09<00:59, 101.25it/s, batch_loss=0.6195]


Batch 900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6880


Training:  15%|█▍        | 1017/6926 [00:10<00:58, 100.95it/s, batch_loss=0.6906]


Batch 1000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6683


Training:  16%|█▌        | 1116/6926 [00:11<00:57, 100.87it/s, batch_loss=0.7013]


Batch 1100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6555


Training:  18%|█▊        | 1215/6926 [00:12<00:56, 100.83it/s, batch_loss=0.6570]


Batch 1200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6732


Training:  19%|█▉        | 1312/6926 [00:13<00:55, 100.34it/s, batch_loss=0.6621]


Batch 1300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6667


Training:  20%|██        | 1411/6926 [00:14<00:54, 101.45it/s, batch_loss=0.6456]


Batch 1400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6386


Training:  22%|██▏       | 1510/6926 [00:15<00:53, 101.70it/s, batch_loss=0.6671]


Batch 1500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6498


Training:  23%|██▎       | 1620/6926 [00:16<00:52, 100.30it/s, batch_loss=0.6445]


Batch 1600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6396


Training:  25%|██▍       | 1719/6926 [00:17<00:51, 100.74it/s, batch_loss=0.7108]


Batch 1700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6456


Training:  26%|██▌       | 1818/6926 [00:18<00:50, 101.46it/s, batch_loss=0.6511]


Batch 1800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6751


Training:  28%|██▊       | 1917/6926 [00:19<00:49, 101.62it/s, batch_loss=0.7136]


Batch 1900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6554


Training:  29%|██▉       | 2016/6926 [00:20<00:48, 101.00it/s, batch_loss=0.7096]


Batch 2000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6782


Training:  31%|███       | 2115/6926 [00:21<00:47, 100.70it/s, batch_loss=0.7075]


Batch 2100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7118


Training:  32%|███▏      | 2214/6926 [00:22<00:46, 101.70it/s, batch_loss=0.6687]


Batch 2200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6874


Training:  33%|███▎      | 2313/6926 [00:23<00:45, 102.30it/s, batch_loss=0.6944]


Batch 2300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6689


Training:  35%|███▍      | 2412/6926 [00:23<00:44, 102.58it/s, batch_loss=0.6986]


Batch 2400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6356


Training:  36%|███▋      | 2511/6926 [00:24<00:43, 102.01it/s, batch_loss=0.6976]


Batch 2500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6577


Training:  38%|███▊      | 2610/6926 [00:25<00:42, 101.47it/s, batch_loss=0.6706]


Batch 2600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6946


Training:  39%|███▉      | 2720/6926 [00:26<00:41, 101.55it/s, batch_loss=0.7075]


Batch 2700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6779


Training:  41%|████      | 2819/6926 [00:27<00:40, 101.24it/s, batch_loss=0.7048]


Batch 2800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6881


Training:  42%|████▏     | 2918/6926 [00:28<00:39, 100.87it/s, batch_loss=0.6526]


Batch 2900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6796


Training:  44%|████▎     | 3017/6926 [00:29<00:38, 100.82it/s, batch_loss=0.6733]


Batch 3000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6475


Training:  45%|████▍     | 3116/6926 [00:30<00:37, 101.36it/s, batch_loss=0.6540]


Batch 3100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6675


Training:  46%|████▋     | 3215/6926 [00:31<00:36, 100.91it/s, batch_loss=0.6673]


Batch 3200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6671


Training:  48%|████▊     | 3314/6926 [00:32<00:35, 100.97it/s, batch_loss=0.6728]


Batch 3300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6724


Training:  49%|████▉     | 3412/6926 [00:33<00:34, 101.28it/s, batch_loss=0.6639]


Batch 3400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6854


Training:  51%|█████     | 3511/6926 [00:34<00:33, 101.26it/s, batch_loss=0.6554]


Batch 3500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6422


Training:  52%|█████▏    | 3610/6926 [00:35<00:32, 102.10it/s, batch_loss=0.6947]


Batch 3600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7123


Training:  54%|█████▎    | 3720/6926 [00:36<00:31, 101.90it/s, batch_loss=0.6431]


Batch 3700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6662


Training:  55%|█████▌    | 3819/6926 [00:37<00:30, 101.98it/s, batch_loss=0.6856]


Batch 3800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6654


Training:  57%|█████▋    | 3918/6926 [00:38<00:29, 101.17it/s, batch_loss=0.6738]


Batch 3900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6835


Training:  58%|█████▊    | 4016/6926 [00:39<00:29, 100.33it/s, batch_loss=0.6386]


Batch 4000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6540


Training:  59%|█████▉    | 4115/6926 [00:40<00:27, 101.16it/s, batch_loss=0.6454]


Batch 4100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6763


Training:  61%|██████    | 4214/6926 [00:41<00:26, 101.50it/s, batch_loss=0.7295]


Batch 4200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6962


Training:  62%|██████▏   | 4312/6926 [00:42<00:25, 101.12it/s, batch_loss=0.6767]


Batch 4300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7038


Training:  64%|██████▎   | 4411/6926 [00:43<00:24, 102.04it/s, batch_loss=0.7085]


Batch 4400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6438


Training:  65%|██████▌   | 4510/6926 [00:44<00:23, 102.35it/s, batch_loss=0.7032]


Batch 4500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6673


Training:  67%|██████▋   | 4620/6926 [00:45<00:22, 101.69it/s, batch_loss=0.7310]


Batch 4600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6514


Training:  68%|██████▊   | 4719/6926 [00:46<00:21, 101.48it/s, batch_loss=0.6924]


Batch 4700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6399


Training:  70%|██████▉   | 4817/6926 [00:47<00:21, 100.22it/s, batch_loss=0.6868]


Batch 4800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6721


Training:  71%|███████   | 4915/6926 [00:48<00:20, 99.49it/s, batch_loss=0.6535] 


Batch 4900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6649


Training:  72%|███████▏  | 5018/6926 [00:49<00:19, 99.31it/s, batch_loss=0.6703]


Batch 5000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6818


Training:  74%|███████▍  | 5115/6926 [00:50<00:18, 97.14it/s, batch_loss=0.7077] 


Batch 5100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6943


Training:  75%|███████▌  | 5220/6926 [00:51<00:17, 98.85it/s, batch_loss=0.6505]


Batch 5200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6740


Training:  77%|███████▋  | 5313/6926 [00:52<00:16, 99.38it/s, batch_loss=0.6667]


Batch 5300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6841


Training:  78%|███████▊  | 5415/6926 [00:53<00:15, 96.76it/s, batch_loss=0.6718]


Batch 5400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7145


Training:  80%|███████▉  | 5510/6926 [00:54<00:14, 99.86it/s, batch_loss=0.6637]


Batch 5500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6399


Training:  81%|████████  | 5620/6926 [00:55<00:12, 100.87it/s, batch_loss=0.6614]


Batch 5600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6869


Training:  83%|████████▎ | 5717/6926 [00:56<00:12, 100.31it/s, batch_loss=0.6880]


Batch 5700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6484


Training:  84%|████████▍ | 5816/6926 [00:57<00:11, 100.33it/s, batch_loss=0.6663]


Batch 5800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6906


Training:  85%|████████▌ | 5915/6926 [00:58<00:10, 100.12it/s, batch_loss=0.6870]


Batch 5900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6662


Training:  87%|████████▋ | 6011/6926 [00:59<00:09, 100.05it/s, batch_loss=0.6474]


Batch 6000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6764


Training:  88%|████████▊ | 6115/6926 [01:00<00:08, 99.93it/s, batch_loss=0.6478] 


Batch 6100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6962


Training:  90%|████████▉ | 6210/6926 [01:01<00:07, 99.62it/s, batch_loss=0.6652] 


Batch 6200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6891


Training:  91%|█████████ | 6317/6926 [01:02<00:06, 101.04it/s, batch_loss=0.6892]


Batch 6300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6501


Training:  93%|█████████▎| 6415/6926 [01:03<00:05, 101.37it/s, batch_loss=0.6824]


Batch 6400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6727


Training:  94%|█████████▍| 6514/6926 [01:04<00:04, 100.64it/s, batch_loss=0.7163]


Batch 6500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6513


Training:  95%|█████████▌| 6613/6926 [01:05<00:03, 101.83it/s, batch_loss=0.6412]


Batch 6600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6578


Training:  97%|█████████▋| 6712/6926 [01:06<00:02, 102.05it/s, batch_loss=0.6618]


Batch 6700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6709


Training:  98%|█████████▊| 6811/6926 [01:07<00:01, 100.46it/s, batch_loss=0.6259]


Batch 6800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7031


Training: 100%|█████████▉| 6910/6926 [01:08<00:00, 99.17it/s, batch_loss=0.6592] 


Batch 6900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6641


Training: 100%|██████████| 6926/6926 [01:08<00:00, 100.63it/s, batch_loss=0.6977]



Completed epoch. Average loss: 0.6735
Epoch 7/10, Loss: 0.6735

Epoch 8/10
--------------------------------------------------

Training on 443259 examples with 6926 batches


Training:   0%|          | 10/6926 [00:00<01:14, 92.51it/s, batch_loss=0.7140]


First batch shapes:
Input tokens: torch.Size([64, 15])
Style labels: torch.Size([64])
Sequence lengths: [7, 2, 12, 9, 6] (showing first 5)


Training:   2%|▏         | 116/6926 [00:01<01:07, 100.62it/s, batch_loss=0.6767]


Batch 100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6698


Training:   3%|▎         | 215/6926 [00:02<01:06, 100.47it/s, batch_loss=0.6559]


Batch 200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6667


Training:   5%|▍         | 314/6926 [00:03<01:05, 101.23it/s, batch_loss=0.6722]


Batch 300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6358


Training:   6%|▌         | 413/6926 [00:04<01:04, 100.36it/s, batch_loss=0.6703]


Batch 400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6569


Training:   7%|▋         | 512/6926 [00:05<01:03, 101.24it/s, batch_loss=0.6973]


Batch 500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6651


Training:   9%|▉         | 611/6926 [00:06<01:03, 99.93it/s, batch_loss=0.6424] 


Batch 600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7122


Training:  10%|█         | 713/6926 [00:07<01:04, 96.92it/s, batch_loss=0.6665]


Batch 700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6443


Training:  12%|█▏        | 812/6926 [00:08<01:00, 100.34it/s, batch_loss=0.7012]


Batch 800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6456


Training:  13%|█▎        | 910/6926 [00:09<01:02, 96.12it/s, batch_loss=0.6851] 


Batch 900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6264


Training:  15%|█▍        | 1014/6926 [00:10<01:00, 98.37it/s, batch_loss=0.6003]


Batch 1000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6436


Training:  16%|█▌        | 1113/6926 [00:11<00:58, 99.67it/s, batch_loss=0.6603] 


Batch 1100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6603


Training:  17%|█▋        | 1210/6926 [00:12<00:57, 98.92it/s, batch_loss=0.6432] 


Batch 1200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6628


Training:  19%|█▉        | 1312/6926 [00:13<00:57, 97.61it/s, batch_loss=0.7086]


Batch 1300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6442


Training:  20%|██        | 1411/6926 [00:14<00:54, 100.46it/s, batch_loss=0.6996]


Batch 1400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6727


Training:  22%|██▏       | 1509/6926 [00:15<00:53, 100.40it/s, batch_loss=0.6864]


Batch 1500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6874


Training:  23%|██▎       | 1614/6926 [00:16<00:53, 98.80it/s, batch_loss=0.6733] 


Batch 1600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7016


Training:  25%|██▍       | 1719/6926 [00:17<00:51, 101.01it/s, batch_loss=0.6847]


Batch 1700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.7065


Training:  26%|██▌       | 1818/6926 [00:18<00:50, 100.48it/s, batch_loss=0.6633]


Batch 1800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7159


Training:  28%|██▊       | 1914/6926 [00:19<00:50, 99.32it/s, batch_loss=0.6230] 


Batch 1900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6557


Training:  29%|██▉       | 2010/6926 [00:20<00:49, 100.13it/s, batch_loss=0.6591]


Batch 2000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6592


Training:  31%|███       | 2115/6926 [00:21<00:48, 98.43it/s, batch_loss=0.6544] 


Batch 2100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6727


Training:  32%|███▏      | 2217/6926 [00:22<00:47, 99.47it/s, batch_loss=0.6806]


Batch 2200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6852


Training:  33%|███▎      | 2320/6926 [00:23<00:46, 99.40it/s, batch_loss=0.6569]


Batch 2300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6909


Training:  35%|███▍      | 2411/6926 [00:24<00:45, 98.95it/s, batch_loss=0.6631]


Batch 2400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6991


Training:  36%|███▋      | 2512/6926 [00:25<00:45, 96.24it/s, batch_loss=0.6619]


Batch 2500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6897


Training:  38%|███▊      | 2614/6926 [00:26<00:43, 99.11it/s, batch_loss=0.7017]


Batch 2600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6395


Training:  39%|███▉      | 2718/6926 [00:27<00:42, 99.77it/s, batch_loss=0.6686]


Batch 2700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6989


Training:  41%|████      | 2810/6926 [00:28<00:41, 98.53it/s, batch_loss=0.6680]


Batch 2800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6809


Training:  42%|████▏     | 2911/6926 [00:29<00:40, 99.19it/s, batch_loss=0.6556]


Batch 2900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7149


Training:  44%|████▎     | 3014/6926 [00:30<00:39, 99.86it/s, batch_loss=0.6846]


Batch 3000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0002
Style Loss: 0.6803


Training:  45%|████▌     | 3119/6926 [00:31<00:37, 100.38it/s, batch_loss=0.6867]


Batch 3100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7268


Training:  46%|████▋     | 3218/6926 [00:32<00:36, 100.50it/s, batch_loss=0.7093]


Batch 3200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6524


Training:  48%|████▊     | 3317/6926 [00:33<00:36, 98.89it/s, batch_loss=0.6291] 


Batch 3300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6436


Training:  49%|████▉     | 3415/6926 [00:34<00:35, 100.26it/s, batch_loss=0.6754]


Batch 3400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6632


Training:  51%|█████     | 3513/6926 [00:35<00:33, 101.44it/s, batch_loss=0.7068]


Batch 3500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7002


Training:  52%|█████▏    | 3612/6926 [00:36<00:33, 100.25it/s, batch_loss=0.7079]


Batch 3600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6637


Training:  54%|█████▎    | 3710/6926 [00:37<00:31, 100.89it/s, batch_loss=0.6913]


Batch 3700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6476


Training:  55%|█████▌    | 3817/6926 [00:38<00:31, 99.86it/s, batch_loss=0.6889] 


Batch 3800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6771


Training:  57%|█████▋    | 3920/6926 [00:39<00:30, 99.61it/s, batch_loss=0.6232]


Batch 3900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6929


Training:  58%|█████▊    | 4018/6926 [00:40<00:29, 98.87it/s, batch_loss=0.6923] 


Batch 4000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6580


Training:  59%|█████▉    | 4114/6926 [00:41<00:27, 101.17it/s, batch_loss=0.6089]


Batch 4100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7088


Training:  61%|██████    | 4213/6926 [00:42<00:26, 101.76it/s, batch_loss=0.7027]


Batch 4200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6733


Training:  62%|██████▏   | 4311/6926 [00:43<00:26, 99.60it/s, batch_loss=0.6182] 


Batch 4300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6284


Training:  64%|██████▍   | 4419/6926 [00:44<00:25, 99.51it/s, batch_loss=0.6771] 


Batch 4400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6863


Training:  65%|██████▌   | 4509/6926 [00:45<00:24, 98.51it/s, batch_loss=0.6944]


Batch 4500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6868


Training:  67%|██████▋   | 4610/6926 [00:46<00:23, 96.54it/s, batch_loss=0.6595]


Batch 4600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7253


Training:  68%|██████▊   | 4714/6926 [00:47<00:22, 99.93it/s, batch_loss=0.6494]


Batch 4700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6493


Training:  70%|██████▉   | 4818/6926 [00:48<00:21, 98.14it/s, batch_loss=0.6744] 


Batch 4800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7134


Training:  71%|███████   | 4917/6926 [00:49<00:19, 101.85it/s, batch_loss=0.6922]


Batch 4900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6800


Training:  72%|███████▏  | 5016/6926 [00:50<00:18, 100.66it/s, batch_loss=0.6528]


Batch 5000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6803


Training:  74%|███████▍  | 5113/6926 [00:51<00:18, 98.22it/s, batch_loss=0.6545] 


Batch 5100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6202


Training:  75%|███████▌  | 5219/6926 [00:52<00:17, 100.38it/s, batch_loss=0.6505]


Batch 5200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6426


Training:  77%|███████▋  | 5317/6926 [00:53<00:16, 99.78it/s, batch_loss=0.6902] 


Batch 5300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6678


Training:  78%|███████▊  | 5414/6926 [00:54<00:15, 100.72it/s, batch_loss=0.6757]


Batch 5400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7318


Training:  80%|███████▉  | 5513/6926 [00:55<00:14, 100.69it/s, batch_loss=0.6877]


Batch 5500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6713


Training:  81%|████████  | 5612/6926 [00:56<00:13, 100.71it/s, batch_loss=0.6957]


Batch 5600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6876


Training:  83%|████████▎ | 5719/6926 [00:57<00:12, 99.95it/s, batch_loss=0.6322] 


Batch 5700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6780


Training:  84%|████████▍ | 5816/6926 [00:58<00:11, 100.62it/s, batch_loss=0.6648]


Batch 5800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7091


Training:  85%|████████▌ | 5913/6926 [00:59<00:10, 99.10it/s, batch_loss=0.6264] 


Batch 5900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6864


Training:  87%|████████▋ | 6019/6926 [01:00<00:08, 101.30it/s, batch_loss=0.6380]


Batch 6000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6802


Training:  88%|████████▊ | 6118/6926 [01:01<00:08, 100.91it/s, batch_loss=0.6615]


Batch 6100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6900


Training:  90%|████████▉ | 6214/6926 [01:02<00:07, 100.14it/s, batch_loss=0.7136]


Batch 6200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6646


Training:  91%|█████████ | 6311/6926 [01:03<00:06, 99.71it/s, batch_loss=0.6287] 


Batch 6300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6612


Training:  93%|█████████▎| 6409/6926 [01:04<00:05, 100.48it/s, batch_loss=0.6233]


Batch 6400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7050


Training:  94%|█████████▍| 6512/6926 [01:05<00:04, 97.55it/s, batch_loss=0.6950] 


Batch 6500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6732


Training:  96%|█████████▌| 6618/6926 [01:06<00:03, 99.42it/s, batch_loss=0.6934] 


Batch 6600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6750


Training:  97%|█████████▋| 6712/6926 [01:07<00:02, 99.63it/s, batch_loss=0.6766]


Batch 6700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6377


Training:  98%|█████████▊| 6818/6926 [01:08<00:01, 99.57it/s, batch_loss=0.6873] 


Batch 6800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6372


Training: 100%|█████████▉| 6913/6926 [01:09<00:00, 99.91it/s, batch_loss=0.6878] 


Batch 6900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6743


Training: 100%|██████████| 6926/6926 [01:09<00:00, 99.57it/s, batch_loss=0.6506]



Completed epoch. Average loss: 0.6734
Epoch 8/10, Loss: 0.6734

Epoch 9/10
--------------------------------------------------

Training on 443259 examples with 6926 batches


Training:   0%|          | 9/6926 [00:00<01:16, 89.95it/s, batch_loss=0.6774]


First batch shapes:
Input tokens: torch.Size([64, 15])
Style labels: torch.Size([64])
Sequence lengths: [9, 10, 14, 7, 10] (showing first 5)


Training:   2%|▏         | 114/6926 [00:01<01:08, 99.92it/s, batch_loss=0.6625] 


Batch 100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6449


Training:   3%|▎         | 221/6926 [00:02<01:04, 104.22it/s, batch_loss=0.6568]


Batch 200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.7035


Training:   5%|▍         | 320/6926 [00:03<01:05, 101.02it/s, batch_loss=0.7227]


Batch 300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6578


Training:   6%|▌         | 418/6926 [00:04<01:05, 99.89it/s, batch_loss=0.6577] 


Batch 400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6598


Training:   7%|▋         | 510/6926 [00:05<01:04, 99.90it/s, batch_loss=0.6403]


Batch 500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6689


Training:   9%|▉         | 617/6926 [00:06<01:03, 100.07it/s, batch_loss=0.6856]


Batch 600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6740


Training:  10%|█         | 713/6926 [00:07<01:01, 100.35it/s, batch_loss=0.6720]


Batch 700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6832


Training:  12%|█▏        | 812/6926 [00:08<01:00, 101.31it/s, batch_loss=0.6767]


Batch 800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6528


Training:  13%|█▎        | 911/6926 [00:09<00:58, 102.19it/s, batch_loss=0.6897]


Batch 900/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6495


Training:  15%|█▍        | 1010/6926 [00:10<00:56, 104.41it/s, batch_loss=0.6957]


Batch 1000/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6550


Training:  16%|█▌        | 1120/6926 [00:11<00:55, 104.97it/s, batch_loss=0.7050]


Batch 1100/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6455


Training:  18%|█▊        | 1219/6926 [00:12<00:54, 103.97it/s, batch_loss=0.6588]


Batch 1200/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6569


Training:  19%|█▉        | 1318/6926 [00:13<00:55, 100.37it/s, batch_loss=0.6866]


Batch 1300/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6539


Training:  20%|██        | 1417/6926 [00:14<00:55, 100.13it/s, batch_loss=0.6830]


Batch 1400/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6637


Training:  22%|██▏       | 1516/6926 [00:14<00:53, 101.22it/s, batch_loss=0.7111]


Batch 1500/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6642


Training:  23%|██▎       | 1615/6926 [00:15<00:53, 100.14it/s, batch_loss=0.6662]


Batch 1600/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6655


Training:  25%|██▍       | 1710/6926 [00:16<00:52, 99.43it/s, batch_loss=0.6228] 


Batch 1700/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6915


Training:  26%|██▌       | 1812/6926 [00:18<00:51, 99.81it/s, batch_loss=0.6641]


Batch 1800/6926
Reconstruction Loss: 0.0000
KL Loss: 0.0001
Style Loss: 0.6789


Training:  27%|██▋       | 1878/6926 [00:18<00:49, 100.98it/s, batch_loss=0.6355]


KeyboardInterrupt: 

In [27]:
class StyleClassifier(nn.Module):
    def _init_(self, vocab_size, embedding_dim, hidden_dim):
        super(StyleClassifier, self)._init_()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.encoder = nn.GRU(embedding_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.embedding(x)
        _, hidden = self.encoder(x)
        # Combine bidirectional states
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
        return self.fc(hidden)

#### UPDATED PART ABOVE, NOT UPDATED PART BELOW

In [28]:
torch.save(vae, 'model_complete.pth')

In [29]:
model_1 = torch.load('model_complete.pth') 

  model_1 = torch.load('model_complete.pth')


In [34]:


# Inspect some sentences from the data loader 
model_1.eval()
with torch.no_grad():
    for input_tokens, _, lengths in data_loader_test:
        input_tokens = input_tokens.to(device)
        x_reconstructed, _, _, _, _, _, _  = model_1(input_tokens)
        x_reconstructed = x_reconstructed.argmax(dim=-1)  # Get the predicted token IDs

        # Print a few input and output sentences
        for i in range(5):  # Print 5 examples
            original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
            reconstructed_sentence = tokens_to_words(x_reconstructed[i].tolist(), vocab)

            print("Original Sentence: \t\t\t", " ".join(original_sentence))
            print("Reconstructed Sentence: \t\t", " ".join(reconstructed_sentence))
            print()

        break  # Only inspect the first batch

Original Sentence: 			 ever since joes has changed hands it 's just gotten worse and worse .
Reconstructed Sentence: 		 ever since joes has changed hands it 's just gotten worse and worse . tasting

Original Sentence: 			 there is definitely not enough room in that part of the venue .
Reconstructed Sentence: 		 there is definitely not enough room in that part of the venue . tasting tasting

Original Sentence: 			 so basically tasted watered down .
Reconstructed Sentence: 		 so basically tasted watered down . tasting tasting tasting tasting tasting tasting tasting tasting tasting

Original Sentence: 			 she said she 'd be back and disappeared for a few minutes .
Reconstructed Sentence: 		 she said she 'd be back and disappeared for a few minutes . tasting joining

Original Sentence: 			 i ca n't believe how inconsiderate this pharmacy is .
Reconstructed Sentence: 		 i ca n't believe how inconsiderate this pharmacy is . tasting joining joining joining joining



In [33]:
def train_style_classifier(data_loader, vocab_size, device):
    classifier = StyleClassifier(vocab_size, 300, 128).to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

    classifier.train()
    for epoch in range(20):  # Train for a few epochs
        total_loss = 0
        for input_tokens, labels, _ in data_loader:  # Adjusted to unpack three values
            input_tokens = input_tokens.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            predictions = classifier(input_tokens)
            loss = criterion(predictions, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/5, Loss: {total_loss / len(data_loader)}")
    
    return classifier

def evaluate_style_transfer(data_loader, model, classifier, device):
    model.eval()
    classifier.eval()
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for input_tokens, labels, _ in data_loader:
            input_tokens = input_tokens.to(device)
            labels = labels.to(device)

            # Get the reconstructed sentences
            x_reconstructed, _, _, _, _, _, _ = model(input_tokens)
            x_reconstructed = x_reconstructed.argmax(dim=-1)

            # Predict the style of the reconstructed sentences
            style_predictions = classifier(x_reconstructed)
            style_labels = (style_predictions > 0.5).float()
            
            correct_predictions += (style_labels == labels).sum().item()
            total_predictions += labels.size(0)
    
    accuracy = correct_predictions / total_predictions
    print(f"Style Transfer Accuracy: {accuracy:.4f}")

In [17]:
classifier = train_style_classifier(data_loader, len(vocab), device)
evaluate_style_transfer(data_loader_test, model_1, classifier, device)

KeyboardInterrupt: 

In [18]:
#### updated code 
import torch
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import nltk
nltk.download('punkt')

def tokens_to_words(token_ids, vocab):
    inv_vocab = {v: k for k, v in vocab.items()}
    return [inv_vocab.get(token_id, '<UNK>') for token_id in token_ids if token_id != 0]

def calculate_bleu_score(data_loader, model, vocab, device):
    model.eval()
    total_bleu_score = 0
    num_sentences = 0
    smoothing_fn = SmoothingFunction().method1

    print("\nBLEU-S: Evaluating content preservation...\n")
    with torch.no_grad():
        for input_tokens, _, lengths in data_loader:
            input_tokens = input_tokens.to(device)
            x_reconstructed, _, _, _ = model(input_tokens)
            x_reconstructed = x_reconstructed.argmax(dim=-1)

            for i in range(min(5, len(input_tokens))):  
                original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
                reconstructed_sentence = tokens_to_words(x_reconstructed[i].tolist(), vocab)
                print(f"Original: {' '.join(original_sentence)}")
                print(f"Reconstructed: {' '.join(reconstructed_sentence)}\n")

                bleu_score = sentence_bleu([original_sentence], reconstructed_sentence, smoothing_function=smoothing_fn)
                total_bleu_score += bleu_score
                num_sentences += 1

            break  # Evaluate only on the first batch for now

    avg_bleu_score = total_bleu_score / num_sentences if num_sentences > 0 else 0
    print(f"Average BLEU-S Score: {avg_bleu_score:.4f}")
    return avg_bleu_score

def train_style_classifier(data_loader, vocab_size, device):
    classifier = StyleClassifier(vocab_size, 300, 128).to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

    classifier.train()
    print("\nTraining Style Classifier...\n")
    for epoch in range(20):
        total_loss = 0
        for input_tokens, labels, _ in data_loader:
            input_tokens = input_tokens.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            predictions = classifier(input_tokens)
            loss = criterion(predictions, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}/5, Loss: {total_loss / len(data_loader):.4f}")
    
    return classifier

def evaluate_style_transfer(data_loader, model, classifier, vocab, device):
    model.eval()
    classifier.eval()
    correct_predictions = 0
    total_predictions = 0

    print("\nEvaluating Style Transfer Accuracy...\n")
    with torch.no_grad():
        for input_tokens, labels, _ in data_loader:
            input_tokens = input_tokens.to(device)
            labels = labels.to(device)

            x_reconstructed, _, _, _ = model(input_tokens)
            x_reconstructed = x_reconstructed.argmax(dim=-1)

            style_predictions = classifier(x_reconstructed)
            style_labels = (style_predictions > 0.5).float()
            correct_predictions += (style_labels == labels).sum().item()
            total_predictions += labels.size(0)

            for i in range(min(5, len(input_tokens))):
                original_sentence = tokens_to_words(input_tokens[i].tolist(), vocab)
                reconstructed_sentence = tokens_to_words(x_reconstructed[i].tolist(), vocab)
                print(f"Original: {' '.join(original_sentence)}")
                print(f"Reconstructed: {' '.join(reconstructed_sentence)}")
                print(f"Style Prediction: {style_labels[i].item()}, True Style: {labels[i].item()}\n")

            break  # Evaluate only on the first batch for now

    accuracy = correct_predictions / total_predictions if total_predictions > 0 else 0
    print(f"Style Transfer Accuracy: {accuracy:.4f}")
    return accuracy

def run_evaluation(data_loader_train, data_loader_test, model, vocab, vocab_size, device):
    # Train Style Classifier
    classifier = train_style_classifier(data_loader_train, vocab_size, device)

    print("\n--- BLEU-S Score (Content Preservation) ---")
    bleu_score = calculate_bleu_score(data_loader_test, model, vocab, device)

    print("\n--- Style Transfer Accuracy ---")
    style_transfer_accuracy = evaluate_style_transfer(data_loader_test, model, classifier, vocab, device)

    print("\n--- Final Results ---")
    print(f"BLEU-S Score: {bleu_score:.4f}")
    print(f"Style Transfer Accuracy: {style_transfer_accuracy:.4f}")

    return bleu_score, style_transfer_accuracy

run_evaluation(data_loader, data_loader_test, model, vocab, len(vocab), device)

[nltk_data] Downloading package punkt to /home/qik/nltk_data...
[nltk_data]   Package punkt is already up-to-date!



Training Style Classifier...

Epoch 1/5, Loss: 0.0954
Epoch 2/5, Loss: 0.0522
Epoch 3/5, Loss: 0.0387
Epoch 4/5, Loss: 0.0300
Epoch 5/5, Loss: 0.0241
Epoch 6/5, Loss: 0.0205
Epoch 7/5, Loss: 0.0178
Epoch 8/5, Loss: 0.0158
Epoch 9/5, Loss: 0.0145
Epoch 10/5, Loss: 0.0139
Epoch 11/5, Loss: 0.0132
Epoch 12/5, Loss: 0.0124
Epoch 13/5, Loss: 0.0116
Epoch 14/5, Loss: 0.0117
Epoch 15/5, Loss: 0.0116
Epoch 16/5, Loss: 0.0113
Epoch 17/5, Loss: 0.0113
Epoch 18/5, Loss: 0.0108
Epoch 19/5, Loss: 0.0105
Epoch 20/5, Loss: 0.0104

--- BLEU-S Score (Content Preservation) ---

BLEU-S: Evaluating content preservation...

Original: at this location the service was terrible .
Reconstructed: so brand scale crumbs crappy steady shogun shogun adds adds adds adds adds adds adds

Original: i ordered garlic bread and fettuccine alfredo pasta with vegetables .
Reconstructed: anyways a+ cancer crusted cancer middle towing crusted middle watered belgian camarones inspector inspector inspector

Original: i did n't

(0.0, 0.546875)

In [20]:
for a, b, c in data_loader:
    for i in a:
        print(i.shape)

torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15])
torch.Size([15