<a href="https://colab.research.google.com/github/yongsun-yoon/music-vae/blob/main/01_preprocess.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Preprocess

## 0. 설명
MusicVAE는 미디(MIDI) 데이터셋으로 학습합니다. 

이 코드는 [Groove MIDI Dataset](https://magenta.tensorflow.org/datasets/groove)을 다운로드 후 전처리하는 코드입니다.

모든 코드는 Google Colab 환경에서 실행했습니다.

### Reference
* https://github.com/magenta/magenta/tree/main/magenta/models/music_vae
* https://magenta.tensorflow.org/datasets/groove
* https://github.com/magenta/magenta/blob/main/magenta/scripts/convert_dir_to_note_sequences.py

## 1. 환경 설정

In [None]:
# 라이브러리 설치
!apt-get -qq update -y
!apt-get -qq install build-essential libasound2-dev libjack-dev libfluidsynth2 fluid-soundfont-gm -y

!pip install -q magenta

In [None]:
# 라이브러리 임포트
from note_seq import midi_io
import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()
tf.logging.set_verbosity('INFO')

Instructions for updating:
non-resource variables are not supported in the long term


In [None]:
# 마운트된 Google drive 디렉토리 지정
BASE_DIR = '/content/drive/MyDrive/project/pozalabs-assignment'

In [None]:
# 데이터 디렉토리 생성
!mkdir $BASE_DIR/data

## 2. 데이터 다운로드

In [None]:
# 데이터 다운로드
!wget https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip

--2023-01-31 10:05:40--  https://storage.googleapis.com/magentadata/datasets/groove/groove-v1.0.0-midionly.zip
Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.213.128, 108.177.11.128, 173.194.216.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.213.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3260318 (3.1M) [application/zip]
Saving to: ‘groove-v1.0.0-midionly.zip’


2023-01-31 10:05:40 (108 MB/s) - ‘groove-v1.0.0-midionly.zip’ saved [3260318/3260318]



In [None]:
# 데이터 압축 해제
!unzip groove-v1.0.0-midionly.zip -d $BASE_DIR/data

In [None]:
# 데이터 확인
!ls $BASE_DIR/data/groove

drummer1   drummer2  drummer4  drummer6  drummer8  Icon      LICENSE
drummer10  drummer3  drummer5  drummer7  drummer9  info.csv  README


## 3. 데이터 읽기

In [None]:
file_path = f'{BASE_DIR}/data/groove/drummer1/session1/1_funk_80_beat_4-4.mid'
sequence = midi_io.midi_to_sequence_proto(tf.gfile.GFile(file_path, 'rb').read())

In [None]:
# note 수
len(sequence.notes)

773

In [None]:
# note 예시
sequence.notes[0]

pitch: 38
velocity: 7
start_time: 2.115625
end_time: 2.1765625
is_drum: true

In [None]:
# drum categories
set([n.pitch for n in sequence.notes])

{26, 36, 37, 38, 40, 42, 43, 44, 46}

In [None]:
# serialization
sequence.SerializeToString()[:10]

b' \xe0\x03*\x04\x10\x04\x18\x042'

## 4. 데이터 전처리

In [None]:
!convert_dir_to_note_sequences \
    --recursive \
    --input_dir=$BASE_DIR/data/groove \
    --output_file=$BASE_DIR/data/tfrecord \ 