In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from nltk.tokenize import word_tokenize
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from collections import Counter
from tqdm import tqdm
import random

# Configuration
class Config:
    # Paths
    motions_dir = "motions"
    texts_dir = "texts"
    train_list = "train.txt"
    test_list = "test.txt"
    val_list = "val.txt"

    # Hyperparameter ranges for tuning 
    hidden_dims = [128, 256, 512]
    emb_dims = [64, 128, 256]
    enc_layers_options = [1, 2, 3]
    dec_layers_options = [1, 2, 3]
    dropouts = [0.2, 0.3, 0.15]
    batch_sizes = [32, 64, 16]
    lrs = [0.001, 0.0005, 0.00001]
    teacher_forcing_ratios = [0.5, 0.7, 0.3, 0.2]

    # Model parameters (default, can be overridden during hyperparameter tuning)
    hidden_dim = 256
    emb_dim = 128
    enc_layers = 2
    dec_layers = 2
    dropout = 0.3
    batch_size = 64
    lr = 0.001
    epochs = 20
    max_seq_len = 20
    teacher_forcing_ratio = 0.5
    min_word_count = 2

# Text Preprocessing
class Vocabulary:
    def __init__(self):
        self.word2idx = {"<pad>": 0, "<sos>": 1, "<eos>": 2, "<unk>": 3}
        self.idx2word = {v: k for k, v in self.word2idx.items()}
        self.counter = Counter()

    def build_vocab(self, descriptions):
        for desc in descriptions:
            tokens = word_tokenize(desc.lower())
            self.counter.update(tokens)

        words = [word for word, count in self.counter.items() if count >= Config.min_word_count]
        for idx, word in enumerate(words, 4):
            self.word2idx[word] = idx
            self.idx2word[idx] = word

    def numericalize(self, text):
        tokens = word_tokenize(text.lower())
        return [
            self.word2idx[token] if token in self.word2idx else self.word2idx["<unk>"]
            for token in tokens
        ]

    def denumericalize(self, numericalized):
      return " ".join([self.idx2word.get(idx, "<unk>") for idx in numericalized])

# Motion Preprocessing
class MotionNormalizer:
    def __init__(self):
        self.mean = None
        self.std = None

    def fit(self, motions):
        # reshaping each motion to (frames, 22*3)
        reshaped_motions = [m.reshape(-1, 22*3) for m in motions]
        # Concatenating all frames
        all_frames = np.concatenate(reshaped_motions, axis=0)
        # For normalization ..
        self.mean = np.mean(all_frames, axis=0)
        self.std = np.std(all_frames, axis=0) + 1e-8

    def normalize(self, motion):
        if self.mean is None or self.std is None:
            raise ValueError("Normalizer must be fitted before normalizing data")
        if len(motion.shape) != 2 or motion.shape[1] != len(self.mean):
            raise ValueError(f"Expected motion shape (frames, {len(self.mean)}), got {motion.shape}")
        return (motion - self.mean[None, :]) / self.std[None, :]

    def apply(self, motion): # This is meant for creating the submission file
      return (motion - self.mean[None, :]) / self.std[None, :]


# Dataset Class
class MotionTextDataset(Dataset):
    def __init__(self, file_list, normalizer, vocab=None, mode="train"):
        self.file_ids = self._load_file_ids(file_list)
        self.normalizer = normalizer
        self.motions = []
        self.texts_list = [] # Store list of texts for BLEU calculation
        self.mode = mode

        # loading all motions to fit normalizer
        raw_motions = []
        print(f"Loading {mode} motions for normalization...")
        for fid in tqdm(self.file_ids):
            motion = np.load(os.path.join(Config.motions_dir, f"{fid}.npy"))
            raw_motions.append(motion)


        if mode == "train":
            print("Fitting normalizer...")
            self.normalizer.fit(raw_motions)

        # Process motions and texts
        print(f"Processing {mode} data...")
        for fid, motion in tqdm(zip(self.file_ids, raw_motions)):
            # Normalize and flatten motion
            motion = motion.reshape(-1, 22*3)
            motion = self.normalizer.normalize(motion)
            self.motions.append(torch.FloatTensor(motion))

            # Load texts
            if mode != "test":
                with open(os.path.join(Config.texts_dir, f"{fid}.txt")) as f:
                    texts = [line.split('#')[0].strip() for line in f.readlines()]
                    self.texts_list.append(texts) # Store all texts
                    if mode == 'train' or mode == 'val':
                        self.texts = self.texts_list # For training/val, use self.texts for random choice in __getitem__
                    elif mode == 'test':
                        self.texts = None # Test doesn't need self.texts

        # Build vocab from training data
        if mode == "train" and vocab is None:
            print("Building vocabulary...")
            all_texts = [t for texts in self.texts_list for t in texts]
            self.vocab = Vocabulary()
            self.vocab.build_vocab(all_texts)
        else:
            self.vocab = vocab

    def _load_file_ids(self, file_list):
        with open(file_list) as f:
            return [line.strip().split('.')[0] for line in f.readlines()]

    def __len__(self):
        return len(self.file_ids)

    def __getitem__(self, idx):
        motion = self.motions[idx]
        if self.mode == "test":
          return motion, self.file_ids[idx]
        text_options = self.texts_list[idx] # Get all text options for BLEU in val/test
        text = random.choice(text_options) # Random choice for training
        numericalized = [self.vocab.word2idx["<sos>"]] + \
                        self.vocab.numericalize(text) + \
                        [self.vocab.word2idx["<eos>"]]
        return motion, torch.LongTensor(numericalized), text_options # Return text options


# Collate Function
def collate_fn(batch):
    if isinstance(batch[0][1], str): # For test data, no texts
      motions, file_ids = zip(*batch)

      # Pad motions
      motion_lens = [len(m) for m in motions]
      max_motion_len = max(motion_lens)
      padded_motions = torch.zeros(len(motions), max_motion_len, motions[0].shape[1])
      for i, m in enumerate(motions):
        padded_motions[i, :len(m)] = m
      return padded_motions, file_ids
    else: # For train/val data with texts
      motions, texts, text_options_list = zip(*batch) # Unpack text_options_list

      # Pad motions
      motion_lens = [len(m) for m in motions]
      max_motion_len = max(motion_lens)
      padded_motions = torch.zeros(len(motions), max_motion_len, motions[0].shape[1])
      for i, m in enumerate(motions):
          padded_motions[i, :len(m)] = m

      # Pad texts
      text_lens = [len(t) for t in texts]
      max_text_len = max(text_lens)
      padded_texts = torch.zeros(len(texts), max_text_len).long()
      for i, t in enumerate(texts):
          padded_texts[i, :len(t)] = t

      return padded_motions, padded_texts, list(text_options_list) # Return text_options_list


# Model Components
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.n_layers = n_layers

        self.lstm = nn.LSTM(input_dim, hidden_dim, n_layers,
                           dropout=dropout if n_layers > 1 else 0,
                           bidirectional=True, batch_first=True)

        # Adjust fc layer to reduce bidirectional outputs to decoder dimension
        self.fc_hidden = nn.Linear(hidden_dim * 2, hidden_dim)
        self.fc_out = nn.Linear(hidden_dim * 2, hidden_dim)  # to reduce output dimension
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # src: [batch_size, src_len, input_dim]
        batch_size = src.shape[0]

        outputs, (hidden, cell) = self.lstm(src)
        # outputs: [batch_size, src_len, hidden_dim * 2]
        # hidden: [n_layers * 2, batch_size, hidden_dim]

        # Reduce output dimension
        outputs = self.fc_out(outputs)  # [batch_size, src_len, hidden_dim]

        # Combine forward and backward states
        hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)  # [batch_size, hidden_dim * 2]
        hidden = self.dropout(hidden)
        hidden = self.fc_hidden(hidden)  # [batch_size, hidden_dim]

        # Reshape hidden and cell for decoder
        hidden = hidden.unsqueeze(0).repeat(self.n_layers, 1, 1)

        # Process cell state similarly
        cell = torch.cat([cell[-2], cell[-1]], dim=1)
        cell = self.dropout(cell)
        cell = self.fc_hidden(cell)
        cell = cell.unsqueeze(0).repeat(self.n_layers, 1, 1)

        return outputs, hidden, cell

class Decoder(nn.Module):
    def __init__(self, output_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(output_dim, hidden_dim)
        self.lstm = nn.LSTM(hidden_dim * 2, hidden_dim, n_layers,
                           dropout=dropout if n_layers > 1 else 0,
                           batch_first=True)

        # Adjust attention to work with reduced dimensions
        self.attention = nn.Linear(hidden_dim * 2, 1)
        self.fc_out = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, input, encoder_outputs, hidden, cell):
        # input: [batch_size]
        # encoder_outputs: [batch_size, src_len, hidden_dim]
        # hidden: [n_layers, batch_size, hidden_dim]

        batch_size = input.shape[0]
        src_len = encoder_outputs.shape[1]

        # Embed input
        embedded = self.dropout(self.embedding(input))  # [batch_size, hidden_dim]
        embedded = embedded.unsqueeze(1)  # [batch_size, 1, hidden_dim]

        # Prepare hidden state for attention
        hidden_for_attn = hidden[-1].unsqueeze(1)  # [batch_size, 1, hidden_dim]

        # Calculate attention scores
        attention_input = torch.cat(
            (hidden_for_attn.repeat(1, src_len, 1), encoder_outputs),
            dim=2
        )  # [batch_size, src_len, hidden_dim * 2]

        # Calculate attention weights
        attention = self.attention(attention_input)  # [batch_size, src_len, 1]
        attention_weights = F.softmax(attention.squeeze(-1), dim=1)  # [batch_size, src_len]

        # Apply attention to encoder outputs
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        # [batch_size, 1, hidden_dim]

        # Concatenate embedded input and context vector
        rnn_input = torch.cat((embedded, context), dim=2)
        # [batch_size, 1, hidden_dim * 2]

        # Pass through LSTM
        output, (hidden, cell) = self.lstm(rnn_input, (hidden, cell))
        # output: [batch_size, 1, hidden_dim]

        # Prepare output
        output = torch.cat((output.squeeze(1), context.squeeze(1)), dim=1)
        prediction = self.fc_out(output)

        return prediction, hidden, cell

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, src, trg, teacher_forcing_ratio=0.5):
        batch_size = src.shape[0]
        max_len = trg.shape[1]
        vocab_size = self.decoder.output_dim

        outputs = torch.zeros(batch_size, max_len, vocab_size).to(self.device)

        # Encoder
        encoder_outputs, hidden, cell = self.encoder(src)

        # First input to decoder is <sos> token
        input = trg[:, 0]

        for t in range(1, max_len):
            output, hidden, cell = self.decoder(input, encoder_outputs, hidden, cell)
            outputs[:, t] = output

            teacher_force = random.random() < teacher_forcing_ratio
            top1 = output.argmax(1)
            input = trg[:, t] if teacher_force else top1

        return outputs

    def inference(self, src, max_len):
        model.eval()
        batch_size = src.shape[0]
        vocab_size = self.decoder.output_dim

        outputs = torch.zeros(batch_size, max_len, vocab_size).to(self.device)

        # Encoder
        encoder_outputs, hidden, cell = self.encoder(src)

        # First input to decoder is <sos> token
        input = torch.ones(batch_size, dtype=torch.long).to(self.device) * train_dataset.vocab.word2idx["<sos>"]

        predicted_tokens = [[] for _ in range(batch_size)] # Initialize lists to hold predicted token indices for each instance in the batch

        for t in range(1, max_len):
            output, hidden, cell = self.decoder(input, encoder_outputs, hidden, cell)
            outputs[:, t] = output

            top1 = output.argmax(1)

            # Save each token to its corresponding instance
            for batch_idx, token_idx in enumerate(top1):
              predicted_tokens[batch_idx].append(token_idx.item())

            input = top1

        return predicted_tokens
# Training Functions
def train(model, iterator, optimizer, criterion, clip):
    model.train()
    epoch_loss = 0

    for src, trg, _ in tqdm(iterator, desc="Training"): # Added _ to ignore text_options_list
        src = src.to(device)
        trg = trg.to(device)

        optimizer.zero_grad()
        output = model(src, trg)

        output = output[:, 1:].reshape(-1, output.shape[-1])
        trg = trg[:, 1:].reshape(-1)

        loss = criterion(output, trg)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion, vocab):
    model.eval()
    epoch_loss = 0
    epoch_bleu = 0
    smoothing = SmoothingFunction().method1 # Smoothing function to avoid zero BLEU scores

    with torch.no_grad():
        for src, trg, text_options_list in tqdm(iterator, desc="Evaluating"): # Get text_options_list
            src = src.to(device)
            trg = trg.to(device)

            output = model(src, trg, 0)  # Turn off teacher forcing

            output_tokens = model.inference(src, Config.max_seq_len) # Get predicted tokens
            predicted_sentences = [vocab.denumericalize(tokens) for tokens in output_tokens] # Convert tokens to sentences
            reference_sentences_list = text_options_list # Already list of sentences

            output = output[:, 1:].reshape(-1, output.shape[-1])
            trg = trg[:, 1:].reshape(-1)

            loss = criterion(output, trg)
            epoch_loss += loss.item()

            # Calculate BLEU score for each sentence in the batch
            for i in range(len(predicted_sentences)):
                reference_sentences = [word_tokenize(ref.lower()) for ref in reference_sentences_list[i]] # Tokenize references
                candidate_sentence = word_tokenize(predicted_sentences[i].lower()) # Tokenize prediction
                bleu_score = sentence_bleu(reference_sentences, candidate_sentence, smoothing_function=smoothing)
                epoch_bleu += bleu_score

    avg_bleu = epoch_bleu / len(iterator.dataset) # Average BLEU over dataset
    return epoch_loss / len(iterator), avg_bleu

def tune_hyperparameters():
    best_bleu = -1.0
    best_config = None

    # Grid search over hyperparameters
    for hidden_dim in Config.hidden_dims:
        for emb_dim in Config.emb_dims:
            for enc_layers in Config.enc_layers_options:
                for dec_layers in Config.dec_layers_options:
                    for dropout in Config.dropouts:
                        for batch_size in Config.batch_sizes:
                            for lr in Config.lrs:
                                for teacher_forcing_ratio in Config.teacher_forcing_ratios:
                                    current_config = {
                                        'hidden_dim': hidden_dim,
                                        'emb_dim': emb_dim,
                                        'enc_layers': enc_layers,
                                        'dec_layers': dec_layers,
                                        'dropout': dropout,
                                        'batch_size': batch_size,
                                        'lr': lr,
                                        'teacher_forcing_ratio': teacher_forcing_ratio,
                                    }
                                    print(f"\n--- Tuning with config: {current_config} ---")

                                    # Setup device
                                    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

                                    # Prepare data
                                    normalizer = MotionNormalizer()
                                    train_dataset = MotionTextDataset(Config.train_list, normalizer, mode="train")
                                    val_dataset = MotionTextDataset(Config.val_list, normalizer, vocab=train_dataset.vocab, mode="val")

                                    train_loader = DataLoader(train_dataset, batch_size=batch_size,
                                                            shuffle=True, collate_fn=collate_fn)
                                    val_loader = DataLoader(val_dataset, batch_size=batch_size,
                                                          collate_fn=collate_fn)

                                    # Initialize model with current hyperparameters
                                    enc = Encoder(input_dim=66, hidden_dim=hidden_dim,
                                                 n_layers=enc_layers, dropout=dropout)
                                    dec = Decoder(output_dim=len(train_dataset.vocab.word2idx),
                                                 hidden_dim=hidden_dim,
                                                 n_layers=dec_layers,
                                                 dropout=dropout)
                                    model = Seq2Seq(enc, dec, device).to(device)

                                    # Optimizer and loss
                                    optimizer = optim.Adam(model.parameters(), lr=lr)
                                    criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.vocab.word2idx["<pad>"])

                                    # Training loop (simplified - only one epoch for tuning example)
                                    for epoch in range(1): # Reduced epochs for tuning
                                        train_loss = train(model, train_loader, optimizer, criterion, clip=1)
                                        val_loss, val_bleu = evaluate(model, val_loader, criterion, train_dataset.vocab)
                                        print(f"\tEpoch: {epoch+1}, Val BLEU: {val_bleu:.3f}")

                                    if val_bleu > best_bleu:
                                        best_bleu = val_bleu
                                        best_config = current_config
                                        print(f"\t--- New best BLEU: {best_bleu:.3f} with config: {best_config} ---")

    print("\n--- Best Hyperparameters Found ---")
    print(f"Best BLEU Score: {best_bleu:.3f}")
    print(f"Best Config: {best_config}")
    return best_config

if __name__ == "__main__":
    # Option to tune hyperparameters 
    tune_hparams = True

    if tune_hparams:
        best_config = tune_hyperparameters()
        # Update Config with best hyperparameters
        Config.hidden_dim = best_config['hidden_dim']
        Config.emb_dim = best_config['emb_dim']
        Config.enc_layers = best_config['enc_layers']
        Config.dec_layers = best_config['dec_layers']
        Config.dropout = best_config['dropout']
        Config.batch_size = best_config['batch_size']
        Config.lr = best_config['lr']
        Config.teacher_forcing_ratio = best_config['teacher_forcing_ratio']
        print("\n--- Using best config for final training ---")
        print(Config.__dict__) # Print config to verify
    else:
        print("\n--- Using default config for training ---")
        print(Config.__dict__) # Print config to verify


    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Prepare data
    print("Initializing datasets...")
    normalizer = MotionNormalizer()
    train_dataset = MotionTextDataset(Config.train_list, normalizer, mode="train")
    val_dataset = MotionTextDataset(Config.val_list, normalizer, vocab=train_dataset.vocab, mode="val")
    test_dataset = MotionTextDataset(Config.test_list, normalizer, vocab=train_dataset.vocab, mode="test") # Test dataset does not contain text

    print("Creating data loaders...")
    train_loader = DataLoader(train_dataset, batch_size=Config.batch_size,
                            shuffle=True, collate_fn=collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=Config.batch_size,
                          collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=Config.batch_size,
                             shuffle=False, collate_fn=collate_fn)

    # Initialize model
    print("Initializing model...")
    enc = Encoder(input_dim=66, hidden_dim=Config.hidden_dim,
                 n_layers=Config.enc_layers, dropout=Config.dropout)
    dec = Decoder(output_dim=len(train_dataset.vocab.word2idx),
                 hidden_dim=Config.hidden_dim,
                 n_layers=Config.dec_layers,
                 dropout=Config.dropout)
    model = Seq2Seq(enc, dec, device).to(device)

    # Optimizer and loss
    optimizer = optim.Adam(model.parameters(), lr=Config.lr)
    criterion = nn.CrossEntropyLoss(ignore_index=train_dataset.vocab.word2idx["<pad>"])

    # Training loop
    print("Starting training...")
    best_val_loss = float('inf')
    for epoch in range(Config.epochs):
        print(f"\nEpoch: {epoch+1}/{Config.epochs}")

        train_loss = train(model, train_loader, optimizer, criterion, clip=1)
        val_loss, val_bleu = evaluate(model, val_loader, criterion, train_dataset.vocab) # Get BLEU score

        print(f"\tTrain Loss: {train_loss:.3f}")
        print(f"\tVal Loss: {val_loss:.3f}")
        print(f"\tVal BLEU: {val_bleu:.3f}") # Print BLEU score

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            print(f"\tSaving best model...")
            torch.save(model.state_dict(), "best_model.pt")

    # Load best model
    model.load_state_dict(torch.load("best_model.pt"))

    # Generate submission file
    print("Generating submission file...")
    model.eval()

    submission_data = []
    with torch.no_grad():
        for src, file_ids in tqdm(test_loader, desc="Generating predictions"):
            src = src.to(device)

            predicted_tokens = model.inference(src, Config.max_seq_len)

            for tokens, file_id in zip(predicted_tokens, file_ids):
              predicted_text = train_dataset.vocab.denumericalize(tokens)
              # remove eos and pad tokens
              predicted_text = predicted_text.replace("<eos>", "").replace("<pad>", "").strip()
              submission_data.append({"id": file_id, "text": predicted_text})

    submission_df = pd.DataFrame(submission_data)
    submission_df.to_csv("./submission.csv", index=False)
    print("Submission file saved to submission.csv")