/
train_video-diffusion.py
119 lines (106 loc) · 5.78 KB
/
train_video-diffusion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
from video_diffusion_pytorch.wip_video_diffusion_pytorch import Unet3D, GaussianDiffusion, Trainer
import os
import re
import argparse
import wandb
def parse_arguments():
parser = argparse.ArgumentParser(description='Video Diffusion Training')
parser.add_argument('--image_size', type=int, default=64, help='Image size')
parser.add_argument('--channels', type=int, default=3, help='Image Channel size')
parser.add_argument('--num_frames', type=int, default=16, help='Number of video frames')
parser.add_argument('--base_channel_size', type=int, default=256, help='Base channel size of 3D U-Net')
parser.add_argument('--timeemb_linears', type=int, default=2, help='Number of temporal layers')
parser.add_argument('--attn_heads', type=int, default=8, help='Number of attention heads')
parser.add_argument('--attn_dim_head', type=int, default=64, help='Dimension of attention head')
parser.add_argument('--s4d_hidden_dim', type=int, default=None, help='Hidden dimension of S4D')
parser.add_argument('--s4d_linear_dim', type=int, default=None, help='Linear dimension of S4D')
parser.add_argument('--s4d_version', type=int, default=None, help='Version of S4D')
parser.add_argument('--temporal_layer', type=str, default='attn', help='Architecture of temporal layers')
parser.add_argument('--timesteps', type=int, default=1000, help='Number of steps')
parser.add_argument('--loss_type', type=str, default='l2', help='Loss type')
parser.add_argument('--train_batch_size', type=int, default=16, help='Training batch size')
parser.add_argument('--train_lr', type=float, default=1e-4, help='Training learning rate')
parser.add_argument('--beta1', type=float, default=0.9, help='Beta1 for Adam')
parser.add_argument('--beta2', type=float, default=0.999, help='Beta2 for Adam')
parser.add_argument('--save_and_sample_every', type=int, default=1000, help='Save and sample frequency')
parser.add_argument('--train_num_steps', type=int, default=700000, help='Total training steps')
parser.add_argument('--gradient_accumulate_every', type=int, default=4, help='Gradient accumulation steps')
parser.add_argument('--ema_decay', type=float, default=0.995, help='Exponential moving average decay')
parser.add_argument('--amp', type=bool, default=False, help='Use mixed precision')
parser.add_argument('--dataset', type=str, default='movingmnist', help='Dataset name')
parser.add_argument('--folder', type=str, help='Data folder')
parser.add_argument('--results_folder', type=str, help='Results folder')
parser.add_argument('--device_ids', nargs='+', type=int, default=[0, 1, 2], help='List of GPU IDs')
parser.add_argument('--seed', type=int, default=0, help='Random seed')
# parser.add_argument('--csv_path', type=str, default="/home/acd13972py/Video-LLaMA/video_captions.csv", help='CSV path')
# parser.add_argument('--use_cond', type=bool, default=True, help='use text condition or not')
args = parser.parse_args()
return args
def main(args):
torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)
wandb_run = wandb.init(
project='ssm_vdm_sampling',
config=vars(args))
model = Unet3D(
dim = args.base_channel_size, # 256 -> 128
dim_mults = (1, 2, 4, 8),
channels = args.channels,
timeemb_linears = args.timeemb_linears,
attn_heads = args.attn_heads,
attn_dim_head = args.attn_dim_head,
s4d_hidden_dim = args.s4d_hidden_dim,
s4d_linear_dim = args.s4d_linear_dim,
temporal_arch = args.temporal_layer,
s4d_version = args.s4d_version,
)
model = torch.nn.DataParallel(model, device_ids=args.device_ids) # DP (old ver)
diffusion = GaussianDiffusion(
model,
image_size = args.image_size,
channels = args.channels,
num_frames = args.num_frames,
timesteps = args.timesteps, # number of steps
loss_type = args.loss_type # L1 or L2
)
diffusion = diffusion.cuda()
# diffusion = torch.nn.DataParallel(diffusion, device_ids=args.device_ids) # DP (new ver)
trainer = Trainer(
diffusion,
dataset=args.dataset, # dataset name
folder=args.folder, # this folder path needs to contain all your training data, as .gif files, of correct image size and number of frames
results_folder = args.results_folder,
train_batch_size = args.train_batch_size,
train_lr = args.train_lr,
beta1 = args.beta1,
beta2 = args.beta2,
save_and_sample_every = args.save_and_sample_every,
train_num_steps = args.train_num_steps, # total training steps
gradient_accumulate_every = args.gradient_accumulate_every, # gradient accumulation steps
ema_decay = args.ema_decay, # exponential moving average decay
amp = args.amp, # turn on mixed precision
num_gpus = len(args.device_ids), # number of GPUs
)
if args.resume_milestone > 0:
# resume training from last saved milestone
trainer.load(args.resume_milestone)
print(args.resume_milestone)
trainer.train(log_fn=wandb_run.log)
if __name__ == "__main__":
args = parse_arguments()
# resume training from maximum milestone
if os.path.exists(args.results_folder):
# search maximum number of saved milestions
pattern = re.compile(r'model-(\d+)\.pt$')
max_num = -1
for filename in os.listdir(args.results_folder):
match = pattern.match(filename)
if match:
num = int(match.group(1))
if num > max_num:
max_num = num
args.resume_milestone = max_num
else:
args.resume_milestone = -1
main(args)