In [54]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [55]:
# Imports for most of the notebook
import torch
import pandas
from transformers import BertModel
from transformers import AutoTokenizer
from typing import Dict, List
from util import load_training_data

In [56]:
train_arg_path = "../data/arguments-training.tsv"
train_label_path = "../data/labels-training.tsv"
validation_arg_path = "../data/arguments-validation.tsv"
validation_label_path = "../data/labels-validation.tsv"

In [57]:
df_train_arguments, df_train_labels = load_training_data(train_arg_path, train_label_path)

df_train_labels = df_train_labels[["Argument ID", "Self-direction: action"]]
df_train_arguments = df_train_arguments.merge(df_train_labels, on='Argument ID')
df_train_arguments

Argument ID	Conclusion	Stance	Premise

  Argument ID                                         Conclusion       Stance  \
0      A01001                     Entrapment should be legalized  in favor of   
1      A01002                        We should ban human cloning  in favor of   
2      A01003                         We should abandon marriage      against   
3      A01004                          We should ban naturopathy      against   
4      A01005                            We should ban fast food  in favor of   
5      A01006        We should end the use of economic sanctions      against   
6      A01007               We should abolish capital punishment      against   
7      A01008                      We should ban factory farming      against   
8      A01009  We should fight for the abolition of nuclear w...      against   
9      A01010                   We should prohibit school prayer      against   

                                             Premise  
0  if entrapme

Unnamed: 0,Argument ID,Conclusion,Stance,Premise,Self-direction: action
0,A01001,Entrapment should be legalized,in favor of,if entrapment can serve to more easily capture...,0
1,A01002,We should ban human cloning,in favor of,we should ban human cloning as it will only ca...,0
2,A01003,We should abandon marriage,against,marriage is the ultimate commitment to someone...,1
3,A01004,We should ban naturopathy,against,it provides a useful income for some people,0
4,A01005,We should ban fast food,in favor of,fast food should be banned because it is reall...,0
...,...,...,...,...,...
5215,D27096,Nepotism exists in Bollywood,against,Star kids also have an upbringing which is sur...,0
5216,D27097,Nepotism exists in Bollywood,in favor of,Movie stars of Bollywood often launch their ch...,0
5217,D27098,India is safe for women,in favor of,Evil historic practices on women in the pre an...,0
5218,D27099,India is safe for women,in favor of,Women of our country have been and are achievi...,0


In [63]:
df_vali_arguments, df_vali_labels = load_training_data(validation_arg_path, validation_label_path)

df_vali_labels = df_vali_labels[["Argument ID", "Self-direction: action"]]
df_vali_arguments = df_vali_arguments.merge(df_vali_labels, on='Argument ID')
df_vali_arguments

Argument ID	Conclusion	Stance	Premise

unrecognized stance for line 'D03045	Cattle slaughter should be banned	in favour of	The cow is a sacred animal for Hindus. Killing cows hurts their sentiments.
', skipping.
unrecognized stance for line 'D03051	Hindi should be the national language of India	in favour of	English is not native to India, but Hindi is.
', skipping.
unrecognized stance for line 'D03052	Hindi should be the national language of India	in favour of	Hindi is spoken by more than half of India. As no single language is spoken by the entire of India, it is much better to make the majoritarian language the national language. 
', skipping.
  Argument ID                                       Conclusion       Stance  \
0      A01001                   Entrapment should be legalized  in favor of   
1      A01012  The use of public defenders should be mandatory  in favor of   
2      A02001                    Payday loans should be banned  in favor of   
3      A02002                 

Unnamed: 0,Argument ID,Conclusion,Stance,Premise,Self-direction: action
0,A01001,Entrapment should be legalized,in favor of,if entrapment can serve to more easily capture...,0
1,A01012,The use of public defenders should be mandatory,in favor of,the use of public defenders should be mandator...,0
2,A02001,Payday loans should be banned,in favor of,payday loans create a more impoverished societ...,0
3,A02002,Surrogacy should be banned,against,Surrogacy should not be banned as it is the wo...,1
4,A02009,Entrapment should be legalized,against,entrapment is gravely immoral and,0
...,...,...,...,...,...
1888,E08014,We should shift the EU policy toward the Russi...,in favor of,Pushing Russia to the wall will have adverse e...,0
1889,E08021,We should stop buying Russian gas,in favor of,The Russians use the money we give them in exc...,0
1890,E08022,We should stop buying Russian gas,in favor of,The cost of gas will be higher. But I prefer t...,1
1891,E08024,We should strengthen our ties with Ukraine and...,in favor of,We must support countries that want to improve...,1


In [74]:
def generate_input(dataset: pandas.core.frame.DataFrame) -> (List[str], List[str], List[int]):
    
    premise, conclusion, label = ([] for i in range(3))
    
    premise = dataset["Premise"].tolist()
    conclusion = (dataset["Stance"] + ": " + dataset["Conclusion"]).tolist()
    label = dataset["Self-direction: action"].tolist()
        
    return premise, conclusion, label

batch_size = 8

def chunk(lst, n):
    """Yield successive n-sized chunks from lst."""
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

def chunk_multi(lst1, lst2, n):
    for i in range(0, len(lst1), n):
        yield lst1[i: i + n], lst2[i: i + n]

In [75]:
def encode_labels(labels: List[int]) -> torch.FloatTensor:
    """Turns the batch of labels into a tensor

    Args:
        labels (List[int]): List of all labels in the batch

    Returns:
        torch.FloatTensor: Tensor of all labels in the batch
    """
    return torch.LongTensor([int(l) for l in labels])

In [76]:
# Huggingface tokenizer

class BatchTokenizer:
    """Tokenizes and pads a batch of input sentences."""

    def __init__(self):
        """Initializes the tokenizer

        Args:
            pad_symbol (Optional[str], optional): The symbol for a pad. Defaults to "<P>".
        """
        self.hf_tokenizer = AutoTokenizer.from_pretrained("prajjwal1/bert-small")
    
    def get_sep_token(self,):
        return self.hf_tokenizer.sep_token
    
    def __call__(self, prem_batch: List[str], conc_batch: List[str]) -> List[List[str]]:
        """Uses the huggingface tokenizer to tokenize and pad a batch.

        We return a dictionary of tensors per the huggingface model specification.

        Args:
            batch (List[str]): A List of sentence strings

        Returns:
            Dict: The dictionary of token specifications provided by HuggingFace
        """
        # The HF tokenizer will PAD for us, and additionally combine 
        # The two sentences deimited by the [SEP] token.
        enc = self.hf_tokenizer(
            prem_batch,
            conc_batch,
            padding=True,
            return_token_type_ids=False,
            return_tensors='pt'
        )

        return enc
    

# HERE IS AN EXAMPLE OF HOW TO USE THE BATCH TOKENIZER
tokenizer = BatchTokenizer()
x = tokenizer(*[["this is the premise.", "This is also a premise"], ["this is the hypothesis", "This is a second hypothesis"]])
print(x)
tokenizer.hf_tokenizer.batch_decode(x["input_ids"])


{'input_ids': tensor([[  101,  2023,  2003,  1996, 18458,  1012,   102,  2023,  2003,  1996,
         10744,   102,     0],
        [  101,  2023,  2003,  2036,  1037, 18458,   102,  2023,  2003,  1037,
          2117, 10744,   102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}


['[CLS] this is the premise. [SEP] this is the hypothesis [SEP] [PAD]',
 '[CLS] this is also a premise [SEP] this is a second hypothesis [SEP]']

In [90]:
# Build the model here
class HVDClassifier(torch.nn.Module):
    def __init__(self, output_size: int, hidden_size: int):
        super().__init__()
        self.output_size = output_size
        self.hidden_size = hidden_size
        # Initialize BERT, which we use instead of a single embedding layer.
        self.bert = BertModel.from_pretrained("prajjwal1/bert-small")
        # TODO [OPTIONAL]: Updating all BERT parameters can be slow and memory intensive. 
        # Freeze them if training is too slow. Notice that the learning
        # rate should probably be smaller in this case.
        # Uncommenting out the below 2 lines means only our classification layer will be updated.
        for param in self.bert.parameters():
            param.requires_grad = False
        self.bert_hidden_dimension = self.bert.config.hidden_size
        # TODO: Add an extra hidden layer in the classifier, projecting
        #      from the BERT hidden dimension to hidden size.
        self.hidden_layer = torch.nn.Linear(self.bert_hidden_dimension, self.hidden_size)
        
        # TODO: Add a relu nonlinearity to be used in the forward method
        #      https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html
        self.relu = torch.nn.ReLU()
        self.classifier = torch.nn.Linear(self.hidden_size, self.output_size)
        self.log_softmax = torch.nn.LogSoftmax(dim=2)

    def encode_text(
        self,
        symbols: Dict
    ) -> torch.Tensor:
        """Encode the (batch of) sequence(s) of token symbols with an LSTM.
            Then, get the last (non-padded) hidden state for each symbol and return that.

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: The final hiddens tate of the LSTM, which represents an encoding of
                the entire sentence
        """
        # First we get the contextualized embedding for each input symbol
        # We no longer need an LSTM, since BERT encodes context and 
        # gives us a single vector describing the sequence in the form of the [CLS] token.
        encoded_sequence = self.bert(**symbols)
        # TODO: Get the [CLS] token using the `pooler_output` from 
        #      The BertModel output. See here: https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertModel
        #      and check the returns for the forward method.
        # We want to return a tensor of the form batch_size x 1 x bert_hidden_dimension
        output = torch.unsqueeze(encoded_sequence.pooler_output,1)
        return output
    

    def forward(
        self,
        symbols: Dict,
    ) -> torch.Tensor:
        """_summary_

        Args:
            symbols (Dict): The Dict of token specifications provided by the HuggingFace tokenizer

        Returns:
            torch.Tensor: _description_
        """
        encoded_sents = self.encode_text(symbols)
        output = self.hidden_layer(encoded_sents)
        output = self.relu(output)
        output = self.classifier(output)
        return self.log_softmax(output)

In [78]:
# For making predictions at test time
def predict(model: torch.nn.Module, sents: torch.Tensor) -> List:
    logits = model(sents)
    return list(torch.argmax(logits, axis=2).squeeze().numpy())

In [79]:
import numpy as np
from numpy import logical_and, sum as t_sum

def precision(predicted_labels, true_labels, which_label=1):
    """
    Precision is True Positives / All Positives Predictions
    """
    pred_which = np.array([pred == which_label for pred in predicted_labels])
    true_which = np.array([lab == which_label for lab in true_labels])
    denominator = t_sum(pred_which)
    if denominator:
        return t_sum(logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def recall(predicted_labels, true_labels, which_label=1):
    """
    Recall is True Positives / All Positive Labels
    """
    pred_which = np.array([pred == which_label for pred in predicted_labels])
    true_which = np.array([lab == which_label for lab in true_labels])
    denominator = t_sum(true_which)
    if denominator:
        return t_sum(logical_and(pred_which, true_which))/denominator
    else:
        return 0.


def f1_score(
    predicted_labels: List[int],
    true_labels: List[int],
    which_label: int
):
    """
    F1 score is the harmonic mean of precision and recall
    """
    P = precision(predicted_labels, true_labels, which_label=which_label)
    R = recall(predicted_labels, true_labels, which_label=which_label)
    if P and R:
        return 2*P*R/(P+R)
    else:
        return 0.


def macro_f1(
    predicted_labels: List[int],
    true_labels: List[int],
    possible_labels: List[int]
):
    scores = [f1_score(predicted_labels, true_labels, l) for l in possible_labels]
    # Macro, so we take the uniform avg.
    return sum(scores) / len(scores)

In [88]:
import random
from tqdm import tqdm

def training_loop(
    num_epochs,
    train_features,
    train_labels,
    dev_sents,
    dev_labels,
    optimizer,
    model,
):
    print("Training...")
    loss_func = torch.nn.NLLLoss()
    batches = list(zip(train_features, train_labels))
    random.shuffle(batches)
    for i in range(num_epochs):
        losses = []
        for features, labels in tqdm(batches):
            # Empty the dynamic computation graph
            optimizer.zero_grad()
            preds = model(features).squeeze(1)
            loss = loss_func(preds, labels)
            # Backpropogate the loss through our model
            loss.backward()
            optimizer.step()
            losses.append(loss.item())
        
        print(f"epoch {i}, loss: {sum(losses)/len(losses)}")
        # Estimate the f1 score for the development set
        print("Evaluating dev...")
        all_preds = []
        all_labels = []
        for sents, labels in tqdm(zip(dev_sents, dev_labels), total=len(dev_sents)):
            pred = predict(model, sents)
            all_preds.extend(pred)
            all_labels.extend(list(labels.numpy()))

        dev_f1 = macro_f1(all_preds, all_labels, [0,1])
        print(f"Dev F1 {dev_f1}")
        
    # Return the trained model
    return model

In [81]:
# Batch and Tokenize data here
tokenizer = BatchTokenizer()

# Traning dataset
train_premises, train_conclusions, train_labels = generate_input(df_train_arguments)
# Batches
train_input_batches = [b for b in chunk_multi(train_premises, train_conclusions, batch_size)]
train_label_batches = [b for b in chunk(train_labels, batch_size)]
# Tokenize + encode
train_input_batches = [tokenizer(*batch) for batch in train_input_batches]
train_label_batches = [encode_labels(batch) for batch in train_label_batches]

# Validation dataset
vali_premises, vali_conclusions, vali_labels = generate_input(df_vali_arguments)
# Batches
vali_input_batches = [b for b in chunk_multi(vali_premises, vali_conclusions, batch_size)]
vali_label_batches = [b for b in chunk(vali_labels, batch_size)]
# Tokenize + encode
vali_input_batches = [tokenizer(*batch) for batch in vali_input_batches]
vali_label_batches = [encode_labels(batch) for batch in vali_label_batches]

{0, 1}

In [93]:
# You can increase epochs if need be
epochs = 25
# TODO: Find a good learning rate
LR = 0.001

possible_labels = len(set(train_labels))
model = HVDClassifier(output_size=possible_labels, hidden_size=2)
optimizer = torch.optim.AdamW(model.parameters(), LR)

training_loop(
    epochs,
    train_input_batches,
    train_label_batches,
    vali_input_batches,
    vali_label_batches,
    optimizer,
    model,
)

Some weights of the model checkpoint at prajjwal1/bert-small were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Training...


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:36<00:00, 18.02it/s]


epoch 0, loss: 0.5491673522621713
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:15<00:00, 15.73it/s]


Dev F1 0.42694014119031276


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.25it/s]


epoch 1, loss: 0.5154772237468095
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.19it/s]


Dev F1 0.4671328817443076


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.64it/s]


epoch 2, loss: 0.5018063504676724
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:15<00:00, 15.68it/s]


Dev F1 0.5087990288883444


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.71it/s]


epoch 3, loss: 0.4933519338301824
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.09it/s]


Dev F1 0.5191160973147378


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.39it/s]


epoch 4, loss: 0.487600623853904
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.29it/s]


Dev F1 0.5313513469970813


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.63it/s]


epoch 5, loss: 0.4830247539039958
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.11it/s]


Dev F1 0.5491069471636152


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.76it/s]


epoch 6, loss: 0.47935281344796393
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.11it/s]


Dev F1 0.5603201506591337


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:34<00:00, 19.07it/s]


epoch 7, loss: 0.47647733547101523
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.24it/s]


Dev F1 0.5688545635611236


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.43it/s]


epoch 8, loss: 0.4738991100563572
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.07it/s]


Dev F1 0.576917903535721


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:34<00:00, 19.20it/s]


epoch 9, loss: 0.4715468457017157
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 15.84it/s]


Dev F1 0.5836161387631976


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.54it/s]


epoch 10, loss: 0.4694790890009626
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 15.87it/s]


Dev F1 0.5899935022742041


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.67it/s]


epoch 11, loss: 0.4678639250067255
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.22it/s]


Dev F1 0.5935420562998162


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.68it/s]


epoch 12, loss: 0.4664178371748917
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.16it/s]


Dev F1 0.5951484436264185


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:34<00:00, 19.03it/s]


epoch 13, loss: 0.465084907112452
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:15<00:00, 15.50it/s]


Dev F1 0.5971631170493242


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.25it/s]


epoch 14, loss: 0.46390028524385113
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:15<00:00, 15.45it/s]


Dev F1 0.6004762074623671


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:34<00:00, 18.91it/s]


epoch 15, loss: 0.46270275421178175
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 16.20it/s]


Dev F1 0.6029737027073357


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:33<00:00, 19.27it/s]


epoch 16, loss: 0.46155814851019505
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 15.86it/s]


Dev F1 0.6062441130298273


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:34<00:00, 19.20it/s]


epoch 17, loss: 0.46048997788708573
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:14<00:00, 15.96it/s]


Dev F1 0.6090542672685487


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:41<00:00, 15.84it/s]


epoch 18, loss: 0.45968303405558536
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:15<00:00, 15.30it/s]


Dev F1 0.6076510925407361


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:29<00:00, 22.00it/s]


epoch 19, loss: 0.458705662648446
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:15<00:00, 15.09it/s]


Dev F1 0.6076510925407361


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:36<00:00, 17.94it/s]


epoch 20, loss: 0.45792025317413704
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:16<00:00, 14.59it/s]


Dev F1 0.6104536551817852


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:29<00:00, 21.89it/s]


epoch 21, loss: 0.45718367715547314
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:12<00:00, 18.91it/s]


Dev F1 0.6113950375584221


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:34<00:00, 18.68it/s]


epoch 22, loss: 0.4564702615917733
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:15<00:00, 14.90it/s]


Dev F1 0.6164747280870121


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:36<00:00, 18.12it/s]


epoch 23, loss: 0.4557794091815149
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:16<00:00, 14.32it/s]


Dev F1 0.6178488350983359


100%|████████████████████████████████████████████████████████████████████████████████| 653/653 [00:37<00:00, 17.48it/s]


epoch 24, loss: 0.45513034261660407
Evaluating dev...


100%|████████████████████████████████████████████████████████████████████████████████| 237/237 [00:16<00:00, 14.29it/s]

Dev F1 0.6233089934062969





HVDClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 512, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=512, out_features=512, bias=True)
              (key): Linear(in_features=512, out_features=512, bias=True)
              (value): Linear(in_features=512, out_features=512, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=512, out_features=512, bias=True)
              (LayerNorm): LayerNorm((512,), eps=1e-12, elementwise_affine=True