# Attention

In [11]:
import math
from typing import Optional
import random

import torch

import torch.nn as nn
from torch import Tensor
from torch.nn import Module, TransformerEncoder
from torch.utils.data import DataLoader
import gpytorch

## Test batch (fast_gp)

In [52]:
torch.manual_seed(1)
num_features = 20 # setting this manually

class PriorDataLoader(DataLoader):
    pass
    # init accepts num_steps as first argument

    # has two attributes set on class or object level:
    # num_features: int and
    # num_outputs: int
    # fuse_x_y: bool
    # Optional: validate function that accepts a transformer model

# tabpfn/utils.py
def set_locals_in_self(locals):
    """
    Call this function like `set_locals_in_self(locals())` to set all local variables as object variables.
    Especially useful right at the beginning of `__init__`.
    :param locals: `locals()`
    """
    self = locals['self']
    for var_name, val in locals.items():
        if var_name != 'self': setattr(self, var_name, val)

default_device = 'cuda:0' if torch.cuda.is_available() else 'cpu:0'

# priors/utils.py
def get_batch_to_dataloader(get_batch_method_):
    class DL(PriorDataLoader):
        get_batch_method = get_batch_method_

        num_features = num_features

        # Caution, you might need to set self.num_features manually if it is not part of the args.
        def __init__(self, num_steps, **get_batch_kwargs):
            set_locals_in_self(locals())

            # The stuff outside the or is set as class attribute before instantiation.
            self.num_features = get_batch_kwargs.get('num_features') or self.num_features
            self.epoch_count = 0
            #print('DataLoader.__dict__', self.__dict__)

        @staticmethod
        def gbm(*args, eval_pos_seq_len_sampler, **kwargs):
            kwargs['single_eval_pos'], kwargs['seq_len'] = eval_pos_seq_len_sampler()
            # Scales the batch size dynamically with the power of 'dynamic_batch_size'.
            # A transformer with quadratic memory usage in the seq len would need a power of 2 to keep memory constant.
            if 'dynamic_batch_size' in kwargs and kwargs['dynamic_batch_size'] > 0 and kwargs['dynamic_batch_size']:
                kwargs['batch_size'] = kwargs['batch_size'] * math.floor(math.pow(kwargs['seq_len_maximum'], kwargs['dynamic_batch_size']) / math.pow(kwargs['seq_len'], kwargs['dynamic_batch_size']))
            batch = get_batch_method_(*args, **kwargs)
            x, y, target_y, style = batch if len(batch) == 4 else (batch[0], batch[1], batch[2], None)
            return (style, x, y), target_y, kwargs['single_eval_pos']

        def __len__(self):
            return self.num_steps

        def get_test_batch(self): # does not increase epoch_count
            return self.gbm(**self.get_batch_kwargs, epoch=self.epoch_count, model=self.model if hasattr(self, 'model') else None)

        def __iter__(self):
            assert hasattr(self, 'model'), "Please assign model with `dl.model = ...` before training."
            self.epoch_count += 1
            return iter(self.gbm(**self.get_batch_kwargs, epoch=self.epoch_count - 1, model=self.model) for _ in range(self.num_steps))

    return DL

In [53]:
# from priors/fast_gp.py
torch.manual_seed(1)

class ExactGPModel(gpytorch.models.ExactGP):
    def __init__(self, train_x, train_y, likelihood):
        super(ExactGPModel, self).__init__(train_x, train_y, likelihood)
        self.mean_module = gpytorch.means.ConstantMean()
        self.covar_module = gpytorch.kernels.ScaleKernel(gpytorch.kernels.RBFKernel())

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

def get_model(x, y, hyperparameters):
    likelihood = gpytorch.likelihoods.GaussianLikelihood(noise_constraint=gpytorch.constraints.GreaterThan(1.e-9))
    model = ExactGPModel(x, y, likelihood)
    model.likelihood.noise = torch.ones_like(model.likelihood.noise) * hyperparameters["noise"]
    model.covar_module.outputscale = torch.ones_like(model.covar_module.outputscale) * hyperparameters["outputscale"]
    model.covar_module.base_kernel.lengthscale = torch.ones_like(model.covar_module.base_kernel.lengthscale) * \
                                                 hyperparameters["lengthscale"]
    return model, likelihood

# manually setting num_features=20
@torch.no_grad()
def get_batch(batch_size, seq_len, num_features=20, device=default_device, hyperparameters=None,
              equidistant_x=False, fix_x=None, **kwargs):
    if isinstance(hyperparameters, (tuple, list)):
        hyperparameters = {"noise": hyperparameters[0]
            , "outputscale": hyperparameters[1]
            , "lengthscale": hyperparameters[2]
            , "is_binary_classification": hyperparameters[3]
            # , "num_features_used": hyperparameters[4]
            , "normalize_by_used_features": hyperparameters[5]
            , "order_y": hyperparameters[6]
            , "sampling": hyperparameters[7]
                           }
    elif hyperparameters is None:
        hyperparameters = {"noise": .1, "outputscale": .1, "lengthscale": .1}

    if 'verbose' in hyperparameters and hyperparameters['verbose']:
        print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
                  , "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size, 'sampling': hyperparameters['sampling']})

    # hyperparameters = {k: hyperparameters[k]() if callable(hyperparameters[k]) else hyperparameters[k] for k in
    #      hyperparameters.keys()}
    assert not (equidistant_x and (fix_x is not None))

    with gpytorch.settings.fast_computations(*hyperparameters.get('fast_computations', (True, True, True))):
        if equidistant_x:
            assert num_features == 1
            x = torch.linspace(0, 1., seq_len).unsqueeze(0).repeat(batch_size, 1).unsqueeze(-1)
        elif fix_x is not None:
            assert fix_x.shape == (seq_len, num_features)
            x = fix_x.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
        else:
            if hyperparameters.get('sampling','uniform') == 'uniform':
                x = torch.rand(batch_size, seq_len, num_features, device=device)
            else:
                x = torch.randn(batch_size, seq_len, num_features, device=device)
        model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
        model.to(device)
        # trained_model = ExactGPModel(train_x, train_y, likelihood).cuda()
        # trained_model.eval()
        is_fitted = False
        while not is_fitted:
            try:
                with gpytorch.settings.prior_mode(True):
                    model, likelihood = get_model(x, torch.Tensor(), hyperparameters)
                    model.to(device)

                    d = model(x)
                    d = likelihood(d)
                    sample = d.sample().transpose(0, 1)
                    is_fitted = True
            except RuntimeError: # This can happen when torch.linalg.eigh fails. Restart with new init resolves this.
                print('GP Fitting unsuccessful, retrying.. ')
                print(x)
                print(hyperparameters)

    if bool(torch.any(torch.isnan(x)).detach().cpu().numpy()):
        print({"noise": hyperparameters['noise'], "outputscale": hyperparameters['outputscale']
                  , "lengthscale": hyperparameters['lengthscale'], 'batch_size': batch_size})

    # TODO: Multi output
    return x.transpose(0, 1), sample, sample  # x.shape = (T,B,H)

DataLoader = get_batch_to_dataloader(get_batch)

In [54]:
# tabpfn/train.py
torch.manual_seed(1)
steps_per_epoch = 100 # set to 10
batch_size = 200 # set to 1000
bptt=10 # default in function train(), not changed afterwards
bptt_extra_samples=None # default in function train(), not changed afterwards
single_eval_pos_gen=None # default in function train(), not changed afterwards
extra_prior_kwargs_dict={} # default in function train(), not changed afterwards
gpu_device='cuda:0' # default in function train(), not changed afterwards
device = gpu_device if torch.cuda.is_available() else 'cpu:0'
single_eval_pos_gen=None

def eval_pos_seq_len_sampler():
    single_eval_pos = single_eval_pos_gen()
    if bptt_extra_samples:
        return single_eval_pos, single_eval_pos + bptt_extra_samples
    else:
        return single_eval_pos, bptt

priordataloader_class = DataLoader

dl = priordataloader_class(num_steps=steps_per_epoch, batch_size=batch_size, eval_pos_seq_len_sampler=eval_pos_seq_len_sampler, seq_len_maximum=bptt+(bptt_extra_samples if bptt_extra_samples else 0), device=device, **extra_prior_kwargs_dict)

In [55]:
torch.manual_seed(1)
def get_test_batch(self): # does not increase epoch_count
            return self.gbm(**self.get_batch_kwargs, epoch=self.epoch_count, model=self.model if hasattr(self, 'model') else None)

def get_uniform_single_eval_pos_sampler(max_len, min_len=0):
    """
    Just sample any evaluation position with the same weight
    :return: Sampler that can be fed to `train()` as `single_eval_pos_gen`.
    """
    return lambda: random.choices(range(min_len, max_len))[0]

get_sampler = get_uniform_single_eval_pos_sampler
permutation_invariant_max_eval_pos = 100 # very random, had to set it to sth but don't know what this is

single_eval_pos_gen = get_sampler(permutation_invariant_max_eval_pos)

In [56]:
style_def = dl.get_test_batch()[0][0] # the style in batch of the form ((style, x, y), target, single_eval_pos)

### data

In [57]:
data_gp = dl.get_test_batch()

### info about data

In [58]:
# tuple: (style, x, y)
print(dl.get_test_batch()[0][0]) # style: seems like it's None :(
print(dl.get_test_batch()[0][1].shape) # x: seems like its a batch of 10 samples where each has 200 x vectors each with 20 features
print(dl.get_test_batch()[0][2].shape) # y: seems like its a batch of 10 samples where each has 200 x vectors of length 1

None
torch.Size([10, 200, 20])
torch.Size([10, 200])


In [59]:
print(dl.get_test_batch()[0][1][0,0:4,0:5]) # rows are vecs x1, x2, x3, x4
print(dl.get_test_batch()[0][2][0,0:4]) # elements are values y1, y2, y3, y4

tensor([[0.7474, 0.6250, 0.1107, 0.2828, 0.1912],
        [0.7711, 0.6751, 0.0138, 0.4008, 0.6349],
        [0.5934, 0.3755, 0.3774, 0.0090, 0.6477],
        [0.1862, 0.3648, 0.1937, 0.8451, 0.6535]])
tensor([ 0.9026,  0.1836, -0.0214,  0.3186])


## Data

In [28]:
src = data_gp

## TransformerModel()

In [60]:
# passed into train() in train.py
emsize=512 #yes, same in the paper
nhead=4 #yes, same in the paper
nhid=2*emsize # #yes, same in the paper: 1024
nlayers=6 # hmm, paper says 12

# encoder = 
n_out = 1 # can be 2 or sth else
ninp = emsize
nhead = nhead
nhid = 2*emsize
nlayers = nlayers
dropout=0.0
style_encoder=None
y_encoder=None
pos_encoder=None
decoder=None
input_normalization=False
init_method=None
pre_norm=False
activation='gelu'
recompute_attn=False
num_global_att_tokens=0
full_attention=False
all_layers_same_init=False
efficient_eval_masking=True

num_features = 20 # no default, depends on the dataset

## Encoder block

### TransformerEncoderLayer()

In [73]:
from functools import partial

from torch import nn
import torch
from torch.nn.modules.transformer import _get_activation_fn, Module, Tensor, Optional, MultiheadAttention, Linear, Dropout, LayerNorm
from torch.utils.checkpoint import checkpoint

# added by Ugne (before it showed error: F is not defined)
from torch.nn import functional as F

# full
class TransformerEncoderLayer(Module):
    r"""TransformerEncoderLayer is made up of self-attn and feedforward network.
    This standard encoder layer is based on the paper "Attention Is All You Need".
    Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N Gomez,
    Lukasz Kaiser, and Illia Polosukhin. 2017. Attention is all you need. In Advances in
    Neural Information Processing Systems, pages 6000-6010. Users may modify or implement
    in a different way during application.

    Args:
        d_model: the number of expected features in the input (required).
        nhead: the number of heads in the multiheadattention models (required).
        dim_feedforward: the dimension of the feedforward network model (default=2048).
        dropout: the dropout value (default=0.1).
        activation: the activation function of intermediate layer, relu or gelu (default=relu).
        layer_norm_eps: the eps value in layer normalization components (default=1e-5).
        batch_first: If ``True``, then the input and output tensors are provided
            as (batch, seq, feature). Default: ``False``.

    Examples::
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
        >>> src = torch.rand(10, 32, 512)
        >>> out = encoder_layer(src)

    Alternatively, when ``batch_first`` is ``True``:
        >>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
        >>> src = torch.rand(32, 10, 512)
        >>> out = encoder_layer(src)
    """
    __constants__ = ['batch_first']

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
                 layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
                 device=None, dtype=None, recompute_attn=False) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)
        # Implementation of Feedforward model
        self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout = Dropout(dropout)
        self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.pre_norm = pre_norm
        self.recompute_attn = recompute_attn

        self.activation = _get_activation_fn(activation)

    def __setstate__(self, state): # not sure what it does
        if 'activation' not in state:
            state['activation'] = F.relu
        super().__setstate__(state)

    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        torch.manual_seed(1)
        r"""Pass the input through the encoder layer.

        Args:
            src: the sequence to the encoder layer (required).
            src_mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        if self.pre_norm: # NOT RUN: pre_norm=False by default and is not changed in model=TransformerModel() in train.py
            src_ = self.norm1(src)
            #print("not run")
        else: # this gets RUN
            src_ = src
        if isinstance(src_mask, tuple): # NOT RUN - AssertionError 
            # global attention setup
            assert not self.self_attn.batch_first # AssertionError when batch_first=True: not True = False  --> so batch_first must be False (and it is - default False is not changed in model=TransformerModel() in train.py)
            assert src_key_padding_mask is None # AssertionError when src_key_padding_mask=None --> so src_key_padding_mask must be not None (but it is None - default None is not changed)
            
            # I think this is not run as we get AssertionError: default src_key_padding_mask=None is not changed
            # so we actually do what's in else (elif also gets AssertionError fot the same reason)
            
            global_src_mask, trainset_src_mask, valset_src_mask = src_mask

            num_global_tokens = global_src_mask.shape[0]
            num_train_tokens = trainset_src_mask.shape[0]

            global_tokens_src = src_[:num_global_tokens]
            train_tokens_src = src_[num_global_tokens:num_global_tokens+num_train_tokens]
            global_and_train_tokens_src = src_[:num_global_tokens+num_train_tokens]
            eval_tokens_src = src_[num_global_tokens+num_train_tokens:]


            attn = partial(checkpoint, self.self_attn) if self.recompute_attn else self.self_attn

            global_tokens_src2 = attn(global_tokens_src, global_and_train_tokens_src, global_and_train_tokens_src, None, True, global_src_mask)[0]
            train_tokens_src2 = attn(train_tokens_src, global_tokens_src, global_tokens_src, None, True, trainset_src_mask)[0]
            eval_tokens_src2 = attn(eval_tokens_src, src_, src_,
                                    None, True, valset_src_mask)[0]

            src2 = torch.cat([global_tokens_src2, train_tokens_src2, eval_tokens_src2], dim=0)

        elif isinstance(src_mask, int): # NOT RUN - AssertionError 
            assert src_key_padding_mask is None # AssertionError when src_key_padding_mask=None --> so src_key_padding_mask must be not None (but it is None - default None is not changed)
            single_eval_position = src_mask
            src_left = self.self_attn(src_[:single_eval_position], src_[:single_eval_position], src_[:single_eval_position])[0]
            src_right = self.self_attn(src_[single_eval_position:], src_[:single_eval_position], src_[:single_eval_position])[0]
            src2 = torch.cat([src_left, src_right], dim=0)
        else: # this gets RUN 
            if self.recompute_attn: # recompute_attn=False by default, and is not changed in model=TransformerModel() in train.py)
                src2 = checkpoint(self.self_attn, src_, src_, src_, src_key_padding_mask, True, src_mask)[0]
            else: # so we actually do this part
                src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask,
                                      key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        if not self.pre_norm: # this gets RUN: pre_norm=False so not False is True
            src = self.norm1(src)

        if self.pre_norm: # NOT RUN: pre_norm=False
            src_ = self.norm2(src)
        else: # this gets RUN
            src_ = src
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src_))))
        src = src + self.dropout2(src2)

        if not self.pre_norm: # this gets RUN: pre_norm=False so not False is True
            src = self.norm2(src)
        return src


In [77]:
torch.manual_seed(1)
src = src = torch.rand(10, 32, 512)

encoder_layer = TransformerEncoderLayer(d_model=512, nhead=4)
out_full = encoder_layer(src)

print(src[0,0,0:3])
print(out_full.shape)
print(out_full[0,0,0:3]) # tensor([ 0.5695, -1.0787,  0.1266])

tensor([0.7576, 0.2793, 0.4031])
torch.Size([10, 32, 512])
tensor([ 0.5695, -1.0787,  0.1266], grad_fn=<SliceBackward>)


### DelTransformerEncoderLayer()

In [78]:
from functools import partial

from torch import nn
import torch
from torch.nn.modules.transformer import _get_activation_fn, Module, Tensor, Optional, MultiheadAttention, Linear, Dropout, LayerNorm
from torch.utils.checkpoint import checkpoint

# added by Ugne (before it showed error: F is not defined)
from torch.nn import functional as F

# commented out what's not run
class DelTransformerEncoderLayer(Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
                 layer_norm_eps=1e-5, batch_first=False, pre_norm=False,
                 device=None, dtype=None, recompute_attn=False) -> None:
        factory_kwargs = {'device': device, 'dtype': dtype}
        super().__init__()
        self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first,
                                            **factory_kwargs)

        self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
        self.dropout_ch = Dropout(dropout) # dropout -> dropout_ch
        self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)

        self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
        self.dropout1 = Dropout(dropout)
        self.dropout2 = Dropout(dropout)
        self.pre_norm = pre_norm
        self.recompute_attn = recompute_attn

        self.activation = _get_activation_fn(activation)


    def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        torch.manual_seed(1)
        
        # multihead attention
        src_ = src
        src2 = self.self_attn(src_, src_, src_, attn_mask=src_mask, key_padding_mask=src_key_padding_mask)[0]
        
        # add and normalize
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        # feed forward
        src_ = src
        src2 = self.linear2(self.dropout_ch(self.activation(self.linear1(src_)))) # dropout -> dropout_ch
        
        # add and normalize
        src = src + self.dropout2(src2)
        src = self.norm2(src)
        
        return src

In [79]:
torch.manual_seed(1)
src = src = torch.rand(10, 32, 512)

encoder_layer_del = DelTransformerEncoderLayer(d_model=512, nhead=4)
out_deleted = encoder_layer_del(src)

print(out_deleted.shape)
print(out_deleted[0,0,0:3]) # tensor([ 0.5695, -1.0787,  0.1266])

torch.Size([10, 32, 512])
tensor([ 0.5695, -1.0787,  0.1266], grad_fn=<SliceBackward>)


### with 1 layer

In [81]:
encoder_layer_creator = lambda: TransformerEncoderLayer(ninp, nhead, nhid, dropout, activation=activation,
                                                                pre_norm=pre_norm, recompute_attn=recompute_attn)

In [82]:
class TransformerEncoderDiffInit(Module):
    r"""TransformerEncoder is a stack of N encoder layers

    Args:
        encoder_layer_creator: a function generating objects of TransformerEncoderLayer class without args (required).
        num_layers: the number of sub-encoder-layers in the encoder (required).
        norm: the layer normalization component (optional).
    """
    __constants__ = ['norm']

    def __init__(self, encoder_layer_creator, num_layers, norm=None):
        super().__init__()
        self.layers = nn.ModuleList([encoder_layer_creator() for _ in range(num_layers)])
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src: Tensor, mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
        r"""Pass the input through the encoder layers in turn.

        Args:
            src: the sequence to the encoder (required).
            mask: the mask for the src sequence (optional).
            src_key_padding_mask: the mask for the src keys per batch (optional).

        Shape:
            see the docs in Transformer class.
        """
        output = src

        for mod in self.layers:
            output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask)

        if self.norm is not None:
            output = self.norm(output)

        return output

transformer_encoder = TransformerEncoderDiffInit(encoder_layer_creator, nlayers)

In [None]:
def init_weights(self):
    initrange = 1.
    # if isinstance(self.encoder,EmbeddingEncoder):
    #    self.encoder.weight.data.uniform_(-initrange, initrange)
    # self.decoder.bias.data.zero_()
    # self.decoder.weight.data.uniform_(-initrange, initrange)
    if self.init_method is not None:
        self.apply(self.init_method)
    for layer in self.transformer_encoder.layers:
        nn.init.zeros_(layer.linear2.weight)
        nn.init.zeros_(layer.linear2.bias)
        attns = layer.self_attn if isinstance(layer.self_attn, nn.ModuleList) else [layer.self_attn]
        for attn in attns:
            nn.init.zeros_(attn.out_proj.weight)
            nn.init.zeros_(attn.out_proj.bias)

def forward(self, src, src_mask=None, single_eval_pos=None):
    assert isinstance(src, tuple), 'inputs (src) have to be given as (x,y) or (style,x,y) tuple'

    if len(src) == 2: # (x,y) and no style
        src = (None,) + src

    style_src, x_src, y_src = src
    x_src = self.encoder(x_src)
    y_src = self.y_encoder(y_src.unsqueeze(-1) if len(y_src.shape) < len(x_src.shape) else y_src)
    style_src = self.style_encoder(style_src).unsqueeze(0) if self.style_encoder else \
        torch.tensor([], device=x_src.device)
    global_src = torch.tensor([], device=x_src.device) if self.global_att_embeddings is None else \
        self.global_att_embeddings.weight.unsqueeze(1).repeat(1, x_src.shape[1], 1)

    if src_mask is not None: assert self.global_att_embeddings is None or isinstance(src_mask, tuple)
    if src_mask is None: # this is RUN: default src_mask=None not changed it seems
        if self.global_att_embeddings is None: # this is RUN: global_att_embeddings=None it seems
            full_len = len(x_src) + len(style_src)
            if self.full_attention:
                src_mask = bool_mask_to_att_mask(torch.ones((full_len, full_len), dtype=torch.bool)).to(x_src.device)
            elif self.efficient_eval_masking:
                src_mask = single_eval_pos + len(style_src)
            else:
                src_mask = self.generate_D_q_matrix(full_len, len(x_src) - single_eval_pos).to(x_src.device)
        else:
            src_mask_args = (self.global_att_embeddings.num_embeddings,
                                len(x_src) + len(style_src),
                                len(x_src) + len(style_src) - single_eval_pos)
            src_mask = (self.generate_global_att_globaltokens_matrix(*src_mask_args).to(x_src.device),
                        self.generate_global_att_trainset_matrix(*src_mask_args).to(x_src.device),
                        self.generate_global_att_query_matrix(*src_mask_args).to(x_src.device))

    train_x = x_src[:single_eval_pos] + y_src[:single_eval_pos]
    src = torch.cat([global_src, style_src, train_x, x_src[single_eval_pos:]], 0)

    if self.input_ln is not None:
        src = self.input_ln(src)

    if self.pos_encoder is not None:
        src = self.pos_encoder(src)

    output = self.transformer_encoder(src, src_mask)
    output = self.decoder(output)
    return output[single_eval_pos+len(style_src)+(self.global_att_embeddings.num_embeddings if self.global_att_embeddings else 0):]