Skip to content

Commit

Permalink
train.py added
Browse files Browse the repository at this point in the history
  • Loading branch information
sakemin committed Oct 19, 2023
1 parent b26dac8 commit adec2ad
Show file tree
Hide file tree
Showing 69 changed files with 2,733 additions and 167 deletions.
Binary file added __pycache__/metadata.cpython-39.pyc
Binary file not shown.
Binary file modified __pycache__/predict.cpython-39.pyc
Binary file not shown.
Binary file added __pycache__/train.cpython-39.pyc
Binary file not shown.
Binary file added audiocraft/__pycache__/train.cpython-39.pyc
Binary file not shown.
Binary file modified audiocraft/modules/__pycache__/chord_chroma.cpython-39.pyc
Binary file not shown.
Binary file modified audiocraft/modules/btc/utils/__pycache__/hparams.cpython-39.pyc
Binary file not shown.
3 changes: 3 additions & 0 deletions audiocraft/modules/btc/utils/hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ def __repr__(self):

@classmethod
def load(cls, path):
import os
print(os.environ['HOME'])
print(os.getcwd())
with open(path, 'r') as f:
return cls(**yaml.load(f, yaml.Loader))

Expand Down
4 changes: 2 additions & 2 deletions audiocraft/modules/chord_chroma.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ class ChordExtractor(nn.Module):

def __init__(self, device, sample_rate, max_duration, chroma_len, n_chroma, winhop):
super().__init__()
self.config = HParams.load("audiocraft/modules/btc/run_config.yaml") #gotta specify the path for run_config.yaml of btc
self.config = HParams.load("/src/audiocraft/modules/btc/run_config.yaml") #gotta specify the path for run_config.yaml of btc

# self.config.feature['large_voca'] = False
# self.config.model['num_chords'] = 25

self.model_file = 'audiocraft/modules/btc/test/btc_model_large_voca.pt'
self.model_file = '/src/audiocraft/modules/btc/test/btc_model_large_voca.pt'
# self.model_file = 'audiocraft/modules/btc/test/btc_model.pt'
self.idx_to_chord = idx2voca_chord()
self.sr = sample_rate
Expand Down
Binary file modified audiocraft/solvers/__pycache__/base.cpython-39.pyc
Binary file not shown.
Binary file not shown.
Binary file not shown.
4 changes: 2 additions & 2 deletions audiocraft/solvers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,8 @@ def load_checkpoints(self, load_best: bool = False, ignore_state_keys: tp.List[s
# checkpoints are not from the current xp, we only retrieve the best state
if checkpoint_source is not None and checkpoint_source != checkpoint.CheckpointSource.CURRENT_XP:
assert state is not None
self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
load_best = True
# self.logger.info("Checkpoint source is not the current xp: Load state_dict from best state.")
load_best = False
state = {key: state[key] for key in self._continue_best_source_keys if key in state}
# loaded checkpoints are FSDP checkpoints: we're reading the best state
# from FSDP and we drop the regular best_state
Expand Down
162 changes: 1 addition & 161 deletions audiocraft/solvers/musicgen_chord.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,96 +138,14 @@ def build_model(self) -> None:
self.compression_model.num_codebooks, self.compression_model.cardinality,
self.compression_model.frame_rate)

import os, psutil
print("Memory Usage : ", psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)

print("Initiating Model")
# instantiate LM model
self.model: models.LMModel = models.builders.get_lm_model(self.cfg).to(self.device)

'''
# Change existing ChromaStemConditioner to ChromaChordConditioner and migrate params
print("Memory Usage : ", psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)
print("Loading pretrained model instance")
mgmodel = MusicGen.get_pretrained('facebook/musicgen-melody')
mgmodel.set_generation_params(duration=30)
print("Memory Usage : ", psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)
output_proj_weight = mgmodel.lm.condition_provider.conditioners['self_wav'].output_proj.state_dict() #not loaded yet?
mgmodel.lm.condition_provider.conditioners['self_wav'] = ChromaChordConditioner(int(mgmodel.lm.condition_provider.conditioners['self_wav'].output_dim), int(mgmodel.lm.condition_provider.conditioners['self_wav'].sample_rate), int(mgmodel.lm.condition_provider.conditioners['self_wav'].chroma.n_chroma), int(math.log2(mgmodel.lm.condition_provider.conditioners['self_wav'].chroma.winlen)), float(mgmodel.lm.condition_provider.conditioners['self_wav'].duration), device=self.device)
# mgmodel.lm.condition_provider.conditioners['self_wav'].output_proj[0].load_state_dict(output_proj_weight) #For MLP projection
mgmodel.lm.condition_provider.conditioners['self_wav'].output_proj.load_state_dict(output_proj_weight)
print('assignin mgmodel.lm params to LMModel !!!')
self.model.load_state_dict(mgmodel.lm.state_dict())
self.model.to(self.device)
print('assigned params!!')
print("Memory Usage : ", psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)
print('deleted instances')
del mgmodel, output_proj_weight
gc.collect()
print("Memory Usage : ", psutil.Process(os.getpid()).memory_info().rss / 1024 ** 2)
# self.model.condition_provider.conditioners['description'].output_proj.weight.requires_grad = True
'''

'''
for param in self.model.parameters():
param.requires_grad = False
self.model.condition_provider.conditioners['self_wav'].output_proj.weight.requires_grad = True
self.model.condition_provider.conditioners['self_wav'].output_proj.bias.requires_grad = True
# for i in range(6):
# for param in self.model.transformer.layers[i].parameters():
# param.requires_grad = True
'''

# for name, param in self.model.condition_provider.conditioners['self_wav'].output_proj.named_parameters():
# print(name, param.requires_grad)
# print(param)
'''
if self.cfg.fsdp.use:
assert not self.cfg.autocast, "Cannot use autocast with fsdp"
self.model = self.wrap_with_fsdp(self.model)
self.register_ema('model')
# initialize optimization
'''
# model_groups = [
# {'params': self.model.condition_provider.conditioners['self_wav'].output_proj.parameters()}, #, 'lr': 1e-4},
# {'params': self.model.transformer.layers[0].parameters()},
# {'params': self.model.transformer.layers[1].parameters()},
# {'params': self.model.transformer.layers[2].parameters()},
# {'params': self.model.transformer.layers[3].parameters()},
# {'params': self.model.transformer.layers[4].parameters()},
# {'params': self.model.transformer.layers[5].parameters()}
# ]
'''
# self.optimizer = builders.get_optimizer(model_groups, self.cfg.optim) #including tf layers
# self.optimizer = builders.get_optimizer(self.model.condition_provider.conditioners['self_wav'].output_proj, self.cfg.optim)
'''

self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) #Single GPU

self.lr_scheduler = builders.get_lr_scheduler(self.optimizer, self.cfg.schedule, self.total_updates)
self.register_stateful('compression_model', 'model', 'optimizer', 'lr_scheduler')
self.register_best_state('model')
Expand All @@ -247,21 +165,6 @@ def build_model(self) -> None:
self.scaler = torch.cuda.amp.GradScaler()
self.register_stateful('scaler')

def freeze_parameters(self) -> None:
for param in self.model.parameters():
param.requires_grad = False

self.model.condition_provider.conditioners['self_wav'].output_proj.weight.requires_grad = True
self.model.condition_provider.conditioners['self_wav'].output_proj.bias.requires_grad = True

'''
# for i in range(6):
# for param in self.model.transformer.layers[i].parameters():
# param.requires_grad = True
'''
print("Parameters frozen!")
self.optimizer = builders.get_optimizer(builders.get_optim_parameter_groups(self.model), self.cfg.optim) #Single GPU

def build_dataloaders(self) -> None:
"""Instantiate audio dataloaders for each stage."""
self.dataloaders = builders.get_audio_datasets(self.cfg, dataset_type=self.DATASET_TYPE)
Expand All @@ -278,7 +181,6 @@ def run(self):
assert len(self.state_dict()) > 0
self.restore(replay_metrics=True) # load checkpoint and replay history
self.log_hyperparams(dict_from_config(self.cfg))
self.freeze_parameters()
for epoch in range(self.epoch, self.cfg.optim.epochs + 1):
if self.should_stop_training():
return
Expand Down Expand Up @@ -341,7 +243,7 @@ def _compute_cross_entropy(
ce = ce / K
return ce, ce_per_codebook

# @torch.no_grad()
@torch.no_grad()
def _prepare_tokens_and_attributes(
self, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAttributes]],
check_synchronization_points: bool = False
Expand Down Expand Up @@ -454,17 +356,13 @@ def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAt
if check_synchronization_points:
torch.cuda.set_sync_debug_mode('warn')

# print(torch.any(torch.isnan(audio_tokens)))
# print(torch.any(torch.isnan(condition_tensors['description'][0])))
# print(torch.any(torch.isnan(condition_tensors['self_wav'][0])))
with self.autocast:
model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors) # type: ignore
logits = model_output.logits
mask = padding_mask & model_output.mask
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
loss = ce
self.deadlock_detect.update('loss')
# print("model_output : \n", model_output)
if check_synchronization_points:
torch.cuda.set_sync_debug_mode('default')

Expand All @@ -475,42 +373,16 @@ def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAt
self.deadlock_detect.update('scale')
if self.cfg.fsdp.use:
loss.backward()
# print("layer1_w_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj.weight.grad)
# print("layer1_w_grad_sum: ", self.model.condition_provider.conditioners.self_wav.output_proj.weight.grad.sum())
# print("layer1_b_grad : ", self.model.condition_provider.conditioners.self_wav.output_proj.bias.grad)
'''
if self.model.condition_provider.conditioners.self_wav.output_proj[0].bias.device.index == 0:
print("layer1b_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj[0].bias.grad)
if self.model.condition_provider.conditioners.self_wav.output_proj[0].weight.device.index == 0:
print("layer1_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj[0].weight.grad)
if self.model.condition_provider.conditioners.self_wav.output_proj[2].bias.device.index == 0:
print("layer2b_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj[2].bias.grad)
if self.model.condition_provider.conditioners.self_wav.output_proj[2].weight.device.index == 0:
print("layer2_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj[2].weight.grad)
if self.model.condition_provider.conditioners.self_wav.output_proj[4].bias.device.index == 0:
print("layer3b_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj[4].bias.grad)
if self.model.condition_provider.conditioners.self_wav.output_proj[4].weight.device.index == 0:
print("layer3_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj[4].weight.grad)
'''
flashy.distrib.average_tensors(self.model.buffers())
elif self.cfg.optim.eager_sync:
with flashy.distrib.eager_sync_model(self.model):
loss.backward()
print("layer1_w_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj.weight.grad)
print("layer1_w_grad_sum: ", self.model.condition_provider.conditioners.self_wav.output_proj.weight.grad.sum())
print("layer1_b_grad : ", self.model.condition_provider.conditioners.self_wav.output_proj.bias.grad)
else:
# this should always be slower but can be useful
# for weird use cases like multiple backwards.
loss.backward()
print("layer1_w_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj.weight.grad)
print("layer1_w_grad_sum: ", self.model.condition_provider.conditioners.self_wav.output_proj.weight.grad.sum())
print("layer1_b_grad : ", self.model.condition_provider.conditioners.self_wav.output_proj.bias.grad)
flashy.distrib.sync_model(self.model)
self.deadlock_detect.update('backward')
# print(model_output)
# print('\n')
# print(loss)
if self.scaler is not None:
self.scaler.unscale_(self.optimizer)
if self.cfg.optim.max_norm:
Expand All @@ -522,38 +394,12 @@ def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAt
)
if self.scaler is None:
self.optimizer.step()
# print("layer1_weight : ", self.model.condition_provider.conditioners.self_wav.output_proj.weight)
# print("layer1_bias : ", self.model.condition_provider.conditioners.self_wav.output_proj.bias)
# print("layer1_grad: ", self.model.condition_provider.conditioners.self_wav.output_proj[0].weight.grad)
# if self.model.transformer.layers[0].linear1.weight.device.index == 0:
# print("lm_layer0_linear1_weight : ", self.model.transformer.layers[0].linear1.weight)
'''
if self.model.condition_provider.conditioners.self_wav.output_proj[0].weight.device.index == 0:
print("layer1_weight : ", self.model.condition_provider.conditioners.self_wav.output_proj[0].weight)
if self.model.condition_provider.conditioners.self_wav.output_proj[0].bias.device.index == 0:
print("layer1_bias : ", self.model.condition_provider.conditioners.self_wav.output_proj[0].bias)
if self.model.condition_provider.conditioners.self_wav.output_proj[2].weight.device.index == 0:
print("layer2_weight : ", self.model.condition_provider.conditioners.self_wav.output_proj[2].weight)
if self.model.condition_provider.conditioners.self_wav.output_proj[2].bias.device.index == 0:
print("layer2_bias : ", self.model.condition_provider.conditioners.self_wav.output_proj[2].bias)
if self.model.condition_provider.conditioners.self_wav.output_proj[4].weight.device.index == 0:
print("layer3_weight : ", self.model.condition_provider.conditioners.self_wav.output_proj[4].weight)
if self.model.condition_provider.conditioners.self_wav.output_proj[4].bias.device.index == 0:
print("layer3_bias : ", self.model.condition_provider.conditioners.self_wav.output_proj[4].bias)
'''
else:
self.scaler.step(self.optimizer)
self.scaler.update()

print("layer_weight : ", self.model.condition_provider.conditioners.self_wav.output_proj.weight)
print("layer_bias : ", self.model.condition_provider.conditioners.self_wav.output_proj.bias)
print("lm_layer0_linear1_weight : ", self.model.transformer.layers[0].linear1.weight)
# print("layer1_weight : ", self.model.condition_provider.conditioners.self_wav.output_proj.weight)
# print("layer1_bias : ", self.model.condition_provider.conditioners.self_wav.output_proj.bias)
if self.lr_scheduler:
self.lr_scheduler.step()
self.optimizer.zero_grad()
# self.model.zero_grad() #For partial weights optimizing Multi-GPU
self.deadlock_detect.update('optim')
if self.scaler is not None:
scale = self.scaler.get_scale()
Expand All @@ -567,10 +413,6 @@ def run_step(self, idx: int, batch: tp.Tuple[torch.Tensor, tp.List[SegmentWithAt
metrics[f'ce_q{k + 1}'] = ce_q
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)

# for name, param in self.model.condition_provider.conditioners.self_wav.output_proj.named_parameters():
# print(name, param.requires_grad)
# print(param)

return metrics

@torch.no_grad()
Expand Down Expand Up @@ -720,8 +562,6 @@ def get_hydrated_conditions(meta: tp.List[SegmentWithAttributes]):
**self.generation_params)
gen_audio = gen_outputs['gen_audio'].cpu()
prompt_audio = gen_outputs['prompt_audio'].cpu()
print(gen_audio.shape)
print(prompt_audio.shape)
sample_manager.add_samples(
gen_audio, self.epoch, hydrated_conditions,
prompt_wavs=prompt_audio, ground_truth_wavs=audio,
Expand Down
15 changes: 14 additions & 1 deletion cog.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ build:
# a list of ubuntu apt packages to install
system_packages:
- ffmpeg
- unzip
- build-essential
- libeigen3-dev
- libyaml-dev
- libfftw3-dev
- libtag1-dev>=1.9
- libchromaprint-dev
- numactl
- sox
# - "libgl1-mesa-glx"
# - "libglib2.0-0"

Expand Down Expand Up @@ -40,9 +49,13 @@ build:
- "encodec"
- "protobuf"
- "tensorboard>=1.15"
- "pydub"
- "numba"
- "essentia-tensorflow"

# commands run after the environment is setup
# run:
run:
- python3 -m pip install -U git+https://github.com/facebookresearch/demucs#egg=demucs
# - "apt-get update && apt-get install -y ffmpeg"
# - "apt-get install unzip"

Expand Down
46 changes: 46 additions & 0 deletions config/conditioner/chord2music.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# @package __global__

classifier_free_guidance:
training_dropout: 0.2
inference_coef: 3.0

attribute_dropout:
args:
active_on_eval: false
text: {}
wav:
self_wav: 0.5

fuser:
cross_attention_pos_emb: false
cross_attention_pos_emb_scale: 1
sum: []
prepend: [self_wav, description]
cross: []
input_interpolate: []

conditioners:
self_wav:
model: chroma_chord
chroma_chord:
sample_rate: ${sample_rate}
n_chroma: 12
radix2_exp: 14
argmax: true
match_len_on_eval: false
eval_wavs: null
n_eval_wavs: 100
cache_path: cache
description:
model: t5
t5:
name: t5-base
finetune: false
word_dropout: 0.2
normalize_text: false

dataset:
train:
merge_text_p: 0.25
drop_desc_p: 0.5
drop_other_p: 0.5

0 comments on commit adec2ad

Please sign in to comment.