In [1]:
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
import sys
sys.path.insert(0, '/userhome/42/msd21003/TATS')

import os
import argparse
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from tats import Net2NetTransformer, VideoData

In [2]:
pl.seed_everything(1234)

parser = argparse.ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = Net2NetTransformer.add_model_specific_args(parser)
parser = VideoData.add_data_specific_args(parser)

args = parser.parse_args(args=["--num_workers", "32", "--val_check_interval", " 0.5", "--progress_bar_refresh_rate", " 500",
                    "--gpus", " 8" ,"--sync_batchnorm" ,"--batch_size", " 3",  "--unconditional",
                    "--vqvae", " ../../ckpt/vqgan_ucf.ckpt", "--data_path", " ../../ucf101", "--dataset", "ucf101", "--default_root_dir", " ../../trainGPT_ckpt",
                    "--vocab_size", " 16384", "--block_size", " 1024", "--n_layer", " 24", "--n_head", " 16", "--n_embd", " 1024",
                    "--resolution", " 128", "--sequence_length", " 16", "--max_steps", " 2000000"])

data = VideoData(args)
# pre-make relevant cached files if necessary
data.train_dataloader()
data.test_dataloader()

args.class_cond_dim = data.n_classes if not args.unconditional and args.cond_stage_key=='label' else None
model = Net2NetTransformer(args, first_stage_key=args.first_stage_key, cond_stage_key=args.cond_stage_key)

callbacks = []
callbacks.append(ModelCheckpoint(every_n_train_steps=10000, save_top_k=-1, filename='{epoch}-{step}-{train/loss:.2f}'))
callbacks.append(ModelCheckpoint(every_n_train_steps=50000, save_top_k=-1, filename='{epoch}-{step}-{train/loss:.2f}'))
callbacks.append(ModelCheckpoint(monitor='val/loss', mode='min', save_top_k=3, filename='best_checkpoint'))

kwargs = dict()
if args.gpus > 1:
    # find_unused_parameters = False to support gradient checkpointing
    kwargs = dict(gpus=args.gpus,
                  # plugins=["deepspeed_stage_2"])
                  plugins=[pl.plugins.DDPPlugin(find_unused_parameters=False)])

# configure learning rate
bs, base_lr = args.batch_size, args.base_lr
ngpu = args.gpus
accumulate_grad_batches = args.accumulate_grad_batches or 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print("Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
    model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr))

# load the most recent checkpoint file
base_dir = os.path.join(args.default_root_dir, 'lightning_logs')
if os.path.exists(base_dir):
    log_folder = ckpt_file = ''
    version_id_used = step_used = 0
    for folder in os.listdir(base_dir):
        version_id = int(folder.split('_')[1])
        if version_id > version_id_used:
            version_id_used = version_id
            log_folder = folder
    if len(log_folder) > 0:
        ckpt_folder = os.path.join(base_dir, log_folder, 'checkpoints')
        for fn in os.listdir(ckpt_folder):
            if fn == 'latest_checkpoint.ckpt':
                ckpt_file = 'latest_checkpoint_prev.ckpt'
                os.rename(os.path.join(ckpt_folder, fn), os.path.join(ckpt_folder, ckpt_file))
        if len(ckpt_file) > 0:
            args.resume_from_checkpoint = os.path.join(ckpt_folder, ckpt_file)
            print('will start from the recent ckpt %s'%args.resume_from_checkpoint)

trainer = pl.Trainer.from_argparse_args(args, callbacks=callbacks,
                                        max_steps=args.max_steps, **kwargs)

trainer.fit(model, data)

Global seed set to 1234


data_path:<class 'str'>,sequence_len:<class 'int'>,dataset:<class 'str'>,train:<class 'bool'>,dataset:<class 'type'>


TypeError: __init__() got an unexpected keyword argument 'istrain'

In [55]:
import torch
import math
import cupy
import numpy as np
import cupy as cp
from datetime import datetime
from scipy.stats import multivariate_normal
from scipy.special import softmax
from torch.autograd import Function


def getGaussian(T, H, W, beta, d):

    diag = np.diag([beta[0], beta[1], beta[1]])
    rv = multivariate_normal([T - 1, H - 1, W - 1], diag)
    tensor = torch.tensor((), dtype=torch.float32)

    NT = 2 * T - 1
    NH = 2 * H - 1
    NW = 2 * W - 1

    weight = tensor.new_ones((NT, NW, NH), device=d)

    for pos in np.arange(0, NT * NH * NW):
        i = math.floor(pos / (NH * NW))
        j = math.floor((pos - i * NH * NW) / NH)
        k = pos - i * NH * NW - j * NW
        weight[i, j, k] = rv.pdf([i, j, k])

        weight = weight / torch.max(weight)

    return weight

class focusAttention(Function):

    T, H, W = 4, 4, 4
    T_flatten = T * H * W
    center_T, center_H, center_W = T - 1, H - 1, W - 1
    beta = [100, 100]
    
    diag = np.diag([beta[0], beta[1], beta[1]])
    rv = multivariate_normal([T - 1, H - 1, W - 1], diag)
    tensor = torch.tensor((), dtype=torch.float32)

    NT = 2 * T - 1
    NH = 2 * H - 1
    NW = 2 * W - 1

    weight_cuda0 = tensor.new_ones((NT, NW, NH), device=torch.device("cuda:0"))

    for pos in np.arange(0, NT * NH * NW):
        i = math.floor(pos / (NH * NW))
        j = math.floor((pos - i * NH * NW) / NH)
        k = pos - i * NH * NW - j * NW
        weight_cuda0[i, j, k] = rv.pdf([i, j, k])

        weight_cuda0 = weight_cuda0 / torch.max(weight_cuda0)
    
    weight_cuda1 = weight_cuda0.detach().to("cuda:1")

    @staticmethod
    def forward(ctx, score, V):

        att=[]
        
        if V.get_device() == 0:
            weight = focusAttention.weight_cuda0
        else:
            weight = focusAttention.weight_cuda1
        
        st = torch.cuda.memory_allocated()
        
        st_loop = datetime.now()
        
        for pos in np.arange(0, focusAttention.T_flatten):
            
            st = datetime.now()
            print(f"start of loop {st}")
            
            i = math.floor(pos / (focusAttention.H * focusAttention.W))
            j = math.floor((pos - i * focusAttention.H * focusAttention.W) / focusAttention.H)
            k = pos - i * focusAttention.H * focusAttention.W - j * focusAttention.W
            
            t1 = datetime.now()
            
            print(f"pos calculate {t1 - st}")
            

            weight_xyz = weight[focusAttention.center_T - i:2 * focusAttention.center_T - i + 1, focusAttention.center_W - j:2 * focusAttention.center_W - j + 1,
                         focusAttention.center_H - k:2 * focusAttention.center_H - k + 1].reshape(-1)
            
            t2 = datetime.now()
            print(f"weight idx {t2 - t1}")
            
            weight_xyz = weight_xyz[None, None, :, None]
            t3 = datetime.now()
            
            print(f"weight boardcast {t3 - t2}")
            
            # qk shape (B, NH, 1, T)
            qk = score[:, :, pos, :]
            qk = qk[:, :, None, :]
            
            t4 = datetime.now()
            
            print(f"qk index and boardcast {t4 - t3}")

            att_pos = torch.matmul(qk, (V * weight_xyz)).detach()
            
            t5 = datetime.now()
            
            print(f"att cal {t5 - t4}")

            att.append(att_pos)
            
            t6 = datetime.now()
            
            print(f"append {t6 -t5}")
            # V = torch.clone(V_ori)

        
        end = datetime.now()
        
        print(f"complete loop {end - st_loop}")
        
        # print(f"Before cat {torch.cuda.memory_allocated()}")

        result = torch.cat(att, dim=2)
        
        # print(f"result shape {result.shape}")

        end = torch.cuda.memory_allocated()

        # print(f"result memory usage is {result.element_size() * result.nelement()}, memory used {end - st}, memory for v is {V.element_size() * V.nelement()}")

        # print(f"After focused attention, memory usage is {end}, memory used {end - st}")

        # torch.cuda.empty_cache()

        # print(f"After empty cache, memory usage is {torch.cuda.memory_allocated()}")

        # ctx.save_for_backward(score, V, result)

        # print(f"After save for backwards, memory usage is {torch.cuda.memory_allocated()}")

        return result

    @staticmethod
    def backward(ctx, grad_output):
        score, V, result = ctx.saved_tensors
        
        if V.get_device() == 0:
            weight = focusAttention.weight_cuda0
        else:
            weight = focusAttention.weight_cuda1

        grad_score = []
        grad_V = []

        for pos in np.arange(0, focusAttention.T_flatten):
            grad_att_pos = grad_output[:, :, pos, :]

            grad_att_pos = grad_att_pos[:, :, None, :]

            i = math.floor(pos / (focusAttention.H * focusAttention.W))
            j = math.floor((pos - i * focusAttention.H * focusAttention.W) / focusAttention.H)
            k = pos - i * focusAttention.H * focusAttention.W - j * focusAttention.W

            weight_xyz = weight[focusAttention.center_T - i:2 * focusAttention.center_T - i + 1, focusAttention.center_W - j:2 * focusAttention.center_W - j + 1,
                         focusAttention.center_H - k:2 * focusAttention.center_H - k + 1].reshape(-1)

            qk = score[:, :, pos, :]
            
            qk = qk[:, :, None, :]
            
            print(f"shape qk {qk.shape}")
            
            qk = torch.swapaxes(qk, 2, 3)
            
            weight_xyz = weight_xyz[None, None, :, None]
            
            print(f"shape qk {qk.shape}, weight_xyz {weight_xyz.shape}")
            
            print(f"shape qk*weight {(qk * weight_xyz).shape}, grad_att_pos {grad_att_pos.shape}")
            
            grad_V_pos = torch.matmul((qk * weight_xyz), grad_att_pos)[:, :, pos, :]
            
            grad_V.append(grad_V_pos[:, :, None, :])
            
            print(f"V_pos shape {grad_V_pos.shape}")
            
            grad_score_pos = (torch.matmul(weight_xyz, grad_att_pos) * V)[:, :, :, 0]
            
            print(f"grad_score_pos shape {grad_score_pos.shape}")
            
            grad_score.append(grad_score_pos[:, :, None, :])

            # grad_qk = grad_att_pos @ torch.linalg.inv(V_focus)
            # grad_V_focus = torch.linalg.inv(qk) @ grad_att_pos
            #
            # grad_score.append(grad_qk)
            # grad_V_focus = grad_V_focus * weight_xyz
            # grad_V.append(grad_V_focus)

        # Shape should be B, NH, T, T
        grad_score = torch.cat(grad_score, dim=2)

        # Shape should be B, NH, T, HS
        grad_V = torch.cat(grad_V, dim=2)
        
        # print(f"grad_score {grad_score.shape}, grad_V {grad_V.shape}")

        return grad_score, grad_V


In [38]:
import torch
import math
import cupy
import numpy as np
import cupy as cp
from datetime import datetime
from scipy.stats import multivariate_normal
from scipy.special import softmax
from torch.autograd import Function


def getGaussian(T, H, W, beta, d):

    diag = np.diag([beta[0], beta[1], beta[1]])
    rv = multivariate_normal([T - 1, H - 1, W - 1], diag)
    tensor = torch.tensor((), dtype=torch.float32)

    NT = 2 * T - 1
    NH = 2 * H - 1
    NW = 2 * W - 1

    weight = tensor.new_ones((NT, NW, NH), device=d)

    for pos in np.arange(0, NT * NH * NW):
        i = math.floor(pos / (NH * NW))
        j = math.floor((pos - i * NH * NW) / NH)
        k = pos - i * NH * NW - j * NW
        weight[i, j, k] = rv.pdf([i, j, k])

        weight = weight / torch.max(weight)

    return weight

class focusAttention(Function):

    T, H, W = 4, 16, 16
    T_flatten = T * H * W
    center_T, center_H, center_W = T - 1, H - 1, W - 1
    beta = [10000, 10000]
    
    diag = np.diag([beta[0], beta[1], beta[1]])
    rv = multivariate_normal([T - 1, H - 1, W - 1], diag)
    tensor = torch.tensor((), dtype=torch.float32)

    NT = 2 * T - 1
    NH = 2 * H - 1
    NW = 2 * W - 1

    weight_cuda0 = tensor.new_ones((NT, NW, NH), device=torch.device("cuda:0"))

    for pos in np.arange(0, NT * NH * NW):
        i = math.floor(pos / (NH * NW))
        j = math.floor((pos - i * NH * NW) / NH)
        k = pos - i * NH * NW - j * NW
        weight_cuda0[i, j, k] = rv.pdf([i, j, k])

    weight_cuda0 = weight_cuda0 / torch.max(weight_cuda0)
    
    # print(weight_cuda0, weight_cuda0[T-1, H-1, W-1])
    
    # Shape T, 1, 1, T, 1
    V_weight_cuda0 = torch.empty((T_flatten,1,1,T_flatten,1), dtype=torch.float32, device ="cuda:0")

    for pos in np.arange(0, T_flatten):

        i = math.floor(pos / (H * W))
        j = math.floor((pos - i * H * W) / H)
        k = pos - i * H * W - j * W

        weight_xyz = weight_cuda0[center_T - i:2 * center_T - i + 1, center_W - j:2 * center_W - j + 1,
                     center_H - k:2 * center_H - k + 1].reshape(-1)

        V_weight_cuda0[pos, 0, 0, :, 0] = weight_xyz

    V_weight_cuda1 = V_weight_cuda0.detach().to("cuda:1")

    @staticmethod
    def forward(ctx, score, V):
        
        B, NH, T_flatten, HS = V.shape
        d = V.get_device()
        
        if d == 0:
            V_weight = focusAttention.V_weight_cuda0
        else:
            V_weight = focusAttention.V_weight_cuda1
        

        st_loop = datetime.now()

        # V full shape is T, B, NH, T, HS = T, 1, 1, T, 1 * B, NH, T, HS
        V_full = V_weight * V
        
        # print("V full size", V_full.nelement() * V_full.element_size(), V_full.shape, V_weight.shape)
        # print("before score ", score.shape)
        
        qk = score[:, :, :, None, :]
        
        # qk should be T, B, NH, 1 , T 
        qk = qk.permute(2, 0, 1, 3, 4)
        
        # print("qk ", qk.shape)
        
        # result should be T, B, NH, 1, HS
        result = torch.empty((T_flatten, B, NH, 1, HS), dtype = torch.float32, device = f"cuda:{d}")
        
        # mem = torch.cuda.torch.cuda.memory_allocated()
        # result1 = torch.einsum('ijklm, ijkmn -> ijkln', qk, V_full)
        # mem1 = torch.cuda.torch.cuda.memory_allocated()
        # print("einsum used memory", mem1 - mem)
        
        div = 16
        mem2 = torch.cuda.torch.cuda.memory_allocated()
        for sub in np.arange(div):
            base = int(focusAttention.T_flatten / div)
            result[base*sub:base*(sub+1), :, :, :, :] = torch.matmul(qk[base*sub:base*(sub+1), :, :, :, :], V_full[base*sub:base*(sub+1), :, :, :, :])
            
        mem3 = torch.cuda.torch.cuda.memory_allocated()
        # print("divided memory", mem3 -mem2)
        
        result1 = torch.swapaxes(result, 0, 2)[:, :, :, 0, :]
        
#         print("result", result.shape)
        
        end = datetime.now()
        print(f"forward {end - st_loop}")
        
        ctx.save_for_backward(score, V, result)

        return result1

    @staticmethod
    def backward(ctx, grad_output):
        
        st_loop = datetime.now()
        
        score, V, result = ctx.saved_tensors
        
        B, NH, T_flatten, HS = V.shape

        # V weight shape is [1, 1, T, T]
        d = V.get_device()

        if d == 0:
            V_weight = focusAttention.V_weight_cuda0
        else:
            V_weight = focusAttention.V_weight_cuda1

        # （1, B, NH, T, T)
        score = score[None, :, :, :, :]

        # (T, B, NH, T, 1)
        score = score.permute(3, 1, 2, 4, 0)
        print(score.shape, V_weight.shape)
        
        # (1, B, NH, T, HS)
        grad_output = grad_output[None, :, :, :, :]

        # (T, B, NH, 1, HS)
        grad_output = torch.swapaxes(grad_output, 0, 3)

        grad_V = torch.empty(V.shape, dtype = torch.float32, device = f"cuda:{d}")
        grad_V_total = torch.empty((T_flatten, B, NH, T_flatten, HS), dtype = torch.float32, device = f"cuda:{d}")
        # Shape is (T, B, NH, T, HS)
        
        div = 16
        mem2 = torch.cuda.torch.cuda.memory_allocated()
        for sub in np.arange(div):
            base = int(focusAttention.T_flatten / div)
            grad_V_total[base*sub:base*(sub+1), :, :, :, :] = torch.matmul((V_weight * score)[base*sub:base*(sub+1), :, :, :, :], grad_output[base*sub:base*(sub+1), :, :, :, :])
            
        mem3 = torch.cuda.torch.cuda.memory_allocated()
        
        for i in np.arange(focusAttention.T_flatten):
            grad_V[:, :, i, :] = grad_V_total[i, :, :, i, :]

        del grad_V_total

        # (T, B, NH, T, 1) = (B, NH, T, 1) * (T, B, NH, 1, 1) * (T, 1, 1, T, 1)

        grad_score = V[:, :, :, 0][:, :, :, None] * grad_output[:, :, :, :, 0][:, :, :, :, None] * V_weight

        # (T, B, NH, T)
        grad_score = grad_score[:, :, :, :, 0]

        # (B, NH, T, T)
        grad_score = grad_score.permute(1, 2, 0, 3)
        
        end = datetime.now()
        
        print(f"backward {end - st_loop}")
        # print(f"grad_score {grad_score.shape}, grad_V {grad_V.shape}")
        return grad_score, grad_V


In [39]:

ctx = {}

score = torch.randn(4, 16, 1024, 1024, dtype=torch.float32, requires_grad=True, device='cuda:0')

V = torch.randn(4, 16, 1024, 16, dtype=torch.float32, requires_grad=True, device='cuda:0')

result = torch.randn(4, 16, 1024, 16, dtype=torch.float32, requires_grad=True, device='cuda:0')

grad_output = torch.randn(4, 16, 1024, 16, dtype=torch.float32, requires_grad=True, device='cuda:0')

class ctx_class:
    def __init__(self):
        self.saved_tensors = [score, V, result]


ctx = ctx_class()

# focusAttention.forward(ctx, score, V)

focusAttention.backward(ctx, grad_output)

torch.Size([1024, 4, 16, 1024, 1]) torch.Size([1024, 1, 1, 1024, 1])
backward 0:00:00.059186


(tensor([[[[ 1.0659e-02, -7.1752e-02, -3.2511e-02,  ...,  5.8555e-02,
             2.2246e-02, -1.0939e-01],
           [-1.0925e-01,  7.3552e-01,  3.3329e-01,  ..., -6.0096e-01,
            -2.2834e-01,  1.1229e+00],
           [-1.8819e-01,  1.2671e+00,  5.7421e-01,  ..., -1.0365e+00,
            -3.9386e-01,  1.9371e+00],
           ...,
           [ 6.5541e-02, -4.4176e-01, -2.0042e-01,  ...,  3.7484e-01,
             1.4259e-01, -7.0208e-01],
           [ 5.8338e-03, -3.9325e-02, -1.7843e-02,  ...,  3.3409e-02,
             1.2710e-02, -6.2587e-02],
           [-3.9475e-02,  2.6612e-01,  1.2076e-01,  ..., -2.2636e-01,
            -8.6125e-02,  4.2413e-01]],
 
          [[ 6.2830e-02,  5.0019e-01,  9.7075e-02,  ..., -3.7757e-01,
            -3.3462e-01, -1.1313e-01],
           [-1.3270e-01, -1.0566e+00, -2.0508e-01,  ...,  7.9852e-01,
             7.0775e-01,  2.3930e-01],
           [-1.5364e-01, -1.2234e+00, -2.3748e-01,  ...,  9.2571e-01,
             8.2057e-01,  2.7747e-01],


In [9]:
import torch


cuda = torch.device('cuda', 0)
B = 2
NH = 2
T_flatten = 64
HS = 2
Q = torch.rand(B, NH, T_flatten, HS, device = cuda)
K = torch.rand(B, NH, T_flatten, HS, device = cuda)
V = torch.rand(B, NH, T_flatten, HS, device = cuda)

score = torch.rand(B, NH, T_flatten, T_flatten, device = cuda)

A = FocusedAttention(score, V)
B = FocusedAttention(score, V)



print(A)

before V clone 17047552
After V clone 17049600
before loop 17051136
start iter 17051136
end iter 17053696
start iter 17053696
end iter 17054208
start iter 17054208
end iter 17054720
start iter 17054720
end iter 17055232
start iter 17055232
end iter 17055744
start iter 17055744
end iter 17056256
start iter 17056256
end iter 17056768
start iter 17056768
end iter 17057280
start iter 17057280
end iter 17057792
start iter 17057792
end iter 17058304
start iter 17058304
end iter 17058816
start iter 17058816
end iter 17059328
start iter 17059328
end iter 17059840
start iter 17059840
end iter 17060352
start iter 17060352
end iter 17060864
start iter 17060864
end iter 17061376
start iter 17061376
end iter 17061888
start iter 17061888
end iter 17062400
start iter 17062400
end iter 17062912
start iter 17062912
end iter 17063424
start iter 17063424
end iter 17063936
start iter 17063936
end iter 17064448
start iter 17064448
end iter 17064960
start iter 17064960
end iter 17065472
start iter 17065472


In [None]:
import torch
import math
import cupy
import numpy as np
import cupy as cp
from scipy.stats import multivariate_normal
from scipy.special import softmax
from torch.autograd import Function


def getGaussian(T, H, W, beta, d):
    diag = np.diag([beta[0], beta[1], beta[1]])
    rv = multivariate_normal([T - 1, H - 1, W - 1], diag)
    tensor = torch.tensor((), dtype=torch.float32)

    NT = 2 * T - 1
    NH = 2 * H - 1
    NW = 2 * W - 1

    weight = tensor.new_ones((NT, NW, NH), device=d)

    for pos in np.arange(0, NT * NH * NW):
        i = math.floor(pos / (NH * NW))
        j = math.floor((pos - i * NH * NW) / NH)
        k = pos - i * NH * NW - j * NW
        weight[i, j, k] = rv.pdf([i, j, k])

        weight = weight / torch.max(weight)

    return weight




class focusAttention(Function):

    T, H, W = 4, 4, 4
    T_flatten = T * H * W
    center_T, center_H, center_W = T - 1, H - 1, W - 1
    beta = [100, 100]
    device = torch.device("cuda:0")
    weight = getGaussian(T, H, W, beta, device)

    @staticmethod
    def forward(ctx, score, V):

        att=[]

        st = torch.cuda.memory_allocated()

        for pos in np.arange(0, focusAttention.T_flatten):

            # print(f"start of loop {torch.cuda.memory_allocated()}")

            i = math.floor(pos / (focusAttention.H * focusAttention.W))
            j = math.floor((pos - i * focusAttention.H * focusAttention.W) / focusAttention.H)
            k = pos - i * focusAttention.H * focusAttention.W - j * focusAttention.W

            # print(f"Before weight_xyz {torch.cuda.memory_allocated()}")

            weight_xyz = focusAttention.weight[focusAttention.center_T - i:2 * focusAttention.center_T - i + 1, focusAttention.center_W - j:2 * focusAttention.center_W - j + 1,
                         focusAttention.center_H - k:2 * focusAttention.center_H - k + 1].reshape(-1)

            # print(f"After sub indexing weight {torch.cuda.memory_allocated()}")

            weight_xyz = weight_xyz[None, None, :, None]

            # print(f"After add axis {torch.cuda.memory_allocated()}")

            # V_focused = V * weight_xyz

            # print(f"After multiply weight {torch.cuda.memory_allocated()}")

            # qk shape (B, NH, 1, T)
            qk = score[:, :, pos, :]

            # print(f"After index qk{torch.cuda.memory_allocated()}")

            qk = qk[:, :, None, :]

            # print(f"After add axis to qk {torch.cuda.memory_allocated()}")

            att_pos = torch.matmul(qk, (V * weight_xyz)).detach()

            att.append(att_pos)
            # V = torch.clone(V_ori)


        # print(f"Before cat {torch.cuda.memory_allocated()}")

        result = torch.cat(att, dim=2).detach()

        end = torch.cuda.memory_allocated()

        # print(f"result memory usage is {result.element_size() * result.nelement()}, memory used {end - st}, memory for v is {V.element_size() * V.nelement()}")

        # print(f"After focused attention, memory usage is {end}, memory used {end - st}")

        torch.cuda.empty_cache()

        # print(f"After empty cache, memory usage is {torch.cuda.memory_allocated()}")

        ctx.save_for_backward(score, V, result)

        # print(f"After save for backwards, memory usage is {torch.cuda.memory_allocated()}")

        return result

    @staticmethod
    def backward(ctx, grad_output):
        score, V, result = ctx.saved_tensor

        grad_score = []
        grad_V = []

        for pos in np.arange(0, focusAttention.T_flatten):
            grad_att_pos = grad_output[:, :, pos, :]

            grad_att_pos = grad_att_pos[:, :, None, :]

            i = math.floor(pos / (focusAttention.H * focusAttention.W))
            j = math.floor((pos - i * focusAttention.H * focusAttention.W) / focusAttention.H)
            k = pos - i * focusAttention.H * focusAttention.W - j * focusAttention.W

            weight_xyz = focusAttention.weight[focusAttention.center_T - i:2 * focusAttention.center_T - i + 1, focusAttention.center_W - j:2 * focusAttention.center_W - j + 1,
                         focusAttention.center_H - k:2 * focusAttention.center_H - k + 1].reshape(-1)

            qk = score[:, :, pos, :]

            qk = torch.swapaxes(qk, 2, 3)

            grad_V.append(torch.matmul((qk * weight_xyz), grad_att_pos)[:, :, pos, :])
            grad_score.append(torch.matmul(weight_xyz, grad_att_pos) * V[:, :, :, 0])

            # grad_qk = grad_att_pos @ torch.linalg.inv(V_focus)
            # grad_V_focus = torch.linalg.inv(qk) @ grad_att_pos
            #
            # grad_score.append(grad_qk)
            # grad_V_focus = grad_V_focus * weight_xyz
            # grad_V.append(grad_V_focus)

        # Shape should be B, NH, T, T
        grad_score = torch.cat(grad_score, dim=2)

        # Shape should be B, NH, T, HS
        grad_V = torch.cat(grad_V, dim=2)

        return grad_score, grad_V

focus = focusAttention.apply

from torch.autograd import gradcheck

input = (torch.randn(1, 1, 64, 64, dtype=torch.float64, requires_grad=True, device='cuda:0'), torch.randn(1, 1, 64, 4, dtype=torch.float64, requires_grad=True, device='cuda:0'))

test = gradcheck(focus, input, eps=1e-6, atol=1e-4)
print(test)