# GPT on guitar tabs

In [2]:
# !pip install torch transformers wandb PyGuitarPro

In [1]:
%env TOKENIZERS_PARALLELISM false

env: TOKENIZERS_PARALLELISM=false


In [2]:
from pathlib import Path
drive_path = Path('/Users/vlad/googledrive')
if not drive_path.exists():
    drive_path = Path("/content/drive/MyDrive")

In [3]:
tabs_path = drive_path / 'PlayMusic/tabs'
paths = list(tabs_path.glob('**/*.gp[3-5]'))
print(f'Found {len(paths)} files')

Found 49125 files


In [6]:
N_TOY = 10

#### Parsing tabs
We would start with only bass tracks as they are simpler. Doing best to select bass tracks based on heuristics: 4 strings, 24 frets, instruments from 32 to 40 (unfortunately PyGuitarPro doesn't provide labels for the instruments, only numbers - so I checked the names in GuitarPro itself (rather, I use an open source version tuxguitar) and selected all instrument that have "bass" in the name.

In [7]:
from collections import defaultdict
import guitarpro
from tqdm import tqdm

N_FRETS = 24
N_STRINGS = 4  # base; 6 for standard guitar
INSTRUMENTS = range(32, 40)  # base; range(24, 31) for standard guitar

tracks_by_path = dict()

for path_num, path in enumerate(tqdm(paths)):
    try:
        song = guitarpro.parse(path)
    except guitarpro.GPException:
        print(f'   failed to parse {path}')
        continue

    tracks = []
    for track in song.tracks:
        if all([
            track.settings.tablature, 
            len(track.strings) == N_STRINGS,
            track.fretCount == N_FRETS,
            track.channel.instrument in INSTRUMENTS,
        ]):
            tracks.append(track)
    
    if tracks:
        tracks_by_path[path] = tracks

    if sum(map(len, tracks_by_path.values())) >= N_TOY:
        break

if len(tracks_by_path) == 0:
    raise Exception('No bass tracks found')

print(
    f'Found {len(tracks_by_path)}/{len(paths)} files with bass tracks, '
    f'total {sum(map(len, tracks_by_path.values()))} tracks'
)

  0%|          | 228/49125 [01:51<6:37:25,  2.05it/s] 

Found 99/49125 files with bass tracks, total 100 tracks





#### Verifying tabs by printing one in a human-readable format

In [46]:
def pprint_track(track: guitarpro.Track):
    lines = ['' for _ in range(N_STRINGS)]

    cur_time_sig = None
    for mi, measure in enumerate(track.measures):
        if measure.timeSignature != cur_time_sig:
            cur_time_sig = measure.timeSignature
            lines[0] += '  '
            lines[1] += f'{cur_time_sig.numerator} '
            lines[2] += f'{cur_time_sig.denominator.value} '
            lines[3] += '  '

        voice = measure.voices[0]  # only the first voice is usually non-empty
        for bi, beat in enumerate(voice.beats):
            out = f'{beat.duration.value}{"+" if beat.duration.isDotted else ""}:'
            note_by_string = {note.string - 1: note for note in beat.notes}
            for s in range(N_STRINGS):
                if n := note_by_string.get(s):
                    if n.type == guitarpro.NoteType.tie:
                        v = '~'
                    elif n.type == guitarpro.NoteType.dead:
                        v = 'x'
                    else:
                        v = n.value
                else:
                    v = '-'
                lines[s] += f'{out}{v:<3} '
        for s in range(N_STRINGS):
            lines[s] += '|'

    screen_width = 80
    for screen_num in range(0, len(lines[0]), screen_width):
        for line in lines:
            print(line[screen_num:screen_num+screen_width])
        print()

path = sorted(list(tracks_by_path.keys()))[0]
track = tracks_by_path[path][0]
print(path)
print(track.name)
pprint_track(track)

/Users/vlad/googledrive/PlayMusic/tabs/Badlands/Badlands - High Wire.gp4
Bass (Greg Chaisson)
  4:-   |  4:-   |4:-   |16:-   16:-   16:-   16:-   8:-   16:-   16:-   |16:-  
4 4:-   |2 4:-   |4:-   |16:-   16:-   16:-   16:-   8:-   16:-   16:-   |16:-  
4 4:-   |4 4:-   |4:-   |16:0   16:0   16:0   16:3   8:~   16:0   16:0   |16:0  
  4:-   |  4:-   |4:-   |16:-   16:-   16:-   16:-   8:-   16:-   16:-   |16:-  

 8+:-   16:-   16:-   16:-   16:-   |8:-   16:-   16:-   16:-   16:-   16:-   8:
 8+:-   16:-   16:-   16:-   16:-   |8:-   16:-   16:-   16:-   16:-   16:-   8:
 8+:5   16:0   16:-   16:-   16:0   |8:5   16:5   16:3   16:5   16:3   16:0   8:
 8+:-   16:-   16:0   16:3   16:-   |8:-   16:-   16:-   16:-   16:-   16:-   8:

-   |16:-   16:-   16:-   16:-   8:-   16:-   16:-   |16:-   8+:-   16:-   16:- 
-   |16:-   16:-   16:-   16:-   8:-   16:-   16:-   |16:-   8+:-   16:-   16:- 
-   |16:0   16:0   16:0   16:3   8:~   16:0   16:0   |16:0   8+:5   16:0   16:- 
3   |16:-   1

#### Encoding into flat string
 
Now attempting to encode multi-string tabs into a flat string, suitable for training a transformer. In an NLP world, our token would encode all notes that sound at the same time during one beat, along with the beat duration. We would separate beats with a space (like words of text), and separate measures with a "|" character (like paragraphs/sentences).

In [47]:
import math

duration_alphabet = ['1', '2', '4', '8', 'F', 'H', 'G', 'I']
strings_alphabet = ['E', 'A', 'D', 'G', 'B', 'e']

save_dir = drive_path / 'AI/datasets/gtp_gpt'
save_dir.mkdir(exist_ok=True)

def encode_track(track):
    out = []
    
    for mi, measure in enumerate(track.measures):
        beats = measure.voices[0].beats
        if not any(len(b.notes) > 0 for b in beats):
            continue
        
        for bi, beat in enumerate(beats):
            s = duration_alphabet[int(math.log2(beat.duration.value))]
            if beat.duration.isDotted:
                s += '.'
            
            if beat.notes:
                notes = []
                for n in beat.notes:
                    if n.type == guitarpro.NoteType.tie:
                        v = '~'
                    elif n.type == guitarpro.NoteType.dead:
                        v = 'x'
                    else:
                        v = n.value
                    notes.append(f'{strings_alphabet[n.string - 1]}{v}')
                s += "".join(notes)
            out.append(s)
        out.append('|')
    return out

encoded = encode_track(track)
print(path)
print(' '.join(encoded))

/Users/vlad/googledrive/PlayMusic/tabs/Badlands/Badlands - High Wire.gp4
FD0 FD0 FD0 FD3 8D~ FD0 FD0 | FD0 8.D5 FD0 FG0 FG3 FD0 | 8D5 FD5 FD3 FD5 FD3 FD0 8G3 | FD0 FD0 FD0 FD3 8D~ FD0 FD0 | FD0 8.D5 FD0 FG0 FG3 FD0 | 8D5 FD5 FD3 FD5 FD3 8D5 | 8G3 FG3 FG3 FG~ FG3 FG3 FG~ | FG3 FG3 FG~ FG3 4G3 | 8.D0 8.G3 8D0 | FD~ 8.D3 8G5 8G4 | 8.G3 FG3 8G3 8G3 | FG~ FG3 FG3 FG3 8G3 8G3 | 8.D0 8.G3 8D0 | FD~ 8D3 FD0 FD3 FD0 FG3 FG0 | 4.D3 FD~ FD3 | FD~ FD3 FD3 4D~ FD~ | FA1 FA0 FD3 FA0 F FAx FA1 FA0 | FD3 8.A0 FGx 8.G3 | 2D0 | 2D~ | 8.D0 8.G3 8D0 | FD~ FD3 FD0 FG3 FG5 FG3 FG5 FG3 | 8A2 FA~ 8D2 FD~ 8G0 | FG~ 8.D2 8A2 8A1 | 8A0 F 8G2 F 8G3 | F FA2 FA0 FD3 FA0 FD3 8D~ 8G5 8Gx | FD0 FD0 FD0 FD3 8D~ FD0 FD0 | FD0 8.D5 FD0 FG0 FG3 FD0 | 8D5 FD5 FD3 FD5 FD3 FD0 8G3 | FD0 FD0 FD0 FD3 8D~ FD0 FD0 | FD0 8.D5 FD0 FG0 FG3 FD0 | 8D5 FD5 FD3 FD5 FD3 8D5 | 8.G3 8.G3 8G~ | 2G~ | 8.D0 8.G3 8D0 | FD~ 8.D3 8G5 8G4 | 4.G3 FG~ FG3 | FG3 FG3 F FG3 FG3 FG3 8 | 8.D0 8.G3 8D0 | FD~ FD3 F FD0 FD3 FD0 FG3 FG0 | 2D3 | 2D~ | FA1 F

#### Decode into human-readable tab

Decoding back into human-readable tab to verify correctness.

In [48]:
def decode_and_print(encoded: list[str]):
    decoded_by_string = defaultdict(str)

    for measure in ' '.join(encoded).strip().split('|'):
        for beat in measure.strip().split(' '):
            if not beat:
                continue

            i = 0
            duration = str(2 ** duration_alphabet.index(beat[i]))
            i += 1
            
            if i < len(beat) and beat[i] == '·':
                duration += '·'
                i += 1
    
            for si in range(N_STRINGS):
                decoded_by_string[si] += f'{duration}:'
             
            note_by_string = {}
            cur_string, cur_note = None, ''
            while i < len(beat):
                if beat[i] in strings_alphabet:
                    if cur_string is not None:
                        note_by_string[cur_string] = cur_note
                    cur_string = strings_alphabet.index(beat[i])
                    cur_note = ''
                else:
                    cur_note += beat[i]
                i += 1
            
            if cur_string is not None:
                note_by_string[cur_string] = cur_note
            
            for si in range(N_STRINGS):
                if si in note_by_string:
                    decoded_by_string[si] += f'{note_by_string[si]:<3}'
                else:
                    decoded_by_string[si] += f'{"-":<3}'
                decoded_by_string[si] += ' '
    
        for si in range(N_STRINGS):
            decoded_by_string[si] += '|'
    
    screen_width = 80
    for screen_num in range(0, len(decoded_by_string[0]), screen_width):
        for line in decoded_by_string.values():
            print(line[screen_num:screen_num+screen_width])
        print()
        
decode_and_print(encoded)

16:-   16:-   16:-   16:-   8:-   16:-   16:-   |16:-   8:-   16:-   16:-   16:-
16:-   16:-   16:-   16:-   8:-   16:-   16:-   |16:-   8:-   16:-   16:-   16:-
16:0   16:0   16:0   16:3   8:~   16:0   16:0   |16:0   8:5   16:0   16:-   16:-
16:-   16:-   16:-   16:-   8:-   16:-   16:-   |16:-   8:-   16:-   16:0   16:3

   16:-   |8:-   16:-   16:-   16:-   16:-   16:-   8:-   |16:-   16:-   16:-   
   16:-   |8:-   16:-   16:-   16:-   16:-   16:-   8:-   |16:-   16:-   16:-   
   16:0   |8:5   16:5   16:3   16:5   16:3   16:0   8:-   |16:0   16:0   16:0   
   16:-   |8:-   16:-   16:-   16:-   16:-   16:-   8:3   |16:-   16:-   16:-   

16:-   8:-   16:-   16:-   |16:-   8:-   16:-   16:-   16:-   16:-   |8:-   16:-
16:-   8:-   16:-   16:-   |16:-   8:-   16:-   16:-   16:-   16:-   |8:-   16:-
16:3   8:~   16:0   16:0   |16:0   8:5   16:0   16:-   16:-   16:0   |8:5   16:5
16:-   8:-   16:-   16:-   |16:-   8:-   16:-   16:0   16:3   16:-   |8:-   16:-

   16:-   16:-   16:-   8

#### Decode into GTP file

In [52]:
bass_strings = list(tracks_by_path.values())[0][0].strings
bass_channel = list(tracks_by_path.values())[0][0].channel


def calc_n_beats_in_song(song: guitarpro.Song) -> int:
    n_beats = 0
    for track in song.tracks:
        for measure in track.measures:
            n_beats += len(measure.voices[0].beats)
    return n_beats


def decode_into_song(encoded: list[str]) -> guitarpro.Song:
    song = guitarpro.Song(
        title='Generated',
        artist='AI',
    )
    track = guitarpro.Track(
        song, 
        name='Bass', 
        strings=bass_strings,
        channel=bass_channel,
    )
    song.tracks = [track]
    
    decoded_by_string = defaultdict(str)
    for mi, enc_measure in enumerate(' '.join(encoded).strip().split('|')):
        measure = guitarpro.Measure(track, guitarpro.MeasureHeader(number=mi + 1))
        track.measures.append(measure)

        for bi, enc_beat in enumerate(enc_measure.strip().split(' ')):
            if not enc_beat:
                continue
                
            beat = guitarpro.Beat(
                measure.voices[0]    ,
                duration=guitarpro.Duration(value=1, isDotted=False),
            )
            measure.voices[0].beats.append(beat)

            # s = duration_alphabet[int(math.log2(beat.duration.value))]
            # if beat.duration.isDotted:
            #     s += '.'
            # 
    
            i = 0
            beat.duration = guitarpro.Duration(
                value=2 ** duration_alphabet.index(enc_beat[i]), 
            )
            i += 1
            
            if i < len(enc_beat) and enc_beat[i] == '·':
                i += 1
                beat.duration.isDotted = True
            
            note_by_string = {}
            cur_string, cur_note = None, ''
            while i < len(enc_beat):
                if enc_beat[i] in strings_alphabet:
                    if cur_string is not None:
                        note_by_string[cur_string] = cur_note
                    cur_string = strings_alphabet.index(enc_beat[i])
                    cur_note = ''
                else:
                    cur_note += enc_beat[i]
                i += 1
            if cur_string is not None:
                note_by_string[cur_string] = cur_note
            
            for si in range(N_STRINGS):
                if si in note_by_string:
                    note = guitarpro.Note(
                        beat=beat,
                        string=si + 1,
                    )
                    if note_by_string[si] == 'x':
                        note.type = guitarpro.NoteType.dead
                    elif note_by_string[si] == 'r':
                        note.type = guitarpro.NoteType.rest
                    elif note_by_string[si] == '~':
                        note.type = guitarpro.NoteType.tie
                    else:
                        note.value = int(note_by_string[si])
                        note.type = guitarpro.NoteType.normal
                    beat.notes.append(note)
        
        # measure.timeSignature.numerator = 4
        # measure.timeSignature.denominator.value = dur_fraction_sum * 4  # number of quarters
    
        for si in range(N_STRINGS):
            decoded_by_string[si] += '|'

    return song

song = decode_into_song(encoded)

with open(f'/Users/vlad/tmp/generated-{path.name}', 'wb') as f:
    guitarpro.write(song, f)

#### Prepare dataset
We encode all tracks in flat format

In [54]:
for path, tracks in tracks_by_path.items():
    out_path = save_dir / (path.name + '-bass.txt')
    if out_path.exists():
        continue

    enc_tracks = [encode_track(t) for t in tracks]
    with out_path.open('w') as f:
        for t in enc_tracks:
            f.write(' '.join(t) + '\n')

Then we concatenate all encoded tracks in one file

In [None]:
concat_text_path = drive_path / 'AI/datasets/gtp_gpt/concat.txt'
if not concat_text_path.exists():
    paths = list((drive_path / 'AI/datasets/gtp_gpt').iterdir())
    with open(concat_text_path, 'w') as f:
        for path in tqdm(paths):
            if path.is_file():
                with open(path, 'r') as f2:
                    f.write(f2.read() + '\n')

### Alternative encoding: alphaTex format

Using [AlphaTex](https://www.alphatab.net/docs/alphatex/introduction) as a flat text representation

In [None]:
save_dir = drive_path / "AI/datasets/gtp_gpt"
save_dir.mkdir(exist_ok=True)

from tqdm import tqdm
for path in tqdm(paths[:100]):
    out_path = save_dir / (path.name + '-bass.tex')
    if out_path.exists():
        continue
    
    cmd = f'node gtp_to_tex.js "{path}" "{out_path}"'
    print(cmd)
    !{cmd}

### Tokenizing

The choice of tokenizer is not important here, so just arbitrary picking SentencePieceBPETokenizer. We might consider building a more music-specified custom tokenizer, but trusting transformers to figure out the song structure on their own. Interesting to see if it will figure out proper time signatures.


In [None]:
import pprint
import itertools
from torch.utils.data import random_split, TensorDataset
from tokenizers.implementations import SentencePieceBPETokenizer

print(f'Total tracks: {len(tracks_by_path)}')
tracks = list(itertools.chain.from_iterable(tracks_by_path.values()))

tokenizer = SentencePieceBPETokenizer()
vocab_size = 50_000
with concat_text_path.open() as f:
    tokenizer.train_from_iterator(
        (line.strip() for line in f), 
        vocab_size=vocab_size, 
        special_tokens=['<s>', '<unk>']
    )
print(f'Tokenizer vocab ({tokenizer.get_vocab_size()} tokens):')
pprint.pprint(sorted(tokenizer.get_vocab().items(), key=lambda kv: kv[1])[:1000])

In [None]:
t = "8.E0 8E0 1A6"
print(tokenizer.encode(t).ids, tokenizer.encode(t).tokens)

#### Initialising a model

A smaller version of GPT2 for now.

In [None]:
import torch
torch.random.manual_seed(42)
torch.cuda.random.manual_seed(42)

In [None]:
from transformers import GPT2LMHeadModel, GPT2Config

config = GPT2Config(
    bos_token_id=0,
    eos_token_id=0,
    vocab_size=tokenizer.get_vocab_size(),
    n_positions=120,  # 1024
    n_embd=96,        # 768,
    n_layer=6,        # 12,
    n_head=6,         # 12,
)

model = GPT2LMHeadModel(config)
model

#### Building a dataset

In [None]:
!wc {concat_text_path}

In [None]:
class Dataset(TensorDataset):
    def __init__(self, file_path: Path, block_size: int):
        self.block_size = block_size
        
        cache_path = file_path.with_name(file_path.name + '.pt')
        if cache_path.exists():
            self.data = torch.load(str(cache_path))
            print(f"Loading data from cache {cache_path}")
            return
    
        print(f"Creating examples from dataset file {file_path}")
        with open(file_path, encoding="utf-8") as f:
            lines = f.readlines()
        ids = [0]
        for line in tqdm(lines):
            if line := line.strip():
                ids.extend(tokenizer.encode(line).ids + [0])
        super().__init__(torch.LongTensor(ids))

        test_n = min(1000, int(len(self) * 0.1))
        self.train, self.test = random_split(self, [len(self) - test_n, test_n])

    def __getitem__(self, idx: int):
        t = torch.LongTensor(self.tensors[0][idx:idx + self.block_size])
        return {"input_ids": t, "labels": t}

    def __len__(self):
        return len(self.tensors[0]) - self.block_size + 1

ds = Dataset(concat_text_path, block_size=config.n_positions)
print(f'Dataset size: {len(ds)}')
print(f'Train size: {len(ds.train)}')
print(f'Test size: {len(ds.test)}')

#### Sampling

We would sample and also immediately decode our string into a more human-readable tab format.

In [None]:
def sample(num_seqs=1, max_length=13):
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    for i, seq in enumerate(model.generate(
        max_length=max_length,
        top_p=0.95,
        num_return_sequences=num_seqs,
        do_sample=True, 
        top_k=50,
        pad_token_id=0,
        eos_token_id=0,
        bos_token_id=0,
    )):
        seq = tokenizer.decode(seq.tolist())
        print(i + 1, seq)
        decode_and_print(seq.split())

#### Training the model

In [None]:
import os
from transformers import Trainer, TrainingArguments, TrainerCallback
from transformers.trainer_utils import get_last_checkpoint

save_dir = drive_path / "AI" / "gtp_gpt"
save_dir.mkdir(exist_ok=True)
if last_checkpoint_dir := get_last_checkpoint(str(save_dir)):
    last_checkpoint_dir = Path(last_checkpoint_dir)
    print([t.name for t in last_checkpoint_dir.iterdir()])

class MyCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        sample()

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir=str(save_dir),
        report_to=['wandb'] if 'WANDB_API_KEY' in os.environ else [],
        evaluation_strategy="epoch",
        overwrite_output_dir=True,
        eval_steps=500,
        save_steps=500,
        save_total_limit=2,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        ignore_data_skip=True,
    ),
    train_dataset=ds.train,
    callbacks=[MyCallback],
)
trainer.train(resume_from_checkpoint=last_checkpoint_dir)
# trainer.train()
