In [2]:
!pip install torch transformers wandb PyGuitarPro
%env TOKENIZERS_PARALLELISM false

Collecting wandb
  Using cached wandb-0.13.10-py3-none-any.whl (2.0 MB)
Collecting docker-pycreds>=0.4.0
  Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting GitPython>=1.0.0
  Using cached GitPython-3.1.31-py3-none-any.whl (184 kB)
Collecting appdirs>=1.4.3
  Using cached appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)
Collecting setproctitle
  Using cached setproctitle-1.3.2-cp310-cp310-macosx_10_9_x86_64.whl (11 kB)
Collecting protobuf!=4.21.0,<5,>=3.19.0
  Using cached protobuf-4.22.0-cp37-abi3-macosx_10_9_universal2.whl (397 kB)
Collecting sentry-sdk>=1.0.0
  Using cached sentry_sdk-1.16.0-py2.py3-none-any.whl (184 kB)
Collecting Click!=8.0.0,>=7.0
  Using cached click-8.1.3-py3-none-any.whl (96 kB)
Collecting pathtools
  Using cached pathtools-0.1.2.tar.gz (11 kB)
  Preparing metadata (setup.py) ... [?25ldone
Collecting gitdb<5,>=4.0.1
  Using cached gitdb-4.0.10-py3-none-any.whl (62 kB)
Collecting smmap<6,>=3.0.1
  Using cached smmap-5.

In [1]:
from pathlib import Path
drive_path = Path('/Users/vlad/googledrive')
if not drive_path.exists():
    drive_path = Path("/content/drive/MyDrive")

In [7]:
tabs_path = drive_path / 'PlayMusic/tabs'
paths = list(tabs_path.glob('**/*.gp[3-5]'))
print(f'Found {len(paths)} files')

Found 49125 files


In [19]:
N_TOY = 100

#### Parsing tabs
We would start with only bass tracks as they are simpler. Doing best to select bass tracks based on heuristics: 4 strings, 24 frets, instruments from 32 to 40 (unfortunately PyGuitarPro doesn't provide labels for the instruments, only numbers - so I checked the names in GuitarPro itself (rather, I use an open source version tuxguitar) and selected all instrument that have "bass" in the name.

In [20]:
from collections import defaultdict
import guitarpro
from tqdm import tqdm

N_FRETS = 24
N_STRINGS = 4  # base; 6 for standard guitar
INSTRUMENTS = range(32, 40)  # base; range(24, 31) for standard guitar

tracks_by_path = dict()

for path_num, path in enumerate(tqdm(paths)):
    try:
        curl = guitarpro.parse(path)
    except guitarpro.GPException:
        print(f'   failed to parse {path}')
        continue

    tracks = []
    for track in curl.tracks:
        if all([
            track.settings.tablature, 
            len(track.strings) == N_STRINGS,
            track.fretCount == N_FRETS,
            track.channel.instrument in INSTRUMENTS,
        ]):
            tracks.append(track)
    
    if tracks:
        tracks_by_path[path] = tracks

    if sum(map(len, tracks_by_path.values())) >= N_TOY:
        break

if len(tracks_by_path) == 0:
    raise Exception('No bass tracks found')

print(
    f'Found {len(tracks_by_path)}/{len(paths)} files with bass tracks, '
    f'total {sum(map(len, tracks_by_path.values()))} tracks'
)

  0%|          | 228/49125 [01:44<6:14:43,  2.17it/s] 

Found 99/49125 files with bass tracks, total 100 tracks





#### Verifying tabs by printing one in a human-readable format

In [24]:
def pprint_track(track: guitarpro.Track):
    lines = ['' for _ in range(N_STRINGS)]

    cur_time_sig = None
    for mi, measure in enumerate(track.measures):
        if measure.timeSignature != cur_time_sig:
            cur_time_sig = measure.timeSignature
            lines[0] += '  '
            lines[1] += f'{cur_time_sig.numerator} '
            lines[2] += f'{cur_time_sig.denominator.value} '
            lines[3] += '  '

        voice = measure.voices[0]  # only the first voice is usually non-empty
        for bi, beat in enumerate(voice.beats):
            out = f'{beat.duration.value}{"+" if beat.duration.isDotted else ""}:'
            note_by_string = {note.string - 1: note for note in beat.notes}
            for s in range(N_STRINGS):
                if n := note_by_string.get(s):
                    if n.type == guitarpro.NoteType.tie:
                        v = '~'
                    elif n.type == guitarpro.NoteType.dead:
                        v = 'x'
                    else:
                        v = n.value
                else:
                    v = '-'
                lines[s] += f'{out}{v:<3} '
        for s in range(N_STRINGS):
            lines[s] += '|'

    screen_width = 80
    for screen_num in range(0, len(lines[0]), screen_width):
        for line in lines:
            print(line[screen_num:screen_num+screen_width])
        print()

path = list(tracks_by_path.keys())[0]
track = tracks_by_path[path][0]
print(path)
print(track.name)
pprint_track(track)

/Users/vlad/googledrive/PlayMusic/tabs/White Lion/White Lion - Lady Of The Valley.gp4
Bass
  1:-   |4+:7   4+:7   4:7   |8:~   4+:7   4:7   4:7   |4+:7   4+:7   4:7   |8:~
4 1:-   |4+:-   4+:-   4:-   |8:-   4+:-   4:-   4:-   |4+:-   4+:-   4:-   |8:-
4 1:-   |4+:-   4+:-   4:-   |8:-   4+:-   4:-   4:-   |4+:-   4+:-   4:-   |8:-
  1:-   |4+:-   4+:-   4:-   |8:-   4+:-   4:-   4:-   |4+:-   4+:-   4:-   |8:-

   4+:7   4:7   4:7   |4+:7   4+:7   4:7   |8:~   4+:7   4:7   4:7   |4+:7   4+:
   4+:-   4:-   4:-   |4+:-   4+:-   4:-   |8:-   4+:-   4:-   4:-   |4+:-   4+:
   4+:-   4:-   4:-   |4+:-   4+:-   4:-   |8:-   4+:-   4:-   4:-   |4+:-   4+:
   4+:-   4:-   4:-   |4+:-   4+:-   4:-   |8:-   4+:-   4:-   4:-   |4+:-   4+:

7   4:7   |8:~   4+:7   8:7   8+:-   8+:-   |4+:-   2:-   8:-   |4+:-   4:-   8+
-   4:-   |8:-   4+:-   8:-   8+:-   8+:-   |4+:-   2:0   8:~   |4+:-   4:-   8+
-   4:-   |8:-   4+:-   8:-   8+:-   8+:-   |4+:-   2:-   8:-   |4+:-   4:-   8+
-   4:-   |8:-  

#### Encoding into flat string
 
Now attempting to encode multi-string tabs into a flat string, suitable for training a transformer. In an NLP world, our token would encode all notes that sound at the same time during one beat, along with the beat duration. We would separate beats with a space (like words of text), and separate measures with a "|" character (like paragraphs/sentences).

In [26]:
import math

duration_alphabet = ['1', '2', '4', '8', 'F', 'H', 'G', 'I']
strings_alphabet = ['E', 'A', 'D', 'G', 'B', 'e']

save_dir = drive_path / 'AI/datasets/gtp_gpt'
save_dir.mkdir(exist_ok=True)

def encode_track(track):
    out = []
    
    for mi, measure in enumerate(track.measures):
        beats = measure.voices[0].beats
        if not any(len(b.notes) > 0 for b in beats):
            continue
        
        for bi, beat in enumerate(beats):
            s = duration_alphabet[int(math.log2(beat.duration.value))]
            if beat.duration.isDotted:
                s += '.'
            
            if beat.notes:
                notes = []
                for n in beat.notes:
                    if n.type == guitarpro.NoteType.tie:
                        v = '~'
                    elif n.type == guitarpro.NoteType.dead:
                        v = 'x'
                    else:
                        v = n.value
                    notes.append(f'{strings_alphabet[n.string - 1]}{v}')
                s += "".join(notes)
            out.append(s)
        out.append('|')
    return out

for path, tracks in tracks_by_path.items():
    out_path = save_dir / (path.name + '-bass.txt')
    if out_path.exists():
        continue

    enc_tracks = [encode_track(t) for t in tracks]
    with out_path.open('w') as f:
        for t in enc_tracks:
            f.write(' '.join(t) + '\n')

encoded = encode_track(track)
print(path)
print(' '.join(encoded))

/Users/vlad/googledrive/PlayMusic/tabs/De Palmas/De Palmas - Regarde Moi Bien En Face (2).gp3
4.E7 4.E7 4E7 | 8E~ 4.E7 4E7 4E7 | 4.E7 4.E7 4E7 | 8E~ 4.E7 4E7 4E7 | 4.E7 4.E7 4E7 | 8E~ 4.E7 4E7 4E7 | 4.E7 4.E7 4E7 | 8E~ 4.E7 8E7 8.G1 8.G0 | 4. 2A0 8A~ | 4. 4 8.G1 8.G0 | 4D3 8D3 2E5 8E~ | 4.E6 4A7 8.G1 8.G0 | 4. 2A0 8A~ | 4. 4 8.G1 8.G0 | 4D3 8D3 2E5 8E~ | 4.E6 4A7 8.G1 8.G0 | 4. 2A0 8A~ | 4. 4 8.G1 8.G0 | 4D3 8D3 4.E5 4E5 | 4.E6 4A7 8.G1 8.G0 | 4. 2A0 8A~ | 4 8 4 8.G1 8.G0 | 4D3 8D3 4.E5 4E5 | 4.E6 4.A7 4A7 | 4.D8 4D7 8 4D8 | 8D~ 4D7 8 4D8 8E0 8E0 | 1E~ | 1E~ | 2. 8 8D0 | 1A0 | 2.A~ 4 | 2. 8 8D0 | 4.A0 8D0 4A0 4A0 | 2.D3 8D~ 8D3 | 4.D3 8D3 8A0 8D3 8 8G3 | 2.G1 8G~ 8 | 4.G1 8 8G1 8G3 8D0 8A0 | 2.D3 8D~ 8G3 | 8E2 8E0 8A3 8A2 8A0 8D3 8D0 8G3 | 8D0 8D0 8D0 8D0 8D0 8D0 8G3 8D0 | 4D~ 8D0 8D0 8D0 8D0 8D0 8D0 | 8D1 8D1 8D1 8D1 8D1 8D1 8D1 8D3 | 8D~ 8D3 8D3 8D3 8D3 8D1 8D0 8G3 | 8D0 8D0 8D0 8D0 8D0 8D0 8G3 8D0 | 8D~ 8D0 8D0 8D0 8D0 8D0 8D0 8D0 | 8D1 8D1 8D1 8D1 8D1 8D1 8D1 8D3 | 4D~ 2. | 2.D3 8D

#### Decode into human-readable tab

Decoding back into human-readable tab to verify correctness.

In [43]:
def decode_and_print(encoded: list[str]):
    decoded_by_string = defaultdict(str)

    for measure in ' '.join(encoded).strip().split('|'):
        for beat in measure.strip().split(' '):
            if not beat:
                continue

            i = 0
            duration = str(2 ** duration_alphabet.index(beat[i]))
            i += 1
            
            if i < len(beat) and beat[i] == '·':
                duration += '·'
                i += 1
    
            for si in range(N_STRINGS):
                decoded_by_string[si] += f'{duration}:'
             
            note_by_string = {}
            cur_string, cur_note = None, ''
            while i < len(beat):
                if beat[i] in strings_alphabet:
                    if cur_string is not None:
                        note_by_string[cur_string] = cur_note
                    cur_string = strings_alphabet.index(beat[i])
                    cur_note = ''
                else:
                    cur_note += beat[i]
                i += 1
            
            if cur_string is not None:
                note_by_string[cur_string] = cur_note
            
            for si in range(N_STRINGS):
                if si in note_by_string:
                    decoded_by_string[si] += f'{note_by_string[si]:<3}'
                else:
                    decoded_by_string[si] += f'{"-":<3}'
                decoded_by_string[si] += ' '
    
        for si in range(N_STRINGS):
            decoded_by_string[si] += '|'
    
    screen_width = 80
    for screen_num in range(0, len(decoded_by_string[0]), screen_width):
        for line in decoded_by_string.values():
            print(line[screen_num:screen_num+screen_width])
        print()
        
decode_and_print(encoded)

2:-   2:3   |2:2   2:5   |2:1   2:1   |2:0   2:0   |2:0   2:3   |2:2   2:5   |2:
2:5   2:-   |2:-   2:3   |2:-   2:0   |2:1   2:-   |2:1   2:1   |2:-   2:3   |2:
2:5   2:5   |2:2   2:-   |2:2   2:-   |2:-   2:-   |2:-   2:-   |2:2   2:-   |2:
2:-   2:-   |2:-   2:-   |2:-   2:-   |2:-   2:-   |2:-   2:-   |2:-   2:-   |2:

1   2:1   |2:0   2:0   |2:0   2:0   |2:0   2:0   |2:0   2:0   |2:0   2:0   |2:0 
-   2:0   |2:1   2:-   |2:1   2:1   |2:0   2:0   |2:-   2:-   |2:-   2:-   |2:1 
2   2:-   |2:-   2:-   |2:-   2:-   |2:-   2:-   |2:2   2:2   |2:1   2:1   |2:- 
-   2:-   |2:-   2:-   |2:-   2:-   |2:-   2:-   |2:-   2:-   |2:-   2:-   |2:- 

  2:3   |2:2   2:5   |2:1   2:1   |2:0   2:0   |4:0   4:0   4:1   4:0   |4:-   8
  2:1   |2:-   2:3   |2:-   2:0   |2:1   2:-   |4:-   4:-   4:-   4:-   |4:3   8
  2:-   |2:2   2:-   |2:2   2:-   |2:-   2:-   |4:-   4:-   4:-   4:-   |4:-   8
  2:-   |2:-   2:-   |2:-   2:-   |2:-   2:-   |4:-   4:-   4:-   4:-   |4:-   8

:-   2:-   |4:-   4:-   4

#### Prepare dataset

First, we concatenate all encoded tracks in one file

In [27]:
concat_text_path = drive_path / 'AI/datasets/gtp_gpt/concat.txt'

paths = list((drive_path / 'AI/datasets/gtp_gpt').iterdir())
with open(concat_text_path, 'w') as f:
    for path in tqdm(paths):
        if path.is_file():
            with open(path, 'r') as f2:
                f.write(f2.read() + '\n')

100%|██████████| 12067/12067 [00:23<00:00, 516.01it/s]


#### Tokenizing

The choice of tokenizer is not important here, so just arbitrary picking SentencePieceBPETokenizer. We might consider building a more music-specified custom tokenizer, but trusting transformers to figure out the song structure on their own. Interesting to see if it will figure out proper time signatures.


In [8]:
import pprint
import itertools
from torch.utils.data import random_split, TensorDataset
from tokenizers.implementations import SentencePieceBPETokenizer

print(f'Total tracks: {len(tracks_by_path)}')
tracks = list(itertools.chain.from_iterable(tracks_by_path.values()))

tokenizer = SentencePieceBPETokenizer()
vocab_size = 50_000
with concat_text_path.open() as f:
    tokenizer.train_from_iterator(
        (line.strip() for line in f), 
        vocab_size=vocab_size, 
        special_tokens=['<s>', '<unk>']
    )
print(f'Tokenizer vocab ({tokenizer.get_vocab_size()} tokens):')
pprint.pprint(sorted(tokenizer.get_vocab().items(), key=lambda kv: kv[1])[:1000])

Total tracks: 17



Tokenizer vocab (277 tokens):
[('<s>', 0),
 ('<unk>', 1),
 ('.', 2),
 ('0', 3),
 ('1', 4),
 ('2', 5),
 ('3', 6),
 ('4', 7),
 ('5', 8),
 ('6', 9),
 ('7', 10),
 ('8', 11),
 ('9', 12),
 ('A', 13),
 ('D', 14),
 ('E', 15),
 ('F', 16),
 ('G', 17),
 ('H', 18),
 ('x', 19),
 ('|', 20),
 ('~', 21),
 ('▁', 22),
 ('▁8', 23),
 ('▁|', 24),
 ('▁4', 25),
 ('▁F', 26),
 ('▁8D', 27),
 ('▁2', 28),
 ('▁8G', 29),
 ('▁FD', 30),
 ('▁4D', 31),
 ('▁4G', 32),
 ('▁FG', 33),
 ('▁8D0', 34),
 ('▁4.', 35),
 ('▁8A', 36),
 ('▁8D2', 37),
 ('▁8.', 38),
 ('▁8D4', 39),
 ('▁8E', 40),
 ('▁FA', 41),
 ('▁2D', 42),
 ('▁1', 43),
 ('▁FD0', 44),
 ('▁2G', 45),
 ('▁FG0', 46),
 ('▁8D1', 47),
 ('▁2E', 48),
 ('▁8G3', 49),
 ('▁4D3', 50),
 ('▁4G1', 51),
 ('▁8D3', 52),
 ('▁8G0', 53),
 ('▁4A', 54),
 ('▁8D7', 55),
 ('▁1G', 56),
 ('▁8G4', 57),
 ('▁FD2', 58),
 ('▁4D0', 59),
 ('▁FD1', 60),
 ('▁4E', 61),
 ('▁FA0', 62),
 ('▁4.D', 63),
 ('▁4.G', 64),
 ('▁2A', 65),
 ('▁2D0', 66),
 ('▁FA1', 67),
 ('▁FD4', 68),
 ('▁8G1', 69),
 ('

In [61]:
t = "8.E0 8E0 1A6"
print(tokenizer.encode(t).ids, tokenizer.encode(t).tokens)

[183, 96, 194, 9] ['▁8.E0', '▁8E0', '▁1A', '6']


#### Initialising a model

A smaller version of GPT2 for now.

In [11]:
import torch
torch.random.manual_seed(42)
torch.cuda.random.manual_seed(42)

In [9]:
from transformers import GPT2LMHeadModel, GPT2Config

config = GPT2Config(
    bos_token_id=0,
    eos_token_id=0,
    vocab_size=tokenizer.get_vocab_size(),
    n_positions=120,  # 1024
    n_embd=96,        # 768,
    n_layer=6,        # 12,
    n_head=6,         # 12,
)

model = GPT2LMHeadModel(config)
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(277, 96)
    (wpe): Embedding(120, 96)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0

#### Building a dataset

In [49]:
!wc {concat_text_path}

    83699  35131828 130807398 /Users/vlad/googledrive/AI/gtp_gpt/encoded/concat.txt


In [63]:
class Dataset(TensorDataset):
    def __init__(self, file_path: Path, block_size: int):
        self.block_size = block_size
        
        cache_path = file_path.with_name(file_path.name + '.pt')
        if cache_path.exists():
            self.data = torch.load(str(cache_path))
            print(f"Loading data from cache {cache_path}")
            return
    
        print(f"Creating examples from dataset file {file_path}")
        with open(file_path, encoding="utf-8") as f:
            lines = f.readlines()
        ids = [0]
        for line in tqdm(lines):
            if line := line.strip():
                ids.extend(tokenizer.encode(line).ids + [0])
        super().__init__(torch.LongTensor(ids))

        test_n = min(1000, int(len(self) * 0.1))
        self.train, self.test = random_split(self, [len(self) - test_n, test_n])

    def __getitem__(self, idx: int):
        t = torch.LongTensor(self.tensors[0][idx:idx + self.block_size])
        return {"input_ids": t, "labels": t}

    def __len__(self):
        return len(self.tensors[0]) - self.block_size + 1

ds = Dataset(concat_text_path, block_size=config.n_positions)
print(f'Dataset size: {len(ds)}')
print(f'Train size: {len(ds.train)}')
print(f'Test size: {len(ds.test)}')

Creating examples from dataset file /Users/vlad/googledrive/AI/gtp_gpt/encoded/concat.txt


100%|██████████| 1000/1000 [00:02<00:00, 368.73it/s]


Dataset size: 481085
Train size: 480085
Test size: 1000


#### Sampling

We would sample and also immediately decode our string into a more human-readable tab format.

In [47]:
def sample(num_seqs=1, max_length=13):
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    for i, seq in enumerate(model.generate(
        max_length=max_length,
        top_p=0.95,
        num_return_sequences=num_seqs,
        do_sample=True, 
        top_k=50,
        pad_token_id=0,
        eos_token_id=0,
        bos_token_id=0,
    )):
        seq = tokenizer.decode(seq.tolist())
        print(i + 1, seq)
        decode_and_print(seq.split())

A decoder-only architecture is being used, but right-padding was detected! For correct generation results, please set `padding_side='left'` when initializing the tokenizer.


1 FA1 8E6 8E6 1G3 2E0 1 1G7 8G2 8G57 FE2 2G5
16:-   8:6   8:6   1:-   2:0   1:-   1:-   8:-   8:-   16:2   2:-   |
16:1   8:-   8:-   1:-   2:-   1:-   1:-   8:-   8:-   16:-   2:-   |
16:-   8:-   8:-   1:-   2:-   1:-   1:-   8:-   8:-   16:-   2:-   |
16:-   8:-   8:-   1:3   2:-   1:-   1:7   8:2   8:57  16:-   2:5   |

2 4.D5 8.G5 4E8 4 4.EG 8A3 4.D4 2G0 FE 4D2 2.D
4:-   8:-   4:8   4:-   4:    8:-   4:-   2:-   16:    4:-   2:-   |
4:-   8:-   4:-   4:-   4:-   8:3   4:-   2:-   16:-   4:-   2:-   |
4:5   8:-   4:-   4:-   4:-   8:-   4:4   2:-   16:-   4:2   2:    |
4:-   8:5   4:-   4:-   4:    8:-   4:-   2:0   16:-   4:-   2:-   |



#### Training the model

In [64]:
import os
from transformers import Trainer, TrainingArguments, TrainerCallback
from transformers.trainer_utils import get_last_checkpoint

save_dir = DRIVE_PATH / "AI" / "gtp_gpt"
save_dir.mkdir(exist_ok=True)
if last_checkpoint_dir := get_last_checkpoint(str(save_dir)):
    last_checkpoint_dir = Path(last_checkpoint_dir)
    print([t.name for t in last_checkpoint_dir.iterdir()])

class MyCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        sample()

trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir=str(save_dir),
        report_to=['wandb'] if 'WANDB_API_KEY' in os.environ else [],
        evaluation_strategy="epoch",
        overwrite_output_dir=True,
        eval_steps=500,
        save_steps=500,
        save_total_limit=2,
        per_device_train_batch_size=2,
        per_device_eval_batch_size=2,
        ignore_data_skip=True,
    ),
    train_dataset=ds.train,
    callbacks=[MyCallback],
)
trainer.train(resume_from_checkpoint=last_checkpoint_dir)
# trainer.train()


***** Running training *****
  Num examples = 480085
  Num Epochs = 3
  Instantaneous batch size per device = 2
  Total train batch size (w. parallel, distributed & accumulation) = 2
  Gradient Accumulation steps = 1
  Total optimization steps = 720129
  Number of trainable parameters = 709344


Epoch,Training Loss,Validation Loss


KeyboardInterrupt: 