In [1]:
# train_model.ipynb

# Import necessary libraries
from params import PARAMS_DICTS
import os
from train.train_config import LdmTrainConfig

# Define the functions as in the original script
def args_check(args):
    assert args['mode'] in ['frm', 'ctp', 'lsh', 'acc']
    if args['mode'] == 'frm':
        assert not args['autoreg'] and not args['external'] and not args['mask_bg']

def args_setting_to_fn(args):
    def to_str(x: bool, char):
        return char if x else ''

    mode = args['mode']
    autoreg = to_str(args['autoreg'], 'a')
    external = to_str(args['external'], 'e')
    mask_bg = to_str(args['mask_bg'], 'b')
    multi_label = to_str(args['multi_label'], 'l')
    p_shift = to_str(args['uniform_pitch_shift'], 'p')
    debug = to_str(args['debug'], 'd')

    return f"{mode}-{autoreg}{external}-{mask_bg}{multi_label}{p_shift}-{debug}"

# Set the argument values directly
args = {
    'output_dir': 'results',
    'mode': 'acc',  # set the mode you want to use
    'external': False,
    'autoreg': False,
    'mask_bg': False,
    'multi_label': False,
    'uniform_pitch_shift': False,
    'debug': False
}

# Perform argument check
args_check(args)

# Determine random pitch augmentation
random_pitch_aug = not args['uniform_pitch_shift']

# Get parameters based on the mode
params = PARAMS_DICTS[args['mode']]
if args['debug']:
    params.override({'batch_size': 2})

# Generate the filename based on argument settings
fn = args_setting_to_fn(args)

# Set the output directory
output_dir = os.path.join(args['output_dir'], fn)

# Create the training configuration
config = LdmTrainConfig(params, output_dir, args['mode'], args['autoreg'], args['external'],
                        args['mask_bg'], args['multi_label'], random_pitch_aug, args['debug'])

# Start training
# config.train()

Loading train set: 100%|██████████| 818/818 [00:05<00:00, 160.29it/s]
Loading valid set: 100%|██████████| 91/91 [00:00<00:00, 164.47it/s]
Analyzing train set: 100%|██████████| 818/818 [01:21<00:00, 10.00it/s]
Analyzing valid set: 100%|██████████| 91/91 [00:08<00:00, 10.46it/s]


In [5]:
config.model.ldm#.first_stage_model

LatentDiffusion(
  (eps_model): UNetModel(
    (time_embed): Sequential(
      (0): Linear(in_features=64, out_features=256, bias=True)
      (1): SiLU()
      (2): Linear(in_features=256, out_features=256, bias=True)
    )
    (input_blocks): ModuleList(
      (0): TimestepEmbedSequential(
        (0): Conv2d(14, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      )
      (1-2): 2 x TimestepEmbedSequential(
        (0): ResBlock(
          (in_layers): Sequential(
            (0): GroupNorm32(32, 64, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (emb_layers): Sequential(
            (0): SiLU()
            (1): Linear(in_features=256, out_features=64, bias=True)
          )
          (out_layers): Sequential(
            (0): GroupNorm32(32, 64, eps=1e-05, affine=True)
            (1): SiLU()
            (2): Dropout(p=0.0, inplace=False)
            (3): Conv2d(64, 64, 

In [28]:
for batch in config.train_dl:
    break

In [34]:
batch[0][0].shape

torch.Size([14, 128, 128])

In [42]:
batch[0][0][:6,5,:]

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0

In [21]:
for i in range(128):
    if (batch[0][0][0:2,i,:] != batch[0][0][2:4,i,:]).any():
        print(i)

4
6
7
8
11
15
16
17
18
20
21
22
23
33
34
35
39
40
41
42
43
44
47
52
53
54
55
63
80
81
82
83
87
88
89
90
91
92
95
100
101
102
103
111
112
113
114
116
117
118
119


In [25]:
(15+12+10+0.1)*500/60/24

12.881944444444445

In [26]:
15*500

7500

In [28]:
(batch[0][0][:2,:,-4:]==0.0).all()

tensor(True)

In [44]:
256/15*4

68.26666666666667

In [None]:
def expand_background(self, background, nbpm):
    """form: (1/bs, 8, L, 14) -> (1/bs, 8, L * nbpm, 128)"""
    background = background[:, :, :, 0: 12]
    background = np.tile(background.repeat(nbpm, axis=-2), reps=(1, 1, 1, 11))
    background = background[:, :, :, 0: self.data_params['h']]
    return background

In [48]:
batch[0][0][-6:,2:4,:16]

tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
          0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]],

        [[0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000, 0.6000,
          0.6000, 0.6000, 

In [51]:
import numpy as np
np.array([[[1,2],[3,4]],[[5,6],[7,8]]]).repeat(3, axis=-2)

array([[[1, 2],
        [1, 2],
        [1, 2],
        [3, 4],
        [3, 4],
        [3, 4]],

       [[5, 6],
        [5, 6],
        [5, 6],
        [7, 8],
        [7, 8],
        [7, 8]]])

In [53]:
np.tile(np.array([[[1,2],[3,4]],[[5,6],[7,8]]]), reps=(1,1,3,1))

array([[[[1, 2],
         [3, 4],
         [1, 2],
         [3, 4],
         [1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8],
         [5, 6],
         [7, 8],
         [5, 6],
         [7, 8]]]])