In [3]:
import json
import pandas as pd
import torch

from torch.autograd import Variable
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import SubsetRandomSampler
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import classification_report
from collections import Counter
from torch import nn

import sys
import copy
from torch.utils.data import Dataset, DataLoader
from sklearn import preprocessing
import numpy as np
from torchinfo import summary

from tqdm.auto import tqdm
import warnings

import math



In [None]:
data_dir = "/home/negar/Documents/Datasets/ChicagoWild++/mediapipe_res_chicago/"
hand_detected_label = "/home/negar/Desktop/Pooya/TF-DeepHand/Transformer/sign_hand_detection_wild++_first.csv"
labels_csv = "/home/negar/Desktop/Pooya/TF-DeepHand/Transformer/final.csv"


In [None]:
batch_size = 1
num_workers = 10
char_counts = 32
learning_rate = 0.0001
optim_step_size = 10
optim_gamma = 0.1
num_epochs = 120
SOS_token = 32
EOS_token = 0

In [None]:
#decode_type = "beam"
decode_type = "trans"

beam_size  = 5
lm_beta = 0.4
ins_gamma = 1.2
chars = "$' &.@acbedgfihkjmlonqpsrutwvyxz"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:

def beam_search(
    model, 
    X,
    poses,
    predictions = 20,
    beam_width = 5,
    batch_size = 128, 
    progress_bar = 0
):
    """
    Implements Beam Search to extend the sequences given in X. The method can compute 
    several outputs in parallel with the first dimension of X.

    Parameters
    ----------    
    X: LongTensor of shape (examples, length)
        The sequences to start the decoding process.

    predictions: int
        The number of tokens to append to X.

    beam_width: int
        The number of candidates to keep in the search.

    batch_size: int
        The batch size of the inner loop of the method, which relies on the beam width. 

    progress_bar: bool
        Shows a tqdm progress bar, useful for tracking progress with large tensors.

    Returns
    -------
    X: LongTensor of shape (examples, length + predictions)
        The sequences extended with the decoding process.

    probabilities: FloatTensor of length examples
        The estimated log-probabilities for the output sequences. They are computed by iteratively adding the 
        probability of the next token at every step.
    """
    with torch.no_grad():
        # The next command can be a memory bottleneck, but can be controlled with the batch 
        # size of the predict method.

        next_probabilities = model.forward(poses,X)[0][:,-1, :]
        vocabulary_size = next_probabilities.shape[-1]
        probabilities, idx = next_probabilities.squeeze().log_softmax(-1)\
            .topk(k = beam_width, axis = -1)
        X = X.repeat((beam_width, 1, 1)).transpose(0, 1)\
            .flatten(end_dim = -2)
        next_chars = idx.reshape(-1, 1)
        X = torch.cat((X, next_chars), axis = -1)
        
        poses = torch.cat((poses,poses, poses, poses, poses), axis = 0)
        # This has to be minus one because we already produced a round
        # of predictions before the for loop.
        predictions_iterator = range(predictions - 1)
        if progress_bar > 0:
            predictions_iterator = tqdm(predictions_iterator)
        for i in predictions_iterator:
            dataset = tud.TensorDataset(X)
            loader = tud.DataLoader(dataset, batch_size = batch_size)
            next_probabilities = []
            iterator = iter(loader)
            if progress_bar > 1:
                iterator = tqdm(iterator)
            for (x,) in iterator:
                next_probabilities.append(
                    model.forward(poses, x)[0][:,-1, :].log_softmax(-1)
                )
            next_probabilities = torch.cat(next_probabilities, axis = 0)
            next_probabilities = next_probabilities.reshape(
                (-1, beam_width, next_probabilities.shape[-1])
            )
            probabilities = probabilities.unsqueeze(-1) + next_probabilities
            probabilities = probabilities.flatten(start_dim = 1)
            probabilities, idx = probabilities.topk(
                k = beam_width, 
                axis = -1
            )
            next_chars = torch.remainder(idx, vocabulary_size).flatten()\
                .unsqueeze(-1)
            best_candidates = (idx / vocabulary_size).long()
            best_candidates += torch.arange(
                X.shape[0] // beam_width, 
                device = X.device
            ).unsqueeze(-1) * beam_width
            X = X[best_candidates].flatten(end_dim = -2)
            X = torch.cat((X, next_chars), axis = 1)
        return X.reshape(-1, beam_width, X.shape[-1]), probabilities

class Decoder(object):
    def __init__(self, labels, blank_index=0):
        self.labels = labels
        self.int_to_char = dict([(i, c) for (i, c) in enumerate(labels)])
        self.char_to_int = dict([(c, i) for (i, c) in enumerate(labels)])

        print(self.int_to_char)
        # print(self.char_to_int)

        self.blank_index = blank_index
        space_index = len(labels)
        if ' ' in labels:
            space_index = labels.index(' ')
        self.space_index = space_index

    def greedy_decode(self, prob, digit=False):
        # prob: [seq_len, num_labels+1], numpy array
        indexes = np.argmax(prob, axis=1)
        string = []
        prev_index = -1
        for i in range(len(indexes)):
            if indexes[i] == self.blank_index:
                prev_index = -1
                continue
            elif indexes[i] == prev_index:
                continue
            else:
                if digit is False:
                    
                    if len(string)>1 and self.int_to_char[indexes[i]]==string[-1]:
                        continue
                    string.append(self.int_to_char[indexes[i]])
                else:
                    string.append(indexes[i])
                prev_index = indexes[i]
        return string

    def beam_decode(self, prob, beam_size, beta=0.0, gamma=0.0, scorer=None, digit=False):
        # prob: [seq_len, num_labels+1], numpy array
        # beta: lm coef, gamma: insertion coef
        seqlen = len(prob)
        beam_idx = np.argsort(prob[0, :])[-beam_size:].tolist()

        beam_prob = list(map(lambda x: math.log(prob[0, x]), beam_idx))
        beam_idx = list(map(lambda x: [x], beam_idx))
        for t in range(1, seqlen):
            topk_idx = np.argsort(prob[t, :])[-beam_size:].tolist()
            topk_prob = list(map(lambda x: prob[t, x], topk_idx))
            aug_beam_prob, aug_beam_idx = [], []
            for b in range(beam_size*beam_size):
                aug_beam_prob.append(beam_prob[int(b/beam_size)])
                aug_beam_idx.append(list(beam_idx[int(b/beam_size)]))
            # allocate
            for b in range(beam_size*beam_size):
                i, j = int(b/beam_size), int(b % beam_size)
                aug_beam_idx[b].append(topk_idx[j])
                aug_beam_prob[b] = aug_beam_prob[b]+math.log(topk_prob[j])
            # merge
            merge_beam_idx, merge_beam_prob = [], []
            for b in range(beam_size*beam_size):
                if aug_beam_idx[b][-1] == aug_beam_idx[b][-2]:
                    beam, beam_prob = aug_beam_idx[b][:-1], aug_beam_prob[b]
                elif aug_beam_idx[b][-2] == self.blank_index:
                    beam, beam_prob = aug_beam_idx[b][:-2]+[aug_beam_idx[b][-1]], aug_beam_prob[b]
                else:
                    beam, beam_prob = aug_beam_idx[b], aug_beam_prob[b]
                beam_str = list(map(lambda x: self.int_to_char[x], beam))
                if beam_str not in merge_beam_idx:
                    merge_beam_idx.append(beam_str)
                    merge_beam_prob.append(beam_prob)
                else:
                    idx = merge_beam_idx.index(beam_str)
                    merge_beam_prob[idx] = np.logaddexp(merge_beam_prob[idx], beam_prob)

            if scorer is not None:
                merge_beam_prob_lm, ins_bonus, strings = [], [], []
                for b in range(len(merge_beam_prob)):
                    if merge_beam_idx[b][-1] == self.int_to_char[self.blank_index]:
                        strings.append(merge_beam_idx[b][:-1])
                        ins_bonus.append(len(merge_beam_idx[b][:-1]))
                    else:
                        strings.append(merge_beam_idx[b])
                        ins_bonus.append(len(merge_beam_idx[b]))
                lm_scores = scorer.get_score_fast(strings)
                for b in range(len(merge_beam_prob)):
                    total_score = merge_beam_prob[b]+beta*lm_scores[b]+gamma*ins_bonus[b]
                    merge_beam_prob_lm.append(total_score)

            if scorer is None:
                ntopk_idx = np.argsort(np.array(merge_beam_prob))[-beam_size:].tolist()
            else:
                ntopk_idx = np.argsort(np.array(merge_beam_prob_lm))[-beam_size:].tolist()
            beam_idx = list(map(lambda x: merge_beam_idx[x], ntopk_idx))
            for b in range(len(beam_idx)):
                beam_idx[b] = list(map(lambda x: self.char_to_int[x], beam_idx[b]))
            beam_prob = list(map(lambda x: merge_beam_prob[x], ntopk_idx))
        if self.blank_index in beam_idx[-1]:
            pred = beam_idx[-1][:-1]
        else:
            pred = beam_idx[-1]
        if digit is False:
            pred = list(map(lambda x: self.int_to_char[x], pred))
        return pred

    # def get_trans_score(self, decoder , strings):

    def beam_decode_trans(self, prob, beam_size, decoder, poses , beta=0.0, gamma=0.0, digit=False):
        # prob: [seq_len, num_labels+1], numpy array
        # beta: lm coef, gamma: insertion coef
        seqlen = len(prob)
        beam_idx = np.argsort(prob[0, :])[-beam_size:].tolist()

        beam_prob = list(map(lambda x: math.log(prob[0, x]), beam_idx))
        beam_idx = list(map(lambda x: [x], beam_idx))
        for t in range(1, seqlen):
            topk_idx = np.argsort(prob[t, :])[-beam_size:].tolist()
            topk_prob = list(map(lambda x: prob[t, x], topk_idx))
            aug_beam_prob, aug_beam_idx = [], []
            for b in range(beam_size*beam_size):
                aug_beam_prob.append(beam_prob[int(b/beam_size)])
                aug_beam_idx.append(list(beam_idx[int(b/beam_size)]))
            # allocate
            for b in range(beam_size*beam_size):
                i, j = int(b/beam_size), int(b % beam_size)
                aug_beam_idx[b].append(topk_idx[j])
                aug_beam_prob[b] = aug_beam_prob[b]+math.log(topk_prob[j])
            # merge
            merge_beam_idx, merge_beam_prob = [], []
            for b in range(beam_size*beam_size):
                if aug_beam_idx[b][-1] == aug_beam_idx[b][-2]:
                    beam, beam_prob = aug_beam_idx[b][:-1], aug_beam_prob[b]
                elif aug_beam_idx[b][-2] == self.blank_index:
                    beam, beam_prob = aug_beam_idx[b][:-2]+[aug_beam_idx[b][-1]], aug_beam_prob[b]
                else:
                    beam, beam_prob = aug_beam_idx[b], aug_beam_prob[b]
                beam_str = list(map(lambda x: self.int_to_char[x], beam))
                if beam_str not in merge_beam_idx:
                    merge_beam_idx.append(beam_str)
                    merge_beam_prob.append(beam_prob)
                else:
                    idx = merge_beam_idx.index(beam_str)
                    merge_beam_prob[idx] = np.logaddexp(merge_beam_prob[idx], beam_prob)

            merge_beam_prob_lm, ins_bonus, strings = [], [], []
            for b in range(len(merge_beam_prob)):
                if merge_beam_idx[b][-1] == self.int_to_char[self.blank_index]:
                    strings.append(merge_beam_idx[b][:-1])
                    ins_bonus.append(len(merge_beam_idx[b][:-1]))
                else:
                    strings.append(merge_beam_idx[b])
                    ins_bonus.append(len(merge_beam_idx[b]))
            
            lm_scores = decoder.return_scores(poses,strings,self.char_to_int)

            for b in range(len(merge_beam_prob)):
                total_score = merge_beam_prob[b]+beta*lm_scores[b]+gamma*ins_bonus[b]
                merge_beam_prob_lm.append(total_score)

            ntopk_idx = np.argsort(np.array(merge_beam_prob_lm))[-beam_size:].tolist()
            
            beam_idx = list(map(lambda x: merge_beam_idx[x], ntopk_idx))
            for b in range(len(beam_idx)):
                beam_idx[b] = list(map(lambda x: self.char_to_int[x], beam_idx[b]))
            beam_prob = list(map(lambda x: merge_beam_prob[x], ntopk_idx))
        if self.blank_index in beam_idx[-1]:
            pred = beam_idx[-1][:-1]
        else:
            pred = beam_idx[-1]
        if digit is False:
            pred = list(map(lambda x: self.int_to_char[x], pred))
        return pred

In [2]:
#Utilities
def remove_duplicates(x):
    if len(x) < 2:
        return x
    fin = ""
    for j in x:
        if fin == "":
            fin = j
        else:
            if j == fin[-1]:
                continue
            else:
                fin = fin + j
    return fin


def decode_predictions(preds, encoder):
    preds = torch.softmax(preds, 2)
    preds = torch.argmax(preds, 2)
    preds = preds.detach().cpu().numpy()
    sign_preds = []
    for j in range(preds.shape[0]):
        temp = []
        for k in preds[j, :]:
            k = k - 1
            if k == -1:
                temp.append("§")
            else:
                p = encoder.inverse_transform([k])[0]
                temp.append(p)
        tp = "".join(temp).replace("§", "")
        sign_preds.append(remove_duplicates(tp))
    return sign_preds

def numerize(sents, vocab_map,full_transformer):
    outs = []
    for sent in sents:
        if type(sent) != float :
            if full_transformer:
                outs.append([32]+ list(map(lambda x: vocab_map[x], sent))+ [0])
            else:
                outs.append(list(map(lambda x: vocab_map[x], sent)))

    return outs

def invert_to_chars(sents, inv_ctc_map):
    sents = sents.detach().numpy()
    outs = []
    for sent in sents:
        for x in sent:
            if x == 0:
                break
            outs.append(inv_ctc_map[x]) 
    return outs

def get_ctc_vocab(char_list):
    # blank
    ctc_char_list = "_" + char_list
    ctc_map, inv_ctc_map = {}, {}
    for i, char in enumerate(ctc_char_list):
        ctc_map[char] = i
        inv_ctc_map[i] = char
    return ctc_map, inv_ctc_map, ctc_char_list

def get_autoreg_vocab(char_list):
    # blank
    ctc_map, inv_ctc_map = {}, {}
    for i, char in enumerate(char_list):
        ctc_map[char] = i
        inv_ctc_map[i] = char
    return ctc_map, inv_ctc_map, char_list


def convert_text_for_ctc(DATASET_CSV_PATH,vocab_map,full_transformer=False):
    all_data = pd.read_csv(DATASET_CSV_PATH)
    all_data = all_data[all_data['filename'].notna()]
    all_data = all_data[all_data['label_proc'].notna()]
    label = all_data["label_proc"]
    
    targets_enc = numerize(label, vocab_map,full_transformer)

    # targets = [[c for c in x] for x in label]
    # targets_flat = [c for clist in targets for c in clist]
    # lbl_enc = preprocessing.LabelEncoder()
    # lbl_enc.fit(targets_flat)
    # targets_enc = [lbl_enc.transform(x) for x in targets]
    # targets_enc = np.array(targets_enc)
    # targets_enc = targets_enc + 1
    
    df = pd.DataFrame()
    df["names"] = all_data["filename"]
    df["enc"] = targets_enc

    # print("number of classes after conversion for CTC", lbl_enc.classes_)
    
    return  df

    # return  df , lbl_enc


subs = np.zeros((26,26))

def iterative_levenshtein(s, t, costs=(1, 1, 1)):
    """ 
    Computes Levenshtein distance between the strings s and t.
    For all i and j, dist[i,j] will contain the Levenshtein 
    distance between the first i characters of s and the 
    first j characters of t
    s: source, t: target
    costs: a tuple or a list with three integers (d, i, s)
           where d defines the costs for a deletion
                 i defines the costs for an insertion and
                 s defines the costs for a substitution
    return: 
    H, S, D, I: correct chars, number of substitutions, number of deletions, number of insertions
    """
    rows = len(s)+1
    cols = len(t)+1
    deletes, inserts, substitutes = costs
    
    dist = [[0 for x in range(cols)] for x in range(rows)]
    H, D, S, I = 0, 0, 0, 0
    for row in range(1, rows):
        dist[row][0] = row * deletes
    for col in range(1, cols):
        dist[0][col] = col * inserts
        
    for col in range(1, cols):
        for row in range(1, rows):
            if s[row-1] == t[col-1]:
                cost = 0
            else:
                cost = substitutes
            dist[row][col] = min(dist[row-1][col] + deletes,
                                 dist[row][col-1] + inserts,
                                 dist[row-1][col-1] + cost)
    row, col = rows-1, cols-1
    while row != 0 or col != 0:
        if row == 0:
            I += col
            col = 0
        elif col == 0:
            D += row
            row = 0
        elif dist[row][col] == dist[row-1][col] + deletes:
            D += 1
            row = row-1
        elif dist[row][col] == dist[row][col-1] + inserts:
            I += 1
            col = col-1
        elif dist[row][col] == dist[row-1][col-1] + substitutes:
            S += 1
            row, col = row-1, col-1

            print(s,t, s[row],t[col])
            if s[row] not in [' ','.'] and t[col] not in [' ','.'] :
                subs[ord(s[row])-97][ord(t[col])-97] += 1

        else:
            H += 1
            row, col = row-1, col-1
    D, I = I, D
    # print()
    return H, D, S, I

def compute_acc(preds, labels, costs=(7, 7, 10)):
    # cost according to HTK: http://www.ee.columbia.edu/~dpwe/LabROSA/doc/HTKBook21/node142.html

    if not len(preds) == len(labels):
        raise ValueError('# predictions not equal to # labels')
    Ns, Ds, Ss, Is = 0, 0, 0, 0
    for i, _ in enumerate(preds):
        H, D, S, I = iterative_levenshtein(preds[i], labels[i], costs)
        # print(H, D, S, I)
        Ns += len(labels[i])
        Ds += D
        Ss += S
        Is += I
    try:
        acc = 100*(Ns-Ds-Ss-Is)/Ns
    except ZeroDivisionError as err:
        raise ZeroDivisionError('Empty labels')
    
    print(Ds, Ss, Is, Ns)
    print(subs)
    return acc

# compute_acc(["akbyr"],["aaaabkbar"])

def ctcloss_reference(log_probs, targets, input_lengths, target_lengths, blank=0, reduction='none', logits_lm =None ):
    input_lengths = torch.as_tensor(input_lengths, dtype=torch.long)
    target_lengths = torch.as_tensor(target_lengths, dtype=torch.long)
    dt = log_probs.dtype
    log_probs = log_probs.double()  # we need the accuracy as we are not in logspace
    targets = targets.long()
    cum_target_lengths = target_lengths.cumsum(0)
    losses = []
    for i in range(log_probs.size(1)):
        input_length = input_lengths[i].item()
        target_length = target_lengths[i].item()
        cum_target_length = cum_target_lengths[i].item()
        # ==========================================================================================================
        targets_prime = targets.new_full((2 * target_length + 1,), blank)
        if targets.dim() == 2:
            targets_prime[1::2] = targets[i, :target_length]
        else:
            targets_prime[1::2] = targets[cum_target_length - target_length:cum_target_length]
        # ==========================================================================================================
        probs = log_probs[:input_length, i].exp()
        # ==========================================================================================================
        alpha = log_probs.new_zeros((target_length * 2 + 1,))
        logits_lm = F.softmax(logits_lm, dim=-1)
        
        lm_alpha = log_probs.new_zeros((target_length * 2 + 1,),dtype=logits_lm.dtype).float()
        lm_alpha[1::2] = torch.diagonal(logits_lm[:,targets[0]], 0) 

        alpha[0] = probs[0, blank]
        alpha[1] = probs[0, targets_prime[1]]
        mask_third = (targets_prime[:-2] != targets_prime[2:])
        
        for t in range(1, input_length):
            alpha_next = alpha.clone()
            alpha_next[1:] += (alpha[:-1] + lm_alpha[1:])
            alpha_next[2:] += torch.where(mask_third, alpha[:-2], alpha.new_zeros(1))
            alpha = probs[t, targets_prime] * alpha_next
            # if logits_lm != None:
        # ==========================================================================================================
        losses.append(-alpha[-2:].sum().log()[None])
    output = torch.cat(losses, 0)
    if reduction == 'mean':
        return (output / target_lengths.to(dtype=output.dtype, device=output.device)).mean()
    elif reduction == 'sum':
        return output.sum()
    output = output.to(dt)
    return output

In [3]:
vocab_map, inv_vocab_map, char_list = get_autoreg_vocab(chars)
decoder_dec = Decoder(char_list, blank_index=0)

print(vocab_map)
print(inv_vocab_map)
print(char_list)



NameError: name 'chars' is not defined

In [None]:
target_enc_df = convert_text_for_ctc(labels_csv,vocab_map,True)

transform = transforms.Compose([GaussianNoise()])

dataset_train = HandPoseDataset(data_dir, labels_csv ,hand_detected_label, target_enc_df , "train", transform=transform)
traindataloader = DataLoader(dataset_train, batch_size=1, shuffle=True, num_workers=num_workers)

dataset_test = HandPoseDataset(data_dir, labels_csv , hand_detected_label, target_enc_df , "test" , augmentations =False )
testdataloader = DataLoader(dataset_test, batch_size=1, shuffle=False, num_workers=num_workers)

model = TransformerModel(output_dim=len(char_list), d_input = 42 ,d_model=256, nhead=8, num_layers=3, dropout=0.1).to(device)


In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)

loss_encoder = nn.CTCLoss(blank=0,zero_infinity=True, reduction='none')
loss_decoder = nn.CrossEntropyLoss()
# loss_decoder = nn.NLLLoss()
loss_cls = torch.nn.MSELoss()


In [None]:
scheduler = StepLR(optimizer, step_size=optim_step_size, gamma=optim_gamma)

vocab_map_enc, inv_vocab_map_enc, char_list_enc = get_ctc_vocab(chars[1:])
decoder_enc = Decoder(char_list_enc, blank_index=0)
print(vocab_map_enc)
print(inv_vocab_map_enc)
print(char_list_enc)


In [None]:
best_acc = 0


In [None]:
class GaussianNoise(object):
    def __init__(self, mean=0., std=0.001):
        self.std = std
        self.mean = mean

    def __call__(self, tensor):
        return tensor + torch.randn(tensor.size()) * self.std + self.mean

    def __repr__(self):
        return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)

In [None]:
#MODEL
def get_emb(sin_inp):
    """
    Gets a base embedding for one dimension with sin and cos intertwined
    """
    emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
    return torch.flatten(emb, -2, -1)

class PositionalEncoding1D(nn.Module):
    def __init__(self, channels):
        """
        :param channels: The last dimension of the tensor you want to apply pos emb to.
        """
        super(PositionalEncoding1D, self).__init__()
        self.org_channels = channels
        channels = int(np.ceil(channels / 2) * 2)
        self.channels = channels
        inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
        self.register_buffer("inv_freq", inv_freq)
        self.cached_penc = None

    def forward(self, tensor):
        """
        :param tensor: A 3d tensor of size (batch_size, x, ch)
        :return: Positional Encoding Matrix of size (batch_size, x, ch)
        """
        if len(tensor.shape) != 3:
            raise RuntimeError("The input tensor has to be 3d!")

        if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
            return self.cached_penc

        self.cached_penc = None
        batch_size, x, orig_ch = tensor.shape
        pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
        sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
        emb_x = get_emb(sin_inp_x)
        emb = torch.zeros((x, self.channels), device=tensor.device).type(tensor.type())
        emb[:, : self.channels] = emb_x

        self.cached_penc = emb[None, :, :orig_ch].repeat(batch_size, 1, 1)
        return self.cached_penc
        
class TransformerModel(nn.Module):
    def __init__(self, output_dim, d_input , d_model, nhead, num_layers, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.d_model = d_model
        self.pose_embed = nn.Linear(d_input, d_model)
        self.tgt_query = nn.Embedding(output_dim+1, d_model)
        self.d_input = d_input

        self.class_token = torch.nn.Parameter(
            torch.randn(1, 1, self.d_input)
        )

        # self.pos_encoder = PositionalEncoding(d_model, dropout)

        # LAYERS
        self.positional_encoder = PositionalEncoding1D(d_model)

        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dropout=dropout,batch_first=True), 
            num_layers=num_layers)
        self.transformer_decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, dropout=dropout,batch_first=True), 
            num_layers=num_layers)

        self.fc = nn.Sequential(
            nn.Linear(d_model, 128),
            nn.Dropout(0.1),
            nn.Linear(128, output_dim)
        )
        self.fc_enc = nn.Sequential(
            nn.Linear(d_model, 128),
            nn.Dropout(0.1),
            nn.Linear(128, output_dim)
        )
        
        # self.fc_cls = nn.Sequential(
        #     nn.Linear(d_model, 128),
        #     nn.Dropout(0.1),
        #     nn.Linear(128, 2)
        # )
        self.fc_cls = nn.Sequential(nn.LayerNorm(d_model), nn.Linear(d_model, 2))


        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dropout=dropout,
            batch_first=True,
        )



    def forward(self, poses,tgt=None):
        # Embedding layer for poses
        # poses = poses.view(1,-1,63)
        # pose_embedded = self.pose_embed(poses)
        # print(poses.shape)
        bs = poses.shape[0]

        poses = poses.view(bs,-1,self.d_input)

        poses = torch.cat([self.class_token, poses], dim=1)

        pose_embedded = self.pose_embed(poses)
        pos_embd = self.positional_encoder(pose_embedded)
        # print(self.tgt_query.shape)
        # Positional encoding
        batch_size , pose_len, pose_dim = pose_embedded.size()
        # pos = torch.arange(0, pose_len).unsqueeze(1).repeat(1, batch_size).to(pose_embedded.device)
        
        # # print(pos.shape)
        # pos_embedded = self.pos_embed(pos).permute(1, 0, 2) 
        src = pose_embedded + pos_embd.to(pose_embedded.device)
        
        # transformer_out = self.transformer(src, tgt)
        # Transformer encoder

        encoder = self.transformer_encoder(src)  # B x L x D
        encoder_out = self.fc_enc(encoder[:,1:,:])  # L x B x V
        cls_out = self.fc_cls(encoder[:,0,:])  # L x B x V


        if tgt!=None:
            tgt = tgt.reshape(bs,-1,1)
            pos_tgt = self.positional_encoder(tgt)
            tgt = self.tgt_query(tgt).reshape(bs,-1,self.d_model)
            tgt = tgt + pos_tgt.to(pose_embedded.device)

            tgt_mask = self.get_tgt_mask(tgt.size(1)).to(pose_embedded.device)
            decoder = self.transformer_decoder(tgt, encoder, tgt_mask=tgt_mask)
            logits = self.fc(decoder)  # L x B x V
            return  cls_out ,logits, encoder_out  # B x L x V

        return cls_out, encoder_out  # B x L x V
    
    @torch.no_grad()
    def return_scores(self, poses, strings, vocab_map):
        scores = []
        bs = poses.shape[0]
        poses = poses.view(bs,-1,self.d_input)

        poses = torch.cat([self.class_token, poses], dim=1)

        pose_embedded = self.pose_embed(poses)
        pos_embd = self.positional_encoder(pose_embedded)
        # print(self.tgt_query.shape)
        # Positional encoding
        batch_size , pose_len, pose_dim = pose_embedded.size()
        # pos = torch.arange(0, pose_len).unsqueeze(1).repeat(1, batch_size).to(pose_embedded.device)
        
        # # print(pos.shape)
        # pos_embedded = self.pos_embed(pos).permute(1, 0, 2) 
        src = pose_embedded + pos_embd.to(pose_embedded.device)
        
        # transformer_out = self.transformer(src, tgt)
        # Transformer encoder

        encoder = self.transformer_encoder(src)  # B x L x D
        encoder_out = self.fc_enc(encoder[:,1:,:])  # L x B x V
        cls_out = self.fc_cls(encoder[:,0,:])  # L x B x V

        for string in strings:
            if len(string)<1 :
                scores.append(0)
                continue
            criterion = nn.CrossEntropyLoss()
            tgt = torch.tensor([[32]], dtype=torch.long, device=poses.device)
            input = [32] + list(map(lambda x: vocab_map[x], string)) 
            grt =  torch.tensor(input[1:], dtype=torch.long, device=poses.device)
            tgt = torch.tensor(input[:-1], dtype=torch.long, device=poses.device)
            

            tgt = tgt.reshape(bs,-1,1)
            pos_tgt = self.positional_encoder(tgt)
            tgt = self.tgt_query(tgt).reshape(bs,-1,self.d_model)
            tgt = tgt + pos_tgt.to(pose_embedded.device)

            tgt_mask = self.get_tgt_mask(tgt.size(1)).to(pose_embedded.device)
            decoder = self.transformer_decoder(tgt, encoder, tgt_mask=tgt_mask)
            logits = self.fc(decoder)  # L x B x V
            loss = criterion(logits[0], grt)
            scores.append(-(len(input)-1)*loss.item())
        return scores


    def create_pad_mask(self, matrix: torch.tensor, pad_token: int):
        # If matrix = [1,2,3,0,0,0] where pad_token=0, the result mask is
        # [False, False, False, True, True, True]
        return (matrix == pad_token)

    def get_tgt_mask(self, size):
        # Generates a squeare matrix where the each row allows one word more to be seen
        mask = torch.tril(torch.ones(size, size) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) # Convert zeros to -inf
        mask = mask.masked_fill(mask == 1, float(0.0)) # Convert ones to 0
        
        # EX for size=5:
        # [[0., -inf, -inf, -inf, -inf],
        #  [0.,   0., -inf, -inf, -inf],
        #  [0.,   0.,   0., -inf, -inf],
        #  [0.,   0.,   0.,   0., -inf],
        #  [0.,   0.,   0.,   0.,   0.]]
        return mask

def collate_fn(batch):
    # Pad sequences to the same length
    # Not used right now 

    poses, labels = zip(*batch)
    poses = pad_sequence(poses, batch_first=True)
    labels = pad_sequence(labels, batch_first=True)
    return poses, labels




class GCN(torch.nn.Module):
    def __init__(self, input_chanels, hidden_channels1, hidden_channels2, hidden_channels3, output_channels):
        super().__init__( )
        # torch.manual_seed(1234567)
        self.pose_embed = nn.Linear(input_chanels, hidden_channels1)
        self.positional_encoder = PositionalEncoding1D(hidden_channels1)

        self.transformer1 = TransformerConv(hidden_channels1, hidden_channels2, heads=8, dropout=0.1)
        self.head_transformer1 =nn.Linear(hidden_channels1*8, hidden_channels1)
        self.transformer2 = TransformerConv(hidden_channels1, hidden_channels2, heads=8, dropout=0.1)
        self.head_transformer2 =nn.Linear(hidden_channels1*8, hidden_channels1)

        self.conv3 = GCNConv(hidden_channels2, hidden_channels3)
        self.conv4 = GCNConv(hidden_channels3, output_channels)

        self.non_linearity = nn.ELU()
        self.activation = torch.nn.Sigmoid()

    def forward(self, data):
        # print(f'Inside model - num graphs: {data.num_graphs},', f'device: {data.batch.device}')
        x, edge_index = data.x, data.edge_index
        # import pdb; pdb.set_trace()

        x = x.view(-1,63)
        x = self.pose_embed(x)
        
        pos_embd = self.positional_encoder(x.unsqueeze(dim=0))
        x = x + pos_embd.to(x.device)[0]
        
        x = self.transformer1(x,edge_index)
        x = self.head_transformer1(x)
        x = self.transformer2(x,edge_index)
        x = self.head_transformer2(x)

        x = self.conv3(x,edge_index)
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv4(x,edge_index)
        # import pdb; pdb.set_trace()

        return x.unsqueeze(0)


In [None]:
#Training Loop
for epoch in range(num_epochs):
    total_loss = 0
    total_loss_cls = 0

for epoch in range(num_epochs):
    total_loss = 0
    total_loss_cls = 0

    model.train()

    for i, (poses, labels) in enumerate(traindataloader):
        optimizer.zero_grad()
                
        cls_token , logits_lm, encoder_out = model(poses.to(device), labels[:, :-1].to(device))

        log_probs_enc = F.log_softmax(encoder_out, dim=-1).permute(1,0,2)
        
        input_lengths = torch.full((encoder_out.size(0),), log_probs_enc.size(0), dtype=torch.long)
        target_lengths = torch.full((encoder_out.size(0),), labels.size(1)-2, dtype=torch.long )
        
        loss_enc = loss_encoder(log_probs_enc, labels[:,1:-1], input_lengths=input_lengths, target_lengths=target_lengths)
        
        # print(log_probs_enc.device,input_lengths.device)

        # loss_enc = ctcloss_reference(log_probs_enc, labels[:,1:-1].cuda(), input_lengths, target_lengths, logits_lm= logits_lm[0,:-1]).float()
        
        # print(loss_enc,expected)

        loss_dec = loss_decoder(logits_lm[0].cpu(), labels[:,1:].view(-1))


        # if i%30 == 0:
        #     current_preds_enc = decoder_enc.greedy_decode(log_probs_enc[:,0,:].detach().cpu().numpy())
        #     current_preds_enc = ''.join(current_preds_enc)
        #     print(current_preds_enc, ''.join(invert_to_chars(labels[:,1:-1],inv_vocab_map)))
        
        gt_label_size = torch.tensor([[math.sin(((labels.size(1)-2)/30-0.5)*2*torch.pi),math.cos(((labels.size(1)-2)/30-0.5)*2*torch.pi)]],device = device )
        loss_token = (cls_token, gt_label_size) 

        loss = loss_dec+ 5*loss_enc + loss_token

        if i%400 == 0:    
            print('Epoch {}/{} - loss: {:.4f} - loss_cls: {:.4f}' .format(epoch+1, num_epochs, total_loss/(i+1), total_loss_cls/(i+1)))

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_loss_cls += loss_token.item()

    scheduler.step()
    model.eval()
    preds = []
    gt_labels = []

    if epoch < 9 :
        continue
    

    for i, (poses, labels) in enumerate(testdataloader):
        poses = poses.to(device)
        cls_token , logits = model(poses)
        log_probs = F.softmax(logits, dim=-1)

        pred_size = (torch.atan2(torch.tensor([cls_token[0,0].detach().cpu()]),torch.tensor([cls_token[0,1].detach().cpu()]))/(2 * torch.pi) +0.5) * 30
        pred_size = torch.round(pred_size)

        current_preds = decoder_dec.beam_decode_trans(log_probs[0].detach().cpu(), beam_size, model, poses , beta=lm_beta, gamma=ins_gamma)
        current_preds = ''.join(current_preds)

        preds.append(current_preds)

        print(current_preds,''.join(invert_to_chars(labels[:,1:-1],inv_vocab_map)), "   ", pred_size) 
        gt_labels.append(''.join(invert_to_chars(labels[:,1:-1],inv_vocab_map)))

    lev_acc = compute_acc(preds, gt_labels)
    if best_acc < lev_acc:
        best_acc = lev_acc
        torch.save(model.state_dict(), 'best_model.pt')


    print('Epoch {}/{} - Letter Acc: {:.4f} - Best Acc {:.4f}'.format(epoch+1, num_epochs, lev_acc, best_acc))    model.train()

    for i, (poses, labels) in enumerate(traindataloader):
        optimizer.zero_grad()
                
        cls_token , logits_lm, encoder_out = model(poses.to(device), labels[:, :-1].to(device))

        log_probs_enc = F.log_softmax(encoder_out, dim=-1).permute(1,0,2)
        
        input_lengths = torch.full((encoder_out.size(0),), log_probs_enc.size(0), dtype=torch.long)
        target_lengths = torch.full((encoder_out.size(0),), labels.size(1)-2, dtype=torch.long )
        
        loss_enc = loss_encoder(log_probs_enc, labels[:,1:-1], input_lengths=input_lengths, target_lengths=target_lengths)
        
        # print(log_probs_enc.device,input_lengths.device)

        # loss_enc = ctcloss_reference(log_probs_enc, labels[:,1:-1].cuda(), input_lengths, target_lengths, logits_lm= logits_lm[0,:-1]).float()
        
        # print(loss_enc,expected)

        loss_dec = loss_decoder(logits_lm[0].cpu(), labels[:,1:].view(-1))


        # if i%30 == 0:
        #     current_preds_enc = decoder_enc.greedy_decode(log_probs_enc[:,0,:].detach().cpu().numpy())
        #     current_preds_enc = ''.join(current_preds_enc)
        #     print(current_preds_enc, ''.join(invert_to_chars(labels[:,1:-1],inv_vocab_map)))
        
        gt_label_size = torch.tensor([[math.sin(((labels.size(1)-2)/30-0.5)*2*torch.pi),math.cos(((labels.size(1)-2)/30-0.5)*2*torch.pi)]],device = device )
        loss_token = (cls_token, gt_label_size) 

        loss = loss_dec+ 5*loss_enc + loss_token

        if i%400 == 0:    
            print('Epoch {}/{} - loss: {:.4f} - loss_cls: {:.4f}' .format(epoch+1, num_epochs, total_loss/(i+1), total_loss_cls/(i+1)))

        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        total_loss_cls += loss_token.item()

    scheduler.step()
    model.eval()
    preds = []
    gt_labels = []

    if epoch < 9 :
        continue
    

    for i, (poses, labels) in enumerate(testdataloader):
        poses = poses.to(device)
        cls_token , logits = model(poses)
        log_probs = F.softmax(logits, dim=-1)

        pred_size = (torch.atan2(torch.tensor([cls_token[0,0].detach().cpu()]),torch.tensor([cls_token[0,1].detach().cpu()]))/(2 * torch.pi) +0.5) * 30
        pred_size = torch.round(pred_size)

        current_preds = decoder_dec.beam_decode_trans(log_probs[0].detach().cpu(), beam_size, model, poses , beta=lm_beta, gamma=ins_gamma)
        current_preds = ''.join(current_preds)

        preds.append(current_preds)

        print(current_preds,''.join(invert_to_chars(labels[:,1:-1],inv_vocab_map)), "   ", pred_size) 
        gt_labels.append(''.join(invert_to_chars(labels[:,1:-1],inv_vocab_map)))

    lev_acc = compute_acc(preds, gt_labels)
    if best_acc < lev_acc:
        best_acc = lev_acc
        torch.save(model.state_dict(), 'best_model.pt')


    print('Epoch {}/{} - Letter Acc: {:.4f} - Best Acc {:.4f}'.format(epoch+1, num_epochs, lev_acc, best_acc))