# Undertale & Deltarune Soundtrack Generator

---

## Table of Contents

0. [**Table of Contents**](#Table-of-Contents)

1. [**Imports**](#Imports)

2. [**Data Processing**](#Data-Processing)

    2.1 [Data Loading](#Data-Loading)
    
    2.2 [Data Preprocessing](#Data-Preprocessing)
    
    2.3 [Dataset & Dataloader Definition](#Dataset-&-Dataloader-Definition)
    
3. [**Model Definition**](#Model-Definition)
    
4. [**Hyperparameters & Instantiation**](#Hyperparameters-&-Instantiation)

5. [**Training**](#Training)
    
    4.1 [Training Function](#Training-Function)
    
    4.2 [Training Session](#Training-Session)

6. [**Loading the Best Model**](#Loading-the-Best-Model)

7. [**Generation**](#Generation)
    
    6.1 [Sampling Function](#Sampling-Function)

    6.2 [Generation Function](#Generation-Function)
    
    6.3 [Music Generation](#Music-Generation)

8. [**Final Summary, Notes, and Thoughts**](#Final-Summary,-Notes,-and-Thoughts)

---

## Imports
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Import required packages

In [1]:
import os                                         # File handling
import itertools                                  # chain() for merging lists
import random                                     # Shuffling
import collections                                # Useful tools like Counter, OrderedDict
import math                                       # For... math
from decimal import Decimal                       # Scientific notations in string formatting
from time import time                             # For use in progress bar

import tqdm.auto as tqdm                          # Progress bar

import torch                                      # Deep Learning Framework
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt                   # Plotting training progress
from matplotlib.ticker import AutoLocator
%matplotlib inline

fig_bg_color = "lightsteelblue"
plot_bg_color = "slategray"
fontsize = 20

---

## Data Processing
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

### Data Loading
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Read the text files in the target directory.

Do some processing to make sure the texts are clean.

In [2]:
def get_texts(texts_dir):

    if not os.path.isdir(texts_dir):
        raise FileNotFoundError("given text directory not found: {}".format(texts_dir))

    texts = []
    
    for text_path in (file.path for file in os.scandir(texts_dir) if file.is_file() and file.name.endswith(".txt")):
        with open(file=text_path, mode='r', encoding="utf-8") as text_file:
            
            text = text_file.read().strip()

            if not text.replace(' ', '').isdigit():
                raise RuntimeError("one or more characters other than digits and white spaces are detected: {}".format(text_path))

            while "  " in text:
                text = text.replace("  ", ' ')
            
            texts.append((text_path, text))
    
    return dict(texts)


[(os.path.split(text_path)[1], text[:20]) for text_path, text in get_texts("./source/converted_texts").items()]

[('ANOTHER_HIM_-_DeltaRune.txt', '42 46 49 53 0 42 46 '),
 ('A_Town_Called_Hometown_Deltarune_-_Arranged_for_Piano.txt',
  '73 89 0 73 89 0 73 8'),
 ('Basement_Deltarune_-_Arranged_for_Piano.txt', '39 51 0 39 51 0 39 5'),
 ('Before_the_Story_Deltarune_-_Arranged_for_piano_.txt',
  '48 0 48 0 48 0 48 0 '),
 ('Card_Castle_Deltarune_-_Arranged_for_Piano.txt', '39 0 39 0 39 0 39 0 '),
 ('Checker_Dance_Deltarune_-_Arranged_for_Piano.txt', '30 0 30 0 30 0 30 0 '),
 ('Deltarune_-_Beginning.txt', '48 55 0 48 55 0 48 5'),
 ('Deltarune_-_Chaos_King.txt', '27 39 0 27 39 0 27 3'),
 ('Deltarune_-_Darkness_Falls.txt', '61 64 71 75 0 61 64 '),
 ('Deltarune_-_Dont_Forget_Ending_Theme_Solo_Piano_Version.txt',
  '77 0 77 0 77 0 77 0 '),
 ('Deltarune_-_Friendship.txt', '74 0 74 0 74 0 74 0 '),
 ('Deltarune_-_Gallery.txt', '32 36 39 68 0 32 36 '),
 ('Deltarune_-_Lancer_Battle.txt', '62 0 62 0 62 0 0 0 0'),
 ('DELTARUNE_-_Lancer_piano_solo.txt', '0 0 0 62 0 62 0 62 0'),
 ('Deltarune_-_Lantern.txt', '49 0 4

### Data Preprocessing
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Get integers out of the text and make lists of ints.

These lists can be used for the input of the models, or be further processed to compress or simplify the sequences.

In this notebook, I'll leave the data as it is and do note-by-note. (Similar to Character-By-Character approach)

In [3]:
def texts_to_intlists(text_list):
    
    intlists = []
    
    for i, text in enumerate(iterable=text_list):
        
        int_strings = text.split(' ')
        
        if not all(int_str.isdigit() for int_str in int_strings):
            raise RuntimeError("non-digit string detected in text {}".format(i))

        ints = [int(int_str) for int_str in int_strings]
        
        intlists.append(ints)
        
    return intlists


print([ints[:10] for ints in texts_to_intlists(get_texts("./source/converted_texts").values())])

[[42, 46, 49, 53, 0, 42, 46, 49, 53, 0], [73, 89, 0, 73, 89, 0, 73, 89, 0, 73], [39, 51, 0, 39, 51, 0, 39, 51, 0, 39], [48, 0, 48, 0, 48, 0, 48, 0, 48, 0], [39, 0, 39, 0, 39, 0, 39, 0, 39, 0], [30, 0, 30, 0, 30, 0, 30, 0, 30, 0], [48, 55, 0, 48, 55, 0, 48, 55, 0, 48], [27, 39, 0, 27, 39, 0, 27, 39, 0, 27], [61, 64, 71, 75, 0, 61, 64, 71, 75, 0], [77, 0, 77, 0, 77, 0, 77, 0, 77, 0], [74, 0, 74, 0, 74, 0, 74, 0, 74, 0], [32, 36, 39, 68, 0, 32, 36, 39, 68, 0], [62, 0, 62, 0, 62, 0, 0, 0, 0, 65], [0, 0, 0, 62, 0, 62, 0, 62, 0, 62], [49, 0, 49, 0, 49, 0, 49, 0, 49, 0], [31, 43, 0, 31, 43, 0, 31, 43, 0, 31], [24, 31, 0, 24, 31, 0, 24, 31, 0, 24], [45, 57, 0, 45, 57, 0, 45, 57, 0, 45], [39, 0, 39, 0, 39, 0, 39, 0, 39, 0], [46, 0, 46, 0, 46, 0, 46, 0, 46, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [58, 70, 0, 58, 70, 0, 58, 70, 0, 58], [37, 49, 0, 37, 49, 0, 37, 49, 0, 37], [44, 68, 0, 44, 68, 0, 44, 68, 0, 44], [67, 0, 67, 0, 67, 0, 67, 0, 67, 0], [61, 0, 61, 0, 61, 0, 61, 0, 61, 0], [49, 0, 49, 0, 

### Dataset & Dataloader Definition
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Create a Dataset class from which training data can be sampled.

This Dataset should convert the encoded sequence above into tensors

and have a method for shuffling the order of multiple sequences while

leaving the patterns inside of each sequence untouched.

In [4]:
class UndertaleDeltaruneDataset(Dataset):
    def __init__(self, texts_dir, batch_size=1):
        self.texts = get_texts(texts_dir) # read and get a dictionary of {file_paths: text_contents}
        self.sequences = texts_to_intlists(self.texts.values())

        self.texts_dir = texts_dir
        self.batch_size = batch_size

    def __len__(self):
        return self.batch_size

    def data_len(self):
        return sum([len(sequence) for sequence in self.sequences])

    def __getitem__(self, index):
        shuffled_list = list(itertools.chain(*random.sample(self.sequences, len(self.sequences))))
        inputs = torch.LongTensor(shuffled_list[:-1])
        labels = torch.LongTensor(shuffled_list[1:])
        return inputs, labels

Create a custom class that loads the data from the dataset above and

allows iteration over the dataset, yielding a small sequence batch at a time.

In [5]:
class UDBatchLoader:
    def __init__(self, ud_dataset, batch_size, sequence_len, drop_last=False, batch_first=True):
        self.ud_dataset = ud_dataset
        self.batch_size = batch_size
        self.sequence_len = sequence_len
        self.drop_last = drop_last
        self.batch_first = batch_first
    
    def __len__(self):
        if self.drop_last:
            return math.floor((self.ud_dataset.data_len() - 1) / self.sequence_len)
        return math.ceil((self.ud_dataset.data_len() - 1) / self.sequence_len)
    
    def generator(self):
        seq_len = self.sequence_len
        n_seq_batches = self.__len__()
        batch_first = self.batch_first
        
        input_batch, target_batch = next(iter(DataLoader(self.ud_dataset, self.batch_size)))
        if not batch_first:
            input_batch = input_batch.transpose(0, 1).contiguous()
            target_batch = target_batch.transpose(0, 1).contiguous()
        
        for start, end in zip(range(0, seq_len * n_seq_batches, seq_len), range(seq_len, (seq_len + 1) * n_seq_batches, seq_len)):
            if batch_first:
                yield (input_batch[:, start:end].contiguous(), target_batch[:, start:end].contiguous())
            else:
                yield (input_batch[start:end], target_batch[start:end])
    
    def __iter__(self):
        return self.generator()

---

## Model Definition
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Define the model architectures.

### Generator

In [6]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.init_hidden0 = nn.Parameter(torch.randn(1, 1, 64))
        self.init_hidden1 = nn.Parameter(torch.randn(1, 1, 64))
        self.init_hidden2 = nn.Parameter(torch.randn(1, 1, 64))

        self.embed = nn.Embedding(num_embeddings=129, embedding_dim=64)

        self.norm0 = nn.LayerNorm(64)
        self.gru0  = nn.GRU(input_size=64, hidden_size=64, batch_first=True)

        self.norm1 = nn.LayerNorm(64)
        self.gru1  = nn.GRU(input_size=64, hidden_size=64, batch_first=True)

        self.norm2 = nn.LayerNorm(64)
        self.gru2  = nn.GRU(input_size=64, hidden_size=64, batch_first=True)

        self.fc0 = nn.Sequential(
            nn.LayerNorm(64),
            nn.Linear(in_features=64, out_features=128)
        )

        self.fc1 = nn.Sequential(
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Linear(in_features=128, out_features=256)
        )

        self.fc2 = nn.Sequential(
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(in_features=256, out_features=129)
        )

    def forward(self, x, hiddens=None, return_all_hiddens=False):
        if hiddens is None:
            hiddens = self.get_init_hiddens(x.size(0))

        if return_all_hiddens:
            internal_hiddens = []

        x = self.embed(x)

        x, new_hidden0 = self.gru0(self.norm0(x), hiddens[0])
        if return_all_hiddens:
            internal_hiddens.append(x)
        x, new_hidden1 = self.gru1(self.norm1(x), hiddens[1])
        if return_all_hiddens:
            internal_hiddens.append(x)
        x, new_hidden2 = self.gru2(self.norm2(x), hiddens[2])
        if return_all_hiddens:
            internal_hiddens.append(x)

        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)

        if return_all_hiddens:
            return x, [new_hidden0, new_hidden1, new_hidden2], internal_hiddens
        return x, [new_hidden0, new_hidden1, new_hidden2]

    def get_init_hiddens(self, n_batches):
        return [self.init_hidden0.repeat(1, n_batches, 1),
                self.init_hidden1.repeat(1, n_batches, 1),
                self.init_hidden2.repeat(1, n_batches, 1)]

### Discriminator

In [7]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.init_hidden0 = nn.Parameter(torch.randn(2, 1, 32))
        self.init_hidden1 = nn.Parameter(torch.randn(2, 1, 32))
        self.init_hidden2 = nn.Parameter(torch.randn(2, 1, 32))

        self.norm0 = nn.LayerNorm(192)
        self.gru0  = nn.GRU(input_size=192, hidden_size=32, batch_first=True, bidirectional=True)

        self.norm1 = nn.LayerNorm(64)
        self.gru1  = nn.GRU(input_size=64, hidden_size=32, batch_first=True, bidirectional=True)

        self.norm2 = nn.LayerNorm(64)
        self.gru2  = nn.GRU(input_size=64, hidden_size=32, batch_first=True, bidirectional=True)

        self.fc0 = nn.Sequential(
            nn.LayerNorm(64),
            nn.Linear(in_features=64, out_features=128)
        )

        self.fc1 = nn.Sequential(
            nn.ReLU(),
            nn.LayerNorm(128),
            nn.Linear(in_features=128, out_features=256)
        )

        self.fc2 = nn.Sequential(
            nn.ReLU(),
            nn.LayerNorm(256),
            nn.Linear(in_features=256, out_features=1)
        )

    def forward(self, x, hiddens=None):
        if hiddens is None:
            hiddens = self.get_init_hiddens(x.size(0))

        x, new_hidden0 = self.gru0(self.norm0(x), hiddens[0])
        x, new_hidden1 = self.gru1(self.norm1(x), hiddens[1])
        x, new_hidden2 = self.gru2(self.norm2(x), hiddens[2])

        ### Consider only the final outputs from both directions by:
        ### 1) separating the directions,
        ### 2) taking the last timestep's output for forward direction([:, -1, 0, :])
        ###    and the first timestep's output for backward direction([:, 0, 1, :]),
        ### 3) then fianlly resizing the selected output into [n_batch, features] form.
        x = x.view(*x.shape[:-1], 2, -1)[:, [-1, 0], [0, 1], :].view(x.size(0), -1)

        x = self.fc0(x)
        x = self.fc1(x)
        x = self.fc2(x)

        return x, [new_hidden0, new_hidden1, new_hidden2]

    def get_init_hiddens(self, n_batches):
        return [self.init_hidden0.repeat(1, n_batches, 1),
                self.init_hidden1.repeat(1, n_batches, 1),
                self.init_hidden2.repeat(1, n_batches, 1)]

---

## Training Function
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

In [8]:
def free_running_generation(generator, inputs, seq_len, return_all_hiddens=False):
    """
    Params
    ======
    generator (Generator): the generator model.
    inputs (LongTensor): 2D tensor with dimensions [n_batches, 1].
                         If the given dimensions are [n_batches, seq_len],
                         then only the first timestep is used.
    seq_len (int): length of sequence to generate.
    return_all_hiddens (bool, optional): whether to retrieve the internal hidden states
                                         for all timesteps from every RNN layer. (default: False)

    Returns
    =======
    output_sequence (LongTensor): tensor of generated outputs.
    hiddens (list[Tensor]): updated hidden states.
    internal_hiddens (list[Tensor]): list of tensors containing internal hidden states
                                     for all timesteps from every RNN layer. Returned only
                                     if `return_all_hiddens` is True.
    """
    output_sequence = []
    if return_all_hiddens:
        internal_hiddens = []

    hiddens = generator.get_init_hiddens(inputs.size(0))
    x = inputs[:, :1]

    for i in range(seq_len):
        if return_all_hiddens:
            y, hiddens, current_internal_hiddens = generator(x, hiddens, return_all_hiddens=True)
            internal_hiddens.append(current_internal_hiddens)
        else:
            y, hiddens = generator(x, hiddens, return_all_hiddens=False)
        y = torch.multinomial(y.squeeze(1).softmax(dim=-1), num_samples=1)
        output_sequence.append(y)
        x = y
    output_sequence = torch.cat(output_sequence, dim=1)
    
    if return_all_hiddens:
        internal_hiddens = [torch.cat(layer_hiddens, dim=1) for layer_hiddens in zip(*internal_hiddens)]
        return output_sequence, hiddens, internal_hiddens
    return output_sequence, hiddens

In [9]:
def generator_loss(discriminator, output_teacher, target_teacher, b_seq_teacher, b_seq_free,
                   teacher_forcing_loss=True, free_teacher_behavior_loss=True, teacher_free_behavior_loss=True):
    """
    Params
    ======
    discriminator (Discriminator): the Discriminator model.
    output_teacher (Tensor): output of the Generator during teacher-forcing mode.
    target_teacher (Tensor): target for the Generator during teacher-forcing mode.
    b_seq_teacher (Tensor): behavior sequence from Generator during teacher-forcing mode.
                            Must have dimensions [n_batches, seq_len, behavior_size].
    b_seq_free (Tensor): behavior sequence from Generator during free-running mode.
                         Assumed to have the same dimensionality as `b_seq_teacher`.
    teacher_forcing_loss (bool): whether to calculate and return the loss for teacher-forced training. (default: True)
    free_teacher_behavior_loss (bool): whether to calculate and return the loss for matching the
                                       free-running behavior to the teacher-forced behavior.(default: True)
    teacher_free_behavior_loss (bool): whether to calculate and return the loss for matching the
                                       teacher-forced behavior to the free-running behavior.(default: True)

    Returns
    =======
    teacher_forcing_loss (Tensor): loss for maximizing the likelihood of the data
                                   during teacher-forcing mode. Returned if
                                   `teacher_forcing_loss` is True.
    free_teacher_behavior_loss (Tensor): loss for changing the free-running behavior so that
                                         it better matches the teacher-forced behavior,
                                         considering the latter fixed. Returned if
                                         `free_teacher_behavior_loss` is True.
    teacher_free_behavior_loss (Tensor): loss for making the teacher-forced behavior
                                         indistinguishable from the free-running behavior.
                                         Returned if `teacher_free_behavior_loss` is True.
    """
    losses = []
    
    if teacher_forcing_loss:
        losses.append(F.cross_entropy(output_teacher.view(-1, output_teacher.size(-1)), target_teacher.view(-1)))

    if free_teacher_behavior_loss:
        hiddens = discriminator.get_init_hiddens(n_batches=b_seq_free.size(0))
        raw_preds, _ = discriminator(b_seq_free, hiddens)
        losses.append(F.binary_cross_entropy_with_logits(raw_preds, torch.ones_like(raw_preds)))

    if teacher_free_behavior_loss:
        hiddens = discriminator.get_init_hiddens(n_batches=b_seq_teacher.size(0))
        raw_preds, _ = discriminator(b_seq_teacher, hiddens)
        losses.append(F.binary_cross_entropy_with_logits(raw_preds, torch.zeros_like(raw_preds)))

    return losses

In [10]:
def discriminator_loss(discriminator, b_seq_teacher, b_seq_free, return_acc=False):
    """
    Params
    ======
    discriminator (Discriminator): the Discriminator model.
    b_seq_teacher (Tensor): behavior sequence from Generator during teacher-forcing mode.
                            Must have dimensions [n_batches, seq_len, behavior_size].
    b_seq_free (Tensor): behavior sequence from Generator during free-running mode.
                         Assumed to have the same dimensionality as `b_seq_teacher`.
    return_acc (bool): whether to return the Discriminator accuracy or not. (default: False)

    Returns
    =======
    discriminator_loss (Tensor): the Discriminator loss as a scalar tensor.
    discriminator_acc (float): accuracy of the Discriminator for the given data.
                               Returned only if `return_acc` is True.
    """
    ### Ensure that gradients will flow only through the discriminator and not the generator
    b_seq_teacher = b_seq_teacher.detach()
    b_seq_free    = b_seq_free.detach()

    inputs  = torch.cat([b_seq_teacher, b_seq_free], dim=0)
    targets = torch.cat([b_seq_teacher.new_ones(b_seq_teacher.size(0), 1),
                         b_seq_free.new_zeros(b_seq_free.size(0), 1)], dim=0)

    hiddens = discriminator.get_init_hiddens(n_batches=inputs.size(0))
    raw_preds, _ = discriminator(inputs, hiddens)

    if return_acc:
        preds = raw_preds.sigmoid()

        discriminator_loss = F.binary_cross_entropy(preds, targets, reduction='none')
        discriminator_loss = discriminator_loss[:b_seq_teacher.size(0)].mean() + discriminator_loss[b_seq_teacher.size(0):].mean() ### Calculate separate average losses
        discriminator_acc  = (preds.gt(0.5) == targets.byte()).float().mean().item()

        return discriminator_loss, discriminator_acc

    else:
        discriminator_loss = F.binary_cross_entropy_with_logits(raw_preds, targets, reduction='none')
        discriminator_loss = discriminator_loss[:b_seq_teacher.size(0)].mean() + discriminator_loss[b_seq_teacher.size(0):].mean() ### Calculate separate average losses

        return discriminator_loss

In [11]:
seed                   = 0
n_epochs               = 1000
batch_size             = 64
sequence_length        = 1200

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

random.seed(seed)
torch.manual_seed(seed)

ud_dataset = UndertaleDeltaruneDataset("./source/converted_texts", batch_size)
ud_loader = UDBatchLoader(ud_dataset, batch_size, sequence_length, drop_last=False, batch_first=True)

generator     = Generator().to(device)
discriminator = Discriminator().to(device)

generator_optimizer     = optim.Adam(generator.parameters(), lr=1e-3)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=1e-3)

history = {}

print()
print('Data Sequence Total Length:', ud_dataset.data_len())
print()
print(generator)
print()
print(discriminator)


Data Sequence Total Length: 893324

Generator(
  (embed): Embedding(129, 64)
  (norm0): LayerNorm(torch.Size([64]), eps=1e-05, elementwise_affine=True)
  (gru0): GRU(64, 64, batch_first=True)
  (norm1): LayerNorm(torch.Size([64]), eps=1e-05, elementwise_affine=True)
  (gru1): GRU(64, 64, batch_first=True)
  (norm2): LayerNorm(torch.Size([64]), eps=1e-05, elementwise_affine=True)
  (gru2): GRU(64, 64, batch_first=True)
  (fc0): Sequential(
    (0): LayerNorm(torch.Size([64]), eps=1e-05, elementwise_affine=True)
    (1): Linear(in_features=64, out_features=128, bias=True)
  )
  (fc1): Sequential(
    (0): ReLU()
    (1): LayerNorm(torch.Size([128]), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=128, out_features=256, bias=True)
  )
  (fc2): Sequential(
    (0): ReLU()
    (1): LayerNorm(torch.Size([256]), eps=1e-05, elementwise_affine=True)
    (2): Linear(in_features=256, out_features=129, bias=True)
  )
)

Discriminator(
  (norm0): LayerNorm(torch.Size([192]), eps=1e

In [None]:
teacher_forcing_always_on = True

i = 0
while True:
    for inputs, targets in ud_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        teacher_forcing_outputs, _, teacher_internal_hiddens = generator(inputs, return_all_hiddens=True)
        generated_sequence,      _, free_internal_hiddens    = free_running_generation(generator, inputs, inputs.size(1), return_all_hiddens=True)

        b_seq_teacher = torch.cat(teacher_internal_hiddens, dim=-1)
        b_seq_free    = torch.cat(free_internal_hiddens,    dim=-1)

        D_loss, D_acc = discriminator_loss(discriminator, b_seq_teacher, b_seq_free, return_acc=True)

        if D_acc < 0.99: # If too good, don't train the Discriminator
            discriminator_optimizer.zero_grad()
            D_loss.backward()
            discriminator_optimizer.step()
        D_loss = D_loss.item()

        if D_acc > 0.75: # If too bad, don't use the Discriminator to train the Generator
            t_force_loss, ft_b_loss, tf_b_loss = generator_loss(discriminator, teacher_forcing_outputs, targets,
                                                                b_seq_teacher, b_seq_free,
                                                                teacher_forcing_loss=True,
                                                                free_teacher_behavior_loss=True,
                                                                teacher_free_behavior_loss=True)

            G_loss = (t_force_loss + ft_b_loss + tf_b_loss) / 3

            generator_optimizer.zero_grad()
            G_loss.backward()
            generator_optimizer.step()

            G_loss       = G_loss.item()
            t_force_loss = t_force_loss.item()
            ft_b_loss    = ft_b_loss.item()
            tf_b_loss    = tf_b_loss.item()
        else:
            if teacher_forcing_always_on:
                t_force_loss,                      = generator_loss(discriminator, teacher_forcing_outputs, targets,
                                                                    b_seq_teacher, b_seq_free,
                                                                    teacher_forcing_loss=True,
                                                                    free_teacher_behavior_loss=False,
                                                                    teacher_free_behavior_loss=False)

                G_loss = t_force_loss

                generator_optimizer.zero_grad()
                G_loss.backward()
                generator_optimizer.step()

                G_loss       = G_loss.item()
                t_force_loss = t_force_loss.item()
            else:
                G_loss       = float('nan')
                t_force_loss = float('nan')
            ft_b_loss = float('nan')
            tf_b_loss = float('nan')

        G_acc = (teacher_forcing_outputs.argmax(dim=-1) == targets).float().mean().item()

        print("[Update {}]".format(i))
        print("Discriminator")
        print("=============")
        print("    Loss:", D_loss)
        print("    Acc:", D_acc)
        print("Generator")
        print("=============")
        print("    Loss:", G_loss)
        print("        Teacher-Force:", t_force_loss)
        print("        FreeToTeacher:", ft_b_loss)
        print("        TeacherToFree:", tf_b_loss)
        print("    Acc:", G_acc)
        print("####################################")
        print()
        i += 1

        if i % 100 == 0:
            torch.save(generated_sequence[0].cpu().numpy(), "teacher_forcing_temp/{}.pth".format(i))

[Update 0]
Discriminator
    Loss: 1.4877625703811646
    Acc: 0.5
Generator
    Loss: 4.9770355224609375
        Teacher-Force: 4.9770355224609375
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.008333333767950535
####################################

[Update 1]
Discriminator
    Loss: 1.2214598655700684
    Acc: 0.609375
Generator
    Loss: 4.419028282165527
        Teacher-Force: 4.419028282165527
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.20402343571186066
####################################

[Update 2]
Discriminator
    Loss: 0.9821017384529114
    Acc: 0.8359375
Generator
    Loss: 2.291712760925293
        Teacher-Force: 4.0899200439453125
        FreeToTeacher: 1.6097900867462158
        TeacherToFree: 1.1754281520843506
    Acc: 0.26532551646232605
####################################

[Update 3]
Discriminator
    Loss: 1.1395044326782227
    Acc: 0.671875
Generator
    Loss: 3.9739022254943848
        Teacher-Force: 3.9739022254943848
 

    Loss: 1.1235604286193848
    Acc: 0.7421875
Generator
    Loss: 2.6336145401000977
        Teacher-Force: 2.6336145401000977
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.3033463656902313
####################################

[Update 28]
Discriminator
    Loss: 1.1086500883102417
    Acc: 0.7265625
Generator
    Loss: 2.588038682937622
        Teacher-Force: 2.588038682937622
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.3179166615009308
####################################

[Update 29]
Discriminator
    Loss: 1.1235086917877197
    Acc: 0.734375
Generator
    Loss: 2.59507417678833
        Teacher-Force: 2.59507417678833
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.3251822888851166
####################################

[Update 30]
Discriminator
    Loss: 1.253517985343933
    Acc: 0.6484375
Generator
    Loss: 2.516648054122925
        Teacher-Force: 2.516648054122925
        FreeToTeacher: nan
        TeacherToFree: nan
  

        FreeToTeacher: 1.6870310306549072
        TeacherToFree: 1.0806372165679932
    Acc: 0.4743880331516266
####################################

[Update 54]
Discriminator
    Loss: 1.1582669019699097
    Acc: 0.6796875
Generator
    Loss: 1.9232943058013916
        Teacher-Force: 1.9232943058013916
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.48268231749534607
####################################

[Update 55]
Discriminator
    Loss: 1.0740430355072021
    Acc: 0.703125
Generator
    Loss: 1.9249471426010132
        Teacher-Force: 1.9249471426010132
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.4627734422683716
####################################

[Update 56]
Discriminator
    Loss: 1.0631177425384521
    Acc: 0.7421875
Generator
    Loss: 1.9065287113189697
        Teacher-Force: 1.9065287113189697
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.475273460149765
####################################

[Update 57]
Discriminator


        TeacherToFree: 1.2805256843566895
    Acc: 0.5448958277702332
####################################

[Update 80]
Discriminator
    Loss: 0.9674004316329956
    Acc: 0.78125
Generator
    Loss: 1.4467921257019043
        Teacher-Force: 1.7122180461883545
        FreeToTeacher: 1.4184117317199707
        TeacherToFree: 1.2097465991973877
    Acc: 0.5442057251930237
####################################

[Update 81]
Discriminator
    Loss: 1.0390102863311768
    Acc: 0.7421875
Generator
    Loss: 1.6807435750961304
        Teacher-Force: 1.6807435750961304
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.549791693687439
####################################

[Update 82]
Discriminator
    Loss: 1.1433026790618896
    Acc: 0.6640625
Generator
    Loss: 1.6802177429199219
        Teacher-Force: 1.6802177429199219
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.5536458492279053
####################################

[Update 83]
Discriminator
    Loss: 1.03

    Loss: 0.9547935724258423
    Acc: 0.8125
Generator
    Loss: 1.4054579734802246
        Teacher-Force: 1.5042732954025269
        FreeToTeacher: 0.8849246501922607
        TeacherToFree: 1.8271760940551758
    Acc: 0.6140885353088379
####################################

[Update 107]
Discriminator
    Loss: 0.9871605038642883
    Acc: 0.765625
Generator
    Loss: 1.4452381134033203
        Teacher-Force: 1.539881944656372
        FreeToTeacher: 1.150864839553833
        TeacherToFree: 1.6449673175811768
    Acc: 0.5972005128860474
####################################

[Update 108]
Discriminator
    Loss: 0.9493525624275208
    Acc: 0.7734375
Generator
    Loss: 1.4505620002746582
        Teacher-Force: 1.5411378145217896
        FreeToTeacher: 1.4895045757293701
        TeacherToFree: 1.3210433721542358
    Acc: 0.6039062738418579
####################################

[Update 109]
Discriminator
    Loss: 1.0762975215911865
    Acc: 0.7421875
Generator
    Loss: 1.5817055702209473
 

    Loss: 1.2661110162734985
    Acc: 0.6484375
Generator
    Loss: 1.3860678672790527
        Teacher-Force: 1.3860678672790527
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6421875357627869
####################################

[Update 133]
Discriminator
    Loss: 1.0211045742034912
    Acc: 0.7109375
Generator
    Loss: 1.3429291248321533
        Teacher-Force: 1.3429291248321533
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6608594059944153
####################################

[Update 134]
Discriminator
    Loss: 1.4107837677001953
    Acc: 0.6015625
Generator
    Loss: 1.3615450859069824
        Teacher-Force: 1.3615450859069824
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.651940107345581
####################################

[Update 135]
Discriminator
    Loss: 1.406010627746582
    Acc: 0.65625
Generator
    Loss: 1.3487321138381958
        Teacher-Force: 1.3487321138381958
        FreeToTeacher: nan
        TeacherToFre

    Acc: 0.7173047065734863
####################################

[Update 159]
Discriminator
    Loss: 1.0196411609649658
    Acc: 0.7421875
Generator
    Loss: 1.1281627416610718
        Teacher-Force: 1.1281627416610718
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7394140958786011
####################################

[Update 160]
Discriminator
    Loss: 1.2120656967163086
    Acc: 0.6328125
Generator
    Loss: 1.1319491863250732
        Teacher-Force: 1.1319491863250732
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7312109470367432
####################################

[Update 161]
Discriminator
    Loss: 1.1691911220550537
    Acc: 0.6640625
Generator
    Loss: 1.1233611106872559
        Teacher-Force: 1.1233611106872559
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7371484637260437
####################################

[Update 162]
Discriminator
    Loss: 1.0327821969985962
    Acc: 0.7578125
Generator
    Loss: 1.190855264

        TeacherToFree: nan
    Acc: 0.6498437523841858
####################################

[Update 185]
Discriminator
    Loss: 1.4399385452270508
    Acc: 0.6875
Generator
    Loss: 1.21629798412323
        Teacher-Force: 1.21629798412323
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6969401240348816
####################################

[Update 186]
Discriminator
    Loss: 1.3542587757110596
    Acc: 0.6484375
Generator
    Loss: 1.241682529449463
        Teacher-Force: 1.241682529449463
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6979296803474426
####################################

[Update 187]
Discriminator
    Loss: 1.2502739429473877
    Acc: 0.5859375
Generator
    Loss: 1.2356505393981934
        Teacher-Force: 1.2356505393981934
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6968359351158142
####################################

[Update 188]
Discriminator
    Loss: 2.098559617996216
    Acc: 0.53125
Generator
    Lo

    Loss: 1.0800899267196655
        Teacher-Force: 1.0800899267196655
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7520312666893005
####################################

[Update 212]
Discriminator
    Loss: 1.1754897832870483
    Acc: 0.703125
Generator
    Loss: 1.096251368522644
        Teacher-Force: 1.096251368522644
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7510156631469727
####################################

[Update 213]
Discriminator
    Loss: 0.8635663986206055
    Acc: 0.8203125
Generator
    Loss: 1.323967695236206
        Teacher-Force: 1.0379021167755127
        FreeToTeacher: 1.7084460258483887
        TeacherToFree: 1.2255545854568481
    Acc: 0.7682682275772095
####################################

[Update 214]
Discriminator
    Loss: 0.8040292263031006
    Acc: 0.84375
Generator
    Loss: 1.352858543395996
        Teacher-Force: 1.0172295570373535
        FreeToTeacher: 1.9009456634521484
        TeacherToFree: 1.14040040969

####################################

[Update 237]
Discriminator
    Loss: 1.0980620384216309
    Acc: 0.6875
Generator
    Loss: 1.0912272930145264
        Teacher-Force: 1.0912272930145264
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7491145730018616
####################################

[Update 238]
Discriminator
    Loss: 1.2355930805206299
    Acc: 0.6171875
Generator
    Loss: 1.072628378868103
        Teacher-Force: 1.072628378868103
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7536718845367432
####################################

[Update 239]
Discriminator
    Loss: 1.1241543292999268
    Acc: 0.6640625
Generator
    Loss: 1.0486279726028442
        Teacher-Force: 1.0486279726028442
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7649869918823242
####################################

[Update 240]
Discriminator
    Loss: 0.9486548900604248
    Acc: 0.7578125
Generator
    Loss: 1.3025367259979248
        Teacher-Force: 1.

        TeacherToFree: 3.4127197265625
    Acc: 0.6774349212646484
####################################

[Update 263]
Discriminator
    Loss: 0.5753374695777893
    Acc: 0.8828125
Generator
    Loss: 2.2943034172058105
        Teacher-Force: 1.3404818773269653
        FreeToTeacher: 1.9761052131652832
        TeacherToFree: 3.5663228034973145
    Acc: 0.6625651121139526
####################################

[Update 264]
Discriminator
    Loss: 0.7871837615966797
    Acc: 0.8359375
Generator
    Loss: 2.0914430618286133
        Teacher-Force: 1.38713538646698
        FreeToTeacher: 1.9867714643478394
        TeacherToFree: 2.9004225730895996
    Acc: 0.6539453268051147
####################################

[Update 265]
Discriminator
    Loss: 1.5989291667938232
    Acc: 0.59375
Generator
    Loss: 1.4418126344680786
        Teacher-Force: 1.4418126344680786
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6331640481948853
####################################

[Update 266

    Loss: 1.0412849187850952
    Acc: 0.765625
Generator
    Loss: 1.0807924270629883
        Teacher-Force: 0.9276596903800964
        FreeToTeacher: 1.3090713024139404
        TeacherToFree: 1.005645990371704
    Acc: 0.79332035779953
####################################

[Update 290]
Discriminator
    Loss: 0.9726376533508301
    Acc: 0.8046875
Generator
    Loss: 1.1616458892822266
        Teacher-Force: 0.9253265261650085
        FreeToTeacher: 1.3559446334838867
        TeacherToFree: 1.20366632938385
    Acc: 0.7958333492279053
####################################

[Update 291]
Discriminator
    Loss: 1.2826735973358154
    Acc: 0.671875
Generator
    Loss: 0.9051234722137451
        Teacher-Force: 0.9051234722137451
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.80154949426651
####################################

[Update 292]
Discriminator
    Loss: 1.098851203918457
    Acc: 0.765625
Generator
    Loss: 0.9836184978485107
        Teacher-Force: 0.90205615758

        TeacherToFree: 2.6191213130950928
    Acc: 0.7522135376930237
####################################

[Update 315]
Discriminator
    Loss: 0.9478021264076233
    Acc: 0.7734375
Generator
    Loss: 1.7610024213790894
        Teacher-Force: 1.1641021966934204
        FreeToTeacher: 1.7898238897323608
        TeacherToFree: 2.3290810585021973
    Acc: 0.7262760400772095
####################################

[Update 316]
Discriminator
    Loss: 2.1550698280334473
    Acc: 0.5
Generator
    Loss: 1.3104668855667114
        Teacher-Force: 1.3104668855667114
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6746745109558105
####################################

[Update 317]
Discriminator
    Loss: 1.9082822799682617
    Acc: 0.28125
Generator
    Loss: 1.3104982376098633
        Teacher-Force: 1.3104982376098633
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6816015839576721
####################################

[Update 318]
Discriminator
    Loss: 1.449

        Teacher-Force: 1.3377461433410645
        FreeToTeacher: 1.1173899173736572
        TeacherToFree: 2.5715994834899902
    Acc: 0.6597005128860474
####################################

[Update 341]
Discriminator
    Loss: 1.7828259468078613
    Acc: 0.5625
Generator
    Loss: 1.2606970071792603
        Teacher-Force: 1.2606970071792603
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6719531416893005
####################################

[Update 342]
Discriminator
    Loss: 1.455095887184143
    Acc: 0.515625
Generator
    Loss: 1.2589240074157715
        Teacher-Force: 1.2589240074157715
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6764192581176758
####################################

[Update 343]
Discriminator
    Loss: 1.2463252544403076
    Acc: 0.6875
Generator
    Loss: 1.2297438383102417
        Teacher-Force: 1.2297438383102417
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6901432275772095
##########################

    Loss: 1.0458028316497803
    Acc: 0.7421875
Generator
    Loss: 0.8322612643241882
        Teacher-Force: 0.8322612643241882
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8178906440734863
####################################

[Update 368]
Discriminator
    Loss: 1.0161222219467163
    Acc: 0.75
Generator
    Loss: 0.831562876701355
        Teacher-Force: 0.831562876701355
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.819531261920929
####################################

[Update 369]
Discriminator
    Loss: 1.0984079837799072
    Acc: 0.6796875
Generator
    Loss: 0.855974018573761
        Teacher-Force: 0.855974018573761
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8174740076065063
####################################

[Update 370]
Discriminator
    Loss: 1.0675902366638184
    Acc: 0.7265625
Generator
    Loss: 0.8740761280059814
        Teacher-Force: 0.8740761280059814
        FreeToTeacher: nan
        TeacherToFree: nan

    Loss: 0.8129024505615234
        Teacher-Force: 0.8129024505615234
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8282421827316284
####################################

[Update 394]
Discriminator
    Loss: 1.0168766975402832
    Acc: 0.7265625
Generator
    Loss: 0.7728409171104431
        Teacher-Force: 0.7728409171104431
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.83454430103302
####################################

[Update 395]
Discriminator
    Loss: 0.8997645378112793
    Acc: 0.7890625
Generator
    Loss: 1.308390498161316
        Teacher-Force: 0.7472339868545532
        FreeToTeacher: 1.7255895137786865
        TeacherToFree: 1.4523478746414185
    Acc: 0.8463671803474426
####################################

[Update 396]
Discriminator
    Loss: 0.804733395576477
    Acc: 0.828125
Generator
    Loss: 1.3673696517944336
        Teacher-Force: 0.7544155120849609
        FreeToTeacher: 2.027524471282959
        TeacherToFree: 1.3201687335

        FreeToTeacher: 1.9519230127334595
        TeacherToFree: 0.5404977798461914
    Acc: 0.8301432728767395
####################################

[Update 420]
Discriminator
    Loss: 1.1958746910095215
    Acc: 0.7109375
Generator
    Loss: 0.7711329460144043
        Teacher-Force: 0.7711329460144043
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8335937857627869
####################################

[Update 421]
Discriminator
    Loss: 0.997859537601471
    Acc: 0.71875
Generator
    Loss: 0.7560250163078308
        Teacher-Force: 0.7560250163078308
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8434895873069763
####################################

[Update 422]
Discriminator
    Loss: 0.7054733633995056
    Acc: 0.8515625
Generator
    Loss: 1.6682367324829102
        Teacher-Force: 0.7826111912727356
        FreeToTeacher: 2.42694091796875
        TeacherToFree: 1.7951579093933105
    Acc: 0.8448047041893005
###################################

    Loss: 1.6778285503387451
        Teacher-Force: 1.159578561782837
        FreeToTeacher: 2.3766422271728516
        TeacherToFree: 1.4972646236419678
    Acc: 0.7224088907241821
####################################

[Update 446]
Discriminator
    Loss: 1.0295820236206055
    Acc: 0.765625
Generator
    Loss: 1.4791769981384277
        Teacher-Force: 1.188699722290039
        FreeToTeacher: 1.8498320579528809
        TeacherToFree: 1.3989992141723633
    Acc: 0.7120963931083679
####################################

[Update 447]
Discriminator
    Loss: 0.829373300075531
    Acc: 0.78125
Generator
    Loss: 1.5618658065795898
        Teacher-Force: 1.2272874116897583
        FreeToTeacher: 1.8831403255462646
        TeacherToFree: 1.575169563293457
    Acc: 0.7034375071525574
####################################

[Update 448]
Discriminator
    Loss: 0.8312321901321411
    Acc: 0.796875
Generator
    Loss: 1.4662755727767944
        Teacher-Force: 1.2038003206253052
        FreeToTeach

        TeacherToFree: 1.1438734531402588
    Acc: 0.830065131187439
####################################

[Update 472]
Discriminator
    Loss: 0.9143249988555908
    Acc: 0.78125
Generator
    Loss: 1.0972801446914673
        Teacher-Force: 0.7533328533172607
        FreeToTeacher: 1.2456541061401367
        TeacherToFree: 1.2928533554077148
    Acc: 0.8420052528381348
####################################

[Update 473]
Discriminator
    Loss: 0.9611853957176208
    Acc: 0.7734375
Generator
    Loss: 1.073476791381836
        Teacher-Force: 0.7781132459640503
        FreeToTeacher: 1.2003130912780762
        TeacherToFree: 1.242004156112671
    Acc: 0.8350651264190674
####################################

[Update 474]
Discriminator
    Loss: 0.9568811058998108
    Acc: 0.7734375
Generator
    Loss: 1.0610690116882324
        Teacher-Force: 0.7807812690734863
        FreeToTeacher: 1.1661344766616821
        TeacherToFree: 1.2362910509109497
    Acc: 0.835533857345581
##################

    Loss: 1.1764366626739502
    Acc: 0.7109375
Generator
    Loss: 0.8426080942153931
        Teacher-Force: 0.8426080942153931
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8148958683013916
####################################

[Update 498]
Discriminator
    Loss: 1.0990126132965088
    Acc: 0.6796875
Generator
    Loss: 0.869498074054718
        Teacher-Force: 0.869498074054718
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8011458516120911
####################################

[Update 499]
Discriminator
    Loss: 0.8348221778869629
    Acc: 0.8046875
Generator
    Loss: 1.5595481395721436
        Teacher-Force: 0.8841371536254883
        FreeToTeacher: 1.6249182224273682
        TeacherToFree: 2.169588804244995
    Acc: 0.8005208373069763
####################################

[Update 500]
Discriminator
    Loss: 0.5970696210861206
    Acc: 0.9140625
Generator
    Loss: 1.903495192527771
        Teacher-Force: 0.900814414024353
        FreeToTeac

####################################

[Update 523]
Discriminator
    Loss: 0.8505644202232361
    Acc: 0.828125
Generator
    Loss: 1.183233618736267
        Teacher-Force: 0.8730758428573608
        FreeToTeacher: 1.3644828796386719
        TeacherToFree: 1.312142014503479
    Acc: 0.804674506187439
####################################

[Update 524]
Discriminator
    Loss: 0.8845582008361816
    Acc: 0.8359375
Generator
    Loss: 1.1639044284820557
        Teacher-Force: 0.8916256427764893
        FreeToTeacher: 1.547398567199707
        TeacherToFree: 1.0526890754699707
    Acc: 0.7998307347297668
####################################

[Update 525]
Discriminator
    Loss: 0.8504372835159302
    Acc: 0.8125
Generator
    Loss: 1.2199302911758423
        Teacher-Force: 0.904663622379303
        FreeToTeacher: 1.725228190422058
        TeacherToFree: 1.0298988819122314
    Acc: 0.8029687404632568
####################################

[Update 526]
Discriminator
    Loss: 1.020327568054199

####################################

[Update 549]
Discriminator
    Loss: 1.245415210723877
    Acc: 0.65625
Generator
    Loss: 0.8124793767929077
        Teacher-Force: 0.8124793767929077
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.829296886920929
####################################

[Update 550]
Discriminator
    Loss: 1.1638013124465942
    Acc: 0.7265625
Generator
    Loss: 0.7800796031951904
        Teacher-Force: 0.7800796031951904
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8395442962646484
####################################

[Update 551]
Discriminator
    Loss: 1.130561113357544
    Acc: 0.7265625
Generator
    Loss: 0.7440003752708435
        Teacher-Force: 0.7440003752708435
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8484505414962769
####################################

[Update 552]
Discriminator
    Loss: 1.1837197542190552
    Acc: 0.6953125
Generator
    Loss: 0.8000308275222778
        Teacher-Force: 0.

        TeacherToFree: 3.067152738571167
    Acc: 0.7638021111488342
####################################

[Update 575]
Discriminator
    Loss: 0.38006824254989624
    Acc: 0.953125
Generator
    Loss: 2.152216911315918
        Teacher-Force: 1.0527433156967163
        FreeToTeacher: 2.064303159713745
        TeacherToFree: 3.339604139328003
    Acc: 0.7311068177223206
####################################

[Update 576]
Discriminator
    Loss: 0.6566431522369385
    Acc: 0.859375
Generator
    Loss: 2.028258800506592
        Teacher-Force: 1.144439458847046
        FreeToTeacher: 1.6490793228149414
        TeacherToFree: 3.291257381439209
    Acc: 0.711132824420929
####################################

[Update 577]
Discriminator
    Loss: 1.4555970430374146
    Acc: 0.625
Generator
    Loss: 1.3469338417053223
        Teacher-Force: 1.3469338417053223
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.6511979103088379
####################################

[Update 578]
Disc

        TeacherToFree: 1.2619562149047852
    Acc: 0.7908463478088379
####################################

[Update 601]
Discriminator
    Loss: 0.8843965530395508
    Acc: 0.90625
Generator
    Loss: 1.1270253658294678
        Teacher-Force: 0.942211389541626
        FreeToTeacher: 1.1403369903564453
        TeacherToFree: 1.298527479171753
    Acc: 0.7874088883399963
####################################

[Update 602]
Discriminator
    Loss: 0.9453023672103882
    Acc: 0.8359375
Generator
    Loss: 1.121051549911499
        Teacher-Force: 0.9475654363632202
        FreeToTeacher: 1.0407065153121948
        TeacherToFree: 1.3748823404312134
    Acc: 0.7823176980018616
####################################

[Update 603]
Discriminator
    Loss: 0.9873046875
    Acc: 0.7578125
Generator
    Loss: 1.083345890045166
        Teacher-Force: 0.940192461013794
        FreeToTeacher: 0.9375622272491455
        TeacherToFree: 1.3722829818725586
    Acc: 0.7838541865348816
#########################

        TeacherToFree: nan
    Acc: 0.7784635424613953
####################################

[Update 627]
Discriminator
    Loss: 0.9552913904190063
    Acc: 0.78125
Generator
    Loss: 1.2497930526733398
        Teacher-Force: 0.9481030106544495
        FreeToTeacher: 1.3218166828155518
        TeacherToFree: 1.4794594049453735
    Acc: 0.7818750143051147
####################################

[Update 628]
Discriminator
    Loss: 1.0116289854049683
    Acc: 0.765625
Generator
    Loss: 1.166190266609192
        Teacher-Force: 0.9726576805114746
        FreeToTeacher: 1.2032194137573242
        TeacherToFree: 1.3226935863494873
    Acc: 0.7755599021911621
####################################

[Update 629]
Discriminator
    Loss: 1.1227755546569824
    Acc: 0.703125
Generator
    Loss: 1.0116642713546753
        Teacher-Force: 1.0116642713546753
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7574609518051147
####################################

[Update 630]
Discriminat

        Teacher-Force: 0.9048601388931274
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.79296875
####################################

[Update 653]
Discriminator
    Loss: 1.4272103309631348
    Acc: 0.5703125
Generator
    Loss: 0.8989643454551697
        Teacher-Force: 0.8989643454551697
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7960547208786011
####################################

[Update 654]
Discriminator
    Loss: 1.4559544324874878
    Acc: 0.53125
Generator
    Loss: 0.9037219285964966
        Teacher-Force: 0.9037219285964966
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8016406297683716
####################################

[Update 655]
Discriminator
    Loss: 1.2136915922164917
    Acc: 0.6640625
Generator
    Loss: 0.8800060153007507
        Teacher-Force: 0.8800060153007507
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.808190107345581
####################################

[Update 656]
Discrimi

    Loss: 0.7618327140808105
    Acc: 0.8515625
Generator
    Loss: 1.4210628271102905
        Teacher-Force: 0.845165491104126
        FreeToTeacher: 2.072852611541748
        TeacherToFree: 1.345170497894287
    Acc: 0.8163802027702332
####################################

[Update 679]
Discriminator
    Loss: 0.7065021991729736
    Acc: 0.875
Generator
    Loss: 1.4650657176971436
        Teacher-Force: 0.8858318328857422
        FreeToTeacher: 2.1821610927581787
        TeacherToFree: 1.3272041082382202
    Acc: 0.8057942986488342
####################################

[Update 680]
Discriminator
    Loss: 0.6834549903869629
    Acc: 0.8515625
Generator
    Loss: 1.557714819908142
        Teacher-Force: 0.9084494113922119
        FreeToTeacher: 2.193610429763794
        TeacherToFree: 1.5710842609405518
    Acc: 0.7982161641120911
####################################

[Update 681]
Discriminator
    Loss: 0.6815794110298157
    Acc: 0.8515625
Generator
    Loss: 1.622916579246521
     

    Acc: 0.7709766030311584
####################################

[Update 704]
Discriminator
    Loss: 1.4459364414215088
    Acc: 0.5625
Generator
    Loss: 0.9614294171333313
        Teacher-Force: 0.9614294171333313
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.776940107345581
####################################

[Update 705]
Discriminator
    Loss: 1.3522634506225586
    Acc: 0.578125
Generator
    Loss: 0.9546864628791809
        Teacher-Force: 0.9546864628791809
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7742187976837158
####################################

[Update 706]
Discriminator
    Loss: 1.2764906883239746
    Acc: 0.625
Generator
    Loss: 0.9624412059783936
        Teacher-Force: 0.9624412059783936
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.7752864956855774
####################################

[Update 707]
Discriminator
    Loss: 0.956430971622467
    Acc: 0.8203125
Generator
    Loss: 1.2326059341430664
  

    Acc: 0.78125
Generator
    Loss: 1.1510990858078003
        Teacher-Force: 0.816085934638977
        FreeToTeacher: 1.2698453664779663
        TeacherToFree: 1.367365837097168
    Acc: 0.8297265768051147
####################################

[Update 731]
Discriminator
    Loss: 1.1633576154708862
    Acc: 0.7109375
Generator
    Loss: 0.8325182795524597
        Teacher-Force: 0.8325182795524597
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8227344155311584
####################################

[Update 732]
Discriminator
    Loss: 1.0398449897766113
    Acc: 0.75
Generator
    Loss: 0.8481329679489136
        Teacher-Force: 0.8481329679489136
        FreeToTeacher: nan
        TeacherToFree: nan
    Acc: 0.8206120133399963
####################################

[Update 733]
Discriminator
    Loss: 1.0333640575408936
    Acc: 0.6796875
Generator
    Loss: 0.781782865524292
        Teacher-Force: 0.781782865524292
        FreeToTeacher: nan
        TeacherToFree: nan

In [None]:
def train(model, optimizer, criterion, n_epochs, device, train_loader, lr_scheduler=None, prehistory=None, checkpoint_dir='.', checkpoint_basename=None, save_every=None):
    """
    Params
    =======
    model (torch module): model to train.
    optimizer (torch optimizer): optimizer to use for training the model.
    criterion (torch loss module): loss function for computing loss.
    n_epochs (int): number of epochs to train.
    device (str): device on which the training happens.
    train_loader (iterable): training data(inputs, labels) loader.
    lr_scheduler (lr_scheduler, optional): learning rate scheduler.
    prehistory (dict, optional): can be provided for plotting when continuing
                                 the training.
    checkpoint_dir (str, optional): the directory in which checkpoint files are saved.
                               (default: '.')
    checkpoint_basename (str, optional): base name for checkpoint files (default: None)
    save_every (int, optional): how often to create mid-training checkpoints, separate
                                from the best checkpoint. Final result is always saved
                                unless the base name is not given.
                                (default: None)
    """
    if prehistory:
        history         = prehistory
        best_loss_idx   = history['loss'].index(min(history['loss']))
        best_loss       = history['loss'][best_loss_idx]
        best_loss_epoch = history['epoch'][best_loss_idx]
        best_loss_acc_1 = history['acc_1'][best_loss_idx]
        best_loss_acc_5 = history['acc_5'][best_loss_idx]
        loss_avg = loss = history['loss'][-1]
        acc_1           = history['acc_1'][-1]
        acc_5           = history['acc_5'][-1]
        i_epoch         = history['epoch'][-1] + 1
    else:
        history         = {'epoch':[], 'loss': [], 'acc_1': [], 'acc_5': []}
        best_loss       = float('inf')
        best_loss_epoch = -1
        best_loss_acc_1 = 0.
        best_loss_acc_5 = 0.
        loss_avg = loss = float('inf')
        acc_1           = 0.
        acc_5           = 0.
        i_epoch         = 0
    
    if save_every is None:
        save_every = n_epochs
    _save_every_check = save_every - 1
    if checkpoint_basename:
        checkpoint_normal_filepath_template = os.path.join(checkpoint_dir, checkpoint_basename + "(epoch={}).pth")
        checkpoint_best_filepath   = os.path.join(checkpoint_dir, checkpoint_basename + "_best.pth")
    
    def update_progress_stats(update_epoch, update_train):
        if update_epoch:
            if 'momentum' in optimizer.param_groups[0]:
                momentum = optimizer.param_groups[0]['momentum']
            elif 'betas' in optimizer.param_groups[0]:
                momentum = optimizer.param_groups[0]['betas'][0]
            else:
                momentum = None
            epoch_iterator.set_postfix_str("current_epoch={}, "
                                           "loss={:.4f}, acc=(top1={:.4f}, top5={:.4f}), "
                                           "best_loss_stats(epoch={}, loss={:.4f}, acc(top1={:.4f}, top5={:.4f})), "
                                           "lr={:.4e}, momentum={:.4f}"
                                           .format(epoch,
                                                   loss_avg, acc_1, acc_5,
                                                   best_loss_epoch, best_loss, best_loss_acc_1, best_loss_acc_5,
                                                   Decimal(optimizer.param_groups[0]['lr']),
                                                   momentum),
                                           refresh=True)
        if update_train:
            train_progress_bar.set_postfix_str("loss={:>7.4f}, acc(top1={:.4f}, top5={:.4f})".format(loss.item(), b_acc_1, b_acc_5), refresh=True)
    
    tracking_dict = {'history':         history,
                     'hyperparameters': hyperparameters,
                     'model_dict':      model.state_dict(),
                     'optimizer_dict':  optimizer.state_dict(),
                     'lr_dict':         lr_scheduler.state_dict() if lr_scheduler else None,
                     'tokens':          ud_dataset.tokens}
    
    if n_epochs < 1:
        return history
    
    model.to(device)
    
    plot_on = False
    
    if lr_scheduler:
        last_epoch = int(lr_scheduler.last_epoch)
    
    try:
        epoch_iterator = tqdm.tqdm(iterable=range(i_epoch, i_epoch + n_epochs), desc="Train Epochs")
        train_progress_bar = tqdm.tqdm(total=len(train_loader), desc="Train Iterations")
        model.train()
        for epoch in epoch_iterator:
            if lr_scheduler:
                last_epoch += 1
            
            train_progress_bar.n = 0
            train_progress_bar.last_print_n = 0
            train_progress_bar.start_t = time()
            train_progress_bar.last_print_t = time()
            train_progress_bar.refresh()
            
            update_progress_stats(True, False)

            hidden_states = model.init_hidden(train_loader.batch_size)

            running_loss = 0
            n_top1_corrects = 0
            n_top5_corrects = 0
            n_instances = 0
            for i, (inputs, labels) in enumerate(train_loader):
                if lr_scheduler:
                    lr_scheduler.step(last_epoch + (i / len(train_loader)))
                
                inputs = inputs.to(device)
                labels = labels.view(-1).to(device)

                outputs, hidden_states = model(inputs, hidden_states)
                outputs = outputs.view(-1, outputs.size(-1))
                
                hidden_states = ([hidden.detach() for hidden in hidden_states[0]],
                                 [cell.detach() for cell in hidden_states[1]])

                loss = criterion(outputs, labels)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                running_loss    += loss.item() * labels.size(0)
                top5_match       = outputs.data.topk(k=5, dim=1)[1].eq(labels.unsqueeze(1))
                top1_corrects    = top5_match[:, 0].sum().item()
                top5_corrects    = top5_match.sum().item()
                b_acc_1          = top1_corrects / labels.size(0)
                b_acc_5          = top5_corrects / labels.size(0)
                n_top1_corrects += top1_corrects
                n_top5_corrects += top5_corrects
                n_instances     += labels.size(0)
                
                update_progress_stats(True, True)
                del outputs, loss, top5_match, top1_corrects, top5_corrects, b_acc_1, b_acc_5
                train_progress_bar.update(1)
                
            if lr_scheduler:
                lr_scheduler.last_epoch = last_epoch
            loss_avg = running_loss / n_instances
            acc_1    = n_top1_corrects / n_instances
            acc_5    = n_top5_corrects / n_instances
            del running_loss, n_top1_corrects, n_top5_corrects, n_instances
            
            update_progress_stats(True, False)
            
            history['epoch'].append(epoch)
            history['loss'].append(loss_avg)
            history['acc_1'].append(acc_1)
            history['acc_5'].append(acc_5)
            if loss_avg < best_loss:
                best_loss = loss_avg
                best_loss_epoch = epoch
                best_loss_acc_1 = acc_1
                best_loss_acc_5 = acc_5
                if checkpoint_basename:
                    torch.save(tracking_dict, checkpoint_best_filepath)
            
            if checkpoint_basename and ((epoch - i_epoch) % save_every == _save_every_check or (epoch - i_epoch) == (n_epochs - 1)):
                torch.save(tracking_dict, checkpoint_normal_filepath_template.format(epoch))
                
            if epoch >= 1:
                if not plot_on:
                    %matplotlib notebook
                    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(14, 12), facecolor=fig_bg_color)
                    axes[0].set_facecolor(plot_bg_color)
                    axes[0].grid(True)
                    axes[0].set_title("Training Loss", fontsize=fontsize)
                    axes[0].set_xlabel("Epoch", fontsize=fontsize)
                    axes[0].set_ylabel("Loss", fontsize=fontsize)
                    axes[0].plot([], [], color='blue', label='loss')
                    axes[0].legend()
                    axes[1].set_facecolor(plot_bg_color)
                    axes[1].grid(True)
                    axes[1].set_title("Training Accuracy", fontsize=fontsize)
                    axes[1].set_xlabel("Epoch", fontsize=fontsize)
                    axes[1].set_ylabel("Accuracy", fontsize=fontsize)
                    axes[1].plot([], [], color='blue', label='top-1 acc')
                    axes[1].plot([], [], color='orange', label='top-5 acc')
                    axes[1].legend()
                    fig.canvas.draw()
                    plot_on = True
                    
                axes[0].lines[0].set_xdata(history['epoch'])
                axes[0].lines[0].set_ydata(history['loss'])
                axes[1].lines[0].set_xdata(history['epoch'])
                axes[1].lines[0].set_ydata(history['acc_1'])
                axes[1].lines[1].set_xdata(history['epoch'])
                axes[1].lines[1].set_ydata(history['acc_5'])

                axes[0].set_xlim(-0.05 * epoch, 1.05 * epoch)
                max_min_diff = max(history['loss']) - min(history['loss'])
                if max_min_diff > 0:
                    axes[0].set_ylim(min(history['loss']) - 0.05 * max_min_diff, max(history['loss']) + 0.05 * max_min_diff)
                axes[1].set_xlim(-0.05 * epoch, 1.05 * epoch)
                max_min_diff = max(history['acc_1'] + history['acc_5']) - min(history['acc_1'] + history['acc_5'])
                if max_min_diff > 0:
                    axes[1].set_ylim(min(history['acc_1'] + history['acc_5']) - 0.05 * max_min_diff, max(history['acc_1'] + history['acc_5']) + 0.05 * max_min_diff)

                axes[0].xaxis.set_major_locator(AutoLocator())
                axes[0].yaxis.set_major_locator(AutoLocator())
                axes[1].xaxis.set_major_locator(AutoLocator())
                axes[1].yaxis.set_major_locator(AutoLocator())

                xlim = axes[0].get_xlim()
                xticks = [tick for tick in axes[0].get_xticks() if xlim[0] <= tick <= xlim[1]]
                index = len(history['loss']) - 1 - history['loss'][::-1].index(max(history['loss']))
                if index not in xticks:
                    xticks.append(index)
                index = len(history['loss']) - 1 - history['loss'][::-1].index(min(history['loss']))
                if index not in xticks:
                    xticks.append(index)
                if epoch not in xticks:
                    xticks.append(epoch)
                axes[0].set_xticks(xticks)

                xlim = axes[1].get_xlim()
                xticks = [tick for tick in axes[1].get_xticks() if xlim[0] <= tick <= xlim[1]]
                index = len(history['acc_1']) - 1 - history['acc_1'][::-1].index(max(history['acc_1']))
                if index not in xticks:
                    xticks.append(index)
                index = len(history['acc_1']) - 1 - history['acc_1'][::-1].index(min(history['acc_1']))
                if index not in xticks:
                    xticks.append(index)
                index = len(history['acc_5']) - 1 - history['acc_5'][::-1].index(max(history['acc_5']))
                if index not in xticks:
                    xticks.append(index)
                index = len(history['acc_5']) - 1 - history['acc_5'][::-1].index(min(history['acc_5']))
                if index not in xticks:
                    xticks.append(index)
                if epoch not in xticks:
                    xticks.append(epoch)
                axes[1].set_xticks(xticks)

                ylim = axes[0].get_ylim()
                yticks = [tick for tick in axes[0].get_yticks() if ylim[0] <= tick <= ylim[1]]
                if max(history['loss']) not in yticks:
                    yticks.append(max(history['loss']))
                if best_loss not in yticks:
                    yticks.append(best_loss)
                if history['loss'][-1] not in yticks:
                    yticks.append(history['loss'][-1])
                axes[0].set_yticks(yticks)

                ylim = axes[1].get_ylim()
                yticks = [tick for tick in axes[1].get_yticks() if ylim[0] <= tick <= ylim[1]]
                if max(history['acc_1']) not in yticks:
                    yticks.append(max(history['acc_1']))
                if min(history['acc_1']) not in yticks:
                    yticks.append(min(history['acc_1']))
                if history['acc_1'][-1] not in yticks:
                    yticks.append(history['acc_1'][-1])
                if max(history['acc_5']) not in yticks:
                    yticks.append(max(history['acc_5']))
                if min(history['acc_5']) not in yticks:
                    yticks.append(min(history['acc_5']))
                if history['acc_5'][-1] not in yticks:
                    yticks.append(history['acc_5'][-1])
                axes[1].set_yticks(yticks)

                fig.canvas.draw()
                plt.pause(0.001)
                
    except KeyboardInterrupt:
        if lr_scheduler:
            lr_scheduler.step(int(last_epoch))
    %matplotlib inline
    return history

In [None]:
def make_behavior_sequence(internal_hiddens):
    """
    Params
    =======
    internal_hiddens (list[Tensor]): list of tensors containing internal
                                     hidden states from RNN layers across timesteps.

    Returns
    ========
    b_seq (Tensor): behavior sequence obtained from the internal hidden states.
    """
    b_seq = torch.cat(internal_hiddens, dim=-1)
    
    return b_seq

---

## Hyperparameters & Instantiation
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

Set hyperparameters and instantiate a dataset, model, optimizer, and criterion. (+ optionally a LR scheduler)

In [None]:
seed                   = 0
random.seed(seed)
torch.manual_seed(seed)

batch_size             = 8
max_repeats            = 20
sequence_length        = 300

embedding_dim          = 16
hidden_dim             = 256
n_lstm_layers          = 3
dropout                = 0.3

n_epochs               = 1000

optimizer_params = {'lr': 1e-3, 'betas': (0.9, 0.999), 'eps': 1e-5, 'weight_decay': 0., 'amsgrad': True}

hyperparameters = {'seed': seed,
                   'batch_size': batch_size, 'max_repeats': max_repeats, 'sequence_length': sequence_length,
                   'embedding_dim': embedding_dim, 'hidden_dim': hidden_dim, 'dropout': dropout,
                   'n_epochs': n_epochs,
                   'optimizer_params': optimizer_params}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)
print()

ud_dataset = UndertaleDeltaruneDataset("./source/converted_texts", batch_size, max_repeats)
ud_loader = UDBatchLoader(ud_dataset, batch_size, sequence_length, False, True)

model = UDNet(len(ud_dataset.tokens), embedding_dim, hidden_dim, n_lstm_layers, dropout)
model.to(device)

optimizer              = optim.Adam(model.parameters(), **optimizer_params)
criterion              = nn.CrossEntropyLoss()
lr_scheduler           = None

history = {}

print()
print('Data Sequence Total Length:', ud_dataset.data_len())
print()
print(model)

---

## Training
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

### Training Session
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

In [None]:
history = train(model, optimizer, criterion, n_epochs, device, ud_loader, lr_scheduler,
                prehistory=history, checkpoint_dir='.', checkpoint_basename='PoC', save_every=250)

---

## Loading the Best Model
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

In [None]:
loaded_dict = torch.load("./PoC_best.pth", map_location=device)
history = loaded_dict['history']
hyperparameters = loaded_dict['hyperparameters']
model_dict = loaded_dict['model_dict']
optimizer_dict = loaded_dict['optimizer_dict']
lr_dict = loaded_dict['lr_dict']

model.load_state_dict(model_dict)
optimizer.load_state_dict(optimizer_dict)

In [None]:
if True: # Plot the history of the loaded model
    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(24, 12), facecolor=fig_bg_color)
    axes[0].set_facecolor(plot_bg_color)
    axes[0].grid(True)
    axes[0].set_title("Training Loss", fontsize=fontsize)
    axes[0].set_xlabel("Epoch", fontsize=fontsize)
    axes[0].set_ylabel("Loss", fontsize=fontsize)
    axes[0].plot(history['epoch'], history['loss'], color='blue', label='loss')
    axes[0].legend()
    axes[1].set_facecolor(plot_bg_color)
    axes[1].grid(True)
    axes[1].set_title("Training Accuracy", fontsize=fontsize)
    axes[1].set_xlabel("Epoch", fontsize=fontsize)
    axes[1].set_ylabel("Accuracy", fontsize=fontsize)
    axes[1].plot(history['epoch'], history['acc_1'], color='blue', label='top-1 acc')
    axes[1].plot(history['epoch'], history['acc_5'], color='orange', label='top-5 acc')
    axes[1].legend()

    xlim = axes[0].get_xlim()
    xticks = [tick for tick in axes[0].get_xticks() if xlim[0] <= tick <= xlim[1]]
    index = len(history['loss']) - 1 - history['loss'][::-1].index(max(history['loss']))
    if index not in xticks:
        xticks.append(index)
    index = len(history['loss']) - 1 - history['loss'][::-1].index(min(history['loss']))
    if index not in xticks:
        xticks.append(index)
    if history['epoch'][-1] not in xticks:
        xticks.append(history['epoch'][-1])
    axes[0].set_xticks(xticks)

    xlim = axes[1].get_xlim()
    xticks = [tick for tick in axes[1].get_xticks() if xlim[0] <= tick <= xlim[1]]
    index = len(history['acc_1']) - 1 - history['acc_1'][::-1].index(max(history['acc_1']))
    if index not in xticks:
        xticks.append(index)
    index = len(history['acc_1']) - 1 - history['acc_1'][::-1].index(min(history['acc_1']))
    if index not in xticks:
        xticks.append(index)
    index = len(history['acc_5']) - 1 - history['acc_5'][::-1].index(max(history['acc_5']))
    if index not in xticks:
        xticks.append(index)
    index = len(history['acc_5']) - 1 - history['acc_5'][::-1].index(min(history['acc_5']))
    if index not in xticks:
        xticks.append(index)
    if history['epoch'][-1] not in xticks:
        xticks.append(history['epoch'][-1])
    axes[1].set_xticks(xticks)

    ylim = axes[0].get_ylim()
    yticks = [tick for tick in axes[0].get_yticks() if ylim[0] <= tick <= ylim[1]]
    if max(history['loss']) not in yticks:
        yticks.append(max(history['loss']))
    if min(history['loss']) not in yticks:
        yticks.append(min(history['loss']))
    if history['loss'][-1] not in yticks:
        yticks.append(history['loss'][-1])
    axes[0].set_yticks(yticks)

    ylim = axes[1].get_ylim()
    yticks = [tick for tick in axes[1].get_yticks() if ylim[0] <= tick <= ylim[1]]
    if max(history['acc_1']) not in yticks:
        yticks.append(max(history['acc_1']))
    if min(history['acc_1']) not in yticks:
        yticks.append(min(history['acc_1']))
    if history['acc_1'][-1] not in yticks:
        yticks.append(history['acc_1'][-1])
    axes[1].set_yticks(yticks);
    ylim = axes[1].get_ylim()
    yticks = [tick for tick in axes[1].get_yticks() if ylim[0] <= tick <= ylim[1]]
    if max(history['acc_1']) not in yticks:
        yticks.append(max(history['acc_1']))
    if min(history['acc_1']) not in yticks:
        yticks.append(min(history['acc_1']))
    if history['acc_1'][-1] not in yticks:
        yticks.append(history['acc_1'][-1])
    if max(history['acc_5']) not in yticks:
        yticks.append(max(history['acc_5']))
    if min(history['acc_5']) not in yticks:
        yticks.append(min(history['acc_5']))
    if history['acc_5'][-1] not in yticks:
        yticks.append(history['acc_5'][-1])
    axes[1].set_yticks(yticks);

---

## Generation
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

### Sampling Function
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

In [None]:
def sample(model, input, hidden_states, top_k=5, temperature=1., return_as_tensor=False):
    assert top_k > 0
    assert temperature >= 0
    with torch.no_grad():
        output, hidden_states = model(input.view(1, -1), hidden_states)
        probs = output[0].softmax(dim=1)
    
    if temperature == 0 or top_k == 1:
        sampled = probs[-1].argmax(dim=0, keepdim=True)
    else:
        top_k_probs, top_k_args = probs[-1].topk(k=top_k, dim=0)
        top_idx = torch.multinomial(top_k_probs.pow(1 / temperature), 1)
        sampled = top_k_args.gather(dim=0, index=top_idx)
    
    if return_as_tensor:
        return sampled, hidden_states
    return sampled.item(), hidden_states

### Generation Function
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

In [None]:
def generate(model, length, starting_tokens=[], starting_len=100, top_k=5, temperature=1.):
    # `starting_len` is used to sample a sequence from ud_dataset when `staring_tokens` is not given.
    if starting_tokens:
        tokens = starting_tokens
        note = torch.tensor(starting_tokens, dtype=torch.long, device=device)
    else:
        tokens = ud_dataset[0][0][:starting_len].tolist()
        note = torch.tensor(tokens, dtype=torch.long, device=device)
    
    hidden_states = model.init_hidden(1)
    model.eval()
    
    with torch.no_grad():
        iterator = tqdm.tqdm(range(length), desc='Generating Tokens:', unit='token')
        for _ in iterator:
            note, hidden_states = sample(model, note, hidden_states, top_k, temperature, True)
            tokens.append(note.item())
    
    model.train()
    
    return tokens

def predict(model, ud_dataset, device=device):
    inputs, targets = ud_dataset[0]
    original = inputs[:1].tolist() + targets.tolist()
    
    model.to(device)
    
    seq_len = 2000
    
    predicted = [inputs[:1].item()]
    
    hidden_states = model.init_hidden(1)
    hidden_states = ([hidden.detach() for hidden in hidden_states[0]],
                     [cell.detach() for cell in hidden_states[1]])
    model.eval()
    
    iterable = zip(range(0, len(inputs), seq_len), range(seq_len, len(inputs) + seq_len - 1, seq_len))
    progress = tqdm.tqdm(iterable, desc='Processing Dataset:', total=inputs.size(0), unit='token')
    with torch.no_grad():
        for start, end in iterable:
            outputs, hidden_states = model(inputs.data[start:end].unsqueeze(0).to(device), hidden_states)
            predicted.extend(outputs.data.squeeze(0).argmax(dim=1, keepdim=False).cpu().tolist())
            if end <= inputs.size(0):
                progress.update(seq_len)
            else:
                progress.update(inputs.size(0) - start)
    
    model.train()
    
    return original, predicted

### Music Generation
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

In [None]:
def tokens_to_intlist(tokens, token_to_key_dict):
    if len(tokens) == 0:
        return []

    patterns = []
    intlist = []
    last_pattern = tuple()

    for token in tokens:
        pattern = token_to_key_dict[token]
        if pattern != tuple() and pattern[0] == "<REPEAT>":
            for _ in range(pattern[1]):
                patterns.append(last_pattern)
        else:
            patterns.append(pattern)
            last_pattern = pattern

    intlist.extend(patterns[0])
    for pattern in patterns[1:]:
        intlist.append(0)
        intlist.extend(pattern)
    
    return intlist

In [None]:
def intlist_to_text(intlist):
    return " ".join([str(int) for int in intlist])

In [None]:
original_tokens, predicted_tokens = predict(model, ud_dataset)
original_intlist, predicted_intlist = tokens_to_intlist(original_tokens, ud_dataset.tokens), tokens_to_intlist(predicted_tokens, ud_dataset.tokens)
original_text, predicted_text = intlist_to_text(original_intlist), intlist_to_text(predicted_intlist)

In [None]:
filebase = "./poc"

with open(filebase + "_orig.txt", 'w') as f:
    f.write(original_text)
with open(filebase + "_pred.txt", 'w') as f:
    f.write(predicted_text)

In [None]:
generated_tokens = generate(model, 50000, starting_tokens=[], starting_len=1, top_k=5, temperature=1.)
generated_intlist = tokens_to_intlist(generated_tokens, ud_dataset.tokens)
generated_text = intlist_to_text(generated_intlist)

print("Generated Tokens ({})".format(len(generated_tokens)))
print(generated_tokens[:100] + ['...'])
print("\nInt List ({})".format(len(generated_intlist)))
print(generated_intlist[:100] + ['...'])
print("\nText ({})".format(len(generated_text)))
print(generated_text[:500] + ' ...')

In [None]:
filepath = "./poc_gen.txt"

with open(filepath, 'w') as f:
    f.write(generated_text)

---

## Final Summary, Notes, and Thoughts
[(go to top)](#Undertale-&-Deltarune-Soundtrack-Generator)

---