# Train a model to generate structure of a piece (MELONS-inspired)

1. Read structure dataset from POP909_structure
2. Pre-process str into graph format
3. Setup transformer model
4. Train-val split, data loader
5. Evaluate model predictions

In [1]:
import torch
import os
import re

In [2]:
structure_path = "POP909_structure"

In [3]:
labels = []
for folder in os.listdir(structure_path):
    try:
        f = open(f"{structure_path}/{folder}/human_label1.txt", "r")
        # print(f.read())
        labels.append(f.read())
    except:
        continue

In [4]:
labels[:5]

['i8A8A8B8C4C4b4b4x2A8B8C4C4C4C4X1o1\n',
 'i4A4A4B4B4C4C4C4D4x4B4B4C4C4C4D4X3\n',
 'i4A8A8B4C4b5x1b5A8B4x1C4C4\n',
 'i18A8A8A9x14B8A10B8A10o4\n',
 'X4b4A8B12b4A8B12b4b4B12o4\n']

In [5]:
def split_string(s):
    # This regex pattern matches a letter followed by one or more digits
    pattern = re.compile(r'[a-zA-Z]\d+')
    # Find all matches in the string
    matches = pattern.findall(s)
    return matches

In [6]:
split_string(labels[3])

['i18', 'A8', 'A8', 'A9', 'x14', 'B8', 'A10', 'B8', 'A10', 'o4']

In [7]:
# def split_into_pairs(input_string):
#     # Initialize an empty list to hold the pairs
#     pairs = []
    
#     # Iterate over the string in steps of 2
#     for i in range(0, len(input_string), 2):
#         next_pair = input_string[i:i+2]
#         if next_pair != "\n":
#             # Append the substring of the next two characters to the list
#             pairs.append(next_pair)
    
#     return pairs

In [10]:
# split_into_pairs("i4A4A4B9b4A4B9b4B9X5o2")

In [11]:
all_phrases = []
for label in labels:
    all_phrases.append(split_string(label))

In [12]:
all_phrases[:3]

[['i8',
  'A8',
  'A8',
  'B8',
  'C4',
  'C4',
  'b4',
  'b4',
  'x2',
  'A8',
  'B8',
  'C4',
  'C4',
  'C4',
  'C4',
  'X1',
  'o1'],
 ['i4',
  'A4',
  'A4',
  'B4',
  'B4',
  'C4',
  'C4',
  'C4',
  'D4',
  'x4',
  'B4',
  'B4',
  'C4',
  'C4',
  'C4',
  'D4',
  'X3'],
 ['i4',
  'A8',
  'A8',
  'B4',
  'C4',
  'b5',
  'x1',
  'b5',
  'A8',
  'B4',
  'x1',
  'C4',
  'C4']]

In [13]:
def get_phrase_edge_type(prev_phrase, curr_phrase, prev_phrase_idx, curr_phrase_idx):
    """
    Edge types:
    0: Intro to Any
    1: Any to Outro
    2: Repeated phrase
    3: Melody to Melody
    4: Melody to Non-Melody
    5: Non-Melody to Melody
    6: Non-Melody to Non-Melody
    """
    # print(prev_phrase_idx, curr_phrase_idx)
    
    prev_phrase_type = prev_phrase[0]
    curr_phrase_type = curr_phrase[0]
    
    if prev_phrase == curr_phrase:
            return 2
    
    if prev_phrase_idx + 1 == curr_phrase_idx:
        # print(prev_phrase_type)
    
        if prev_phrase_type == "i":
            return 0
        elif curr_phrase_type == "o":
            return 1
        elif prev_phrase_type.isupper() & curr_phrase_type.isupper():
            return 3
        elif prev_phrase_type.isupper() & curr_phrase_type.islower():
            return 4
        elif prev_phrase_type.islower() & curr_phrase_type.isupper():
            return 5
        elif prev_phrase_type.islower() & curr_phrase_type.islower():
            return 6
        else:
            return None

In [14]:
get_phrase_edge_type("B4", "B9", 0, 3)

In [15]:
all_phrases[0]

['i8',
 'A8',
 'A8',
 'B8',
 'C4',
 'C4',
 'b4',
 'b4',
 'x2',
 'A8',
 'B8',
 'C4',
 'C4',
 'C4',
 'C4',
 'X1',
 'o1']

In [None]:
# TODO: Get max size of phrases

In [16]:
def create_sequence(phrases):
    # Create sequence of edges from phrase, where each item is a tuple (i, j, edge type, num bars in i, num bars in j)
    seq = []
    max_phrase_len = 0
    for i, phrase_from in enumerate(phrases):
        for j, phrase_to in enumerate(phrases[i+1:]):
            phrase_to_idx = j+i+1
            edge_type = get_phrase_edge_type(phrase_from, phrase_to, i, phrase_to_idx)
            if edge_type is not None:
                phrase_from_len = int(phrase_from[1])
                phrase_to_len = int(phrase_to[1])
                max_phrase_len = max(max_phrase_len, max(phrase_from_len, phrase_to_len))
                seq.append((i, phrase_to_idx, edge_type, phrase_from_len, phrase_to_len))
    
    # Append END token
    seq.append((len(phrases)-1, len(phrases), 7, 0, 0))
    return seq, max_phrase_len

In [19]:
seqs = []
max_phrase_len = 0
max_num_nodes = 0
for phrases in all_phrases:
    # print(phrases)
    num_nodes = len(phrases)
    seq, max_phrase_len_indiv = create_sequence(phrases)
    seqs.append(seq)
    max_phrase_len = max(max_phrase_len, max_phrase_len_indiv)
    max_num_nodes = max(max_num_nodes, num_nodes)

In [20]:
seqs[0]

[(0, 1, 0, 8, 8),
 (1, 2, 2, 8, 8),
 (1, 9, 2, 8, 8),
 (2, 3, 3, 8, 8),
 (2, 9, 2, 8, 8),
 (3, 4, 3, 8, 4),
 (3, 10, 2, 8, 8),
 (4, 5, 2, 4, 4),
 (4, 11, 2, 4, 4),
 (4, 12, 2, 4, 4),
 (4, 13, 2, 4, 4),
 (4, 14, 2, 4, 4),
 (5, 6, 4, 4, 4),
 (5, 11, 2, 4, 4),
 (5, 12, 2, 4, 4),
 (5, 13, 2, 4, 4),
 (5, 14, 2, 4, 4),
 (6, 7, 2, 4, 4),
 (7, 8, 6, 4, 2),
 (8, 9, 5, 2, 8),
 (9, 10, 3, 8, 8),
 (10, 11, 3, 8, 4),
 (11, 12, 2, 4, 4),
 (11, 13, 2, 4, 4),
 (11, 14, 2, 4, 4),
 (12, 13, 2, 4, 4),
 (12, 14, 2, 4, 4),
 (13, 14, 2, 4, 4),
 (14, 15, 3, 4, 1),
 (15, 16, 1, 1, 1),
 (16, 17, 7, 0, 0)]

In [21]:
max_phrase_len

9

In [22]:
max_num_nodes

39

In [23]:
seq

[(0, 1, 0, 5, 4),
 (1, 2, 3, 4, 5),
 (1, 6, 2, 4, 4),
 (2, 3, 3, 5, 3),
 (2, 7, 2, 5, 5),
 (3, 4, 3, 3, 9),
 (3, 8, 2, 3, 3),
 (4, 5, 4, 9, 3),
 (4, 9, 2, 9, 9),
 (4, 11, 2, 9, 9),
 (5, 6, 5, 3, 4),
 (6, 7, 3, 4, 5),
 (7, 8, 3, 5, 3),
 (8, 9, 3, 3, 9),
 (9, 10, 3, 9, 3),
 (9, 11, 2, 9, 9),
 (10, 11, 3, 3, 9),
 (11, 12, 3, 9, 6),
 (12, 13, 1, 6, 1),
 (13, 14, 7, 0, 0)]

In [24]:
def create_input_output_pairs(seq):
    # Create input and output pairs
    input_seqs = []
    output_seqs = []
    for idx in range(1, len(seq)):
        input_seqs.append(seq[:idx])
        output_seqs.append(seq[idx:])
    return input_seqs, output_seqs

In [25]:
inputs = []
outputs = []

for seq in seqs:
    input_seqs, output_seqs = create_input_output_pairs(seq)
    inputs.append(input_seqs)
    outputs.append(output_seqs)

In [26]:
inputs[0][3]

[(0, 1, 0, 8, 8), (1, 2, 2, 8, 8), (1, 9, 2, 8, 8), (2, 3, 3, 8, 8)]

In [27]:
outputs[0][3]

[(2, 9, 2, 8, 8),
 (3, 4, 3, 8, 4),
 (3, 10, 2, 8, 8),
 (4, 5, 2, 4, 4),
 (4, 11, 2, 4, 4),
 (4, 12, 2, 4, 4),
 (4, 13, 2, 4, 4),
 (4, 14, 2, 4, 4),
 (5, 6, 4, 4, 4),
 (5, 11, 2, 4, 4),
 (5, 12, 2, 4, 4),
 (5, 13, 2, 4, 4),
 (5, 14, 2, 4, 4),
 (6, 7, 2, 4, 4),
 (7, 8, 6, 4, 2),
 (8, 9, 5, 2, 8),
 (9, 10, 3, 8, 8),
 (10, 11, 3, 8, 4),
 (11, 12, 2, 4, 4),
 (11, 13, 2, 4, 4),
 (11, 14, 2, 4, 4),
 (12, 13, 2, 4, 4),
 (12, 14, 2, 4, 4),
 (13, 14, 2, 4, 4),
 (14, 15, 3, 4, 1),
 (15, 16, 1, 1, 1),
 (16, 17, 7, 0, 0)]

In [28]:
inputs_flat = [seq for seqs in inputs for seq in seqs]
outputs_flat = [seq for seqs in outputs for seq in seqs]

In [29]:
len(inputs_flat)

19982

In [30]:
max(len(x) for x in inputs_flat)

150

In [31]:
len(outputs_flat)

19982

In [32]:
max(len(x) for x in outputs_flat)

150

In [34]:
len(set([token for tokens in inputs_flat for token in tokens]))

2932

In [35]:
len(set([token for tokens in outputs_flat for token in tokens]))

2897

## Dataloader

In [36]:
import torch
from torch.nn.utils.rnn import pad_sequence

def sequences_to_tensor(sequences, padding_value=0):
    """
    Convert a list of sequences of different lengths to a padded tensor.

    Args:
        sequences (list of list of tuples): List of sequences where each sequence is a list of tuples.
        padding_value (int, optional): Value to use for padding. Defaults to 0.

    Returns:
        torch.Tensor: Padded tensor of shape (batch_size, max_length, tuple_length)
    """
    # Convert each sequence to a tensor
    tensor_sequences = [torch.tensor(seq) for seq in sequences]

    # Pad sequences to the length of the longest sequence
    padded_sequences = pad_sequence(tensor_sequences, batch_first=True, padding_value=padding_value)

    return padded_sequences


In [37]:
padded_input = sequences_to_tensor(inputs_flat, padding_value=0)
padded_output = sequences_to_tensor(outputs_flat, padding_value=0)

print("Padded input shape:", padded_input.shape)
print("Padded output shape:", padded_output.shape)

Padded input shape: torch.Size([19982, 150, 5])
Padded output shape: torch.Size([19982, 150, 5])


In [61]:
seqs[0]

[(0, 1, 0, 8, 8),
 (1, 2, 2, 8, 8),
 (1, 9, 2, 8, 8),
 (2, 3, 3, 8, 8),
 (2, 9, 2, 8, 8),
 (3, 4, 3, 8, 4),
 (3, 10, 2, 8, 8),
 (4, 5, 2, 4, 4),
 (4, 11, 2, 4, 4),
 (4, 12, 2, 4, 4),
 (4, 13, 2, 4, 4),
 (4, 14, 2, 4, 4),
 (5, 6, 4, 4, 4),
 (5, 11, 2, 4, 4),
 (5, 12, 2, 4, 4),
 (5, 13, 2, 4, 4),
 (5, 14, 2, 4, 4),
 (6, 7, 2, 4, 4),
 (7, 8, 6, 4, 2),
 (8, 9, 5, 2, 8),
 (9, 10, 3, 8, 8),
 (10, 11, 3, 8, 4),
 (11, 12, 2, 4, 4),
 (11, 13, 2, 4, 4),
 (11, 14, 2, 4, 4),
 (12, 13, 2, 4, 4),
 (12, 14, 2, 4, 4),
 (13, 14, 2, 4, 4),
 (14, 15, 3, 4, 1),
 (15, 16, 1, 1, 1),
 (16, 17, 7, 0, 0)]

In [62]:
len(seqs)

909

In [66]:
padded_seq = sequences_to_tensor(seqs, padding_value=0)
print("Padded seq shape:", padded_seq.shape)

Padded seq shape: torch.Size([909, 151, 5])


In [98]:
from torch.utils.data import Dataset, DataLoader, random_split

In [96]:
# Train-test split
test_ratio = 0.1

num_test = round(len(seqs) * test_ratio)
train_split, test_split = random_split(padded_seq, [len(seqs)-num_test, num_test])
print(f"Split data into Train and Test sets of size {len(train_split)} and {len(test_split)} respectively.")

Split data into Train and Test sets of size 818 and 91 respectively.


In [99]:
# Define the custom dataset
class TupleSequenceDataset(Dataset):
    def __init__(self, input_sequences, output_sequences):
        self.input_sequences = input_sequences
        self.output_sequences = output_sequences

    def __len__(self):
        return len(self.input_sequences)

    def __getitem__(self, idx):
        return [self.input_sequences[idx], self.output_sequences[idx]]

# Parameters
batch_size = 32
shuffle = True

# Create the dataset
dataset_train = TupleSequenceDataset(train_split, train_split)
dataset_test = TupleSequenceDataset(test_split, test_split)

# Create the DataLoader
dataloader_train = DataLoader(dataset_train, batch_size=batch_size, shuffle=shuffle)
dataloader_test = DataLoader(dataset_test, batch_size=batch_size, shuffle=shuffle)


In [100]:
n_token = [max_num_nodes, max_num_nodes, 7, max_phrase_len, max_phrase_len]

In [101]:
n_token

[39, 39, 7, 9, 9]

## Autoregression transformer

In [112]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim

import math
from tqdm import tqdm

In [250]:
# Ref: https://github.com/YatingMusic/compound-word-transformer/blob/main/workspace/uncond/cp-linear/main-cp.py
# https://towardsdatascience.com/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1
# https://gist.github.com/danimelchor/bcad4d7f79b98464c4d4481d62d27622

class Embeddings(nn.Module):
    """
    Get embeddings for edge tokens
    """
    def __init__(self, n_token, d_model):
        super(Embeddings, self).__init__()
        # print(n_token)
        self.lut = nn.Embedding(n_token+1, d_model)
        self.d_model = d_model

    def forward(self, x):
        # print(n_token)
        # print(x.shape)
        # print(self.d_model)
        # print(self.lut(x))
        return self.lut(x) * math.sqrt(self.d_model)
    
    
class PositionalEncoding(nn.Module):
    """
    Get positional encodings
    """
    def __init__(self, d_model, dropout=0.1, max_len=20000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
        


class AutoregressiveTransformer(nn.Module):
    def __init__(self, n_token):
        super(AutoregressiveTransformer, self).__init__()
        
        # --- params config --- #
        self.n_token = n_token   
        self.d_model = D_MODEL 
        self.d_feedforward = D_FEEDFW
        self.n_layer = N_LAYER #
        self.dropout = 0.1
        self.n_head = N_HEAD
        self.d_head = D_MODEL // N_HEAD
        self.d_inner = 1024
        self.loss_func = nn.CrossEntropyLoss(reduction='none')
        self.emb_sizes = [128, 128, 12, 16, 16]
        
        
        # --- modules config --- #
        # embeddings
        print('>>>>>:', self.n_token)
        self.emb_i = Embeddings(self.n_token[0], self.emb_sizes[0])
        self.emb_j = Embeddings(self.n_token[1], self.emb_sizes[1])
        self.emb_edge_type = Embeddings(self.n_token[2], self.emb_sizes[2])
        self.emb_i_size = Embeddings(self.n_token[3], self.emb_sizes[3])
        self.emb_j_size = Embeddings(self.n_token[4], self.emb_sizes[4])
        self.pos_emb = PositionalEncoding(self.d_model, self.dropout)

        # linear 
        self.in_linear = nn.Linear(np.sum(self.emb_sizes), self.d_model)
        
        # encoder
        self.transformer = nn.Transformer(
            d_model=self.d_model,
            nhead=self.n_head,
            num_encoder_layers=self.n_layer,
            num_decoder_layers=self.n_layer,
            dim_feedforward=self.d_feedforward,
            dropout=self.dropout,
        )

        # individual output
        self.proj_i    = nn.Linear(self.d_model, self.n_token[0])        
        self.proj_j    = nn.Linear(self.d_model, self.n_token[1])
        self.proj_edge_type  = nn.Linear(self.d_model, self.n_token[2])
        self.proj_i_size     = nn.Linear(self.d_model, self.n_token[3])
        self.proj_j_size    = nn.Linear(self.d_model, self.n_token[4])
        
    def compute_loss(self, predict, target, loss_mask):
        loss = self.loss_func(predict, target)
        loss = loss * loss_mask
        loss = torch.sum(loss) / torch.sum(loss_mask)
        return loss

#     def train_step(self, x, target, loss_mask):
#         h, y_i, y_j, y_edge_type, y_i_size, y_j_size = self.forward(x)
         
#         # reshape (b, s, f) -> (b, f, s)
#         y_i = y_i[:, ...].permute(0, 2, 1)
#         y_j = y_j[:, ...].permute(0, 2, 1)
#         y_edge_type = y_edge_type[:, ...].permute(0, 2, 1)
#         y_i_size = y_i_size[:, ...].permute(0, 2, 1)
#         y_j_size = y_j_size[:, ...].permute(0, 2, 1)
        
#         # loss
#         loss_i = self.compute_loss(
#                 y_i, target[..., 0], loss_mask)
#         loss_j = self.compute_loss(
#                 y_j, target[..., 1], loss_mask)
#         loss_edge_type = self.compute_loss(
#                 y_edge_type, target[..., 2], loss_mask)
#         loss_i_size = self.compute_loss(
#                 y_i_size,  target[..., 3], loss_mask)
#         loss_j_size = self.compute_loss(
#                 y_j_size, target[..., 4], loss_mask)

#         return loss_i, loss_j, loss_edge_type, loss_i_size, loss_j_size
    
    
    def forward(self, src, tgt, tgt_mask=None, src_pad_mask=None, tgt_pad_mask=None):
        '''
        linear transformer: b x s x f
        x.shape=(bs, nf)
        '''
        # Src size must be (batch_size, src sequence length)
        # Tgt size must be (batch_size, tgt sequence length)

        # Embedding + positional encoding - Out size = (batch_size, sequence length, dim_model)
        # src = self.embedding(src) * math.sqrt(self.dim_model)
        # tgt = self.embedding(tgt) * math.sqrt(self.dim_model)
        # src = self.positional_encoder(src)
        # tgt = self.positional_encoder(tgt)
        
        # # We could use the parameter batch_first=True, but our KDL version doesn't support it yet, so we permute
        # # to obtain size (sequence length, batch_size, dim_model),
        # src = src.permute(1,0,2)
        # tgt = tgt.permute(1,0,2)
    
        # src embeddings
        emb_i_src =    self.emb_i(src[..., 0])
        emb_j_src =    self.emb_j(src[..., 1])
        emb_edge_type_src =  self.emb_edge_type(src[..., 2])
        emb_i_size_src =     self.emb_i_size(src[..., 3])
        emb_j_size_src =    self.emb_j_size(src[..., 4])

        embs_src = torch.cat(
            [
                emb_i_src,
                emb_j_src,
                emb_edge_type_src,
                emb_i_size_src,
                emb_j_size_src,
            ], dim=-1)

        emb_linear_src = self.in_linear(embs_src)
        pos_emb_src = self.pos_emb(emb_linear_src)
        
        
        # tgt embeddings
        emb_i_tgt =    self.emb_i(tgt[..., 0])
        emb_j_tgt =    self.emb_j(tgt[..., 1])
        emb_edge_type_tgt =  self.emb_edge_type(tgt[..., 2])
        emb_i_size_tgt =     self.emb_i_size(tgt[..., 3])
        emb_j_size_tgt =    self.emb_j_size(tgt[..., 4])

        embs_tgt = torch.cat(
            [
                emb_i_tgt,
                emb_j_tgt,
                emb_edge_type_tgt,
                emb_i_size_tgt,
                emb_j_size_tgt,
            ], dim=-1)

        emb_linear_tgt = self.in_linear(embs_tgt)
        pos_emb_tgt = self.pos_emb(emb_linear_tgt)
        
        # target embeddings
    
        # transformer
        # Transformer blocks - Out size = (sequence length, batch_size, num_tokens)
        # print(pos_emb_src.shape)
        # print(pos_emb_tgt.shape)
        pos_emb_src = pos_emb_src.permute(1,0,2)
        pos_emb_tgt = pos_emb_tgt.permute(1,0,2)
        transformer_out = self.transformer(pos_emb_src, pos_emb_tgt, 
                                           tgt_mask=tgt_mask, 
                                           src_key_padding_mask=src_pad_mask, 
                                           tgt_key_padding_mask=tgt_pad_mask)
        # out = self.out(transformer_out)
        
        
#         if is_training:
#             # mask
#             attn_mask = TriangularCausalMask(pos_emb.size(1), device=x.device)
#             h = self.transformer_encoder(pos_emb, attn_mask) # y: b x s x d_model

#             # # project type
#             # y_type = self.proj_type(h)
#             # return h, y_type
#         else:
#             pos_emb = pos_emb.squeeze(0)
#             h, memory = self.transformer_encoder(pos_emb, memory=memory) # y: s x d_model
            
#             # # project type
#             # y_type = self.proj_type(h)
#             # return h, y_type, memory

        y_i    = self.proj_i(transformer_out)
        y_j    = self.proj_j(transformer_out)
        y_edge_type  = self.proj_edge_type(transformer_out)
        y_i_size    = self.proj_i_size(transformer_out)
        y_j_size = self.proj_j_size(transformer_out)

        return  y_i, y_j, y_edge_type, y_i_size, y_j_size
    
    def get_tgt_mask(self, size) -> torch.tensor:
        # Generates a square matrix where the each row allows one word more to be seen
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        
        # EX for size=5:
        # [[0., -inf, -inf, -inf, -inf],
        #  [0.,   0., -inf, -inf, -inf],
        #  [0.,   0.,   0., -inf, -inf],
        #  [0.,   0.,   0.,   0., -inf],
        #  [0.,   0.,   0.,   0.,   0.]]
        
        return mask
    
    def create_pad_mask(self, matrix: torch.tensor, pad_token: int) -> torch.tensor:
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)
        



In [None]:
# TODO: Change input size to expand to vocab size?? I.e. if vocab size is 39, then it should be a 1-hot of 39.

In [222]:
n_token[2]

7

In [246]:
def compute_loss(predict, target, loss_func, loss_mask):
        predict = predict.permute(1, 2, 0)   
        print(predict.shape)
        print(target.shape)
        loss = loss_func(predict, target)
        loss = loss * loss_mask
        loss = torch.sum(loss) / torch.sum(loss_mask)
        return loss

In [251]:
def train_loop(model, opt, loss_fn, dataloader):
    
    model.train()
    total_loss = 0
    
    for batch_X, batch_y in dataloader:
        # X, y = batch[:, 0], batch[:, 1]
        # X, y = torch.tensor(X).to(device), torch.tensor(y).to(device)
        X = batch_X.to(device)
        y = batch_y.to(device)

        # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
        y_input = y[:,:-1]
        y_expected = y[:,1:]
        
        # Get mask to mask out the next words
        sequence_length = y_input.size(1)
        tgt_mask = model.get_tgt_mask(sequence_length).to(device)

        # # Standard training except we pass in y_input and tgt_mask
        # pred = model(X, y_input, tgt_mask)

        # Permute pred to have batch size first again
        # pred = pred.permute(1, 2, 0)      
        # loss = loss_fn(pred, y_expected)
        
        
        y_i, y_j, y_edge_type, y_i_size, y_j_size = model(X, y_input, tgt_mask)

        # reshape (b, s, f) -> (b, f, s)
        y_i = y_i[:, ...].permute(0, 2, 1)
        y_j = y_j[:, ...].permute(0, 2, 1)
        y_edge_type = y_edge_type[:, ...].permute(0, 2, 1)
        y_i_size = y_i_size[:, ...].permute(0, 2, 1)
        y_j_size = y_j_size[:, ...].permute(0, 2, 1)
        
        print(y_i.shape)
        print(y_j.shape)
        print(y_expected.shape)
        print(y_expected[..., 0].shape)

        # loss
        loss_i = compute_loss(
                y_i, y_expected[..., 0], loss_fn, tgt_mask)
        loss_j = compute_loss(
                y_j, y_expected[..., 1], loss_fn, tgt_mask)
        loss_edge_type = compute_loss(
                y_edge_type, y_expected[..., 2], loss_fn, tgt_mask)
        loss_i_size = compute_loss(
                y_i_size,  y_expected[..., 3], loss_fn, tgt_mask)
        loss_j_size = compute_loss(
                y_j_size, y_expected[..., 4], loss_fn, tgt_mask)


        opt.zero_grad()
        loss.backward()
        opt.step()
        
        batch_loss = (loss_i + loss_j + loss_edge_type + loss_i_size + loss_j_size) / 5
        total_loss += batch_loss.detach().item()
        # total_loss_i += loss_i.detach().item()

        
    return total_loss / len(dataloader)

def validation_loop(model, loss_fn, dataloader):
    
    model.eval()
    total_loss = 0
    
    with torch.no_grad():
        for batch in dataloader:
            X, y = batch[:, 0], batch[:, 1]
            X, y = torch.tensor(X, dtype=torch.long, device=device), torch.tensor(y, dtype=torch.long, device=device)

            # Now we shift the tgt by one so with the <SOS> we predict the token at pos 1
            y_input = y[:,:-1]
            y_expected = y[:,1:]
            
            # Get mask to mask out the next words
            sequence_length = y_input.size(1)
            tgt_mask = model.get_tgt_mask(sequence_length).to(device)

#             # Standard training except we pass in y_input and src_mask
#             pred = model(X, y_input, tgt_mask)

#             # Permute pred to have batch size first again
#             pred = pred.permute(1, 2, 0)      
#             loss = loss_fn(pred, y_expected)
#             total_loss += loss.detach().item()
            
            
            y_i, y_j, y_edge_type, y_i_size, y_j_size = model(X, y_input, tgt_mask)

            # reshape (b, s, f) -> (b, f, s)
            y_i = y_i[:, ...].permute(0, 2, 1)
            y_j = y_j[:, ...].permute(0, 2, 1)
            y_edge_type = y_edge_type[:, ...].permute(0, 2, 1)
            y_i_size = y_i_size[:, ...].permute(0, 2, 1)
            y_j_size = y_j_size[:, ...].permute(0, 2, 1)
            

            # loss
            loss_i = compute_loss(
                    y_i, y_expected[..., 0], loss_fn, tgt_mask)
            loss_j = compute_loss(
                    y_j, y_expected[..., 1], loss_fn, tgt_mask)
            loss_edge_type = compute_loss(
                    y_edge_type, y_expected[..., 2], loss_fn, tgt_mask)
            loss_i_size = compute_loss(
                    y_i_size,  y_expected[..., 3], loss_fn, tgt_mask)
            loss_j_size = compute_loss(
                    y_j_size, y_expected[..., 4], loss_fn, tgt_mask)

            batch_loss = (loss_i + loss_j + loss_edge_type + loss_i_size + loss_j_size) / 5
            total_loss += batch_loss.detach().item()
        
    return total_loss / len(dataloader)

In [252]:
def fit(model, opt, loss_fn, train_dataloader, val_dataloader, epochs):
    
    # Used for plotting later on
    train_loss_list, validation_loss_list = [], []
    
    print("Training and validating model")
    for epoch in range(epochs):
        print("-"*25, f"Epoch {epoch + 1}","-"*25)
        
        train_loss = train_loop(model, opt, loss_fn, train_dataloader)
        train_loss_list += [train_loss]
        
        validation_loss = validation_loop(model, loss_fn, val_dataloader)
        validation_loss_list += [validation_loss]
        
        print(f"Training loss: {train_loss:.4f}")
        print(f"Validation loss: {validation_loss:.4f}")
        print()
        
    return train_loss_list, validation_loss_list


In [253]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

# Hyperparameters
# vocab_size = 2932  # Example vocab size
# embed_size = 6

N_LAYER = 4
N_HEAD = 4
D_MODEL = 256
D_FEEDFW = 1024

learning_rate = 1e-4
max_seq_length = 150

# Initialize the model, optimizer, and loss function
# model = AutoregressiveTransformer(vocab_size, embed_size, num_layers, num_heads, hidden_size, ff_size, dropout_rate)
model = AutoregressiveTransformer(n_token).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_fn = nn.CrossEntropyLoss()

train_loss_list, validation_loss_list = fit(model, optimizer, loss_fn, dataloader_train, dataloader_test, 10)

>>>>>: [39, 39, 7, 9, 9]
Training and validating model
------------------------- Epoch 1 -------------------------
torch.Size([150, 39, 32])
torch.Size([150, 39, 32])
torch.Size([32, 150, 5])
torch.Size([32, 150])
torch.Size([39, 32, 150])
torch.Size([32, 150])


ValueError: Expected input batch_size (39) to match target batch_size (32).