# Undertale & Deltarune Soundtrack Generator

---

## Table of Contents

0. [**Table of Contents**](#Table-of-Contents)

1. [**Imports**](#Imports)

2. [**Data Processing**](#Data-Processing)

    2.1 [Data Loading](#Data-Loading)
    
    2.2 [Data Preprocessing](#Data-Preprocessing)
    
    2.3 [Dataset & Dataloader Definition](#Dataset-&-Dataloader-Definition)
    
3. [**Model Definition**](#Model-Definition)
    
4. [**Hyperparameters & Instantiation**](#Hyperparameters-&-Instantiation)

5. [**Training**](#Training)

---

## Imports
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Import required packages

In [18]:
import os                                         # File handling
import itertools                                  # chain() for merging lists
import random                                     # Shuffling
import collections                                # Useful tools like Counter, OrderedDict
import math                                       # For... math
from decimal import Decimal                       # Scientific notations in string formatting
from time import time                             # For use in progress bar

import tqdm.auto as tqdm                          # Progress bar

import numpy as np

import torch                                      # Deep Learning Framework
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt                   # Plotting training progress
from matplotlib.ticker import AutoLocator
%matplotlib inline

fig_bg_color = "lightsteelblue"
plot_bg_color = "slategray"
fontsize = 20

---

## Data Processing
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

### Data Loading
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Read the text files in the target directory.

Do some processing to make sure the texts are clean.

In [2]:
def get_texts(texts_dir):

    if not os.path.isdir(texts_dir):
        raise FileNotFoundError("given text directory not found: {}".format(texts_dir))

    texts = []
    
    for text_path in (file.path for file in os.scandir(texts_dir) if file.is_file() and file.name.endswith(".txt")):
        with open(file=text_path, mode='r', encoding="utf-8") as text_file:
            
            text = text_file.read().strip()

            if not text.replace(' ', '').isdigit():
                raise RuntimeError("one or more characters other than digits and white spaces are detected: {}".format(text_path))

            while "  " in text:
                text = text.replace("  ", ' ')
            
            texts.append((text_path, text))
    
    return dict(texts)


[(os.path.split(text_path)[1], text[:20]) for text_path, text in get_texts("./source/converted_texts").items()]

[('ANOTHER_HIM_-_DeltaRune.txt', '42 46 49 53 0 42 46 '),
 ('A_Town_Called_Hometown_Deltarune_-_Arranged_for_Piano.txt',
  '73 89 0 73 89 0 73 8'),
 ('Basement_Deltarune_-_Arranged_for_Piano.txt', '39 51 0 39 51 0 39 5'),
 ('Before_the_Story_Deltarune_-_Arranged_for_piano_.txt',
  '48 0 48 0 48 0 48 0 '),
 ('Card_Castle_Deltarune_-_Arranged_for_Piano.txt', '39 0 39 0 39 0 39 0 '),
 ('Checker_Dance_Deltarune_-_Arranged_for_Piano.txt', '30 0 30 0 30 0 30 0 '),
 ('Deltarune_-_Beginning.txt', '48 55 0 48 55 0 48 5'),
 ('Deltarune_-_Chaos_King.txt', '27 39 0 27 39 0 27 3'),
 ('Deltarune_-_Darkness_Falls.txt', '61 64 71 75 0 61 64 '),
 ('Deltarune_-_Dont_Forget_Ending_Theme_Solo_Piano_Version.txt',
  '77 0 77 0 77 0 77 0 '),
 ('Deltarune_-_Friendship.txt', '74 0 74 0 74 0 74 0 '),
 ('Deltarune_-_Gallery.txt', '32 36 39 68 0 32 36 '),
 ('Deltarune_-_Lancer_Battle.txt', '62 0 62 0 62 0 0 0 0'),
 ('DELTARUNE_-_Lancer_piano_solo.txt', '0 0 0 62 0 62 0 62 0'),
 ('Deltarune_-_Lantern.txt', '49 0 4

### Data Preprocessing
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Get integers out of the text and make lists of ints.

These lists can be used for the input of the models, or be further processed to compress or simplify the sequences.

In [3]:
def texts_to_intlists(text_list):
    
    intlists = []
    
    for i, text in enumerate(iterable=text_list):
        
        int_strings = text.split(' ')
        
        if not all(int_str.isdigit() for int_str in int_strings):
            raise RuntimeError("non-digit string detected in text {}".format(i))

        ints = [int(int_str) for int_str in int_strings]
        
        intlists.append(ints)
        
    return intlists


print([ints[:10] for ints in texts_to_intlists(get_texts("./source/converted_texts").values())])

[[42, 46, 49, 53, 0, 42, 46, 49, 53, 0], [73, 89, 0, 73, 89, 0, 73, 89, 0, 73], [39, 51, 0, 39, 51, 0, 39, 51, 0, 39], [48, 0, 48, 0, 48, 0, 48, 0, 48, 0], [39, 0, 39, 0, 39, 0, 39, 0, 39, 0], [30, 0, 30, 0, 30, 0, 30, 0, 30, 0], [48, 55, 0, 48, 55, 0, 48, 55, 0, 48], [27, 39, 0, 27, 39, 0, 27, 39, 0, 27], [61, 64, 71, 75, 0, 61, 64, 71, 75, 0], [77, 0, 77, 0, 77, 0, 77, 0, 77, 0], [74, 0, 74, 0, 74, 0, 74, 0, 74, 0], [32, 36, 39, 68, 0, 32, 36, 39, 68, 0], [62, 0, 62, 0, 62, 0, 0, 0, 0, 65], [0, 0, 0, 62, 0, 62, 0, 62, 0, 62], [49, 0, 49, 0, 49, 0, 49, 0, 49, 0], [31, 43, 0, 31, 43, 0, 31, 43, 0, 31], [24, 31, 0, 24, 31, 0, 24, 31, 0, 24], [45, 57, 0, 45, 57, 0, 45, 57, 0, 45], [39, 0, 39, 0, 39, 0, 39, 0, 39, 0], [46, 0, 46, 0, 46, 0, 46, 0, 46, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [58, 70, 0, 58, 70, 0, 58, 70, 0, 58], [37, 49, 0, 37, 49, 0, 37, 49, 0, 37], [44, 68, 0, 44, 68, 0, 44, 68, 0, 44], [67, 0, 67, 0, 67, 0, 67, 0, 67, 0], [61, 0, 61, 0, 61, 0, 61, 0, 61, 0], [49, 0, 49, 0, 

To use "words" as the input and output instead of "characters",

consider '0's as spaces and find all existing words in the texts.

(Here, each word becomes a "token")

We can also tokenize the duration of each word to reduce the

repetition of words that appear several times in a row.

In [4]:
def tokenize(intlists, max_repeat_encoding=0, return_ratios=False):
    assert isinstance(max_repeat_encoding, int) and max_repeat_encoding >= -1 # -1 for no limit
    
    encode_repetition = (max_repeat_encoding != 0)
    if encode_repetition:
        observed_repeats = []
    
    counter = collections.Counter() # Note: repetition tokens are not counted. They are appended to the dictionary later.
    if return_ratios:
        measure = collections.Counter()
    tokenized_lists = []
    
    for intlist in intlists:
        if encode_repetition:
            last_token = None
            n_repeats = 0
        token = []
        tokenized = []
        for int_val in intlist:
            if int_val != 0:
                token.append(int_val)
            else:
                token = tuple(sorted(token))
                
                if encode_repetition:
                    if last_token == token:
                        if n_repeats == max_repeat_encoding:
                            tokenized.append(("<REPEAT>", n_repeats))
                            if return_ratios:
                                measure.update((("<REPEAT>", n_repeats),))
                            n_repeats = 1
                        else:
                            n_repeats += 1
                            if n_repeats not in observed_repeats:
                                observed_repeats.append(n_repeats)
                    else:
                        if n_repeats != 0:
                            tokenized.append(("<REPEAT>", n_repeats))
                            if return_ratios:
                                measure.update((("<REPEAT>", n_repeats),))
                        counter.update((token,))
                        if return_ratios:
                            measure.update((token,))
                        tokenized.append(token)
                        last_token = token
                        n_repeats = 0

                else:
                    counter.update((token,))
                    if return_ratios:
                        measure.update((token,))
                    tokenized.append(token)
                token = []
        tokenized_lists.append(tokenized)
    
    tokens_token_to_idx = collections.OrderedDict((token_key, i) for i, (token_key, _) in enumerate(counter.most_common()))
    if encode_repetition:
        tokens_token_to_idx.update([(("<REPEAT>", r), i) for i, r in enumerate(observed_repeats, len(tokens_token_to_idx))])
    tokens_idx_to_token = collections.OrderedDict((i, token_key) for token_key, i in tokens_token_to_idx.items())
    print(len(tokens_idx_to_token), "tokens")
    
    if return_ratios:
        n_total_tokens = sum(tuple(measure.values()))
        ratios = collections.OrderedDict((token_key, n_occurs / n_total_tokens) for token_key, n_occurs in measure.most_common())
    
    for tokenized in tokenized_lists:
        for i, token_key in enumerate(tokenized):
            tokenized[i] = tokens_token_to_idx[token_key]

    if return_ratios:
        return tokenized_lists, tokens_idx_to_token, ratios
    return tokenized_lists, tokens_idx_to_token


max_repeats = 15

intlists = texts_to_intlists(get_texts("./source/converted_texts").values())
tokenized_lists, tokens_idx_to_token, ratios = tokenize(intlists, max_repeat_encoding=max_repeats, return_ratios=True)
print("\nPart of tokenized sequences:")
print([tokenized_list[:10] for tokenized_list in tokenized_lists[:10]])
print("\nOriginal lengths:")
print([len(intlist) for intlist in intlists[:10]])
print("\nTokenized lengths (with maximum repetition of {}):".format("0 (no repetition tokens)" if max_repeats == 0
                                                                    else "infinity (unlimited length)" if max_repeats == -1
                                                                    else max_repeats))
print([len(tokenized_list) for tokenized_list in tokenized_lists[:10]])
print("\nSome of the most frequent tokens + repetition tokens if used:")
print(list(tokens_idx_to_token.items())[:5] + list(tokens_idx_to_token.items())[-5:])

print("\nTop 50%")
print("=========")
cum_ratio_sum = 0
for token_key, ratio in ratios.items():
    print("{:<20s}{:>15.3f}%".format(str(token_key), ratio * 100))
    cum_ratio_sum += ratio
    if cum_ratio_sum >= 0.5:
        break

7536 tokens

Part of tokenized sequences:
[[764, 7535, 7535, 7535, 7535, 7535, 7535, 7535, 7535, 7535], [2871, 7535, 1336, 7535, 991, 1336, 7533, 262, 7521, 2872], [59, 7534, 0, 59, 7535, 7535, 7535, 7535, 7522, 2888], [25, 7535, 7521, 0, 50, 7535, 7521, 12, 7535, 7535], [23, 7524, 0, 7523, 23, 7524, 0, 7535, 7527, 121], [52, 7529, 0, 52, 7522, 0, 7535, 7535, 7535, 7535], [221, 7524, 0, 7522, 221, 7526, 0, 221, 7525, 0], [78, 7530, 0, 78, 7524, 0, 7525, 78, 7525, 0], [622, 7532, 166, 7526, 622, 7525, 166, 1374, 7528, 166], [16, 7535, 7533, 0, 7521, 26, 7535, 7533, 0, 1009]]

Original lengths:
[9070, 15651, 4602, 13185, 8139, 10348, 7383, 30509, 10821, 6219]

Tokenized lengths (with maximum repetition of 15):
[246, 892, 119, 641, 673, 1451, 719, 4066, 684, 307]

Some of the most frequent tokens + repetition tokens if used:
[(0, ()), (1, (44,)), (2, (69,)), (3, (66,)), (4, (63,)), (7531, ('<REPEAT>', 11)), (7532, ('<REPEAT>', 12)), (7533, ('<REPEAT>', 13)), (7534, ('<REPEAT>', 14)), (753

As you try out different *'max_repeat_encoding'* values \[0, 5, 10, 20, -1\] (-1 for unlimited repetition length)

you should observe great reductions in sequence lengths when using repetition encodings.

### Dataset & Dataloader Definition
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Create a Dataset class from which training data can be sampled.

This Dataset should convert the encoded sequence above into tensors

and have a method for shuffling the order of multiple sequences while

leaving the patterns inside of each sequence untouched.

In [5]:
class UndertaleDeltaruneDataset(Dataset):
    def __init__(self, texts_dir, batch_size=1, max_repeats=15):
        self.texts = get_texts(texts_dir) # read and get a dictionary of {file_paths: text_contents}
        self.sequences, self.tokens = tokenize(texts_to_intlists((self.texts.values())), max_repeat_encoding=max_repeats) # convert and tokenize

        self.texts_dir = texts_dir
        self.batch_size = batch_size

    def __len__(self):
        return self.batch_size

    def data_len(self):
        return sum([len(sequence) for sequence in self.sequences])
    
    def update_tokens(self, new_tokens):
        if len(self.tokens) != len(new_tokens):
            raise ValueError("Token dictionary sizes mismatch - old: {} | new: {}".format(len(self.tokens), len(new_tokens)))
        
        new_token_value_to_idx = collections.OrderedDict((v, k) for k, v in new_tokens.items())
        
        for old_key, old_val in self.tokens.items():
            if old_val not in new_token_value_to_idx:
                raise ValueError("Old [{}] value {} not found in the new tokens".format(old_key, old_val))
        
        self.sequences = [[new_token_value_to_idx[self.tokens[old_key]] for old_key in sequence]
                          for sequence in self.sequences]
        self.tokens    = new_tokens

    def __getitem__(self, index):
        shuffled_list = list(itertools.chain(*random.sample(self.sequences, len(self.sequences))))
        return torch.LongTensor(shuffled_list[:-1]), torch.LongTensor(shuffled_list[1:])

Create a custom class that loads the data from the dataset above and

allows iteration over the dataset, yielding a small sequence batch at a time.

In [6]:
class UDBatchLoader:
    def __init__(self, ud_dataset, batch_size, sequence_len, drop_last=False, batch_first=True):
        self.ud_dataset = ud_dataset
        self.batch_size = batch_size
        self.sequence_len = sequence_len
        self.drop_last = drop_last
        self.batch_first = batch_first
    
    def __len__(self):
        if self.drop_last:
            return math.floor((self.ud_dataset.data_len() - 1) / self.sequence_len)
        return math.ceil((self.ud_dataset.data_len() - 1) / self.sequence_len)
    
    def generator(self):
        seq_len = self.sequence_len
        n_seq_batches = self.__len__()
        batch_first = self.batch_first
        
        input_batch, target_batch = next(iter(DataLoader(self.ud_dataset, self.batch_size)))
        if not batch_first:
            input_batch = input_batch.transpose(0, 1).contiguous()
            target_batch = target_batch.transpose(0, 1).contiguous()
        
        for start, end in zip(range(0, seq_len * n_seq_batches, seq_len), range(seq_len, (seq_len + 1) * n_seq_batches, seq_len)):
            if batch_first:
                yield (input_batch[:, start:end].contiguous(), target_batch[:, start:end].contiguous())
            else:
                yield (input_batch[start:end], target_batch[start:end])
    
    def __iter__(self):
        return self.generator()

---

## Model Definition
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Define the model architectures.

### Generator

In [7]:
class Generator(nn.Module):
    def __init__(self, n_tokens):
        super(Generator, self).__init__()

        self.init_hidden0 = nn.Parameter(torch.randn(1, 1, 256))
        self.init_hidden1 = nn.Parameter(torch.randn(1, 1, 256))
        self.init_hidden2 = nn.Parameter(torch.randn(1, 1, 256))

        self.embed = nn.Embedding(num_embeddings=n_tokens, embedding_dim=128)

        self.norm0 = nn.LayerNorm(128)
        self.gru0  = nn.GRU(input_size=128, hidden_size=256, batch_first=True)

        self.norm1 = nn.LayerNorm(256)
        self.gru1  = nn.GRU(input_size=256, hidden_size=256, batch_first=True)

        self.norm2 = nn.LayerNorm(256)
        self.gru2  = nn.GRU(input_size=256, hidden_size=256, batch_first=True)

        self.fc0 = nn.Sequential(
            nn.LayerNorm(256),
            nn.Linear(in_features=256, out_features=512)
        )

        self.fc1 = nn.Sequential(
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(in_features=512, out_features=256)
        )

        self.fc2 = nn.Sequential(
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(in_features=256, out_features=n_tokens)
        )

    def forward(self, x, hiddens=None, return_all_hiddens=False):
        if hiddens is None:
            hiddens = self.get_init_hiddens(x.size(0))

        if return_all_hiddens:
            internal_hiddens = []

        x = self.embed(x)

        x, new_hidden0 = self.gru0(self.norm0(x), hiddens[0])
        if return_all_hiddens:
            internal_hiddens.append(x)
        x, new_hidden1 = self.gru1(self.norm1(x), hiddens[1])
        if return_all_hiddens:
            internal_hiddens.append(x)
        x, new_hidden2 = self.gru2(self.norm2(x), hiddens[2])
        if return_all_hiddens:
            internal_hiddens.append(x)

        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)

        if return_all_hiddens:
            return x, [new_hidden0, new_hidden1, new_hidden2], internal_hiddens
        return x, [new_hidden0, new_hidden1, new_hidden2]

    def get_init_hiddens(self, n_batches):
        return [self.init_hidden0.repeat(1, n_batches, 1),
                self.init_hidden1.repeat(1, n_batches, 1),
                self.init_hidden2.repeat(1, n_batches, 1)]

### Discriminator

In [8]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.init_hidden0 = nn.Parameter(torch.randn(2, 1, 128))
        self.init_hidden1 = nn.Parameter(torch.randn(2, 1, 128))
        self.init_hidden2 = nn.Parameter(torch.randn(2, 1, 128))

        self.norm0 = nn.LayerNorm(768)
        self.gru0  = nn.GRU(input_size=768, hidden_size=128, batch_first=True, bidirectional=True)

        self.norm1 = nn.LayerNorm(256)
        self.gru1  = nn.GRU(input_size=256, hidden_size=128, batch_first=True, bidirectional=True)

        self.norm2 = nn.LayerNorm(256)
        self.gru2  = nn.GRU(input_size=256, hidden_size=128, batch_first=True, bidirectional=True)

        self.fc0 = nn.Sequential(
            nn.LayerNorm(256),
            nn.Linear(in_features=256, out_features=512)
        )

        self.fc1 = nn.Sequential(
            nn.ReLU(),
            nn.LayerNorm(512),
            nn.Linear(in_features=512, out_features=256)
        )

        self.fc2 = nn.Sequential(
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(in_features=256, out_features=1)
        )

    def forward(self, x, hiddens=None):
        if hiddens is None:
            hiddens = self.get_init_hiddens(x.size(0))

        x, new_hidden0 = self.gru0(self.norm0(x), hiddens[0])
        x, new_hidden1 = self.gru1(self.norm1(x), hiddens[1])
        x, new_hidden2 = self.gru2(self.norm2(x), hiddens[2])

        ### Consider only the final outputs from both directions by:
        ### 1) separating the directions,
        ### 2) taking the last timestep's output for forward direction([:, -1, 0, :])
        ###    and the first timestep's output for backward direction([:, 0, 1, :]),
        ### 3) then fianlly resizing the selected output into [n_batch, features] form.
        x = x.view(*x.shape[:-1], 2, -1)[:, [-1, 0], [0, 1], :].view(x.size(0), -1)

        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)

        return x, [new_hidden0, new_hidden1, new_hidden2]

    def get_init_hiddens(self, n_batches):
        return [self.init_hidden0.repeat(1, n_batches, 1),
                self.init_hidden1.repeat(1, n_batches, 1),
                self.init_hidden2.repeat(1, n_batches, 1)]

---

## Training
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

In [9]:
def free_running_generation(generator, inputs, seq_len, return_all_hiddens=False):
    """
    Params
    ======
    generator (Generator): the generator model.
    inputs (LongTensor): 2D tensor with dimensions [n_batches, 1].
                         If the given dimensions are [n_batches, seq_len],
                         then only the first timestep is used.
    seq_len (int): length of sequence to generate.
    return_all_hiddens (bool, optional): whether to retrieve the internal hidden states
                                         for all timesteps from every RNN layer. (default: False)

    Returns
    =======
    output_sequence (LongTensor): tensor of generated outputs.
    hiddens (list[Tensor]): updated hidden states.
    internal_hiddens (list[Tensor]): list of tensors containing internal hidden states
                                     for all timesteps from every RNN layer. Returned only
                                     if `return_all_hiddens` is True.
    """
    output_sequence = []
    if return_all_hiddens:
        internal_hiddens = []

    hiddens = generator.get_init_hiddens(inputs.size(0))
    x = inputs[:, :1]

    for i in range(seq_len):
        if return_all_hiddens:
            y, hiddens, current_internal_hiddens = generator(x, hiddens, return_all_hiddens=True)
            internal_hiddens.append(current_internal_hiddens)
        else:
            y, hiddens = generator(x, hiddens, return_all_hiddens=False)
        y = torch.multinomial(y.squeeze(1).softmax(dim=-1), num_samples=1)
        output_sequence.append(y)
        x = y
    output_sequence = torch.cat(output_sequence, dim=1)
    
    if return_all_hiddens:
        internal_hiddens = [torch.cat(layer_hiddens, dim=1) for layer_hiddens in zip(*internal_hiddens)]
        return output_sequence, hiddens, internal_hiddens
    return output_sequence, hiddens

In [10]:
def generator_loss(discriminator, output_teacher, target_teacher, b_seq_teacher, b_seq_free,
                   teacher_forcing_loss=True, free_teacher_behavior_loss=True, teacher_free_behavior_loss=True):
    """
    Params
    ======
    discriminator (Discriminator): the Discriminator model.
    output_teacher (Tensor): output of the Generator during teacher-forcing mode.
    target_teacher (Tensor): target for the Generator during teacher-forcing mode.
    b_seq_teacher (Tensor): behavior sequence from Generator during teacher-forcing mode.
                            Must have dimensions [n_batches, seq_len, behavior_size].
    b_seq_free (Tensor): behavior sequence from Generator during free-running mode.
                         Assumed to have the same dimensionality as `b_seq_teacher`.
    teacher_forcing_loss (bool): whether to calculate and return the loss for teacher-forced training. (default: True)
    free_teacher_behavior_loss (bool): whether to calculate and return the loss for matching the
                                       free-running behavior to the teacher-forced behavior.(default: True)
    teacher_free_behavior_loss (bool): whether to calculate and return the loss for matching the
                                       teacher-forced behavior to the free-running behavior.(default: True)

    Returns
    =======
    teacher_forcing_loss (Tensor): loss for maximizing the likelihood of the data
                                   during teacher-forcing mode. Returned if
                                   `teacher_forcing_loss` is True.
    free_teacher_behavior_loss (Tensor): loss for changing the free-running behavior so that
                                         it better matches the teacher-forced behavior,
                                         considering the latter fixed. Returned if
                                         `free_teacher_behavior_loss` is True.
    teacher_free_behavior_loss (Tensor): loss for making the teacher-forced behavior
                                         indistinguishable from the free-running behavior.
                                         Returned if `teacher_free_behavior_loss` is True.
    """
    losses = []
    
    if teacher_forcing_loss:
        losses.append(F.cross_entropy(output_teacher.view(-1, output_teacher.size(-1)), target_teacher.view(-1)))

    if free_teacher_behavior_loss:
        hiddens = discriminator.get_init_hiddens(n_batches=b_seq_free.size(0))
        raw_preds, _ = discriminator(b_seq_free, hiddens)
        losses.append(F.binary_cross_entropy_with_logits(raw_preds, torch.ones_like(raw_preds)))

    if teacher_free_behavior_loss:
        hiddens = discriminator.get_init_hiddens(n_batches=b_seq_teacher.size(0))
        raw_preds, _ = discriminator(b_seq_teacher, hiddens)
        losses.append(F.binary_cross_entropy_with_logits(raw_preds, torch.zeros_like(raw_preds)))

    return losses

In [11]:
def discriminator_loss(discriminator, b_seq_teacher, b_seq_free, return_acc=False):
    """
    Params
    ======
    discriminator (Discriminator): the Discriminator model.
    b_seq_teacher (Tensor): behavior sequence from Generator during teacher-forcing mode.
                            Must have dimensions [n_batches, seq_len, behavior_size].
    b_seq_free (Tensor): behavior sequence from Generator during free-running mode.
                         Assumed to have the same dimensionality as `b_seq_teacher`.
    return_acc (bool): whether to return the Discriminator accuracy or not. (default: False)

    Returns
    =======
    discriminator_loss (Tensor): the Discriminator loss as a scalar tensor.
    discriminator_acc (float): accuracy of the Discriminator for the given data.
                               Returned only if `return_acc` is True.
    """
    ### Ensure that gradients will flow only through the discriminator and not the generator
    b_seq_teacher = b_seq_teacher.detach()
    b_seq_free    = b_seq_free.detach()

    inputs  = torch.cat([b_seq_teacher, b_seq_free], dim=0)
    targets = torch.cat([b_seq_teacher.new_ones(b_seq_teacher.size(0), 1),
                         b_seq_free.new_zeros(b_seq_free.size(0), 1)], dim=0)

    hiddens = discriminator.get_init_hiddens(n_batches=inputs.size(0))
    raw_preds, _ = discriminator(inputs, hiddens)

    if return_acc:
        preds = raw_preds.sigmoid()

        discriminator_loss = F.binary_cross_entropy(preds, targets, reduction='none')
        discriminator_loss = discriminator_loss[:b_seq_teacher.size(0)].mean() + discriminator_loss[b_seq_teacher.size(0):].mean() ### Calculate separate average losses
        discriminator_acc  = (preds.gt(0.5) == targets.byte()).float().mean().item()

        return discriminator_loss, discriminator_acc

    else:
        discriminator_loss = F.binary_cross_entropy_with_logits(raw_preds, targets, reduction='none')
        discriminator_loss = discriminator_loss[:b_seq_teacher.size(0)].mean() + discriminator_loss[b_seq_teacher.size(0):].mean() ### Calculate separate average losses

        return discriminator_loss

In [12]:
seed                   = 0
n_epochs               = 1000
batch_size             = 32
sequence_length        = 500

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

random.seed(seed)
torch.manual_seed(seed)

ud_dataset = UndertaleDeltaruneDataset("./source/converted_texts", batch_size)
ud_loader = UDBatchLoader(ud_dataset, batch_size, sequence_length, drop_last=False, batch_first=True)

generator     = Generator(n_tokens=len(ud_dataset.tokens)).to(device)
discriminator = Discriminator().to(device)

generator_optimizer     = optim.Adam(generator.parameters(), lr=5e-5)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)

history = {}

print()
print('Data Sequence Total Length:', ud_dataset.data_len())
print()
print(generator)
print()
print(discriminator)

7536 tokens

Data Sequence Total Length: 88709

Generator(
  (embed): Embedding(7536, 128)
  (norm0): LayerNorm(torch.Size([128]), eps=1e-05, elementwise_affine=True)
  (gru0): GRU(128, 256, batch_first=True)
  (norm1): LayerNorm(torch.Size([256]), eps=1e-05, elementwise_affine=True)
  (gru1): GRU(256, 256, batch_first=True)
  (norm2): LayerNorm(torch.Size([256]), eps=1e-05, elementwise_affine=True)
  (gru2): GRU(256, 256, batch_first=True)
  (fc0): Sequential(
    (0): LayerNorm(torch.Size([256]), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=256, out_features=512, bias=True)
  )
  (fc1): Sequential(
    (0): ReLU()
    (1): LayerNorm(torch.Size([512]), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=512, out_features=256, bias=True)
  )
  (fc2): Sequential(
    (0): ReLU()
    (1): LayerNorm(torch.Size([256]), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=256, out_features=7536, bias=True)
  )
)

Discriminator(
  (norm0): LayerNorm(

In [None]:
teacher_forcing_always_on = True

i = 0
while True:
    for inputs, targets in ud_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        teacher_forcing_outputs, _, teacher_internal_hiddens = generator(inputs, return_all_hiddens=True)
        generated_sequence,      _, free_internal_hiddens    = free_running_generation(generator, inputs, inputs.size(1), return_all_hiddens=True)

        b_seq_teacher = torch.cat(teacher_internal_hiddens, dim=-1)
        b_seq_free    = torch.cat(free_internal_hiddens,    dim=-1)

        D_loss, D_acc = discriminator_loss(discriminator, b_seq_teacher, b_seq_free, return_acc=True)

        if D_acc < 0.99: # If too good, don't train the Discriminator
            discriminator_optimizer.zero_grad()
            D_loss.backward()
            discriminator_optimizer.step()
        D_loss = D_loss.item()

        if D_acc > 0.75: # If too bad, don't use the Discriminator to train the Generator
            t_force_loss, ft_b_loss, tf_b_loss = generator_loss(discriminator, teacher_forcing_outputs, targets,
                                                                b_seq_teacher, b_seq_free,
                                                                teacher_forcing_loss=True,
                                                                free_teacher_behavior_loss=True,
                                                                teacher_free_behavior_loss=True)

            G_loss = (t_force_loss + ft_b_loss + tf_b_loss) / 3

            generator_optimizer.zero_grad()
            G_loss.backward()
            generator_optimizer.step()

            G_loss       = G_loss.item()
            t_force_loss = t_force_loss.item()
            ft_b_loss    = ft_b_loss.item()
            tf_b_loss    = tf_b_loss.item()
        else:
            if teacher_forcing_always_on:
                t_force_loss,                      = generator_loss(discriminator, teacher_forcing_outputs, targets,
                                                                    b_seq_teacher, b_seq_free,
                                                                    teacher_forcing_loss=True,
                                                                    free_teacher_behavior_loss=False,
                                                                    teacher_free_behavior_loss=False)

                G_loss = t_force_loss

                generator_optimizer.zero_grad()
                G_loss.backward()
                generator_optimizer.step()

                G_loss       = G_loss.item()
                t_force_loss = t_force_loss.item()
            else:
                G_loss       = float('nan')
                t_force_loss = float('nan')
            ft_b_loss = float('nan')
            tf_b_loss = float('nan')

        G_acc = (teacher_forcing_outputs.argmax(dim=-1) == targets).float().mean().item()

        print("[Update {}]".format(i))
        print("Discriminator")
        print("=============")
        print("    Loss:", D_loss)
        print("    Acc:", D_acc)
        print("Generator")
        print("=============")
        print("    Loss:", G_loss)
        print("        Teacher-Force:", t_force_loss)
        print("        FreeToTeacher:", ft_b_loss)
        print("        TeacherToFree:", tf_b_loss)
        print("    Acc:", G_acc)
        print("####################################")
        print()
        i += 1

        if i % 100 == 0:
            torch.save(generated_sequence.cpu().numpy(), "professor_forcing_temp/{}.pth".format(i))

In [14]:
torch.save({'g': generator.state_dict(), 'd': discriminator.state_dict()}, 'professor_forcing_temp/GD_30000.pth')

In [15]:
def tokens_to_intlist(tokens, token_to_key_dict):
    if len(tokens) == 0:
        return []

    patterns = []
    intlist = []
    last_pattern = tuple()

    for token in tokens:
        pattern = token_to_key_dict[token]
        if pattern != tuple() and pattern[0] == "<REPEAT>":
            for _ in range(pattern[1]):
                patterns.append(last_pattern)
        else:
            patterns.append(pattern)
            last_pattern = pattern

    intlist.extend(patterns[0])
    for pattern in patterns[1:]:
        intlist.append(0)
        intlist.extend(pattern)
    
    return intlist

In [19]:
generated_sequence = tokens_to_intlist(free_running_generation(generator, inputs[:1], 50000, return_all_hiddens=False)[0][0].cpu().numpy(), ud_dataset.tokens)

In [20]:
torch.save(np.array(generated_sequence), "professor_forcing_temp/generated_sequence_chord.pth")

---