### Experiments with Large Language Models

In [1]:
# General imports
import pandas as pd
import numpy as np
from tqdm import tqdm

#  PyTorch imports
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn

# Transformers imports
from transformers import T5Tokenizer, T5ForConditionalGeneration, PreTrainedTokenizer
from transformers import Trainer, TrainingArguments, AdamW

# Types
from typing import List, Dict, Tuple, Union

# Set seed
torch.manual_seed(42)
np.random.seed(42)



#### Data

In [2]:
wiki = []
with open('data/wiki.txt', 'r', encoding='utf8') as f:
    for line in f:
        wiki.append(line)

wiki = pd.DataFrame(wiki, columns=['text'])
wiki.head()

Unnamed: 0,text
0,Khatchig Mouradian. Khatchig Mouradian is a jo...
1,Jacob Henry Studer. Jacob Henry Studer (26 Feb...
2,"John Stephen. Born in Glasgow, Stephen became ..."
3,Georgina Willis. Georgina Willis is an award w...
4,Stanley Corrsin. Corrsin was born on 3 April 1...


In [3]:
# Read tsv file
birth_places_train = pd.read_csv('data/birth_places_train.tsv', sep='\t') 
birth_places_train.head()

Unnamed: 0,Where was Khatchig Mouradian born?,Lebanon
0,Where was Jacob Henry Studer born?,Columbus
1,Where was John Stephen born?,Glasgow
2,Where was Georgina Willis born?,Australia
3,Where was Stanley Corrsin born?,Philadelphia
4,Where was Eduard Ender born?,Rome


In [4]:
birth_places_test = pd.read_csv('data/birth_places_test.tsv', sep='\t')
birth_places_test.head()

Unnamed: 0,Where was Bryan Dubreuiel born?,Atlanta
0,Where was Ralf Wadephul born?,Berlin
1,Where was Joseph Baggaley born?,England
2,Where was Sandhya Sanjana born?,Mumbai
3,Where was Alfred Mele born?,Detroit
4,Where was Murray Esler born?,Geelong


### Build a single file dataset with $(q, c, a)$ structure

In [5]:
def build_dataset_file(
    qa_df: pd.DataFrame, # questions and answers
    wiki_df: pd.DataFrame, # context
    filename: str # name of the file to be created
) -> None:

    """
    Build dataset file for T5 training
    """
    qa_df_values = qa_df.values
    wiki_df_values = wiki_df.values

    matches = {}
    for i in tqdm(range(len(qa_df_values)), desc='Matching questions and contexts', total=len(qa_df_values)):
        [q, a] = qa_df_values[i]
        person = q.split(' ')[2:-1] # Get the name of the person
        person = ' '.join(person) # Join the name of the person

        # Find the context of the person
        for j in range(len(wiki_df_values)):
            c = wiki_df_values[j][0] # There is only one column in the wiki DataFrame
            if person in c:
                # Remove new line characters from context
                c = c.replace('\n', '')
                matches[person] = (q, c, a)
                break
    
    print(f'Number of entries preprocessed: {len(matches.keys())}')

    # Build the dataset file
    # Create a new csv file with the columns: question, context, answer
    with open(f'data/{filename}.tsv', 'w', encoding='utf8') as f:
        for k in matches.keys():
            q, c, a = matches[k]
            # Write the question, context and answer in the file on the same line
            f.write(f'{q}\t{c}\t{a}\n')
            
    

In [6]:
build_dataset_file(birth_places_train, wiki, 'birth_places_train_clean')
build_dataset_file(birth_places_test, wiki, 'birth_places_test_clean')

Matching questions and contexts: 100%|██████████| 1999/1999 [00:00<00:00, 3008.66it/s]


Number of entries preprocessed: 1990


Matching questions and contexts: 100%|██████████| 499/499 [00:00<00:00, 1299.99it/s]


Number of entries preprocessed: 499


In [7]:
# Read the csv file
train_dataset_df = pd.read_csv('data/birth_places_train_clean.tsv', sep='\t', names=['question', 'context', 'answer'])
train_dataset_df.head()

Unnamed: 0,question,context,answer
0,Where was Jacob Henry Studer born?,Jacob Henry Studer. Jacob Henry Studer (26 Feb...,Columbus
1,Where was John Stephen born?,"John Stephen. Born in Glasgow, Stephen became ...",Glasgow
2,Where was Georgina Willis born?,Georgina Willis. Georgina Willis is an award w...,Australia
3,Where was Stanley Corrsin born?,Stanley Corrsin. Corrsin was born on 3 April 1...,Philadelphia
4,Where was Eduard Ender born?,Eduard Ender. Eduard Ender (3 March 1822 Rome ...,Rome


In [8]:
test_dataset_df = pd.read_csv('data/birth_places_test_clean.tsv', sep='\t', names=['question', 'context', 'answer'])
test_dataset_df.head()

Unnamed: 0,question,context,answer
0,Where was Ralf Wadephul born?,"Ralf Wadephul. Ralf Wadephul, born 1958 in Ber...",Berlin
1,Where was Joseph Baggaley born?,Joseph Baggaley. Joseph Baggaley (c. 1884 -- 1...,England
2,Where was Sandhya Sanjana born?,Sandhya Sanjana. Sandhya Sanjana (); is a sing...,Mumbai
3,Where was Alfred Mele born?,"Alfred Mele. Born in Detroit, Michigan, Mele a...",Detroit
4,Where was Murray Esler born?,Murray Esler. Professor Murray David Esler (bo...,Geelong


### Create utility functions for the $(q, c, a)$ scenario

In [9]:
# process the examples in input and target text format and the eos token at the end 
def add_eos_to_examples(entry: Tuple[str, str, str]) -> Tuple[str, str]:
    question, context, answer = entry
    result = {}
    result['input_text'] = 'question: %s  context: %s </s>' % (question, context)
    result['target_text'] = '%s </s>' % (answer)
    return result

# tokenize the examples
def convert_to_features(qa_entry: Tuple[str, str, str], tokenizer: PreTrainedTokenizer) -> Dict[str, List[int]]:
    input_encodings = tokenizer.encode_plus(text=qa_entry['input_text'], padding='max_length', max_length=512, truncation=True) # the context is limited to 512 tokens
    target_encodings = tokenizer.encode_plus(text=qa_entry['target_text'], padding='max_length', max_length=32, truncation=True) # the answer is limited to 32 tokens, which should conver most of the answers 

    encodings = {
        'input_ids': np.array(input_encodings['input_ids']), 
        'attention_mask': np.array(input_encodings['attention_mask']),
        'target_ids': np.array(target_encodings['input_ids']),
        'target_attention_mask': np.array(target_encodings['attention_mask'])
    }

    return encodings

In [10]:
# Test the functions on a sample
tokenizer = T5Tokenizer.from_pretrained('t5-small')
sample = train_dataset_df.sample(1)
sample = add_eos_to_examples(sample.values[0])
sample = convert_to_features(sample, tokenizer)

For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
- Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.


In [11]:
class BirthPlaceDataset(Dataset):
    """
    Dataset for birth place prediction made on a (q, c, a) format,
    where q is the question, c is the context and a is the answer.

    Args:
        df (pd.DataFrame): DataFrame containing the questions, contexts and answers
    """
    def __init__(self, df: pd.DataFrame) -> None:
        self.df = df
        self.df_values = self.df.values
        self.tokenizer = T5Tokenizer.from_pretrained('t5-small')

    def __len__(self) -> int:
        return len(self.df)
    
    def __getitem__(self, idx: int) -> Dict[str, List[int]]:
        entry = self.df_values[idx]
        entry = add_eos_to_examples(entry)
        entry = convert_to_features(entry, self.tokenizer)
        return entry

In [12]:
# Create two datasets, one for training and one for validation
train_dataset = BirthPlaceDataset(train_dataset_df)
val_dataset = BirthPlaceDataset(test_dataset_df)

# Check if each element in the list of batches has the same shape

# Check for input_ids
for i in range(len(val_dataset)):
    entry = train_dataset[i]
    if entry['input_ids'].shape != (512,):
        print(f'Entry {i} has a different shape for input_ids: {entry["input_ids"].shape}')

# Check for attention_mask
for i in range(len(val_dataset)):
    entry = train_dataset[i]
    if entry['attention_mask'].shape != (512,):
        print(f'Entry {i} has a different shape for attention_mask: {entry["attention_mask"].shape}')

# Check for target_ids
for i in range(len(val_dataset)):
    entry = train_dataset[i]
    if entry['target_ids'].shape != (32,):
        print(f'Entry {i} has a different shape for target_ids: {entry["target_ids"].shape}')

# Check for target_attention_mask
for i in range(len(val_dataset)):
    entry = train_dataset[i]
    if entry['target_attention_mask'].shape != (32,):
        print(f'Entry {i} has a different shape for target_attention_mask: {entry["target_attention_mask"].shape}')

#### A sample from new the new dataset

In [13]:
print(f'Number of training examples: {len(train_dataset)}')
print(f'Number of validation examples: {len(val_dataset)}')

Number of training examples: 1990
Number of validation examples: 499


In [14]:
## Load t5 small model for question answering
model = T5ForConditionalGeneration.from_pretrained('t5-small', return_dict=True)

In [15]:
tokenizer = T5Tokenizer.from_pretrained('t5-small', return_dict=True)

#### Testing the t5 model

In [16]:
# Test model
input_ids = tokenizer.encode("translate English to German: How old are you?", return_tensors="pt")
outputs = model.generate(input_ids, max_length=40, num_beams=4, early_stopping=True)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Wie alt sind Sie?


#### [ Task 1 ] Cold evaluation of the pretrained T5-small model

The evaluation script was taken from SQuAD dataset and it validates the exact matches of the model answers compared with the real ones.

In [17]:
# Test dataloader
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Test is each batch has the correct shape
for batch in train_dataloader:
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    target_ids = batch['target_ids'].to(device)
    target_attention_mask = batch['target_attention_mask'].to(device)
    assert input_ids.shape == (8, 512), f'input_ids has a shape of {input_ids.shape}'
    assert attention_mask.shape == (8, 512), f'attention_mask has a shape of {attention_mask.shape}'
    assert target_ids.shape == (8, 32), f'target_ids has a shape of {target_ids.shape}'
    assert target_attention_mask.shape == (8, 32), f'target_attention_mask has a shape of {target_attention_mask.shape}'
    break

#### Evaluation script

In [18]:
from collections import Counter
import string
import re


def normalize_answer(s: str) -> str:
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def evaluate(gold_answers, predictions):
    f1 = exact_match = total = 0

    for ground_truths, prediction in zip(gold_answers, predictions):
      total += 1
      exact_match += metric_max_over_ground_truths(
                    exact_match_score, prediction, ground_truths)
      f1 += metric_max_over_ground_truths(
          f1_score, prediction, ground_truths)
    
    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}

In [19]:
answers = []
for batch in tqdm(val_dataloader):
  outs = model.generate(input_ids=batch['input_ids'], 
                        attention_mask=batch['attention_mask'],
                        max_length=32,
                        early_stopping=True)
  outs = [tokenizer.decode(ids) for ids in outs]
  answers.extend(outs)

100%|██████████| 63/63 [03:15<00:00,  3.10s/it]


In [20]:
predictions = []
references = []
for ref, pred in zip(val_dataset, answers):
  a = ref['target_ids']  
  a = tokenizer.decode(a)

  # Remove padding
  a = a.replace('<pad>', '')
  a = a.replace('</s>', '')
  a = a.replace('<s>', '')
  
  # Remove all whitespace from the beginning and end
  a = a.rstrip()

  pred = pred.replace('<pad>', '')
  pred = pred.replace('</s>', '')
  pred = pred.replace('<s>', '')
  
  # Remove all whitespace
  pred.replace('', '')

  predictions.append(pred)
  references.append(a)
  

In [21]:
predictions[11], references[11]

(' Oxford, England', 'Washington')

In [173]:
evaluate(references, predictions)

{'exact_match': 0.0, 'f1': 0.10020040080160321}

#### Conclusion on preliminary results

TBD

#### [Task 2] Finetuning process

In [43]:
from transformers import EvalPrediction
from transformers import (
    HfArgumentParser,
    DataCollator,
    Trainer,
    TrainingArguments,
    set_seed,
    DefaultDataCollator,
)

import logging
from dataclasses import dataclass, field
from typing import Optional
import os

In [44]:
# Setup logging
logger = logging.getLogger(__name__)

##### Data collator 
- This is used to transform batches

In [45]:
@dataclass
class T5DataCollator(DefaultDataCollator):
    def collate_batch(self, batch: List) -> Dict[str, torch.Tensor]:
        """
        Take a list of samples from a Dataset and collate them into a batch.
        Returns:
            A dictionary of tensors
        """
        input_ids = torch.stack([example['input_ids'] for example in batch])
        lm_labels = torch.stack([example['target_ids'] for example in batch])
        lm_labels[lm_labels[:, :] == 0] = -100
        attention_mask = torch.stack([example['attention_mask'] for example in batch])
        decoder_attention_mask = torch.stack([example['target_attention_mask'] for example in batch])
        

        return {
            'input_ids': input_ids, 
            'attention_mask': attention_mask,
            'lm_labels': lm_labels, 
            'decoder_attention_mask': decoder_attention_mask
        }


In [46]:
@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
    )

In [47]:

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """
    train_file_path: Optional[str] = field(
        default='train_data.pt',
        metadata={"help": "Path for cached train dataset"},
    )
    valid_file_path: Optional[str] = field(
        default='valid_data.pt',
        metadata={"help": "Path for cached valid dataset"},
    )
    max_len: Optional[int] = field(
        default=512,
        metadata={"help": "Max input length for the source text"},
    )
    target_max_len: Optional[int] = field(
        default=32,
        metadata={"help": "Max input length for the target text"},
    )

Configure logger

In [48]:
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO
    )

In [49]:
logger.info("Training/evaluation processs")

01/03/2023 14:15:16 - INFO - __main__ -   Training/evaluation processs


In [71]:
import json

fine_tune_model = model
fine_tune_tokenizer = tokenizer
fine_tune_train_dataset = train_dataset
fine_tune_val_dataset = val_dataset

args_dict = {
  "model_name_or_path": 't5-base',
  "max_len": 512 ,
  "target_max_len": 16,
  "output_dir": './models/tpu',
  "overwrite_output_dir": True,
  "per_gpu_train_batch_size": 8,
  "per_gpu_eval_batch_size": 8,
  "gradient_accumulation_steps": 4,
  "learning_rate": 1e-4,
  "tpu_num_cores": 8,
  "num_train_epochs": 4,
  "do_train": True
}

with open('args.json', 'w') as f:
  json.dump(args_dict, f)

hfparser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
model_args, data_args, training_args = hfparser.parse_json_file(json_file=os.path.abspath('args.json'))

if (
    os.path.exists(training_args.output_dir)
    and os.listdir(training_args.output_dir)
    and training_args.do_train
    and not training_args.overwrite_output_dir
):
    raise ValueError(
        f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
    )

logging.basicConfig(
      format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
      datefmt="%m/%d/%Y %H:%M:%S",
      level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
  )

logger.warning(
    "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
    training_args.local_rank,
    training_args.device,
    training_args.n_gpu,
    bool(training_args.local_rank != -1),
    training_args.fp16,
)
logger.info("Training/evaluation parameters %s", training_args)

set_seed(training_args.seed)



PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).
01/03/2023 15:14:15 - INFO - __main__ -   Training/evaluation parameters TrainingArguments(
_n_gpu=0,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=False,
do_predict=False,
do_train=True,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_min_num_params=0,
fsdp_transfor

In [74]:
""" 
Trace the model according to the data collator

'input_ids': input_ids, 
'attention_mask': attention_mask,
'lm_labels': lm_labels, 
'decoder_attention_mask': decoder_attention_mask
"""
traced_model = torch.jit.trace(func=fine_tune_model, example_inputs=(
    
))

  if causal_mask.shape[1] < attention_mask.shape[1]:


RuntimeError: Tracer cannot infer type of Seq2SeqLMOutput(loss=None, logits=tensor([[[-65.7776, -15.1142, -28.7621,  ..., -81.6770, -81.8706, -81.9388],
         [-41.7584,  -6.0590, -17.8186,  ..., -50.3439, -50.4415, -50.4755],
         [-18.5064,  -5.2212, -11.1655,  ..., -40.1127, -40.0975, -40.1528],
         ...,
         [-22.1067,  -6.7866, -12.6071,  ..., -44.5598, -44.5704, -44.6414],
         [-22.1067,  -6.7866, -12.6071,  ..., -44.5598, -44.5704, -44.6414],
         [-21.4868,  -6.6022, -12.3518,  ..., -44.0760, -44.0847, -44.1546]]],
       grad_fn=<UnsafeViewBackward0>), past_key_values=((tensor([[[[ 1.3385e+00, -2.1127e-02, -6.5480e-01,  ...,  9.4214e-02,
            7.5609e-03, -3.3276e-01],
          [ 8.8215e-01,  3.2993e-01, -1.1180e+00,  ..., -4.5154e-01,
           -1.8130e-01, -1.0115e+00],
          [-4.8649e-01, -2.3323e+00, -1.1428e+00,  ..., -2.5693e+00,
           -1.7539e+00, -5.6927e-01],
          ...,
          [-4.8649e-01, -2.3323e+00, -1.1428e+00,  ..., -2.5693e+00,
           -1.7539e+00, -5.6927e-01],
          [-4.8649e-01, -2.3323e+00, -1.1428e+00,  ..., -2.5693e+00,
           -1.7539e+00, -5.6927e-01],
          [-4.8649e-01, -2.3323e+00, -1.1428e+00,  ..., -2.5693e+00,
           -1.7539e+00, -5.6927e-01]],

         [[ 6.0694e-01,  6.0701e-01, -2.1880e+00,  ...,  7.7726e-01,
            1.6705e+00,  9.0689e-01],
          [-7.1805e-02, -1.8361e+00, -6.6701e-01,  ..., -6.2895e-01,
            9.2584e-02,  7.8146e-01],
          [-9.5649e-01, -2.3581e+00, -2.4784e-01,  ...,  1.1193e+00,
            7.7703e-01,  5.3768e-01],
          ...,
          [-9.5649e-01, -2.3581e+00, -2.4784e-01,  ...,  1.1193e+00,
            7.7703e-01,  5.3768e-01],
          [-9.5649e-01, -2.3581e+00, -2.4784e-01,  ...,  1.1193e+00,
            7.7703e-01,  5.3768e-01],
          [-9.5649e-01, -2.3581e+00, -2.4784e-01,  ...,  1.1193e+00,
            7.7703e-01,  5.3768e-01]],

         [[-4.1183e+00, -1.6686e+00, -5.0046e-01,  ...,  4.7639e-01,
           -1.4736e+00, -2.5875e+00],
          [ 1.0025e+00,  1.4873e-01, -5.9319e-01,  ...,  1.4527e+00,
            1.2500e+00,  1.5554e-01],
          [ 1.0704e+00, -2.5709e+00,  6.8917e-01,  ..., -1.3349e+00,
            2.7677e-01,  5.7564e-01],
          ...,
          [ 1.0704e+00, -2.5709e+00,  6.8917e-01,  ..., -1.3349e+00,
            2.7677e-01,  5.7564e-01],
          [ 1.0704e+00, -2.5709e+00,  6.8917e-01,  ..., -1.3349e+00,
            2.7677e-01,  5.7564e-01],
          [ 1.0704e+00, -2.5709e+00,  6.8917e-01,  ..., -1.3349e+00,
            2.7677e-01,  5.7564e-01]],

         ...,

         [[-9.1176e-01, -3.3892e-01,  1.5311e+00,  ..., -8.2606e-01,
            1.9128e+00,  2.8557e+00],
          [ 2.9595e-01, -2.4671e+00, -1.4441e+00,  ..., -3.7562e-01,
            1.3632e+00, -3.5399e-01],
          [-6.6602e-01,  2.2657e+00,  3.1239e-01,  ...,  1.4158e+00,
           -2.9123e+00, -1.1275e+00],
          ...,
          [-6.6602e-01,  2.2657e+00,  3.1239e-01,  ...,  1.4158e+00,
           -2.9123e+00, -1.1275e+00],
          [-6.6602e-01,  2.2657e+00,  3.1239e-01,  ...,  1.4158e+00,
           -2.9123e+00, -1.1275e+00],
          [-6.6602e-01,  2.2657e+00,  3.1239e-01,  ...,  1.4158e+00,
           -2.9123e+00, -1.1275e+00]],

         [[-1.8776e+00, -4.7778e-01, -2.2985e+00,  ..., -1.1768e+00,
           -1.5922e+00, -6.9209e-01],
          [-7.9584e-01, -5.1810e-01, -1.6686e+00,  ..., -1.7493e-01,
            6.5511e-01,  6.8859e-01],
          [-6.0374e+00, -1.1453e+00, -1.5700e+00,  ..., -1.3067e+00,
           -1.8851e-01,  1.6042e+00],
          ...,
          [-6.0374e+00, -1.1453e+00, -1.5700e+00,  ..., -1.3067e+00,
           -1.8851e-01,  1.6042e+00],
          [-6.0374e+00, -1.1453e+00, -1.5700e+00,  ..., -1.3067e+00,
           -1.8851e-01,  1.6042e+00],
          [-6.0374e+00, -1.1453e+00, -1.5700e+00,  ..., -1.3067e+00,
           -1.8851e-01,  1.6042e+00]],

         [[-2.0962e-01, -4.8182e-03,  1.1541e+00,  ...,  2.8713e-01,
            9.2157e-01,  9.4773e-01],
          [ 1.2619e+00, -4.5746e-01,  1.5478e+00,  ...,  1.1051e+00,
            6.8085e-01, -1.2374e+00],
          [ 4.8463e-02, -3.5598e-02, -6.9082e-02,  ...,  6.1279e-01,
           -2.4851e-01, -2.9536e-01],
          ...,
          [ 4.8463e-02, -3.5598e-02, -6.9082e-02,  ...,  6.1279e-01,
           -2.4851e-01, -2.9536e-01],
          [ 4.8463e-02, -3.5598e-02, -6.9082e-02,  ...,  6.1279e-01,
           -2.4851e-01, -2.9536e-01],
          [ 4.8463e-02, -3.5598e-02, -6.9082e-02,  ...,  6.1279e-01,
           -2.4851e-01, -2.9536e-01]]]], grad_fn=<TransposeBackward0>), tensor([[[[-0.0824,  0.2075, -0.1973,  ..., -0.3878,  0.6802, -0.1390],
          [ 0.0833,  1.1035, -1.4480,  ..., -0.4279,  0.3702,  0.1257],
          [ 0.2443,  0.3230, -0.2295,  ...,  0.7033,  0.2849,  0.1767],
          ...,
          [ 0.2443,  0.3230, -0.2295,  ...,  0.7033,  0.2849,  0.1767],
          [ 0.2443,  0.3230, -0.2295,  ...,  0.7033,  0.2849,  0.1767],
          [ 0.2443,  0.3230, -0.2295,  ...,  0.7033,  0.2849,  0.1767]],

         [[ 0.4198,  0.4446, -0.5863,  ..., -0.0964,  1.0227, -0.7134],
          [ 2.0778, -1.2436, -0.2788,  ...,  0.1330,  1.6115, -1.6080],
          [-0.1642,  0.5796, -0.4218,  ..., -0.0532, -0.0906, -0.2609],
          ...,
          [-0.1642,  0.5796, -0.4218,  ..., -0.0532, -0.0906, -0.2609],
          [-0.1642,  0.5796, -0.4218,  ..., -0.0532, -0.0906, -0.2609],
          [-0.1642,  0.5796, -0.4218,  ..., -0.0532, -0.0906, -0.2609]],

         [[-0.1940, -0.5726, -0.9271,  ..., -0.0872,  0.5148, -0.2317],
          [-0.7428,  0.3852, -0.7413,  ...,  0.0935, -0.9202,  0.7779],
          [ 0.1343,  0.2932, -0.0561,  ..., -0.0880, -0.8827,  0.1057],
          ...,
          [ 0.1343,  0.2932, -0.0561,  ..., -0.0880, -0.8827,  0.1057],
          [ 0.1343,  0.2932, -0.0561,  ..., -0.0880, -0.8827,  0.1057],
          [ 0.1343,  0.2932, -0.0561,  ..., -0.0880, -0.8827,  0.1057]],

         ...,

         [[-0.4604, -1.0679,  1.2409,  ...,  0.6219, -2.6825, -0.2869],
          [ 0.4231, -0.5454,  0.5188,  ...,  0.5438, -0.6774, -1.2099],
          [ 0.3252, -0.2050, -0.0296,  ..., -0.2699,  0.0973, -0.1057],
          ...,
          [ 0.3252, -0.2050, -0.0296,  ..., -0.2699,  0.0973, -0.1057],
          [ 0.3252, -0.2050, -0.0296,  ..., -0.2699,  0.0973, -0.1057],
          [ 0.3252, -0.2050, -0.0296,  ..., -0.2699,  0.0973, -0.1057]],

         [[-0.7030,  0.9332,  1.1087,  ...,  1.3908,  0.1445, -0.3475],
          [-0.1201, -0.8176, -0.1246,  ...,  0.1747, -0.3611, -1.1205],
          [-0.2374, -0.0480,  0.0079,  ...,  0.0539,  0.0278, -0.1500],
          ...,
          [-0.2374, -0.0480,  0.0079,  ...,  0.0539,  0.0278, -0.1500],
          [-0.2374, -0.0480,  0.0079,  ...,  0.0539,  0.0278, -0.1500],
          [-0.2374, -0.0480,  0.0079,  ...,  0.0539,  0.0278, -0.1500]],

         [[ 0.0535, -1.3293, -0.3338,  ..., -1.3226, -0.1325, -0.6018],
          [-0.0849, -0.5589,  1.0846,  ..., -1.2484,  1.1162,  1.3824],
          [ 0.0533, -0.7683,  0.2670,  ...,  0.2254, -0.1897,  0.0618],
          ...,
          [ 0.0533, -0.7683,  0.2670,  ...,  0.2254, -0.1897,  0.0618],
          [ 0.0533, -0.7683,  0.2670,  ...,  0.2254, -0.1897,  0.0618],
          [ 0.0533, -0.7683,  0.2670,  ...,  0.2254, -0.1897,  0.0618]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[ 0.3177, -3.9350, -3.6291,  ...,  1.7309, -3.8441, -0.3143],
          [-0.3580, -4.0371, -3.1092,  ...,  1.4634, -3.7024,  0.3243],
          [-0.9002, -3.2189, -1.6399,  ..., -1.9796,  0.3990, -0.5770],
          ...,
          [ 0.9548, -3.3546, -1.8183,  ..., -1.2851,  0.2681, -3.2052],
          [ 0.9548, -3.3546, -1.8183,  ..., -1.2851,  0.2681, -3.2052],
          [ 0.9548, -3.3546, -1.8183,  ..., -1.2851,  0.2681, -3.2052]],

         [[ 1.0282,  3.4647,  1.6634,  ..., -1.1989,  0.3023, -0.9329],
          [ 1.1036,  3.3154,  1.6044,  ..., -0.6827,  0.1526, -0.8189],
          [-0.5273, -1.5256, -1.2689,  ..., -0.3888, -1.4180,  0.0622],
          ...,
          [-1.4133,  2.2302,  3.3941,  ..., -1.3874,  3.0916, -1.6720],
          [-1.4133,  2.2302,  3.3941,  ..., -1.3874,  3.0916, -1.6720],
          [-1.4133,  2.2302,  3.3941,  ..., -1.3874,  3.0916, -1.6720]],

         [[-0.1954, -0.5433,  0.2049,  ..., -1.2494,  0.6348,  1.0990],
          [-0.2569, -0.7149,  0.1418,  ..., -1.4181,  1.0893,  0.8436],
          [-1.4736, -0.7935,  0.0396,  ..., -1.0879, -1.1381,  1.1095],
          ...,
          [-0.2104, -0.3690, -0.4043,  ..., -1.2839, -1.0930, -0.1400],
          [-0.2104, -0.3690, -0.4043,  ..., -1.2839, -1.0930, -0.1400],
          [-0.2104, -0.3690, -0.4043,  ..., -1.2839, -1.0930, -0.1400]],

         ...,

         [[-1.2038,  2.5078, -4.2352,  ...,  2.7201, -1.0289,  1.9152],
          [-1.3249,  2.2833, -3.9407,  ...,  2.6706, -1.1244,  2.0280],
          [-1.2594, -0.8047, -2.1883,  ..., -2.0619, -1.7853, -1.6695],
          ...,
          [-1.1642,  1.0709, -1.8092,  ...,  0.8250, -1.5550, -1.0607],
          [-1.1642,  1.0709, -1.8092,  ...,  0.8250, -1.5550, -1.0607],
          [-1.1642,  1.0709, -1.8092,  ...,  0.8250, -1.5550, -1.0607]],

         [[-0.8177, -0.4841, -1.4319,  ...,  0.4529,  1.4253,  0.7027],
          [-1.0763, -0.1472, -1.3071,  ...,  0.4464,  1.8371,  0.4835],
          [-1.2741, -0.9633, -0.2152,  ...,  0.9735,  0.5304,  2.2718],
          ...,
          [ 3.0439, -1.0459, -2.5153,  ...,  1.0619,  1.0960,  2.4216],
          [ 3.0439, -1.0459, -2.5153,  ...,  1.0619,  1.0960,  2.4216],
          [ 3.0439, -1.0459, -2.5153,  ...,  1.0619,  1.0960,  2.4216]],

         [[ 4.3995, -1.2718,  0.9241,  ..., -2.2024, -2.2289,  1.2020],
          [ 4.3145, -1.2971,  1.3051,  ..., -2.3536, -2.2816,  1.2488],
          [ 1.4610, -0.4797,  0.0719,  ..., -0.1048,  1.2458,  0.7922],
          ...,
          [ 0.9580,  0.3589, -0.3716,  ...,  2.3034,  1.7044,  0.7004],
          [ 0.9580,  0.3589, -0.3716,  ...,  2.3034,  1.7044,  0.7004],
          [ 0.9580,  0.3589, -0.3716,  ...,  2.3034,  1.7044,  0.7004]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[-1.3413, -0.8591,  0.8721,  ..., -1.0178,  0.3812, -0.4476],
          [-1.2254, -0.8781,  1.0317,  ..., -1.3125,  0.4713, -0.2220],
          [ 4.0492,  0.1291, -1.7776,  ...,  2.0694,  0.4045,  3.3434],
          ...,
          [ 3.9598, -0.5400,  1.2159,  ..., -1.1515, -0.7557,  1.9474],
          [ 3.9598, -0.5400,  1.2159,  ..., -1.1515, -0.7557,  1.9474],
          [ 3.9598, -0.5400,  1.2159,  ..., -1.1515, -0.7557,  1.9474]],

         [[ 3.2664,  0.0126, -1.0211,  ...,  0.6904, -0.5135, -2.4190],
          [ 4.0803, -0.0408, -0.8988,  ...,  0.9364, -0.5746, -2.2241],
          [-1.1274,  1.8210,  1.5411,  ..., -0.5158, -1.6825,  1.6767],
          ...,
          [-2.4985, -3.4522,  1.2634,  ...,  2.3362, -0.5794,  1.2483],
          [-2.4985, -3.4522,  1.2634,  ...,  2.3362, -0.5794,  1.2483],
          [-2.4985, -3.4522,  1.2634,  ...,  2.3362, -0.5794,  1.2483]],

         [[-3.3306, -1.2058,  5.9638,  ...,  1.0916, -1.3607, -2.3798],
          [-3.7534, -1.9346,  6.5004,  ...,  0.8584, -1.4754, -2.8201],
          [ 2.8999, -0.7341,  2.2069,  ..., -1.0812,  1.2914,  0.7666],
          ...,
          [-0.0801, -0.8147,  0.0081,  ..., -4.3985, -3.0940, -1.7191],
          [-0.0801, -0.8147,  0.0081,  ..., -4.3985, -3.0940, -1.7191],
          [-0.0801, -0.8147,  0.0081,  ..., -4.3985, -3.0940, -1.7191]],

         ...,

         [[-1.8996, -7.0471, -1.4512,  ...,  2.7725, -1.2509, -0.2206],
          [-2.6427, -7.1488, -1.6779,  ...,  2.5620, -1.4048,  0.5921],
          [-1.5242, -2.5171, -0.9094,  ...,  1.2103, -0.3074, -0.5197],
          ...,
          [-0.6421,  1.7312,  2.1411,  ...,  2.6554, -0.3607, -1.7789],
          [-0.6421,  1.7312,  2.1411,  ...,  2.6554, -0.3607, -1.7789],
          [-0.6421,  1.7312,  2.1411,  ...,  2.6554, -0.3607, -1.7789]],

         [[-0.0702, -0.0514,  3.1435,  ...,  1.2925,  3.5129,  4.6858],
          [ 0.0512, -0.9206,  3.7547,  ...,  1.2695,  3.1553,  4.5129],
          [ 2.8442,  1.8430, -0.8938,  ...,  2.4492, -2.0961, -0.9872],
          ...,
          [ 1.4987,  2.1461,  1.0542,  ...,  1.4517,  1.4335,  0.2328],
          [ 1.4987,  2.1461,  1.0542,  ...,  1.4517,  1.4335,  0.2328],
          [ 1.4987,  2.1461,  1.0542,  ...,  1.4517,  1.4335,  0.2328]],

         [[ 5.2495, -1.2338,  1.6963,  ..., -4.9547,  0.6053, -2.0423],
          [ 5.2393, -1.6319,  1.5410,  ..., -4.9861, -0.0081, -1.5080],
          [-1.6751, -2.0444,  1.0106,  ...,  2.5006,  1.2771, -0.3089],
          ...,
          [ 1.8420, -1.8255,  0.3454,  ...,  0.8042,  1.1356,  0.5579],
          [ 1.8420, -1.8255,  0.3454,  ...,  0.8042,  1.1356,  0.5579],
          [ 1.8420, -1.8255,  0.3454,  ...,  0.8042,  1.1356,  0.5579]]]],
       grad_fn=<TransposeBackward0>)), (tensor([[[[-0.8197, -0.4256,  1.8877,  ...,  0.9484,  0.5055, -1.9370],
          [-0.0439, -0.7577,  1.1045,  ...,  0.6504,  0.1040, -1.6059],
          [ 1.0338,  0.7745,  1.9528,  ...,  1.4239, -1.4169,  1.1229],
          ...,
          [ 1.0011,  0.7205,  2.0093,  ...,  1.4276, -1.3676,  1.0853],
          [ 1.0011,  0.7205,  2.0093,  ...,  1.4276, -1.3676,  1.0853],
          [ 1.0008,  0.7239,  2.0012,  ...,  1.4262, -1.3734,  1.0911]],

         [[-0.8665, -0.3987, -0.2463,  ...,  0.8980, -0.3249, -1.3932],
          [-0.9879, -0.4100, -2.4115,  ...,  1.7994, -0.2627,  0.3576],
          [-0.8021, -1.2040, -4.4624,  ...,  3.5764,  0.5177, -0.2089],
          ...,
          [-0.8392, -1.1486, -4.4196,  ...,  3.5559,  0.5139, -0.2675],
          [-0.8392, -1.1486, -4.4196,  ...,  3.5559,  0.5139, -0.2675],
          [-0.8393, -1.1497, -4.4143,  ...,  3.5508,  0.5144, -0.2584]],

         [[ 0.5992,  0.1840,  0.2681,  ..., -1.3550,  1.7431,  1.1039],
          [-0.6981,  0.1419,  0.0728,  ..., -1.7896, -0.2006, -0.5161],
          [-2.9313, -0.0925, -0.3186,  ..., -0.9284,  1.0810,  0.9023],
          ...,
          [-2.8733, -0.0841, -0.3168,  ..., -0.9093,  1.1449,  0.9988],
          [-2.8733, -0.0841, -0.3168,  ..., -0.9093,  1.1449,  0.9988],
          [-2.8856, -0.0881, -0.3162,  ..., -0.8964,  1.1491,  0.9808]],

         ...,

         [[ 0.9272, -0.6229,  0.4131,  ...,  0.4751,  0.0178, -1.6106],
          [ 0.2338,  1.0198,  0.7435,  ...,  0.7083,  0.1688, -1.6569],
          [ 0.6601,  0.7652, -0.1466,  ..., -0.1422,  1.4740, -1.2777],
          ...,
          [ 0.6873,  0.6796, -0.0919,  ..., -0.1732,  1.4421, -1.3430],
          [ 0.6873,  0.6796, -0.0919,  ..., -0.1732,  1.4421, -1.3430],
          [ 0.6871,  0.6865, -0.0953,  ..., -0.1674,  1.4482, -1.3370]],

         [[ 1.3635,  1.6833, -0.3136,  ..., -0.7864, -0.1195, -0.6216],
          [ 0.7742,  2.0745, -0.4628,  ..., -0.2660, -0.3193, -1.5061],
          [-0.0630,  1.5618, -0.4183,  ..., -0.4422,  0.0974,  0.1489],
          ...,
          [-0.0648,  1.5046, -0.4015,  ..., -0.4457,  0.0877,  0.1350],
          [-0.0648,  1.5046, -0.4015,  ..., -0.4457,  0.0877,  0.1350],
          [-0.0659,  1.5037, -0.4068,  ..., -0.4464,  0.0860,  0.1385]],

         [[ 0.7057,  0.0560, -1.6706,  ..., -0.5827,  0.6200,  0.4592],
          [-0.3564, -0.5521, -1.4132,  ...,  0.1439,  0.2285, -0.8465],
          [ 0.2782, -0.3043, -0.0915,  ..., -0.8422,  0.4052, -0.1320],
          ...,
          [ 0.2615, -0.3369, -0.1424,  ..., -0.8432,  0.4175, -0.1542],
          [ 0.2615, -0.3369, -0.1424,  ..., -0.8432,  0.4175, -0.1542],
          [ 0.2556, -0.3430, -0.1363,  ..., -0.8480,  0.4249, -0.1539]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[-2.4389,  0.0447,  2.3461,  ...,  4.1580,  0.6201,  0.7772],
          [-0.8173, -1.8643,  0.3945,  ...,  1.4704,  0.2677, -0.8645],
          [ 0.2822,  0.0454, -0.0777,  ..., -0.1459, -0.4873,  0.5992],
          ...,
          [ 0.2289,  0.0289, -0.0312,  ..., -0.0759, -0.4573,  0.5470],
          [ 0.2289,  0.0289, -0.0312,  ..., -0.0759, -0.4573,  0.5470],
          [ 0.2346,  0.0377, -0.0380,  ..., -0.0793, -0.4586,  0.5470]],

         [[-0.2750,  0.6970, -3.4929,  ..., -0.7269, -0.6045, -0.3390],
          [-0.8370,  0.3688, -1.2070,  ...,  0.7268, -1.5642,  0.2255],
          [-0.8137, -0.3552,  0.0902,  ..., -0.2719,  0.4522, -1.2043],
          ...,
          [-0.8048, -0.4470,  0.0690,  ..., -0.2384,  0.4972, -1.1337],
          [-0.8048, -0.4470,  0.0690,  ..., -0.2384,  0.4972, -1.1337],
          [-0.8072, -0.4422,  0.0803,  ..., -0.2428,  0.5068, -1.1345]],

         [[-0.8001, -1.1292,  0.9122,  ...,  0.6650,  0.7827, -0.4761],
          [ 0.8156, -0.8598,  1.1241,  ..., -0.1841, -0.8172, -0.9108],
          [ 0.1755, -0.0517,  0.0967,  ...,  0.1579,  0.1885, -0.0168],
          ...,
          [ 0.2490, -0.0222,  0.1021,  ...,  0.2004,  0.2252,  0.0292],
          [ 0.2490, -0.0222,  0.1021,  ...,  0.2004,  0.2252,  0.0292],
          [ 0.2449, -0.0171,  0.0994,  ...,  0.1863,  0.2253,  0.0313]],

         ...,

         [[ 0.6036, -0.7570,  1.1259,  ...,  0.2214,  2.7549, -0.9662],
          [ 0.3464, -0.7547, -0.0930,  ...,  0.1462,  2.0962,  1.8812],
          [-0.0821,  0.1773,  0.0245,  ...,  0.2886,  0.1837,  0.1106],
          ...,
          [-0.0673,  0.1017, -0.0069,  ...,  0.3154,  0.1864,  0.1569],
          [-0.0673,  0.1017, -0.0069,  ...,  0.3154,  0.1864,  0.1569],
          [-0.0681,  0.1072, -0.0071,  ...,  0.3182,  0.1829,  0.1569]],

         [[ 0.8082, -2.1138,  2.9275,  ..., -1.9885, -2.9925, -0.9970],
          [ 1.5418, -2.7994,  0.9514,  ..., -1.3575, -1.1858, -1.7017],
          [-0.1102, -0.1967, -0.1862,  ..., -0.2992, -0.3021, -0.3429],
          ...,
          [-0.0609, -0.2937, -0.0664,  ..., -0.2604, -0.4256, -0.3472],
          [-0.0609, -0.2937, -0.0664,  ..., -0.2604, -0.4256, -0.3472],
          [-0.0601, -0.2823, -0.0715,  ..., -0.2526, -0.4191, -0.3447]],

         [[-0.8244, -2.3563, -0.3262,  ...,  1.1182, -1.3824,  5.4574],
          [ 1.9784, -1.2571,  3.0270,  ...,  0.4762, -1.5023,  3.4697],
          [-0.1546,  0.2244,  0.5591,  ..., -0.6704,  0.2306,  0.3655],
          ...,
          [-0.0706,  0.1689,  0.4607,  ..., -0.6710,  0.2013,  0.5756],
          [-0.0706,  0.1689,  0.4607,  ..., -0.6710,  0.2013,  0.5756],
          [-0.0795,  0.1759,  0.4683,  ..., -0.6714,  0.2067,  0.5627]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[ 1.4174, -0.7728, -0.1155,  ...,  0.6848,  0.4997,  1.4859],
          [ 1.3099,  0.0991,  0.0059,  ...,  0.3564,  0.4010,  1.4880],
          [-0.9466,  0.0367, -1.1556,  ...,  1.3546, -1.2413,  1.7884],
          ...,
          [ 1.3971, -0.8174, -1.2203,  ...,  0.8567,  1.7508,  1.2677],
          [ 1.3971, -0.8174, -1.2203,  ...,  0.8567,  1.7508,  1.2677],
          [ 1.3971, -0.8174, -1.2203,  ...,  0.8567,  1.7508,  1.2677]],

         [[ 2.4409, -1.1435,  0.7884,  ...,  0.9529,  1.3292,  1.2161],
          [ 2.2248, -0.7740,  0.4267,  ...,  1.0749,  1.5222,  1.0095],
          [-0.3756,  0.0267, -0.5827,  ..., -0.0837,  0.9137, -1.4451],
          ...,
          [-1.4495,  0.8582,  0.2606,  ...,  0.7307,  0.4667, -1.6316],
          [-1.4495,  0.8582,  0.2606,  ...,  0.7307,  0.4667, -1.6316],
          [-1.4495,  0.8582,  0.2606,  ...,  0.7307,  0.4667, -1.6316]],

         [[ 1.1088,  0.6066, -1.8675,  ...,  0.7374,  1.0438, -1.2479],
          [ 1.0078,  0.7627, -1.8271,  ...,  0.3888,  0.9391, -0.7070],
          [-1.5964, -0.4035, -3.0632,  ...,  3.7942,  0.1004,  0.4426],
          ...,
          [ 0.2286,  0.8355, -3.1183,  ...,  1.6161, -1.0748,  0.1449],
          [ 0.2286,  0.8355, -3.1183,  ...,  1.6161, -1.0748,  0.1449],
          [ 0.2286,  0.8355, -3.1183,  ...,  1.6161, -1.0748,  0.1449]],

         ...,

         [[ 3.2454, -2.8179, -0.5251,  ...,  0.6563, -0.9861, -0.3502],
          [ 2.8193, -2.7435, -0.6982,  ...,  0.6917, -0.5691, -1.2819],
          [ 1.5877,  1.1019,  0.8168,  ..., -0.3032, -0.0917,  1.8527],
          ...,
          [ 0.9423, -0.0295, -1.6021,  ...,  0.1021,  0.0247,  1.2837],
          [ 0.9423, -0.0295, -1.6021,  ...,  0.1021,  0.0247,  1.2837],
          [ 0.9423, -0.0295, -1.6021,  ...,  0.1021,  0.0247,  1.2837]],

         [[-0.0168, -0.1935, -1.4864,  ..., -3.0277, -0.7992, -0.5306],
          [-0.2359, -0.3273, -1.9375,  ..., -3.0152, -0.8224, -0.1113],
          [ 0.3279,  1.1493,  0.5418,  ...,  0.9852,  0.5756,  1.2321],
          ...,
          [ 2.1935,  1.3288, -1.4924,  ..., -1.0198, -0.2009,  2.7670],
          [ 2.1935,  1.3288, -1.4924,  ..., -1.0198, -0.2009,  2.7670],
          [ 2.1935,  1.3288, -1.4924,  ..., -1.0198, -0.2009,  2.7670]],

         [[ 0.6131, -0.5096, -1.6079,  ...,  0.7946, -2.9298,  0.3025],
          [ 0.0868, -0.9235, -1.5440,  ...,  0.4821, -2.9355,  0.2882],
          [ 3.0856, -1.5598, -3.7926,  ...,  2.7003, -1.2254,  1.0116],
          ...,
          [ 2.3538, -0.6143, -0.8386,  ...,  1.0289, -1.2133,  1.1721],
          [ 2.3538, -0.6143, -0.8386,  ...,  1.0289, -1.2133,  1.1721],
          [ 2.3538, -0.6143, -0.8386,  ...,  1.0289, -1.2133,  1.1721]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[  1.1896,  -4.2189,   1.3297,  ...,   0.0846,  -3.4047,   7.1763],
          [  0.8532,  -4.3932,   1.0389,  ...,  -0.6152,  -3.0206,   6.9729],
          [ -2.1311,  -5.3265,  -2.2507,  ...,   1.9050,   0.7905,   4.4847],
          ...,
          [  7.9889,  -1.9435,  -4.1405,  ...,  -2.3005,   4.2875,  -0.4352],
          [  7.9889,  -1.9435,  -4.1405,  ...,  -2.3005,   4.2875,  -0.4352],
          [  7.9889,  -1.9435,  -4.1405,  ...,  -2.3005,   4.2875,  -0.4352]],

         [[  0.8836,  -5.9879,   0.9148,  ..., -10.7811,  -6.5207,  10.2212],
          [  1.1839,  -5.5776,   1.2452,  ..., -11.4854,  -6.6013,   9.8240],
          [  1.0943,   1.0312,  -0.1009,  ...,   2.4966,   0.1327,  -2.7946],
          ...,
          [  0.4510,  -4.2893,  -0.4804,  ...,  -4.7265,  -1.7644,  -2.1012],
          [  0.4510,  -4.2893,  -0.4804,  ...,  -4.7265,  -1.7644,  -2.1012],
          [  0.4510,  -4.2893,  -0.4804,  ...,  -4.7265,  -1.7644,  -2.1012]],

         [[  1.5233,  -4.2790,  -1.4037,  ...,   2.8303,   0.7067,  -1.6541],
          [  1.6599,  -4.6027,  -1.6008,  ...,   2.8905,   0.6202,  -1.8473],
          [ -0.6531,  -0.2675,  -0.5719,  ...,   1.2210,  -2.1941,  -0.1960],
          ...,
          [ -3.1550,  -1.2115,   0.4271,  ...,  -4.4549,  -2.2453,  -3.4299],
          [ -3.1550,  -1.2115,   0.4271,  ...,  -4.4549,  -2.2453,  -3.4299],
          [ -3.1550,  -1.2115,   0.4271,  ...,  -4.4549,  -2.2453,  -3.4299]],

         ...,

         [[  1.4531,   2.8317,  -0.7540,  ...,   0.6925,  -2.1241,  -1.7855],
          [  0.9479,   2.7888,  -0.8134,  ...,   0.0880,  -2.2732,  -2.1758],
          [  0.9719,   0.3128,  -2.4485,  ...,   1.1286,  -1.3362,   2.1550],
          ...,
          [  2.9137,  -0.3043,   0.4824,  ...,   0.7416,  -2.4359,   4.9251],
          [  2.9137,  -0.3043,   0.4824,  ...,   0.7416,  -2.4359,   4.9251],
          [  2.9137,  -0.3043,   0.4824,  ...,   0.7416,  -2.4359,   4.9251]],

         [[ -0.1939,   2.7731,   1.3016,  ...,   1.7361,   3.4011,   0.1282],
          [  0.1031,   3.0486,   0.6398,  ...,   2.0138,   2.1470,  -0.4251],
          [ -0.3269,  -0.0167,  -2.3901,  ...,  -3.3961,  -2.5260,   1.8217],
          ...,
          [ -0.1958,   1.1971,   3.0600,  ...,  -0.4817,   0.4354,  -0.3505],
          [ -0.1958,   1.1971,   3.0600,  ...,  -0.4817,   0.4354,  -0.3505],
          [ -0.1958,   1.1971,   3.0600,  ...,  -0.4817,   0.4354,  -0.3505]],

         [[  0.1400,  -2.1670,   1.2301,  ...,  -0.9402,   0.7195,  -0.1467],
          [  0.3218,  -2.1785,   1.5235,  ...,  -0.6303,   0.5997,  -0.2442],
          [ -0.5068,   1.8439,  -3.4842,  ...,   2.6559,  -1.3814,   1.5654],
          ...,
          [  0.3024,  -1.9878,  -2.0457,  ...,   4.2555,  -1.4206,  -0.0259],
          [  0.3024,  -1.9878,  -2.0457,  ...,   4.2555,  -1.4206,  -0.0259],
          [  0.3024,  -1.9878,  -2.0457,  ...,   4.2555,  -1.4206,  -0.0259]]]],
       grad_fn=<TransposeBackward0>)), (tensor([[[[ 5.3276e-01,  1.5310e-01, -7.3969e-01,  ...,  2.8258e+00,
           -2.4038e-01, -5.5150e-02],
          [-5.6303e-01, -2.9500e-01,  6.2860e-01,  ...,  2.6956e+00,
           -7.8644e-01, -1.9224e-02],
          [-1.1200e+00, -4.2726e-01, -5.4130e-01,  ...,  2.7160e-01,
           -4.1436e-01, -9.6622e-01],
          ...,
          [-1.0651e+00, -4.8715e-01, -5.7254e-01,  ...,  3.6974e-01,
           -4.0526e-01, -8.1361e-01],
          [-1.0651e+00, -4.8715e-01, -5.7254e-01,  ...,  3.6974e-01,
           -4.0526e-01, -8.1361e-01],
          [-1.0741e+00, -4.8401e-01, -5.7895e-01,  ...,  3.5170e-01,
           -4.0494e-01, -8.1568e-01]],

         [[-1.0412e+00, -1.2725e+00, -2.5688e+00,  ...,  1.5835e+00,
           -7.9998e-01,  1.2474e+00],
          [-7.8811e-01, -9.7848e-01, -3.4214e+00,  ...,  4.6448e-01,
           -8.1672e-01,  1.7253e+00],
          [-5.9041e-01, -1.2713e-01, -8.2074e-01,  ...,  1.4822e-01,
            1.9196e-01,  9.1469e-01],
          ...,
          [-5.5896e-01, -2.2174e-01, -9.2330e-01,  ...,  2.4013e-01,
            2.4258e-01,  9.2025e-01],
          [-5.5896e-01, -2.2174e-01, -9.2330e-01,  ...,  2.4013e-01,
            2.4258e-01,  9.2025e-01],
          [-5.6999e-01, -2.2150e-01, -9.1324e-01,  ...,  2.3042e-01,
            2.5193e-01,  9.1939e-01]],

         [[-1.2583e+00,  2.2184e+00,  2.5679e+00,  ..., -5.7159e-01,
            9.0348e-01,  8.2978e-01],
          [-1.1207e+00,  2.7411e+00,  1.6577e+00,  ..., -4.8812e-01,
            8.7620e-01,  1.3003e+00],
          [-1.2123e+00,  7.7429e-01,  1.2576e+00,  ...,  5.9179e-01,
            8.4220e-01, -1.6376e-03],
          ...,
          [-1.2313e+00,  7.4926e-01,  1.3965e+00,  ...,  4.9734e-01,
            8.9602e-01, -4.0893e-03],
          [-1.2313e+00,  7.4926e-01,  1.3965e+00,  ...,  4.9734e-01,
            8.9602e-01, -4.0893e-03],
          [-1.2240e+00,  7.4733e-01,  1.3988e+00,  ...,  5.0173e-01,
            9.0866e-01,  2.3385e-03]],

         ...,

         [[-1.4296e+00, -6.2263e-01, -2.3910e+00,  ...,  1.1119e+00,
            4.9919e-01, -6.3276e-01],
          [-1.2109e+00, -1.9783e-01, -1.0445e+00,  ...,  2.0751e+00,
            1.2936e+00,  6.7546e-01],
          [-1.4653e-01, -2.0254e+00, -3.6632e-01,  ...,  1.8205e+00,
            6.4869e-01,  1.1025e+00],
          ...,
          [-1.5431e-01, -1.9636e+00, -4.8978e-01,  ...,  1.8789e+00,
            5.9545e-01,  1.1229e+00],
          [-1.5431e-01, -1.9636e+00, -4.8978e-01,  ...,  1.8789e+00,
            5.9545e-01,  1.1229e+00],
          [-1.5483e-01, -1.9745e+00, -4.8089e-01,  ...,  1.8802e+00,
            5.9606e-01,  1.1344e+00]],

         [[ 8.4959e-01,  1.7841e+00, -7.6360e-01,  ...,  9.6239e-01,
           -7.3048e-01,  1.9404e+00],
          [-5.9477e-01,  7.9886e-01,  9.8179e-01,  ...,  6.5196e-01,
            1.5216e-02, -1.9262e-01],
          [ 7.7298e-01,  1.6663e-01,  3.2727e-01,  ..., -7.6781e-01,
            4.3671e-01, -4.2288e+00],
          ...,
          [ 8.0814e-01,  1.7342e-01,  2.5034e-01,  ..., -7.6099e-01,
            4.3779e-01, -4.2307e+00],
          [ 8.0814e-01,  1.7342e-01,  2.5034e-01,  ..., -7.6099e-01,
            4.3779e-01, -4.2307e+00],
          [ 8.1690e-01,  1.5740e-01,  2.5785e-01,  ..., -7.6402e-01,
            4.4125e-01, -4.2438e+00]],

         [[ 2.7694e+00, -8.0310e-01, -9.5865e-01,  ...,  1.8031e-01,
            1.8032e+00, -1.6014e+00],
          [ 1.5153e+00,  5.6645e-01, -1.7564e+00,  ...,  5.0180e-01,
            1.4908e+00, -4.2351e-01],
          [-4.6676e-01,  1.3965e+00, -6.6460e-02,  ..., -2.2321e-01,
            2.6836e-01,  3.2589e-01],
          ...,
          [-4.0963e-01,  1.2674e+00, -9.6844e-02,  ..., -2.2423e-01,
            3.7950e-01,  2.7284e-01],
          [-4.0963e-01,  1.2674e+00, -9.6844e-02,  ..., -2.2423e-01,
            3.7950e-01,  2.7284e-01],
          [-4.2764e-01,  1.2735e+00, -8.8005e-02,  ..., -2.2680e-01,
            3.7849e-01,  2.6254e-01]]]], grad_fn=<TransposeBackward0>), tensor([[[[ 2.9551e+00,  1.7769e+00, -1.0421e+00,  ..., -1.7342e+00,
            4.0340e+00,  1.5149e+00],
          [ 2.0572e+00, -6.6643e-01, -6.1491e-01,  ..., -1.1988e+00,
            2.1122e+00,  1.8178e-02],
          [ 2.2493e-01,  4.2783e-01, -1.7543e-02,  ...,  5.0681e-01,
            1.8954e-02,  2.3825e-02],
          ...,
          [ 1.9126e-01,  3.7789e-01, -3.1635e-02,  ...,  3.2705e-01,
            3.8872e-01, -8.3484e-03],
          [ 1.9126e-01,  3.7789e-01, -3.1635e-02,  ...,  3.2705e-01,
            3.8872e-01, -8.3484e-03],
          [ 1.7926e-01,  3.8091e-01, -1.8601e-02,  ...,  3.3619e-01,
            3.7370e-01, -2.0579e-02]],

         [[-1.4429e+00, -1.8328e+00,  2.1062e+00,  ...,  3.4271e+00,
            4.6362e-02,  1.7426e+00],
          [-2.9256e+00,  2.7955e-01,  2.3537e+00,  ...,  2.5616e+00,
            2.9719e-01,  3.3225e+00],
          [-6.9226e-01,  8.2266e-01,  6.1051e-01,  ..., -4.2229e-03,
            4.2303e-01,  1.9784e+00],
          ...,
          [-6.9496e-01,  8.4715e-01,  6.0887e-01,  ...,  2.0260e-01,
            4.2212e-01,  1.9423e+00],
          [-6.9496e-01,  8.4715e-01,  6.0887e-01,  ...,  2.0260e-01,
            4.2212e-01,  1.9423e+00],
          [-6.7093e-01,  8.1923e-01,  5.7603e-01,  ...,  1.7068e-01,
            4.4605e-01,  1.9353e+00]],

         [[ 2.9355e+00, -7.7296e-01,  7.7944e-01,  ..., -1.2233e+00,
           -3.0951e+00, -3.1906e+00],
          [ 2.8139e+00,  1.3317e+00, -3.5548e-01,  ..., -2.3327e+00,
           -4.4878e-01, -1.7607e+00],
          [ 1.6520e-01,  3.9480e-01,  1.4102e-01,  ..., -2.1611e-01,
            2.6835e-01, -9.3185e-01],
          ...,
          [ 2.9579e-01,  3.5695e-01,  3.0241e-01,  ..., -3.1813e-01,
           -3.7401e-02, -9.7012e-01],
          [ 2.9579e-01,  3.5695e-01,  3.0241e-01,  ..., -3.1813e-01,
           -3.7401e-02, -9.7012e-01],
          [ 2.9241e-01,  3.5011e-01,  2.9431e-01,  ..., -2.9594e-01,
           -1.1465e-02, -9.6766e-01]],

         ...,

         [[-4.1100e-01, -6.4293e-03, -3.3417e-01,  ..., -1.2146e+00,
            1.5143e-01,  2.8336e-01],
          [ 1.1195e+00,  1.0229e+00,  1.2826e+00,  ..., -5.2185e-01,
            1.3658e-01,  2.0506e+00],
          [-1.0088e-01,  4.4964e-01,  1.7306e-03,  ..., -3.4758e-01,
           -5.4198e-01, -6.4414e-02],
          ...,
          [-1.4692e-01,  4.4923e-01,  8.1058e-02,  ..., -3.6052e-01,
           -5.2969e-01,  1.3968e-01],
          [-1.4692e-01,  4.4923e-01,  8.1058e-02,  ..., -3.6052e-01,
           -5.2969e-01,  1.3968e-01],
          [-1.5833e-01,  4.2727e-01,  7.6314e-02,  ..., -3.5182e-01,
           -5.4194e-01,  1.5179e-01]],

         [[-2.5969e-01, -1.8990e+00, -9.8688e-01,  ..., -1.6446e+00,
            1.0819e+00,  1.9316e+00],
          [ 5.0724e-01,  3.4761e-01, -3.1433e-01,  ..., -1.7899e-01,
            1.0381e+00, -1.7713e+00],
          [ 5.5442e-01,  2.9548e-01,  5.6988e-01,  ..., -6.6425e-01,
            7.5796e-02, -1.3960e-01],
          ...,
          [ 3.6637e-01,  1.7303e-01,  3.8576e-01,  ..., -4.5985e-01,
            1.2105e-01, -9.1018e-02],
          [ 3.6637e-01,  1.7303e-01,  3.8576e-01,  ..., -4.5985e-01,
            1.2105e-01, -9.1018e-02],
          [ 3.5968e-01,  1.7265e-01,  3.7926e-01,  ..., -4.4860e-01,
            1.2051e-01, -6.9605e-02]],

         [[-6.3836e+00, -1.8134e+00,  1.5221e+00,  ..., -9.0778e-02,
            2.5078e+00, -5.6678e-01],
          [-5.4208e+00, -3.0361e-01,  1.4575e+00,  ...,  1.1086e+00,
            1.4552e+00, -2.2068e+00],
          [-1.1317e+00, -2.4464e-01, -1.1445e-01,  ...,  1.7024e-01,
           -4.7194e-01,  1.0561e-01],
          ...,
          [-1.2262e+00, -3.9979e-01, -1.9336e-01,  ...,  1.8590e-01,
           -3.9704e-01,  1.3009e-01],
          [-1.2262e+00, -3.9979e-01, -1.9336e-01,  ...,  1.8590e-01,
           -3.9704e-01,  1.3009e-01],
          [-1.1834e+00, -3.8067e-01, -1.8745e-01,  ...,  1.9198e-01,
           -4.0337e-01,  1.3476e-01]]]], grad_fn=<TransposeBackward0>), tensor([[[[-1.2511e+00,  3.7495e-01, -5.7327e-02,  ...,  6.9601e-01,
           -5.0905e-01, -7.5787e-01],
          [-1.6186e+00,  6.0703e-01, -6.3545e-01,  ...,  7.5229e-01,
           -6.4124e-01, -7.9800e-01],
          [-9.4873e-01,  2.0260e+00,  1.6129e+00,  ..., -1.1604e+00,
           -1.6558e+00,  1.3795e+00],
          ...,
          [-1.5216e+00,  1.0823e+00,  3.5328e+00,  ...,  1.7184e-02,
            3.3497e-03, -4.3612e-01],
          [-1.5216e+00,  1.0823e+00,  3.5328e+00,  ...,  1.7184e-02,
            3.3497e-03, -4.3612e-01],
          [-1.5216e+00,  1.0823e+00,  3.5328e+00,  ...,  1.7184e-02,
            3.3497e-03, -4.3612e-01]],

         [[ 4.2646e+00,  6.0137e-02,  2.8101e+00,  ...,  5.7412e-01,
           -1.1337e-01, -4.5420e-01],
          [ 4.3242e+00,  1.4235e-01,  2.8002e+00,  ...,  1.8216e-01,
           -4.3013e-02, -1.1481e-01],
          [ 2.1468e+00,  1.0782e+00,  1.6008e+00,  ...,  9.9352e-01,
            4.3377e+00, -9.5019e-02],
          ...,
          [ 1.6487e+00,  2.1899e+00,  2.0504e+00,  ..., -1.6676e-01,
            7.3962e-02,  8.5972e-02],
          [ 1.6487e+00,  2.1899e+00,  2.0504e+00,  ..., -1.6676e-01,
            7.3962e-02,  8.5972e-02],
          [ 1.6487e+00,  2.1899e+00,  2.0504e+00,  ..., -1.6676e-01,
            7.3962e-02,  8.5972e-02]],

         [[-1.0564e+00,  1.1184e+00,  1.9002e+00,  ..., -3.5648e-01,
           -3.3589e+00, -3.9953e+00],
          [-1.1235e+00,  4.8577e-01,  1.9953e+00,  ..., -3.9443e-01,
           -3.4515e+00, -4.1298e+00],
          [-1.1588e+00, -6.2210e-01,  1.3663e+00,  ...,  6.0337e-01,
           -5.9875e-01,  8.6563e-01],
          ...,
          [-4.9492e-01, -2.6926e-01,  3.1344e+00,  ..., -1.9963e+00,
            2.5492e+00, -3.1893e+00],
          [-4.9492e-01, -2.6926e-01,  3.1344e+00,  ..., -1.9963e+00,
            2.5492e+00, -3.1893e+00],
          [-4.9492e-01, -2.6926e-01,  3.1344e+00,  ..., -1.9963e+00,
            2.5492e+00, -3.1893e+00]],

         ...,

         [[ 7.0136e-01,  2.3625e-01,  9.6178e-01,  ...,  1.8946e+00,
           -8.9837e-02, -6.1949e-01],
          [ 5.1719e-01,  2.2409e-01,  6.0763e-01,  ...,  2.1621e+00,
            1.1043e-01, -5.9483e-01],
          [ 2.7398e-01,  1.9928e+00, -1.5373e+00,  ...,  7.5181e-01,
           -1.2591e-01,  4.1947e-01],
          ...,
          [-1.0239e+00,  3.5880e+00, -3.5424e+00,  ..., -6.7446e-01,
            3.2265e+00,  3.5020e-01],
          [-1.0239e+00,  3.5880e+00, -3.5424e+00,  ..., -6.7446e-01,
            3.2265e+00,  3.5020e-01],
          [-1.0239e+00,  3.5880e+00, -3.5424e+00,  ..., -6.7446e-01,
            3.2265e+00,  3.5020e-01]],

         [[-1.2209e+00,  1.5056e+00, -7.3441e-01,  ..., -1.6650e-01,
            2.1825e+00, -2.2798e-01],
          [-8.1909e-01,  1.3488e+00, -5.1872e-01,  ..., -4.2426e-01,
            1.9469e+00, -5.4386e-02],
          [-1.7947e+00,  2.3675e-01, -8.7003e-01,  ...,  1.9497e+00,
            1.8599e+00, -1.9041e+00],
          ...,
          [-1.7220e+00,  3.7888e+00,  1.4281e+00,  ..., -5.2234e-01,
            1.0489e+00, -1.8766e+00],
          [-1.7220e+00,  3.7888e+00,  1.4281e+00,  ..., -5.2234e-01,
            1.0489e+00, -1.8766e+00],
          [-1.7220e+00,  3.7888e+00,  1.4281e+00,  ..., -5.2234e-01,
            1.0489e+00, -1.8766e+00]],

         [[-2.6832e+00,  1.9699e+00,  9.5804e-01,  ...,  3.3441e+00,
           -7.2309e-02, -5.9860e-02],
          [-2.2881e+00,  1.6546e+00,  6.3866e-01,  ...,  3.3765e+00,
           -6.9150e-02, -4.8416e-01],
          [ 2.5379e-01, -1.5709e+00, -7.5831e-01,  ...,  2.1498e+00,
            8.3885e-01,  2.6967e-01],
          ...,
          [-2.2237e+00,  1.2769e-01,  9.2184e-01,  ..., -1.2862e+00,
            2.4844e+00,  1.4264e+00],
          [-2.2237e+00,  1.2769e-01,  9.2184e-01,  ..., -1.2862e+00,
            2.4844e+00,  1.4264e+00],
          [-2.2237e+00,  1.2769e-01,  9.2184e-01,  ..., -1.2862e+00,
            2.4844e+00,  1.4264e+00]]]], grad_fn=<TransposeBackward0>), tensor([[[[-3.3352,  3.3889,  2.2235,  ..., -2.2209, -3.0573,  1.7323],
          [-2.4355,  3.6610,  2.9889,  ..., -2.5845, -2.7488,  1.8779],
          [-1.7547,  0.4535,  5.0128,  ..., -0.6340, -1.6740, -0.2810],
          ...,
          [ 2.6541, -3.5449,  5.2943,  ..., -0.8571, -0.0552, -3.6376],
          [ 2.6541, -3.5449,  5.2943,  ..., -0.8571, -0.0552, -3.6376],
          [ 2.6541, -3.5449,  5.2943,  ..., -0.8571, -0.0552, -3.6376]],

         [[-1.5079, -1.5965, -7.3706,  ...,  4.4704, -2.0334,  1.0101],
          [-1.0952, -1.5304, -7.2331,  ...,  3.7257, -1.8745,  1.0372],
          [-4.8083, -4.2165,  3.2438,  ..., -3.4526, -0.9692, -5.2195],
          ...,
          [-1.9277,  2.8934, -4.5653,  ..., -1.5712,  2.3364, -3.9165],
          [-1.9277,  2.8934, -4.5653,  ..., -1.5712,  2.3364, -3.9165],
          [-1.9277,  2.8934, -4.5653,  ..., -1.5712,  2.3364, -3.9165]],

         [[-1.4617, -2.3904, -3.5556,  ...,  3.3718, -3.0685, -5.4673],
          [-2.0176, -2.0281, -3.7958,  ...,  3.2712, -2.6791, -5.5292],
          [-1.7949,  3.1533,  8.2355,  ..., -2.5511,  1.7747, -7.2089],
          ...,
          [-0.9776,  2.9998, 10.4486,  ...,  1.1624,  0.7640, -2.8616],
          [-0.9776,  2.9998, 10.4486,  ...,  1.1624,  0.7640, -2.8616],
          [-0.9776,  2.9998, 10.4486,  ...,  1.1624,  0.7640, -2.8616]],

         ...,

         [[ 3.1779, -2.4830,  1.0229,  ...,  3.5808,  0.0126, -4.5747],
          [ 3.0572, -2.9566,  1.6095,  ...,  3.4408,  0.0581, -5.3820],
          [ 0.2099, -2.8703, -6.6337,  ..., -1.5535,  5.0598, -3.6865],
          ...,
          [ 2.1313,  3.4532, -0.7948,  ..., -0.5028, -2.7525, -4.9816],
          [ 2.1313,  3.4532, -0.7948,  ..., -0.5028, -2.7525, -4.9816],
          [ 2.1313,  3.4532, -0.7948,  ..., -0.5028, -2.7525, -4.9816]],

         [[-2.4065,  1.1902,  0.8760,  ...,  5.1508, -0.0560,  3.0916],
          [-2.4086,  1.0954,  0.6596,  ...,  4.9191,  0.1580,  4.3810],
          [-0.2090,  0.7652,  1.4994,  ...,  2.6046,  0.7374, -3.6844],
          ...,
          [ 4.0442,  0.4595,  0.1357,  ...,  2.2207, -2.9208, -1.1586],
          [ 4.0442,  0.4595,  0.1357,  ...,  2.2207, -2.9208, -1.1586],
          [ 4.0442,  0.4595,  0.1357,  ...,  2.2207, -2.9208, -1.1586]],

         [[ 2.2566, -0.0834,  1.2415,  ..., -1.9636, -0.9420,  0.8752],
          [ 2.0120,  0.4099,  1.6252,  ..., -2.1733, -1.3015,  1.3142],
          [-2.3564,  1.3236,  3.2222,  ..., -5.7469, -0.7060, -1.0781],
          ...,
          [ 1.2463,  2.1971,  0.6620,  ..., -1.3778,  3.4974,  6.5994],
          [ 1.2463,  2.1971,  0.6620,  ..., -1.3778,  3.4974,  6.5994],
          [ 1.2463,  2.1971,  0.6620,  ..., -1.3778,  3.4974,  6.5994]]]],
       grad_fn=<TransposeBackward0>)), (tensor([[[[-1.3152,  5.7096, -1.4670,  ...,  4.0057, -1.9811, -0.0417],
          [-1.3517,  4.4744, -0.3206,  ...,  3.6664,  0.2456, -0.7527],
          [ 1.5313,  4.4346,  1.4559,  ...,  3.6487, -0.2907,  0.3437],
          ...,
          [ 1.4448,  4.7119,  1.3634,  ...,  3.7556, -0.4891,  0.2225],
          [ 1.4448,  4.7119,  1.3634,  ...,  3.7556, -0.4891,  0.2225],
          [ 1.5081,  4.7250,  1.3997,  ...,  3.7862, -0.4839,  0.2360]],

         [[-0.0358,  3.7276, -2.1809,  ...,  2.2028, -0.6374,  0.7550],
          [-0.0698,  3.5469, -0.8481,  ...,  2.5270, -1.3567,  2.6072],
          [-0.8615,  1.6999, -1.5324,  ...,  0.7598, -1.1318,  1.0486],
          ...,
          [-0.7572,  1.7966, -1.8498,  ...,  0.7398, -1.0272,  1.0819],
          [-0.7572,  1.7966, -1.8498,  ...,  0.7398, -1.0272,  1.0819],
          [-0.7742,  1.8012, -1.8385,  ...,  0.7322, -1.0165,  1.0857]],

         [[ 1.2382, -2.2304, -0.0336,  ...,  2.6869,  0.6686, -0.7790],
          [ 1.4414, -1.9522, -1.0685,  ...,  1.8475,  0.1428, -0.3725],
          [ 1.8634, -3.6727, -1.5712,  ...,  1.9651,  0.3070,  1.2563],
          ...,
          [ 1.9004, -3.8849, -1.5788,  ...,  2.0565,  0.4762,  1.1199],
          [ 1.9004, -3.8849, -1.5788,  ...,  2.0565,  0.4762,  1.1199],
          [ 1.9144, -3.8902, -1.5943,  ...,  2.0559,  0.4589,  1.1571]],

         ...,

         [[-1.0692, -1.7633,  1.0904,  ..., -1.0664,  2.2132, -1.6461],
          [-1.3498, -1.3241,  0.4274,  ..., -3.0949,  3.2670, -2.5696],
          [-1.6900,  0.5072,  0.7544,  ..., -3.4662,  1.5721, -2.4407],
          ...,
          [-1.5290,  0.3393,  0.7621,  ..., -3.2124,  1.5249, -2.5209],
          [-1.5290,  0.3393,  0.7621,  ..., -3.2124,  1.5249, -2.5209],
          [-1.5312,  0.3738,  0.7916,  ..., -3.2177,  1.4954, -2.5350]],

         [[-2.1366,  0.1672, -0.4520,  ...,  0.8337,  0.7720, -0.3295],
          [-2.0881, -0.3588,  0.5541,  ...,  0.1822, -0.4676, -0.5762],
          [ 0.2390, -0.1266,  0.5464,  ..., -0.6053,  0.9582, -0.1565],
          ...,
          [ 0.1112, -0.1189,  0.4667,  ..., -0.6465,  1.0055, -0.1957],
          [ 0.1112, -0.1189,  0.4667,  ..., -0.6465,  1.0055, -0.1957],
          [ 0.1408, -0.1106,  0.4568,  ..., -0.6444,  1.0372, -0.1554]],

         [[ 0.1606,  2.0688, -1.6405,  ...,  0.8578,  0.7647,  0.4907],
          [ 0.3669,  2.4629, -1.4337,  ...,  2.2612, -0.7775,  1.3218],
          [-0.4416,  0.9637,  0.2780,  ...,  2.1399, -0.4964,  0.2996],
          ...,
          [-0.4771,  0.9350,  0.3569,  ...,  2.0456, -0.3149,  0.4188],
          [-0.4771,  0.9350,  0.3569,  ...,  2.0456, -0.3149,  0.4188],
          [-0.4661,  0.8870,  0.3390,  ...,  2.0605, -0.3263,  0.4209]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[-0.5043,  0.2636,  1.4723,  ...,  1.4167,  2.3656,  0.8091],
          [ 0.3857,  2.4613,  0.7024,  ..., -0.3348,  1.4814,  1.8408],
          [ 0.7997, -0.5004, -0.8322,  ..., -0.7486, -0.5901,  0.5256],
          ...,
          [ 0.9376, -0.6556, -0.7691,  ..., -0.6477, -0.3311,  0.5316],
          [ 0.9376, -0.6556, -0.7691,  ..., -0.6477, -0.3311,  0.5316],
          [ 0.9222, -0.7009, -0.7772,  ..., -0.6741, -0.3433,  0.5095]],

         [[ 0.5318, -1.0096, -3.9788,  ...,  1.5915,  0.0621, -1.9748],
          [-0.1159, -5.2598, -1.8256,  ...,  1.9317,  1.3304, -1.9164],
          [ 0.3690, -1.7902, -1.3342,  ...,  1.0914,  1.4400, -1.4057],
          ...,
          [ 0.6538, -1.6367, -1.7247,  ...,  1.3665,  1.0273, -1.2422],
          [ 0.6538, -1.6367, -1.7247,  ...,  1.3665,  1.0273, -1.2422],
          [ 0.6310, -1.6276, -1.7571,  ...,  1.3153,  1.0201, -1.1973]],

         [[-1.4012,  5.2912, -1.9981,  ...,  2.1934,  1.5357, -3.4944],
          [-1.5652,  6.2503,  0.7941,  ...,  1.1995,  1.6993, -1.9306],
          [ 0.8375,  2.2730,  1.4815,  ...,  0.7461, -0.0257, -0.2548],
          ...,
          [ 0.6914,  2.3553,  0.9197,  ...,  0.8110,  0.0999, -0.9021],
          [ 0.6914,  2.3553,  0.9197,  ...,  0.8110,  0.0999, -0.9021],
          [ 0.6571,  2.2973,  0.9143,  ...,  0.8334,  0.0857, -0.8726]],

         ...,

         [[-1.7442, -4.1813,  3.4650,  ...,  0.9163,  0.3649,  3.9919],
          [ 1.5136, -1.7914,  1.4460,  ..., -0.5828,  0.6847, -0.9267],
          [-0.1727,  1.1965,  0.1136,  ...,  0.1679,  0.5164,  0.5779],
          ...,
          [-0.1513,  0.6686,  0.2295,  ...,  0.5156, -0.0424,  0.9516],
          [-0.1513,  0.6686,  0.2295,  ...,  0.5156, -0.0424,  0.9516],
          [-0.1772,  0.6841,  0.1980,  ...,  0.5090, -0.0263,  0.9547]],

         [[ 2.3033,  2.4417, -0.5174,  ..., -3.3501,  0.6944,  0.9258],
          [ 0.2469, -0.9863, -0.9988,  ..., -0.0774,  0.0445, -0.4780],
          [ 0.7372, -0.5101,  0.5764,  ..., -0.0882,  0.2793, -0.5515],
          ...,
          [ 0.7107, -0.2971,  0.6416,  ..., -0.3173, -0.3684, -0.4907],
          [ 0.7107, -0.2971,  0.6416,  ..., -0.3173, -0.3684, -0.4907],
          [ 0.6926, -0.2978,  0.6368,  ..., -0.2930, -0.3711, -0.5228]],

         [[ 1.4786,  1.9821,  1.7708,  ...,  5.3074,  3.1418,  4.8824],
          [-1.9149, -0.7301,  2.5113,  ...,  3.8475,  4.4100,  0.6501],
          [ 0.4485, -0.3066, -0.2750,  ...,  2.1705,  1.5357,  0.8236],
          ...,
          [ 0.5427,  0.1374, -0.1514,  ...,  2.4106,  1.5766,  0.9945],
          [ 0.5427,  0.1374, -0.1514,  ...,  2.4106,  1.5766,  0.9945],
          [ 0.5709,  0.1128, -0.1049,  ...,  2.3734,  1.5723,  0.9746]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[ 0.7424, -0.1856, -1.4998,  ...,  0.0124, -4.2153,  1.7766],
          [-0.0690, -0.6230, -1.7327,  ...,  0.3371, -4.0828,  1.8326],
          [ 1.2001,  0.6908, -0.2088,  ...,  1.2886, -1.5273,  1.4691],
          ...,
          [ 0.9665,  0.3914, -0.3440,  ...,  4.5045, -5.4637, -1.3155],
          [ 0.9665,  0.3914, -0.3440,  ...,  4.5045, -5.4637, -1.3155],
          [ 0.9665,  0.3914, -0.3440,  ...,  4.5045, -5.4637, -1.3155]],

         [[ 1.0770, -0.4095, -0.4903,  ...,  0.0590,  0.1868,  1.7682],
          [ 0.8725, -0.2968, -0.8809,  ...,  0.0678,  0.4489,  1.4478],
          [-3.7578, -3.3392,  0.3157,  ..., -1.6165,  0.4946,  2.1345],
          ...,
          [-1.5144, -0.9210,  1.5036,  ..., -0.0553,  0.4206,  1.9260],
          [-1.5144, -0.9210,  1.5036,  ..., -0.0553,  0.4206,  1.9260],
          [-1.5144, -0.9210,  1.5036,  ..., -0.0553,  0.4206,  1.9260]],

         [[-0.8695, -2.1103,  0.7260,  ...,  0.5480,  1.2375, -0.8362],
          [-0.5145, -1.8842,  1.1627,  ...,  0.5728,  0.9921, -1.0172],
          [ 3.4993, -2.3207, -1.9817,  ..., -1.8788, -1.0562, -1.3032],
          ...,
          [ 2.4275,  1.9382, -2.6925,  ...,  0.7009,  0.6741, -1.3071],
          [ 2.4275,  1.9382, -2.6925,  ...,  0.7009,  0.6741, -1.3071],
          [ 2.4275,  1.9382, -2.6925,  ...,  0.7009,  0.6741, -1.3071]],

         ...,

         [[-1.4898, -6.0278,  1.6576,  ...,  1.9197, -2.5137, -2.1845],
          [-1.7999, -5.6362,  1.9669,  ...,  1.8376, -2.7063, -2.0679],
          [ 2.2877,  0.7371,  0.1494,  ..., -1.2721, -0.4711, -0.0975],
          ...,
          [ 2.4246, -1.8799, -2.9783,  ...,  0.5199,  0.1788,  1.0477],
          [ 2.4246, -1.8799, -2.9783,  ...,  0.5199,  0.1788,  1.0477],
          [ 2.4246, -1.8799, -2.9783,  ...,  0.5199,  0.1788,  1.0477]],

         [[-6.0115,  1.4841,  0.9313,  ..., -0.1359,  1.0382,  0.0100],
          [-6.2610,  1.4935,  0.8322,  ..., -0.2289,  0.6686, -0.1349],
          [-1.0624,  0.3951, -1.4432,  ...,  0.8929, -0.0581, -1.1998],
          ...,
          [-0.9990, -0.7166, -3.0820,  ..., -1.7033,  1.5525,  3.9427],
          [-0.9990, -0.7166, -3.0820,  ..., -1.7033,  1.5525,  3.9427],
          [-0.9990, -0.7166, -3.0820,  ..., -1.7033,  1.5525,  3.9427]],

         [[ 0.2139,  0.1764,  0.5277,  ..., -0.3264,  3.6051, -2.7468],
          [ 0.6106,  0.4998,  0.6312,  ..., -0.5482,  3.8508, -2.5846],
          [ 0.1305,  0.8948,  3.9403,  ..., -3.1363,  1.0654, -0.4777],
          ...,
          [-2.6688,  0.6764,  5.5045,  ...,  0.5529,  1.4393, -1.6970],
          [-2.6688,  0.6764,  5.5045,  ...,  0.5529,  1.4393, -1.6970],
          [-2.6688,  0.6764,  5.5045,  ...,  0.5529,  1.4393, -1.6970]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[ -2.8720,   1.7378,  -4.2808,  ...,  -0.5163,   0.3902,  -2.3792],
          [ -2.7505,   4.4352,  -3.3775,  ...,  -1.6374,   2.2067,  -1.7213],
          [ -4.9716,  -6.3281,  -1.3770,  ...,  -1.8053,   1.0781,  -1.7071],
          ...,
          [ -0.6815,   6.5603,  -8.9482,  ...,  -1.4596,  15.1201,  -0.1730],
          [ -0.6815,   6.5603,  -8.9482,  ...,  -1.4596,  15.1201,  -0.1730],
          [ -0.6815,   6.5603,  -8.9482,  ...,  -1.4596,  15.1201,  -0.1730]],

         [[ -2.2394,   0.5505,  -0.7398,  ...,  -5.2507,  -2.8439,  -2.8989],
          [ -3.2066,  -0.2693,  -0.4719,  ...,  -4.8264,  -0.8575,  -2.3726],
          [ -1.3534,   0.1986,  -4.7485,  ...,  -2.6218,   2.9503,  -3.3376],
          ...,
          [ -2.3833,   5.4691,  -0.4715,  ...,  -4.7071,   4.4838,  -0.9943],
          [ -2.3833,   5.4691,  -0.4715,  ...,  -4.7071,   4.4838,  -0.9943],
          [ -2.3833,   5.4691,  -0.4715,  ...,  -4.7071,   4.4838,  -0.9943]],

         [[ -2.9352,   0.9142,   2.2983,  ...,  -7.1204,  -2.2943,   1.3724],
          [ -3.1415,   2.0320,   2.1691,  ...,  -7.1298,  -3.2574,   2.0360],
          [  3.7376,   2.4622,  -3.9512,  ...,   0.0333,   1.2005,  -3.3871],
          ...,
          [ -4.6638,   0.9737,   4.9397,  ...,  -1.2744,  -2.6221,  -2.6063],
          [ -4.6638,   0.9737,   4.9397,  ...,  -1.2744,  -2.6221,  -2.6063],
          [ -4.6638,   0.9737,   4.9397,  ...,  -1.2744,  -2.6221,  -2.6063]],

         ...,

         [[  3.8345,   2.7969,  -6.1198,  ...,   5.3892,   4.6436,  -4.9443],
          [  3.6301,   2.6785,  -6.2371,  ...,   4.2319,   5.3229,  -4.4813],
          [ -0.4385,   1.0151,   0.1386,  ...,  -2.1973,   0.4127,  -1.1105],
          ...,
          [ -0.5974,  -1.4970,   0.8624,  ...,  -4.6397,   0.0703,  -5.2127],
          [ -0.5974,  -1.4970,   0.8624,  ...,  -4.6397,   0.0703,  -5.2127],
          [ -0.5974,  -1.4970,   0.8624,  ...,  -4.6397,   0.0703,  -5.2127]],

         [[ -3.2954,  -1.0594,   0.7593,  ...,   1.4822,   2.0132,   0.0268],
          [ -1.6741,  -1.4980,  -0.7640,  ...,   1.2698,   1.8636,  -1.0460],
          [  5.3975,   0.1127,  -9.0164,  ...,   4.4620,  -2.7416,  -1.9106],
          ...,
          [  7.9356,   6.6429,  -0.8715,  ...,   9.4993,   3.7676, -14.2954],
          [  7.9356,   6.6429,  -0.8715,  ...,   9.4993,   3.7676, -14.2954],
          [  7.9356,   6.6429,  -0.8715,  ...,   9.4993,   3.7676, -14.2954]],

         [[ -1.6906,   3.0952,   0.3371,  ...,   2.6363,   2.9857,  -3.9183],
          [ -1.7522,   2.2027,   1.3113,  ...,   1.4247,   4.2515,  -3.8434],
          [ -2.2816,   0.0287,   1.3675,  ...,  -3.2672,  -1.4312,  -0.1154],
          ...,
          [  1.3658,  -5.8454,  -2.8784,  ...,  -1.6338,  -1.2976,  -4.5671],
          [  1.3658,  -5.8454,  -2.8784,  ...,  -1.6338,  -1.2976,  -4.5671],
          [  1.3658,  -5.8454,  -2.8784,  ...,  -1.6338,  -1.2976,  -4.5671]]]],
       grad_fn=<TransposeBackward0>)), (tensor([[[[ 2.1097e-03,  1.5514e+00, -9.0948e-01,  ...,  2.5326e-01,
            8.7332e-01,  1.9186e+00],
          [-1.2408e+00,  1.1000e+00, -2.1731e+00,  ..., -5.2833e-01,
            1.7485e+00,  3.0796e+00],
          [-9.8369e-01,  1.6516e+00, -2.6705e+00,  ..., -5.9531e-01,
            1.0600e+00,  2.5151e+00],
          ...,
          [-8.4195e-01,  1.7919e+00, -2.4020e+00,  ..., -4.2190e-01,
            1.3217e+00,  2.4385e+00],
          [-8.4195e-01,  1.7919e+00, -2.4020e+00,  ..., -4.2190e-01,
            1.3217e+00,  2.4385e+00],
          [-8.3869e-01,  1.8131e+00, -2.4077e+00,  ..., -4.5655e-01,
            1.3187e+00,  2.4530e+00]],

         [[-4.0928e-01, -3.9571e-01,  2.5721e+00,  ..., -2.5026e+00,
            5.4624e+00,  1.9382e+00],
          [ 2.8151e+00, -2.0011e+00,  8.1293e-01,  ..., -1.6058e+00,
            3.9042e+00,  2.0139e+00],
          [-8.5266e-01, -2.5939e+00,  1.3430e+00,  ..., -2.5943e-01,
            1.9476e+00,  7.3382e-01],
          ...,
          [-9.6097e-01, -2.4612e+00,  1.7446e+00,  ..., -7.2258e-01,
            2.5734e+00,  6.8234e-01],
          [-9.6097e-01, -2.4612e+00,  1.7446e+00,  ..., -7.2258e-01,
            2.5734e+00,  6.8234e-01],
          [-9.6144e-01, -2.5063e+00,  1.7226e+00,  ..., -7.2374e-01,
            2.5278e+00,  6.6021e-01]],

         [[ 4.0476e+00, -1.9082e-01, -1.6235e-01,  ..., -7.0813e-02,
            4.0002e-02,  1.6509e+00],
          [ 2.6256e+00,  8.2695e-01, -3.4207e-01,  ...,  9.3707e-01,
            1.1479e+00,  7.7939e-01],
          [ 2.4563e+00,  1.0075e+00, -1.8322e+00,  ..., -1.9840e-02,
           -1.3038e-01,  2.9628e-01],
          ...,
          [ 2.7454e+00,  1.1603e+00, -1.6880e+00,  ..., -3.8503e-01,
           -3.5235e-01,  3.9962e-01],
          [ 2.7454e+00,  1.1603e+00, -1.6880e+00,  ..., -3.8503e-01,
           -3.5235e-01,  3.9962e-01],
          [ 2.7343e+00,  1.1764e+00, -1.7226e+00,  ..., -3.7677e-01,
           -3.3716e-01,  3.7063e-01]],

         ...,

         [[ 3.0009e+00, -4.0712e-01,  4.3189e-02,  ..., -1.0643e+00,
           -3.4228e-01, -3.1813e-01],
          [ 2.1058e+00,  1.8221e+00, -3.1255e-01,  ...,  1.1693e+00,
           -1.7947e+00, -3.1066e+00],
          [ 2.2702e+00,  8.2103e-01, -1.5118e+00,  ...,  8.9663e-01,
           -5.8417e-01, -3.1723e+00],
          ...,
          [ 2.3342e+00,  6.5289e-01, -1.3117e+00,  ...,  9.3031e-01,
           -4.5603e-01, -3.0863e+00],
          [ 2.3342e+00,  6.5289e-01, -1.3117e+00,  ...,  9.3031e-01,
           -4.5603e-01, -3.0863e+00],
          [ 2.3480e+00,  6.6640e-01, -1.3482e+00,  ...,  9.5862e-01,
           -4.7385e-01, -3.0841e+00]],

         [[-1.0624e-01, -1.5962e-01,  1.9455e+00,  ...,  3.3270e-01,
           -2.4193e+00, -1.4272e+00],
          [ 3.5535e-01, -1.1451e+00,  1.1754e+00,  ...,  5.6430e-01,
           -1.4728e+00, -9.0555e-01],
          [ 6.3479e-02,  5.9669e-01,  1.4688e+00,  ...,  9.9553e-02,
           -9.4915e-02, -5.2103e-01],
          ...,
          [-1.6270e-01,  7.4088e-01,  1.7250e+00,  ...,  2.8688e-02,
           -5.4150e-01, -5.0623e-01],
          [-1.6270e-01,  7.4088e-01,  1.7250e+00,  ...,  2.8688e-02,
           -5.4150e-01, -5.0623e-01],
          [-1.3725e-01,  7.6402e-01,  1.7128e+00,  ...,  1.8485e-02,
           -5.1554e-01, -4.9807e-01]],

         [[-2.7507e+00,  2.3738e+00,  3.6078e+00,  ..., -2.9289e+00,
           -7.5693e-02,  2.2153e+00],
          [-2.2753e+00,  1.5102e+00,  1.7518e+00,  ..., -1.5429e+00,
           -3.2317e-01,  4.3705e+00],
          [-9.0793e-01,  8.3936e-01,  2.4499e-01,  ..., -1.0836e+00,
           -6.8581e-01,  2.7190e+00],
          ...,
          [-9.2310e-01,  8.2699e-01,  8.4103e-01,  ..., -1.3068e+00,
           -9.3001e-01,  2.9278e+00],
          [-9.2310e-01,  8.2699e-01,  8.4103e-01,  ..., -1.3068e+00,
           -9.3001e-01,  2.9278e+00],
          [-8.8181e-01,  8.3411e-01,  8.3968e-01,  ..., -1.3068e+00,
           -8.9643e-01,  2.9523e+00]]]], grad_fn=<TransposeBackward0>), tensor([[[[-6.4551,  2.9794,  2.1137,  ..., -1.7281,  1.0012, -0.1483],
          [-6.2747, -0.4749,  0.3873,  ..., -1.6742,  1.5838,  0.6671],
          [-2.6717,  0.1030, -0.1349,  ..., -1.6918,  1.9066,  0.0555],
          ...,
          [-2.8501, -0.0786, -0.5107,  ..., -1.7496,  2.3979, -0.1053],
          [-2.8501, -0.0786, -0.5107,  ..., -1.7496,  2.3979, -0.1053],
          [-2.7676, -0.1523, -0.5276,  ..., -1.8103,  2.3841, -0.1234]],

         [[-0.2188,  3.4449,  1.6569,  ...,  0.2848,  0.2934,  3.8558],
          [-1.9766,  3.1298,  1.4610,  ...,  2.5610,  1.2279,  0.1298],
          [-0.0566,  0.8476,  0.9728,  ...,  0.3719,  0.2699,  0.2808],
          ...,
          [ 0.3286,  1.2222,  1.0062,  ..., -0.4439,  0.3677, -0.1067],
          [ 0.3286,  1.2222,  1.0062,  ..., -0.4439,  0.3677, -0.1067],
          [ 0.3148,  1.2142,  1.0343,  ..., -0.4184,  0.3681, -0.1375]],

         [[-1.5560, -3.2860, -0.8093,  ..., -2.3496,  0.4661, -1.9984],
          [-0.9852, -1.7860, -2.4885,  ..., -2.0438,  0.1729, -0.9261],
          [ 0.5226, -1.2155, -0.4930,  ..., -0.9604, -0.8283, -0.5194],
          ...,
          [ 0.3134, -0.6465, -0.4053,  ..., -0.8836, -0.7167, -0.3569],
          [ 0.3134, -0.6465, -0.4053,  ..., -0.8836, -0.7167, -0.3569],
          [ 0.3712, -0.6864, -0.4211,  ..., -0.8621, -0.7425, -0.3771]],

         ...,

         [[ 4.3734,  2.4901,  0.6447,  ..., -2.8976,  3.0179,  1.1680],
          [-0.1177,  2.5831, -2.9450,  ..., -0.5310,  4.4930,  0.0666],
          [ 0.6190,  1.7576,  0.4303,  ..., -0.4257,  2.1298,  1.9180],
          ...,
          [ 1.0634,  1.7538,  0.4630,  ..., -0.7421,  2.0346,  2.2906],
          [ 1.0634,  1.7538,  0.4630,  ..., -0.7421,  2.0346,  2.2906],
          [ 1.0607,  1.7287,  0.5094,  ..., -0.8194,  2.0420,  2.2556]],

         [[ 0.7435, -1.3226,  2.2849,  ..., -5.8612, -1.2330, -2.7826],
          [-0.0814,  1.1790,  1.0418,  ..., -2.9954,  1.8294, -3.8599],
          [ 0.2119,  0.3491,  0.4706,  ..., -1.8963,  1.1721, -1.6747],
          ...,
          [ 0.7160,  0.1121,  0.7762,  ..., -2.1339,  0.7383, -2.2995],
          [ 0.7160,  0.1121,  0.7762,  ..., -2.1339,  0.7383, -2.2995],
          [ 0.7153,  0.1371,  0.7545,  ..., -2.1251,  0.7836, -2.2613]],

         [[-3.2615, -3.2867,  1.5694,  ..., -1.9111, -0.4147, -1.4978],
          [-4.5383, -0.6074, -0.2040,  ..., -2.2887,  2.6059,  2.3513],
          [-1.7461, -1.2393, -0.8870,  ..., -1.8850,  0.3960,  0.1979],
          ...,
          [-1.7148, -2.1446, -1.2101,  ..., -1.9512,  0.6065,  0.2968],
          [-1.7148, -2.1446, -1.2101,  ..., -1.9512,  0.6065,  0.2968],
          [-1.6941, -2.1637, -1.2466,  ..., -1.9936,  0.5818,  0.3331]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[ 3.7258,  1.7102,  8.0881,  ..., -3.4739,  5.2415,  9.0920],
          [ 3.6132,  1.8102,  7.9358,  ..., -2.9680,  5.9912,  9.2070],
          [ 2.9974,  0.4965,  4.9965,  ..., -1.3959, -3.8038,  1.7065],
          ...,
          [ 6.4989,  3.2868,  3.9479,  ..., -2.7333, -5.9575,  3.3040],
          [ 6.4989,  3.2868,  3.9479,  ..., -2.7333, -5.9575,  3.3040],
          [ 6.4989,  3.2868,  3.9479,  ..., -2.7333, -5.9575,  3.3040]],

         [[-0.9470, -0.4220,  2.5948,  ...,  0.2992, -4.0967,  0.7998],
          [-1.1535, -0.5347,  2.9862,  ...,  0.4219, -4.2436,  0.8798],
          [-2.7095, -1.6229,  0.6953,  ..., -2.2617, -0.5293, -0.9015],
          ...,
          [-0.3336, -5.9162,  1.7431,  ..., -0.6285, -0.9861,  0.3378],
          [-0.3336, -5.9162,  1.7431,  ..., -0.6285, -0.9861,  0.3378],
          [-0.3336, -5.9162,  1.7431,  ..., -0.6285, -0.9861,  0.3378]],

         [[ 0.3236, -1.6754,  2.4135,  ..., -1.8325, -3.7335,  1.5744],
          [ 0.3745, -1.2684,  2.5911,  ..., -1.8334, -3.7325,  2.0243],
          [-0.8711,  0.5350,  1.0570,  ..., -2.3854, -2.1423,  3.3204],
          ...,
          [ 0.7190, -2.8541,  1.7896,  ..., -5.0469, -2.0697, -0.2030],
          [ 0.7190, -2.8541,  1.7896,  ..., -5.0469, -2.0697, -0.2030],
          [ 0.7190, -2.8541,  1.7896,  ..., -5.0469, -2.0697, -0.2030]],

         ...,

         [[-5.6665,  3.2098, -1.4405,  ..., -3.1927,  0.2826, -0.7847],
          [-5.7331,  3.0253, -0.8402,  ..., -2.9408, -0.2857, -0.4530],
          [-1.7424,  3.6961, -2.5406,  ..., -2.3120, -1.3674, -0.9706],
          ...,
          [-1.1876,  1.2857, -3.1044,  ..., -0.0539,  1.2877,  2.2252],
          [-1.1876,  1.2857, -3.1044,  ..., -0.0539,  1.2877,  2.2252],
          [-1.1876,  1.2857, -3.1044,  ..., -0.0539,  1.2877,  2.2252]],

         [[ 1.1441, -0.0098,  0.8105,  ..., -2.1600,  0.8287, -1.2281],
          [ 1.0581, -0.8423,  0.6405,  ..., -1.8214,  1.0759, -1.2893],
          [ 2.9784,  1.5987, -0.2583,  ..., -1.6524, -1.9879, -0.0596],
          ...,
          [ 2.4903,  1.3806,  1.2685,  ..., -0.2799, -0.9384, -0.9093],
          [ 2.4903,  1.3806,  1.2685,  ..., -0.2799, -0.9384, -0.9093],
          [ 2.4903,  1.3806,  1.2685,  ..., -0.2799, -0.9384, -0.9093]],

         [[-0.3995, -0.4081, -1.5003,  ..., -2.2998, -1.6834,  1.2933],
          [ 0.2302, -0.9300, -2.2677,  ..., -2.6598, -1.8165,  1.2381],
          [ 0.2096,  2.1490, -0.6914,  ..., -1.1440, -0.7731,  1.2069],
          ...,
          [ 2.8502,  0.4137,  3.7305,  ..., -0.1402,  2.5941, -5.0223],
          [ 2.8502,  0.4137,  3.7305,  ..., -0.1402,  2.5941, -5.0223],
          [ 2.8502,  0.4137,  3.7305,  ..., -0.1402,  2.5941, -5.0223]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[ -5.7806,  -0.1627,  -1.7749,  ...,  -1.2571,   5.0386, -11.7940],
          [ -4.9589,  -0.4945,  -1.6389,  ...,  -1.9981,   2.7451, -12.9327],
          [ -2.7943,   1.8288,  14.8798,  ...,  -6.4925,   8.2568,  -4.2893],
          ...,
          [-12.8340,   5.8293,   0.3416,  ...,   1.2527,  13.6421,   0.0389],
          [-12.8340,   5.8293,   0.3416,  ...,   1.2527,  13.6421,   0.0389],
          [-12.8340,   5.8293,   0.3416,  ...,   1.2527,  13.6421,   0.0389]],

         [[  8.7175,  10.0063,  -3.1505,  ...,  -2.5578,   5.7157,   3.6232],
          [ 10.5075,   8.9197,  -3.8752,  ...,  -2.6336,   6.2578,   1.2150],
          [ -6.5720,   7.1512,  -0.0733,  ...,   2.9159,   3.0414,  -7.5064],
          ...,
          [-11.7473,   1.5496,  -2.6760,  ...,   4.4461,   1.2940,  -8.3210],
          [-11.7473,   1.5496,  -2.6760,  ...,   4.4461,   1.2940,  -8.3210],
          [-11.7473,   1.5496,  -2.6760,  ...,   4.4461,   1.2940,  -8.3210]],

         [[  1.5810,   1.7656,   5.5673,  ...,  -7.9448,   0.7199,  -4.0970],
          [  1.1964,   1.9764,   5.2619,  ...,  -7.5985,  -1.2487,  -2.7517],
          [ -0.0332,   1.3375,   3.8689,  ...,   1.4732,  -3.9478,   0.1484],
          ...,
          [-18.8993, -11.0211, -10.7089,  ...,   1.3439, -18.1071,   2.4146],
          [-18.8993, -11.0211, -10.7089,  ...,   1.3439, -18.1071,   2.4146],
          [-18.8993, -11.0211, -10.7089,  ...,   1.3439, -18.1071,   2.4146]],

         ...,

         [[ -2.0628,  -2.9849,  -3.8942,  ...,  -4.4961,  -3.1316,   3.5867],
          [ -2.1971,  -1.3944,  -4.4933,  ...,  -4.6964,  -3.9728,   4.4899],
          [ -0.6022,   7.7346,  -1.4433,  ...,  -5.4052,  -1.3670,   1.4640],
          ...,
          [ -0.4124,  -6.9121,   0.6613,  ..., -12.4070,   2.4194,   3.2318],
          [ -0.4124,  -6.9121,   0.6613,  ..., -12.4070,   2.4194,   3.2318],
          [ -0.4124,  -6.9121,   0.6613,  ..., -12.4070,   2.4194,   3.2318]],

         [[ -0.9144,  -1.0570,  -0.2240,  ...,  -9.8352,   0.8528,  -7.1527],
          [  0.1759,  -1.7132,   0.0942,  ...,  -7.9109,   2.0724,  -4.2512],
          [  3.3228,  -1.1993,  -5.1080,  ...,   1.8315,  -2.4904,   1.1763],
          ...,
          [ -3.0930,  -4.2095, -14.0020,  ...,   8.0590,   0.8006,   6.3577],
          [ -3.0930,  -4.2095, -14.0020,  ...,   8.0590,   0.8006,   6.3577],
          [ -3.0930,  -4.2095, -14.0020,  ...,   8.0590,   0.8006,   6.3577]],

         [[ -2.3418,   1.1557,  -0.8803,  ...,   2.4626,  -2.6735,  -7.8135],
          [ -3.1027,   1.1908,  -1.9743,  ...,   2.2457,  -3.1706,  -7.3264],
          [  0.6847,  -4.5201,  -5.1600,  ...,   1.0622,   4.0921,   4.0350],
          ...,
          [  1.3614,  -8.5352,   4.9724,  ...,   5.8550,   3.0930,   7.1385],
          [  1.3614,  -8.5352,   4.9724,  ...,   5.8550,   3.0930,   7.1385],
          [  1.3614,  -8.5352,   4.9724,  ...,   5.8550,   3.0930,   7.1385]]]],
       grad_fn=<TransposeBackward0>)), (tensor([[[[ 0.5680,  1.7588,  4.2604,  ..., -2.3048,  2.0464, -1.5318],
          [-0.1497, -0.5128,  4.7584,  ..., -2.5882,  1.7084, -1.3274],
          [ 1.2499, -0.3630,  4.0982,  ..., -0.9824,  1.2733, -0.6717],
          ...,
          [ 1.3568, -0.1633,  3.8603,  ..., -1.1382,  1.6079, -0.7496],
          [ 1.3568, -0.1633,  3.8603,  ..., -1.1382,  1.6079, -0.7496],
          [ 1.3801, -0.2375,  3.8648,  ..., -1.0940,  1.5717, -0.7194]],

         [[ 2.8220,  0.3850,  1.2129,  ...,  2.3258,  1.8751, -0.8240],
          [ 3.7856,  0.5684,  2.4499,  ...,  1.3860,  0.8256, -1.3745],
          [ 1.6679,  0.9678,  1.7258,  ...,  0.1613,  2.5851, -1.4448],
          ...,
          [ 1.6160,  1.0231,  1.3962,  ...,  0.5761,  2.5576, -1.5130],
          [ 1.6160,  1.0231,  1.3962,  ...,  0.5761,  2.5576, -1.5130],
          [ 1.6433,  0.9863,  1.4060,  ...,  0.5460,  2.5619, -1.5374]],

         [[-1.7226, -1.5418,  0.2377,  ..., -1.8535, -0.9308,  1.1154],
          [ 0.5571, -1.9032,  0.4887,  ...,  0.4398,  0.7694, -0.1393],
          [-0.8773, -1.7808,  1.7459,  ...,  1.6040,  1.0416, -0.4245],
          ...,
          [-1.6689, -1.8515,  1.8435,  ...,  1.2579,  0.8218, -0.4802],
          [-1.6689, -1.8515,  1.8435,  ...,  1.2579,  0.8218, -0.4802],
          [-1.6354, -1.8302,  1.8835,  ...,  1.3295,  0.8276, -0.4972]],

         ...,

         [[ 4.2259, -0.8044,  2.4946,  ..., -1.4358, -1.9874,  3.6556],
          [ 4.1429, -1.2543,  0.9365,  ..., -2.8444, -2.9146,  3.7071],
          [ 3.7662, -1.2588,  0.4270,  ..., -2.1969, -3.0470,  4.6233],
          ...,
          [ 3.8689, -1.2625,  0.9608,  ..., -2.1639, -3.1746,  4.8872],
          [ 3.8689, -1.2625,  0.9608,  ..., -2.1639, -3.1746,  4.8872],
          [ 3.8461, -1.2110,  0.9740,  ..., -2.1314, -3.1520,  4.8353]],

         [[-1.4253,  3.4722,  5.0920,  ..., -0.1175,  2.5718,  1.7778],
          [ 1.2864,  4.0734,  2.5147,  ...,  0.2562,  2.9092,  2.0148],
          [ 0.2458,  2.6041,  1.1676,  ...,  1.1611,  2.5795,  1.1802],
          ...,
          [-0.0923,  2.5805,  2.2403,  ...,  1.1062,  2.0588,  1.4205],
          [-0.0923,  2.5805,  2.2403,  ...,  1.1062,  2.0588,  1.4205],
          [-0.0867,  2.5966,  2.1987,  ...,  1.1414,  2.0445,  1.4262]],

         [[-0.5101,  2.0817, -0.1329,  ..., -0.2882, -3.8950,  0.3965],
          [-0.6309,  1.6105, -0.4447,  ..., -1.9462, -5.2308, -0.2430],
          [ 0.8304,  2.4027, -0.3143,  ..., -1.8136, -2.6307, -0.4855],
          ...,
          [ 0.7552,  2.3396, -0.1358,  ..., -1.7030, -2.9343, -0.3142],
          [ 0.7552,  2.3396, -0.1358,  ..., -1.7030, -2.9343, -0.3142],
          [ 0.7585,  2.3006, -0.1527,  ..., -1.7348, -2.9583, -0.3005]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[ 1.5378e+00,  1.4616e-01,  2.2261e+00,  ..., -2.2830e+00,
            1.7818e+00,  6.8454e+00],
          [ 2.1952e+00, -2.4284e+00,  3.0440e+00,  ..., -2.7263e+00,
            1.5082e+00,  2.0543e+00],
          [ 3.1851e+00, -7.8968e-01,  1.8665e+00,  ..., -1.3539e+00,
            1.3653e+00, -2.1525e-01],
          ...,
          [ 3.3567e+00, -4.4127e-01,  1.6282e+00,  ..., -2.1696e+00,
            1.0740e+00, -1.0190e-01],
          [ 3.3567e+00, -4.4127e-01,  1.6282e+00,  ..., -2.1696e+00,
            1.0740e+00, -1.0190e-01],
          [ 3.3725e+00, -3.7312e-01,  1.6645e+00,  ..., -2.1906e+00,
            1.0265e+00, -2.2337e-01]],

         [[ 4.5752e+00, -4.6010e-01,  5.2903e+00,  ..., -4.2559e-01,
           -3.9103e-03, -4.2297e+00],
          [-4.5674e-01, -3.0143e+00,  7.1625e-01,  ...,  4.1374e+00,
            6.4247e-01,  2.4313e-01],
          [-1.7092e-01, -6.5598e-02,  3.1484e+00,  ...,  3.3721e+00,
           -1.2099e+00,  2.4574e+00],
          ...,
          [ 5.8281e-01, -6.2851e-01,  3.4611e+00,  ...,  2.3035e+00,
           -1.5808e+00,  2.0792e+00],
          [ 5.8281e-01, -6.2851e-01,  3.4611e+00,  ...,  2.3035e+00,
           -1.5808e+00,  2.0792e+00],
          [ 5.7394e-01, -5.8661e-01,  3.4242e+00,  ...,  2.3638e+00,
           -1.5770e+00,  2.1809e+00]],

         [[ 2.5498e+00,  3.2209e+00,  5.4704e-01,  ...,  3.2519e+00,
            1.7319e+00,  1.8689e-01],
          [-1.6067e+00,  6.2433e+00, -6.1812e-01,  ...,  1.8015e+00,
            3.8599e-01, -2.0579e+00],
          [ 4.4920e-01,  3.7032e+00,  1.2665e-01,  ...,  4.4935e-01,
            4.5971e-01, -2.8084e+00],
          ...,
          [ 6.6018e-01,  3.8735e+00,  5.2279e-01,  ...,  8.5738e-01,
            7.7842e-01, -3.4486e+00],
          [ 6.6018e-01,  3.8735e+00,  5.2279e-01,  ...,  8.5738e-01,
            7.7842e-01, -3.4486e+00],
          [ 6.4868e-01,  3.8441e+00,  5.6372e-01,  ...,  7.9704e-01,
            7.2606e-01, -3.5107e+00]],

         ...,

         [[ 3.8789e+00, -6.6985e-01, -2.1669e-01,  ...,  9.6079e+00,
           -1.3102e+00, -2.3750e+00],
          [ 3.1827e+00,  6.5805e-01,  9.0065e-01,  ...,  4.6071e+00,
            1.9246e+00, -4.6779e+00],
          [ 1.0106e+00, -5.6079e-01,  1.6322e+00,  ...,  1.4999e+00,
            8.1410e-01, -4.4058e+00],
          ...,
          [ 5.1026e-01, -1.0263e+00,  1.8090e+00,  ...,  2.2476e+00,
           -2.6197e-01, -4.4987e+00],
          [ 5.1026e-01, -1.0263e+00,  1.8090e+00,  ...,  2.2476e+00,
           -2.6197e-01, -4.4987e+00],
          [ 5.1464e-01, -1.0095e+00,  1.8155e+00,  ...,  2.1907e+00,
           -3.2396e-01, -4.4627e+00]],

         [[-7.6957e+00, -2.0512e-01, -1.4184e+00,  ..., -3.6351e-01,
            2.0693e+00, -7.3596e-01],
          [-5.1083e+00,  1.4366e+00, -1.8867e+00,  ...,  6.1701e+00,
            2.4058e-03, -2.0730e-01],
          [-3.2417e+00,  1.2330e+00, -2.6343e+00,  ...,  4.5528e+00,
           -1.8813e-01,  2.7919e-01],
          ...,
          [-3.1742e+00,  1.2361e+00, -2.9207e+00,  ...,  4.5065e+00,
           -1.9971e-01,  5.1452e-01],
          [-3.1742e+00,  1.2361e+00, -2.9207e+00,  ...,  4.5065e+00,
           -1.9971e-01,  5.1452e-01],
          [-3.1227e+00,  1.3183e+00, -2.9111e+00,  ...,  4.5855e+00,
           -1.7975e-01,  6.0418e-01]],

         [[ 3.2272e+00,  3.1114e+00,  7.0115e+00,  ...,  2.1251e+00,
           -5.7529e-01, -2.3629e+00],
          [ 1.0599e+00,  4.4187e+00,  4.5766e+00,  ...,  1.1459e-01,
            1.8373e+00, -3.8947e+00],
          [ 1.1973e+00,  2.5019e+00,  1.9262e+00,  ...,  9.6098e-01,
            1.2418e+00, -1.9367e+00],
          ...,
          [ 5.4555e-01,  3.1166e+00,  2.4686e+00,  ...,  2.4721e+00,
            8.6789e-01, -2.2258e+00],
          [ 5.4555e-01,  3.1166e+00,  2.4686e+00,  ...,  2.4721e+00,
            8.6789e-01, -2.2258e+00],
          [ 5.6675e-01,  3.1082e+00,  2.3089e+00,  ...,  2.4676e+00,
            8.2031e-01, -2.2647e+00]]]], grad_fn=<TransposeBackward0>), tensor([[[[-2.0295, -4.4600,  0.8229,  ..., -0.4880, -3.8051, -3.9810],
          [-2.2741, -4.1019,  0.7087,  ..., -0.5637, -3.7706, -3.6842],
          [-3.8179, -0.9542, -2.6136,  ..., -1.0998, -0.0576,  0.5739],
          ...,
          [ 1.0410, -2.2980, -2.0412,  ...,  2.6457, -2.6785,  1.4212],
          [ 1.0410, -2.2980, -2.0412,  ...,  2.6457, -2.6785,  1.4212],
          [ 1.0410, -2.2980, -2.0412,  ...,  2.6457, -2.6785,  1.4212]],

         [[-1.5981,  2.0811, -4.6121,  ...,  1.2390,  0.0562, -0.3107],
          [-2.0040,  1.9027, -4.7588,  ...,  1.5203,  0.0767, -0.1306],
          [-1.8758,  0.0525, -2.8941,  ..., -0.4330,  0.5746, -0.7455],
          ...,
          [-2.9105,  0.9460, -2.8834,  ..., -0.3359,  1.2561, -1.7515],
          [-2.9105,  0.9460, -2.8834,  ..., -0.3359,  1.2561, -1.7515],
          [-2.9105,  0.9460, -2.8834,  ..., -0.3359,  1.2561, -1.7515]],

         [[ 0.5575, -2.1440,  2.6944,  ...,  0.1324,  2.5830, -1.5628],
          [ 0.6957, -2.3722,  2.4132,  ...,  0.4570,  2.4328, -1.9267],
          [-0.0482, -2.4959,  2.6646,  ...,  0.2500, -0.4130, -2.3349],
          ...,
          [ 1.6512, -1.5231,  1.1724,  ..., -2.1762,  0.0058, -3.3955],
          [ 1.6512, -1.5231,  1.1724,  ..., -2.1762,  0.0058, -3.3955],
          [ 1.6512, -1.5231,  1.1724,  ..., -2.1762,  0.0058, -3.3955]],

         ...,

         [[-3.1132,  3.9486, -1.0419,  ..., -4.6946,  1.4022, -0.4498],
          [-3.1964,  4.1503, -1.3335,  ..., -4.4367,  1.5417, -0.3003],
          [-1.1541,  0.1357,  0.4987,  ..., -2.9714,  0.7071, -0.5805],
          ...,
          [-0.4549, -2.8995,  0.7015,  ..., -3.5934,  1.6380,  1.7764],
          [-0.4549, -2.8995,  0.7015,  ..., -3.5934,  1.6380,  1.7764],
          [-0.4549, -2.8995,  0.7015,  ..., -3.5934,  1.6380,  1.7764]],

         [[ 0.2291, -1.8184, -1.6833,  ..., -1.6161,  0.5520, -2.5566],
          [-0.0843, -1.8690, -1.3323,  ..., -1.5385,  0.9474, -3.1426],
          [ 1.1130,  1.8234,  0.8395,  ..., -0.0173, -1.1970, -3.1985],
          ...,
          [-1.7407,  0.5094,  0.8873,  ..., -1.2882,  1.1380, -2.0544],
          [-1.7407,  0.5094,  0.8873,  ..., -1.2882,  1.1380, -2.0544],
          [-1.7407,  0.5094,  0.8873,  ..., -1.2882,  1.1380, -2.0544]],

         [[ 1.4947,  0.1461,  1.2785,  ...,  1.1898, -0.0235,  0.7927],
          [ 1.1003,  0.3135,  1.5369,  ...,  0.7608, -0.5206,  0.3612],
          [ 0.8703,  2.4213,  0.2664,  ..., -0.6946,  1.8925,  1.8720],
          ...,
          [ 0.9266, -2.5636, -1.0498,  ..., -1.0591,  0.3316, -0.7819],
          [ 0.9266, -2.5636, -1.0498,  ..., -1.0591,  0.3316, -0.7819],
          [ 0.9266, -2.5636, -1.0498,  ..., -1.0591,  0.3316, -0.7819]]]],
       grad_fn=<TransposeBackward0>), tensor([[[[-2.0419e+00, -3.4137e+00,  8.9518e+00,  ..., -5.6484e+00,
           -8.7039e+00, -5.2318e+00],
          [-1.6504e+00, -3.7605e+00,  7.9005e+00,  ..., -5.5971e+00,
           -7.9206e+00, -5.6725e+00],
          [ 2.9998e+00, -8.5156e+00,  3.2183e+00,  ..., -3.8826e+00,
           -3.0613e+00, -5.7757e+00],
          ...,
          [ 1.4919e+01, -1.0306e+01,  1.0666e+01,  ...,  6.8907e+00,
           -5.1657e+00, -3.6152e+00],
          [ 1.4919e+01, -1.0306e+01,  1.0666e+01,  ...,  6.8907e+00,
           -5.1657e+00, -3.6152e+00],
          [ 1.4919e+01, -1.0306e+01,  1.0666e+01,  ...,  6.8907e+00,
           -5.1657e+00, -3.6152e+00]],

         [[-1.0402e-01,  1.1980e+00,  1.5691e-01,  ...,  2.9742e+00,
           -1.4255e-01, -2.9674e+00],
          [ 1.4508e+00,  9.1208e-01, -1.4536e-01,  ...,  1.2346e+00,
            8.2411e-01, -2.1496e+00],
          [ 7.5831e+00, -2.7479e+00, -7.8653e-01,  ..., -8.8103e+00,
           -3.9566e+00, -1.7331e+00],
          ...,
          [ 1.3588e+01,  4.9676e+00, -3.0800e+00,  ...,  4.9436e+00,
            1.6062e+01,  1.5308e+01],
          [ 1.3588e+01,  4.9676e+00, -3.0800e+00,  ...,  4.9436e+00,
            1.6062e+01,  1.5308e+01],
          [ 1.3588e+01,  4.9676e+00, -3.0800e+00,  ...,  4.9436e+00,
            1.6062e+01,  1.5308e+01]],

         [[ 9.6242e+00, -6.0907e+00, -8.5869e+00,  ...,  3.4619e+00,
           -3.0602e+00, -4.3327e+00],
          [ 9.4869e+00, -4.2836e+00, -6.2826e+00,  ...,  2.3134e+00,
           -1.9831e+00, -5.1728e+00],
          [-3.4782e+00,  3.0170e+00,  1.3237e+01,  ...,  8.4086e+00,
            6.2151e+00, -2.6405e+00],
          ...,
          [ 1.4175e+01, -7.1945e+00,  1.3463e+01,  ...,  2.2113e+01,
           -1.3405e+01,  1.2378e+01],
          [ 1.4175e+01, -7.1945e+00,  1.3463e+01,  ...,  2.2113e+01,
           -1.3405e+01,  1.2378e+01],
          [ 1.4175e+01, -7.1945e+00,  1.3463e+01,  ...,  2.2113e+01,
           -1.3405e+01,  1.2378e+01]],

         ...,

         [[-1.6623e+01, -2.0619e+00, -3.8565e+00,  ...,  9.6094e+00,
           -1.3770e+00,  1.3978e+01],
          [-1.4983e+01, -2.6902e+00, -4.5743e+00,  ...,  8.8884e+00,
           -4.0477e+00,  1.4615e+01],
          [ 2.9045e+00,  1.9545e+00,  7.3804e+00,  ...,  1.2686e+01,
           -7.0304e+00, -1.2733e+01],
          ...,
          [ 8.7789e+00,  4.9394e+00,  1.5792e+01,  ...,  1.4738e+01,
           -2.0569e+00,  5.9586e+00],
          [ 8.7789e+00,  4.9394e+00,  1.5792e+01,  ...,  1.4738e+01,
           -2.0569e+00,  5.9586e+00],
          [ 8.7789e+00,  4.9394e+00,  1.5792e+01,  ...,  1.4738e+01,
           -2.0569e+00,  5.9586e+00]],

         [[ 5.2026e+00, -1.7273e+00, -5.9830e+00,  ..., -5.8807e+00,
           -7.4595e+00, -1.5658e+00],
          [ 4.8016e+00, -1.4822e+00, -6.8809e+00,  ..., -2.2992e+00,
           -5.2676e+00, -1.6409e+00],
          [-2.4262e+00,  8.6801e-01, -7.9895e+00,  ..., -2.0361e+00,
            1.1964e+01,  1.0700e+01],
          ...,
          [-7.9648e+00,  3.8391e-01,  1.8819e+01,  ...,  3.6906e+00,
            2.1696e+00,  9.7356e+00],
          [-7.9648e+00,  3.8391e-01,  1.8819e+01,  ...,  3.6906e+00,
            2.1696e+00,  9.7356e+00],
          [-7.9648e+00,  3.8391e-01,  1.8819e+01,  ...,  3.6906e+00,
            2.1696e+00,  9.7356e+00]],

         [[ 1.0809e+00, -3.2577e+00,  1.6504e+00,  ...,  8.3112e+00,
           -1.2314e+01,  8.7786e+00],
          [ 3.3897e+00, -2.9281e+00,  1.1509e+00,  ...,  8.7709e+00,
           -1.0923e+01,  8.6924e+00],
          [ 1.1743e+01, -1.4745e+00,  7.6639e+00,  ...,  7.1697e-01,
           -1.5021e-02,  9.8263e+00],
          ...,
          [ 8.4227e+00, -5.6090e+00, -7.6515e+00,  ..., -1.9268e+01,
            1.0934e+01,  3.7467e+00],
          [ 8.4227e+00, -5.6090e+00, -7.6515e+00,  ..., -1.9268e+01,
            1.0934e+01,  3.7467e+00],
          [ 8.4227e+00, -5.6090e+00, -7.6515e+00,  ..., -1.9268e+01,
            1.0934e+01,  3.7467e+00]]]], grad_fn=<TransposeBackward0>))), decoder_hidden_states=None, decoder_attentions=None, cross_attentions=None, encoder_last_hidden_state=tensor([[[ 9.3725e-02, -1.4927e-01,  2.5169e-01,  ..., -2.8456e-01,
          -6.9041e-02, -1.0899e-01],
         [ 1.6471e-01, -1.6734e-01,  1.7760e-01,  ..., -2.3902e-01,
          -6.5672e-02, -1.0356e-01],
         [-7.8680e-02, -1.9931e-01,  2.5028e-01,  ..., -4.0804e-03,
           1.5391e-01,  1.7626e-02],
         ...,
         [-2.3925e-04,  7.8143e-02,  2.5111e-01,  ..., -8.6810e-02,
           6.8724e-02,  4.7175e-02],
         [-2.3925e-04,  7.8143e-02,  2.5111e-01,  ..., -8.6810e-02,
           6.8724e-02,  4.7175e-02],
         [-2.3925e-04,  7.8143e-02,  2.5111e-01,  ..., -8.6810e-02,
           6.8724e-02,  4.7175e-02]]], grad_fn=<MulBackward0>), encoder_hidden_states=None, encoder_attentions=None)
:Dictionary inputs to traced functions must have consistent type. Found Tensor and Tuple[Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor, Tensor, Tensor]]

In [66]:
trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=T5DataCollator(),
)

Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.


In [67]:
if training_args.do_train:
    trainer.train(
        model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
    )
    trainer.save_model()
    # For convenience, we also re-save the tokenizer to the same directory,
    # so that you can share your model easily on huggingface.co/models =)
    if trainer.is_world_master():
        tokenizer.save_pretrained(training_args.output_dir)

Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.
Using deprecated `--per_gpu_train_batch_size` argument which will be removed in a future version. Using `--per_device_train_batch_size` is preferred.
***** Running training *****
  Num examples = 1990
  Num Epochs = 4
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 32
  Gradient Accumulation steps = 4
  Total optimization steps = 248
  Number of trainable parameters = 222903552


  0%|          | 0/248 [00:00<?, ?it/s]

The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: target_attention_mask, target_ids. If target_attention_mask, target_ids are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.


ValueError: You have to specify either decoder_input_ids or decoder_inputs_embeds