In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import numpy as np
import librosa
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
import math
import string
import collections
import os
import threading
import queue
import time
from typing import Optional, Tuple, List
import warnings
warnings.filterwarnings('ignore')

# Character vocabulary for LibriSpeech
VOCAB = [' ', "'"] + list(string.ascii_lowercase) + ['<blank>', '<sos>', '<eos>', '<pad>']
CHAR_TO_IDX = {char: idx for idx, char in enumerate(VOCAB)}
IDX_TO_CHAR = {idx: char for char, idx in CHAR_TO_IDX.items()}
VOCAB_SIZE = len(VOCAB)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() *
                           -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

class ConformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, kernel_size=31, dropout=0.1):
        super().__init__()
        self.d_model = d_model

        # Multi-head self attention
        self.attention = nn.MultiheadAttention(d_model, n_heads, dropout=dropout, batch_first=True)

        # Convolution module
        self.conv1 = nn.Conv1d(d_model, d_model * 2, 1)
        self.depthwise_conv = nn.Conv1d(d_model, d_model, kernel_size,
                                       padding=kernel_size//2, groups=d_model)
        self.conv2 = nn.Conv1d(d_model, d_model, 1)

        # Feed forward
        self.ff1 = nn.Linear(d_model, d_ff)
        self.ff2 = nn.Linear(d_ff, d_model)

        # Layer norms
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.norm4 = nn.LayerNorm(d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Feed forward 1
        residual = x
        x = self.norm1(x)
        x = self.ff1(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.ff2(x)
        x = self.dropout(x)
        x = residual + 0.5 * x

        # Multi-head attention
        residual = x
        x = self.norm2(x)
        attn_out, _ = self.attention(x, x, x, key_padding_mask=mask)
        x = residual + attn_out

        # Convolution module
        residual = x
        x = self.norm3(x)
        x = x.transpose(1, 2)  # (B, T, D) -> (B, D, T)
        x = self.conv1(x)
        x = F.glu(x, dim=1)
        x = self.depthwise_conv(x)
        x = F.silu(x)
        x = self.conv2(x)
        x = self.dropout(x)
        x = x.transpose(1, 2)  # (B, D, T) -> (B, T, D)
        x = residual + x

        # Feed forward 2
        residual = x
        x = self.norm4(x)
        x = self.ff1(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.ff2(x)
        x = self.dropout(x)
        x = residual + 0.5 * x

        return x

class StreamingConformer(nn.Module):
    def __init__(self, n_mels=80, d_model=512, n_layers=12, n_heads=8, d_ff=2048, vocab_size=VOCAB_SIZE, kernel_size=31, dropout=0.1):
        super().__init__()

        # Feature extraction with depthwise separable convolutions
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(1, 32, (3, 3), padding=(1, 1)),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, (3, 3), padding=(1, 1), groups=32),  # Depthwise
            nn.Conv2d(32, 64, 1),  # Pointwise
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),  # Only pool in frequency

            # Second conv block
            nn.Conv2d(64, 64, (3, 3), padding=(1, 1), groups=64),  # Depthwise
            nn.Conv2d(64, 128, 1),  # Pointwise
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d((2, 1)),  # Only pool in frequency

            # Third conv block
            nn.Conv2d(128, 128, (3, 3), padding=(1, 1), groups=128),  # Depthwise
            nn.Conv2d(128, d_model, 1),  # Pointwise
            nn.BatchNorm2d(d_model),
            nn.ReLU(),
        )

        # Calculate feature dimension after conv layers
        self.n_mels_after_conv = n_mels // 4  # After 2 pooling layers

        # Project to model dimension
        self.input_projection = nn.Linear(d_model * self.n_mels_after_conv, d_model)

        # Positional encoding
        self.pos_encoding = PositionalEncoding(d_model)

        # Conformer blocks
        self.conformer_blocks = nn.ModuleList([
            ConformerBlock(d_model, n_heads, d_ff, kernel_size, dropout)
            for _ in range(n_layers)
        ])

        # Output projection
        self.output_projection = nn.Linear(d_model, vocab_size)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, lengths=None):
        # x: (batch, n_mels, time)
        batch_size, n_mels, time_steps = x.shape

        # Add channel dimension for conv2d
        x = x.unsqueeze(1)  # (batch, 1, n_mels, time)

        # Feature extraction
        x = self.feature_extractor(x)  # (batch, d_model, n_mels_reduced, time)

        # Reshape for transformer
        x = x.permute(0, 3, 1, 2)  # (batch, time, d_model, n_mels_reduced)
        x = x.reshape(batch_size, -1, self.n_mels_after_conv * x.size(2))

        # Project to model dimension
        x = self.input_projection(x)  # (batch, time, d_model)

        # Add positional encoding
        x = self.pos_encoding(x)
        x = self.dropout(x)

        # Create attention mask if lengths provided
        mask = None
        if lengths is not None:
            max_len = x.size(1)
            mask = torch.arange(max_len, device=x.device).expand(
                batch_size, max_len) >= lengths.unsqueeze(1)

        # Apply conformer blocks
        for block in self.conformer_blocks:
            x = block(x, mask)

        # Output projection
        logits = self.output_projection(x)  # (batch, time, vocab_size)

        return logits

class LibriSpeechDataset(Dataset):
    def __init__(self, audio_paths, transcripts, sample_rate=16000, n_mels=80):
        self.audio_paths = audio_paths
        self.transcripts = transcripts
        self.sample_rate = sample_rate
        self.n_mels = n_mels

        # Mel spectrogram transform
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_mels=n_mels,
            n_fft=400,  # 25ms window
            hop_length=160,  # 10ms hop
            win_length=400,
            window_fn=torch.hann_window
        )

    def __len__(self):
        return len(self.audio_paths)

    def text_to_indices(self, text):
        # Convert text to character indices
        text = text.lower().strip()
        indices = []
        for char in text:
            if char in CHAR_TO_IDX:
                indices.append(CHAR_TO_IDX[char])
            else:
                indices.append(CHAR_TO_IDX[' '])  # Unknown char -> space
        return indices

    def __getitem__(self, idx):
        # Load audio
        waveform, sr = torchaudio.load(self.audio_paths[idx])

        # Resample if necessary
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            waveform = resampler(waveform)

        # Convert to mono
        if waveform.shape[0] > 1:
            waveform = waveform.mean(dim=0, keepdim=True)

        # Extract mel spectrogram
        mel_spec = self.mel_transform(waveform)
        mel_spec = torch.log(mel_spec + 1e-8)  # Log mel

        # Normalize
        mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-8)

        # Get transcript
        transcript = self.transcripts[idx]
        target_indices = self.text_to_indices(transcript)

        return {
            'mel_spec': mel_spec.squeeze(0),  # Remove channel dim
            'target': torch.tensor(target_indices, dtype=torch.long),
            'input_length': mel_spec.shape[-1],
            'target_length': len(target_indices)
        }

def load_custom_dataset(dataset_root="./dataset"):
    """
    Load custom dataset from Google Drive folder structure

    Expected structure:
    dataset_root/
    ├── wavs/           # Contains .wav audio files
    ├── texts/          # Contains .txt transcript files with same filenames
    └── phonemes/       # Contains phoneme files (not used for now)

    Args:
        dataset_root: Root directory of the dataset

    Returns:
        audio_paths: List of paths to .wav files
        transcripts: List of corresponding transcriptions
    """
    import glob

    audio_paths = []
    transcripts = []

    # Define paths
    wavs_dir = os.path.join(dataset_root, "wavs")
    texts_dir = os.path.join(dataset_root, "texts")

    # Check if directories exist
    if not os.path.exists(wavs_dir):
        print(f"Error: {wavs_dir} not found!")
        return [], []

    if not os.path.exists(texts_dir):
        print(f"Error: {texts_dir} not found!")
        return [], []

    # Find all wav files
    wav_files = glob.glob(os.path.join(wavs_dir, "*.wav"))

    print(f"Found {len(wav_files)} wav files in {wavs_dir}")

    matched_files = 0
    missing_transcripts = []

    for wav_file in wav_files:
        # Get base filename without extension
        base_name = os.path.splitext(os.path.basename(wav_file))[0]

        # Construct corresponding text file path
        text_file = os.path.join(texts_dir, f"{base_name}.txt")

        # Check if transcript exists
        if os.path.exists(text_file):
            try:
                # Read transcript
                with open(text_file, 'r', encoding='utf-8') as f:
                    transcript = f.read().strip()

                # Skip empty transcripts
                if transcript:
                    audio_paths.append(wav_file)
                    transcripts.append(transcript)
                    matched_files += 1
                else:
                    print(f"Warning: Empty transcript for {base_name}")

            except Exception as e:
                print(f"Error reading {text_file}: {e}")
        else:
            missing_transcripts.append(base_name)

    # Report results
    print(f"Successfully matched {matched_files} audio-transcript pairs")

    if missing_transcripts:
        print(f"Warning: {len(missing_transcripts)} wav files have no corresponding transcripts:")
        for missing in missing_transcripts[:10]:  # Show first 10
            print(f"  - {missing}.txt")
        if len(missing_transcripts) > 10:
            print(f"  ... and {len(missing_transcripts) - 10} more")

    return audio_paths, transcripts

def download_from_google_drive(drive_url, destination_folder="./dataset"):
    """
    Download dataset from Google Drive

    Args:
        drive_url: Google Drive sharing URL
        destination_folder: Where to extract the dataset
    """
    try:
        import gdown
        print("Using gdown to download from Google Drive...")

        # Create destination folder
        os.makedirs(destination_folder, exist_ok=True)

        # Download and extract
        gdown.download_folder(drive_url, output=destination_folder, quiet=False)
        print(f"Dataset downloaded to {destination_folder}")

    except ImportError:
        print("gdown not installed. Install with: pip install gdown")
        print("Or manually download the dataset and place it in the expected structure:")
        print(f"{destination_folder}/")
        print("├── wavs/")
        print("├── texts/")
        print("└── phonemes/")

    except Exception as e:
        print(f"Download failed: {e}")
        print("Please manually download and extract the dataset.")

def split_dataset(audio_paths, transcripts, train_ratio=0.8, val_ratio=0.1, test_ratio=0.1, random_seed=42):
    """
    Split dataset into train/val/test sets

    Args:
        audio_paths: List of audio file paths
        transcripts: List of transcripts
        train_ratio: Proportion for training
        val_ratio: Proportion for validation
        test_ratio: Proportion for testing
        random_seed: Random seed for reproducibility

    Returns:
        Dictionary with train/val/test splits
    """
    import random

    # Set seed for reproducibility
    random.seed(random_seed)

    # Create paired list and shuffle
    paired_data = list(zip(audio_paths, transcripts))
    random.shuffle(paired_data)

    # Calculate split indices
    total_samples = len(paired_data)
    train_end = int(total_samples * train_ratio)
    val_end = train_end + int(total_samples * val_ratio)

    # Split data
    train_data = paired_data[:train_end]
    val_data = paired_data[train_end:val_end]
    test_data = paired_data[val_end:]

    # Separate paths and transcripts
    train_paths, train_transcripts = zip(*train_data) if train_data else ([], [])
    val_paths, val_transcripts = zip(*val_data) if val_data else ([], [])
    test_paths, test_transcripts = zip(*test_data) if test_data else ([], [])

    print(f"Dataset split:")
    print(f"  Train: {len(train_paths)} samples")
    print(f"  Validation: {len(val_paths)} samples")
    print(f"  Test: {len(test_paths)} samples")

    return {
        'train': (list(train_paths), list(train_transcripts)),
        'val': (list(val_paths), list(val_transcripts)),
        'test': (list(test_paths), list(test_transcripts))
    }

def collate_fn(batch):
    # Sort by input length (descending)
    batch = sorted(batch, key=lambda x: x['input_length'], reverse=True)

    # Get lengths
    input_lengths = torch.tensor([item['input_length'] for item in batch])
    target_lengths = torch.tensor([item['target_length'] for item in batch])

    # Pad mel spectrograms
    max_input_len = max(input_lengths)
    n_mels = batch[0]['mel_spec'].shape[0]

    padded_mels = torch.zeros(len(batch), n_mels, max_input_len)
    for i, item in enumerate(batch):
        length = item['input_length']
        padded_mels[i, :, :length] = item['mel_spec']

    # Pad targets
    max_target_len = max(target_lengths)
    padded_targets = torch.full((len(batch), max_target_len),
                               CHAR_TO_IDX['<pad>'], dtype=torch.long)
    for i, item in enumerate(batch):
        length = item['target_length']
        padded_targets[i, :length] = item['target']

    return {
        'mel_specs': padded_mels,
        'targets': padded_targets,
        'input_lengths': input_lengths,
        'target_lengths': target_lengths
    }

class CTCLoss(nn.Module):
    def __init__(self, blank_idx=CHAR_TO_IDX['<blank>']):
        super().__init__()
        self.blank_idx = blank_idx
        self.ctc_loss = nn.CTCLoss(blank=blank_idx, reduction='mean', zero_infinity=True)

    def forward(self, log_probs, targets, input_lengths, target_lengths):
        # log_probs: (batch, time, vocab)
        # targets: (batch, target_length)
        # input_lengths: (batch,)
        # target_lengths: (batch,)

        # CTC expects (time, batch, vocab)
        log_probs = log_probs.transpose(0, 1)
        log_probs = F.log_softmax(log_probs, dim=-1)

        # Flatten targets
        targets_flat = []
        for i, length in enumerate(target_lengths):
            targets_flat.extend(targets[i][:length].tolist())
        targets_flat = torch.tensor(targets_flat, dtype=torch.long, device=targets.device)

        return self.ctc_loss(log_probs, targets_flat, input_lengths, target_lengths)

class RealTimeSTT:
    def __init__(self, model_path=None, device='cuda' if torch.cuda.is_available() else 'cpu'):
        self.device = device
        self.model = StreamingConformer().to(device)

        if model_path and os.path.exists(model_path):
            self.model.load_state_dict(torch.load(model_path, map_location=device))

        self.model.eval()

        # Audio parameters
        self.sample_rate = 16000
        self.chunk_duration = 0.03  # 30ms
        self.hop_duration = 0.01   # 10ms
        self.chunk_samples = int(self.chunk_duration * self.sample_rate)
        self.hop_samples = int(self.hop_duration * self.sample_rate)

        # Mel transform
        self.mel_transform = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.sample_rate,
            n_mels=80,
            n_fft=400,
            hop_length=160,
            win_length=400,
            window_fn=torch.hann_window
        ).to(device)

        # Streaming state
        self.audio_buffer = torch.zeros(0)
        self.context_frames = 50  # Keep some context
        self.mel_buffer = torch.zeros(80, 0)

    def preprocess_audio(self, audio_chunk):
        """Convert audio chunk to mel spectrogram"""
        if len(audio_chunk.shape) == 1:
            audio_chunk = audio_chunk.unsqueeze(0)

        # Extract mel spectrogram
        mel_spec = self.mel_transform(audio_chunk)
        mel_spec = torch.log(mel_spec + 1e-8)

        # Normalize
        mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-8)

        return mel_spec.squeeze(0)

    def decode_predictions(self, logits):
        """Decode CTC predictions to text"""
        # Get best path
        predictions = torch.argmax(logits, dim=-1)  # (time,)

        # Remove blanks and consecutive duplicates
        decoded = []
        prev_char = None

        for pred in predictions:
            pred_char = IDX_TO_CHAR[pred.item()]
            if pred_char != '<blank>' and pred_char != prev_char:
                decoded.append(pred_char)
            prev_char = pred_char

        return ''.join(decoded)

    def transcribe_stream(self, audio_chunk):
        """Process a single audio chunk and return transcription"""
        with torch.no_grad():
            # Convert to tensor and move to device
            if isinstance(audio_chunk, np.ndarray):
                audio_chunk = torch.from_numpy(audio_chunk).float()
            audio_chunk = audio_chunk.to(self.device)

            # Preprocess
            mel_spec = self.preprocess_audio(audio_chunk)

            # Add to buffer with context
            self.mel_buffer = torch.cat([self.mel_buffer, mel_spec], dim=1)

            # Keep only recent context
            if self.mel_buffer.shape[1] > self.context_frames:
                self.mel_buffer = self.mel_buffer[:, -self.context_frames:]

            # Add batch dimension
            mel_batch = self.mel_buffer.unsqueeze(0)

            # Forward pass
            logits = self.model(mel_batch)

            # Decode
            text = self.decode_predictions(logits.squeeze(0))

            return text

def train_model(train_loader, val_loader, model, device, num_epochs=50):
    """Training function"""
    model = model.to(device)
    criterion = CTCLoss()
    optimizer = AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)

    # Learning rate scheduler
    steps_per_epoch = len(train_loader)
    scheduler = OneCycleLR(optimizer, max_lr=1e-3,
                          steps_per_epoch=steps_per_epoch,
                          epochs=num_epochs)

    best_val_loss = float('inf')

    for epoch in range(num_epochs):
        # Training
        model.train()
        train_loss = 0

        for batch_idx, batch in enumerate(train_loader):
            mel_specs = batch['mel_specs'].to(device)
            targets = batch['targets'].to(device)
            input_lengths = batch['input_lengths'].to(device)
            target_lengths = batch['target_lengths'].to(device)

            optimizer.zero_grad()

            # Forward pass
            logits = model(mel_specs, input_lengths)

            # Calculate loss
            loss = criterion(logits, targets, input_lengths, target_lengths)

            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            train_loss += loss.item()

            if batch_idx % 100 == 0:
                print(f'Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}')

        # Validation
        model.eval()
        val_loss = 0

        with torch.no_grad():
            for batch in val_loader:
                mel_specs = batch['mel_specs'].to(device)
                targets = batch['targets'].to(device)
                input_lengths = batch['input_lengths'].to(device)
                target_lengths = batch['target_lengths'].to(device)

                logits = model(mel_specs, input_lengths)
                loss = criterion(logits, targets, input_lengths, target_lengths)
                val_loss += loss.item()

        train_loss /= len(train_loader)
        val_loss /= len(val_loader)

        print(f'Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), '/content/drive/MyDrive/NAC/best_conformer_model.pth')
            print(f'New best model saved with val loss: {val_loss:.4f}')

# Example usage and demo
def simulate_real_time():
    """Simulate real-time transcription"""
    # Initialize the real-time STT system
    stt = RealTimeSTT('best_conformer_model.pth')

    # Simulate audio chunks (in practice, this would come from microphone)
    # For demo, we'll create some dummy audio
    sample_rate = 16000
    chunk_duration = 0.03  # 30ms
    chunk_samples = int(chunk_duration * sample_rate)

    print("Simulating real-time transcription...")
    print("Processing 30ms chunks with 10ms hop...")

    # Create dummy audio signal (in practice, replace with microphone input)
    duration = 2.0  # 2 seconds
    t = torch.linspace(0, duration, int(duration * sample_rate))
    # Mix of frequencies to simulate speech
    audio_signal = (torch.sin(2 * np.pi * 440 * t) * 0.3 +
                   torch.sin(2 * np.pi * 880 * t) * 0.2 +
                   torch.randn_like(t) * 0.1)

    # Process in chunks
    hop_samples = int(0.01 * sample_rate)  # 10ms hop
    transcription = ""

    for i in range(0, len(audio_signal) - chunk_samples, hop_samples):
        chunk = audio_signal[i:i + chunk_samples]

        # Simulate processing time
        start_time = time.time()

        # Transcribe chunk
        text = stt.transcribe_stream(chunk)

        processing_time = (time.time() - start_time) * 1000  # ms

        if text.strip():
            transcription += text + " "
            print(f"Chunk {i//hop_samples}: '{text}' (processed in {processing_time:.1f}ms)")

    print(f"\nFinal transcription: {transcription.strip()}")

if __name__ == "__main__":
    # Example of how to use the system
    print("Real-Time Speech-to-Text System")
    print("===============================")

    # Initialize model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")

    model = StreamingConformer()
    print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

    # For training, you would prepare your custom dataset:
    # audio_paths, transcripts = load_custom_dataset('./dataset')
    #
    # # Split into train/val/test
    # splits = split_dataset(audio_paths, transcripts)
    # train_paths, train_transcripts = splits['train']
    # val_paths, val_transcripts = splits['val']
    #
    # train_dataset = LibriSpeechDataset(train_paths, train_transcripts)
    # train_loader = DataLoader(train_dataset, batch_size=16,
    #                          shuffle=True, collate_fn=collate_fn)
    #
    # val_dataset = LibriSpeechDataset(val_paths, val_transcripts)
    # val_loader = DataLoader(val_dataset, batch_size=16,
    #                        shuffle=False, collate_fn=collate_fn)
    #
    #

    audio_paths, transcripts = load_custom_dataset('/content/drive/MyDrive/NAC/')

    # Split the data
    splits = split_dataset(audio_paths, transcripts)
    train_paths, train_transcripts = splits['train']
    val_paths, val_transcripts = splits['val']
    train_dataset = LibriSpeechDataset(train_paths, train_transcripts)
    train_loader = DataLoader(train_dataset, batch_size=16,
                              shuffle=True, collate_fn=collate_fn)

    val_dataset = LibriSpeechDataset(val_paths, val_transcripts)
    val_loader = DataLoader(val_dataset, batch_size=16,
                            shuffle=False, collate_fn=collate_fn)
    train_model(train_loader, val_loader, model, device)

    print("\nSystem Features:")
    print("- 30ms processing latency")
    print("- 10ms hop length for real-time streaming")
    print("- Conformer architecture with conv + transformer")
    print("- CTC loss for alignment-free training")
    print("- Optimized for LibriSpeech vocabulary")
    print("- Streaming context management")
    print("- GPU acceleration support")

Real-Time Speech-to-Text System
Using device: cpu
Model parameters: 52,845,728
Found 3842 wav files in /content/drive/MyDrive/NAC/wavs
Successfully matched 3842 audio-transcript pairs
Dataset split:
  Train: 3073 samples
  Validation: 384 samples
  Test: 385 samples
Epoch 0, Batch 0, Loss: 18.6699
Epoch 0, Batch 100, Loss: 3.2797
Epoch 0: Train Loss: 3.9070, Val Loss: 3.0994
New best model saved with val loss: 3.0994
Epoch 1, Batch 0, Loss: 3.1242
Epoch 1, Batch 100, Loss: 3.0362
Epoch 1: Train Loss: 3.0327, Val Loss: 3.0382
New best model saved with val loss: 3.0382
Epoch 2, Batch 0, Loss: 2.9482
Epoch 2, Batch 100, Loss: 2.9048
Epoch 2: Train Loss: 2.9051, Val Loss: 2.8439
New best model saved with val loss: 2.8439
Epoch 3, Batch 0, Loss: 2.7876
Epoch 3, Batch 100, Loss: 2.8716
Epoch 3: Train Loss: 2.8601, Val Loss: 2.8734
Epoch 4, Batch 0, Loss: 2.7504
Epoch 4, Batch 100, Loss: 2.7863
Epoch 4: Train Loss: 2.8355, Val Loss: 2.9200
Epoch 5, Batch 0, Loss: 2.8226
