In [1]:
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import sys
sys.path.insert(0, '/userhome/42/msd21003/TATS')

import os
import argparse
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from tats import Net2NetTransformer, VideoData

In [2]:
pl.seed_everything(1234)

parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = Net2NetTransformer.add_model_specific_args(parser)
parser = VideoData.add_data_specific_args(parser)

args = parser.parse_args(args=["--num_workers", "32", "--val_check_interval", " 0.5", "--progress_bar_refresh_rate", " 500",
                    "--gpus", " 8" ,"--sync_batchnorm" ,"--batch_size", " 3",  "--unconditional",
                    "--vqvae", " ../../ckpt/vqgan_ucf.ckpt", "--data_path", " ../../ucf101", "--dataset", "ucf101", "--default_root_dir", " ../../trainGPT_ckpt",
                    "--vocab_size", " 16384", "--block_size", " 1024", "--n_layer", " 24", "--n_head", " 16", "--n_embd", " 1024",
                    "--resolution", " 128", "--sequence_length", " 16", "--max_steps", " 2000000"])

data = VideoData(args)
# pre-make relevant cached files if necessary
data.train_dataloader()
data.test_dataloader()

args.class_cond_dim = data.n_classes if not args.unconditional and args.cond_stage_key=='label' else None
model = Net2NetTransformer(args, first_stage_key=args.first_stage_key, cond_stage_key=args.cond_stage_key)

callbacks = []
callbacks.append(ModelCheckpoint(every_n_train_steps=10000, save_top_k=-1, filename='{epoch}-{step}-{train/loss:.2f}'))
callbacks.append(ModelCheckpoint(every_n_train_steps=50000, save_top_k=-1, filename='{epoch}-{step}-{train/loss:.2f}'))
callbacks.append(ModelCheckpoint(monitor='val/loss', mode='min', save_top_k=3, filename='best_checkpoint'))

kwargs = dict()
if args.gpus > 1:
    # find_unused_parameters = False to support gradient checkpointing
    kwargs = dict(gpus=args.gpus,
                  # plugins=["deepspeed_stage_2"])
                  plugins=[pl.plugins.DDPPlugin(find_unused_parameters=False)])

# configure learning rate
bs, base_lr = args.batch_size, args.base_lr
ngpu = args.gpus
accumulate_grad_batches = args.accumulate_grad_batches or 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
    model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))

# load the most recent checkpoint file
base_dir = os.path.join(args.default_root_dir, 'lightning_logs')
if os.path.exists(base_dir):
    log_folder = ckpt_file = ''
    version_id_used = step_used = 0
    for folder in os.listdir(base_dir):
        version_id = int(folder.split('_')[1])
        if version_id > version_id_used:
            version_id_used = version_id
            log_folder = folder
    if len(log_folder) > 0:
        ckpt_folder = os.path.join(base_dir, log_folder, 'checkpoints')
        for fn in os.listdir(ckpt_folder):
            if fn == 'latest_checkpoint.ckpt':
                ckpt_file = 'latest_checkpoint_prev.ckpt'
                os.rename(os.path.join(ckpt_folder, fn), os.path.join(ckpt_folder, ckpt_file))
        if len(ckpt_file) > 0:
            args.resume_from_checkpoint = os.path.join(ckpt_folder, ckpt_file)
            print('will start from the recent ckpt %s'%args.resume_from_checkpoint)

trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks,
                                        max_steps=args.max_steps, **kwargs)

trainer.fit(model, data)

Global seed set to 1234


data_path:<class 'str'>,sequence_len:<class 'int'>,dataset:<class 'str'>,train:<class 'bool'>,dataset:<class 'type'>


TypeError: __init__() got an unexpected keyword argument 'istrain'

In [8]:
import cupy
import math
import numpy
from tqdm import trange
from torch.utils.dlpack import to_dlpack
from torch.utils.dlpack import from_dlpack
from scipy.stats import multivariate_normal

# def getGaussian(T, H, W, beta, d):
    
#     tensor = torch.tensor((), dtype=torch.float32)
#     weight = tensor.new_ones((T*H*W, T*H*W), device=d)
#     diag = numpy.diag([beta[0], beta[1], beta[1]])
    
#     rv_list = []
    
#     for x in range(T):
#         for y in range(H):
#             for z in range(W):
#                 rv_list.append(multivariate_normal([x, y, z], diag))
    
#     for idx in range(T*H*W):
#         for pos in range(T*H*W):
#             i = math.floor(pos/(H*W))
#             j = math.floor((pos - i * H * W) / H)
#             k = pos - i * H * W - j * W
#             # print(f"i {i}, j {j}, k {k}")
#             weight[idx, pos] = rv_list[idx].pdf([i, j, k])
        
#         # print(weight[idx, :])
#         weight[idx, :] = weight[idx, :] / torch.max(weight[idx, :])
    
#     return weight

# def getGaussian(T, H, W, beta, x, y, z, d):
    
#     diag = numpy.diag([beta[0], beta[1], beta[1]])

#     rv = multivariate_normal([x, y, z], diag)
    
#     tensor = torch.tensor((), dtype=torch.float32)
#     weight = tensor.new_ones((T*H*W,), device=d)
    
#     # gau = trange(T*H*W, desc='Gaussian Matrix', leave=False)
    
#     for pos in numpy.arange(0, T*H*W):
#         i = math.floor(pos/(H*W))
#         j = math.floor((pos - i * H * W) / H)
#         k = pos - i * H * W - j * W
#         # print(f"i {i}, j {j}, k {k}")
#         weight[pos] = rv.pdf([i, j, k])
        
#         weight = weight / torch.max(weight)

#     return weight

def getGaussian(T, H, W, beta, d):
    
    diag = numpy.diag([beta[0], beta[1], beta[1]])

    rv = multivariate_normal([T-1, H-1, W-1], diag)
    
    tensor = torch.tensor((), dtype=torch.float32)
    
    NT = 2*T-1
    NH = 2*H-1
    NW = 2*W-1
    
    weight = tensor.new_ones((NT, NW, NH), device=d)

    # gau = trange(T*H*W, desc='Gaussian Matrix', leave=False)
    
    for pos in numpy.arange(0, NT*NH*NW):
        i = math.floor(pos/(NH*NW))
        j = math.floor((pos - i * NH * NW) / NH)
        k = pos - i * NH * NW - j * NW
        # print(f"i {i}, j {j}, k {k}")
        weight[i, j, k] = rv.pdf([i, j, k])
        
        weight = weight / torch.max(weight)

    return weight


# def FocusedAttention(score, V):
#     V = cupy.asarray(V)
#     # Q = cupy.asarray(Q)
#     # K = cupy.asarray(K)
    
#     B = V.shape[0]
#     NH = V.shape[1]
#     T_flatten = V.shape[2]
#     HS = V.shape[3]

#     T = 4
#     H = 4
#     W = 4
#     beta = [100, 100]
    
#     # V = V.reshape(B * NH, T, H, W, HS)
#     # Q = Q.reshape(B * NH, T, H, W, HS)
#     # K = K.reshape(B * NH, T, H, W, HS)
#     # A = cupy.zeros((B, NH, T, H, W, HS))
    
#     V_full = cupy.ones((T_flatten, B, NH, T_flatten, HS))
    
#     V_boardcast = V[cupy.newaxis, :]
    
#     V_full = V_full * V_boardcast
    
#     weight = getGaussian(T, H, W, beta)
    
#     weight = weight[:, cupy.newaxis, cupy.newaxis, :, cupy.newaxis]
    
#     print(f"weight {weight}")
    
#     print(f"before {V_full}")
    
#     # V_full shape is (T_flatten, B, NH, T_flatten, HS)
#     V_full = V_full * weight
    
#     print(f"After {V_full}")
    
#     # print(f"v full shape is {V_full.shape}, ")
    
#     att = cupy.ones((B, NH, T_flatten, HS))
    
#     score = cupy.asarray(score)
    
#     for pos in range(T_flatten):
        
#         # qk shape (B, NH, 1, T)
#         qk = score[:, :, pos, :]
#         qk = qk[:, :, cupy.newaxis, :]
        
#         # att_pos shape (B, NH, 1, HS)
#         att_pos = qk @ V_full[pos, :, :, :, :]
        
#         # print(f"qk {qk.shape}, V_full {V_full[pos, :, :, :, :].shape}, att_pos {att_pos.shape}")
        
#         att[:,:,pos,:] = att_pos[:, :, 0, :]
    
#     att = from_dlpack(att.toDlpack())
    
#     return att

def FocusedAttention(score, V):
    d = V.get_device()
    
    # V = cupy.asarray(V)
    print(f"before V clone {torch.cuda.memory_allocated()}")
    V_ori = torch.clone(V)
    print(f"After V clone {torch.cuda.memory_allocated()}")

    B = V.shape[0]
    NH = V.shape[1]
    T_flatten = V.shape[2]
    HS = V.shape[3]

    T = 4
    H = 4
    W = 4
    beta = [100, 100]
    
    att = []
    
    center_T = T-1
    center_H = H-1
    center_W = W-1
    
    weight = getGaussian(T, H, W, beta, d)

    # foc = trange(T_flatten, desc='Focused Attention', leave=False)
    
    mem_loop_st = torch.cuda.memory_allocated()
    print(f"before loop {mem_loop_st}")
    
    for pos in numpy.arange(0, T_flatten):
        
        # print(f"start iter {torch.cuda.memory_allocated()}")
        
        i = math.floor(pos/(H*W))
        j = math.floor((pos - i * H * W) / H)
        k = pos - i * H * W - j * W
        
        # weight_xyz = weight[center_T-i:2*center_T-i + 1, center_W-j:2*center_W-j + 1, center_H-k:2*center_H-k + 1]
        
        # print(weight_xyz.shape)
        
        weight_xyz = weight[center_T-i:2*center_T-i + 1, center_W-j:2*center_W-j + 1, center_H-k:2*center_H-k + 1].reshape(-1)

        weight_xyz = weight_xyz[None, None, :, None]
        # print(V.shape)
        
        V = V*weight_xyz
        
        del weight_xyz
        
        # qk shape (B, NH, 1, T)
        qk = score[:, :, pos, :]
        qk = qk[:, :, None, :]

        att_pos = qk @ V
        att.append(att_pos)
        
        V = torch.clone(V_ori)
        
        # print(f"end iter {torch.cuda.memory_allocated()}")
    
    print(f"end loop used {torch.cuda.memory_allocated() - mem_loop_st}")
    result = torch.cat(att, dim=2)
    
    return result

# def FocusedAttention(score, V):
#     d = V.get_device()
    
#     # V = cupy.asarray(V)
#     V_ori = torch.clone(V)

#     B = V.shape[0]
#     NH = V.shape[1]
#     T_flatten = V.shape[2]
#     HS = V.shape[3]

#     T = 4
#     H = 4
#     W = 4
#     beta = [100, 100]
    
#     # att = tensor.new_ones((B, NH, T_flatten, HS), requires_grad=True, device=d)
    
#     att = []
#     # score = cupy.asarray(score)
    
#     # foc = trange(T_flatten, desc='Focused Attention', leave=True)
    
#     weight = getGaussian(T, H, W, beta, d)
    
#     # print(f"weight {weight}")
    
#     weight = weight[:, None, None, :, None]
    
    
#     for pos in range(T_flatten):
        
#         i = math.floor(pos/(H*W))
#         j = math.floor((pos - i * H * W) / H)
#         k = pos - i * H * W - j * W
#         V = V*weight[pos, :, :, :, :]
        
#         # qk shape (B, NH, 1, T)
#         qk = score[:, :, pos, :]
#         qk = qk[:, :, None, :]

#         att_pos = qk @ V
#         att.append(att_pos)
        
#         # att[:,:,pos,:] = att[:,:,pos,:] * att_pos[:, :, 0, :]
        
#         V = torch.clone(V_ori)
    
#     # att = from_dlpack(att.toDlpack())
    
#     result = torch.cat(att, dim=2)
    
#     # print(att, result.shape)
    
#     return result

# def FocusedAttention(score, V):
#     d = V.get_device()
    
#     # V = cupy.asarray(V)
#     V_ori = torch.clone(V)

#     B = V.shape[0]
#     NH = V.shape[1]
#     T_flatten = V.shape[2]
#     HS = V.shape[3]

#     T = 4
#     H = 4
#     W = 4
#     beta = [100, 100]
    
#     # att = tensor.new_ones((B, NH, T_flatten, HS), requires_grad=True, device=d)
    
#     att = []
    
#     center_T = T-1
#     center_H = H-1
#     center_W = W-1
    
#     # score = cupy.asarray(score)
    
#     # foc = trange(T_flatten, desc='Focused Attention', leave=True)
    
#     weight = getGaussian(T, H, W, beta, d)
    
#     # print(f"weight {weight}")
    
#     weight_xyz = weight[:, None, None, :, None]
    
    
#     for pos in range(T_flatten):
        
#         i = math.floor(pos/(H*W))
#         j = math.floor((pos - i * H * W) / H)
#         k = pos - i * H * W - j * W
        
#         V = V*weight[pos, :, :, :, :]
        
#         # qk shape (B, NH, 1, T)
#         qk = score[:, :, pos, :]
#         qk = qk[:, :, None, :]

#         att_pos = qk @ V
#         att.append(att_pos)
        
#         # att[:,:,pos,:] = att[:,:,pos,:] * att_pos[:, :, 0, :]
        
#         V = torch.clone(V_ori)
    
#     # att = from_dlpack(att.toDlpack())
    
#     result = torch.cat(att, dim=2)
    
#     # print(att, result.shape)
    
#     return result

In [9]:
import torch


cuda = torch.device('cuda', 0)
B = 2
NH = 2
T_flatten = 64
HS = 2
Q = torch.rand(B, NH, T_flatten, HS, device = cuda)
K = torch.rand(B, NH, T_flatten, HS, device = cuda)
V = torch.rand(B, NH, T_flatten, HS, device = cuda)

score = torch.rand(B, NH, T_flatten, T_flatten, device = cuda)

A = FocusedAttention(score, V)
B = FocusedAttention(score, V)


print(A)

before V clone 17047552
After V clone 17049600
before loop 17051136
start iter 17051136
end iter 17053696
start iter 17053696
end iter 17054208
start iter 17054208
end iter 17054720
start iter 17054720
end iter 17055232
start iter 17055232
end iter 17055744
start iter 17055744
end iter 17056256
start iter 17056256
end iter 17056768
start iter 17056768
end iter 17057280
start iter 17057280
end iter 17057792
start iter 17057792
end iter 17058304
start iter 17058304
end iter 17058816
start iter 17058816
end iter 17059328
start iter 17059328
end iter 17059840
start iter 17059840
end iter 17060352
start iter 17060352
end iter 17060864
start iter 17060864
end iter 17061376
start iter 17061376
end iter 17061888
start iter 17061888
end iter 17062400
start iter 17062400
end iter 17062912
start iter 17062912
end iter 17063424
start iter 17063424
end iter 17063936
start iter 17063936
end iter 17064448
start iter 17064448
end iter 17064960
start iter 17064960
end iter 17065472
start iter 17065472
