In [1]:
import pandas as pd
import pickle

from tqdm import tqdm

tqdm.pandas()

import torch
from torch.utils.data import Dataset


import torch, torch.nn as nn
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor
from peft import get_peft_model, LoraConfig, TaskType


from torch.nn.utils.rnn import pad_sequence
from transformers import BatchEncoding

In [2]:
path_to_dataset = "/mnt/core-llm/unisrec/data/raw/user_history_users_8_000_000_nmid_1_000_000_raw_2024_05_30"
train_df = pd.read_parquet(path_to_dataset)

In [3]:
# Фиксируем random seed для воспроизводимости
random_seed = 42

# Выделяем 10% случайных индексов для валидации
val_df = train_df.sample(frac=0.1, random_state=random_seed)

# Убираем эти строки из train_df
train_df = train_df.drop(val_df.index)

# Из оставшегося train_df выделяем 10% для теста
test_df = train_df.sample(frac=0.1, random_state=random_seed)

# Убираем эти строки из train_df
train_df = train_df.drop(test_df.index)

print(f"Train size: {len(train_df)}")
print(f"Validation size: {len(val_df)}")
print(f"Test size: {len(test_df)}")


Train size: 6458215
Validation size: 797310
Test size: 717580


In [4]:
with open("/mnt/core-llm/unisrec/data/downstream_text_cv_50M_users_1kk_nm_ids/cache/filtered_item2index.pkl", "rb") as f:
    filtered_item2index = pickle.load(f)

# Создаём обратный словарь item -> index
filtered_index2item = {idx: item for item, idx in filtered_item2index.items()}


In [5]:
def prepare_train_examples(df, filtered_item2index):

    def process_nm_ids(nm_ids):
        return [
            filtered_item2index[nm_id]
            for nm_id in nm_ids
            if nm_id in filtered_item2index
        ]
    
    df["nm_ids"] = df["nm_ids"].progress_apply(process_nm_ids)

    # Фильтруем строки с <2 интеракциями
    df = df[df["nm_ids"].progress_apply(len) > 1]

    examples = []

    for row in tqdm(df.itertuples(index=False), total=len(df), desc="Building examples"):
        user_id = row.user_id
        ids = row.nm_ids
        titles = row.titles

        for i in range(1, len(ids)):
            examples.append({
                "text": f"user {user_id}, history purchase: " + ", ".join(titles[:i]),
                "ids": ids[:i],
                "target": ids[i]
            })

    return pd.DataFrame(examples)


In [6]:
class RecommendationDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        conversation = [
            {
                "role": "system",
                "content": [
                    {
                        "type": "text",
                        "text": (
                            "You are a recommendation system assistant designed to provide personalized product suggestions "
                            "based on user purchase history and preferences. You can process both textual information and "
                            "item IDs representing products the user interacted with. Respond clearly and helpfully."
                        )
                    }
                ],
            },
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": row['text']},
                    # {"type": "id", "id": row['ids']}
                ]
            },
        ]
        
        target = row['target']
        ids = row['ids']
        return conversation, ids, target


In [7]:
train_examples_df = prepare_train_examples(train_df, filtered_item2index)
val_examples_df = prepare_train_examples(val_df, filtered_item2index)  # Аналогично, для валидации

train_ds = RecommendationDataset(train_examples_df)
val_ds = RecommendationDataset(val_examples_df)

100%|██████████| 6458215/6458215 [00:33<00:00, 195261.05it/s]
100%|██████████| 6458215/6458215 [00:03<00:00, 1921067.68it/s]
Building examples: 100%|██████████| 5246979/5246979 [02:29<00:00, 35098.25it/s]
100%|██████████| 797310/797310 [00:04<00:00, 186260.60it/s]
100%|██████████| 797310/797310 [00:00<00:00, 1801930.04it/s]
Building examples: 100%|██████████| 647992/647992 [00:24<00:00, 26141.25it/s]


In [8]:
from transformers import Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration

# Используем Thinker, у него forward реализован
model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2.5-Omni-7B", torch_dtype=torch.float16, device_map="auto"
)
processor = Qwen2_5OmniProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")

text = "user 232313, history purchase: белые джинсы, розовая помада"
inputs = processor.tokenizer(text, return_tensors="pt").to(model.device)

outputs = model(**inputs, output_hidden_states=True, return_dict=True)
# print(outputs.hidden_states)
# print(len(outputs.hidden_states))
last_hidden_state = outputs.hidden_states[-1]
print(last_hidden_state.shape)


Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

torch.Size([1, 25, 3584])


In [10]:
import torch
import torch.nn as nn
from transformers import Qwen2_5OmniForConditionalGeneration, Qwen2_5OmniProcessor, Qwen2_5OmniThinkerForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType
from torch.nn.utils.rnn import pad_sequence
 

class QwenTextEncoder(nn.Module):
    def __init__(self, ckpt="Qwen/Qwen2.5-Omni-7B", reduced_dim=1024, device='cpu'):
        super().__init__()
        self.device = device
        self.model = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
            ckpt, torch_dtype=torch.float16, device_map={"" : str(device)}
        )
        # device_map="auto"
        self.processor = Qwen2_5OmniProcessor.from_pretrained(ckpt)
        self.hidden_size = self.model.config.text_config.hidden_size
        self.reduced_dim = reduced_dim
        
        self.pool_proj = nn.Linear(self.hidden_size, self.reduced_dim).to(torch.float16)

        for name, param in self.model.named_parameters():
            param.requires_grad = False

    def forward(self, text):
        inputs = self.processor.tokenizer(
            text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).to(self.device)

        outputs = self.model(**inputs, output_hidden_states=True, return_dict=True)
        # print(outputs.hidden_states)
        # print(len(outputs.hidden_states))
        last_hidden_state = outputs.hidden_states[-1]
        # print(f"last_hidden_state = {last_hidden_state.shape}")
        pooled = last_hidden_state.mean(dim=1)  # mean pooling по seq_len -> (B, D)
        # print(f"pooled.shape = {pooled.shape}")
        projected = self.pool_proj(pooled)
        # print(f"projected.shape = {projected.shape}")
        return projected


class FusionHead(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super().__init__()
        self.dense1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.dense2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = self.dense1(x)
        x = self.relu(x)
        x = self.dense2(x)
        return x

class QwenOmniWithID(nn.Module):
    def __init__(self, ckpt, id_vocab_size=1_000_000, id_dim=512, fusion_dim=1024, reduced_dim=1024, device='cpu'):
        super().__init__()
        self.text_encoder = QwenTextEncoder(ckpt, reduced_dim, device)
        
        self.id_emb = nn.Embedding(id_vocab_size, id_dim)
        self.id_proj = nn.Linear(id_dim, self.text_encoder.reduced_dim)
        
        # Fusion head для предсказания следующего id
        self.fusion_head = FusionHead(
            input_dim=self.text_encoder.reduced_dim * 2,
            hidden_dim=fusion_dim,
            output_dim=id_vocab_size
        )


    def forward(self, text=None, id_ids=None, labels=None):
        # Текстовый эмбеддинг
        text_emb = self.text_encoder(text)  # (B, hidden_size)
        
        # ID эмбеддинг
        if id_ids is not None:
            ids_tensor = pad_sequence(
                [torch.tensor(ids, dtype=torch.long) for ids in id_ids],
                batch_first=True,
                padding_value=0
            ).to(text_emb.device)
            id_embeds = self.id_emb(ids_tensor).mean(dim=1)  # (B, id_dim)
            id_proj = self.id_proj(id_embeds)  # (B, hidden_size)
        else:
            id_proj = torch.zeros_like(text_emb)

        print(f"id_proj = {id_proj.shape}")
        # Fusion (конкатенация)
        fused = torch.cat([text_emb, id_proj], dim=1)  # (B, hidden_size*2)
        print(f"fused shape = {fused.shape}")
        logits = self.fusion_head(fused)  # (B, id_vocab_size)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits, labels)
        
        return {"loss": loss, "logits": logits}


In [11]:
from torch.utils.data import DataLoader
from transformers import default_data_collator

def collate_fn(batch):
    conversations, id_ids, targets = zip(*batch)
    
    texts = [c[1]['content'][0]['text'] for c in conversations]  # user text
    id_seqs = list(id_ids)
    labels = torch.tensor(targets, dtype=torch.long)

    return {
        "text": texts,
        "id_ids": id_seqs,
        "labels": labels
    }

train_loader = DataLoader(train_ds, batch_size=8, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=8, shuffle=False, collate_fn=collate_fn)


In [12]:
from transformers import get_scheduler
from peft import prepare_model_for_kbit_training

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = QwenOmniWithID(ckpt="Qwen/Qwen2.5-Omni-7B", id_vocab_size=len(filtered_item2index), reduced_dim=1024, device=device).to(device)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)

num_epochs = 3
num_training_steps = len(train_loader) * num_epochs
lr_scheduler = get_scheduler(
    "linear", optimizer=optimizer, num_warmup_steps=100, num_training_steps=num_training_steps
)


Unrecognized keys in `rope_scaling` for 'rope_type'='default': {'mrope_section'}


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

In [13]:
from tqdm import tqdm

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1} [train]"):
        batch = {k: v for k, v in batch.items()}
        batch["text"] = batch["text"]  # already list[str]
        batch["id_ids"] = batch["id_ids"]
        batch["labels"] = batch["labels"].to(device)

        outputs = model(
            text=batch["text"],
            id_ids=batch["id_ids"],
            labels=batch["labels"]
        )

        loss = outputs["loss"]
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.item()

    avg_train_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Train Loss: {avg_train_loss:.4f}")

    ### Eval ###
    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in tqdm(val_loader, desc=f"Epoch {epoch+1} [val]"):
            batch["labels"] = batch["labels"].to(device)
            outputs = model(
                text=batch["text"],
                id_ids=batch["id_ids"],
                labels=batch["labels"]
            )
            val_loss += outputs["loss"].item()

            logits = outputs["logits"]
            preds = logits.argmax(dim=-1)
            correct += (preds == batch["labels"]).sum().item()
            total += len(batch["labels"])

    acc = correct / total
    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1} Val Loss: {avg_val_loss:.4f}, Accuracy: {acc:.4f}")


Epoch 1 [train]:   0%|          | 0/4699119 [00:00<?, ?it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1/4699119 [00:02<3502:39:48,  2.68s/it]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2/4699119 [00:03<1799:54:45,  1.38s/it]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3/4699119 [00:03<1154:53:57,  1.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4/4699119 [00:03<882:19:00,  1.48it/s] 

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5/4699119 [00:04<733:44:36,  1.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6/4699119 [00:04<661:29:20,  1.97it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 7/4699119 [00:05<645:40:44,  2.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 8/4699119 [00:05<632:38:05,  2.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 9/4699119 [00:05<623:25:22,  2.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 10/4699119 [00:06<592:03:51,  2.20it/s]

last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 11/4699119 [00:06<587:14:38,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 12/4699119 [00:07<593:27:25,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 346, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 13/4699119 [00:07<548:53:55,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 14/4699119 [00:07<508:27:34,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 15/4699119 [00:08<520:41:46,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 16/4699119 [00:08<454:43:02,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 17/4699119 [00:09<503:26:05,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 18/4699119 [00:09<446:52:36,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 19/4699119 [00:09<434:37:31,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 20/4699119 [00:10<487:51:31,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 21/4699119 [00:10<462:21:33,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 22/4699119 [00:10<506:06:34,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 114, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 23/4699119 [00:11<425:46:29,  3.07it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 24/4699119 [00:11<479:47:30,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 25/4699119 [00:11<497:36:15,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 26/4699119 [00:12<450:39:08,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 27/4699119 [00:12<499:37:45,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 28/4699119 [00:12<457:49:10,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 29/4699119 [00:13<451:36:18,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 30/4699119 [00:13<412:15:28,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 31/4699119 [00:13<410:52:25,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 166, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 32/4699119 [00:14<370:14:03,  3.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 33/4699119 [00:14<371:05:04,  3.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 285, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 34/4699119 [00:14<376:45:22,  3.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 35/4699119 [00:15<445:50:33,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 36/4699119 [00:15<461:09:54,  2.83it/s]

last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 37/4699119 [00:15<437:33:16,  2.98it/s]

last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 38/4699119 [00:16<429:22:48,  3.04it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 39/4699119 [00:16<483:31:48,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 40/4699119 [00:16<460:45:19,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 41/4699119 [00:17<487:13:45,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 42/4699119 [00:17<455:17:07,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 43/4699119 [00:17<418:05:05,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 44/4699119 [00:18<390:25:36,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 45/4699119 [00:18<456:55:48,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 46/4699119 [00:18<503:24:46,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 47/4699119 [00:19<444:50:00,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 48/4699119 [00:19<492:55:55,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 49/4699119 [00:20<485:39:18,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 50/4699119 [00:20<522:43:23,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 51/4699119 [00:20<496:10:29,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 52/4699119 [00:21<522:41:54,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 53/4699119 [00:21<547:56:31,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 54/4699119 [00:22<483:36:26,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 55/4699119 [00:22<469:48:51,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 56/4699119 [00:22<436:51:23,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 57/4699119 [00:23<470:49:31,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 58/4699119 [00:23<482:19:16,  2.71it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 59/4699119 [00:23<520:51:53,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 60/4699119 [00:24<496:16:05,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 61/4699119 [00:24<493:28:26,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 62/4699119 [00:25<528:05:46,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 155, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 63/4699119 [00:25<448:56:55,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 64/4699119 [00:25<497:31:08,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 65/4699119 [00:26<464:31:14,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 66/4699119 [00:26<482:25:47,  2.71it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 67/4699119 [00:26<520:53:10,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 68/4699119 [00:27<504:31:21,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 69/4699119 [00:27<486:52:47,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 70/4699119 [00:28<523:00:55,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 71/4699119 [00:28<549:12:37,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 72/4699119 [00:28<467:51:19,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 73/4699119 [00:29<474:14:42,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 439, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 74/4699119 [00:29<498:20:38,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 75/4699119 [00:30<531:51:38,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 76/4699119 [00:30<503:24:42,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 77/4699119 [00:30<515:19:09,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 78/4699119 [00:31<651:00:41,  2.01it/s]

last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 79/4699119 [00:31<639:32:36,  2.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 80/4699119 [00:32<630:03:50,  2.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 81/4699119 [00:32<556:14:05,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 82/4699119 [00:33<561:30:02,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 83/4699119 [00:33<499:58:57,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 84/4699119 [00:33<511:55:18,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 85/4699119 [00:34<488:22:24,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 86/4699119 [00:34<446:54:05,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 87/4699119 [00:34<496:26:07,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 88/4699119 [00:35<531:32:01,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 89/4699119 [00:35<516:35:30,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 90/4699119 [00:36<460:45:04,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 91/4699119 [00:36<505:27:29,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 169, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 92/4699119 [00:36<437:17:54,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 93/4699119 [00:37<489:58:44,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 94/4699119 [00:37<525:34:29,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 95/4699119 [00:38<551:20:33,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 96/4699119 [00:38<492:28:03,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 97/4699119 [00:38<484:12:49,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 98/4699119 [00:39<495:39:00,  2.63it/s]

last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 99/4699119 [00:39<464:37:47,  2.81it/s]

last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 100/4699119 [00:39<469:28:55,  2.78it/s]

last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 101/4699119 [00:40<443:18:48,  2.94it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 102/4699119 [00:40<494:46:20,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 103/4699119 [00:40<460:22:29,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 104/4699119 [00:41<439:15:39,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 105/4699119 [00:41<434:32:59,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 171, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 106/4699119 [00:41<388:20:22,  3.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 135, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 107/4699119 [00:41<348:48:28,  3.74it/s]

last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 108/4699119 [00:42<375:09:25,  3.48it/s]

last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 109/4699119 [00:42<426:58:41,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 110/4699119 [00:42<415:13:32,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 111/4699119 [00:43<424:07:37,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 112/4699119 [00:43<400:07:38,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 114/4699119 [00:44<397:10:25,  3.29it/s]

last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 115/4699119 [00:44<400:33:41,  3.26it/s]

last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 116/4699119 [00:44<449:57:54,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 117/4699119 [00:45<429:49:50,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 284, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 118/4699119 [00:45<418:01:35,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 119/4699119 [00:45<434:22:02,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 120/4699119 [00:46<489:20:31,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 121/4699119 [00:46<526:21:18,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 122/4699119 [00:47<541:33:21,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 123/4699119 [00:47<509:43:30,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 124/4699119 [00:47<489:29:32,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 125/4699119 [00:48<486:52:15,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 126/4699119 [00:48<512:12:55,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 127/4699119 [00:49<510:54:06,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 128/4699119 [00:49<542:19:59,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 129/4699119 [00:49<504:45:47,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 130/4699119 [00:50<510:59:17,  2.55it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 131/4699119 [00:50<542:35:24,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 132/4699119 [00:51<542:32:46,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 133/4699119 [00:51<564:02:35,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 134/4699119 [00:52<516:57:57,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 135/4699119 [00:52<464:26:44,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 479, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 136/4699119 [00:52<503:42:59,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 137/4699119 [00:53<536:28:18,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 138/4699119 [00:53<453:59:07,  2.88it/s]

last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 139/4699119 [00:53<455:57:39,  2.86it/s]

last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 140/4699119 [00:54<446:07:12,  2.93it/s]

last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 141/4699119 [00:54<440:10:16,  2.97it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 142/4699119 [00:54<492:45:28,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 143/4699119 [00:55<530:14:09,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 131, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 145/4699119 [00:55<404:21:07,  3.23it/s]

last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 146/4699119 [00:56<450:46:32,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 147/4699119 [00:56<499:52:05,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 148/4699119 [00:57<484:42:39,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 149/4699119 [00:57<512:21:17,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 150/4699119 [00:57<498:15:56,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 151/4699119 [00:58<489:13:12,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 152/4699119 [00:58<434:37:50,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 153/4699119 [00:58<426:27:58,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 154/4699119 [00:59<483:44:14,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 155/4699119 [00:59<522:02:03,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 156/4699119 [00:59<487:45:42,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 157/4699119 [01:00<528:10:08,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 158/4699119 [01:00<486:55:47,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 465, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 159/4699119 [01:01<515:45:53,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 160/4699119 [01:01<546:31:09,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 161/4699119 [01:02<566:59:12,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 162/4699119 [01:02<490:44:28,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 163/4699119 [01:02<527:31:46,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 164/4699119 [01:03<554:50:45,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 165/4699119 [01:03<521:31:09,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 166/4699119 [01:04<550:15:04,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 511, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 167/4699119 [01:04<575:39:50,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 168/4699119 [01:04<509:03:41,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 170, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 169/4699119 [01:05<439:27:01,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 170/4699119 [01:05<415:37:14,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 171/4699119 [01:05<426:36:38,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 172/4699119 [01:05<390:10:38,  3.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 173/4699119 [01:06<457:12:57,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 174/4699119 [01:06<506:28:47,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 175/4699119 [01:07<485:06:36,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 176/4699119 [01:07<524:41:26,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 177/4699119 [01:07<469:37:17,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 178/4699119 [01:08<513:47:59,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 179/4699119 [01:08<510:25:26,  2.56it/s]

last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 180/4699119 [01:09<471:41:34,  2.77it/s]

last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 181/4699119 [01:09<474:50:30,  2.75it/s]

last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 182/4699119 [01:09<473:36:53,  2.76it/s]

last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 183/4699119 [01:10<479:14:30,  2.72it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 184/4699119 [01:10<520:34:56,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 185/4699119 [01:10<471:56:49,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 186/4699119 [01:11<515:00:17,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 187/4699119 [01:11<496:55:04,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 188/4699119 [01:12<534:14:18,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 189/4699119 [01:12<558:46:13,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 190/4699119 [01:13<535:12:25,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 191/4699119 [01:13<584:03:20,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 124, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 192/4699119 [01:13<482:09:06,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 193/4699119 [01:14<437:54:49,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 194/4699119 [01:14<492:51:26,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 195/4699119 [01:15<532:34:49,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 445, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 196/4699119 [01:15<541:39:15,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 197/4699119 [01:15<513:47:23,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 198/4699119 [01:16<527:30:09,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 199/4699119 [01:16<559:40:24,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 200/4699119 [01:17<495:19:04,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 201/4699119 [01:17<532:03:47,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 202/4699119 [01:17<464:46:28,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 203/4699119 [01:18<459:15:04,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 204/4699119 [01:18<506:10:41,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 205/4699119 [01:19<541:20:33,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 206/4699119 [01:19<544:04:25,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 207/4699119 [01:19<540:10:44,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 208/4699119 [01:20<563:43:56,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 402, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 209/4699119 [01:20<548:36:28,  2.38it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 210/4699119 [01:21<570:01:00,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 211/4699119 [01:21<583:59:59,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 212/4699119 [01:22<594:24:50,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 213/4699119 [01:22<589:06:53,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 214/4699119 [01:22<575:40:55,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 215/4699119 [01:23<589:05:02,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 216/4699119 [01:23<540:45:48,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 217/4699119 [01:24<469:35:33,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 218/4699119 [01:24<432:14:17,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 505, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 219/4699119 [01:24<493:11:04,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 220/4699119 [01:25<458:10:01,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 221/4699119 [01:25<493:58:51,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 222/4699119 [01:25<515:33:50,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 223/4699119 [01:26<545:30:04,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 224/4699119 [01:26<568:42:57,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 225/4699119 [01:27<584:43:19,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 226/4699119 [01:27<595:33:17,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 167, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 227/4699119 [01:28<500:57:43,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 165, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 228/4699119 [01:28<433:47:13,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 229/4699119 [01:28<405:04:31,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 230/4699119 [01:28<384:11:42,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 231/4699119 [01:29<420:48:25,  3.10it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 232/4699119 [01:29<486:29:14,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 233/4699119 [01:30<519:32:36,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 234/4699119 [01:30<455:08:48,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 235/4699119 [01:30<505:38:55,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 236/4699119 [01:31<488:37:37,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 237/4699119 [01:31<458:34:54,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 238/4699119 [01:31<460:18:22,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 239/4699119 [01:32<454:20:58,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 117, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 240/4699119 [01:32<390:06:41,  3.35it/s]

last_hidden_state = torch.Size([8, 235, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 241/4699119 [01:32<377:52:23,  3.45it/s]

last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 243/4699119 [01:33<361:27:08,  3.61it/s]

last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 244/4699119 [01:33<356:03:32,  3.67it/s]

last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 245/4699119 [01:33<385:02:50,  3.39it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 246/4699119 [01:34<455:12:13,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 247/4699119 [01:34<422:54:25,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 482, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 248/4699119 [01:34<479:01:49,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 249/4699119 [01:35<513:50:38,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 250/4699119 [01:35<508:17:45,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 251/4699119 [01:36<450:04:07,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 252/4699119 [01:36<461:03:03,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 253/4699119 [01:36<509:41:51,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 254/4699119 [01:37<495:41:26,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 255/4699119 [01:37<453:43:58,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 462, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 256/4699119 [01:37<488:16:56,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 257/4699119 [01:38<488:42:28,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 258/4699119 [01:38<450:39:35,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 259/4699119 [01:39<501:45:02,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 260/4699119 [01:39<537:56:53,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 261/4699119 [01:39<476:22:24,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 262/4699119 [01:40<511:04:07,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 263/4699119 [01:40<543:23:27,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 264/4699119 [01:41<532:54:44,  2.45it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 265/4699119 [01:41<558:43:37,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 266/4699119 [01:41<483:25:50,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 267/4699119 [01:42<524:38:12,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 268/4699119 [01:42<552:45:33,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 269/4699119 [01:43<488:01:26,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 270/4699119 [01:43<528:24:36,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 271/4699119 [01:44<556:12:28,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 272/4699119 [01:44<484:46:12,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 273/4699119 [01:44<526:47:55,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 274/4699119 [01:44<448:19:37,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 181, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 275/4699119 [01:45<403:54:01,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 276/4699119 [01:45<459:07:57,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 467, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 277/4699119 [01:46<500:23:53,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 278/4699119 [01:46<473:40:02,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 279/4699119 [01:46<436:37:13,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 280/4699119 [01:47<492:28:14,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 281/4699119 [01:47<472:57:50,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 282/4699119 [01:47<517:02:28,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 283/4699119 [01:48<484:51:41,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 284/4699119 [01:48<489:18:47,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 285/4699119 [01:48<446:17:03,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 286/4699119 [01:49<460:30:10,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 287/4699119 [01:49<483:22:44,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 122, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 288/4699119 [01:49<411:20:11,  3.17it/s]

last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 289/4699119 [01:50<455:22:38,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 290/4699119 [01:50<505:35:46,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 154, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 291/4699119 [01:50<433:24:38,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 292/4699119 [01:51<490:20:14,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 293/4699119 [01:51<485:26:57,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 294/4699119 [01:52<443:58:38,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 295/4699119 [01:52<497:40:36,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 296/4699119 [01:52<499:55:44,  2.61it/s]

last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 297/4699119 [01:53<464:25:53,  2.81it/s]

last_hidden_state = torch.Size([8, 492, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 298/4699119 [01:53<510:11:12,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 299/4699119 [01:54<544:02:39,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 300/4699119 [01:54<484:52:29,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 301/4699119 [01:54<439:01:18,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 490, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 302/4699119 [01:55<491:29:50,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 303/4699119 [01:55<530:51:24,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 304/4699119 [01:56<538:00:09,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 305/4699119 [01:56<564:40:44,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 306/4699119 [01:57<583:31:14,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 307/4699119 [01:57<508:38:21,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 308/4699119 [01:57<495:01:29,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 309/4699119 [01:58<532:48:47,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 310/4699119 [01:58<465:33:28,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 132, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 311/4699119 [01:58<402:15:48,  3.24it/s]

last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 312/4699119 [01:58<450:57:53,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 313/4699119 [01:59<444:16:30,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 170, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 314/4699119 [01:59<395:30:43,  3.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 316/4699119 [01:59<316:35:08,  4.12it/s]

last_hidden_state = torch.Size([8, 111, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 317/4699119 [02:00<311:45:12,  4.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 318/4699119 [02:00<405:06:17,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 319/4699119 [02:01<471:57:08,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 320/4699119 [02:01<445:42:43,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 462, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 321/4699119 [02:01<483:25:34,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 322/4699119 [02:02<443:56:13,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 323/4699119 [02:02<456:04:02,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 324/4699119 [02:02<504:43:46,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 325/4699119 [02:03<504:28:34,  2.59it/s]

last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 327/4699119 [02:03<416:21:11,  3.13it/s]

last_hidden_state = torch.Size([8, 130, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 328/4699119 [02:04<477:28:43,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 329/4699119 [02:04<462:45:44,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 330/4699119 [02:05<511:27:40,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 331/4699119 [02:05<481:52:53,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 332/4699119 [02:05<525:01:03,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 333/4699119 [02:06<464:51:59,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 334/4699119 [02:06<469:06:58,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 335/4699119 [02:06<433:32:08,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 336/4699119 [02:07<437:30:50,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 337/4699119 [02:07<436:32:28,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 338/4699119 [02:07<493:55:45,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 339/4699119 [02:08<533:51:24,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 340/4699119 [02:08<560:33:40,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 341/4699119 [02:09<483:58:43,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 342/4699119 [02:09<473:48:05,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 343/4699119 [02:09<493:44:10,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 344/4699119 [02:10<534:52:06,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 345/4699119 [02:10<520:37:35,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 346/4699119 [02:11<480:11:28,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 347/4699119 [02:11<496:13:00,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 348/4699119 [02:11<440:31:03,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 349/4699119 [02:12<496:09:43,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 350/4699119 [02:12<449:39:09,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 351/4699119 [02:12<439:50:55,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 352/4699119 [02:13<413:43:56,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 353/4699119 [02:13<477:54:55,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 354/4699119 [02:13<483:05:05,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 355/4699119 [02:14<526:07:05,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 356/4699119 [02:14<555:54:21,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 357/4699119 [02:15<577:34:54,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 358/4699119 [02:15<569:48:25,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 359/4699119 [02:16<523:49:46,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 360/4699119 [02:16<507:03:29,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 361/4699119 [02:16<541:14:18,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 362/4699119 [02:17<520:08:05,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 494, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 363/4699119 [02:17<549:13:18,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 364/4699119 [02:18<537:03:53,  2.43it/s]

last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 365/4699119 [02:18<512:09:24,  2.55it/s]

last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 366/4699119 [02:18<503:07:13,  2.59it/s]

last_hidden_state = torch.Size([8, 480, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 367/4699119 [02:19<526:52:31,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 150, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 368/4699119 [02:19<448:12:51,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 125, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 369/4699119 [02:19<387:37:27,  3.37it/s]

last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 370/4699119 [02:20<397:25:13,  3.28it/s]

last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 372/4699119 [02:20<390:34:40,  3.34it/s]

last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 373/4699119 [02:20<357:55:34,  3.65it/s]

last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 374/4699119 [02:21<438:49:08,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 375/4699119 [02:21<402:29:34,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 376/4699119 [02:21<419:35:00,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 377/4699119 [02:22<459:13:47,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 378/4699119 [02:22<450:00:40,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 379/4699119 [02:23<451:01:29,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 380/4699119 [02:23<449:07:17,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 381/4699119 [02:23<408:10:47,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 382/4699119 [02:24<447:26:19,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 383/4699119 [02:24<457:25:57,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 131, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 384/4699119 [02:24<396:31:25,  3.29it/s]

last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 385/4699119 [02:25<449:44:50,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 386/4699119 [02:25<438:32:24,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 387/4699119 [02:25<426:22:14,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 388/4699119 [02:26<486:08:59,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 389/4699119 [02:26<528:22:07,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 390/4699119 [02:27<558:41:51,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 391/4699119 [02:27<515:54:09,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 392/4699119 [02:27<505:01:32,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 393/4699119 [02:28<443:23:19,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 510, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 394/4699119 [02:28<499:06:30,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 395/4699119 [02:28<477:00:16,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 399, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 396/4699119 [02:29<491:32:34,  2.66it/s]

last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 397/4699119 [02:29<496:40:49,  2.63it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 398/4699119 [02:30<535:26:24,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 399/4699119 [02:30<527:07:25,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 400/4699119 [02:30<472:41:25,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 401/4699119 [02:31<516:57:16,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 402/4699119 [02:31<464:28:59,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 401, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 403/4699119 [02:31<483:49:50,  2.70it/s]

last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 404/4699119 [02:32<445:04:13,  2.93it/s]

last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 405/4699119 [02:32<441:06:04,  2.96it/s]

last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 406/4699119 [02:32<439:29:38,  2.97it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 407/4699119 [02:33<495:13:27,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 408/4699119 [02:33<486:39:00,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 409/4699119 [02:34<527:34:23,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 410/4699119 [02:34<556:55:35,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 411/4699119 [02:34<491:24:27,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 412/4699119 [02:35<533:12:14,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 413/4699119 [02:35<497:21:54,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 414/4699119 [02:35<433:06:50,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 415/4699119 [02:36<450:16:16,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 416/4699119 [02:36<503:48:10,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 417/4699119 [02:37<540:33:03,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 412, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 418/4699119 [02:37<537:03:04,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 419/4699119 [02:38<532:33:29,  2.45it/s]

last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 420/4699119 [02:38<493:09:43,  2.65it/s]

last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 422/4699119 [02:38<430:49:23,  3.03it/s]

last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 423/4699119 [02:39<476:49:42,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 285, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 424/4699119 [02:39<454:27:00,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 425/4699119 [02:40<505:32:06,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 144, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 426/4699119 [02:40<431:44:25,  3.02it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 427/4699119 [02:40<489:14:53,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 428/4699119 [02:41<436:55:14,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 429/4699119 [02:41<425:24:23,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 430/4699119 [02:41<485:33:03,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 431/4699119 [02:42<437:28:19,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 432/4699119 [02:42<494:05:15,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 433/4699119 [02:42<489:00:56,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 434/4699119 [02:43<484:22:52,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 435/4699119 [02:43<490:14:45,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 506, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 436/4699119 [02:44<530:35:26,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 437/4699119 [02:44<475:18:28,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 438/4699119 [02:44<502:33:31,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 191, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 439/4699119 [02:45<444:57:15,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 440/4699119 [02:45<499:50:08,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 441/4699119 [02:45<485:56:30,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 442/4699119 [02:46<528:10:58,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 443/4699119 [02:46<476:48:26,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 444/4699119 [02:47<476:56:00,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 445/4699119 [02:47<523:02:13,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 446/4699119 [02:47<505:46:03,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 447/4699119 [02:48<472:38:30,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 448/4699119 [02:48<432:47:57,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 449/4699119 [02:48<460:45:53,  2.83it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 450/4699119 [02:49<509:24:21,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 451/4699119 [02:49<475:32:36,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 452/4699119 [02:49<437:47:50,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 158, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 453/4699119 [02:50<386:52:32,  3.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 454/4699119 [02:50<443:00:06,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 455/4699119 [02:50<424:53:08,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 456/4699119 [02:51<433:49:49,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 457/4699119 [02:51<428:07:19,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 458/4699119 [02:52<487:50:05,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 459/4699119 [02:52<503:04:03,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 460, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 460/4699119 [02:52<524:58:12,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 461/4699119 [02:53<555:30:00,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 462/4699119 [02:53<497:10:09,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 463/4699119 [02:54<509:53:13,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 464/4699119 [02:54<435:53:12,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 137, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 465/4699119 [02:54<382:39:53,  3.41it/s]

last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 466/4699119 [02:54<458:32:13,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 467/4699119 [02:55<428:32:21,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 468/4699119 [02:55<437:03:47,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 469/4699119 [02:55<462:31:46,  2.82it/s]

last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 470/4699119 [02:56<440:27:26,  2.96it/s]

last_hidden_state = torch.Size([8, 412, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 471/4699119 [02:56<469:14:01,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 472/4699119 [02:57<516:57:53,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 473/4699119 [02:57<548:23:27,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 474/4699119 [02:58<571:41:56,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 475/4699119 [02:58<525:18:40,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 465, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 476/4699119 [02:58<547:12:07,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 477/4699119 [02:59<563:46:31,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 113, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 479/4699119 [02:59<419:46:40,  3.11it/s]

last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 480/4699119 [03:00<429:53:05,  3.04it/s]

last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 481/4699119 [03:00<406:32:59,  3.21it/s]

last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 482/4699119 [03:00<428:55:20,  3.04it/s]

last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 483/4699119 [03:01<455:11:00,  2.87it/s]

last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 484/4699119 [03:01<505:28:17,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 485/4699119 [03:01<471:29:07,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 486/4699119 [03:02<517:49:42,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 487/4699119 [03:02<505:46:33,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 488/4699119 [03:03<471:57:15,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 489/4699119 [03:03<465:17:24,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 490/4699119 [03:03<514:45:50,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 491/4699119 [03:04<488:39:35,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 492/4699119 [03:04<422:12:00,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 181, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 493/4699119 [03:04<386:00:48,  3.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 494/4699119 [03:05<458:24:12,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 495/4699119 [03:05<425:02:29,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 496/4699119 [03:05<390:50:10,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 497/4699119 [03:06<461:19:08,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 498/4699119 [03:06<454:13:12,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 499/4699119 [03:06<429:10:10,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 506, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 500/4699119 [03:07<488:26:42,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 501/4699119 [03:07<481:04:00,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 502/4699119 [03:08<526:18:59,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 503/4699119 [03:08<556:08:15,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 504/4699119 [03:08<542:45:43,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 505/4699119 [03:09<543:08:37,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 307, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 506/4699119 [03:09<509:51:30,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 507/4699119 [03:10<491:35:46,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 508/4699119 [03:10<533:23:28,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 154, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 509/4699119 [03:10<453:01:20,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 510/4699119 [03:11<504:38:24,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 511/4699119 [03:11<475:07:55,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 512/4699119 [03:11<519:55:07,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 513/4699119 [03:12<551:26:33,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 514/4699119 [03:12<515:40:12,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 515/4699119 [03:13<536:40:10,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 516/4699119 [03:13<519:42:21,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 517/4699119 [03:14<552:57:08,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 518/4699119 [03:14<575:01:11,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 519/4699119 [03:15<593:14:14,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 520/4699119 [03:15<604:43:16,  2.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 521/4699119 [03:15<552:46:16,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 522/4699119 [03:16<506:56:39,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 523/4699119 [03:16<473:14:00,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 524/4699119 [03:16<519:04:27,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 525/4699119 [03:17<530:11:16,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 526/4699119 [03:17<558:55:13,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 527/4699119 [03:18<528:09:11,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 528/4699119 [03:18<557:25:34,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 529/4699119 [03:19<567:33:53,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 530/4699119 [03:19<585:27:09,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 531/4699119 [03:20<597:59:01,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 532/4699119 [03:20<608:21:24,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 533/4699119 [03:20<560:54:52,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 534/4699119 [03:21<582:35:32,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 535/4699119 [03:21<595:58:38,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 536/4699119 [03:22<541:51:40,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 537/4699119 [03:22<568:26:57,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 150, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 538/4699119 [03:22<477:38:01,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 454, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 539/4699119 [03:23<506:28:53,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 540/4699119 [03:23<501:31:26,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 541/4699119 [03:24<539:25:56,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 542/4699119 [03:24<525:05:24,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 543/4699119 [03:24<469:47:00,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 544/4699119 [03:25<469:33:18,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 545/4699119 [03:25<448:34:57,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 546/4699119 [03:25<439:29:33,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 547/4699119 [03:26<435:40:56,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 548/4699119 [03:26<412:09:19,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 549/4699119 [03:26<426:48:38,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 550/4699119 [03:27<408:59:07,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 551/4699119 [03:27<401:03:58,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 552/4699119 [03:27<457:26:54,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 553/4699119 [03:28<417:49:12,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 554/4699119 [03:28<481:48:39,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 555/4699119 [03:29<525:13:32,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 556/4699119 [03:29<473:56:31,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 557/4699119 [03:29<519:28:15,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 558/4699119 [03:30<475:36:34,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 559/4699119 [03:30<483:09:50,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 560/4699119 [03:30<464:27:24,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 561/4699119 [03:31<442:15:56,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 562/4699119 [03:31<411:42:30,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 69, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 563/4699119 [03:31<348:45:30,  3.74it/s]

last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 564/4699119 [03:31<369:50:38,  3.53it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 565/4699119 [03:32<446:40:11,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 566/4699119 [03:32<489:00:44,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 567/4699119 [03:33<449:18:49,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 77, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 568/4699119 [03:33<375:37:26,  3.47it/s]

last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 570/4699119 [03:33<364:34:59,  3.58it/s]

last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 571/4699119 [03:34<443:11:18,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 572/4699119 [03:34<457:09:48,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 573/4699119 [03:35<508:44:50,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 574/4699119 [03:35<440:14:07,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 185, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 575/4699119 [03:35<399:28:45,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 576/4699119 [03:35<433:47:08,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 577/4699119 [03:36<491:55:36,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 578/4699119 [03:36<499:48:51,  2.61it/s]

last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 580/4699119 [03:37<434:10:13,  3.01it/s]

last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 581/4699119 [03:37<409:02:55,  3.19it/s]

last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 582/4699119 [03:38<464:45:21,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 583/4699119 [03:38<444:19:32,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 584/4699119 [03:38<431:15:10,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 585/4699119 [03:39<489:04:11,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 586/4699119 [03:39<518:54:47,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 587/4699119 [03:39<446:53:00,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 588/4699119 [03:40<443:15:58,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 589/4699119 [03:40<409:24:37,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 590/4699119 [03:40<407:50:00,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 591/4699119 [03:41<474:09:37,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 479, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 592/4699119 [03:41<513:43:29,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 593/4699119 [03:42<479:46:00,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 130, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 594/4699119 [03:42<411:32:59,  3.17it/s]

last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 595/4699119 [03:42<395:09:47,  3.30it/s]

last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 596/4699119 [03:42<417:21:01,  3.13it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 597/4699119 [03:43<479:37:00,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 598/4699119 [03:43<466:10:09,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 599/4699119 [03:44<514:20:42,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 600/4699119 [03:44<462:55:26,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 601/4699119 [03:44<463:46:56,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 602/4699119 [03:45<452:41:16,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 603/4699119 [03:45<505:03:50,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 604/4699119 [03:45<459:31:41,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 140, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 605/4699119 [03:46<399:19:34,  3.27it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 606/4699119 [03:46<467:39:05,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 607/4699119 [03:46<516:54:31,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 608/4699119 [03:47<556:29:11,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 609/4699119 [03:47<545:18:14,  2.39it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 610/4699119 [03:48<573:13:38,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 611/4699119 [03:48<589:10:44,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 612/4699119 [03:49<555:36:02,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 613/4699119 [03:49<532:37:40,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 614/4699119 [03:49<499:17:03,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 615/4699119 [03:50<469:28:08,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 616/4699119 [03:50<478:13:46,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 617/4699119 [03:51<523:25:38,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 618/4699119 [03:51<467:16:41,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 619/4699119 [03:51<516:15:28,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 620/4699119 [03:52<549:53:00,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 506, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 621/4699119 [03:52<574:01:31,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 622/4699119 [03:53<558:33:29,  2.34it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 623/4699119 [03:53<579:28:00,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 624/4699119 [03:54<554:31:53,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 625/4699119 [03:54<577:54:20,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 626/4699119 [03:54<579:08:27,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 627/4699119 [03:55<566:49:42,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 191, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 628/4699119 [03:55<489:09:17,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 629/4699119 [03:56<530:49:31,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 630/4699119 [03:56<499:52:58,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 631/4699119 [03:56<494:23:02,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 632/4699119 [03:57<533:56:03,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 633/4699119 [03:57<483:52:29,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 140, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 634/4699119 [03:57<416:11:00,  3.14it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 635/4699119 [03:58<479:18:57,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 636/4699119 [03:58<524:52:13,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 637/4699119 [03:59<502:53:43,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 467, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 638/4699119 [03:59<532:43:20,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 639/4699119 [03:59<504:34:43,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 640/4699119 [04:00<446:30:23,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 641/4699119 [04:00<484:35:40,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 642/4699119 [04:01<527:18:57,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 643/4699119 [04:01<484:34:43,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 644/4699119 [04:01<471:47:29,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 645/4699119 [04:02<520:58:42,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 646/4699119 [04:02<460:39:38,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 647/4699119 [04:02<424:39:40,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 648/4699119 [04:02<406:58:15,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 649/4699119 [04:03<421:23:40,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 650/4699119 [04:03<386:28:46,  3.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 651/4699119 [04:03<458:17:08,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 652/4699119 [04:04<509:32:52,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 653/4699119 [04:04<544:28:25,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 654/4699119 [04:05<575:17:35,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 655/4699119 [04:05<591:49:40,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 507, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 656/4699119 [04:06<608:08:37,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 657/4699119 [04:06<596:38:03,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 658/4699119 [04:07<606:18:43,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 659/4699119 [04:07<535:52:23,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 199, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 660/4699119 [04:07<469:48:05,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 661/4699119 [04:08<516:53:00,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 662/4699119 [04:08<549:35:15,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 663/4699119 [04:09<573:12:34,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 170, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 664/4699119 [04:09<486:01:04,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 422, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 665/4699119 [04:09<503:20:10,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 666/4699119 [04:10<540:28:19,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 667/4699119 [04:10<567:27:19,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 668/4699119 [04:11<585:06:48,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 669/4699119 [04:11<517:51:57,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 670/4699119 [04:11<453:16:39,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 671/4699119 [04:12<407:33:33,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 672/4699119 [04:12<429:34:04,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 673/4699119 [04:12<440:23:51,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 674/4699119 [04:13<398:23:16,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 162, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 675/4699119 [04:13<363:03:26,  3.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 676/4699119 [04:13<423:26:21,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 677/4699119 [04:14<484:26:08,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 678/4699119 [04:14<493:03:37,  2.65it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 679/4699119 [04:15<535:12:52,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 680/4699119 [04:15<551:01:37,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 681/4699119 [04:16<574:58:04,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 682/4699119 [04:16<544:02:16,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 683/4699119 [04:16<505:16:23,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 499, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 684/4699119 [04:17<547:32:19,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 685/4699119 [04:17<545:44:39,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 686/4699119 [04:17<481:49:46,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 687/4699119 [04:18<485:19:14,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 688/4699119 [04:18<532:22:51,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 689/4699119 [04:19<517:24:34,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 690/4699119 [04:19<551:31:38,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 691/4699119 [04:19<492:35:12,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 692/4699119 [04:20<463:51:08,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 693/4699119 [04:20<418:40:59,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 694/4699119 [04:20<481:38:18,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 695/4699119 [04:21<499:53:04,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 174, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 696/4699119 [04:21<435:02:32,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 697/4699119 [04:21<472:36:05,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 698/4699119 [04:22<519:33:45,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 699/4699119 [04:22<505:28:13,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 700/4699119 [04:23<543:38:22,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 701/4699119 [04:23<485:10:56,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 702/4699119 [04:24<520:33:59,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 703/4699119 [04:24<488:36:03,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 704/4699119 [04:24<447:07:34,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 705/4699119 [04:25<501:11:37,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 706/4699119 [04:25<479:36:40,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 707/4699119 [04:25<442:53:55,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 708/4699119 [04:26<497:46:24,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 709/4699119 [04:26<453:00:58,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 710/4699119 [04:26<451:23:25,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 711/4699119 [04:27<503:52:41,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 506, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 712/4699119 [04:27<541:47:47,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 713/4699119 [04:28<498:58:55,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 714/4699119 [04:28<486:05:55,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 715/4699119 [04:28<434:23:27,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 716/4699119 [04:28<405:00:13,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 717/4699119 [04:29<427:13:59,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 718/4699119 [04:29<390:20:45,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 719/4699119 [04:29<461:27:01,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 108, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 720/4699119 [04:30<390:37:58,  3.34it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 721/4699119 [04:30<462:43:09,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 179, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 722/4699119 [04:30<414:45:29,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 508, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 723/4699119 [04:31<479:19:23,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 724/4699119 [04:31<475:06:27,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 725/4699119 [04:31<434:27:36,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 726/4699119 [04:32<424:15:12,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 727/4699119 [04:32<485:30:05,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 728/4699119 [04:33<527:38:26,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 729/4699119 [04:35<1098:01:43,  1.19it/s]

last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 730/4699119 [04:35<870:43:07,  1.50it/s] 

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 731/4699119 [04:35<689:33:25,  1.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 732/4699119 [04:35<587:55:24,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 733/4699119 [04:36<518:22:41,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 734/4699119 [04:36<509:30:59,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 735/4699119 [04:36<544:25:08,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 736/4699119 [04:37<542:54:47,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 737/4699119 [04:37<552:11:42,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 738/4699119 [04:38<574:15:13,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 739/4699119 [04:38<527:32:35,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 740/4699119 [04:38<476:35:12,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 741/4699119 [04:39<459:20:38,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 742/4699119 [04:39<423:43:14,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 743/4699119 [04:39<435:42:14,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 744/4699119 [04:40<448:30:42,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 745/4699119 [04:40<438:10:38,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 746/4699119 [04:40<401:16:08,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 747/4699119 [04:41<469:55:03,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 748/4699119 [04:41<467:49:42,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 749/4699119 [04:41<446:35:48,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 750/4699119 [04:42<427:27:16,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 751/4699119 [04:42<487:46:57,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 752/4699119 [04:43<530:27:42,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 753/4699119 [04:43<559:05:51,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 754/4699119 [04:44<579:26:26,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 755/4699119 [04:44<580:47:02,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 756/4699119 [04:44<486:41:41,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 140, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 757/4699119 [04:44<417:34:01,  3.13it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 758/4699119 [04:45<481:31:43,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 759/4699119 [04:45<525:15:20,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 760/4699119 [04:46<446:40:45,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 761/4699119 [04:46<458:07:19,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 762/4699119 [04:46<507:51:58,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 763/4699119 [04:47<489:47:05,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 764/4699119 [04:47<460:28:49,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 765/4699119 [04:47<439:48:04,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 766/4699119 [04:48<426:01:39,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 767/4699119 [04:48<417:40:06,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 768/4699119 [04:48<423:06:56,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 769/4699119 [04:49<484:50:21,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 770/4699119 [04:49<487:29:33,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 771/4699119 [04:50<467:03:17,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 772/4699119 [04:50<515:31:56,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 773/4699119 [04:50<468:19:44,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 162, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 774/4699119 [04:50<411:47:58,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 775/4699119 [04:51<392:07:41,  3.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 776/4699119 [04:51<374:54:09,  3.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 777/4699119 [04:51<369:40:27,  3.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 778/4699119 [04:52<447:40:53,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 779/4699119 [04:52<413:47:17,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 780/4699119 [04:53<481:53:50,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 781/4699119 [04:53<443:38:46,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 782/4699119 [04:53<465:54:40,  2.80it/s]

last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 783/4699119 [04:54<464:39:34,  2.81it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 784/4699119 [04:54<513:11:17,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 785/4699119 [04:54<451:26:27,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 786/4699119 [04:55<423:04:58,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 787/4699119 [04:55<401:42:36,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 788/4699119 [04:55<421:14:36,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 789/4699119 [04:56<482:19:52,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 790/4699119 [04:56<514:03:18,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 791/4699119 [04:56<497:02:24,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 792/4699119 [04:57<535:42:29,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 793/4699119 [04:57<477:23:19,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 794/4699119 [04:58<479:15:11,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 795/4699119 [04:58<511:30:26,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 796/4699119 [04:58<494:24:47,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 797/4699119 [04:59<441:22:44,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 798/4699119 [04:59<427:45:18,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 799/4699119 [04:59<487:19:56,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 800/4699119 [05:00<466:20:16,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 801/4699119 [05:00<430:51:08,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 802/4699119 [05:00<443:52:40,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 468, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 803/4699119 [05:01<486:19:35,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 804/4699119 [05:01<534:50:14,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 805/4699119 [05:02<562:30:53,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 806/4699119 [05:02<513:24:10,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 807/4699119 [05:02<498:21:47,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 401, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 808/4699119 [05:03<510:53:44,  2.55it/s]

last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 809/4699119 [05:03<457:34:03,  2.85it/s]

last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 811/4699119 [05:04<408:16:01,  3.20it/s]

last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 812/4699119 [05:04<392:39:18,  3.32it/s]

last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 813/4699119 [05:04<381:45:50,  3.42it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 814/4699119 [05:05<455:59:07,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 815/4699119 [05:05<512:36:53,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 152, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 816/4699119 [05:05<438:24:22,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 817/4699119 [05:06<487:02:07,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 818/4699119 [05:06<419:44:12,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 819/4699119 [05:06<434:47:05,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 820/4699119 [05:07<492:24:16,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 821/4699119 [05:07<537:40:43,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 822/4699119 [05:08<485:58:51,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 823/4699119 [05:08<419:26:09,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 101, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 824/4699119 [05:08<360:42:17,  3.62it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 825/4699119 [05:09<441:11:11,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 826/4699119 [05:09<471:39:48,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 827/4699119 [05:09<445:41:03,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 87, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 828/4699119 [05:09<376:36:11,  3.47it/s]

last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 829/4699119 [05:10<398:19:22,  3.28it/s]

last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 831/4699119 [05:10<360:58:39,  3.62it/s]

last_hidden_state = torch.Size([8, 174, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 832/4699119 [05:11<367:42:33,  3.55it/s]

last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 833/4699119 [05:11<406:59:14,  3.21it/s]

last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 834/4699119 [05:11<447:25:27,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 835/4699119 [05:12<501:41:42,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 836/4699119 [05:12<540:00:29,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 837/4699119 [05:13<507:42:25,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 838/4699119 [05:13<478:29:21,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 127, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 839/4699119 [05:13<408:54:21,  3.19it/s]

last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 840/4699119 [05:13<389:39:21,  3.35it/s]

last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 841/4699119 [05:14<451:06:36,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 143, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 842/4699119 [05:14<394:45:24,  3.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 843/4699119 [05:14<438:59:16,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 844/4699119 [05:15<414:38:26,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 845/4699119 [05:15<479:08:28,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 846/4699119 [05:16<523:47:49,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 847/4699119 [05:16<533:33:37,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 492, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 848/4699119 [05:17<558:41:32,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 849/4699119 [05:17<566:52:50,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 850/4699119 [05:17<555:20:01,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 851/4699119 [05:18<577:21:30,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 853/4699119 [05:19<487:14:55,  2.68it/s]

last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 854/4699119 [05:19<530:39:27,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 855/4699119 [05:19<498:29:43,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 856/4699119 [05:20<488:54:41,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 857/4699119 [05:20<444:10:32,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 858/4699119 [05:21<499:53:12,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 859/4699119 [05:21<443:58:44,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 860/4699119 [05:21<420:29:04,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 861/4699119 [05:21<465:16:15,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 862/4699119 [05:22<434:24:49,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 863/4699119 [05:22<453:19:00,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 864/4699119 [05:22<408:56:40,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 865/4699119 [05:23<467:34:08,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 439, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 866/4699119 [05:23<497:41:44,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 867/4699119 [05:24<536:13:53,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 868/4699119 [05:24<565:13:54,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 869/4699119 [05:25<582:43:44,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 508, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 870/4699119 [05:25<595:55:01,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 871/4699119 [05:25<519:14:20,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 872/4699119 [05:26<550:59:04,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 873/4699119 [05:26<513:36:22,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 874/4699119 [05:27<535:20:23,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 875/4699119 [05:27<514:17:36,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 876/4699119 [05:27<460:22:56,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 877/4699119 [05:28<427:44:23,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 124, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 878/4699119 [05:28<372:51:16,  3.50it/s]

last_hidden_state = torch.Size([8, 402, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 879/4699119 [05:28<418:43:56,  3.12it/s]

last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 880/4699119 [05:29<442:20:58,  2.95it/s]

last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 881/4699119 [05:29<410:40:03,  3.18it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 882/4699119 [05:29<475:29:58,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 883/4699119 [05:30<453:35:59,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 158, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 884/4699119 [05:30<397:58:08,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 885/4699119 [05:30<445:21:02,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 886/4699119 [05:30<411:47:05,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 887/4699119 [05:31<424:31:00,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 888/4699119 [05:31<459:14:48,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 889/4699119 [05:31<424:08:19,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 890/4699119 [05:32<440:57:18,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 891/4699119 [05:32<476:34:47,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 169, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 892/4699119 [05:33<418:47:43,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 893/4699119 [05:33<423:06:15,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 486, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 894/4699119 [05:33<480:43:54,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 895/4699119 [05:34<514:13:01,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 896/4699119 [05:34<483:52:16,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 897/4699119 [05:34<466:20:56,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 898/4699119 [05:35<442:51:40,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 899/4699119 [05:35<502:25:43,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 900/4699119 [05:36<495:20:45,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 901/4699119 [05:36<443:39:08,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 902/4699119 [05:36<486:40:02,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 903/4699119 [05:37<434:37:40,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 904/4699119 [05:37<494:19:37,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 905/4699119 [05:37<534:10:46,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 906/4699119 [05:38<511:24:16,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 907/4699119 [05:38<449:26:49,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 908/4699119 [05:39<504:36:42,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 909/4699119 [05:39<546:58:41,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 910/4699119 [05:39<511:37:57,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 911/4699119 [05:40<496:45:31,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 912/4699119 [05:40<441:41:26,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 913/4699119 [05:40<498:08:23,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 914/4699119 [05:41<459:01:18,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 915/4699119 [05:41<510:27:31,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 124, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 916/4699119 [05:41<430:49:32,  3.03it/s]

last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 917/4699119 [05:42<471:00:29,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 918/4699119 [05:42<465:35:11,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 919/4699119 [05:42<425:11:25,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 920/4699119 [05:43<421:50:23,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 511, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 921/4699119 [05:43<489:32:02,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 922/4699119 [05:43<433:26:50,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 314, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 923/4699119 [05:44<433:25:48,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 924/4699119 [05:44<494:08:09,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 925/4699119 [05:45<467:02:12,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 926/4699119 [05:45<515:03:09,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 927/4699119 [05:45<460:45:26,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 928/4699119 [05:46<469:19:11,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 929/4699119 [05:46<517:12:14,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 330, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 930/4699119 [05:47<495:32:26,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 404, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 931/4699119 [05:47<503:42:15,  2.59it/s]

last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 932/4699119 [05:47<502:13:26,  2.60it/s]

last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 933/4699119 [05:48<539:16:38,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 934/4699119 [05:48<471:36:04,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 935/4699119 [05:49<517:58:05,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 936/4699119 [05:49<511:16:13,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 937/4699119 [05:49<521:04:07,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 938/4699119 [05:50<504:39:22,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 939/4699119 [05:50<460:25:02,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 940/4699119 [05:50<510:01:03,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 941/4699119 [05:51<475:45:14,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 942/4699119 [05:51<448:51:28,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 507, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 943/4699119 [05:52<508:53:16,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 944/4699119 [05:52<544:26:47,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 945/4699119 [05:52<571:01:27,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 946/4699119 [05:53<576:57:34,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 947/4699119 [05:53<575:36:24,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 948/4699119 [05:54<591:39:15,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 949/4699119 [05:54<505:13:27,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 950/4699119 [05:54<461:47:23,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 951/4699119 [05:55<510:56:09,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 952/4699119 [05:55<525:45:33,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 953/4699119 [05:56<518:34:01,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 111, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 954/4699119 [05:56<430:49:59,  3.03it/s]

last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 955/4699119 [05:56<469:36:26,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 956/4699119 [05:57<517:01:19,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 957/4699119 [05:57<551:41:43,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 958/4699119 [05:58<547:30:30,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 959/4699119 [05:58<508:53:10,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 960/4699119 [05:58<450:50:42,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 961/4699119 [05:59<435:16:37,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 962/4699119 [05:59<494:43:15,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 964/4699119 [06:00<441:20:01,  2.96it/s]

last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 965/4699119 [06:00<455:14:49,  2.87it/s]

last_hidden_state = torch.Size([8, 465, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 966/4699119 [06:00<498:08:20,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 967/4699119 [06:01<539:50:09,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 468, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 968/4699119 [06:01<553:23:56,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 969/4699119 [06:02<541:37:14,  2.41it/s]

last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 970/4699119 [06:02<517:33:37,  2.52it/s]

last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 971/4699119 [06:03<518:11:00,  2.52it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 972/4699119 [06:03<550:55:07,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 973/4699119 [06:03<476:19:21,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 974/4699119 [06:04<434:44:15,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 495, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 975/4699119 [06:04<495:44:34,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 976/4699119 [06:04<516:48:27,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 977/4699119 [06:05<549:49:14,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 978/4699119 [06:05<543:45:50,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 979/4699119 [06:06<568:45:00,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 980/4699119 [06:06<513:22:55,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 981/4699119 [06:07<548:03:09,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 133, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 983/4699119 [06:07<395:59:31,  3.30it/s]

last_hidden_state = torch.Size([8, 125, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 984/4699119 [06:07<367:53:00,  3.55it/s]

last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 985/4699119 [06:08<446:43:09,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 986/4699119 [06:08<499:50:03,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 987/4699119 [06:08<443:09:20,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 988/4699119 [06:09<465:04:23,  2.81it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 989/4699119 [06:09<515:17:00,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 990/4699119 [06:10<461:19:47,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 991/4699119 [06:10<452:34:13,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 992/4699119 [06:10<435:17:07,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 993/4699119 [06:10<416:38:31,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 994/4699119 [06:11<480:44:48,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 995/4699119 [06:11<524:25:26,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 996/4699119 [06:12<555:01:26,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 997/4699119 [06:12<557:45:41,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 436, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 998/4699119 [06:13<556:49:36,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 999/4699119 [06:13<470:44:06,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1000/4699119 [06:13<474:47:24,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1001/4699119 [06:14<465:16:10,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1002/4699119 [06:14<420:00:03,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1003/4699119 [06:14<434:58:28,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1004/4699119 [06:15<396:44:31,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1005/4699119 [06:15<466:52:23,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1006/4699119 [06:15<515:33:34,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1007/4699119 [06:16<485:04:20,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1008/4699119 [06:16<538:21:17,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1009/4699119 [06:17<565:37:01,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1010/4699119 [06:17<573:05:19,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1011/4699119 [06:18<593:59:19,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1012/4699119 [06:18<604:42:07,  2.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1013/4699119 [06:19<612:19:15,  2.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1014/4699119 [06:19<533:08:39,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1015/4699119 [06:19<504:56:57,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1016/4699119 [06:20<525:16:09,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 505, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1017/4699119 [06:20<561:15:58,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1018/4699119 [06:20<498:07:41,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1019/4699119 [06:21<537:50:44,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1020/4699119 [06:21<478:34:30,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1021/4699119 [06:22<454:17:25,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1022/4699119 [06:22<508:20:47,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1023/4699119 [06:22<485:37:57,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 98, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1024/4699119 [06:23<406:30:51,  3.21it/s]

last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1025/4699119 [06:23<390:28:27,  3.34it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1026/4699119 [06:23<461:35:04,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1027/4699119 [06:24<429:43:13,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1028/4699119 [06:24<463:36:13,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1029/4699119 [06:24<477:32:19,  2.73it/s]

last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1030/4699119 [06:25<500:31:21,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1031/4699119 [06:25<513:33:46,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1032/4699119 [06:25<470:29:50,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1033/4699119 [06:26<517:08:54,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1034/4699119 [06:26<481:56:19,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1035/4699119 [06:27<433:07:08,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1036/4699119 [06:27<393:27:55,  3.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1037/4699119 [06:27<452:26:11,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1038/4699119 [06:27<400:57:12,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1039/4699119 [06:28<395:36:07,  3.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1040/4699119 [06:28<465:56:48,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1041/4699119 [06:29<514:10:49,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1042/4699119 [06:29<487:30:26,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1043/4699119 [06:29<445:16:25,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 166, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1044/4699119 [06:29<395:55:59,  3.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1045/4699119 [06:30<466:18:09,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1046/4699119 [06:30<446:20:59,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 156, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1047/4699119 [06:30<392:55:55,  3.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1048/4699119 [06:31<417:33:30,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1049/4699119 [06:31<411:23:45,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1050/4699119 [06:31<368:17:06,  3.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1051/4699119 [06:32<412:03:07,  3.17it/s]

last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1052/4699119 [06:32<390:04:52,  3.35it/s]

last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1053/4699119 [06:32<414:16:40,  3.15it/s]

last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1054/4699119 [06:33<431:08:45,  3.03it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1055/4699119 [06:33<490:00:42,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1056/4699119 [06:33<457:25:57,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1057/4699119 [06:34<420:44:22,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1058/4699119 [06:34<418:44:29,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1059/4699119 [06:35<481:27:33,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1060/4699119 [06:35<523:34:33,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1061/4699119 [06:35<519:51:00,  2.51it/s]

last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1062/4699119 [06:36<536:08:47,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1063/4699119 [06:36<526:38:13,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1064/4699119 [06:37<486:29:09,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 384, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1065/4699119 [06:37<488:44:21,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 143, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1066/4699119 [06:37<421:00:24,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1067/4699119 [06:37<439:11:33,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1068/4699119 [06:38<408:28:49,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1069/4699119 [06:38<474:31:42,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1070/4699119 [06:39<471:31:05,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1071/4699119 [06:39<506:40:28,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1072/4699119 [06:40<543:32:18,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1073/4699119 [06:40<532:34:53,  2.45it/s]

last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1074/4699119 [06:40<505:35:53,  2.58it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1075/4699119 [06:41<542:44:04,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1076/4699119 [06:41<483:19:06,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1077/4699119 [06:41<467:41:45,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1078/4699119 [06:42<467:04:16,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1079/4699119 [06:42<431:50:26,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 128, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1080/4699119 [06:42<375:52:50,  3.47it/s]

last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1081/4699119 [06:43<441:16:19,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1082/4699119 [06:43<457:14:08,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1083/4699119 [06:43<436:34:43,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1084/4699119 [06:44<446:36:37,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 170, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1085/4699119 [06:44<396:26:36,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1086/4699119 [06:44<385:04:43,  3.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1087/4699119 [06:45<458:18:24,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1088/4699119 [06:45<508:41:22,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1089/4699119 [06:45<471:11:04,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1090/4699119 [06:46<482:56:53,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1091/4699119 [06:46<526:39:29,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1092/4699119 [06:46<459:54:23,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1093/4699119 [06:47<484:17:45,  2.69it/s]

last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1094/4699119 [06:47<457:49:05,  2.85it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1095/4699119 [06:48<509:45:18,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1096/4699119 [06:48<459:01:13,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1097/4699119 [06:48<496:52:58,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1098/4699119 [06:49<535:58:22,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 268, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1099/4699119 [06:49<491:09:00,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1100/4699119 [06:50<531:24:43,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1101/4699119 [06:50<559:46:15,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1102/4699119 [06:51<582:05:23,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1103/4699119 [06:51<541:38:47,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1104/4699119 [06:51<567:10:41,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 495, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1105/4699119 [06:52<587:05:54,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 454, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1106/4699119 [06:52<584:25:18,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1107/4699119 [06:53<597:10:24,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1108/4699119 [06:53<553:37:08,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1109/4699119 [06:53<484:24:13,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 131, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1110/4699119 [06:54<415:15:09,  3.14it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1111/4699119 [06:54<478:06:42,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1112/4699119 [06:54<485:38:09,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1113/4699119 [06:55<509:59:48,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1114/4699119 [06:55<493:18:35,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1115/4699119 [06:56<519:50:10,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1116/4699119 [06:56<508:17:50,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1117/4699119 [06:57<516:53:06,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1118/4699119 [06:57<549:55:22,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 163, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1119/4699119 [06:57<468:48:06,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1120/4699119 [06:58<527:54:55,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1121/4699119 [06:58<462:13:08,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1122/4699119 [06:58<511:38:39,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1123/4699119 [06:59<463:46:13,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1124/4699119 [06:59<442:58:33,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 484, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1125/4699119 [06:59<494:33:04,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1126/4699119 [07:00<449:53:49,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1127/4699119 [07:00<422:52:10,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1128/4699119 [07:00<428:53:19,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 473, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1129/4699119 [07:01<481:24:20,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1130/4699119 [07:01<428:45:37,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1131/4699119 [07:01<407:36:06,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1132/4699119 [07:02<473:24:36,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1133/4699119 [07:02<445:36:06,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1134/4699119 [07:03<499:33:58,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1136/4699119 [07:03<442:17:26,  2.95it/s]

last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1137/4699119 [07:04<497:44:45,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 482, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1138/4699119 [07:04<532:48:58,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1139/4699119 [07:04<493:14:50,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1140/4699119 [07:05<503:38:57,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1141/4699119 [07:05<488:05:40,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 145, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1142/4699119 [07:05<421:20:43,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 454, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1143/4699119 [07:06<467:46:09,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1144/4699119 [07:06<520:01:22,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1145/4699119 [07:07<494:10:59,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1146/4699119 [07:07<499:06:09,  2.61it/s]

last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1147/4699119 [07:07<497:19:29,  2.62it/s]

last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1148/4699119 [07:08<460:01:48,  2.84it/s]

last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1150/4699119 [07:08<396:07:32,  3.29it/s]

last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 486, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1151/4699119 [07:09<462:03:59,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1152/4699119 [07:09<480:53:17,  2.71it/s]

last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1153/4699119 [07:09<438:50:40,  2.97it/s]

last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1154/4699119 [07:10<412:49:00,  3.16it/s]

last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1155/4699119 [07:10<393:45:48,  3.31it/s]

last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1156/4699119 [07:10<440:57:20,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1157/4699119 [07:11<496:18:41,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1158/4699119 [07:11<491:21:58,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1159/4699119 [07:12<486:54:02,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1160/4699119 [07:12<530:47:04,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1161/4699119 [07:12<518:20:21,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1162/4699119 [07:13<551:11:34,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1163/4699119 [07:13<522:52:42,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1164/4699119 [07:14<556:47:47,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 399, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1165/4699119 [07:14<546:19:30,  2.39it/s]

last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1166/4699119 [07:15<539:03:27,  2.42it/s]

last_hidden_state = torch.Size([8, 488, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1167/4699119 [07:15<562:52:27,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1168/4699119 [07:16<583:11:59,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1169/4699119 [07:16<526:55:48,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1170/4699119 [07:16<461:48:42,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1171/4699119 [07:16<458:30:59,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1172/4699119 [07:17<510:10:04,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 172, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1173/4699119 [07:17<441:51:41,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1174/4699119 [07:18<499:47:29,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1175/4699119 [07:18<440:05:22,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1176/4699119 [07:18<495:54:25,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1177/4699119 [07:19<508:39:09,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1178/4699119 [07:19<491:21:41,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 174, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1179/4699119 [07:19<428:43:07,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1180/4699119 [07:20<488:41:28,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1181/4699119 [07:20<458:36:18,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1182/4699119 [07:21<509:46:47,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1183/4699119 [07:21<545:26:44,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1184/4699119 [07:21<478:48:24,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 165, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1185/4699119 [07:21<419:42:22,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1186/4699119 [07:22<453:20:40,  2.88it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1187/4699119 [07:22<507:08:32,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1188/4699119 [07:23<542:30:48,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1189/4699119 [07:23<529:11:45,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1190/4699119 [07:24<559:17:35,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1191/4699119 [07:24<545:40:19,  2.39it/s]

last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1192/4699119 [07:24<528:51:51,  2.47it/s]

last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1193/4699119 [07:25<490:20:47,  2.66it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1194/4699119 [07:25<532:27:26,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1195/4699119 [07:26<519:42:56,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 165, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1196/4699119 [07:26<448:26:28,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1197/4699119 [07:26<503:37:44,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1198/4699119 [07:27<542:25:00,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1199/4699119 [07:27<494:02:45,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1200/4699119 [07:28<535:29:15,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 199, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1201/4699119 [07:28<469:23:23,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1202/4699119 [07:28<519:19:41,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1203/4699119 [07:29<553:13:40,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1204/4699119 [07:29<481:19:24,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 153, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1205/4699119 [07:29<416:38:49,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1206/4699119 [07:30<399:08:04,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1207/4699119 [07:30<467:31:23,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1208/4699119 [07:30<426:28:35,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1209/4699119 [07:31<401:23:13,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1210/4699119 [07:31<373:10:27,  3.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 93, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1211/4699119 [07:31<326:04:31,  4.00it/s]

last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1212/4699119 [07:31<393:52:24,  3.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1213/4699119 [07:32<380:53:15,  3.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1214/4699119 [07:32<425:02:55,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1215/4699119 [07:32<439:22:26,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1216/4699119 [07:33<444:31:27,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1217/4699119 [07:33<499:18:13,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1218/4699119 [07:34<498:58:28,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1219/4699119 [07:34<465:45:45,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1220/4699119 [07:34<460:55:37,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1221/4699119 [07:35<512:08:11,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1222/4699119 [07:35<547:24:44,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1223/4699119 [07:36<572:35:13,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1224/4699119 [07:36<519:55:41,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1225/4699119 [07:36<553:56:45,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1226/4699119 [07:37<547:00:07,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 162, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1227/4699119 [07:37<467:02:28,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 114, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1228/4699119 [07:37<399:33:08,  3.27it/s]

last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1229/4699119 [07:38<394:12:54,  3.31it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1230/4699119 [07:38<467:40:06,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1231/4699119 [07:38<453:31:49,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1232/4699119 [07:39<447:03:23,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1233/4699119 [07:39<447:12:19,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1234/4699119 [07:39<405:07:45,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1235/4699119 [07:40<375:35:58,  3.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1236/4699119 [07:40<382:37:58,  3.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1237/4699119 [07:40<372:44:54,  3.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1238/4699119 [07:41<449:55:18,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1239/4699119 [07:41<404:57:19,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1240/4699119 [07:41<412:58:32,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1241/4699119 [07:41<416:39:29,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1242/4699119 [07:42<480:37:22,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1243/4699119 [07:42<457:08:13,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1244/4699119 [07:43<511:04:50,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1245/4699119 [07:43<476:01:19,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1246/4699119 [07:43<424:50:42,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 98, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1247/4699119 [07:43<364:09:25,  3.58it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1248/4699119 [07:44<444:07:24,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1249/4699119 [07:44<500:01:32,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1250/4699119 [07:45<537:30:01,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1251/4699119 [07:45<492:01:05,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1252/4699119 [07:46<478:14:23,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1253/4699119 [07:46<522:26:15,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1254/4699119 [07:47<553:23:46,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 421, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1255/4699119 [07:47<553:30:48,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 173, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1256/4699119 [07:47<472:27:36,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1257/4699119 [07:48<518:57:07,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 156, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1258/4699119 [07:48<443:20:46,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1259/4699119 [07:48<458:45:37,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1260/4699119 [07:49<437:55:36,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1261/4699119 [07:49<434:38:39,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1262/4699119 [07:49<421:54:41,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1263/4699119 [07:49<399:42:59,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1264/4699119 [07:50<452:24:51,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1265/4699119 [07:50<416:53:23,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1266/4699119 [07:51<459:48:28,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1267/4699119 [07:51<448:07:15,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1268/4699119 [07:51<406:48:45,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1269/4699119 [07:51<427:54:41,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1270/4699119 [07:52<424:21:45,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1271/4699119 [07:52<486:41:01,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1272/4699119 [07:53<482:59:03,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1273/4699119 [07:53<432:13:01,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1274/4699119 [07:53<491:30:07,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1275/4699119 [07:54<506:44:30,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1276/4699119 [07:54<530:59:22,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1277/4699119 [07:55<564:46:04,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1278/4699119 [07:55<489:01:52,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 268, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1279/4699119 [07:55<459:15:33,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1280/4699119 [07:56<510:26:29,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1281/4699119 [07:56<495:56:59,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 81, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1282/4699119 [07:56<410:36:45,  3.18it/s]

last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1284/4699119 [07:57<385:00:13,  3.39it/s]

last_hidden_state = torch.Size([8, 128, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1285/4699119 [07:57<373:52:22,  3.49it/s]

last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1287/4699119 [07:58<378:13:51,  3.45it/s]

last_hidden_state = torch.Size([8, 170, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1288/4699119 [07:58<404:57:17,  3.22it/s]

last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1289/4699119 [07:58<425:19:35,  3.07it/s]

last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1290/4699119 [07:59<467:10:49,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1291/4699119 [07:59<432:06:44,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1292/4699119 [08:00<447:45:21,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1293/4699119 [08:00<405:11:44,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1294/4699119 [08:00<426:40:28,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1295/4699119 [08:01<466:55:14,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1296/4699119 [08:01<467:23:56,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1297/4699119 [08:01<516:27:20,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1298/4699119 [08:02<500:25:33,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1299/4699119 [08:02<538:48:46,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1300/4699119 [08:03<565:26:15,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1301/4699119 [08:03<546:30:28,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1302/4699119 [08:03<524:07:21,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 186, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1303/4699119 [08:04<458:00:30,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1304/4699119 [08:04<497:15:25,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1305/4699119 [08:04<467:34:58,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 140, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1306/4699119 [08:05<404:55:27,  3.22it/s]

last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1307/4699119 [08:05<381:36:19,  3.42it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1308/4699119 [08:05<456:10:56,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1309/4699119 [08:06<456:09:42,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1310/4699119 [08:06<491:22:50,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1311/4699119 [08:07<504:07:57,  2.59it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1312/4699119 [08:07<542:34:29,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 439, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1313/4699119 [08:07<550:20:57,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1314/4699119 [08:08<572:35:36,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1315/4699119 [08:08<532:08:24,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1316/4699119 [08:09<563:36:13,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1317/4699119 [08:09<582:42:43,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1318/4699119 [08:10<531:06:13,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1319/4699119 [08:10<539:01:34,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1320/4699119 [08:10<565:12:17,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1321/4699119 [08:11<539:40:31,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1322/4699119 [08:11<565:12:34,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1323/4699119 [08:12<501:17:45,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1324/4699119 [08:12<430:43:14,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1325/4699119 [08:12<489:08:19,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1326/4699119 [08:13<494:39:27,  2.64it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1327/4699119 [08:13<534:36:38,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1328/4699119 [08:14<562:19:08,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 159, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1329/4699119 [08:14<474:00:51,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1330/4699119 [08:14<411:21:13,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 122, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1331/4699119 [08:14<360:50:20,  3.62it/s]

last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1332/4699119 [08:15<369:32:08,  3.53it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1333/4699119 [08:15<446:12:57,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1334/4699119 [08:15<456:26:25,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1335/4699119 [08:16<451:04:03,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1336/4699119 [08:16<407:14:02,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 158, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1337/4699119 [08:16<365:26:16,  3.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1338/4699119 [08:16<362:47:42,  3.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1339/4699119 [08:17<360:27:52,  3.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1340/4699119 [08:17<443:04:43,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1341/4699119 [08:17<407:46:11,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 480, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1342/4699119 [08:18<460:00:10,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1343/4699119 [08:18<510:07:15,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 155, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1344/4699119 [08:19<436:52:37,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1345/4699119 [08:19<396:17:37,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1346/4699119 [08:19<434:30:17,  3.00it/s]

last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1347/4699119 [08:20<421:32:55,  3.10it/s]

last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1348/4699119 [08:20<432:41:06,  3.02it/s]

last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1350/4699119 [08:20<399:10:33,  3.27it/s]

last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1351/4699119 [08:21<359:05:43,  3.63it/s]

last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1352/4699119 [08:21<398:38:10,  3.27it/s]

last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1353/4699119 [08:21<429:30:59,  3.04it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1354/4699119 [08:22<489:46:05,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1355/4699119 [08:22<468:21:47,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1356/4699119 [08:22<418:33:26,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1357/4699119 [08:23<396:38:23,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1358/4699119 [08:23<465:50:13,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1359/4699119 [08:23<423:35:47,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1360/4699119 [08:24<485:54:58,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1361/4699119 [08:24<473:21:51,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1362/4699119 [08:25<467:26:13,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 511, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1363/4699119 [08:25<522:15:40,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1364/4699119 [08:26<553:51:45,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1365/4699119 [08:26<480:41:05,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1366/4699119 [08:26<524:06:19,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1367/4699119 [08:27<554:28:57,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1368/4699119 [08:27<517:56:17,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 425, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1369/4699119 [08:28<528:58:38,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1370/4699119 [08:28<502:41:50,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1371/4699119 [08:28<539:43:45,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1372/4699119 [08:29<566:24:15,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1373/4699119 [08:29<504:32:33,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1374/4699119 [08:29<450:00:02,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 157, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1375/4699119 [08:30<395:52:07,  3.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1376/4699119 [08:30<465:11:04,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1377/4699119 [08:30<462:17:47,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1378/4699119 [08:31<511:44:49,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1379/4699119 [08:31<449:26:20,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1380/4699119 [08:32<490:23:27,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 118, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1381/4699119 [08:32<416:09:13,  3.14it/s]

last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1382/4699119 [08:32<395:34:36,  3.30it/s]

last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1383/4699119 [08:32<400:57:20,  3.25it/s]

last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1384/4699119 [08:33<404:00:09,  3.23it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1385/4699119 [08:33<472:10:05,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1386/4699119 [08:33<445:36:56,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1387/4699119 [08:34<413:24:26,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1388/4699119 [08:34<383:04:52,  3.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1389/4699119 [08:34<420:22:45,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1390/4699119 [08:35<482:20:17,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1391/4699119 [08:35<442:38:08,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1392/4699119 [08:36<499:13:45,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1393/4699119 [08:36<500:04:11,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1394/4699119 [08:36<538:05:14,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1395/4699119 [08:37<481:35:03,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1396/4699119 [08:37<528:12:44,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 99, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1398/4699119 [08:38<403:16:01,  3.24it/s]

last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1399/4699119 [08:38<475:27:27,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1400/4699119 [08:39<505:19:19,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1401/4699119 [08:39<479:26:48,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1402/4699119 [08:39<445:53:02,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1403/4699119 [08:40<500:03:07,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1404/4699119 [08:40<490:07:10,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1405/4699119 [08:40<440:36:46,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 127, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1406/4699119 [08:40<382:08:56,  3.41it/s]

last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1407/4699119 [08:41<453:56:48,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1408/4699119 [08:41<508:39:24,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1409/4699119 [08:42<509:40:44,  2.56it/s]

last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1410/4699119 [08:42<461:25:38,  2.83it/s]

last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1411/4699119 [08:42<449:57:59,  2.90it/s]

last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1412/4699119 [08:43<418:12:16,  3.12it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1413/4699119 [08:43<480:22:30,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1414/4699119 [08:43<470:38:25,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1415/4699119 [08:44<437:37:54,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1416/4699119 [08:44<390:01:59,  3.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 191, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1417/4699119 [08:44<365:10:54,  3.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1418/4699119 [08:45<442:44:32,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1419/4699119 [08:45<434:22:48,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1420/4699119 [08:45<443:07:45,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1421/4699119 [08:46<481:43:08,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1422/4699119 [08:46<480:48:52,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1423/4699119 [08:47<525:19:59,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1424/4699119 [08:47<555:34:05,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1425/4699119 [08:47<517:21:12,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1426/4699119 [08:48<468:28:47,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1427/4699119 [08:48<448:59:08,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 163, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1428/4699119 [08:48<398:43:22,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1429/4699119 [08:49<416:20:36,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1430/4699119 [08:49<419:35:54,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 114, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1431/4699119 [08:49<365:44:13,  3.57it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1432/4699119 [08:50<444:28:50,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1433/4699119 [08:50<499:02:02,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 170, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1434/4699119 [08:50<433:37:57,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1435/4699119 [08:51<405:22:41,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1436/4699119 [08:51<376:16:16,  3.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 186, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1437/4699119 [08:51<354:26:40,  3.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1438/4699119 [08:51<385:38:02,  3.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1439/4699119 [08:52<375:29:09,  3.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 124, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1441/4699119 [08:52<313:30:24,  4.16it/s]

last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1443/4699119 [08:53<349:45:58,  3.73it/s]

last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1444/4699119 [08:53<387:00:28,  3.37it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1445/4699119 [08:53<458:24:14,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1446/4699119 [08:54<453:48:54,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1447/4699119 [08:54<444:22:18,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1448/4699119 [08:54<410:05:42,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1449/4699119 [08:55<474:44:07,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 494, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1450/4699119 [08:55<518:37:06,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1451/4699119 [08:56<550:22:54,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1452/4699119 [08:56<572:36:05,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1453/4699119 [08:57<521:07:50,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1454/4699119 [08:57<515:45:05,  2.53it/s]

last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1456/4699119 [08:57<419:09:38,  3.11it/s]

last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1457/4699119 [08:58<369:35:56,  3.53it/s]

last_hidden_state = torch.Size([8, 136, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1458/4699119 [08:58<362:32:31,  3.60it/s]

last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1459/4699119 [08:58<419:17:25,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1460/4699119 [08:59<432:19:09,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1461/4699119 [08:59<417:50:47,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1462/4699119 [08:59<481:19:27,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1463/4699119 [09:00<452:51:09,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1464/4699119 [09:00<471:15:08,  2.77it/s]

last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1465/4699119 [09:01<492:59:34,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1466/4699119 [09:01<439:09:32,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 330, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1467/4699119 [09:01<442:18:21,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1468/4699119 [09:01<435:35:39,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1469/4699119 [09:02<495:12:50,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1470/4699119 [09:02<534:57:42,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1471/4699119 [09:03<563:06:17,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1472/4699119 [09:03<519:14:21,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1473/4699119 [09:04<549:48:07,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1474/4699119 [09:04<496:37:26,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1475/4699119 [09:05<535:56:02,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1476/4699119 [09:05<511:44:02,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 506, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1477/4699119 [09:05<548:21:03,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1478/4699119 [09:06<561:01:18,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1479/4699119 [09:06<548:47:06,  2.38it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1480/4699119 [09:07<572:01:19,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1481/4699119 [09:07<589:30:37,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1482/4699119 [09:08<601:36:41,  2.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1483/4699119 [09:08<611:12:27,  2.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 462, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1484/4699119 [09:09<601:03:00,  2.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1485/4699119 [09:09<539:51:23,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1486/4699119 [09:09<567:08:48,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1487/4699119 [09:10<555:20:25,  2.35it/s]

fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1488/4699119 [09:10<578:05:42,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1489/4699119 [09:11<540:16:28,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1490/4699119 [09:11<566:51:24,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1491/4699119 [09:12<584:33:11,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1492/4699119 [09:12<597:13:15,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1493/4699119 [09:12<574:53:39,  2.27it/s]

last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1494/4699119 [09:13<518:16:02,  2.52it/s]

last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1495/4699119 [09:13<552:27:02,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1496/4699119 [09:14<522:46:49,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 460, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1497/4699119 [09:14<539:08:45,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1498/4699119 [09:14<544:25:56,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 490, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1499/4699119 [09:15<566:28:23,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1500/4699119 [09:15<490:43:21,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 181, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1501/4699119 [09:15<434:36:30,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 119, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1503/4699119 [09:16<357:08:10,  3.65it/s]

last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1504/4699119 [09:16<380:10:07,  3.43it/s]

last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1505/4699119 [09:16<391:56:53,  3.33it/s]

last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1506/4699119 [09:17<409:37:47,  3.19it/s]

last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1507/4699119 [09:17<442:53:46,  2.95it/s]

last_hidden_state = torch.Size([8, 460, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1508/4699119 [09:18<482:47:33,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1509/4699119 [09:18<421:04:11,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1510/4699119 [09:18<482:49:34,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 119, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1511/4699119 [09:19<410:53:38,  3.18it/s]

last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1512/4699119 [09:19<394:18:46,  3.31it/s]

last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1513/4699119 [09:19<377:05:55,  3.46it/s]

last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1514/4699119 [09:19<381:33:55,  3.42it/s]

last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1515/4699119 [09:20<434:36:40,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1516/4699119 [09:20<491:25:53,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1517/4699119 [09:21<532:10:32,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1518/4699119 [09:21<559:58:35,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1519/4699119 [09:22<544:35:08,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 462, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1520/4699119 [09:22<554:04:59,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1521/4699119 [09:22<486:08:09,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1522/4699119 [09:23<529:48:58,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 185, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1523/4699119 [09:23<462:54:12,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1524/4699119 [09:23<440:07:33,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1525/4699119 [09:24<439:00:45,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1526/4699119 [09:24<443:35:30,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 165, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1527/4699119 [09:24<394:49:14,  3.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1528/4699119 [09:25<419:45:17,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1529/4699119 [09:25<410:50:20,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1530/4699119 [09:25<460:00:55,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1531/4699119 [09:26<426:00:36,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1532/4699119 [09:26<487:16:56,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1533/4699119 [09:26<512:27:01,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1534/4699119 [09:27<548:27:42,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1535/4699119 [09:27<543:10:52,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1536/4699119 [09:28<536:07:57,  2.43it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1537/4699119 [09:28<564:26:38,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1538/4699119 [09:29<487:11:19,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1539/4699119 [09:29<476:32:40,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1540/4699119 [09:29<434:41:24,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1541/4699119 [09:29<394:21:48,  3.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1542/4699119 [09:30<464:09:24,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1543/4699119 [09:30<458:02:42,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1544/4699119 [09:30<411:01:21,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1545/4699119 [09:31<479:08:51,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1546/4699119 [09:31<479:06:10,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1547/4699119 [09:32<450:10:17,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1548/4699119 [09:32<502:35:43,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1549/4699119 [09:33<541:21:55,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1550/4699119 [09:33<498:00:36,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1551/4699119 [09:33<536:53:39,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1552/4699119 [09:34<505:22:42,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1553/4699119 [09:34<525:13:03,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 480, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1554/4699119 [09:35<544:09:39,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1555/4699119 [09:35<530:38:39,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1556/4699119 [09:35<497:56:14,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1557/4699119 [09:36<519:34:08,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1558/4699119 [09:36<501:40:21,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1559/4699119 [09:36<509:54:34,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 136, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1560/4699119 [09:37<433:19:30,  3.01it/s]

last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1561/4699119 [09:37<446:40:01,  2.92it/s]

last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1562/4699119 [09:37<443:12:19,  2.94it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1563/4699119 [09:38<498:16:09,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1564/4699119 [09:38<537:29:58,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1565/4699119 [09:39<564:54:30,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1566/4699119 [09:39<532:12:51,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1567/4699119 [09:39<464:13:27,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1568/4699119 [09:40<514:33:50,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1569/4699119 [09:40<522:27:27,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1570/4699119 [09:41<527:48:07,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 421, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1571/4699119 [09:41<535:07:40,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1572/4699119 [09:41<540:58:16,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1573/4699119 [09:42<474:14:21,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1574/4699119 [09:42<426:05:04,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1575/4699119 [09:42<425:58:33,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1576/4699119 [09:43<485:22:56,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1577/4699119 [09:43<527:52:16,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 460, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1578/4699119 [09:44<542:14:48,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1579/4699119 [09:44<517:08:12,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1580/4699119 [09:44<488:30:50,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1581/4699119 [09:45<515:54:30,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1582/4699119 [09:45<508:11:00,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1583/4699119 [09:45<464:27:05,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1584/4699119 [09:46<512:56:57,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1585/4699119 [09:46<457:08:31,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1586/4699119 [09:47<447:30:17,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1587/4699119 [09:47<501:02:48,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1588/4699119 [09:47<539:21:20,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1589/4699119 [09:48<565:20:39,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1590/4699119 [09:48<583:48:53,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1591/4699119 [09:49<597:38:03,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1592/4699119 [09:49<548:57:27,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1593/4699119 [09:50<484:05:10,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1594/4699119 [09:50<530:50:48,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1595/4699119 [09:50<534:01:06,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 173, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1596/4699119 [09:51<458:41:46,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1597/4699119 [09:51<509:42:23,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1598/4699119 [09:51<467:46:31,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 171, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1599/4699119 [09:52<412:45:49,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1600/4699119 [09:52<392:24:44,  3.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1601/4699119 [09:52<463:18:55,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1602/4699119 [09:53<444:11:21,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1603/4699119 [09:53<405:51:26,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1604/4699119 [09:53<381:29:08,  3.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1605/4699119 [09:54<455:48:34,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1606/4699119 [09:54<454:44:19,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1607/4699119 [09:54<453:01:36,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1608/4699119 [09:55<506:46:33,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1609/4699119 [09:55<542:40:33,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1610/4699119 [09:56<479:17:14,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1611/4699119 [09:56<528:48:23,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1612/4699119 [09:57<558:17:49,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1613/4699119 [09:57<483:44:41,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1614/4699119 [09:57<469:28:33,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1615/4699119 [09:57<422:35:17,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1616/4699119 [09:58<424:30:58,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1617/4699119 [09:58<404:28:45,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1618/4699119 [09:58<407:25:53,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1619/4699119 [09:59<473:36:57,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 124, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1620/4699119 [09:59<404:45:39,  3.22it/s]

last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1621/4699119 [09:59<473:37:25,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1622/4699119 [10:00<474:35:15,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1623/4699119 [10:00<474:41:35,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1624/4699119 [10:01<505:08:14,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1625/4699119 [10:01<446:56:11,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1626/4699119 [10:01<501:52:40,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1627/4699119 [10:02<540:29:46,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1628/4699119 [10:02<567:04:28,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1629/4699119 [10:03<515:53:35,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1630/4699119 [10:03<479:06:37,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1631/4699119 [10:03<524:50:23,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1632/4699119 [10:04<487:36:14,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 172, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1633/4699119 [10:04<426:04:26,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1634/4699119 [10:04<486:51:09,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1635/4699119 [10:05<478:41:08,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1636/4699119 [10:05<524:02:28,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 432, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1637/4699119 [10:06<527:59:15,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1638/4699119 [10:06<510:05:55,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1639/4699119 [10:06<520:04:40,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1640/4699119 [10:07<506:03:43,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1641/4699119 [10:07<543:50:12,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 314, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1642/4699119 [10:08<510:50:15,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1643/4699119 [10:08<544:50:22,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1644/4699119 [10:08<523:07:09,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1645/4699119 [10:09<488:22:49,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1646/4699119 [10:09<530:58:46,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1647/4699119 [10:10<559:09:35,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1648/4699119 [10:10<579:14:41,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1649/4699119 [10:10<523:34:21,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 166, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1650/4699119 [10:11<450:49:18,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 436, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1651/4699119 [10:11<483:01:21,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1652/4699119 [10:11<432:02:21,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1653/4699119 [10:12<419:40:12,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1654/4699119 [10:12<481:14:40,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1655/4699119 [10:13<490:33:39,  2.66it/s]

last_hidden_state = torch.Size([8, 432, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1656/4699119 [10:13<506:00:24,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1657/4699119 [10:13<542:12:24,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1658/4699119 [10:14<568:59:07,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1659/4699119 [10:14<492:53:09,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1660/4699119 [10:14<439:29:01,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1661/4699119 [10:15<495:40:26,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1662/4699119 [10:15<534:50:56,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1663/4699119 [10:16<509:26:37,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1664/4699119 [10:16<441:04:06,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1665/4699119 [10:16<431:56:11,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1666/4699119 [10:17<490:59:46,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1667/4699119 [10:17<531:35:15,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1668/4699119 [10:18<559:48:38,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1669/4699119 [10:18<507:38:19,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1670/4699119 [10:18<492:30:54,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 94, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1671/4699119 [10:18<409:24:56,  3.19it/s]

last_hidden_state = torch.Size([8, 506, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1672/4699119 [10:19<474:34:40,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 136, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1673/4699119 [10:19<408:16:39,  3.20it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1674/4699119 [10:20<474:23:50,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1675/4699119 [10:20<469:53:54,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 510, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1676/4699119 [10:20<517:32:23,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1677/4699119 [10:21<460:05:57,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1678/4699119 [10:21<432:07:05,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1679/4699119 [10:21<423:06:52,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1680/4699119 [10:22<389:04:55,  3.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1681/4699119 [10:22<459:57:09,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 480, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1682/4699119 [10:22<497:41:54,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1683/4699119 [10:23<536:26:15,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1684/4699119 [10:23<471:04:30,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1685/4699119 [10:24<463:41:47,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1686/4699119 [10:24<480:31:36,  2.72it/s]

last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1687/4699119 [10:24<483:29:49,  2.70it/s]

last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1688/4699119 [10:25<489:05:25,  2.67it/s]

last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1689/4699119 [10:25<468:35:46,  2.78it/s]

last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1690/4699119 [10:25<470:06:21,  2.78it/s]

last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1691/4699119 [10:26<521:45:08,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1692/4699119 [10:26<515:49:49,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1693/4699119 [10:27<493:31:47,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1694/4699119 [10:27<534:11:11,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1695/4699119 [10:27<478:38:40,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1696/4699119 [10:28<503:22:02,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1697/4699119 [10:28<508:40:37,  2.57it/s]

last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1698/4699119 [10:28<474:54:33,  2.75it/s]

last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1699/4699119 [10:29<485:30:36,  2.69it/s]

last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1700/4699119 [10:29<477:35:54,  2.73it/s]

last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1701/4699119 [10:30<473:23:06,  2.76it/s]

last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1702/4699119 [10:30<437:54:07,  2.98it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1703/4699119 [10:30<495:32:22,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1704/4699119 [10:31<498:17:26,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 484, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1705/4699119 [10:31<532:52:47,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1706/4699119 [10:32<561:51:40,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1707/4699119 [10:32<590:07:30,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1708/4699119 [10:33<614:29:53,  2.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1709/4699119 [10:33<619:07:41,  2.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1710/4699119 [10:33<533:30:23,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1711/4699119 [10:34<478:37:51,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1712/4699119 [10:34<466:37:07,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1713/4699119 [10:34<430:09:15,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1714/4699119 [10:35<392:09:30,  3.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1715/4699119 [10:35<373:34:57,  3.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1716/4699119 [10:35<423:20:50,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1717/4699119 [10:36<489:27:50,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1718/4699119 [10:36<479:00:04,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1719/4699119 [10:37<523:39:59,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1720/4699119 [10:37<509:20:53,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1721/4699119 [10:37<545:37:41,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1722/4699119 [10:38<545:41:05,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1723/4699119 [10:38<500:18:42,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 481, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1724/4699119 [10:39<540:09:23,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1725/4699119 [10:39<488:16:56,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1726/4699119 [10:39<530:09:41,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1727/4699119 [10:40<561:15:28,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 404, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1728/4699119 [10:40<549:27:14,  2.37it/s]

last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1729/4699119 [10:41<510:57:01,  2.55it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1730/4699119 [10:41<545:26:18,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1731/4699119 [10:41<490:07:43,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1732/4699119 [10:42<524:41:59,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 465, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1733/4699119 [10:42<546:46:49,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 185, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1734/4699119 [10:42<474:10:08,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 399, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1735/4699119 [10:43<488:41:22,  2.67it/s]

last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1736/4699119 [10:43<507:00:47,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1737/4699119 [10:44<456:56:29,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1738/4699119 [10:44<509:56:10,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1739/4699119 [10:44<472:16:55,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1740/4699119 [10:45<518:03:45,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1741/4699119 [10:45<534:01:52,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 235, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1742/4699119 [10:46<479:04:08,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1743/4699119 [10:46<525:33:26,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1744/4699119 [10:46<504:59:18,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1745/4699119 [10:47<541:29:21,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1746/4699119 [10:47<483:56:29,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 443, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1747/4699119 [10:48<508:59:11,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1748/4699119 [10:48<544:39:18,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1749/4699119 [10:48<481:03:16,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1750/4699119 [10:49<456:19:37,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 473, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1751/4699119 [10:49<499:38:03,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1752/4699119 [10:50<537:33:44,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 425, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1753/4699119 [10:50<542:33:02,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1754/4699119 [10:50<572:15:17,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1755/4699119 [10:51<534:13:10,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1756/4699119 [10:51<561:50:33,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1757/4699119 [10:52<494:23:38,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1758/4699119 [10:52<502:28:04,  2.60it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1759/4699119 [10:52<541:19:00,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1760/4699119 [10:53<568:55:40,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1761/4699119 [10:53<539:10:32,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1762/4699119 [10:54<554:41:12,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1763/4699119 [10:54<508:36:15,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1764/4699119 [10:54<449:52:54,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1765/4699119 [10:55<440:42:50,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1766/4699119 [10:55<438:16:36,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1767/4699119 [10:55<411:36:13,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1768/4699119 [10:56<436:51:01,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1769/4699119 [10:56<494:01:57,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1770/4699119 [10:56<470:59:53,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1771/4699119 [10:57<517:31:59,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1772/4699119 [10:57<550:46:33,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1773/4699119 [10:58<537:12:22,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1774/4699119 [10:58<527:12:27,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 148, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1775/4699119 [10:58<448:19:32,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 503, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1776/4699119 [10:59<506:05:45,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 174, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1777/4699119 [10:59<439:26:50,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1778/4699119 [10:59<397:30:41,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1779/4699119 [11:00<468:26:54,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1780/4699119 [11:00<469:10:51,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1781/4699119 [11:00<464:58:48,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1782/4699119 [11:01<469:30:08,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1783/4699119 [11:01<516:03:29,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1784/4699119 [11:02<549:06:14,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1785/4699119 [11:02<573:57:55,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1786/4699119 [11:03<574:25:33,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1787/4699119 [11:03<593:37:33,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1788/4699119 [11:04<606:38:29,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 314, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1789/4699119 [11:04<554:26:34,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1790/4699119 [11:04<499:50:08,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1791/4699119 [11:04<441:58:36,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1792/4699119 [11:05<434:47:50,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1793/4699119 [11:05<477:18:54,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 467, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1794/4699119 [11:06<514:53:43,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1795/4699119 [11:06<486:24:59,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1796/4699119 [11:06<463:29:13,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 161, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1797/4699119 [11:07<408:25:35,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 432, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1798/4699119 [11:07<448:00:22,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1799/4699119 [11:07<502:59:58,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1800/4699119 [11:08<499:43:31,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1801/4699119 [11:08<538:20:31,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1802/4699119 [11:09<565:36:36,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 425, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1803/4699119 [11:09<562:33:16,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1804/4699119 [11:10<582:01:10,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 131, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1805/4699119 [11:10<484:12:12,  2.69it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1806/4699119 [11:10<528:49:38,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1807/4699119 [11:11<473:28:34,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1808/4699119 [11:11<423:53:21,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1809/4699119 [11:11<472:19:23,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1810/4699119 [11:12<471:31:34,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1811/4699119 [11:12<455:31:07,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1812/4699119 [11:13<507:02:19,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1813/4699119 [11:13<472:54:03,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1814/4699119 [11:13<423:41:46,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1815/4699119 [11:14<485:47:56,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1816/4699119 [11:14<529:40:58,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1817/4699119 [11:14<558:50:52,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1818/4699119 [11:15<518:55:59,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1819/4699119 [11:15<527:19:34,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1820/4699119 [11:16<557:41:58,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1821/4699119 [11:16<578:40:37,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1822/4699119 [11:17<543:39:33,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1823/4699119 [11:17<485:10:03,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 488, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1824/4699119 [11:17<523:37:59,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1825/4699119 [11:18<490:21:03,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1826/4699119 [11:18<475:00:11,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1827/4699119 [11:18<452:07:11,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1828/4699119 [11:19<503:24:25,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1829/4699119 [11:19<450:14:00,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1830/4699119 [11:19<477:53:55,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1831/4699119 [11:20<473:42:11,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1832/4699119 [11:20<519:11:55,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1833/4699119 [11:21<504:44:55,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1834/4699119 [11:21<541:06:35,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1835/4699119 [11:21<502:02:44,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1836/4699119 [11:22<444:36:32,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 108, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1837/4699119 [11:22<379:02:51,  3.44it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1838/4699119 [11:22<453:33:45,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1839/4699119 [11:23<505:51:32,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1840/4699119 [11:23<478:03:01,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1841/4699119 [11:24<522:31:34,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1842/4699119 [11:24<554:06:57,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1843/4699119 [11:25<575:46:15,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1844/4699119 [11:25<499:45:44,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1845/4699119 [11:25<527:26:08,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1846/4699119 [11:26<503:25:32,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1847/4699119 [11:26<478:49:41,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1848/4699119 [11:26<523:47:42,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1849/4699119 [11:27<518:18:32,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1850/4699119 [11:27<460:35:36,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 165, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1851/4699119 [11:27<406:41:00,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1852/4699119 [11:28<461:27:29,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1853/4699119 [11:28<511:44:58,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1854/4699119 [11:28<485:14:43,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1855/4699119 [11:29<529:03:11,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1856/4699119 [11:29<548:58:54,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 467, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1857/4699119 [11:30<564:01:57,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 468, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1858/4699119 [11:30<569:22:46,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1859/4699119 [11:31<522:56:09,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1860/4699119 [11:31<466:33:38,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1861/4699119 [11:31<514:35:13,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1862/4699119 [11:32<451:20:15,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1863/4699119 [11:32<504:33:50,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 123, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1865/4699119 [11:33<393:11:26,  3.32it/s]

last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1866/4699119 [11:33<392:21:12,  3.33it/s]

last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1867/4699119 [11:33<408:12:23,  3.20it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1868/4699119 [11:34<473:42:18,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1869/4699119 [11:34<521:11:02,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1870/4699119 [11:34<471:25:54,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 422, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1871/4699119 [11:35<492:47:37,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1872/4699119 [11:35<496:01:15,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1873/4699119 [11:36<496:04:54,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 150, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1874/4699119 [11:36<426:38:53,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1875/4699119 [11:36<415:21:24,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1876/4699119 [11:36<432:38:25,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1877/4699119 [11:37<491:27:08,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1878/4699119 [11:37<514:41:40,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1879/4699119 [11:38<450:49:15,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1880/4699119 [11:38<481:32:38,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1881/4699119 [11:38<526:25:22,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 496, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1882/4699119 [11:39<553:19:12,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1883/4699119 [11:39<577:35:00,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 174, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1884/4699119 [11:40<488:37:18,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 443, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1885/4699119 [11:40<512:36:35,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1886/4699119 [11:41<546:49:12,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1887/4699119 [11:41<548:02:47,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1888/4699119 [11:41<480:38:39,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1889/4699119 [11:42<502:38:54,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1890/4699119 [11:42<540:09:25,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1891/4699119 [11:42<485:40:34,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1892/4699119 [11:43<528:28:09,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1893/4699119 [11:43<558:00:53,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1894/4699119 [11:44<470:07:01,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1895/4699119 [11:44<516:37:41,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1896/4699119 [11:45<551:13:37,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1897/4699119 [11:45<532:52:19,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1898/4699119 [11:45<548:05:16,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 399, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1899/4699119 [11:46<541:17:25,  2.41it/s]

last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1900/4699119 [11:46<571:02:56,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1901/4699119 [11:47<543:25:47,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1902/4699119 [11:47<498:07:09,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1903/4699119 [11:47<456:41:26,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1904/4699119 [11:47<425:52:20,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1905/4699119 [11:48<417:01:10,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1906/4699119 [11:48<461:31:29,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1907/4699119 [11:49<510:45:23,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1908/4699119 [11:49<448:28:40,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1909/4699119 [11:49<502:15:35,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1910/4699119 [11:50<540:59:22,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1911/4699119 [11:50<519:55:25,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 172, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1912/4699119 [11:50<449:05:35,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1913/4699119 [11:51<501:49:06,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1914/4699119 [11:51<430:50:58,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 412, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1915/4699119 [11:52<462:10:16,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1916/4699119 [11:52<427:13:59,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1917/4699119 [11:52<438:50:22,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1918/4699119 [11:53<449:29:09,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1919/4699119 [11:53<502:17:01,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 150, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1920/4699119 [11:53<431:25:12,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1921/4699119 [11:53<392:40:14,  3.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1922/4699119 [11:54<463:00:46,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1923/4699119 [11:54<515:47:01,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1924/4699119 [11:55<459:24:18,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1925/4699119 [11:55<472:49:28,  2.76it/s]

last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1926/4699119 [11:55<436:13:32,  2.99it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1927/4699119 [11:56<492:57:00,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1928/4699119 [11:56<533:16:26,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1929/4699119 [11:57<534:00:38,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1930/4699119 [11:57<561:10:02,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1931/4699119 [11:58<555:42:38,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1932/4699119 [11:58<534:41:27,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1933/4699119 [11:58<489:54:30,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1934/4699119 [11:59<447:53:27,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1935/4699119 [11:59<430:54:19,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1936/4699119 [11:59<454:41:39,  2.87it/s]

last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1937/4699119 [12:00<436:24:22,  2.99it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1938/4699119 [12:00<493:06:48,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1939/4699119 [12:00<443:28:37,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1940/4699119 [12:01<457:13:18,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1941/4699119 [12:01<507:21:56,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1942/4699119 [12:01<433:02:02,  3.01it/s]

last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1943/4699119 [12:02<437:05:47,  2.99it/s]

last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1944/4699119 [12:02<472:31:07,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1945/4699119 [12:03<519:51:23,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1946/4699119 [12:03<489:10:22,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1947/4699119 [12:03<538:48:54,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1948/4699119 [12:04<565:21:24,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1949/4699119 [12:04<496:46:52,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 105, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1950/4699119 [12:04<415:26:37,  3.14it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1951/4699119 [12:05<480:32:30,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1952/4699119 [12:05<416:04:44,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1953/4699119 [12:05<429:38:06,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1954/4699119 [12:06<430:24:26,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 487, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1955/4699119 [12:06<491:03:17,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1956/4699119 [12:07<531:56:31,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1957/4699119 [12:07<527:31:50,  2.47it/s]

last_hidden_state = torch.Size([8, 508, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1958/4699119 [12:08<558:23:48,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1959/4699119 [12:08<552:39:15,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1960/4699119 [12:08<537:06:37,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 481, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1961/4699119 [12:09<564:42:36,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1962/4699119 [12:09<587:51:29,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1963/4699119 [12:10<579:16:33,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1964/4699119 [12:10<575:30:16,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1965/4699119 [12:10<523:20:14,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1966/4699119 [12:11<462:45:22,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1967/4699119 [12:11<511:59:55,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1968/4699119 [12:11<449:18:22,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1969/4699119 [12:12<478:36:22,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1970/4699119 [12:12<523:05:12,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1971/4699119 [12:13<478:07:41,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1972/4699119 [12:13<464:48:53,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1973/4699119 [12:13<474:33:41,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1974/4699119 [12:14<525:29:35,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1975/4699119 [12:14<506:39:29,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1976/4699119 [12:14<463:12:26,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1977/4699119 [12:15<469:23:47,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1978/4699119 [12:15<516:04:02,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1979/4699119 [12:16<472:30:06,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1980/4699119 [12:16<520:41:40,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1981/4699119 [12:16<472:52:26,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1982/4699119 [12:17<519:19:08,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1983/4699119 [12:17<497:33:26,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1984/4699119 [12:18<479:19:06,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1985/4699119 [12:18<450:39:15,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 399, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 1986/4699119 [12:18<472:51:00,  2.76it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1987/4699119 [12:19<523:10:20,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1988/4699119 [12:19<472:58:46,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1989/4699119 [12:19<505:22:29,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1990/4699119 [12:20<444:46:48,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1991/4699119 [12:20<488:19:04,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1992/4699119 [12:20<484:18:56,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1993/4699119 [12:21<452:41:16,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 172, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1994/4699119 [12:21<401:28:02,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1995/4699119 [12:21<469:57:32,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1996/4699119 [12:22<518:12:16,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 146, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1997/4699119 [12:22<442:26:36,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1998/4699119 [12:22<445:57:35,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 1999/4699119 [12:23<499:50:53,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2000/4699119 [12:23<522:47:22,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2001/4699119 [12:24<553:44:36,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 503, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2002/4699119 [12:24<581:09:44,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2003/4699119 [12:25<594:53:38,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2004/4699119 [12:25<604:44:59,  2.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2005/4699119 [12:26<579:39:43,  2.25it/s]

last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2006/4699119 [12:26<531:29:52,  2.45it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2007/4699119 [12:27<561:17:39,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2008/4699119 [12:27<581:09:23,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2009/4699119 [12:27<517:58:33,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2010/4699119 [12:28<551:57:25,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2011/4699119 [12:28<574:52:49,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2012/4699119 [12:29<496:41:03,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2013/4699119 [12:29<486:36:50,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2014/4699119 [12:29<528:30:14,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2015/4699119 [12:30<531:48:34,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2016/4699119 [12:30<478:22:19,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2017/4699119 [12:31<522:21:06,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2018/4699119 [12:31<492:56:41,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2019/4699119 [12:31<521:38:13,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2020/4699119 [12:32<554:14:30,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2021/4699119 [12:32<479:27:07,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2022/4699119 [12:32<523:09:21,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2023/4699119 [12:33<481:12:42,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2024/4699119 [12:33<524:37:00,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2025/4699119 [12:34<528:43:42,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2026/4699119 [12:34<471:09:02,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2027/4699119 [12:34<478:04:25,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2028/4699119 [12:35<524:22:19,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2029/4699119 [12:35<544:10:30,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2030/4699119 [12:36<510:43:30,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2031/4699119 [12:36<544:55:13,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2032/4699119 [12:36<511:28:52,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2033/4699119 [12:37<478:14:34,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2034/4699119 [12:37<430:16:49,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2035/4699119 [12:37<395:57:21,  3.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 124, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2036/4699119 [12:37<350:46:19,  3.72it/s]

last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2037/4699119 [12:38<363:15:34,  3.59it/s]

last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2038/4699119 [12:38<398:36:05,  3.27it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2039/4699119 [12:39<466:43:41,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2040/4699119 [12:39<519:59:58,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2041/4699119 [12:40<554:39:13,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2042/4699119 [12:40<513:28:20,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2043/4699119 [12:40<546:59:10,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2044/4699119 [12:41<546:35:17,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2045/4699119 [12:41<503:55:00,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2046/4699119 [12:41<477:26:59,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2047/4699119 [12:42<521:20:09,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2048/4699119 [12:42<480:17:37,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2049/4699119 [12:43<494:30:04,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2050/4699119 [12:43<437:01:09,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2051/4699119 [12:43<411:30:32,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2052/4699119 [12:43<407:41:12,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2053/4699119 [12:44<474:54:28,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2054/4699119 [12:44<512:41:36,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2055/4699119 [12:45<465:42:46,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2056/4699119 [12:45<507:24:34,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2057/4699119 [12:45<527:05:11,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2058/4699119 [12:46<523:06:24,  2.49it/s]

last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2059/4699119 [12:46<484:21:42,  2.69it/s]

last_hidden_state = torch.Size([8, 468, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2060/4699119 [12:47<513:13:57,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 401, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2061/4699119 [12:47<517:58:09,  2.52it/s]

last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2062/4699119 [12:47<499:25:25,  2.61it/s]

last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2063/4699119 [12:48<467:18:48,  2.79it/s]

last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2064/4699119 [12:48<477:11:20,  2.73it/s]

last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2065/4699119 [12:48<489:12:05,  2.67it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2066/4699119 [12:49<530:29:32,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2067/4699119 [12:49<536:32:13,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2068/4699119 [12:50<567:17:53,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 443, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2069/4699119 [12:50<569:42:44,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2070/4699119 [12:51<492:45:18,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 150, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2071/4699119 [12:51<424:58:39,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 460, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2072/4699119 [12:51<470:47:40,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2073/4699119 [12:52<517:25:25,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2074/4699119 [12:52<500:06:19,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2075/4699119 [12:52<441:56:05,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2076/4699119 [12:52<404:46:14,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 129, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2077/4699119 [12:53<360:08:15,  3.62it/s]

last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2078/4699119 [12:53<356:45:35,  3.66it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2079/4699119 [12:53<437:20:18,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2080/4699119 [12:54<439:21:17,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2081/4699119 [12:54<440:01:26,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2082/4699119 [12:55<497:20:55,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2083/4699119 [12:55<463:39:57,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2084/4699119 [12:55<423:36:37,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2085/4699119 [12:56<484:03:23,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2086/4699119 [12:56<516:00:19,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 460, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2087/4699119 [12:56<534:06:57,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2088/4699119 [12:57<484:27:37,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2089/4699119 [12:57<464:27:33,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2090/4699119 [12:57<475:10:30,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2091/4699119 [12:58<488:47:09,  2.67it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2092/4699119 [12:58<530:49:37,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2093/4699119 [12:59<560:27:26,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2094/4699119 [12:59<497:41:55,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2095/4699119 [12:59<476:24:25,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2096/4699119 [13:00<500:58:46,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 496, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2097/4699119 [13:00<535:23:38,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2098/4699119 [13:01<467:45:03,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 129, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2099/4699119 [13:01<403:18:16,  3.24it/s]

last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2101/4699119 [13:01<357:13:32,  3.65it/s]

last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2102/4699119 [13:02<351:03:56,  3.72it/s]

last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2103/4699119 [13:02<424:52:41,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2104/4699119 [13:02<433:30:05,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2105/4699119 [13:03<424:21:41,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2106/4699119 [13:03<435:17:40,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2107/4699119 [13:03<404:48:29,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 285, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2108/4699119 [13:04<403:51:33,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2109/4699119 [13:04<422:09:45,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2110/4699119 [13:04<483:44:21,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2111/4699119 [13:05<514:57:40,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2112/4699119 [13:05<550:35:46,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2113/4699119 [13:06<475:24:41,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2114/4699119 [13:06<448:12:57,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 136, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2115/4699119 [13:06<389:44:04,  3.35it/s]

last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2116/4699119 [13:06<443:57:10,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2117/4699119 [13:07<400:54:46,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2118/4699119 [13:07<469:06:53,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2119/4699119 [13:08<485:20:26,  2.69it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2120/4699119 [13:08<527:28:45,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 144, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2121/4699119 [13:08<447:02:04,  2.92it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2122/4699119 [13:09<501:34:02,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2123/4699119 [13:09<443:08:00,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2124/4699119 [13:09<498:03:27,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 330, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2125/4699119 [13:10<482:28:49,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2126/4699119 [13:10<526:56:45,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2127/4699119 [13:11<559:03:39,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 467, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2128/4699119 [13:11<570:50:00,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 429, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2129/4699119 [13:12<565:15:40,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2130/4699119 [13:12<524:14:20,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 484, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2131/4699119 [13:12<551:15:28,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2132/4699119 [13:13<552:24:21,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2133/4699119 [13:13<574:24:01,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2134/4699119 [13:14<482:29:49,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2135/4699119 [13:14<473:38:45,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2136/4699119 [13:14<493:45:22,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2137/4699119 [13:15<469:06:52,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 151, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2138/4699119 [13:15<408:14:07,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2139/4699119 [13:15<474:19:13,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2140/4699119 [13:16<488:32:36,  2.67it/s]

last_hidden_state = torch.Size([8, 507, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2141/4699119 [13:16<536:18:09,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2142/4699119 [13:17<512:35:19,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2143/4699119 [13:17<547:06:23,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2144/4699119 [13:17<502:22:43,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2145/4699119 [13:18<545:43:48,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2146/4699119 [13:18<498:48:40,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2147/4699119 [13:18<467:31:45,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2148/4699119 [13:19<503:43:52,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2149/4699119 [13:19<541:27:47,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2150/4699119 [13:20<549:52:37,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2151/4699119 [13:20<522:08:52,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2152/4699119 [13:21<556:32:15,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2153/4699119 [13:21<578:43:31,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2154/4699119 [13:22<594:17:27,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2155/4699119 [13:22<517:37:29,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2156/4699119 [13:22<550:29:24,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2157/4699119 [13:23<477:44:02,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2158/4699119 [13:23<459:29:20,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2159/4699119 [13:23<433:17:31,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2160/4699119 [13:24<491:14:32,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2161/4699119 [13:24<470:10:11,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2162/4699119 [13:25<518:42:16,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 482, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2163/4699119 [13:25<547:04:43,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2164/4699119 [13:25<489:06:28,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2165/4699119 [13:25<436:21:23,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2166/4699119 [13:26<410:26:29,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2167/4699119 [13:26<479:14:34,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2168/4699119 [13:27<524:52:49,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2169/4699119 [13:27<468:02:57,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2170/4699119 [13:27<421:24:27,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2171/4699119 [13:28<484:23:56,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2172/4699119 [13:28<528:03:00,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2173/4699119 [13:29<500:08:28,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2174/4699119 [13:29<525:13:52,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 496, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2175/4699119 [13:29<552:31:32,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2176/4699119 [13:30<575:09:07,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2177/4699119 [13:30<517:23:00,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2178/4699119 [13:31<482:06:44,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2179/4699119 [13:31<463:05:41,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2180/4699119 [13:31<511:44:55,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2181/4699119 [13:32<509:45:19,  2.56it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2182/4699119 [13:32<545:58:58,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2183/4699119 [13:33<570:39:13,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2184/4699119 [13:33<551:06:46,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2185/4699119 [13:34<574:09:42,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2186/4699119 [13:34<545:28:49,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 126, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2187/4699119 [13:34<455:45:36,  2.86it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2188/4699119 [13:35<506:48:32,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 432, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2189/4699119 [13:35<516:15:56,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2190/4699119 [13:35<480:24:32,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2191/4699119 [13:36<523:38:45,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2192/4699119 [13:36<556:09:27,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2193/4699119 [13:37<559:35:43,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 152, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2194/4699119 [13:37<471:04:53,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2195/4699119 [13:37<474:36:37,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2196/4699119 [13:38<520:34:21,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2197/4699119 [13:38<459:12:21,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2198/4699119 [13:38<425:28:12,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2199/4699119 [13:39<485:30:29,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2200/4699119 [13:39<490:05:45,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2201/4699119 [13:39<481:52:20,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2202/4699119 [13:40<428:43:19,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2203/4699119 [13:40<475:47:04,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2204/4699119 [13:41<521:07:18,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2205/4699119 [13:41<517:51:25,  2.52it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2206/4699119 [13:42<550:17:45,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2207/4699119 [13:42<573:19:49,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2208/4699119 [13:42<575:29:47,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2209/4699119 [13:43<517:13:16,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2210/4699119 [13:43<500:36:48,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2211/4699119 [13:44<515:35:14,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2212/4699119 [13:44<549:05:35,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2213/4699119 [13:44<574:52:39,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2214/4699119 [13:45<521:37:05,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2215/4699119 [13:45<508:05:13,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2216/4699119 [13:46<511:01:20,  2.55it/s]

last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2217/4699119 [13:46<473:41:24,  2.75it/s]

last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2218/4699119 [13:46<438:37:15,  2.97it/s]

last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2219/4699119 [13:46<442:49:20,  2.95it/s]

last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2221/4699119 [13:47<404:42:03,  3.22it/s]

last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2222/4699119 [13:47<355:48:32,  3.67it/s]

last_hidden_state = torch.Size([8, 119, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 115, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2223/4699119 [13:47<321:31:36,  4.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2224/4699119 [13:48<351:21:39,  3.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2225/4699119 [13:48<432:39:25,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2226/4699119 [13:48<407:26:53,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2227/4699119 [13:49<418:11:06,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2228/4699119 [13:49<437:54:37,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2229/4699119 [13:50<496:15:49,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2230/4699119 [13:50<483:58:46,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2231/4699119 [13:51<528:30:02,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2232/4699119 [13:51<558:33:12,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2233/4699119 [13:51<580:00:14,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2234/4699119 [13:52<595:36:43,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2235/4699119 [13:52<615:51:41,  2.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2236/4699119 [13:53<621:14:55,  2.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2237/4699119 [13:53<605:57:49,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2238/4699119 [13:54<578:40:22,  2.25it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2239/4699119 [13:54<593:11:46,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2240/4699119 [13:55<549:13:33,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2241/4699119 [13:55<474:51:16,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2242/4699119 [13:55<522:00:13,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2243/4699119 [13:56<526:23:35,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2244/4699119 [13:56<557:09:47,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2245/4699119 [13:56<483:45:48,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2246/4699119 [13:57<475:12:41,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2247/4699119 [13:57<480:51:24,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2248/4699119 [13:58<470:32:57,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2249/4699119 [13:58<518:38:22,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2250/4699119 [13:59<551:06:15,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2251/4699119 [13:59<574:41:14,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2252/4699119 [13:59<525:55:01,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2253/4699119 [14:00<525:02:22,  2.48it/s]

last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2254/4699119 [14:00<555:00:40,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2255/4699119 [14:01<576:04:20,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2256/4699119 [14:01<566:19:17,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2257/4699119 [14:01<489:59:02,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2258/4699119 [14:02<529:51:44,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2259/4699119 [14:02<560:30:26,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 494, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2260/4699119 [14:03<577:55:54,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2261/4699119 [14:03<555:34:47,  2.35it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2262/4699119 [14:04<576:55:44,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2263/4699119 [14:04<498:37:00,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2264/4699119 [14:04<484:40:42,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2265/4699119 [14:05<494:16:29,  2.64it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2266/4699119 [14:05<534:51:12,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2267/4699119 [14:06<562:53:51,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2268/4699119 [14:06<510:45:45,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2269/4699119 [14:06<505:22:26,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 284, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2270/4699119 [14:07<473:19:59,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2271/4699119 [14:07<519:45:29,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2272/4699119 [14:07<458:35:19,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2273/4699119 [14:08<508:25:51,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2274/4699119 [14:08<553:24:53,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2275/4699119 [14:08<479:18:05,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2276/4699119 [14:09<523:08:59,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2277/4699119 [14:09<554:15:54,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2278/4699119 [14:10<528:28:13,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2279/4699119 [14:10<557:50:13,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2280/4699119 [14:11<578:14:36,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2281/4699119 [14:11<593:12:59,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2282/4699119 [14:12<604:15:06,  2.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2283/4699119 [14:12<612:33:10,  2.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2284/4699119 [14:13<578:09:04,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2285/4699119 [14:13<580:53:25,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2286/4699119 [14:14<593:54:00,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2287/4699119 [14:14<582:21:29,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2288/4699119 [14:14<555:07:36,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2289/4699119 [14:15<580:18:57,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2290/4699119 [14:15<594:14:34,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2291/4699119 [14:16<549:42:53,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2292/4699119 [14:16<502:51:30,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2293/4699119 [14:16<457:42:44,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2294/4699119 [14:16<410:20:18,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2295/4699119 [14:17<475:59:35,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2296/4699119 [14:17<435:39:12,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2297/4699119 [14:17<416:08:12,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2298/4699119 [14:18<392:04:41,  3.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2299/4699119 [14:18<462:07:30,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2300/4699119 [14:18<424:31:33,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2301/4699119 [14:19<486:12:15,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 235, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2302/4699119 [14:19<446:39:00,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2303/4699119 [14:20<488:27:11,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2304/4699119 [14:20<472:21:34,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2305/4699119 [14:20<422:57:19,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2306/4699119 [14:21<454:49:13,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2307/4699119 [14:21<505:19:29,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2308/4699119 [14:21<464:34:35,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2309/4699119 [14:22<457:50:08,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2310/4699119 [14:22<436:40:14,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2311/4699119 [14:23<494:17:44,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2312/4699119 [14:23<535:06:25,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2313/4699119 [14:23<545:04:35,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2314/4699119 [14:24<511:02:49,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2315/4699119 [14:24<453:02:14,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 384, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2316/4699119 [14:24<465:16:47,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2317/4699119 [14:25<513:58:11,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2318/4699119 [14:25<538:46:14,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2319/4699119 [14:26<554:35:55,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2320/4699119 [14:26<496:49:45,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2321/4699119 [14:27<536:47:46,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2322/4699119 [14:27<493:15:42,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2323/4699119 [14:27<533:03:33,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 505, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2324/4699119 [14:28<567:51:30,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2325/4699119 [14:28<549:24:42,  2.37it/s]

last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2327/4699119 [14:29<431:55:04,  3.02it/s]

last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2328/4699119 [14:29<455:50:45,  2.86it/s]

last_hidden_state = torch.Size([8, 422, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2329/4699119 [14:30<481:59:47,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2330/4699119 [14:30<468:18:39,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2331/4699119 [14:30<478:38:02,  2.73it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2332/4699119 [14:31<522:27:36,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2333/4699119 [14:31<553:58:16,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 143, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2334/4699119 [14:31<466:23:10,  2.80it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2335/4699119 [14:32<514:20:17,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2336/4699119 [14:32<537:01:15,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2337/4699119 [14:33<532:48:45,  2.45it/s]

last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2338/4699119 [14:33<512:45:05,  2.54it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2339/4699119 [14:34<546:59:00,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2340/4699119 [14:34<570:04:48,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2341/4699119 [14:34<523:35:30,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2342/4699119 [14:35<495:57:14,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2343/4699119 [14:35<536:04:04,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2344/4699119 [14:36<530:52:15,  2.46it/s]

last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2346/4699119 [14:36<420:36:25,  3.10it/s]

last_hidden_state = torch.Size([8, 145, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2347/4699119 [14:36<447:43:49,  2.91it/s]

last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2348/4699119 [14:37<452:34:40,  2.88it/s]

last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2350/4699119 [14:37<391:52:52,  3.33it/s]

last_hidden_state = torch.Size([8, 159, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2351/4699119 [14:38<394:31:29,  3.31it/s]

last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2352/4699119 [14:38<381:42:02,  3.42it/s]

last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2353/4699119 [14:38<410:08:40,  3.18it/s]

last_hidden_state = torch.Size([8, 484, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2354/4699119 [14:39<470:52:16,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2355/4699119 [14:39<464:52:41,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2356/4699119 [14:39<425:26:45,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2357/4699119 [14:40<472:42:38,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2358/4699119 [14:40<491:37:08,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2359/4699119 [14:41<533:04:17,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2360/4699119 [14:41<480:10:19,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2361/4699119 [14:41<452:48:59,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 166, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2362/4699119 [14:41<400:31:52,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2363/4699119 [14:42<468:06:15,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2364/4699119 [14:42<467:20:43,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2365/4699119 [14:43<499:37:24,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2366/4699119 [14:43<494:06:38,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2367/4699119 [14:44<549:25:10,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2368/4699119 [14:44<572:06:08,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2369/4699119 [14:45<566:42:06,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 285, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2370/4699119 [14:45<516:28:02,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2371/4699119 [14:45<497:16:19,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2372/4699119 [14:46<467:32:57,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2373/4699119 [14:46<514:16:05,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2374/4699119 [14:46<487:44:21,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 477, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2375/4699119 [14:47<523:05:22,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2376/4699119 [14:47<459:08:44,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2377/4699119 [14:48<518:40:22,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 115, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2379/4699119 [14:48<399:37:32,  3.26it/s]

last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2380/4699119 [14:48<410:37:36,  3.18it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2381/4699119 [14:49<475:12:41,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2382/4699119 [14:49<507:50:55,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2383/4699119 [14:50<544:39:18,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2384/4699119 [14:50<499:16:22,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2385/4699119 [14:50<538:00:56,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2386/4699119 [14:51<565:29:00,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2387/4699119 [14:51<520:57:05,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 157, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2388/4699119 [14:51<445:01:15,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2389/4699119 [14:52<500:25:06,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2390/4699119 [14:52<456:00:53,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2391/4699119 [14:53<496:59:40,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2392/4699119 [14:53<505:50:10,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2393/4699119 [14:54<542:33:28,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2394/4699119 [14:54<503:22:18,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2395/4699119 [14:54<490:11:30,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2396/4699119 [14:55<531:01:54,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 112, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2397/4699119 [14:55<438:59:11,  2.97it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2398/4699119 [14:55<497:14:51,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2399/4699119 [14:56<506:24:38,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 482, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2400/4699119 [14:56<538:48:05,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2401/4699119 [14:57<516:11:17,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2402/4699119 [14:57<468:14:17,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2403/4699119 [14:57<516:00:42,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2404/4699119 [14:58<475:44:16,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2405/4699119 [14:58<451:39:17,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2406/4699119 [14:58<483:04:47,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2407/4699119 [14:59<444:56:44,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2408/4699119 [14:59<418:52:22,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2409/4699119 [14:59<481:45:09,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 495, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2410/4699119 [15:00<528:23:51,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2411/4699119 [15:00<531:49:12,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2412/4699119 [15:01<493:01:33,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2413/4699119 [15:01<475:19:12,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2414/4699119 [15:01<522:06:50,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2415/4699119 [15:02<480:59:53,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2416/4699119 [15:02<493:12:10,  2.65it/s]

last_hidden_state = torch.Size([8, 401, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2417/4699119 [15:03<505:07:32,  2.58it/s]

last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2418/4699119 [15:03<498:01:10,  2.62it/s]

last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2419/4699119 [15:03<479:16:34,  2.72it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2420/4699119 [15:04<524:09:26,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 429, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2421/4699119 [15:04<532:58:34,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2422/4699119 [15:05<560:55:46,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2423/4699119 [15:05<586:17:28,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2424/4699119 [15:05<546:23:29,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2425/4699119 [15:06<570:28:39,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2426/4699119 [15:06<526:00:53,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2427/4699119 [15:07<469:38:07,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2428/4699119 [15:07<517:00:51,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2429/4699119 [15:08<550:11:33,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 429, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2430/4699119 [15:08<551:37:16,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2431/4699119 [15:08<511:02:46,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2432/4699119 [15:09<481:06:42,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2433/4699119 [15:09<506:03:58,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2434/4699119 [15:09<504:15:52,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2435/4699119 [15:10<496:41:30,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2436/4699119 [15:10<535:46:08,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2437/4699119 [15:11<486:40:30,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2438/4699119 [15:11<465:37:32,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2439/4699119 [15:11<427:23:20,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2440/4699119 [15:12<489:14:15,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2441/4699119 [15:12<443:52:01,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2442/4699119 [15:12<450:05:14,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2443/4699119 [15:13<503:17:39,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 161, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2444/4699119 [15:13<435:33:52,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2445/4699119 [15:13<467:53:50,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2446/4699119 [15:14<515:27:55,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2447/4699119 [15:14<548:41:54,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2448/4699119 [15:15<559:44:56,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2449/4699119 [15:15<544:45:59,  2.39it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2450/4699119 [15:16<568:58:57,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2451/4699119 [15:16<506:01:19,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2452/4699119 [15:16<542:54:20,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2453/4699119 [15:17<569:21:21,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2454/4699119 [15:17<499:40:07,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2455/4699119 [15:18<537:39:16,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2456/4699119 [15:18<566:07:37,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 169, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2457/4699119 [15:18<481:02:43,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 507, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2458/4699119 [15:19<529:08:57,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2459/4699119 [15:19<557:55:53,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2460/4699119 [15:19<492:43:46,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2461/4699119 [15:20<475:35:37,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2462/4699119 [15:20<437:42:15,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2463/4699119 [15:21<494:37:29,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2464/4699119 [15:21<461:35:00,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2465/4699119 [15:21<447:13:34,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2466/4699119 [15:22<442:10:23,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2467/4699119 [15:22<434:44:03,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2468/4699119 [15:22<461:36:12,  2.83it/s]

last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2469/4699119 [15:23<425:14:11,  3.07it/s]

last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2470/4699119 [15:23<438:20:54,  2.98it/s]

last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2471/4699119 [15:23<435:08:15,  3.00it/s]

last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2472/4699119 [15:24<460:40:56,  2.83it/s]

last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2473/4699119 [15:24<500:08:18,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2474/4699119 [15:24<434:43:37,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2475/4699119 [15:25<448:45:12,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2476/4699119 [15:25<448:50:03,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2477/4699119 [15:25<416:05:12,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2478/4699119 [15:26<429:16:22,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2479/4699119 [15:26<394:32:00,  3.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2480/4699119 [15:26<438:10:48,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2481/4699119 [15:27<467:47:23,  2.79it/s]

last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2482/4699119 [15:27<428:17:18,  3.05it/s]

last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2483/4699119 [15:27<438:36:19,  2.97it/s]

last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2484/4699119 [15:28<462:32:16,  2.82it/s]

last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2485/4699119 [15:28<516:34:37,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2486/4699119 [15:28<478:15:55,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2487/4699119 [15:29<440:26:31,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2488/4699119 [15:29<495:46:02,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2489/4699119 [15:30<489:04:52,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2490/4699119 [15:30<461:10:34,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2491/4699119 [15:30<500:21:45,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 170, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2492/4699119 [15:31<435:02:01,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 137, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2493/4699119 [15:31<382:04:10,  3.41it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2494/4699119 [15:31<456:01:27,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2495/4699119 [15:31<403:26:37,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2496/4699119 [15:32<384:57:01,  3.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 511, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2497/4699119 [15:32<463:12:09,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2498/4699119 [15:33<512:34:47,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2499/4699119 [15:33<463:39:26,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2500/4699119 [15:33<482:47:23,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2501/4699119 [15:34<525:46:07,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2502/4699119 [15:34<492:50:15,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2503/4699119 [15:34<437:52:28,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 421, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2504/4699119 [15:35<471:55:16,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2505/4699119 [15:35<502:51:06,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2506/4699119 [15:36<541:02:28,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2507/4699119 [15:36<534:29:01,  2.44it/s]

last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2509/4699119 [15:37<426:45:41,  3.06it/s]

last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2510/4699119 [15:37<439:20:18,  2.97it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2511/4699119 [15:37<495:27:09,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2512/4699119 [15:38<536:43:55,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2513/4699119 [15:38<476:32:51,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2514/4699119 [15:39<451:15:56,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2515/4699119 [15:39<445:05:53,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2516/4699119 [15:39<465:26:30,  2.80it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2517/4699119 [15:40<515:07:30,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2518/4699119 [15:40<498:39:06,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 508, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2519/4699119 [15:41<538:49:17,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2520/4699119 [15:41<514:03:51,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2521/4699119 [15:41<462:36:03,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2522/4699119 [15:42<477:36:00,  2.73it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2523/4699119 [15:42<523:13:17,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2524/4699119 [15:43<554:12:12,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2525/4699119 [15:43<576:38:15,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2526/4699119 [15:43<566:04:20,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2527/4699119 [15:44<583:44:36,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2528/4699119 [15:44<516:39:16,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2529/4699119 [15:45<553:28:10,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2530/4699119 [15:45<484:09:11,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2531/4699119 [15:45<486:30:31,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2532/4699119 [15:46<530:43:24,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2533/4699119 [15:46<476:31:41,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2534/4699119 [15:47<523:20:39,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2535/4699119 [15:47<482:27:03,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2536/4699119 [15:47<491:31:44,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2537/4699119 [15:47<444:19:43,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2538/4699119 [15:48<430:09:28,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2539/4699119 [15:48<391:50:31,  3.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2540/4699119 [15:48<440:29:15,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2541/4699119 [15:49<455:27:42,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2542/4699119 [15:49<506:20:47,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2543/4699119 [15:50<545:33:04,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2544/4699119 [15:50<570:35:52,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2545/4699119 [15:51<554:36:18,  2.35it/s]

last_hidden_state = torch.Size([8, 506, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2546/4699119 [15:51<577:04:36,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2547/4699119 [15:52<592:30:46,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2548/4699119 [15:52<533:38:29,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2549/4699119 [15:52<477:38:22,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2550/4699119 [15:53<528:00:19,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2551/4699119 [15:53<477:25:43,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 80, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2552/4699119 [15:53<395:15:32,  3.30it/s]

last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2553/4699119 [15:53<408:52:01,  3.19it/s]

last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2554/4699119 [15:54<416:59:00,  3.13it/s]

last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2555/4699119 [15:54<463:17:39,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2556/4699119 [15:55<513:00:30,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2557/4699119 [15:55<479:13:14,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2558/4699119 [15:55<523:20:45,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2559/4699119 [15:56<502:12:14,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2560/4699119 [15:56<539:34:32,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 510, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2561/4699119 [15:57<566:42:55,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2562/4699119 [15:57<587:31:58,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2563/4699119 [15:58<547:34:06,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2564/4699119 [15:58<571:11:20,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2565/4699119 [15:58<543:21:08,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2566/4699119 [15:59<523:41:39,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2567/4699119 [15:59<553:29:33,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2568/4699119 [16:00<575:17:09,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2569/4699119 [16:00<528:53:01,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2570/4699119 [16:01<559:46:13,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2571/4699119 [16:01<532:13:03,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2572/4699119 [16:01<465:12:47,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 488, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2573/4699119 [16:02<509:09:43,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2574/4699119 [16:02<544:25:02,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2575/4699119 [16:03<541:54:46,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2576/4699119 [16:03<495:08:48,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2577/4699119 [16:03<534:09:50,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2578/4699119 [16:04<544:11:52,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2579/4699119 [16:04<519:06:30,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 486, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2580/4699119 [16:05<547:47:49,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2581/4699119 [16:05<573:27:45,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2582/4699119 [16:05<518:19:43,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2583/4699119 [16:06<492:28:41,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2584/4699119 [16:06<532:42:15,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2585/4699119 [16:07<540:07:00,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2586/4699119 [16:07<470:56:40,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2587/4699119 [16:07<517:34:49,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2588/4699119 [16:08<500:31:19,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2589/4699119 [16:08<467:12:35,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2590/4699119 [16:08<478:38:11,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2591/4699119 [16:09<526:21:10,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2592/4699119 [16:09<544:17:33,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2593/4699119 [16:10<522:10:39,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2594/4699119 [16:10<552:53:01,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2595/4699119 [16:11<554:01:26,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2596/4699119 [16:11<523:09:16,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2597/4699119 [16:11<553:47:51,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2598/4699119 [16:12<525:19:20,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2599/4699119 [16:12<531:06:06,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2600/4699119 [16:13<559:53:14,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2601/4699119 [16:13<559:15:04,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2602/4699119 [16:14<581:32:04,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2603/4699119 [16:14<596:06:09,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2604/4699119 [16:15<605:25:52,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2605/4699119 [16:15<589:20:38,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2606/4699119 [16:15<564:14:30,  2.31it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2607/4699119 [16:16<582:17:26,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2608/4699119 [16:16<595:09:24,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2609/4699119 [16:17<531:31:14,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2610/4699119 [16:17<565:16:58,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2611/4699119 [16:17<533:02:41,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2612/4699119 [16:18<561:57:50,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2613/4699119 [16:18<477:09:45,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2614/4699119 [16:19<522:42:52,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2615/4699119 [16:19<506:17:57,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2616/4699119 [16:19<509:58:43,  2.56it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2617/4699119 [16:20<544:48:54,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2618/4699119 [16:20<505:34:47,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 171, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2619/4699119 [16:20<438:55:31,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2620/4699119 [16:21<497:49:15,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2621/4699119 [16:21<536:44:22,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2622/4699119 [16:22<563:30:33,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2623/4699119 [16:22<546:15:07,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2624/4699119 [16:23<488:25:44,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2625/4699119 [16:23<451:44:46,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2626/4699119 [16:23<416:42:08,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2627/4699119 [16:23<416:48:45,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 191, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2628/4699119 [16:24<383:31:04,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2629/4699119 [16:24<457:29:16,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2630/4699119 [16:24<464:57:01,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2631/4699119 [16:25<416:53:16,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2632/4699119 [16:25<427:50:32,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2633/4699119 [16:26<487:13:36,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2634/4699119 [16:26<459:02:13,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2635/4699119 [16:26<509:09:02,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2636/4699119 [16:27<486:51:17,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2637/4699119 [16:27<460:17:35,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2638/4699119 [16:27<428:03:22,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2639/4699119 [16:28<487:10:17,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2640/4699119 [16:28<440:15:09,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2641/4699119 [16:28<499:41:03,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2642/4699119 [16:29<451:28:24,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 481, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2643/4699119 [16:29<504:32:10,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2644/4699119 [16:29<456:25:51,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2645/4699119 [16:30<414:52:07,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2646/4699119 [16:30<431:11:33,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2647/4699119 [16:30<425:02:15,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2648/4699119 [16:31<415:56:14,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2649/4699119 [16:31<479:26:57,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2651/4699119 [16:32<419:59:11,  3.11it/s]

last_hidden_state = torch.Size([8, 137, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 436, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2652/4699119 [16:32<460:31:32,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2653/4699119 [16:33<511:07:27,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2654/4699119 [16:33<514:14:25,  2.54it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2655/4699119 [16:34<548:34:23,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 71, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2656/4699119 [16:34<445:01:42,  2.93it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2657/4699119 [16:34<500:13:50,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2658/4699119 [16:35<539:50:08,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2659/4699119 [16:35<566:57:34,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 176, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2660/4699119 [16:35<481:53:33,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 186, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2661/4699119 [16:36<428:28:47,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2662/4699119 [16:36<488:05:40,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2663/4699119 [16:36<441:54:14,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2664/4699119 [16:37<455:59:39,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2665/4699119 [16:37<470:38:35,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2666/4699119 [16:38<517:34:14,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2667/4699119 [16:38<499:32:58,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2668/4699119 [16:38<455:52:32,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2669/4699119 [16:38<424:50:27,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2670/4699119 [16:39<459:38:33,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2671/4699119 [16:39<509:23:43,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 496, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2672/4699119 [16:40<540:35:09,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 145, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2673/4699119 [16:40<457:39:27,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2674/4699119 [16:40<471:21:45,  2.77it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2675/4699119 [16:41<517:37:23,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 127, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2676/4699119 [16:41<436:04:12,  2.99it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2677/4699119 [16:42<493:34:59,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2678/4699119 [16:42<498:51:32,  2.62it/s]

last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2679/4699119 [16:42<496:59:07,  2.62it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2680/4699119 [16:43<535:39:25,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2681/4699119 [16:43<503:41:23,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2682/4699119 [16:43<450:07:53,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2683/4699119 [16:44<416:43:22,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2684/4699119 [16:44<479:49:04,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2685/4699119 [16:44<480:55:33,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2686/4699119 [16:45<478:15:25,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 507, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2687/4699119 [16:45<527:51:41,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2688/4699119 [16:46<557:31:31,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2689/4699119 [16:46<492:29:35,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 151, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2690/4699119 [16:46<425:07:18,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2691/4699119 [16:47<427:25:06,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2692/4699119 [16:47<417:28:17,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 488, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2693/4699119 [16:47<475:29:34,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2694/4699119 [16:48<430:53:30,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2695/4699119 [16:48<448:46:53,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2696/4699119 [16:48<503:28:34,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2697/4699119 [16:49<528:57:40,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2698/4699119 [16:49<526:43:02,  2.48it/s]

last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2699/4699119 [16:50<472:21:14,  2.76it/s]

last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2701/4699119 [16:50<438:10:18,  2.98it/s]

last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2703/4699119 [16:51<414:03:55,  3.15it/s]

last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2704/4699119 [16:51<484:21:34,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2705/4699119 [16:52<528:41:15,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2706/4699119 [16:52<560:11:44,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2707/4699119 [16:53<510:41:12,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 445, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2708/4699119 [16:53<529:28:37,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2709/4699119 [16:54<560:12:16,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2710/4699119 [16:54<568:45:04,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2711/4699119 [16:54<512:07:20,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2712/4699119 [16:55<524:33:47,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2713/4699119 [16:55<545:11:32,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2714/4699119 [16:55<489:30:53,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2715/4699119 [16:56<469:37:26,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2716/4699119 [16:56<467:53:13,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2717/4699119 [16:57<515:08:13,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2718/4699119 [16:57<498:50:39,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2719/4699119 [16:57<542:38:58,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2720/4699119 [16:58<569:06:44,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2721/4699119 [16:58<533:48:56,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2722/4699119 [16:59<562:38:47,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2723/4699119 [16:59<523:11:48,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2724/4699119 [17:00<555:01:32,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2725/4699119 [17:00<576:05:49,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2726/4699119 [17:01<575:30:29,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2727/4699119 [17:01<590:56:41,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2728/4699119 [17:01<525:45:21,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2729/4699119 [17:02<548:26:06,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2730/4699119 [17:02<572:18:10,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 421, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2731/4699119 [17:03<566:45:38,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2732/4699119 [17:03<532:47:59,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2733/4699119 [17:03<496:44:53,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2734/4699119 [17:04<455:11:00,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2735/4699119 [17:04<506:21:53,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2736/4699119 [17:05<542:31:44,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2737/4699119 [17:05<470:13:42,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2738/4699119 [17:05<458:25:50,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2739/4699119 [17:06<508:57:17,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 129, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2740/4699119 [17:06<432:07:32,  3.02it/s]

last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2741/4699119 [17:06<474:09:58,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2742/4699119 [17:07<482:44:10,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2743/4699119 [17:07<527:30:48,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2744/4699119 [17:07<475:07:54,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2745/4699119 [17:08<523:14:14,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2746/4699119 [17:08<468:51:12,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2747/4699119 [17:08<465:45:11,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2748/4699119 [17:09<426:32:51,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2749/4699119 [17:09<486:36:40,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2750/4699119 [17:09<445:59:15,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2751/4699119 [17:10<500:36:59,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2752/4699119 [17:10<501:10:02,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2753/4699119 [17:11<500:46:08,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2754/4699119 [17:11<499:14:32,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2755/4699119 [17:11<484:07:25,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 148, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2756/4699119 [17:12<418:07:04,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2757/4699119 [17:12<483:39:17,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 94, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2758/4699119 [17:12<403:05:42,  3.24it/s]

last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2759/4699119 [17:13<422:28:58,  3.09it/s]

last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2761/4699119 [17:13<358:19:41,  3.64it/s]

last_hidden_state = torch.Size([8, 123, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2762/4699119 [17:14<398:55:07,  3.27it/s]

last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2763/4699119 [17:14<440:12:49,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2764/4699119 [17:14<496:10:23,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2765/4699119 [17:15<437:06:44,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2766/4699119 [17:15<442:19:11,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2767/4699119 [17:15<437:48:15,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2768/4699119 [17:16<433:39:06,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2769/4699119 [17:16<433:51:01,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 268, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2770/4699119 [17:16<419:17:13,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2771/4699119 [17:17<464:36:42,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2772/4699119 [17:17<515:20:37,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2773/4699119 [17:17<463:51:37,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 89, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2774/4699119 [17:18<389:36:03,  3.35it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2775/4699119 [17:18<461:44:31,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2776/4699119 [17:19<512:11:27,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 445, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2777/4699119 [17:19<529:25:25,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 384, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2778/4699119 [17:19<518:41:13,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2779/4699119 [17:20<550:37:36,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2780/4699119 [17:20<525:47:00,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2781/4699119 [17:21<556:18:22,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2782/4699119 [17:21<531:12:35,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2783/4699119 [17:21<495:30:12,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2784/4699119 [17:22<535:56:13,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2785/4699119 [17:22<562:59:58,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2786/4699119 [17:23<479:10:59,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2787/4699119 [17:23<512:56:25,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2788/4699119 [17:23<466:44:49,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2789/4699119 [17:24<421:49:01,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2790/4699119 [17:24<395:14:45,  3.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2791/4699119 [17:24<428:07:23,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2792/4699119 [17:25<489:16:29,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2793/4699119 [17:25<530:47:29,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2794/4699119 [17:25<475:42:20,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 401, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2795/4699119 [17:26<492:17:23,  2.65it/s]

last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2796/4699119 [17:26<515:34:50,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2797/4699119 [17:27<495:55:28,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2798/4699119 [17:27<535:42:41,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2799/4699119 [17:28<536:28:34,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2800/4699119 [17:28<564:33:41,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2801/4699119 [17:28<474:47:01,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2802/4699119 [17:29<464:00:24,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2803/4699119 [17:29<513:28:56,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 487, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2804/4699119 [17:29<548:18:30,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2805/4699119 [17:30<536:31:52,  2.43it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2806/4699119 [17:30<563:41:37,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2807/4699119 [17:31<583:31:52,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2808/4699119 [17:31<574:06:28,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2809/4699119 [17:32<588:51:00,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2810/4699119 [17:32<540:51:13,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2811/4699119 [17:33<567:44:35,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2812/4699119 [17:33<586:27:46,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2813/4699119 [17:33<542:23:46,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 139, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2814/4699119 [17:34<457:22:13,  2.85it/s]

last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2815/4699119 [17:34<476:27:51,  2.74it/s]

last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2816/4699119 [17:34<496:48:35,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2817/4699119 [17:35<537:44:24,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2818/4699119 [17:35<491:14:27,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2819/4699119 [17:36<495:23:07,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2820/4699119 [17:36<518:45:07,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2821/4699119 [17:36<552:15:40,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2822/4699119 [17:37<503:37:01,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2823/4699119 [17:37<492:12:07,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2824/4699119 [17:37<459:08:38,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 114, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2825/4699119 [17:38<393:33:42,  3.31it/s]

last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2826/4699119 [17:38<402:39:44,  3.24it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2827/4699119 [17:38<471:10:28,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 430, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2828/4699119 [17:39<492:07:11,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2829/4699119 [17:39<522:22:04,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2830/4699119 [17:40<495:19:27,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2831/4699119 [17:40<535:53:35,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2832/4699119 [17:40<466:38:55,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2833/4699119 [17:41<515:26:00,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2834/4699119 [17:41<477:03:40,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2835/4699119 [17:42<501:16:14,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2836/4699119 [17:42<444:43:05,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2837/4699119 [17:42<500:24:27,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2838/4699119 [17:43<538:24:29,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2839/4699119 [17:43<501:02:43,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2840/4699119 [17:43<476:30:49,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2841/4699119 [17:44<509:42:41,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2842/4699119 [17:44<458:37:46,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2843/4699119 [17:44<439:22:56,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2844/4699119 [17:45<495:12:38,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2845/4699119 [17:45<534:13:55,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2847/4699119 [17:46<462:06:22,  2.82it/s]

last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2848/4699119 [17:46<511:49:41,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2849/4699119 [17:47<448:59:39,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2850/4699119 [17:47<437:24:51,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2851/4699119 [17:48<494:50:03,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2852/4699119 [17:48<512:13:43,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2853/4699119 [17:48<456:30:22,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2854/4699119 [17:48<426:05:41,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2855/4699119 [17:49<377:00:02,  3.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2856/4699119 [17:49<451:38:43,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2857/4699119 [17:50<508:45:42,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2858/4699119 [17:50<475:46:31,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 117, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2860/4699119 [17:50<363:11:39,  3.59it/s]

last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2861/4699119 [17:51<348:30:44,  3.74it/s]

last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2862/4699119 [17:51<334:56:17,  3.89it/s]

last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2863/4699119 [17:51<363:07:42,  3.59it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2864/4699119 [17:52<442:37:30,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2865/4699119 [17:52<499:15:12,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2866/4699119 [17:53<537:51:08,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2867/4699119 [17:53<564:42:51,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2868/4699119 [17:53<545:15:30,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 307, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2869/4699119 [17:54<511:03:39,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2870/4699119 [17:54<520:30:58,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2871/4699119 [17:54<470:19:41,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2872/4699119 [17:55<412:35:11,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2873/4699119 [17:55<477:19:29,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2874/4699119 [17:56<504:40:37,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2875/4699119 [17:56<488:54:52,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2877/4699119 [17:57<430:01:05,  3.03it/s]

last_hidden_state = torch.Size([8, 158, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2878/4699119 [17:57<456:35:52,  2.86it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2879/4699119 [17:57<508:05:27,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2880/4699119 [17:58<460:13:46,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2881/4699119 [17:58<416:39:24,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2882/4699119 [17:58<383:13:02,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2883/4699119 [17:59<440:13:32,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2884/4699119 [17:59<432:49:21,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 129, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2885/4699119 [17:59<378:52:06,  3.44it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2886/4699119 [18:00<452:29:45,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2887/4699119 [18:00<413:09:58,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2888/4699119 [18:00<477:23:44,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2889/4699119 [18:01<470:28:48,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2890/4699119 [18:01<474:03:10,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2891/4699119 [18:01<463:08:05,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2892/4699119 [18:02<495:28:24,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2893/4699119 [18:02<509:13:53,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 443, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2894/4699119 [18:03<527:52:31,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2895/4699119 [18:03<509:18:00,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 172, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2896/4699119 [18:03<441:43:03,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2897/4699119 [18:04<496:55:25,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2898/4699119 [18:04<454:45:02,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2899/4699119 [18:04<481:24:01,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2900/4699119 [18:05<468:01:46,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2901/4699119 [18:05<515:49:29,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2902/4699119 [18:06<548:20:57,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2903/4699119 [18:06<571:29:58,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2904/4699119 [18:06<501:50:00,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2905/4699119 [18:07<487:15:56,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2906/4699119 [18:07<528:56:48,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2907/4699119 [18:08<511:05:46,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2908/4699119 [18:08<484:06:29,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2909/4699119 [18:08<440:43:15,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2910/4699119 [18:08<403:03:08,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2911/4699119 [18:09<469:34:03,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2912/4699119 [18:09<498:55:50,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 137, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2913/4699119 [18:10<426:33:16,  3.06it/s]

last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2914/4699119 [18:10<436:31:30,  2.99it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2915/4699119 [18:10<494:05:45,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2916/4699119 [18:11<535:15:12,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2917/4699119 [18:11<527:36:25,  2.47it/s]

last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2918/4699119 [18:12<483:26:26,  2.70it/s]

last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2919/4699119 [18:12<475:07:35,  2.75it/s]

last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2920/4699119 [18:12<510:50:36,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2921/4699119 [18:13<547:56:48,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 443, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2922/4699119 [18:13<554:13:54,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 496, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2923/4699119 [18:14<571:09:35,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2924/4699119 [18:14<588:39:22,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2925/4699119 [18:15<599:56:52,  2.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2926/4699119 [18:15<520:39:38,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2927/4699119 [18:15<526:18:55,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2928/4699119 [18:16<501:16:41,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2929/4699119 [18:16<531:36:35,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2930/4699119 [18:16<492:33:56,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2931/4699119 [18:17<450:08:05,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 185, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2932/4699119 [18:17<407:40:10,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2933/4699119 [18:17<411:16:59,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2934/4699119 [18:18<478:34:09,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2935/4699119 [18:18<476:06:55,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2936/4699119 [18:19<521:23:06,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2937/4699119 [18:19<503:55:58,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2938/4699119 [18:19<540:59:08,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2939/4699119 [18:20<529:31:52,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2940/4699119 [18:20<475:33:11,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2941/4699119 [18:21<520:23:42,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 98, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2942/4699119 [18:21<430:45:43,  3.03it/s]

last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2944/4699119 [18:21<356:29:16,  3.66it/s]

last_hidden_state = torch.Size([8, 83, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2945/4699119 [18:22<405:09:03,  3.22it/s]

last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2946/4699119 [18:22<450:13:35,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2947/4699119 [18:23<502:45:10,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2948/4699119 [18:23<456:13:07,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 152, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2949/4699119 [18:23<398:35:38,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2950/4699119 [18:23<440:16:26,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2951/4699119 [18:24<496:33:26,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2952/4699119 [18:24<467:44:51,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2953/4699119 [18:25<515:33:00,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2954/4699119 [18:25<548:43:58,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2955/4699119 [18:26<573:07:30,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2956/4699119 [18:26<491:46:56,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2957/4699119 [18:26<486:50:23,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2958/4699119 [18:27<530:32:41,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2959/4699119 [18:27<540:06:52,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2960/4699119 [18:28<565:58:23,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2961/4699119 [18:28<584:34:07,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 480, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2962/4699119 [18:29<584:29:34,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2963/4699119 [18:29<529:14:05,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2964/4699119 [18:29<558:49:01,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2965/4699119 [18:30<505:10:53,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2966/4699119 [18:30<432:58:29,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2967/4699119 [18:30<490:52:17,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2968/4699119 [18:31<533:23:24,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2969/4699119 [18:31<509:55:44,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2970/4699119 [18:32<545:12:52,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2971/4699119 [18:32<461:46:27,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2972/4699119 [18:32<493:50:15,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2973/4699119 [18:33<447:28:22,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2974/4699119 [18:33<431:46:37,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2975/4699119 [18:33<422:25:15,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2976/4699119 [18:34<465:53:17,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2977/4699119 [18:34<516:02:09,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2978/4699119 [18:35<549:08:27,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2979/4699119 [18:35<525:30:22,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2980/4699119 [18:35<506:10:00,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2982/4699119 [18:36<395:43:39,  3.30it/s]

last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2983/4699119 [18:36<401:46:02,  3.25it/s]

last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2984/4699119 [18:37<471:52:09,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2985/4699119 [18:37<470:50:56,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2986/4699119 [18:37<444:43:52,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2987/4699119 [18:37<391:13:56,  3.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 330, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2988/4699119 [18:38<407:01:13,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2989/4699119 [18:38<444:56:04,  2.93it/s]

last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2990/4699119 [18:39<486:07:36,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 173, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2991/4699119 [18:39<425:50:25,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2992/4699119 [18:39<485:49:18,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2993/4699119 [18:40<527:47:11,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2994/4699119 [18:40<562:55:28,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2995/4699119 [18:41<500:48:36,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2996/4699119 [18:41<473:54:40,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 2997/4699119 [18:41<519:43:05,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 137, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 2998/4699119 [18:42<441:04:47,  2.96it/s]

last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3000/4699119 [18:42<382:43:27,  3.41it/s]

last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3002/4699119 [18:43<351:51:47,  3.71it/s]

last_hidden_state = torch.Size([8, 170, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3003/4699119 [18:43<433:56:59,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3004/4699119 [18:43<423:24:24,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3005/4699119 [18:44<411:13:18,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3006/4699119 [18:44<387:26:37,  3.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3007/4699119 [18:44<460:14:30,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3008/4699119 [18:45<510:49:04,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3009/4699119 [18:45<482:58:30,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3010/4699119 [18:46<494:02:22,  2.64it/s]

last_hidden_state = torch.Size([8, 468, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3011/4699119 [18:46<521:07:30,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3012/4699119 [18:46<554:01:03,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3013/4699119 [18:47<575:52:09,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3014/4699119 [18:47<528:24:00,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 268, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3015/4699119 [18:48<485:05:40,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3016/4699119 [18:48<494:53:33,  2.64it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3017/4699119 [18:48<533:50:58,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3018/4699119 [18:49<497:59:24,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3019/4699119 [18:49<540:58:18,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 95, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3020/4699119 [18:49<444:12:10,  2.94it/s]

last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3021/4699119 [18:50<447:27:22,  2.92it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3022/4699119 [18:50<503:16:42,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3023/4699119 [18:51<540:58:07,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3024/4699119 [18:51<568:02:25,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3025/4699119 [18:52<585:51:31,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3026/4699119 [18:52<570:13:42,  2.29it/s]

last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3027/4699119 [18:52<500:24:01,  2.61it/s]

last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3028/4699119 [18:53<530:25:46,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3029/4699119 [18:53<481:19:09,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3030/4699119 [18:54<512:10:23,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3031/4699119 [18:54<513:19:15,  2.54it/s]

last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3032/4699119 [18:54<499:07:28,  2.61it/s]

last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3033/4699119 [18:55<521:04:57,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3034/4699119 [18:55<556:41:18,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3035/4699119 [18:56<579:11:41,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3036/4699119 [18:56<596:58:43,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3037/4699119 [18:57<606:28:00,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3038/4699119 [18:57<518:36:23,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3039/4699119 [18:57<496:09:15,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3040/4699119 [18:58<535:18:17,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 85, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3041/4699119 [18:58<438:55:19,  2.97it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3042/4699119 [18:58<496:00:33,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3043/4699119 [18:59<472:56:27,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3044/4699119 [18:59<454:05:14,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3045/4699119 [18:59<488:57:58,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3046/4699119 [19:00<532:02:51,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3047/4699119 [19:00<561:37:14,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3048/4699119 [19:01<581:14:08,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3049/4699119 [19:01<530:52:49,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3050/4699119 [19:02<519:42:44,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3051/4699119 [19:02<458:53:37,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3052/4699119 [19:02<421:51:04,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3053/4699119 [19:02<386:51:58,  3.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3054/4699119 [19:03<458:31:05,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3055/4699119 [19:03<444:10:13,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3056/4699119 [19:03<434:23:03,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3057/4699119 [19:04<439:04:09,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3058/4699119 [19:04<495:19:31,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3059/4699119 [19:05<483:24:55,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 144, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3060/4699119 [19:05<416:00:07,  3.14it/s]

last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3062/4699119 [19:05<390:52:35,  3.34it/s]

last_hidden_state = torch.Size([8, 156, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3063/4699119 [19:06<449:53:31,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 468, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3064/4699119 [19:06<488:45:36,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 496, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3065/4699119 [19:07<527:24:51,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3066/4699119 [19:07<561:19:36,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3067/4699119 [19:08<522:03:54,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 404, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3068/4699119 [19:08<522:16:57,  2.50it/s]

last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3069/4699119 [19:08<514:50:08,  2.53it/s]

last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3070/4699119 [19:09<548:25:58,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3071/4699119 [19:09<521:11:42,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3072/4699119 [19:10<472:51:13,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3073/4699119 [19:10<464:28:12,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3074/4699119 [19:10<421:56:58,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3075/4699119 [19:11<459:21:09,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3076/4699119 [19:11<471:01:11,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3077/4699119 [19:11<436:41:24,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3078/4699119 [19:12<455:58:20,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3079/4699119 [19:12<417:30:40,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3080/4699119 [19:12<435:31:18,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3081/4699119 [19:12<422:05:20,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3082/4699119 [19:13<471:10:36,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3083/4699119 [19:13<522:37:26,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3084/4699119 [19:14<554:41:53,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3085/4699119 [19:14<532:49:52,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3086/4699119 [19:15<560:44:47,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3087/4699119 [19:15<498:33:25,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3088/4699119 [19:15<509:56:17,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3089/4699119 [19:16<473:08:37,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3090/4699119 [19:16<473:42:08,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3091/4699119 [19:17<519:29:42,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3092/4699119 [19:17<551:15:36,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3093/4699119 [19:18<577:36:43,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3094/4699119 [19:18<528:54:37,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3095/4699119 [19:18<557:58:17,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3096/4699119 [19:19<578:01:58,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3097/4699119 [19:19<591:40:14,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3098/4699119 [19:20<556:45:55,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3099/4699119 [19:20<543:20:02,  2.40it/s]

last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3100/4699119 [19:20<506:49:33,  2.57it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3101/4699119 [19:21<542:39:15,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3102/4699119 [19:21<567:26:18,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3103/4699119 [19:22<506:35:25,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 121, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3104/4699119 [19:22<427:42:18,  3.05it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3105/4699119 [19:22<487:20:47,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 507, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3106/4699119 [19:23<535:17:26,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3107/4699119 [19:23<468:13:39,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3108/4699119 [19:24<515:45:10,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3109/4699119 [19:24<550:13:16,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3110/4699119 [19:24<520:13:14,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3111/4699119 [19:25<540:49:36,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3112/4699119 [19:25<567:03:55,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3113/4699119 [19:26<588:37:03,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3114/4699119 [19:26<578:53:36,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3115/4699119 [19:27<559:18:52,  2.33it/s]

last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3116/4699119 [19:27<580:02:51,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3117/4699119 [19:28<594:45:09,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3118/4699119 [19:28<571:20:56,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 422, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3119/4699119 [19:28<561:46:10,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 161, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3120/4699119 [19:29<476:39:08,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3121/4699119 [19:29<494:31:05,  2.64it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3122/4699119 [19:29<535:10:42,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3123/4699119 [19:30<485:44:20,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3124/4699119 [19:30<527:52:33,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3125/4699119 [19:31<556:51:41,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3126/4699119 [19:31<497:49:16,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3127/4699119 [19:31<452:22:01,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3128/4699119 [19:32<505:11:01,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3129/4699119 [19:32<458:16:17,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 314, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3130/4699119 [19:32<450:46:53,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3131/4699119 [19:33<417:30:50,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3132/4699119 [19:33<457:23:50,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3133/4699119 [19:33<413:42:26,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3134/4699119 [19:34<409:24:35,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3135/4699119 [19:34<473:48:36,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3136/4699119 [19:34<462:25:33,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3137/4699119 [19:35<440:19:44,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3138/4699119 [19:35<498:01:35,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3139/4699119 [19:35<442:25:55,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3140/4699119 [19:36<406:47:42,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3141/4699119 [19:36<409:07:10,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3142/4699119 [19:36<474:26:05,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3143/4699119 [19:37<466:08:52,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3144/4699119 [19:37<516:48:18,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3145/4699119 [19:38<549:48:41,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3146/4699119 [19:38<545:10:26,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3147/4699119 [19:39<520:25:31,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3148/4699119 [19:39<500:26:42,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3149/4699119 [19:39<476:09:19,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 494, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3150/4699119 [19:40<519:04:47,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3151/4699119 [19:40<472:07:41,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3152/4699119 [19:40<439:45:34,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3153/4699119 [19:40<412:16:24,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3154/4699119 [19:41<388:09:16,  3.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3155/4699119 [19:41<412:02:19,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3156/4699119 [19:41<399:15:58,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3157/4699119 [19:42<397:00:25,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3158/4699119 [19:42<465:45:49,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3159/4699119 [19:42<445:04:38,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3160/4699119 [19:43<499:44:11,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3161/4699119 [19:43<473:03:15,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3162/4699119 [19:44<479:03:32,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3163/4699119 [19:44<523:12:05,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 88, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3164/4699119 [19:44<430:26:07,  3.03it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3165/4699119 [19:45<491:05:07,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3166/4699119 [19:45<533:25:57,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3167/4699119 [19:46<530:02:11,  2.46it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3168/4699119 [19:46<558:22:33,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3169/4699119 [19:47<580:06:06,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3170/4699119 [19:47<593:26:38,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3171/4699119 [19:48<582:26:32,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3172/4699119 [19:48<534:31:57,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3173/4699119 [19:48<560:46:49,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 268, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3174/4699119 [19:49<508:47:20,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3175/4699119 [19:49<462:49:55,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3176/4699119 [19:49<474:20:45,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3177/4699119 [19:50<464:50:24,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3178/4699119 [19:50<513:51:18,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3179/4699119 [19:50<517:20:35,  2.52it/s]

last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3180/4699119 [19:51<466:56:51,  2.79it/s]

last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3181/4699119 [19:51<500:51:48,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3182/4699119 [19:51<444:10:31,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3183/4699119 [19:52<474:05:19,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3184/4699119 [19:52<483:11:36,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3185/4699119 [19:53<493:47:15,  2.64it/s]

last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3186/4699119 [19:53<451:30:14,  2.89it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3187/4699119 [19:53<504:27:55,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 481, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3188/4699119 [19:54<541:25:48,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 111, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3189/4699119 [19:54<447:06:31,  2.92it/s]

last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3190/4699119 [19:54<480:57:50,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3191/4699119 [19:55<463:59:54,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3192/4699119 [19:55<440:40:23,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3193/4699119 [19:56<498:25:38,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3194/4699119 [19:56<537:20:50,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3195/4699119 [19:56<484:33:57,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3196/4699119 [19:57<526:44:38,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3197/4699119 [19:57<463:19:36,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3198/4699119 [19:58<511:58:08,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3199/4699119 [19:58<548:25:01,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3200/4699119 [19:58<493:55:03,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3201/4699119 [19:59<450:10:35,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3202/4699119 [19:59<468:18:39,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3203/4699119 [19:59<479:40:44,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3204/4699119 [20:00<428:53:40,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3205/4699119 [20:00<405:58:34,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3206/4699119 [20:00<437:32:55,  2.98it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3207/4699119 [20:01<494:30:35,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3208/4699119 [20:01<504:59:04,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3209/4699119 [20:01<461:49:37,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3210/4699119 [20:02<408:28:21,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3211/4699119 [20:02<441:11:26,  2.96it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3212/4699119 [20:03<497:01:29,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3213/4699119 [20:03<499:25:56,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3214/4699119 [20:03<493:08:43,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3215/4699119 [20:04<533:25:17,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3216/4699119 [20:04<487:25:53,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3217/4699119 [20:04<470:16:19,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3218/4699119 [20:05<495:03:13,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3219/4699119 [20:05<534:09:18,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3220/4699119 [20:06<493:22:03,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3221/4699119 [20:06<437:45:43,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3222/4699119 [20:06<494:06:36,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3223/4699119 [20:07<540:09:33,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 422, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3224/4699119 [20:07<540:11:24,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3225/4699119 [20:08<537:19:16,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3226/4699119 [20:08<565:32:04,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3227/4699119 [20:08<543:33:30,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3228/4699119 [20:09<542:22:06,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3229/4699119 [20:09<505:18:52,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3230/4699119 [20:09<460:49:17,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 479, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3231/4699119 [20:10<503:31:56,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3232/4699119 [20:10<542:25:27,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3233/4699119 [20:11<569:04:03,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3234/4699119 [20:11<569:49:16,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3235/4699119 [20:12<587:33:32,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3236/4699119 [20:12<576:14:27,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3237/4699119 [20:13<591:21:44,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3238/4699119 [20:13<519:19:50,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3239/4699119 [20:13<488:15:59,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3240/4699119 [20:14<529:41:51,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3241/4699119 [20:14<559:48:22,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3242/4699119 [20:15<580:19:28,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3243/4699119 [20:15<521:52:12,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3244/4699119 [20:15<527:11:53,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3245/4699119 [20:16<514:15:02,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3246/4699119 [20:16<486:30:53,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3247/4699119 [20:16<457:52:53,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 191, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3248/4699119 [20:17<412:11:03,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3249/4699119 [20:17<476:39:22,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3250/4699119 [20:18<468:55:28,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3251/4699119 [20:18<447:10:38,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3252/4699119 [20:18<502:05:53,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3253/4699119 [20:19<468:09:31,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3255/4699119 [20:19<429:20:41,  3.04it/s]

last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3256/4699119 [20:20<477:16:47,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3257/4699119 [20:20<438:04:43,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 148, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3258/4699119 [20:20<385:59:31,  3.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3259/4699119 [20:21<441:45:08,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3260/4699119 [20:21<452:41:46,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3261/4699119 [20:21<437:08:05,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3262/4699119 [20:22<425:40:43,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3263/4699119 [20:22<487:13:09,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3264/4699119 [20:23<529:39:56,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 133, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3265/4699119 [20:23<447:13:11,  2.92it/s]

last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3266/4699119 [20:23<414:33:10,  3.15it/s]

last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3267/4699119 [20:23<409:59:42,  3.18it/s]

last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3268/4699119 [20:24<461:24:32,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 176, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3269/4699119 [20:24<407:52:24,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 421, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3270/4699119 [20:24<452:07:38,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3271/4699119 [20:25<410:59:47,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3272/4699119 [20:25<458:22:09,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 482, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3273/4699119 [20:26<504:58:42,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3274/4699119 [20:26<510:17:51,  2.56it/s]

last_hidden_state = torch.Size([8, 430, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3275/4699119 [20:26<521:31:33,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3276/4699119 [20:27<485:06:54,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3277/4699119 [20:27<475:10:11,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3278/4699119 [20:27<448:49:27,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3279/4699119 [20:28<407:31:27,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3280/4699119 [20:28<472:55:14,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3281/4699119 [20:28<424:23:45,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3282/4699119 [20:29<486:24:09,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3283/4699119 [20:29<456:31:42,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3284/4699119 [20:30<507:08:09,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3285/4699119 [20:30<487:35:59,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3286/4699119 [20:30<529:27:42,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3287/4699119 [20:31<559:59:26,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3288/4699119 [20:31<520:06:47,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 162, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3289/4699119 [20:31<448:09:04,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3290/4699119 [20:32<407:07:36,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3291/4699119 [20:32<410:31:29,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3292/4699119 [20:32<453:01:47,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3293/4699119 [20:33<425:09:21,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3294/4699119 [20:33<485:30:13,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 384, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3295/4699119 [20:34<488:05:55,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3296/4699119 [20:34<489:20:29,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3297/4699119 [20:34<509:39:40,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3298/4699119 [20:35<516:25:10,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3299/4699119 [20:35<549:50:42,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3300/4699119 [20:35<493:30:43,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3301/4699119 [20:36<472:13:20,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3302/4699119 [20:36<508:21:53,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3303/4699119 [20:37<525:16:37,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3304/4699119 [20:37<474:03:52,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3305/4699119 [20:37<507:56:26,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3306/4699119 [20:38<447:56:55,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3307/4699119 [20:38<407:16:04,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3308/4699119 [20:38<479:15:06,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3309/4699119 [20:39<442:52:32,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3310/4699119 [20:39<485:29:00,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3311/4699119 [20:39<459:59:40,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3312/4699119 [20:40<509:29:15,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3313/4699119 [20:40<476:32:26,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3314/4699119 [20:40<424:42:47,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3315/4699119 [20:41<486:43:57,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3316/4699119 [20:41<530:27:50,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 495, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3317/4699119 [20:42<561:48:35,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3318/4699119 [20:42<527:56:02,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3319/4699119 [20:43<512:18:50,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 462, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3320/4699119 [20:43<531:18:09,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3321/4699119 [20:43<451:32:00,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3322/4699119 [20:44<503:55:05,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3323/4699119 [20:44<541:11:59,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3324/4699119 [20:44<486:24:48,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3325/4699119 [20:45<502:23:24,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 486, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3326/4699119 [20:45<535:50:46,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3327/4699119 [20:46<474:59:26,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3328/4699119 [20:46<441:06:09,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3329/4699119 [20:46<414:59:16,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3330/4699119 [20:47<438:49:29,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3331/4699119 [20:47<494:56:45,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3332/4699119 [20:47<487:19:22,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3333/4699119 [20:48<459:23:11,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3334/4699119 [20:48<457:07:08,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3335/4699119 [20:49<508:23:26,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3336/4699119 [20:49<526:55:41,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3337/4699119 [20:49<556:36:53,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3338/4699119 [20:50<577:43:56,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3339/4699119 [20:50<596:58:36,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 199, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3340/4699119 [20:51<512:23:18,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3341/4699119 [20:51<519:23:06,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3342/4699119 [20:51<455:58:20,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3343/4699119 [20:52<429:33:10,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3344/4699119 [20:52<400:44:03,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3345/4699119 [20:52<403:49:03,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3346/4699119 [20:53<471:28:17,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3347/4699119 [20:53<492:46:49,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3348/4699119 [20:53<482:05:14,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3349/4699119 [20:54<525:26:47,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 143, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3350/4699119 [20:54<446:24:43,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3351/4699119 [20:54<453:15:53,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 346, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3352/4699119 [20:55<455:30:40,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3353/4699119 [20:55<454:12:12,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3354/4699119 [20:56<505:29:27,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3355/4699119 [20:56<478:51:33,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3356/4699119 [20:56<525:45:05,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3357/4699119 [20:57<486:02:44,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3358/4699119 [20:57<528:39:23,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3359/4699119 [20:58<504:13:20,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3360/4699119 [20:58<509:15:13,  2.56it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3361/4699119 [20:58<544:42:13,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3362/4699119 [20:59<481:46:59,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 107, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3364/4699119 [20:59<368:15:36,  3.54it/s]

last_hidden_state = torch.Size([8, 176, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 477, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3365/4699119 [21:00<439:25:37,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3366/4699119 [21:00<435:25:01,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3367/4699119 [21:00<463:07:50,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3368/4699119 [21:01<482:38:21,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3369/4699119 [21:01<428:50:29,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3370/4699119 [21:01<488:01:19,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3371/4699119 [21:02<491:38:00,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3372/4699119 [21:02<486:32:49,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3373/4699119 [21:03<529:01:46,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3374/4699119 [21:03<486:33:30,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3376/4699119 [21:03<371:48:13,  3.51it/s]

last_hidden_state = torch.Size([8, 146, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3377/4699119 [21:04<363:16:13,  3.59it/s]

last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3379/4699119 [21:04<352:31:36,  3.70it/s]

last_hidden_state = torch.Size([8, 130, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3380/4699119 [21:04<351:21:24,  3.71it/s]

last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3381/4699119 [21:05<410:11:58,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3382/4699119 [21:05<370:06:15,  3.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 137, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3383/4699119 [21:05<336:16:19,  3.88it/s]

last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3384/4699119 [21:06<407:22:02,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 330, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3385/4699119 [21:06<421:10:40,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3386/4699119 [21:06<410:13:06,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3387/4699119 [21:07<371:20:37,  3.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3388/4699119 [21:07<448:53:19,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3389/4699119 [21:07<397:30:17,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3390/4699119 [21:08<468:04:34,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3391/4699119 [21:08<453:42:07,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3392/4699119 [21:09<505:14:33,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3393/4699119 [21:09<464:46:53,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 102, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3394/4699119 [21:09<392:16:22,  3.33it/s]

last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3395/4699119 [21:09<421:37:52,  3.09it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3396/4699119 [21:10<483:07:50,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3397/4699119 [21:10<437:12:11,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3398/4699119 [21:11<493:30:44,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3399/4699119 [21:11<500:41:17,  2.61it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3400/4699119 [21:11<544:45:26,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 465, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3401/4699119 [21:12<560:48:18,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3402/4699119 [21:12<563:59:20,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3403/4699119 [21:13<582:33:03,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3404/4699119 [21:13<565:41:05,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3405/4699119 [21:14<542:59:14,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3406/4699119 [21:14<550:03:01,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3407/4699119 [21:14<547:46:29,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3408/4699119 [21:15<554:16:02,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3409/4699119 [21:15<536:30:02,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3410/4699119 [21:16<563:49:58,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3411/4699119 [21:16<530:03:11,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3412/4699119 [21:17<560:22:38,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3413/4699119 [21:17<483:05:55,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3414/4699119 [21:17<439:16:02,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3415/4699119 [21:17<443:44:48,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 285, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3416/4699119 [21:18<431:38:53,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3417/4699119 [21:18<490:19:56,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3418/4699119 [21:18<445:11:48,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3419/4699119 [21:19<450:04:33,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3420/4699119 [21:19<413:20:55,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 151, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3421/4699119 [21:19<369:10:49,  3.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3422/4699119 [21:20<394:04:24,  3.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3423/4699119 [21:20<466:04:14,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3424/4699119 [21:20<405:01:00,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3425/4699119 [21:21<383:42:50,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 191, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3426/4699119 [21:21<361:31:43,  3.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3427/4699119 [21:21<380:15:25,  3.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3428/4699119 [21:22<417:35:58,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 404, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3429/4699119 [21:22<450:56:00,  2.89it/s]

last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3430/4699119 [21:22<434:48:46,  3.00it/s]

last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3431/4699119 [21:23<422:47:37,  3.09it/s]

last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3432/4699119 [21:23<411:14:22,  3.17it/s]

last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3433/4699119 [21:23<437:44:47,  2.98it/s]

last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3434/4699119 [21:24<465:38:10,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3435/4699119 [21:24<462:54:17,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3436/4699119 [21:24<511:56:11,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3437/4699119 [21:25<494:59:29,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3438/4699119 [21:25<439:32:16,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3439/4699119 [21:25<422:09:18,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3440/4699119 [21:26<483:29:28,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3441/4699119 [21:26<479:48:34,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3442/4699119 [21:27<524:42:23,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3443/4699119 [21:27<555:40:10,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3444/4699119 [21:27<506:20:44,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3445/4699119 [21:28<461:22:32,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3446/4699119 [21:28<512:09:56,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3447/4699119 [21:29<546:52:40,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3448/4699119 [21:29<572:16:46,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3449/4699119 [21:30<588:25:29,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3450/4699119 [21:30<518:50:34,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3451/4699119 [21:30<503:40:35,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3452/4699119 [21:31<470:32:38,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3453/4699119 [21:31<452:57:36,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3454/4699119 [21:31<418:03:12,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3455/4699119 [21:32<459:56:02,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3456/4699119 [21:32<452:14:07,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3457/4699119 [21:32<410:45:29,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3458/4699119 [21:33<474:57:34,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3459/4699119 [21:33<521:14:56,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3460/4699119 [21:34<557:45:06,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3461/4699119 [21:34<501:51:04,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3462/4699119 [21:34<466:06:50,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3463/4699119 [21:35<464:56:12,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3464/4699119 [21:35<437:02:25,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3465/4699119 [21:35<444:16:34,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3466/4699119 [21:36<501:42:16,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3467/4699119 [21:36<436:13:48,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3468/4699119 [21:36<402:53:12,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3469/4699119 [21:37<444:07:07,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3470/4699119 [21:37<412:47:50,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3471/4699119 [21:37<396:00:25,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3472/4699119 [21:37<426:58:35,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3473/4699119 [21:38<413:33:30,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3474/4699119 [21:38<477:53:04,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3475/4699119 [21:39<524:05:53,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3476/4699119 [21:39<472:02:25,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3477/4699119 [21:39<471:22:13,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3478/4699119 [21:40<448:55:53,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3479/4699119 [21:40<483:51:40,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 479, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3480/4699119 [21:41<521:44:28,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3481/4699119 [21:41<527:00:56,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3482/4699119 [21:41<557:01:22,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3483/4699119 [21:42<543:51:01,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3484/4699119 [21:42<472:00:34,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3485/4699119 [21:43<518:43:40,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3486/4699119 [21:43<506:17:50,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3487/4699119 [21:43<473:00:04,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3488/4699119 [21:44<459:48:08,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3489/4699119 [21:44<416:39:20,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 171, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3490/4699119 [21:44<376:23:38,  3.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3491/4699119 [21:44<424:28:27,  3.07it/s]

last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3492/4699119 [21:45<398:26:53,  3.27it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3493/4699119 [21:45<466:55:35,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3494/4699119 [21:45<418:48:30,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3495/4699119 [21:46<482:07:01,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3496/4699119 [21:46<462:06:22,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 109, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3497/4699119 [21:46<390:54:18,  3.34it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3498/4699119 [21:47<463:41:59,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 161, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3499/4699119 [21:47<408:40:42,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3500/4699119 [21:48<474:23:19,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3501/4699119 [21:48<508:17:51,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3502/4699119 [21:48<474:25:03,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3503/4699119 [21:49<434:14:30,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3504/4699119 [21:49<492:31:20,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3505/4699119 [21:50<533:53:09,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 127, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3506/4699119 [21:50<447:45:50,  2.91it/s]

last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3507/4699119 [21:50<475:46:44,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3508/4699119 [21:51<520:42:50,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3509/4699119 [21:51<511:36:42,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3510/4699119 [21:51<546:38:22,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3511/4699119 [21:52<516:49:25,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3512/4699119 [21:52<468:20:08,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3513/4699119 [21:53<515:11:41,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3514/4699119 [21:53<549:39:28,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3515/4699119 [21:53<547:02:14,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3516/4699119 [21:54<570:36:23,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3517/4699119 [21:54<588:17:52,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3518/4699119 [21:55<518:53:45,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3519/4699119 [21:55<540:24:23,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3520/4699119 [21:55<498:18:20,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3521/4699119 [21:56<504:38:46,  2.58it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3522/4699119 [21:56<541:46:27,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3523/4699119 [21:57<568:22:44,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3524/4699119 [21:57<536:53:02,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3525/4699119 [21:58<540:46:16,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3526/4699119 [21:58<566:54:51,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3527/4699119 [21:59<584:09:16,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3528/4699119 [21:59<578:43:43,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3529/4699119 [21:59<592:15:26,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3530/4699119 [22:00<531:33:31,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3531/4699119 [22:00<472:15:12,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3532/4699119 [22:01<519:21:36,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3533/4699119 [22:01<455:07:27,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3534/4699119 [22:01<510:05:13,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3535/4699119 [22:02<493:03:29,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3536/4699119 [22:02<455:29:10,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3537/4699119 [22:02<480:17:30,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3538/4699119 [22:03<527:17:55,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3539/4699119 [22:03<474:56:08,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3540/4699119 [22:03<457:36:24,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3542/4699119 [22:04<428:35:46,  3.04it/s]

last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3543/4699119 [22:04<488:37:16,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 285, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3544/4699119 [22:05<463:35:23,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3545/4699119 [22:05<512:36:22,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 505, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3546/4699119 [22:06<551:28:36,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3547/4699119 [22:06<516:28:15,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3548/4699119 [22:07<550:54:17,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3549/4699119 [22:07<484:39:11,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3550/4699119 [22:07<437:45:56,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3551/4699119 [22:08<494:40:34,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3552/4699119 [22:08<470:28:44,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3553/4699119 [22:08<420:00:20,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3554/4699119 [22:09<458:52:45,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3555/4699119 [22:09<510:07:56,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3556/4699119 [22:09<510:20:36,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3557/4699119 [22:10<482:47:33,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3558/4699119 [22:10<464:00:50,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3559/4699119 [22:11<513:23:29,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3560/4699119 [22:11<549:08:16,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 99, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3561/4699119 [22:11<451:00:02,  2.89it/s]

last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3562/4699119 [22:12<446:12:12,  2.92it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3563/4699119 [22:12<501:27:08,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3564/4699119 [22:12<539:33:57,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3565/4699119 [22:13<515:13:48,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 412, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3566/4699119 [22:13<518:47:34,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3567/4699119 [22:14<551:07:42,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3568/4699119 [22:14<551:58:07,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3569/4699119 [22:15<527:42:08,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3570/4699119 [22:15<556:30:59,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3571/4699119 [22:15<577:47:09,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3572/4699119 [22:16<518:42:47,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3573/4699119 [22:16<489:58:04,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3574/4699119 [22:16<474:58:17,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3575/4699119 [22:17<466:24:38,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3576/4699119 [22:17<466:12:31,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3577/4699119 [22:18<480:15:07,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3578/4699119 [22:18<524:05:11,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3579/4699119 [22:18<482:58:57,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3580/4699119 [22:19<526:11:16,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3581/4699119 [22:19<509:55:50,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3582/4699119 [22:20<533:14:59,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3583/4699119 [22:20<490:11:22,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3584/4699119 [22:20<512:10:05,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3585/4699119 [22:21<460:28:08,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3586/4699119 [22:21<442:25:16,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3587/4699119 [22:21<498:55:43,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 422, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3588/4699119 [22:22<512:11:58,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3589/4699119 [22:22<547:47:40,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3590/4699119 [22:23<490:32:32,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3591/4699119 [22:23<511:51:45,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3592/4699119 [22:23<546:18:18,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3593/4699119 [22:24<519:34:43,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3594/4699119 [22:24<456:19:06,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3595/4699119 [22:24<436:28:13,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 135, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3596/4699119 [22:25<382:17:27,  3.41it/s]

last_hidden_state = torch.Size([8, 482, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3597/4699119 [22:25<452:05:36,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3598/4699119 [22:25<505:54:09,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3599/4699119 [22:26<469:27:36,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3600/4699119 [22:26<516:36:44,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3601/4699119 [22:27<548:49:34,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3602/4699119 [22:27<550:12:55,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3603/4699119 [22:28<572:48:08,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3604/4699119 [22:28<567:29:03,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3605/4699119 [22:28<491:16:24,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3606/4699119 [22:29<501:52:11,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 154, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3607/4699119 [22:29<430:41:27,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 439, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3608/4699119 [22:29<471:34:15,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3609/4699119 [22:30<451:16:09,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3610/4699119 [22:30<504:03:36,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3611/4699119 [22:31<540:43:40,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3612/4699119 [22:31<486:01:40,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3613/4699119 [22:31<533:12:04,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3614/4699119 [22:32<525:08:45,  2.48it/s]

last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3615/4699119 [22:32<491:50:24,  2.65it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3616/4699119 [22:33<533:23:14,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 128, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3617/4699119 [22:33<446:33:01,  2.92it/s]

last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3618/4699119 [22:33<409:56:41,  3.18it/s]

last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3619/4699119 [22:33<402:50:16,  3.24it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3620/4699119 [22:34<469:55:06,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 172, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3621/4699119 [22:34<413:58:13,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3622/4699119 [22:34<478:29:31,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3623/4699119 [22:35<523:49:02,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3624/4699119 [22:35<468:08:04,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3625/4699119 [22:35<421:25:05,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3626/4699119 [22:36<398:51:13,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3627/4699119 [22:36<441:06:41,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3628/4699119 [22:37<497:33:59,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3629/4699119 [22:37<443:30:26,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3630/4699119 [22:37<497:49:39,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3631/4699119 [22:38<465:42:10,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3632/4699119 [22:38<495:42:39,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3633/4699119 [22:39<536:08:35,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 412, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3634/4699119 [22:39<534:07:17,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3635/4699119 [22:39<489:59:19,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3636/4699119 [22:40<509:46:41,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3637/4699119 [22:40<544:32:44,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3638/4699119 [22:41<568:49:23,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3639/4699119 [22:41<587:50:34,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3640/4699119 [22:42<599:32:27,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3641/4699119 [22:42<609:30:22,  2.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3642/4699119 [22:43<614:39:40,  2.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3643/4699119 [22:43<599:32:51,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 510, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3644/4699119 [22:44<609:27:10,  2.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3645/4699119 [22:44<593:24:45,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3646/4699119 [22:44<603:34:14,  2.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 462, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3647/4699119 [22:45<595:14:18,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3648/4699119 [22:45<534:37:58,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3649/4699119 [22:45<502:42:32,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3650/4699119 [22:46<458:30:32,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3651/4699119 [22:46<509:51:42,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3652/4699119 [22:47<508:14:37,  2.57it/s]

last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3653/4699119 [22:47<532:42:10,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3654/4699119 [22:48<561:38:50,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3655/4699119 [22:48<582:23:52,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3656/4699119 [22:48<542:54:29,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3657/4699119 [22:49<511:10:28,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3658/4699119 [22:49<500:49:09,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3659/4699119 [22:50<540:13:46,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3660/4699119 [22:50<566:38:51,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 510, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3661/4699119 [22:51<586:11:52,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3662/4699119 [22:51<504:55:43,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3663/4699119 [22:51<459:07:54,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3664/4699119 [22:52<510:17:01,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3665/4699119 [22:52<507:04:58,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3666/4699119 [22:52<515:20:21,  2.53it/s]

last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3668/4699119 [22:53<415:59:42,  3.14it/s]

last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3669/4699119 [22:53<455:17:32,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3670/4699119 [22:54<459:24:32,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3671/4699119 [22:54<422:18:14,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3672/4699119 [22:54<406:20:14,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3673/4699119 [22:55<472:37:29,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3674/4699119 [22:55<436:37:22,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3675/4699119 [22:55<436:20:51,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 185, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3676/4699119 [22:55<396:31:44,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3677/4699119 [22:56<383:56:07,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3678/4699119 [22:56<352:13:36,  3.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3679/4699119 [22:56<396:31:06,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3680/4699119 [22:57<440:01:39,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 143, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3681/4699119 [22:57<386:10:01,  3.38it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3682/4699119 [22:57<457:57:14,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3683/4699119 [22:58<480:09:36,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3684/4699119 [22:58<524:21:33,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3685/4699119 [22:59<470:38:26,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3686/4699119 [22:59<437:38:40,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3687/4699119 [22:59<446:41:43,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3688/4699119 [22:59<406:00:29,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3689/4699119 [23:00<400:15:28,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3690/4699119 [23:00<409:09:56,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3691/4699119 [23:00<380:16:05,  3.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3692/4699119 [23:01<369:04:40,  3.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3693/4699119 [23:01<369:01:25,  3.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3694/4699119 [23:01<428:51:06,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3695/4699119 [23:02<425:57:21,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3696/4699119 [23:02<414:53:25,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3697/4699119 [23:02<437:42:01,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3698/4699119 [23:03<495:33:09,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3699/4699119 [23:03<538:55:48,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3700/4699119 [23:04<545:44:27,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3701/4699119 [23:04<533:41:09,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3702/4699119 [23:05<561:25:12,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3703/4699119 [23:05<516:17:04,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3704/4699119 [23:05<471:25:51,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3705/4699119 [23:06<517:30:01,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3706/4699119 [23:06<487:24:04,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3707/4699119 [23:06<528:49:28,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3708/4699119 [23:07<496:50:00,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 484, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3709/4699119 [23:07<531:24:32,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3710/4699119 [23:08<559:00:48,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 123, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3711/4699119 [23:08<464:36:29,  2.81it/s]

last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3712/4699119 [23:08<487:29:01,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3713/4699119 [23:09<529:36:47,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3714/4699119 [23:09<486:00:45,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3715/4699119 [23:10<529:35:59,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3716/4699119 [23:10<558:35:58,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3717/4699119 [23:10<505:55:05,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3718/4699119 [23:11<479:12:04,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 145, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3719/4699119 [23:11<414:48:52,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3720/4699119 [23:11<456:00:34,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3721/4699119 [23:12<475:41:39,  2.74it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3722/4699119 [23:12<521:17:05,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3723/4699119 [23:13<506:33:03,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3724/4699119 [23:13<464:26:25,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3725/4699119 [23:13<469:13:24,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3726/4699119 [23:14<464:57:39,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 443, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3727/4699119 [23:14<496:55:55,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3728/4699119 [23:14<483:52:41,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3729/4699119 [23:15<473:02:50,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3730/4699119 [23:15<436:32:58,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3731/4699119 [23:15<494:36:35,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3732/4699119 [23:16<470:06:06,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3733/4699119 [23:16<460:19:10,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3734/4699119 [23:17<509:39:43,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 508, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3735/4699119 [23:17<547:16:45,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3736/4699119 [23:18<571:23:23,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3737/4699119 [23:18<500:02:45,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3738/4699119 [23:18<500:14:09,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3739/4699119 [23:18<475:23:49,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3740/4699119 [23:19<462:10:00,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3741/4699119 [23:19<439:03:20,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3742/4699119 [23:19<425:52:32,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3743/4699119 [23:20<422:59:23,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3744/4699119 [23:20<473:17:55,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3745/4699119 [23:21<519:43:15,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3746/4699119 [23:21<464:32:14,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3747/4699119 [23:21<443:06:44,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 139, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3748/4699119 [23:21<387:33:22,  3.37it/s]

last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3749/4699119 [23:22<442:39:06,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3750/4699119 [23:22<498:42:20,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3751/4699119 [23:23<537:06:42,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3752/4699119 [23:23<517:58:35,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3753/4699119 [23:23<480:18:33,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3754/4699119 [23:24<526:04:57,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3755/4699119 [23:24<558:15:09,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3756/4699119 [23:25<534:58:38,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3757/4699119 [23:25<562:06:02,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3758/4699119 [23:26<581:23:27,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3759/4699119 [23:26<596:19:24,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3760/4699119 [23:27<543:44:03,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 425, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3761/4699119 [23:27<546:23:39,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3762/4699119 [23:27<484:04:35,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3763/4699119 [23:28<432:19:06,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3764/4699119 [23:28<493:13:44,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3765/4699119 [23:28<514:45:37,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3766/4699119 [23:29<548:06:44,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3767/4699119 [23:29<509:47:11,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3768/4699119 [23:29<450:57:54,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3769/4699119 [23:30<440:42:33,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3770/4699119 [23:30<496:23:49,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3771/4699119 [23:31<501:10:38,  2.60it/s]

last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3772/4699119 [23:31<531:38:28,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3773/4699119 [23:31<496:53:11,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3774/4699119 [23:32<536:53:46,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3775/4699119 [23:32<563:52:11,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3776/4699119 [23:33<582:59:02,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3777/4699119 [23:33<581:51:10,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3778/4699119 [23:34<515:48:24,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3779/4699119 [23:34<548:25:42,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3780/4699119 [23:35<571:43:49,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3781/4699119 [23:35<567:41:55,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3782/4699119 [23:35<490:17:05,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3783/4699119 [23:36<530:30:58,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3784/4699119 [23:36<560:51:48,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3785/4699119 [23:37<580:54:38,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 173, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3786/4699119 [23:37<492:13:32,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3787/4699119 [23:37<477:39:08,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3788/4699119 [23:38<497:15:06,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3789/4699119 [23:38<536:23:19,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3790/4699119 [23:38<475:27:03,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3791/4699119 [23:39<424:09:13,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3792/4699119 [23:39<414:06:56,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3793/4699119 [23:39<410:38:45,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3794/4699119 [23:40<475:32:06,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3795/4699119 [23:40<506:36:00,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3796/4699119 [23:40<473:13:57,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3797/4699119 [23:41<522:12:46,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3798/4699119 [23:41<505:48:43,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3799/4699119 [23:42<464:06:20,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3800/4699119 [23:42<445:10:08,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3801/4699119 [23:42<418:49:37,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3802/4699119 [23:42<383:44:48,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 166, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3803/4699119 [23:43<353:19:37,  3.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3804/4699119 [23:43<355:07:45,  3.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3805/4699119 [23:43<406:49:33,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3806/4699119 [23:44<446:28:19,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3807/4699119 [23:44<500:14:56,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3808/4699119 [23:45<479:01:16,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3809/4699119 [23:45<523:17:41,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 429, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3810/4699119 [23:45<533:15:01,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3811/4699119 [23:46<535:11:53,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3812/4699119 [23:46<537:21:31,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3813/4699119 [23:47<564:25:02,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3814/4699119 [23:47<510:12:59,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3815/4699119 [23:48<544:50:29,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3816/4699119 [23:48<568:20:56,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3817/4699119 [23:48<588:51:36,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3818/4699119 [23:49<601:41:18,  2.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3819/4699119 [23:49<518:12:37,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3820/4699119 [23:50<525:11:01,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3821/4699119 [23:50<474:01:10,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 139, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3822/4699119 [23:50<408:52:12,  3.19it/s]

last_hidden_state = torch.Size([8, 481, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3823/4699119 [23:51<474:20:21,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 477, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3824/4699119 [23:51<512:37:19,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3825/4699119 [23:51<477:20:02,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3826/4699119 [23:52<521:41:26,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3827/4699119 [23:52<483:15:04,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3828/4699119 [23:52<480:28:40,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3829/4699119 [23:53<455:10:58,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3830/4699119 [23:53<442:56:10,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3831/4699119 [23:54<478:10:46,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 138, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3832/4699119 [23:54<411:45:47,  3.17it/s]

last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3833/4699119 [23:54<450:37:15,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3834/4699119 [23:54<445:09:53,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3835/4699119 [23:55<405:47:44,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3836/4699119 [23:55<417:13:18,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3837/4699119 [23:55<443:48:37,  2.94it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3838/4699119 [23:56<499:15:14,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3839/4699119 [23:56<476:39:42,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3840/4699119 [23:57<487:28:17,  2.68it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3841/4699119 [23:57<529:15:30,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3842/4699119 [23:58<558:31:04,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3843/4699119 [23:58<578:40:54,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3844/4699119 [23:58<534:28:05,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3845/4699119 [23:59<561:51:12,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3846/4699119 [23:59<502:10:05,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3847/4699119 [24:00<490:11:15,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3848/4699119 [24:00<461:19:36,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3849/4699119 [24:00<429:22:23,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3850/4699119 [24:00<400:27:43,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3851/4699119 [24:01<397:26:09,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3852/4699119 [24:01<384:21:36,  3.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3853/4699119 [24:01<458:28:53,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 165, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3854/4699119 [24:02<404:52:27,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 153, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3855/4699119 [24:02<363:39:24,  3.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3856/4699119 [24:02<409:37:25,  3.18it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3857/4699119 [24:03<475:13:35,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3858/4699119 [24:03<458:20:15,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3859/4699119 [24:03<415:42:31,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3860/4699119 [24:04<380:23:15,  3.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3861/4699119 [24:04<454:11:41,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3862/4699119 [24:04<412:05:04,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3863/4699119 [24:04<382:23:50,  3.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 486, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3864/4699119 [24:05<452:09:13,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3865/4699119 [24:05<426:39:40,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3866/4699119 [24:06<431:07:33,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3867/4699119 [24:06<470:52:56,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3868/4699119 [24:06<454:45:00,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3869/4699119 [24:07<462:46:40,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3870/4699119 [24:07<443:36:30,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3871/4699119 [24:07<472:28:11,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3872/4699119 [24:08<456:31:01,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3873/4699119 [24:08<438:34:29,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3874/4699119 [24:09<495:27:01,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3875/4699119 [24:09<534:25:50,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3876/4699119 [24:09<457:29:50,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3877/4699119 [24:09<428:09:52,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3878/4699119 [24:10<439:11:20,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3879/4699119 [24:10<451:53:49,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3880/4699119 [24:11<433:33:54,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 467, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3881/4699119 [24:11<483:11:57,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3882/4699119 [24:11<473:58:56,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3883/4699119 [24:12<519:38:51,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3885/4699119 [24:12<425:16:30,  3.07it/s]

last_hidden_state = torch.Size([8, 93, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3886/4699119 [24:13<390:37:02,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3887/4699119 [24:13<424:15:55,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3888/4699119 [24:13<426:38:35,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3889/4699119 [24:14<486:27:03,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3890/4699119 [24:14<513:53:44,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3891/4699119 [24:15<548:54:22,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3892/4699119 [24:15<500:12:15,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3893/4699119 [24:15<486:09:28,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3894/4699119 [24:16<468:55:01,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3895/4699119 [24:16<456:59:04,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3896/4699119 [24:16<507:46:29,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3897/4699119 [24:17<545:01:40,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3898/4699119 [24:17<483:04:12,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3899/4699119 [24:18<526:40:05,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 126, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3900/4699119 [24:18<442:20:56,  2.95it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3901/4699119 [24:18<498:59:58,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 118, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3903/4699119 [24:19<374:32:53,  3.48it/s]

last_hidden_state = torch.Size([8, 145, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3904/4699119 [24:19<451:13:11,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3905/4699119 [24:20<504:19:09,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3906/4699119 [24:20<540:58:58,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3907/4699119 [24:21<569:25:49,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3908/4699119 [24:21<569:17:15,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3909/4699119 [24:22<586:48:49,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3910/4699119 [24:22<525:38:30,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3911/4699119 [24:22<534:11:53,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 439, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3912/4699119 [24:23<544:02:57,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3913/4699119 [24:23<544:58:41,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3914/4699119 [24:23<497:33:50,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3915/4699119 [24:24<441:11:01,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 68, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3916/4699119 [24:24<369:02:59,  3.53it/s]

last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3917/4699119 [24:24<394:18:33,  3.31it/s]

last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3918/4699119 [24:25<428:30:23,  3.04it/s]

last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3919/4699119 [24:25<465:33:18,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3920/4699119 [24:25<432:07:52,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3921/4699119 [24:26<382:51:20,  3.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3922/4699119 [24:26<408:20:07,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3923/4699119 [24:26<404:57:33,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3924/4699119 [24:27<414:28:44,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3925/4699119 [24:27<415:58:26,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3926/4699119 [24:27<466:35:59,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3927/4699119 [24:28<514:57:26,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 133, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3928/4699119 [24:28<437:15:45,  2.98it/s]

last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3929/4699119 [24:28<449:07:35,  2.90it/s]

last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3930/4699119 [24:29<463:43:15,  2.81it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3931/4699119 [24:29<512:20:19,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3932/4699119 [24:30<547:33:17,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3933/4699119 [24:30<535:30:41,  2.44it/s]

last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3934/4699119 [24:30<512:30:16,  2.54it/s]

last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3935/4699119 [24:31<462:45:22,  2.82it/s]

last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3936/4699119 [24:31<453:23:21,  2.88it/s]

last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3937/4699119 [24:31<451:07:23,  2.89it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3938/4699119 [24:32<504:15:11,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3939/4699119 [24:32<470:34:33,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 132, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3940/4699119 [24:32<405:40:16,  3.21it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3941/4699119 [24:33<471:46:35,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3942/4699119 [24:33<470:10:07,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3943/4699119 [24:33<420:43:39,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3944/4699119 [24:34<482:54:24,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3945/4699119 [24:34<428:12:27,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3946/4699119 [24:35<487:30:36,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 150, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3947/4699119 [24:35<421:01:45,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3948/4699119 [24:35<482:36:43,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3949/4699119 [24:35<415:17:06,  3.14it/s]

last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3950/4699119 [24:36<415:40:02,  3.14it/s]

last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3951/4699119 [24:36<414:16:29,  3.15it/s]

last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3952/4699119 [24:36<427:21:13,  3.05it/s]

last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3954/4699119 [24:37<413:46:18,  3.15it/s]

last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3955/4699119 [24:37<389:21:16,  3.35it/s]

last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3956/4699119 [24:38<412:42:21,  3.16it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3957/4699119 [24:38<476:47:56,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3958/4699119 [24:38<439:34:41,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 185, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3959/4699119 [24:39<398:52:26,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3960/4699119 [24:39<433:12:39,  3.01it/s]

last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3961/4699119 [24:39<401:01:09,  3.25it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3962/4699119 [24:40<469:53:58,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3963/4699119 [24:40<435:09:19,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 108, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3964/4699119 [24:40<371:51:12,  3.51it/s]

last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3965/4699119 [24:41<377:49:21,  3.45it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3966/4699119 [24:41<454:13:56,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3967/4699119 [24:41<457:27:48,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3968/4699119 [24:42<479:41:34,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3969/4699119 [24:42<470:30:03,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3970/4699119 [24:43<516:59:09,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3971/4699119 [24:43<478:43:57,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3972/4699119 [24:43<522:07:08,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3973/4699119 [24:44<554:39:35,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3974/4699119 [24:44<525:28:34,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3975/4699119 [24:45<487:42:13,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3976/4699119 [24:45<517:43:38,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3977/4699119 [24:45<514:13:34,  2.54it/s]

last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 3978/4699119 [24:46<465:28:26,  2.80it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3979/4699119 [24:46<514:42:33,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3980/4699119 [24:47<548:36:04,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3981/4699119 [24:47<571:46:07,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3982/4699119 [24:48<567:18:09,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3983/4699119 [24:48<491:12:24,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3984/4699119 [24:48<474:54:50,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3985/4699119 [24:49<520:55:27,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3986/4699119 [24:49<504:48:15,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3987/4699119 [24:49<453:35:00,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3988/4699119 [24:50<443:20:59,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3989/4699119 [24:50<442:48:55,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3990/4699119 [24:50<454:34:22,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3991/4699119 [24:51<506:35:57,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3992/4699119 [24:51<523:40:40,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3993/4699119 [24:52<555:33:23,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3994/4699119 [24:52<554:30:01,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3995/4699119 [24:53<576:17:25,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3996/4699119 [24:53<530:31:28,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3997/4699119 [24:53<465:55:14,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3998/4699119 [24:53<439:55:41,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 480, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 3999/4699119 [24:54<483:38:31,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4000/4699119 [24:54<458:16:32,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4001/4699119 [24:54<447:11:02,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4002/4699119 [24:55<442:03:39,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4003/4699119 [24:55<496:59:48,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4004/4699119 [24:56<523:21:21,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4005/4699119 [24:56<541:25:39,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4006/4699119 [24:56<481:04:16,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 511, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4007/4699119 [24:57<532:13:24,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4008/4699119 [24:57<466:39:45,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 169, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4009/4699119 [24:57<410:55:39,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 506, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4010/4699119 [24:58<475:59:10,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4011/4699119 [24:58<472:46:53,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4012/4699119 [24:59<449:26:42,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 162, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4013/4699119 [24:59<398:18:13,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4014/4699119 [24:59<433:20:20,  3.01it/s]

last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4015/4699119 [24:59<438:58:56,  2.97it/s]

last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4016/4699119 [25:00<465:42:40,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4017/4699119 [25:00<435:13:28,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4018/4699119 [25:00<423:07:07,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4019/4699119 [25:01<393:58:38,  3.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4020/4699119 [25:01<441:54:39,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4021/4699119 [25:01<400:38:04,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4022/4699119 [25:02<397:45:29,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4023/4699119 [25:02<426:11:49,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4024/4699119 [25:03<485:57:47,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4025/4699119 [25:03<467:08:16,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4026/4699119 [25:03<515:13:10,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4027/4699119 [25:04<547:58:07,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4028/4699119 [25:04<571:24:25,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4029/4699119 [25:05<528:26:13,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4030/4699119 [25:05<559:06:04,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4031/4699119 [25:05<496:15:31,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4032/4699119 [25:06<534:46:56,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4033/4699119 [25:06<561:41:40,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4034/4699119 [25:07<520:20:24,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4035/4699119 [25:07<504:01:29,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4036/4699119 [25:07<454:45:52,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4037/4699119 [25:08<506:12:29,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 307, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4038/4699119 [25:08<484:04:50,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4039/4699119 [25:09<526:47:57,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4040/4699119 [25:09<453:43:26,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4041/4699119 [25:09<505:15:28,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4042/4699119 [25:09<437:13:49,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 172, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4043/4699119 [25:10<390:47:00,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4044/4699119 [25:10<462:11:35,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4045/4699119 [25:11<513:05:54,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4046/4699119 [25:11<534:08:23,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4047/4699119 [25:11<516:22:44,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4048/4699119 [25:12<453:39:36,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4049/4699119 [25:12<461:29:17,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 117, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4051/4699119 [25:13<366:54:28,  3.55it/s]

last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4052/4699119 [25:13<387:51:58,  3.36it/s]

last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4053/4699119 [25:13<415:57:48,  3.14it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4054/4699119 [25:14<480:34:34,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4055/4699119 [25:14<465:10:12,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4056/4699119 [25:14<431:56:02,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4057/4699119 [25:15<437:50:36,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4058/4699119 [25:15<425:33:14,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 481, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4059/4699119 [25:15<488:22:42,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4060/4699119 [25:16<466:16:52,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4061/4699119 [25:16<463:02:54,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4062/4699119 [25:16<474:19:48,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4063/4699119 [25:17<450:18:11,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4064/4699119 [25:17<507:56:42,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4065/4699119 [25:18<548:14:11,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4066/4699119 [25:18<482:42:45,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4067/4699119 [25:18<456:33:35,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4068/4699119 [25:19<400:08:41,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 154, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4069/4699119 [25:19<359:54:43,  3.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4070/4699119 [25:19<359:57:02,  3.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4071/4699119 [25:19<359:09:11,  3.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4072/4699119 [25:20<441:41:10,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4073/4699119 [25:20<457:22:19,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4074/4699119 [25:20<435:52:29,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4075/4699119 [25:21<467:16:47,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4076/4699119 [25:21<502:21:25,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4077/4699119 [25:22<539:46:53,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4078/4699119 [25:22<477:44:25,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4079/4699119 [25:23<522:09:54,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4080/4699119 [25:23<553:18:17,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4081/4699119 [25:23<479:18:07,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4082/4699119 [25:23<436:09:36,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4083/4699119 [25:24<494:11:59,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4084/4699119 [25:24<533:44:22,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4085/4699119 [25:25<500:28:33,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4086/4699119 [25:25<507:07:45,  2.57it/s]

last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4087/4699119 [25:25<475:39:18,  2.74it/s]

last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4088/4699119 [25:26<523:18:23,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4089/4699119 [25:26<459:46:26,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4090/4699119 [25:27<452:19:09,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4091/4699119 [25:27<446:58:35,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4092/4699119 [25:27<446:00:51,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4093/4699119 [25:27<396:36:56,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4094/4699119 [25:28<465:13:32,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 120, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4095/4699119 [25:28<397:39:17,  3.28it/s]

last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4096/4699119 [25:28<413:40:39,  3.15it/s]

last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4097/4699119 [25:29<434:32:58,  3.00it/s]

last_hidden_state = torch.Size([8, 479, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4098/4699119 [25:29<485:35:31,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4099/4699119 [25:30<484:31:45,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4100/4699119 [25:30<442:52:16,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4101/4699119 [25:30<474:45:22,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4102/4699119 [25:31<530:52:22,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4103/4699119 [25:31<507:04:45,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4104/4699119 [25:31<460:26:51,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4105/4699119 [25:32<511:06:47,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4106/4699119 [25:32<516:52:33,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4107/4699119 [25:33<495:17:49,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4108/4699119 [25:33<471:45:04,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4109/4699119 [25:33<439:04:35,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4110/4699119 [25:34<485:35:02,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4111/4699119 [25:34<532:26:13,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4112/4699119 [25:35<525:22:40,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4113/4699119 [25:35<553:54:16,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4114/4699119 [25:36<577:22:24,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4115/4699119 [25:36<530:18:54,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4116/4699119 [25:36<474:54:43,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4117/4699119 [25:36<433:07:23,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4118/4699119 [25:37<407:39:31,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4119/4699119 [25:37<449:52:20,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4120/4699119 [25:38<465:25:26,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4121/4699119 [25:38<464:10:55,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4122/4699119 [25:38<429:38:58,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4123/4699119 [25:38<397:21:54,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4124/4699119 [25:39<468:30:25,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4125/4699119 [25:39<516:11:44,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4126/4699119 [25:40<550:35:07,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4127/4699119 [25:40<556:35:01,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4128/4699119 [25:41<531:52:33,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4129/4699119 [25:41<505:59:36,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4130/4699119 [25:41<485:40:08,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4131/4699119 [25:42<529:08:03,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 346, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4132/4699119 [25:42<507:25:54,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4133/4699119 [25:42<490:32:24,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4134/4699119 [25:43<437:11:57,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4135/4699119 [25:43<438:30:44,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4136/4699119 [25:43<452:21:48,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4137/4699119 [25:44<460:24:16,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4138/4699119 [25:44<510:10:53,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4139/4699119 [25:45<480:58:32,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4140/4699119 [25:45<524:14:15,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 401, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4141/4699119 [25:45<527:11:56,  2.47it/s]

last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4142/4699119 [25:46<479:07:18,  2.72it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4143/4699119 [25:46<523:08:46,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4144/4699119 [25:47<554:15:19,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 118, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4145/4699119 [25:47<460:32:37,  2.83it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4146/4699119 [25:47<509:42:25,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4147/4699119 [25:48<544:07:13,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4148/4699119 [25:48<490:29:52,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4149/4699119 [25:49<480:23:29,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4150/4699119 [25:49<524:01:02,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4151/4699119 [25:49<503:29:22,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4152/4699119 [25:50<528:05:46,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4153/4699119 [25:50<485:45:51,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4154/4699119 [25:51<528:03:31,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4155/4699119 [25:51<470:11:08,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4156/4699119 [25:51<434:46:04,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4157/4699119 [25:51<427:47:34,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 116, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4158/4699119 [25:52<372:31:03,  3.50it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4159/4699119 [25:52<448:37:50,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4160/4699119 [25:52<457:20:28,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4161/4699119 [25:53<457:15:25,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 499, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4162/4699119 [25:53<513:34:28,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4163/4699119 [25:54<521:51:42,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 115, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4164/4699119 [25:54<437:15:43,  2.98it/s]

last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4165/4699119 [25:54<442:59:44,  2.94it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4166/4699119 [25:55<498:15:26,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4167/4699119 [25:55<537:35:42,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 486, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4168/4699119 [25:56<561:07:46,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4169/4699119 [25:56<580:38:39,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4170/4699119 [25:57<594:32:01,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4171/4699119 [25:57<555:40:27,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4172/4699119 [25:57<577:18:34,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4173/4699119 [25:58<556:26:42,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4174/4699119 [25:58<517:12:50,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4175/4699119 [25:59<538:24:22,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4176/4699119 [25:59<511:27:25,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4177/4699119 [25:59<535:09:04,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4178/4699119 [26:00<537:16:02,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4179/4699119 [26:00<564:19:58,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4180/4699119 [26:01<542:05:35,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4181/4699119 [26:01<496:55:24,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4182/4699119 [26:01<542:23:29,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4183/4699119 [26:02<510:48:05,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4184/4699119 [26:02<492:58:08,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4185/4699119 [26:03<533:05:08,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4186/4699119 [26:03<470:45:15,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4187/4699119 [26:03<419:36:31,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4188/4699119 [26:04<481:35:18,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4189/4699119 [26:04<503:22:03,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 506, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4190/4699119 [26:05<540:57:08,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4191/4699119 [26:05<471:34:57,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4192/4699119 [26:05<471:54:01,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 74, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4193/4699119 [26:05<390:47:56,  3.34it/s]

last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4194/4699119 [26:06<422:52:13,  3.08it/s]

last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4195/4699119 [26:06<463:41:35,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4196/4699119 [26:07<513:33:19,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 445, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4197/4699119 [26:07<531:00:15,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4198/4699119 [26:07<536:49:16,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 173, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4199/4699119 [26:08<460:22:28,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4200/4699119 [26:08<460:30:34,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 199, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4201/4699119 [26:08<416:28:09,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4202/4699119 [26:09<444:26:18,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4203/4699119 [26:09<427:29:51,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4204/4699119 [26:09<487:48:09,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4205/4699119 [26:10<462:25:04,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 148, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4206/4699119 [26:10<402:35:49,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4207/4699119 [26:10<469:53:18,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4208/4699119 [26:11<510:07:28,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4209/4699119 [26:11<461:10:03,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4210/4699119 [26:12<512:25:49,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4211/4699119 [26:12<520:10:14,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4212/4699119 [26:13<553:08:52,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4213/4699119 [26:13<538:57:10,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4214/4699119 [26:13<547:36:00,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4215/4699119 [26:14<527:26:59,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4216/4699119 [26:14<474:11:44,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4217/4699119 [26:14<433:10:06,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4218/4699119 [26:15<469:32:21,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 103, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4219/4699119 [26:15<395:48:59,  3.29it/s]

last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4220/4699119 [26:15<378:27:16,  3.45it/s]

last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4221/4699119 [26:15<401:10:45,  3.25it/s]

last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4222/4699119 [26:16<455:57:59,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4223/4699119 [26:16<460:00:48,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4224/4699119 [26:17<426:03:59,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4225/4699119 [26:17<415:03:12,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4226/4699119 [26:17<408:46:40,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4227/4699119 [26:17<364:03:56,  3.58it/s]

last_hidden_state = torch.Size([8, 487, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4228/4699119 [26:18<444:32:03,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4229/4699119 [26:18<447:16:40,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4230/4699119 [26:19<501:05:19,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4231/4699119 [26:19<544:55:10,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 172, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4232/4699119 [26:19<465:48:15,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4233/4699119 [26:20<514:27:20,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4234/4699119 [26:20<466:17:02,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4235/4699119 [26:21<488:27:54,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 155, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4236/4699119 [26:21<421:46:35,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4237/4699119 [26:21<466:22:35,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4238/4699119 [26:21<456:56:51,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4239/4699119 [26:22<508:20:03,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4240/4699119 [26:22<518:10:38,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4241/4699119 [26:23<552:36:12,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 330, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4242/4699119 [26:23<519:50:01,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4243/4699119 [26:24<491:53:00,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4244/4699119 [26:24<524:48:34,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4245/4699119 [26:24<482:34:15,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4246/4699119 [26:25<464:58:54,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4247/4699119 [26:25<490:28:38,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4248/4699119 [26:26<531:18:56,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4249/4699119 [26:26<561:44:52,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 191, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4250/4699119 [26:26<485:44:03,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4251/4699119 [26:27<527:41:43,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4252/4699119 [26:27<541:21:21,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4253/4699119 [26:28<566:43:17,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4254/4699119 [26:28<587:51:27,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 235, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4255/4699119 [26:28<517:40:26,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4256/4699119 [26:29<532:26:11,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4257/4699119 [26:29<487:27:16,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4258/4699119 [26:30<529:35:41,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4259/4699119 [26:30<465:30:38,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4260/4699119 [26:30<513:48:36,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4261/4699119 [26:31<547:54:54,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4262/4699119 [26:31<533:33:46,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4263/4699119 [26:32<562:19:58,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4264/4699119 [26:32<500:32:28,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4265/4699119 [26:32<539:07:15,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4266/4699119 [26:33<565:09:32,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4267/4699119 [26:33<583:17:17,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4268/4699119 [26:34<558:58:41,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4269/4699119 [26:34<561:10:19,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4270/4699119 [26:35<531:17:29,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4271/4699119 [26:35<495:41:29,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4272/4699119 [26:35<535:01:02,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4273/4699119 [26:36<478:23:59,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4274/4699119 [26:36<438:54:49,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4275/4699119 [26:36<483:19:19,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4276/4699119 [26:37<459:30:27,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4277/4699119 [26:37<509:35:30,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4278/4699119 [26:38<508:04:10,  2.57it/s]

last_hidden_state = torch.Size([8, 439, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4279/4699119 [26:38<527:33:56,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4280/4699119 [26:38<557:05:34,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4281/4699119 [26:39<544:15:58,  2.40it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4282/4699119 [26:39<568:39:21,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4283/4699119 [26:40<530:35:29,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4284/4699119 [26:40<559:07:20,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4285/4699119 [26:41<578:43:20,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4286/4699119 [26:41<592:34:21,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4287/4699119 [26:41<570:54:13,  2.28it/s]

last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4288/4699119 [26:42<587:13:06,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4289/4699119 [26:42<547:50:47,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4290/4699119 [26:43<521:39:59,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4291/4699119 [26:43<460:47:09,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4292/4699119 [26:43<446:51:57,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4293/4699119 [26:44<500:59:30,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4294/4699119 [26:44<447:20:32,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4295/4699119 [26:44<488:59:45,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 445, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4296/4699119 [26:45<512:52:45,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4297/4699119 [26:45<547:54:59,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4298/4699119 [26:46<571:17:36,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4299/4699119 [26:46<519:07:36,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4300/4699119 [26:46<487:57:07,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4301/4699119 [26:47<439:39:06,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4302/4699119 [26:47<443:28:52,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4303/4699119 [26:47<462:53:45,  2.82it/s]

last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4304/4699119 [26:48<450:54:29,  2.89it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4305/4699119 [26:48<503:28:41,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4306/4699119 [26:49<487:08:17,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4307/4699119 [26:49<529:42:49,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4308/4699119 [26:50<559:02:57,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4309/4699119 [26:50<492:08:43,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4310/4699119 [26:50<472:45:24,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4311/4699119 [26:50<424:10:41,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4312/4699119 [26:51<449:47:45,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4313/4699119 [26:51<405:16:19,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4314/4699119 [26:51<419:44:07,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 480, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4315/4699119 [26:52<469:01:25,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4316/4699119 [26:52<506:05:10,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4317/4699119 [26:53<545:16:01,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4318/4699119 [26:53<486:22:42,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4319/4699119 [26:53<528:07:17,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4320/4699119 [26:54<541:15:11,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4321/4699119 [26:54<569:54:41,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4322/4699119 [26:55<587:04:06,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4323/4699119 [26:55<600:25:10,  2.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4324/4699119 [26:56<586:13:39,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4325/4699119 [26:56<503:01:11,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 129, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4326/4699119 [26:56<427:45:10,  3.05it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4327/4699119 [26:57<486:29:32,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 429, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4328/4699119 [26:57<510:03:45,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4329/4699119 [26:58<548:28:30,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4330/4699119 [26:58<540:07:40,  2.41it/s]

last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4331/4699119 [26:58<554:21:27,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 430, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4332/4699119 [26:59<551:00:37,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4333/4699119 [26:59<536:26:30,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4334/4699119 [27:00<564:06:17,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4335/4699119 [27:00<504:23:31,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4336/4699119 [27:01<541:17:56,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4337/4699119 [27:01<522:56:40,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4338/4699119 [27:01<469:50:21,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4339/4699119 [27:01<434:46:34,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4340/4699119 [27:02<493:07:55,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4341/4699119 [27:02<532:05:56,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4342/4699119 [27:03<504:27:14,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4343/4699119 [27:03<541:02:45,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4344/4699119 [27:04<550:27:30,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4345/4699119 [27:04<574:57:16,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4346/4699119 [27:05<592:18:29,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4347/4699119 [27:05<553:21:09,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4348/4699119 [27:05<505:16:11,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4349/4699119 [27:05<451:11:51,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4350/4699119 [27:06<411:43:44,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4351/4699119 [27:06<439:49:23,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4352/4699119 [27:06<441:37:08,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4353/4699119 [27:07<399:59:08,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4354/4699119 [27:07<416:44:08,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4355/4699119 [27:08<479:41:44,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4356/4699119 [27:08<508:36:18,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4357/4699119 [27:08<543:12:33,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4358/4699119 [27:09<568:42:54,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4359/4699119 [27:09<564:49:28,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4360/4699119 [27:10<539:55:47,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4361/4699119 [27:10<567:35:11,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4362/4699119 [27:11<586:15:59,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4363/4699119 [27:11<528:50:30,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4364/4699119 [27:11<523:56:38,  2.49it/s]

last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4365/4699119 [27:12<470:39:03,  2.77it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4366/4699119 [27:12<517:44:29,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4367/4699119 [27:13<507:44:59,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4368/4699119 [27:13<496:32:02,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4369/4699119 [27:13<536:21:46,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4370/4699119 [27:14<508:01:55,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4371/4699119 [27:14<543:56:39,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4372/4699119 [27:15<570:03:15,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4373/4699119 [27:15<576:40:44,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4374/4699119 [27:16<590:53:45,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4375/4699119 [27:16<566:09:34,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4376/4699119 [27:16<584:42:11,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4377/4699119 [27:17<596:54:23,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4378/4699119 [27:17<547:06:13,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4379/4699119 [27:18<553:30:41,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4380/4699119 [27:18<523:52:45,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4381/4699119 [27:18<502:58:41,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4382/4699119 [27:19<539:40:53,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4383/4699119 [27:19<530:25:17,  2.46it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4384/4699119 [27:20<560:02:32,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4385/4699119 [27:20<568:37:21,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 503, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4386/4699119 [27:21<593:00:25,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4387/4699119 [27:21<511:54:13,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4388/4699119 [27:21<551:49:07,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4389/4699119 [27:22<578:01:14,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4390/4699119 [27:22<530:24:28,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4391/4699119 [27:23<496:07:53,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4392/4699119 [27:23<485:16:09,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4393/4699119 [27:23<529:42:54,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4394/4699119 [27:24<559:46:55,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 186, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4395/4699119 [27:24<482:28:53,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4396/4699119 [27:24<478:09:44,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4397/4699119 [27:25<461:11:06,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4398/4699119 [27:25<495:25:53,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4399/4699119 [27:26<449:13:11,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4400/4699119 [27:26<440:13:43,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4401/4699119 [27:26<502:26:49,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4402/4699119 [27:27<539:10:57,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 503, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4403/4699119 [27:27<571:17:22,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 439, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4404/4699119 [27:28<571:07:07,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4405/4699119 [27:28<507:12:15,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4406/4699119 [27:29<543:57:20,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4407/4699119 [27:29<568:36:44,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4408/4699119 [27:29<585:46:38,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4409/4699119 [27:30<571:37:26,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4410/4699119 [27:30<587:20:24,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4411/4699119 [27:31<554:42:29,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4412/4699119 [27:31<577:35:18,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4413/4699119 [27:32<531:25:21,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4414/4699119 [27:32<469:46:03,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4415/4699119 [27:32<516:22:04,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4416/4699119 [27:33<549:29:52,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4417/4699119 [27:33<572:37:03,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4418/4699119 [27:34<588:43:06,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4419/4699119 [27:34<517:18:33,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4420/4699119 [27:34<555:17:58,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4421/4699119 [27:35<495:02:01,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4422/4699119 [27:35<494:55:45,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4423/4699119 [27:35<446:44:44,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4424/4699119 [27:36<456:04:14,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4425/4699119 [27:36<399:00:47,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4426/4699119 [27:36<415:13:04,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4427/4699119 [27:37<434:05:00,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 139, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4429/4699119 [27:37<358:24:16,  3.64it/s]

last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4430/4699119 [27:37<395:25:03,  3.30it/s]

last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4431/4699119 [27:38<416:55:32,  3.13it/s]

last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4432/4699119 [27:38<408:09:51,  3.19it/s]

last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4434/4699119 [27:39<373:55:11,  3.49it/s]

last_hidden_state = torch.Size([8, 96, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 494, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4435/4699119 [27:39<447:29:30,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4436/4699119 [27:39<441:21:28,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 167, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4437/4699119 [27:40<393:06:51,  3.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4438/4699119 [27:40<463:05:23,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4439/4699119 [27:40<429:19:05,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4441/4699119 [27:41<411:28:11,  3.17it/s]

last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4442/4699119 [27:41<387:48:37,  3.36it/s]

last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4443/4699119 [27:42<451:06:21,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4444/4699119 [27:42<503:07:35,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4445/4699119 [27:43<539:59:23,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4446/4699119 [27:43<553:41:18,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4447/4699119 [27:43<487:05:58,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4448/4699119 [27:44<497:49:34,  2.62it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4449/4699119 [27:44<537:54:13,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 133, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4450/4699119 [27:45<453:04:10,  2.88it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4451/4699119 [27:45<505:35:28,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4452/4699119 [27:45<503:23:33,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4453/4699119 [27:46<540:28:06,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4454/4699119 [27:46<565:59:56,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4455/4699119 [27:47<584:05:18,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4456/4699119 [27:47<597:01:56,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4457/4699119 [27:48<510:41:41,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4458/4699119 [27:48<508:45:38,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4459/4699119 [27:48<545:11:14,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4460/4699119 [27:49<537:14:12,  2.43it/s]

last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4461/4699119 [27:49<501:07:03,  2.60it/s]

last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4462/4699119 [27:49<454:09:56,  2.87it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4463/4699119 [27:50<505:46:29,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4464/4699119 [27:50<541:46:06,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 420, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4465/4699119 [27:51<540:23:39,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4466/4699119 [27:51<481:42:52,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4467/4699119 [27:51<473:35:11,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4468/4699119 [27:52<467:43:57,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4469/4699119 [27:52<454:52:09,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4470/4699119 [27:52<480:25:38,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4471/4699119 [27:53<437:08:46,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4472/4699119 [27:53<493:32:48,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4473/4699119 [27:54<505:42:00,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4474/4699119 [27:54<507:44:43,  2.57it/s]

last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4475/4699119 [27:54<462:56:05,  2.82it/s]

last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4476/4699119 [27:55<485:18:42,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4477/4699119 [27:55<475:14:53,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4478/4699119 [27:56<521:55:20,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4479/4699119 [27:56<503:23:24,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4480/4699119 [27:56<459:38:39,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4481/4699119 [27:56<446:36:26,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 138, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4482/4699119 [27:57<389:40:32,  3.35it/s]

last_hidden_state = torch.Size([8, 492, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4483/4699119 [27:57<458:35:24,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4484/4699119 [27:58<470:21:10,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4485/4699119 [27:58<434:05:38,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4486/4699119 [27:58<491:56:30,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4487/4699119 [27:59<532:13:24,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4488/4699119 [27:59<560:08:53,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 167, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4489/4699119 [27:59<476:54:29,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4490/4699119 [28:00<520:23:09,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4491/4699119 [28:00<540:53:02,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 421, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4492/4699119 [28:01<543:44:21,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4494/4699119 [28:01<454:01:01,  2.87it/s]

last_hidden_state = torch.Size([8, 150, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4495/4699119 [28:02<453:44:03,  2.87it/s]

last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4496/4699119 [28:02<468:39:56,  2.78it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4497/4699119 [28:03<516:34:29,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 162, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4498/4699119 [28:03<444:59:31,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4499/4699119 [28:03<449:48:13,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4500/4699119 [28:04<502:01:43,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4501/4699119 [28:04<505:03:03,  2.58it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4502/4699119 [28:05<542:03:24,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4503/4699119 [28:05<471:53:36,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4504/4699119 [28:05<501:02:27,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4505/4699119 [28:06<476:16:20,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4506/4699119 [28:06<471:09:41,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4507/4699119 [28:06<498:51:19,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4508/4699119 [28:07<535:58:22,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4509/4699119 [28:07<492:44:41,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4510/4699119 [28:08<536:33:19,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4511/4699119 [28:08<475:06:12,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4512/4699119 [28:08<416:06:12,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4513/4699119 [28:08<398:03:40,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4514/4699119 [28:09<444:49:26,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4515/4699119 [28:09<470:13:31,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 148, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4516/4699119 [28:09<408:16:55,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4517/4699119 [28:10<435:35:35,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4518/4699119 [28:10<495:18:07,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4519/4699119 [28:11<490:14:23,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4520/4699119 [28:11<530:59:38,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4521/4699119 [28:11<496:31:35,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4522/4699119 [28:12<471:56:48,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4523/4699119 [28:12<447:52:19,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4524/4699119 [28:13<501:08:31,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 161, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4525/4699119 [28:13<434:45:40,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 490, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4526/4699119 [28:13<489:14:21,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4527/4699119 [28:13<453:19:23,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4528/4699119 [28:14<408:16:48,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4529/4699119 [28:14<472:50:12,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4530/4699119 [28:15<459:11:48,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4531/4699119 [28:15<487:50:06,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4532/4699119 [28:15<529:24:23,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 151, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4533/4699119 [28:16<449:47:48,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4534/4699119 [28:16<399:36:51,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4535/4699119 [28:16<442:53:19,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4536/4699119 [28:17<419:39:38,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4537/4699119 [28:17<481:22:37,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4538/4699119 [28:18<527:36:48,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4539/4699119 [28:18<475:04:03,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4540/4699119 [28:18<479:36:54,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4541/4699119 [28:19<508:03:08,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 285, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4542/4699119 [28:19<476:33:09,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 171, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4543/4699119 [28:19<418:37:45,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 181, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4544/4699119 [28:19<383:46:01,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4545/4699119 [28:20<434:54:08,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4546/4699119 [28:20<441:00:56,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4547/4699119 [28:21<478:49:48,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4548/4699119 [28:21<522:15:02,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4549/4699119 [28:22<553:10:26,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 480, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4550/4699119 [28:22<562:38:02,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 153, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4551/4699119 [28:22<474:14:29,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 179, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4552/4699119 [28:22<422:55:49,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4553/4699119 [28:23<415:53:55,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4554/4699119 [28:23<410:57:30,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4555/4699119 [28:23<475:04:22,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4556/4699119 [28:24<439:59:34,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4557/4699119 [28:24<426:04:12,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4558/4699119 [28:24<405:08:38,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4559/4699119 [28:25<377:12:29,  3.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4560/4699119 [28:25<356:01:15,  3.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4561/4699119 [28:25<350:03:58,  3.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4562/4699119 [28:25<346:09:23,  3.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4563/4699119 [28:26<429:57:10,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4564/4699119 [28:26<452:26:31,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 307, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4565/4699119 [28:27<446:23:08,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4566/4699119 [28:27<488:47:47,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4567/4699119 [28:27<487:07:26,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4568/4699119 [28:28<434:49:37,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4569/4699119 [28:28<493:48:58,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4570/4699119 [28:28<496:58:08,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4571/4699119 [28:29<535:40:44,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4572/4699119 [28:29<563:29:28,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4573/4699119 [28:30<495:12:35,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4574/4699119 [28:30<476:49:15,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4575/4699119 [28:30<426:33:05,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4576/4699119 [28:31<408:19:34,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4577/4699119 [28:31<376:17:08,  3.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4578/4699119 [28:31<399:42:19,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4579/4699119 [28:32<455:15:11,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 176, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4580/4699119 [28:32<403:48:12,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4581/4699119 [28:32<397:56:51,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4582/4699119 [28:32<436:53:34,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4583/4699119 [28:33<468:50:40,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4584/4699119 [28:33<470:25:21,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4585/4699119 [28:34<517:03:23,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 121, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4586/4699119 [28:34<434:35:30,  3.00it/s]

last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4587/4699119 [28:34<436:08:29,  2.99it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4588/4699119 [28:35<494:40:11,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4589/4699119 [28:35<507:58:08,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 412, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4590/4699119 [28:36<514:33:17,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4591/4699119 [28:36<476:00:37,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4592/4699119 [28:36<488:54:20,  2.67it/s]

last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4593/4699119 [28:37<508:18:24,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4594/4699119 [28:37<547:47:32,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4595/4699119 [28:38<512:38:44,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 134, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4596/4699119 [28:38<435:33:44,  2.99it/s]

last_hidden_state = torch.Size([8, 460, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4597/4699119 [28:38<477:21:37,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4598/4699119 [28:39<522:27:58,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4599/4699119 [28:39<556:36:59,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4600/4699119 [28:39<506:32:39,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4601/4699119 [28:40<455:10:38,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4602/4699119 [28:40<508:07:27,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4603/4699119 [28:40<451:15:41,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4604/4699119 [28:41<504:00:24,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4605/4699119 [28:41<510:23:48,  2.55it/s]

last_hidden_state = torch.Size([8, 479, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4606/4699119 [28:42<538:29:36,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 179, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4607/4699119 [28:42<467:50:28,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 481, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4608/4699119 [28:42<516:56:45,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4609/4699119 [28:43<540:41:32,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4610/4699119 [28:43<512:16:29,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4611/4699119 [28:44<547:44:59,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4612/4699119 [28:44<577:35:07,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 235, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4613/4699119 [28:45<509:31:30,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 490, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4614/4699119 [28:45<541:43:09,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4615/4699119 [28:45<567:45:02,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4616/4699119 [28:46<555:35:24,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 505, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4617/4699119 [28:46<582:06:07,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4618/4699119 [28:47<550:05:19,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4619/4699119 [28:47<572:38:14,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4620/4699119 [28:48<588:13:33,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4621/4699119 [28:48<599:51:27,  2.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4622/4699119 [28:49<596:38:23,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4623/4699119 [28:49<607:05:01,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4624/4699119 [28:49<543:15:28,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4625/4699119 [28:50<571:38:01,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4626/4699119 [28:50<536:37:13,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4627/4699119 [28:51<500:40:20,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4628/4699119 [28:51<537:54:06,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4629/4699119 [28:52<557:30:26,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4630/4699119 [28:52<487:50:09,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 152, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4631/4699119 [28:52<420:37:20,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4632/4699119 [28:52<444:37:57,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4633/4699119 [28:53<427:45:04,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4634/4699119 [28:53<396:30:14,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4635/4699119 [28:53<383:32:49,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4636/4699119 [28:54<429:55:44,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4637/4699119 [28:54<490:06:34,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4638/4699119 [28:54<518:14:05,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4639/4699119 [28:55<517:05:06,  2.52it/s]

last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4640/4699119 [28:55<517:52:23,  2.52it/s]

last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4642/4699119 [28:56<420:00:34,  3.10it/s]

last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4643/4699119 [28:56<482:11:31,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4644/4699119 [28:57<526:39:49,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4645/4699119 [28:57<558:27:04,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4646/4699119 [28:57<482:36:45,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4647/4699119 [28:58<442:55:00,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4648/4699119 [28:58<388:31:26,  3.36it/s]

last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4650/4699119 [28:59<363:09:48,  3.59it/s]

last_hidden_state = torch.Size([8, 138, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4651/4699119 [28:59<390:53:14,  3.34it/s]

last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4652/4699119 [28:59<393:09:22,  3.32it/s]

last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4653/4699119 [28:59<381:11:22,  3.42it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4654/4699119 [29:00<456:30:23,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4655/4699119 [29:00<472:37:52,  2.76it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4656/4699119 [29:01<519:06:41,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4657/4699119 [29:01<551:00:01,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4658/4699119 [29:02<540:58:32,  2.41it/s]

last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4659/4699119 [29:02<480:15:46,  2.72it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4660/4699119 [29:02<534:49:44,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4661/4699119 [29:03<508:43:23,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 507, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4662/4699119 [29:03<549:45:24,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4663/4699119 [29:04<524:13:29,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4664/4699119 [29:04<504:55:52,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4665/4699119 [29:04<541:24:43,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4666/4699119 [29:05<485:53:35,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4667/4699119 [29:05<472:26:46,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4668/4699119 [29:06<518:40:07,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4669/4699119 [29:06<551:17:35,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4670/4699119 [29:07<573:50:15,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 495, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4671/4699119 [29:07<592:39:51,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4672/4699119 [29:07<513:31:34,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4673/4699119 [29:08<547:46:01,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 346, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4674/4699119 [29:08<519:29:01,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4675/4699119 [29:08<491:27:41,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4676/4699119 [29:09<460:07:41,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4677/4699119 [29:09<455:03:39,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4678/4699119 [29:09<437:11:16,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4679/4699119 [29:10<399:42:03,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4680/4699119 [29:10<467:29:44,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4681/4699119 [29:10<489:27:34,  2.66it/s]

last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4682/4699119 [29:11<445:25:31,  2.93it/s]

last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4683/4699119 [29:11<443:40:09,  2.94it/s]

last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4684/4699119 [29:11<436:38:23,  2.99it/s]

last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4685/4699119 [29:12<467:55:20,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4686/4699119 [29:12<443:42:29,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4687/4699119 [29:13<491:36:54,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 113, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4689/4699119 [29:13<368:47:00,  3.54it/s]

last_hidden_state = torch.Size([8, 138, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4690/4699119 [29:13<389:29:29,  3.35it/s]

last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4691/4699119 [29:14<425:37:27,  3.06it/s]

last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4692/4699119 [29:14<405:40:35,  3.21it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4693/4699119 [29:14<472:01:10,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4694/4699119 [29:15<518:17:29,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4695/4699119 [29:15<504:02:58,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4696/4699119 [29:16<541:38:52,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4697/4699119 [29:16<507:07:16,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 314, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4698/4699119 [29:16<484:42:41,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4699/4699119 [29:17<463:53:14,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4700/4699119 [29:17<500:40:57,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4701/4699119 [29:17<444:07:07,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4702/4699119 [29:18<425:52:59,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 284, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4703/4699119 [29:18<416:54:20,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4704/4699119 [29:18<370:14:38,  3.52it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4705/4699119 [29:19<447:11:42,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4706/4699119 [29:19<500:24:14,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 462, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4707/4699119 [29:20<522:49:58,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4708/4699119 [29:20<507:51:26,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4709/4699119 [29:20<506:12:27,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4710/4699119 [29:21<461:05:13,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4711/4699119 [29:21<500:50:56,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4712/4699119 [29:21<456:30:06,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4713/4699119 [29:22<425:34:20,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 169, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4714/4699119 [29:22<382:35:17,  3.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4715/4699119 [29:22<444:36:50,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4716/4699119 [29:23<427:19:48,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4717/4699119 [29:23<439:39:15,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4718/4699119 [29:23<407:03:04,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4719/4699119 [29:23<379:16:46,  3.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4720/4699119 [29:24<366:56:28,  3.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4721/4699119 [29:24<384:14:44,  3.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4722/4699119 [29:25<456:34:54,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4723/4699119 [29:25<508:18:16,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4724/4699119 [29:25<448:35:11,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4725/4699119 [29:26<439:09:27,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4726/4699119 [29:26<445:16:59,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4727/4699119 [29:26<467:15:13,  2.79it/s]

last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4729/4699119 [29:27<405:54:52,  3.21it/s]

last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4730/4699119 [29:27<390:45:56,  3.34it/s]

last_hidden_state = torch.Size([8, 384, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4732/4699119 [29:28<390:47:25,  3.34it/s]

last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4733/4699119 [29:28<463:19:09,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4734/4699119 [29:29<424:47:33,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4735/4699119 [29:29<399:02:41,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4736/4699119 [29:29<375:58:33,  3.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4737/4699119 [29:29<353:09:42,  3.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4738/4699119 [29:30<355:12:58,  3.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4739/4699119 [29:30<413:50:52,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4740/4699119 [29:30<389:51:01,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4741/4699119 [29:30<367:02:19,  3.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4742/4699119 [29:31<373:19:06,  3.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4743/4699119 [29:31<386:21:13,  3.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4744/4699119 [29:31<431:16:04,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4745/4699119 [29:32<446:16:44,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4746/4699119 [29:32<412:52:57,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 171, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4747/4699119 [29:32<373:10:04,  3.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4748/4699119 [29:33<404:27:24,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4749/4699119 [29:33<407:21:56,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4750/4699119 [29:33<383:16:26,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4751/4699119 [29:34<404:37:44,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 128, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4752/4699119 [29:34<357:10:32,  3.65it/s]

last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4753/4699119 [29:34<384:48:08,  3.39it/s]

last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4754/4699119 [29:34<376:32:04,  3.46it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4755/4699119 [29:35<451:08:55,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4756/4699119 [29:35<503:00:41,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4757/4699119 [29:36<444:10:06,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4758/4699119 [29:36<501:58:40,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4759/4699119 [29:37<523:05:53,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4760/4699119 [29:37<464:41:35,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4761/4699119 [29:37<460:03:58,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4762/4699119 [29:38<475:10:20,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4763/4699119 [29:38<474:16:18,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 174, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4764/4699119 [29:38<417:11:20,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4765/4699119 [29:38<398:24:25,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4766/4699119 [29:39<440:45:13,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4767/4699119 [29:39<470:14:49,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4769/4699119 [29:40<380:13:42,  3.43it/s]

last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4770/4699119 [29:40<418:10:18,  3.12it/s]

last_hidden_state = torch.Size([8, 495, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4771/4699119 [29:41<481:58:32,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4772/4699119 [29:41<444:38:37,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4773/4699119 [29:41<501:30:54,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4774/4699119 [29:42<466:06:04,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4775/4699119 [29:42<427:55:21,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4776/4699119 [29:42<457:00:06,  2.85it/s]

last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4778/4699119 [29:43<393:31:58,  3.31it/s]

last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4779/4699119 [29:43<381:48:33,  3.42it/s]

last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4780/4699119 [29:43<402:45:08,  3.24it/s]

last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4781/4699119 [29:44<461:47:06,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4782/4699119 [29:44<458:38:47,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4783/4699119 [29:44<439:45:06,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 174, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4784/4699119 [29:45<392:43:02,  3.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4785/4699119 [29:45<463:13:14,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4786/4699119 [29:46<459:23:03,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4787/4699119 [29:46<454:34:30,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4788/4699119 [29:46<410:10:04,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4789/4699119 [29:46<388:47:31,  3.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4790/4699119 [29:47<460:25:04,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 307, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4791/4699119 [29:47<451:42:27,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4792/4699119 [29:48<494:36:02,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4793/4699119 [29:48<534:36:05,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4794/4699119 [29:49<533:52:53,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4795/4699119 [29:49<562:31:03,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4796/4699119 [29:49<581:43:56,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4797/4699119 [29:50<526:34:27,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4798/4699119 [29:50<471:48:45,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4799/4699119 [29:51<517:51:52,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 454, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4800/4699119 [29:51<534:12:54,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4801/4699119 [29:51<493:55:46,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4802/4699119 [29:52<465:32:55,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4803/4699119 [29:52<427:55:34,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4804/4699119 [29:52<488:03:39,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4805/4699119 [29:53<530:58:26,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4806/4699119 [29:53<541:01:48,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4807/4699119 [29:53<471:30:01,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4808/4699119 [29:54<517:14:04,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4809/4699119 [29:54<489:39:16,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4810/4699119 [29:55<530:21:44,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4811/4699119 [29:55<515:43:26,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4812/4699119 [29:56<533:32:54,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4813/4699119 [29:56<494:46:07,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4814/4699119 [29:56<511:56:08,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4815/4699119 [29:57<545:55:29,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4816/4699119 [29:57<525:08:41,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 492, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4817/4699119 [29:58<553:34:49,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4818/4699119 [29:58<577:03:32,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4819/4699119 [29:58<495:33:15,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4820/4699119 [29:59<535:39:09,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4821/4699119 [29:59<494:21:38,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4822/4699119 [30:00<536:26:47,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4823/4699119 [30:00<557:25:59,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 487, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4824/4699119 [30:01<580:08:07,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4825/4699119 [30:01<592:38:27,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4826/4699119 [30:01<551:44:14,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 122, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4827/4699119 [30:02<459:36:09,  2.84it/s]

last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4828/4699119 [30:02<455:42:01,  2.86it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4829/4699119 [30:02<507:16:37,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 445, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4830/4699119 [30:03<526:12:22,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4831/4699119 [30:03<556:47:00,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4832/4699119 [30:04<577:37:53,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4833/4699119 [30:04<521:54:51,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4834/4699119 [30:05<531:43:23,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4835/4699119 [30:05<491:50:08,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4836/4699119 [30:05<515:47:33,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4837/4699119 [30:06<550:17:28,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4838/4699119 [30:06<511:22:27,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4839/4699119 [30:07<545:42:59,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4840/4699119 [30:07<571:56:50,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4841/4699119 [30:08<587:20:04,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4842/4699119 [30:08<572:06:25,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 144, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4843/4699119 [30:08<478:23:58,  2.73it/s]

last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4845/4699119 [30:09<403:57:20,  3.23it/s]

last_hidden_state = torch.Size([8, 199, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4846/4699119 [30:09<419:50:37,  3.11it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4847/4699119 [30:09<481:15:57,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4848/4699119 [30:10<436:08:51,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4849/4699119 [30:10<424:54:58,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4850/4699119 [30:10<401:16:16,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4851/4699119 [30:11<371:28:36,  3.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4852/4699119 [30:11<430:19:02,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4853/4699119 [30:11<474:55:33,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4854/4699119 [30:12<488:21:23,  2.67it/s]

last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4855/4699119 [30:12<467:14:02,  2.79it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4856/4699119 [30:13<516:25:26,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4857/4699119 [30:13<463:23:41,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4858/4699119 [30:13<481:24:00,  2.71it/s]

last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4859/4699119 [30:14<500:14:05,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4860/4699119 [30:14<477:44:47,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4861/4699119 [30:14<440:22:40,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4862/4699119 [30:15<495:55:16,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4863/4699119 [30:15<463:56:34,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4864/4699119 [30:15<444:12:03,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4865/4699119 [30:16<498:53:38,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4866/4699119 [30:16<537:16:01,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4867/4699119 [30:17<563:55:50,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4868/4699119 [30:17<583:31:12,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4869/4699119 [30:18<550:56:55,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4870/4699119 [30:18<534:21:27,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4871/4699119 [30:19<562:20:47,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4872/4699119 [30:19<501:39:44,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4873/4699119 [30:19<538:34:20,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4874/4699119 [30:20<565:09:07,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4875/4699119 [30:20<540:32:33,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4876/4699119 [30:21<566:45:15,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4877/4699119 [30:21<501:16:27,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4878/4699119 [30:21<505:33:34,  2.58it/s]

last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4879/4699119 [30:22<484:08:47,  2.69it/s]

last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4881/4699119 [30:22<405:32:34,  3.22it/s]

last_hidden_state = torch.Size([8, 128, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4882/4699119 [30:23<471:15:43,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4883/4699119 [30:23<518:49:05,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 132, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4885/4699119 [30:24<402:41:10,  3.24it/s]

last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4886/4699119 [30:24<470:06:51,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4887/4699119 [30:24<466:20:04,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4888/4699119 [30:25<515:35:35,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4889/4699119 [30:25<550:50:19,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4890/4699119 [30:26<573:43:52,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4891/4699119 [30:26<523:26:11,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4892/4699119 [30:26<482:04:46,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4893/4699119 [30:27<512:06:47,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4894/4699119 [30:27<493:44:51,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4895/4699119 [30:27<435:50:08,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 495, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4896/4699119 [30:28<495:31:42,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4897/4699119 [30:28<494:03:44,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4898/4699119 [30:29<438:21:43,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4899/4699119 [30:29<420:53:08,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4900/4699119 [30:29<431:45:37,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4901/4699119 [30:30<422:17:43,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4902/4699119 [30:30<486:53:04,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4903/4699119 [30:30<451:06:41,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4904/4699119 [30:31<406:18:30,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4905/4699119 [30:31<386:12:39,  3.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4906/4699119 [30:31<371:25:41,  3.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4907/4699119 [30:31<397:40:28,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4908/4699119 [30:32<396:29:16,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4909/4699119 [30:32<378:40:04,  3.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4910/4699119 [30:32<391:07:44,  3.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4911/4699119 [30:33<371:19:47,  3.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 268, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4912/4699119 [30:33<375:35:51,  3.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4913/4699119 [30:33<450:31:18,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4914/4699119 [30:34<430:39:03,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4915/4699119 [30:34<430:30:32,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4916/4699119 [30:34<437:07:24,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4917/4699119 [30:35<476:29:42,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4918/4699119 [30:35<434:31:44,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4919/4699119 [30:35<472:27:57,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4920/4699119 [30:36<434:17:00,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 134, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4921/4699119 [30:36<380:35:44,  3.43it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4922/4699119 [30:36<454:04:38,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4923/4699119 [30:37<505:58:29,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4924/4699119 [30:37<477:51:54,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4925/4699119 [30:38<523:23:00,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4926/4699119 [30:38<554:41:54,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4927/4699119 [30:38<513:50:58,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4928/4699119 [30:39<465:37:45,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4929/4699119 [30:39<487:00:46,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4930/4699119 [30:39<443:14:39,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 285, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4931/4699119 [30:40<429:50:57,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4932/4699119 [30:40<481:02:16,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4933/4699119 [30:40<466:43:21,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4934/4699119 [30:41<462:34:32,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4935/4699119 [30:41<437:39:50,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4936/4699119 [30:41<449:01:43,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 285, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4937/4699119 [30:42<434:08:35,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4938/4699119 [30:42<491:21:49,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4939/4699119 [30:43<447:11:57,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4940/4699119 [30:43<417:57:09,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 499, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4941/4699119 [30:43<484:55:41,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4942/4699119 [30:44<528:28:41,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4943/4699119 [30:44<464:34:44,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4944/4699119 [30:44<449:30:01,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4945/4699119 [30:45<451:29:22,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4946/4699119 [30:45<488:27:49,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4947/4699119 [30:45<419:57:34,  3.10it/s]

last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4948/4699119 [30:46<448:18:49,  2.91it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4949/4699119 [30:46<501:39:14,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4950/4699119 [30:47<505:29:38,  2.58it/s]

last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4951/4699119 [30:47<540:15:10,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4952/4699119 [30:47<529:11:24,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4953/4699119 [30:48<557:53:16,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 462, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4954/4699119 [30:48<563:16:00,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4955/4699119 [30:49<538:30:06,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4956/4699119 [30:49<553:16:01,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4957/4699119 [30:50<575:34:03,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4958/4699119 [30:50<517:33:16,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4959/4699119 [30:50<482:51:42,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4960/4699119 [30:51<482:59:49,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4961/4699119 [30:51<442:35:21,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4962/4699119 [30:51<477:45:25,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4963/4699119 [30:52<521:48:14,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 488, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4964/4699119 [30:52<548:58:16,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4965/4699119 [30:53<572:40:00,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4966/4699119 [30:53<506:53:08,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4967/4699119 [30:53<493:20:32,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4968/4699119 [30:54<469:44:18,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4969/4699119 [30:54<456:45:54,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4970/4699119 [30:54<397:22:08,  3.28it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4971/4699119 [30:55<466:01:30,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 487, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4972/4699119 [30:55<515:09:15,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4973/4699119 [30:56<489:27:17,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4974/4699119 [30:56<518:08:00,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4975/4699119 [30:56<541:19:52,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4976/4699119 [30:57<473:18:25,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4977/4699119 [30:57<493:05:02,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4978/4699119 [30:57<501:58:58,  2.60it/s]

last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4979/4699119 [30:58<469:34:29,  2.78it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4980/4699119 [30:58<516:56:52,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4981/4699119 [30:59<550:22:54,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4982/4699119 [30:59<576:22:17,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4983/4699119 [31:00<592:05:27,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4984/4699119 [31:00<602:06:25,  2.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4985/4699119 [31:01<545:55:58,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 481, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4986/4699119 [31:01<572:02:40,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4987/4699119 [31:01<492:38:42,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4988/4699119 [31:02<532:31:20,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4989/4699119 [31:02<560:46:12,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4990/4699119 [31:03<517:39:37,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4991/4699119 [31:03<453:20:54,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 430, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4992/4699119 [31:03<481:14:41,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 4993/4699119 [31:04<488:50:19,  2.67it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4994/4699119 [31:04<529:21:55,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4995/4699119 [31:04<517:45:55,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 330, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4996/4699119 [31:05<496:28:14,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4997/4699119 [31:05<535:14:37,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4998/4699119 [31:06<567:11:40,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 4999/4699119 [31:06<538:00:35,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5000/4699119 [31:07<565:15:13,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5001/4699119 [31:07<480:51:54,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5002/4699119 [31:07<525:39:46,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5003/4699119 [31:08<540:03:45,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 425, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5004/4699119 [31:08<544:24:12,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5005/4699119 [31:09<536:34:48,  2.43it/s]

last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5007/4699119 [31:09<412:16:02,  3.16it/s]

last_hidden_state = torch.Size([8, 143, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5008/4699119 [31:09<405:14:05,  3.22it/s]

last_hidden_state = torch.Size([8, 465, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5009/4699119 [31:10<463:58:36,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5010/4699119 [31:10<468:31:03,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5011/4699119 [31:10<443:08:33,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5012/4699119 [31:11<455:19:07,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5013/4699119 [31:11<505:49:58,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5014/4699119 [31:12<542:15:09,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5015/4699119 [31:12<568:18:42,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5016/4699119 [31:13<586:11:50,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 156, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5017/4699119 [31:13<490:18:46,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5018/4699119 [31:13<458:40:54,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5019/4699119 [31:13<425:51:02,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5020/4699119 [31:14<448:51:04,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5021/4699119 [31:14<503:48:20,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 178, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5022/4699119 [31:15<442:34:03,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5023/4699119 [31:15<414:53:12,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5024/4699119 [31:15<397:39:37,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5025/4699119 [31:15<385:38:15,  3.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5026/4699119 [31:16<448:31:41,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5027/4699119 [31:16<502:38:55,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5028/4699119 [31:17<538:46:40,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5029/4699119 [31:17<566:08:13,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5030/4699119 [31:18<584:50:37,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 360, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5031/4699119 [31:18<549:07:45,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 477, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5032/4699119 [31:19<564:56:22,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5033/4699119 [31:19<583:22:39,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5034/4699119 [31:19<499:34:04,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5035/4699119 [31:20<537:30:27,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5036/4699119 [31:20<563:57:41,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5037/4699119 [31:21<545:13:40,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5038/4699119 [31:21<570:54:28,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5039/4699119 [31:22<553:46:39,  2.35it/s]

last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5040/4699119 [31:22<535:35:07,  2.43it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5041/4699119 [31:22<564:05:06,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5042/4699119 [31:23<538:41:29,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 490, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5043/4699119 [31:23<564:04:42,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5044/4699119 [31:24<582:55:58,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5045/4699119 [31:24<584:02:50,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 181, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5046/4699119 [31:24<499:59:08,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5047/4699119 [31:25<468:47:21,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5048/4699119 [31:25<446:35:53,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5049/4699119 [31:25<500:34:51,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5050/4699119 [31:26<505:56:15,  2.58it/s]

last_hidden_state = torch.Size([8, 490, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5051/4699119 [31:26<540:08:17,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 184, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5052/4699119 [31:27<468:21:22,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5053/4699119 [31:27<515:40:12,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5054/4699119 [31:27<505:48:00,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5055/4699119 [31:28<542:27:55,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 479, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5056/4699119 [31:28<562:07:35,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5057/4699119 [31:29<533:32:25,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5058/4699119 [31:29<561:52:43,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5059/4699119 [31:29<488:15:52,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5060/4699119 [31:30<529:26:24,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5061/4699119 [31:30<489:40:35,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5062/4699119 [31:31<531:37:42,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5064/4699119 [31:31<461:42:41,  2.82it/s]

last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5065/4699119 [31:32<467:46:23,  2.79it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5066/4699119 [31:32<515:11:35,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5067/4699119 [31:33<535:56:15,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5068/4699119 [31:33<491:05:34,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5069/4699119 [31:33<515:12:53,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5070/4699119 [31:34<466:38:08,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5071/4699119 [31:34<514:16:05,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 191, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5072/4699119 [31:34<452:39:37,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5073/4699119 [31:35<453:48:40,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5074/4699119 [31:35<431:42:44,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5075/4699119 [31:35<426:24:59,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5076/4699119 [31:36<404:11:30,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5077/4699119 [31:36<470:43:48,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5078/4699119 [31:37<491:53:54,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5079/4699119 [31:37<453:36:11,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 166, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5080/4699119 [31:37<401:52:31,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 148, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5081/4699119 [31:37<360:32:15,  3.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5082/4699119 [31:38<440:40:33,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5083/4699119 [31:38<497:12:09,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 235, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5084/4699119 [31:38<452:54:50,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5085/4699119 [31:39<458:57:56,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5086/4699119 [31:39<414:51:51,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5087/4699119 [31:39<420:09:39,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 132, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5088/4699119 [31:40<370:12:03,  3.52it/s]

last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5090/4699119 [31:40<377:28:50,  3.45it/s]

last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 384, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5091/4699119 [31:41<410:46:02,  3.17it/s]

last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5092/4699119 [31:41<436:59:56,  2.98it/s]

last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5093/4699119 [31:41<472:49:53,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5094/4699119 [31:42<434:21:36,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5095/4699119 [31:42<493:04:47,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5096/4699119 [31:42<489:20:16,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5097/4699119 [31:43<439:47:12,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5098/4699119 [31:43<424:15:58,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5099/4699119 [31:43<433:53:35,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5100/4699119 [31:44<395:11:48,  3.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5101/4699119 [31:44<376:27:10,  3.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5102/4699119 [31:44<364:12:02,  3.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5103/4699119 [31:45<444:11:52,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5104/4699119 [31:45<404:46:20,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5105/4699119 [31:45<403:37:51,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5106/4699119 [31:46<474:16:47,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 297, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5107/4699119 [31:46<458:23:19,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5108/4699119 [31:46<438:36:50,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 499, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5109/4699119 [31:47<497:47:36,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5110/4699119 [31:47<536:03:55,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5111/4699119 [31:48<505:44:50,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 133, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5112/4699119 [31:48<430:48:07,  3.03it/s]

last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5113/4699119 [31:48<467:00:54,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5114/4699119 [31:49<458:43:45,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5115/4699119 [31:49<510:46:26,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5116/4699119 [31:49<453:06:26,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5117/4699119 [31:50<506:03:17,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 166, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5118/4699119 [31:50<438:20:54,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5119/4699119 [31:50<494:41:57,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5120/4699119 [31:51<497:29:28,  2.62it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5121/4699119 [31:51<535:52:07,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5122/4699119 [31:52<537:01:23,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5123/4699119 [31:52<563:22:12,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5124/4699119 [31:52<494:28:59,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5125/4699119 [31:53<523:30:36,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5126/4699119 [31:53<554:30:11,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5127/4699119 [31:54<549:22:02,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5128/4699119 [31:54<555:13:58,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5129/4699119 [31:55<553:06:51,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5130/4699119 [31:55<467:18:31,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5131/4699119 [31:55<421:25:45,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5132/4699119 [31:56<481:44:14,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5133/4699119 [31:56<524:10:16,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5134/4699119 [31:57<556:07:50,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5135/4699119 [31:57<520:57:10,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5136/4699119 [31:57<481:24:32,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5137/4699119 [31:58<488:57:44,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5138/4699119 [31:58<529:42:36,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5139/4699119 [31:59<557:48:19,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5140/4699119 [31:59<530:44:54,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5141/4699119 [31:59<562:14:22,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5142/4699119 [32:00<540:47:22,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5143/4699119 [32:00<553:03:10,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 100, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5144/4699119 [32:00<453:47:06,  2.87it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5145/4699119 [32:01<505:39:09,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5146/4699119 [32:01<529:52:34,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5147/4699119 [32:02<472:36:16,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5148/4699119 [32:02<478:25:46,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5149/4699119 [32:02<466:28:20,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5150/4699119 [32:03<451:32:29,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 113, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5151/4699119 [32:03<387:33:40,  3.36it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5152/4699119 [32:03<460:11:25,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5153/4699119 [32:03<412:34:14,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5154/4699119 [32:04<476:15:37,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5155/4699119 [32:04<464:00:05,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5156/4699119 [32:05<512:09:19,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5157/4699119 [32:05<524:40:16,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5158/4699119 [32:05<459:14:01,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5159/4699119 [32:06<509:56:17,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5160/4699119 [32:06<489:36:36,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5161/4699119 [32:07<499:21:14,  2.61it/s]

last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5162/4699119 [32:07<477:06:20,  2.73it/s]

last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5163/4699119 [32:07<501:53:34,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5164/4699119 [32:08<538:53:47,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5165/4699119 [32:08<494:20:39,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5166/4699119 [32:09<535:10:49,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5167/4699119 [32:09<547:34:47,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5168/4699119 [32:09<503:31:02,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 443, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5169/4699119 [32:10<522:20:54,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5170/4699119 [32:10<553:55:06,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5171/4699119 [32:11<545:14:12,  2.39it/s]

last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5172/4699119 [32:11<534:51:37,  2.44it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5173/4699119 [32:12<563:14:18,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5174/4699119 [32:12<496:06:07,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5175/4699119 [32:12<458:56:48,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5176/4699119 [32:12<458:06:55,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5177/4699119 [32:13<511:28:46,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5178/4699119 [32:13<514:29:42,  2.53it/s]

last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5180/4699119 [32:14<412:46:31,  3.16it/s]

last_hidden_state = torch.Size([8, 128, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5181/4699119 [32:14<422:23:36,  3.09it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5182/4699119 [32:15<484:02:29,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5183/4699119 [32:15<526:56:20,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5184/4699119 [32:16<509:49:28,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5185/4699119 [32:16<504:28:10,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5186/4699119 [32:16<490:12:31,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5187/4699119 [32:17<531:11:34,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5188/4699119 [32:17<560:03:44,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 136, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5189/4699119 [32:17<468:15:43,  2.78it/s]

last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5190/4699119 [32:18<464:11:15,  2.81it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5191/4699119 [32:18<512:15:10,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 134, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5192/4699119 [32:18<435:10:33,  3.00it/s]

last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5193/4699119 [32:19<428:55:41,  3.04it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5194/4699119 [32:19<487:36:42,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5195/4699119 [32:20<528:53:51,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5196/4699119 [32:20<557:33:14,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5197/4699119 [32:21<578:52:14,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5198/4699119 [32:21<554:17:09,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5199/4699119 [32:21<523:13:16,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5200/4699119 [32:22<520:30:00,  2.51it/s]

last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5201/4699119 [32:22<503:27:09,  2.59it/s]

last_hidden_state = torch.Size([8, 443, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5202/4699119 [32:23<523:55:59,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5203/4699119 [32:23<555:48:04,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5204/4699119 [32:24<576:34:21,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5205/4699119 [32:24<530:26:24,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5206/4699119 [32:24<525:10:51,  2.48it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5207/4699119 [32:25<555:15:48,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5208/4699119 [32:25<485:35:45,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5210/4699119 [32:26<426:13:00,  3.06it/s]

last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5211/4699119 [32:26<485:52:41,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5212/4699119 [32:26<482:54:12,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5213/4699119 [32:27<443:45:11,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5214/4699119 [32:27<403:49:54,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5215/4699119 [32:27<374:18:05,  3.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5216/4699119 [32:27<365:42:02,  3.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5217/4699119 [32:28<347:01:43,  3.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5218/4699119 [32:28<343:46:34,  3.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5219/4699119 [32:28<427:22:38,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5220/4699119 [32:29<474:13:07,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5221/4699119 [32:29<519:36:27,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5222/4699119 [32:30<519:34:06,  2.51it/s]

last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5223/4699119 [32:30<464:11:44,  2.81it/s]

last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5225/4699119 [32:31<395:41:31,  3.30it/s]

last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 237, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5226/4699119 [32:31<382:49:53,  3.41it/s]

last_hidden_state = torch.Size([8, 445, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5227/4699119 [32:31<439:04:24,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5228/4699119 [32:32<465:22:52,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5229/4699119 [32:32<492:56:24,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5230/4699119 [32:32<446:26:09,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5231/4699119 [32:33<404:58:36,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5232/4699119 [32:33<471:24:29,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5233/4699119 [32:33<428:51:23,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5234/4699119 [32:34<488:10:45,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5235/4699119 [32:34<444:05:09,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5236/4699119 [32:34<435:07:09,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5237/4699119 [32:35<492:18:09,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5238/4699119 [32:35<507:51:32,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5239/4699119 [32:36<543:48:59,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5240/4699119 [32:36<514:37:38,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5241/4699119 [32:36<493:41:11,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5242/4699119 [32:37<464:27:14,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5243/4699119 [32:37<513:48:38,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5244/4699119 [32:38<547:52:41,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5245/4699119 [32:38<498:23:52,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5246/4699119 [32:38<441:20:59,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5247/4699119 [32:39<418:20:26,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 314, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5248/4699119 [32:39<422:35:15,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5249/4699119 [32:39<387:40:12,  3.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 137, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5250/4699119 [32:39<348:35:55,  3.74it/s]

last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5251/4699119 [32:40<399:33:06,  3.26it/s]

last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5252/4699119 [32:40<445:56:09,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5253/4699119 [32:41<503:15:15,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5254/4699119 [32:41<500:50:17,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5255/4699119 [32:41<488:27:33,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 507, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5256/4699119 [32:42<535:38:59,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5257/4699119 [32:42<516:04:42,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5258/4699119 [32:43<535:04:35,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5259/4699119 [32:43<521:02:23,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5260/4699119 [32:43<469:59:24,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5261/4699119 [32:44<475:57:32,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5262/4699119 [32:44<433:13:47,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5263/4699119 [32:44<489:59:41,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 185, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5264/4699119 [32:45<434:05:18,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5265/4699119 [32:45<439:30:53,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5266/4699119 [32:45<496:10:40,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5267/4699119 [32:46<512:22:00,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5268/4699119 [32:46<483:34:29,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5269/4699119 [32:46<457:19:41,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5270/4699119 [32:47<512:31:01,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5271/4699119 [32:47<473:57:06,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5272/4699119 [32:48<451:59:07,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5273/4699119 [32:48<492:08:45,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5274/4699119 [32:48<476:59:11,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5275/4699119 [32:49<523:07:57,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5276/4699119 [32:49<518:27:45,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5277/4699119 [32:49<466:23:44,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5278/4699119 [32:50<482:17:42,  2.70it/s]

last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5279/4699119 [32:50<461:52:16,  2.82it/s]

last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5280/4699119 [32:50<432:30:20,  3.01it/s]

last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5281/4699119 [32:51<404:50:47,  3.22it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5282/4699119 [32:51<471:31:17,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5283/4699119 [32:52<493:03:20,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 510, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5284/4699119 [32:52<534:55:24,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5285/4699119 [32:53<562:41:16,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5286/4699119 [32:53<562:14:58,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5287/4699119 [32:54<581:26:56,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5288/4699119 [32:54<596:16:19,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5289/4699119 [32:54<605:27:04,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5290/4699119 [32:55<557:32:42,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5291/4699119 [32:55<530:15:48,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5292/4699119 [32:55<463:17:34,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5293/4699119 [32:56<455:26:15,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5294/4699119 [32:56<443:02:10,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5295/4699119 [32:56<434:41:41,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5296/4699119 [32:57<493:52:21,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5297/4699119 [32:57<469:28:43,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5298/4699119 [32:57<433:43:14,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5299/4699119 [32:58<404:01:04,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5300/4699119 [32:58<471:10:51,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5301/4699119 [32:59<496:48:34,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5302/4699119 [32:59<442:35:10,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5303/4699119 [32:59<472:13:03,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5304/4699119 [33:00<455:03:25,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5305/4699119 [33:00<396:39:28,  3.29it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5306/4699119 [33:00<465:28:38,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5307/4699119 [33:01<491:20:06,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5308/4699119 [33:01<531:37:17,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5309/4699119 [33:01<472:44:58,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5310/4699119 [33:02<446:34:30,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5311/4699119 [33:02<483:19:02,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5312/4699119 [33:03<525:05:46,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5313/4699119 [33:03<556:01:57,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5314/4699119 [33:03<511:13:50,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5315/4699119 [33:04<474:23:47,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5316/4699119 [33:04<519:11:12,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 199, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5317/4699119 [33:04<457:06:47,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5318/4699119 [33:05<455:36:54,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5319/4699119 [33:05<454:08:10,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5320/4699119 [33:06<509:12:00,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5321/4699119 [33:06<543:49:44,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5322/4699119 [33:07<528:35:42,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5323/4699119 [33:07<471:06:44,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5324/4699119 [33:07<454:37:54,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5325/4699119 [33:07<433:23:16,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5326/4699119 [33:08<490:54:01,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 496, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5327/4699119 [33:08<527:11:08,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5328/4699119 [33:09<462:25:05,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5329/4699119 [33:09<454:05:44,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5330/4699119 [33:09<506:12:31,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5331/4699119 [33:10<529:26:04,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5332/4699119 [33:10<500:39:24,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5333/4699119 [33:11<538:43:43,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5334/4699119 [33:11<472:05:05,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5335/4699119 [33:11<478:21:20,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5336/4699119 [33:12<522:58:15,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5337/4699119 [33:12<467:45:24,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5338/4699119 [33:12<435:48:57,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5339/4699119 [33:13<441:45:26,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5340/4699119 [33:13<451:50:42,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5341/4699119 [33:13<504:59:54,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5342/4699119 [33:14<515:15:06,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5343/4699119 [33:14<475:06:09,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5344/4699119 [33:14<436:37:19,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5345/4699119 [33:15<464:09:12,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5346/4699119 [33:15<514:24:47,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5347/4699119 [33:16<466:51:36,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5348/4699119 [33:16<478:11:02,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5349/4699119 [33:16<429:13:28,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5350/4699119 [33:17<441:47:40,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5351/4699119 [33:17<497:51:29,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5352/4699119 [33:17<468:04:10,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5353/4699119 [33:18<480:55:07,  2.71it/s]

last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5354/4699119 [33:18<440:25:44,  2.96it/s]

last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5355/4699119 [33:18<438:46:32,  2.97it/s]

last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5356/4699119 [33:19<430:50:34,  3.03it/s]

last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5357/4699119 [33:19<433:44:34,  3.01it/s]

last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5358/4699119 [33:20<492:39:13,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5359/4699119 [33:20<502:38:39,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5361/4699119 [33:21<434:24:10,  3.00it/s]

last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5362/4699119 [33:21<491:14:42,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5363/4699119 [33:21<532:10:52,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5364/4699119 [33:22<543:05:33,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5365/4699119 [33:22<569:01:10,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5366/4699119 [33:23<588:19:18,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5367/4699119 [33:23<600:50:19,  2.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5368/4699119 [33:24<554:34:28,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5369/4699119 [33:24<575:33:38,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5370/4699119 [33:25<591:47:15,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5371/4699119 [33:25<524:35:55,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5372/4699119 [33:25<505:28:01,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5373/4699119 [33:26<471:48:47,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 330, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5374/4699119 [33:26<463:38:40,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5375/4699119 [33:26<448:20:22,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5376/4699119 [33:27<475:13:09,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 400, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5377/4699119 [33:27<487:02:58,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 477, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5378/4699119 [33:28<521:26:53,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5379/4699119 [33:28<498:19:42,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5380/4699119 [33:28<484:11:59,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5381/4699119 [33:28<432:14:13,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5382/4699119 [33:29<456:41:44,  2.85it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5383/4699119 [33:29<508:06:10,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5384/4699119 [33:30<507:42:33,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5385/4699119 [33:30<461:36:33,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5386/4699119 [33:30<439:28:55,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 444, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5387/4699119 [33:31<475:17:53,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5388/4699119 [33:31<521:38:42,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5389/4699119 [33:32<558:08:08,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 399, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5390/4699119 [33:32<547:24:42,  2.38it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5391/4699119 [33:33<571:29:48,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5392/4699119 [33:33<556:38:44,  2.34it/s]

last_hidden_state = torch.Size([8, 479, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5393/4699119 [33:33<572:20:07,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5394/4699119 [33:34<517:24:18,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5395/4699119 [33:34<470:07:18,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5396/4699119 [33:34<482:48:20,  2.70it/s]

last_hidden_state = torch.Size([8, 244, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5397/4699119 [33:35<443:34:23,  2.94it/s]

last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5398/4699119 [33:35<451:28:14,  2.89it/s]

last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5399/4699119 [33:35<454:19:12,  2.87it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5400/4699119 [33:36<506:39:00,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 119, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5401/4699119 [33:36<427:17:25,  3.05it/s]

last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5402/4699119 [33:36<403:52:50,  3.23it/s]

last_hidden_state = torch.Size([8, 430, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5403/4699119 [33:37<445:00:10,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5404/4699119 [33:37<449:13:08,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5405/4699119 [33:38<502:19:09,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 425, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5406/4699119 [33:38<518:01:49,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5407/4699119 [33:38<470:45:54,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5408/4699119 [33:39<458:36:55,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 447, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5409/4699119 [33:39<492:19:50,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5410/4699119 [33:39<438:54:36,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 462, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5411/4699119 [33:40<478:46:52,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5412/4699119 [33:40<477:40:50,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5413/4699119 [33:41<522:15:57,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5414/4699119 [33:41<455:26:18,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5415/4699119 [33:41<402:54:49,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 507, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5416/4699119 [33:42<475:00:40,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5417/4699119 [33:42<520:37:38,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5418/4699119 [33:42<512:38:30,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5419/4699119 [33:43<478:39:25,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5420/4699119 [33:43<450:26:33,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5421/4699119 [33:43<468:55:38,  2.78it/s]

last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5422/4699119 [33:44<486:44:43,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5423/4699119 [33:44<503:13:15,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5424/4699119 [33:45<540:19:30,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5425/4699119 [33:45<529:38:04,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5426/4699119 [33:46<566:52:44,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5427/4699119 [33:46<497:55:07,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5428/4699119 [33:46<485:10:29,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5429/4699119 [33:46<456:56:37,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5430/4699119 [33:47<410:27:35,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5431/4699119 [33:47<388:38:53,  3.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5432/4699119 [33:47<397:12:03,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5433/4699119 [33:48<402:44:58,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5434/4699119 [33:48<445:12:14,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5435/4699119 [33:48<446:53:05,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 89, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5436/4699119 [33:49<377:20:53,  3.46it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5437/4699119 [33:49<452:39:24,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5438/4699119 [33:49<460:33:43,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5439/4699119 [33:50<488:56:50,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5440/4699119 [33:50<447:31:52,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 481, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5441/4699119 [33:51<502:35:12,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5442/4699119 [33:51<508:16:07,  2.57it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5443/4699119 [33:51<543:01:04,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5444/4699119 [33:52<568:44:30,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5445/4699119 [33:52<500:41:02,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5446/4699119 [33:53<540:41:16,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 346, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5447/4699119 [33:53<515:20:36,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5448/4699119 [33:53<502:32:57,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5449/4699119 [33:54<539:55:56,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5450/4699119 [33:54<478:42:10,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 412, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5451/4699119 [33:55<493:05:21,  2.64it/s]

fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 127, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5452/4699119 [33:55<418:27:14,  3.12it/s]

last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5453/4699119 [33:55<412:03:14,  3.16it/s]

last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5455/4699119 [33:56<366:28:23,  3.56it/s]

last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5456/4699119 [33:56<445:18:57,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5457/4699119 [33:56<485:35:17,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5458/4699119 [33:57<434:03:28,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5459/4699119 [33:57<457:25:04,  2.85it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5460/4699119 [33:58<508:43:45,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 163, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5461/4699119 [33:58<439:27:36,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5462/4699119 [33:58<413:40:44,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5463/4699119 [33:58<451:32:22,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5464/4699119 [33:59<399:48:28,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 468, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5465/4699119 [33:59<455:46:19,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 158, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5466/4699119 [33:59<399:00:19,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5467/4699119 [34:00<397:41:23,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5468/4699119 [34:00<445:07:39,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5470/4699119 [34:01<425:47:32,  3.06it/s]

last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5471/4699119 [34:01<424:59:24,  3.07it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5472/4699119 [34:02<484:55:16,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5473/4699119 [34:02<463:15:47,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 494, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5474/4699119 [34:02<509:49:30,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5475/4699119 [34:03<544:26:43,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 136, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5476/4699119 [34:03<457:06:12,  2.85it/s]

last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5477/4699119 [34:03<459:12:34,  2.84it/s]

last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5478/4699119 [34:04<482:17:23,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5479/4699119 [34:04<431:51:40,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5480/4699119 [34:04<490:02:34,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5481/4699119 [34:05<534:21:32,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 490, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5482/4699119 [34:05<558:45:16,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 97, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5484/4699119 [34:06<415:13:54,  3.14it/s]

last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5486/4699119 [34:06<378:23:04,  3.45it/s]

last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5487/4699119 [34:07<400:19:39,  3.26it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5488/4699119 [34:07<468:47:00,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5489/4699119 [34:07<446:08:27,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 174, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5490/4699119 [34:08<396:54:23,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5491/4699119 [34:08<444:24:22,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 429, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5492/4699119 [34:09<477:21:31,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5493/4699119 [34:09<436:03:08,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5494/4699119 [34:09<493:15:30,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5495/4699119 [34:10<533:29:18,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5496/4699119 [34:10<520:22:59,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 411, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5497/4699119 [34:11<526:14:33,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5498/4699119 [34:11<526:30:00,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5499/4699119 [34:11<535:29:39,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 93, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5500/4699119 [34:12<440:01:10,  2.96it/s]

last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5501/4699119 [34:12<457:19:58,  2.85it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5502/4699119 [34:12<507:53:25,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 198, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5503/4699119 [34:13<448:14:56,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 110, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5505/4699119 [34:13<346:54:24,  3.76it/s]

last_hidden_state = torch.Size([8, 151, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5506/4699119 [34:13<367:02:02,  3.55it/s]

last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5507/4699119 [34:14<392:00:56,  3.33it/s]

last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5508/4699119 [34:14<376:02:52,  3.47it/s]

last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5509/4699119 [34:14<410:11:14,  3.18it/s]

last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5510/4699119 [34:15<426:16:40,  3.06it/s]

last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5511/4699119 [34:15<408:39:02,  3.19it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5512/4699119 [34:15<473:36:07,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5513/4699119 [34:16<519:10:42,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5514/4699119 [34:16<551:23:08,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 399, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5515/4699119 [34:17<541:46:35,  2.41it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5516/4699119 [34:17<566:44:57,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5517/4699119 [34:18<551:31:35,  2.36it/s]

last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5518/4699119 [34:18<562:12:18,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5519/4699119 [34:19<581:54:14,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5520/4699119 [34:19<526:32:49,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5521/4699119 [34:19<562:20:46,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5522/4699119 [34:20<581:36:37,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5523/4699119 [34:20<517:46:54,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5524/4699119 [34:21<550:03:47,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5525/4699119 [34:21<571:55:40,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5526/4699119 [34:22<588:06:29,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5527/4699119 [34:22<509:48:42,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5528/4699119 [34:22<544:52:27,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5529/4699119 [34:23<568:38:54,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5530/4699119 [34:23<518:02:19,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5531/4699119 [34:23<470:02:17,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5532/4699119 [34:24<464:32:54,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5533/4699119 [34:24<491:51:26,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 64, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5534/4699119 [34:24<401:40:39,  3.25it/s]

last_hidden_state = torch.Size([8, 425, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5535/4699119 [34:25<447:25:34,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5536/4699119 [34:25<502:08:01,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5537/4699119 [34:26<469:04:07,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5538/4699119 [34:26<426:43:01,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5539/4699119 [34:26<473:55:19,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5540/4699119 [34:27<519:50:06,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 138, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5541/4699119 [34:27<441:06:58,  2.96it/s]

last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5542/4699119 [34:27<423:58:15,  3.08it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5543/4699119 [34:28<484:37:39,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5544/4699119 [34:28<505:33:05,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5545/4699119 [34:28<472:36:06,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5546/4699119 [34:29<455:31:13,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5547/4699119 [34:29<457:49:49,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5548/4699119 [34:30<506:31:12,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5549/4699119 [34:30<542:12:35,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5550/4699119 [34:31<567:32:28,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5551/4699119 [34:31<526:47:04,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5552/4699119 [34:31<467:56:33,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5553/4699119 [34:31<467:16:17,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 240, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5554/4699119 [34:32<431:11:11,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5555/4699119 [34:32<462:07:38,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5556/4699119 [34:32<454:27:54,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5557/4699119 [34:33<480:27:44,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5558/4699119 [34:33<477:19:27,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5559/4699119 [34:34<522:47:54,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5560/4699119 [34:34<526:49:32,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5561/4699119 [34:35<546:18:17,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5562/4699119 [34:35<570:52:31,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5563/4699119 [34:36<559:55:09,  2.33it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5564/4699119 [34:36<579:40:03,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5565/4699119 [34:36<593:41:24,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5566/4699119 [34:37<582:47:45,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5567/4699119 [34:37<597:15:41,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5568/4699119 [34:38<578:06:31,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5569/4699119 [34:38<562:41:03,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5570/4699119 [34:38<494:11:11,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5571/4699119 [34:39<537:37:46,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5572/4699119 [34:39<564:09:55,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5573/4699119 [34:40<514:39:57,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5574/4699119 [34:40<490:51:43,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 164, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5575/4699119 [34:40<428:04:54,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 155, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5576/4699119 [34:40<383:17:36,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5577/4699119 [34:41<446:00:50,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 510, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5578/4699119 [34:41<501:58:59,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5579/4699119 [34:42<501:06:05,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5580/4699119 [34:42<539:07:02,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5581/4699119 [34:43<558:30:36,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5582/4699119 [34:43<578:22:11,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5583/4699119 [34:43<495:39:04,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5584/4699119 [34:44<535:56:20,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5585/4699119 [34:44<512:21:48,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5586/4699119 [34:45<547:32:33,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5587/4699119 [34:45<502:01:08,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5588/4699119 [34:45<444:50:31,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5589/4699119 [34:46<499:08:57,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5590/4699119 [34:46<483:59:02,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5591/4699119 [34:46<454:31:30,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 305, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5592/4699119 [34:47<447:02:37,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5593/4699119 [34:47<501:20:43,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 168, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5594/4699119 [34:47<434:23:56,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5595/4699119 [34:48<492:08:59,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5596/4699119 [34:48<500:05:02,  2.61it/s]

last_hidden_state = torch.Size([8, 508, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5597/4699119 [34:49<539:08:15,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5598/4699119 [34:49<565:40:04,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5599/4699119 [34:50<585:35:11,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 468, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5600/4699119 [34:50<585:21:07,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 146, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5601/4699119 [34:50<488:22:02,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5602/4699119 [34:51<461:12:03,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5603/4699119 [34:51<439:08:12,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5604/4699119 [34:51<458:24:41,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5605/4699119 [34:52<468:55:07,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5606/4699119 [34:52<442:50:16,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5607/4699119 [34:52<411:39:23,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5608/4699119 [34:53<382:10:11,  3.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5609/4699119 [34:53<374:50:40,  3.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 346, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5610/4699119 [34:53<399:46:01,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5611/4699119 [34:54<468:07:42,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5612/4699119 [34:54<485:09:27,  2.69it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5613/4699119 [34:55<528:09:19,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5614/4699119 [34:55<558:34:37,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5615/4699119 [34:55<548:14:33,  2.38it/s]

last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5616/4699119 [34:56<509:47:44,  2.56it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5617/4699119 [34:56<545:27:41,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 479, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5618/4699119 [34:57<562:44:10,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5619/4699119 [34:57<517:32:35,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5620/4699119 [34:57<471:53:13,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5621/4699119 [34:58<476:59:09,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5622/4699119 [34:58<511:18:59,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5623/4699119 [34:59<545:30:05,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5624/4699119 [34:59<507:06:23,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5625/4699119 [34:59<492:56:44,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5626/4699119 [35:00<460:45:17,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5627/4699119 [35:00<428:30:19,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 490, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5628/4699119 [35:00<485:37:32,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5629/4699119 [35:01<506:21:05,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5631/4699119 [35:01<444:42:49,  2.93it/s]

last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5632/4699119 [35:02<390:00:07,  3.34it/s]

last_hidden_state = torch.Size([8, 146, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5633/4699119 [35:02<440:59:49,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5634/4699119 [35:03<486:26:07,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5635/4699119 [35:03<441:32:33,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5636/4699119 [35:03<409:35:27,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5637/4699119 [35:04<474:36:56,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5638/4699119 [35:04<432:17:37,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5639/4699119 [35:04<439:01:37,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 495, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5640/4699119 [35:05<499:27:58,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5641/4699119 [35:05<500:25:45,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5642/4699119 [35:05<512:25:44,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5643/4699119 [35:06<546:26:55,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5644/4699119 [35:06<476:46:04,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5645/4699119 [35:06<478:02:59,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5646/4699119 [35:07<522:49:26,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5647/4699119 [35:07<554:57:26,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5648/4699119 [35:08<513:35:51,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 148, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5649/4699119 [35:08<439:00:56,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5650/4699119 [35:08<399:00:11,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5651/4699119 [35:09<433:15:55,  3.01it/s]

last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5653/4699119 [35:09<405:47:32,  3.21it/s]

last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5654/4699119 [35:10<423:52:28,  3.08it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5655/4699119 [35:10<484:31:36,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 284, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5656/4699119 [35:10<458:35:17,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5657/4699119 [35:11<510:13:45,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5658/4699119 [35:11<497:37:40,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 488, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5659/4699119 [35:12<531:14:35,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5660/4699119 [35:12<526:53:33,  2.47it/s]

last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5661/4699119 [35:12<484:11:53,  2.69it/s]

last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5662/4699119 [35:13<453:16:31,  2.88it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5663/4699119 [35:13<505:20:34,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5664/4699119 [35:14<542:36:25,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5665/4699119 [35:14<516:44:08,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5666/4699119 [35:14<441:35:52,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5667/4699119 [35:15<496:39:49,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5668/4699119 [35:15<439:50:56,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5669/4699119 [35:15<398:32:02,  3.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5670/4699119 [35:15<413:42:01,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5671/4699119 [35:16<423:30:07,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5672/4699119 [35:16<483:47:30,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5673/4699119 [35:17<482:11:14,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5674/4699119 [35:17<526:50:34,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5675/4699119 [35:18<559:40:34,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5676/4699119 [35:18<579:50:39,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5677/4699119 [35:18<525:19:57,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 161, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5678/4699119 [35:19<451:36:23,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5679/4699119 [35:19<505:40:37,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5680/4699119 [35:19<432:33:13,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5681/4699119 [35:20<490:04:42,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5682/4699119 [35:20<534:41:32,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5683/4699119 [35:21<523:47:15,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5684/4699119 [35:21<485:27:37,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 495, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5685/4699119 [35:21<530:24:08,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5686/4699119 [35:22<531:28:46,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5687/4699119 [35:22<545:07:44,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5688/4699119 [35:23<474:29:10,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 155, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5689/4699119 [35:23<412:05:14,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5690/4699119 [35:23<477:50:57,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5691/4699119 [35:24<522:32:40,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 307, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5692/4699119 [35:24<495:26:37,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5693/4699119 [35:24<509:16:25,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5694/4699119 [35:25<496:56:28,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5695/4699119 [35:25<467:45:32,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5696/4699119 [35:25<426:57:37,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 146, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5697/4699119 [35:26<378:15:22,  3.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5698/4699119 [35:26<441:20:33,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5699/4699119 [35:26<463:57:01,  2.81it/s]

last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5700/4699119 [35:27<423:04:50,  3.08it/s]

last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5701/4699119 [35:27<414:14:36,  3.15it/s]

last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5702/4699119 [35:27<386:42:22,  3.37it/s]

last_hidden_state = torch.Size([8, 477, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5703/4699119 [35:28<451:13:16,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5704/4699119 [35:28<476:52:46,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5705/4699119 [35:29<521:56:49,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5706/4699119 [35:29<490:21:53,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5707/4699119 [35:29<476:48:33,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5708/4699119 [35:30<477:16:26,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5709/4699119 [35:30<521:43:03,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5710/4699119 [35:30<503:02:10,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 124, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5711/4699119 [35:31<425:34:49,  3.06it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5712/4699119 [35:31<485:51:48,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5713/4699119 [35:31<479:34:36,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5714/4699119 [35:32<455:52:58,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5716/4699119 [35:32<425:03:41,  3.07it/s]

last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5717/4699119 [35:33<491:46:55,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5718/4699119 [35:33<449:41:29,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5719/4699119 [35:33<432:04:30,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5720/4699119 [35:34<403:48:00,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5721/4699119 [35:34<437:55:59,  2.98it/s]

last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5722/4699119 [35:34<436:42:44,  2.99it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5723/4699119 [35:35<494:02:36,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5724/4699119 [35:35<534:04:39,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5725/4699119 [35:36<549:37:03,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 271, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5726/4699119 [35:36<502:46:07,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5727/4699119 [35:36<471:38:16,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 130, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5728/4699119 [35:37<406:07:12,  3.21it/s]

last_hidden_state = torch.Size([8, 355, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5730/4699119 [35:37<383:03:28,  3.40it/s]

last_hidden_state = torch.Size([8, 170, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 268, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5731/4699119 [35:38<385:04:54,  3.39it/s]

last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5732/4699119 [35:38<394:52:58,  3.30it/s]

last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5733/4699119 [35:38<466:49:00,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5734/4699119 [35:39<513:27:14,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5735/4699119 [35:39<463:22:55,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5736/4699119 [35:39<465:39:46,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5737/4699119 [35:40<467:23:25,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5738/4699119 [35:40<464:15:32,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 454, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5739/4699119 [35:41<496:35:46,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5740/4699119 [35:41<445:07:38,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5741/4699119 [35:41<501:04:28,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5742/4699119 [35:42<540:10:31,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 152, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5743/4699119 [35:42<456:50:18,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5744/4699119 [35:43<507:55:40,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5745/4699119 [35:43<543:31:51,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5746/4699119 [35:43<568:58:00,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5747/4699119 [35:44<518:12:01,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 437, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5748/4699119 [35:44<532:39:15,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 123, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5749/4699119 [35:44<445:55:03,  2.92it/s]

last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5750/4699119 [35:45<465:03:36,  2.80it/s]

last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5751/4699119 [35:45<458:48:47,  2.84it/s]

last_hidden_state = torch.Size([8, 510, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5752/4699119 [35:46<510:15:25,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5753/4699119 [35:46<522:44:47,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5754/4699119 [35:47<553:23:55,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5755/4699119 [35:47<575:26:16,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5756/4699119 [35:47<543:28:38,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5757/4699119 [35:48<568:47:14,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5758/4699119 [35:48<585:22:38,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5759/4699119 [35:49<598:03:15,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5760/4699119 [35:49<513:49:33,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5761/4699119 [35:49<514:35:51,  2.53it/s]

last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5762/4699119 [35:50<525:35:14,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5763/4699119 [35:50<533:00:37,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5764/4699119 [35:51<548:22:08,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5765/4699119 [35:51<572:23:35,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5766/4699119 [35:52<516:17:37,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5767/4699119 [35:52<549:05:49,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5768/4699119 [35:52<475:37:29,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5769/4699119 [35:53<450:24:29,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5770/4699119 [35:53<502:01:53,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5771/4699119 [35:53<504:06:31,  2.59it/s]

last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5772/4699119 [35:54<492:25:00,  2.65it/s]

last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5773/4699119 [35:54<463:09:14,  2.81it/s]

last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5774/4699119 [35:55<500:38:55,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5775/4699119 [35:55<441:13:40,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5776/4699119 [35:55<496:21:12,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5777/4699119 [35:55<448:12:53,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 503, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5778/4699119 [35:56<506:15:00,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5779/4699119 [35:56<520:38:59,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5780/4699119 [35:57<554:01:20,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5781/4699119 [35:57<492:07:31,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5782/4699119 [35:57<434:32:08,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5783/4699119 [35:58<420:01:43,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5784/4699119 [35:58<430:32:11,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5785/4699119 [35:59<488:56:06,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5786/4699119 [35:59<461:11:48,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5787/4699119 [35:59<511:29:29,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5788/4699119 [36:00<524:38:09,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5789/4699119 [36:00<554:57:19,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5790/4699119 [36:01<512:40:20,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5791/4699119 [36:01<460:13:19,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5792/4699119 [36:01<476:48:57,  2.73it/s]

last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5793/4699119 [36:02<493:32:57,  2.64it/s]

last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5795/4699119 [36:02<406:42:12,  3.21it/s]

last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5796/4699119 [36:03<473:09:40,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5797/4699119 [36:03<425:24:23,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5798/4699119 [36:03<423:37:17,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 432, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5799/4699119 [36:04<458:46:10,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5800/4699119 [36:04<510:22:55,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5801/4699119 [36:05<546:05:01,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5802/4699119 [36:05<570:23:58,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5803/4699119 [36:05<502:13:45,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5804/4699119 [36:06<538:48:17,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5805/4699119 [36:06<565:12:49,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5806/4699119 [36:06<502:32:06,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5807/4699119 [36:07<489:02:18,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5808/4699119 [36:07<533:15:39,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5809/4699119 [36:08<480:06:54,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 323, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5810/4699119 [36:08<470:40:12,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5811/4699119 [36:08<510:16:21,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5812/4699119 [36:09<462:54:26,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 511, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5813/4699119 [36:09<517:58:39,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5814/4699119 [36:10<511:42:31,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5815/4699119 [36:10<484:56:21,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 412, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5816/4699119 [36:10<497:29:33,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5817/4699119 [36:11<456:59:17,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5818/4699119 [36:11<457:58:13,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 169, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5819/4699119 [36:11<405:11:16,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 249, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5820/4699119 [36:11<391:45:52,  3.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5821/4699119 [36:12<410:07:28,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5822/4699119 [36:12<422:41:51,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5823/4699119 [36:13<484:41:10,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5824/4699119 [36:13<432:46:24,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 130, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5825/4699119 [36:13<378:57:13,  3.44it/s]

last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5826/4699119 [36:13<374:54:13,  3.48it/s]

last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5827/4699119 [36:14<369:39:25,  3.53it/s]

last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5828/4699119 [36:14<376:17:50,  3.46it/s]

last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5829/4699119 [36:14<433:41:05,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 499, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5830/4699119 [36:15<496:03:19,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5831/4699119 [36:15<478:32:51,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5832/4699119 [36:16<522:45:54,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5833/4699119 [36:16<554:05:33,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5834/4699119 [36:16<538:13:28,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5835/4699119 [36:17<564:41:06,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5836/4699119 [36:17<523:11:48,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5837/4699119 [36:18<491:03:57,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5838/4699119 [36:18<467:03:23,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5839/4699119 [36:18<433:53:18,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 162, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5840/4699119 [36:18<386:29:27,  3.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5841/4699119 [36:19<463:37:12,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5842/4699119 [36:19<515:03:16,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 450, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5843/4699119 [36:20<531:23:37,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 157, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5844/4699119 [36:20<452:11:01,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5845/4699119 [36:20<455:14:11,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5846/4699119 [36:21<426:00:09,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5847/4699119 [36:21<460:18:14,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5848/4699119 [36:22<515:50:11,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 384, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5849/4699119 [36:22<508:11:54,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5850/4699119 [36:22<474:17:27,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 157, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5851/4699119 [36:22<412:36:22,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5852/4699119 [36:23<456:07:59,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5853/4699119 [36:23<508:17:10,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5854/4699119 [36:24<455:50:24,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5855/4699119 [36:24<414:26:04,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 137, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5856/4699119 [36:24<367:10:26,  3.55it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5857/4699119 [36:25<446:25:50,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5858/4699119 [36:25<479:54:30,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5859/4699119 [36:25<523:49:18,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5860/4699119 [36:26<461:33:52,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5861/4699119 [36:26<511:02:22,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 163, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5862/4699119 [36:26<442:02:20,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5863/4699119 [36:27<482:37:35,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5864/4699119 [36:27<434:56:10,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5865/4699119 [36:27<411:30:39,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 171, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5866/4699119 [36:28<373:36:30,  3.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5867/4699119 [36:28<450:08:38,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5868/4699119 [36:28<449:39:27,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 67, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5869/4699119 [36:29<375:28:37,  3.47it/s]

last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5870/4699119 [36:29<428:47:20,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5871/4699119 [36:29<406:52:11,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5872/4699119 [36:29<379:35:52,  3.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 235, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5873/4699119 [36:30<371:38:35,  3.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5874/4699119 [36:30<354:40:05,  3.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 293, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5875/4699119 [36:30<372:54:33,  3.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5876/4699119 [36:31<390:16:25,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5877/4699119 [36:31<462:29:42,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 473, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5878/4699119 [36:32<503:45:51,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5879/4699119 [36:32<487:11:26,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5880/4699119 [36:32<528:23:23,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5881/4699119 [36:33<500:16:30,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5882/4699119 [36:33<538:01:57,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5883/4699119 [36:34<563:44:31,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 500, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5884/4699119 [36:34<582:40:15,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5885/4699119 [36:34<510:07:45,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5886/4699119 [36:35<545:32:18,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5887/4699119 [36:35<497:22:48,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5888/4699119 [36:36<536:29:51,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5889/4699119 [36:36<500:00:00,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5890/4699119 [36:36<491:09:42,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 119, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5891/4699119 [36:37<416:39:12,  3.13it/s]

last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5892/4699119 [36:37<395:31:15,  3.30it/s]

last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5893/4699119 [36:37<415:02:07,  3.14it/s]

last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5894/4699119 [36:37<405:26:08,  3.22it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5895/4699119 [36:38<471:19:17,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5896/4699119 [36:38<517:45:09,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5897/4699119 [36:39<550:29:29,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 145, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5898/4699119 [36:39<464:50:48,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 415, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5899/4699119 [36:40<487:41:23,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5900/4699119 [36:40<439:27:53,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5901/4699119 [36:40<433:32:12,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5902/4699119 [36:40<420:21:58,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5903/4699119 [36:41<400:43:25,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5904/4699119 [36:41<471:09:13,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5905/4699119 [36:41<453:02:56,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5906/4699119 [36:42<447:27:41,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5907/4699119 [36:42<502:28:27,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5908/4699119 [36:43<539:19:04,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 119, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5909/4699119 [36:43<450:20:20,  2.89it/s]

last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5910/4699119 [36:43<481:30:32,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 246, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5911/4699119 [36:44<443:36:18,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5912/4699119 [36:44<416:59:50,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5913/4699119 [36:44<406:24:28,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5914/4699119 [36:45<399:57:50,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5915/4699119 [36:45<467:21:15,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5916/4699119 [36:46<517:24:32,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5917/4699119 [36:46<550:16:53,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5918/4699119 [36:46<487:04:35,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5919/4699119 [36:47<528:17:10,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5920/4699119 [36:47<557:22:56,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5921/4699119 [36:48<518:36:53,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5922/4699119 [36:48<552:01:09,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 365, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5923/4699119 [36:48<530:30:36,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5924/4699119 [36:49<478:01:48,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5925/4699119 [36:49<487:12:52,  2.68it/s]

last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5926/4699119 [36:49<517:54:18,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 503, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5927/4699119 [36:50<555:19:01,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5928/4699119 [36:50<520:26:44,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5929/4699119 [36:51<560:01:46,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 171, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5930/4699119 [36:51<477:09:03,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5931/4699119 [36:52<521:44:13,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5932/4699119 [36:52<502:01:11,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5933/4699119 [36:52<539:13:43,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 153, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5934/4699119 [36:53<457:06:36,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5935/4699119 [36:53<454:17:20,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5936/4699119 [36:53<505:46:56,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5937/4699119 [36:54<542:05:40,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 439, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5938/4699119 [36:54<549:59:37,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 132, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5939/4699119 [36:54<461:03:04,  2.83it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5940/4699119 [36:55<510:35:16,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5941/4699119 [36:55<543:59:23,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 314, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5942/4699119 [36:56<509:30:24,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5943/4699119 [36:56<544:22:20,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5944/4699119 [36:57<532:58:46,  2.45it/s]

last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5945/4699119 [36:57<538:34:35,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5946/4699119 [36:58<565:25:52,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5947/4699119 [36:58<487:15:07,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 467, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5948/4699119 [36:58<520:14:51,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 511, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5949/4699119 [36:59<557:11:47,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5950/4699119 [36:59<526:28:04,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5951/4699119 [37:00<556:11:29,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5952/4699119 [37:00<540:41:21,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5953/4699119 [37:00<527:13:20,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5954/4699119 [37:01<454:48:12,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5955/4699119 [37:01<422:16:29,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5956/4699119 [37:01<445:07:28,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5957/4699119 [37:02<477:43:14,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 199, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5958/4699119 [37:02<428:33:43,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 482, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5959/4699119 [37:02<484:21:40,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5960/4699119 [37:03<514:39:30,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5961/4699119 [37:03<535:49:13,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5962/4699119 [37:04<494:41:45,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5963/4699119 [37:04<489:52:48,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5964/4699119 [37:04<532:38:06,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5965/4699119 [37:05<523:32:20,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5966/4699119 [37:05<503:38:24,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5967/4699119 [37:05<446:09:32,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5968/4699119 [37:06<463:31:24,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5969/4699119 [37:06<513:27:11,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5970/4699119 [37:07<547:10:49,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5971/4699119 [37:07<527:43:53,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5972/4699119 [37:07<475:49:32,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5973/4699119 [37:08<521:24:22,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5974/4699119 [37:08<553:38:15,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5975/4699119 [37:09<487:56:37,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 458, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5976/4699119 [37:09<513:56:25,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 122, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5977/4699119 [37:09<432:28:21,  3.01it/s]

last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5978/4699119 [37:10<438:18:05,  2.97it/s]

last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5979/4699119 [37:10<421:42:27,  3.09it/s]

last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5980/4699119 [37:10<423:22:52,  3.08it/s]

last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5981/4699119 [37:11<424:39:53,  3.07it/s]

last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5982/4699119 [37:11<432:46:34,  3.01it/s]

last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5983/4699119 [37:11<407:14:39,  3.20it/s]

last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5984/4699119 [37:11<404:26:12,  3.22it/s]

last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5985/4699119 [37:12<449:54:54,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 98, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5986/4699119 [37:12<381:17:29,  3.42it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5987/4699119 [37:13<454:23:21,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5988/4699119 [37:13<467:24:27,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5989/4699119 [37:13<515:07:30,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 95, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5991/4699119 [37:14<390:04:32,  3.34it/s]

last_hidden_state = torch.Size([8, 196, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5992/4699119 [37:14<396:11:47,  3.29it/s]

last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 5993/4699119 [37:14<401:32:40,  3.25it/s]

last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5994/4699119 [37:15<443:18:33,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5995/4699119 [37:15<407:47:31,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 421, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5996/4699119 [37:15<451:14:37,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5997/4699119 [37:16<460:07:19,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5998/4699119 [37:16<453:17:31,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 5999/4699119 [37:16<437:53:47,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6000/4699119 [37:17<408:17:31,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6001/4699119 [37:17<474:18:35,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6002/4699119 [37:18<493:03:49,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6003/4699119 [37:18<481:32:59,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6004/4699119 [37:18<524:55:54,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6005/4699119 [37:19<499:08:02,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6006/4699119 [37:19<541:00:13,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6007/4699119 [37:20<503:20:58,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6008/4699119 [37:20<445:52:26,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6009/4699119 [37:20<463:57:59,  2.81it/s]

last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6010/4699119 [37:21<469:24:27,  2.78it/s]

last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6011/4699119 [37:21<456:54:48,  2.85it/s]

last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6012/4699119 [37:21<496:18:52,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6013/4699119 [37:22<477:14:39,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6014/4699119 [37:22<514:43:34,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6015/4699119 [37:23<548:01:00,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6016/4699119 [37:23<519:29:18,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6017/4699119 [37:24<551:34:46,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6018/4699119 [37:24<503:40:52,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6019/4699119 [37:24<480:21:39,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6020/4699119 [37:24<452:50:26,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6021/4699119 [37:25<505:23:32,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 346, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6022/4699119 [37:25<491:04:57,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 418, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6023/4699119 [37:26<507:09:03,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6024/4699119 [37:26<542:15:17,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6025/4699119 [37:27<554:25:19,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6026/4699119 [37:27<555:18:19,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6027/4699119 [37:28<575:15:09,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 508, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6028/4699119 [37:28<590:40:47,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6029/4699119 [37:28<602:13:00,  2.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 448, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6030/4699119 [37:29<589:13:24,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6031/4699119 [37:29<601:44:23,  2.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6032/4699119 [37:30<573:41:31,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6033/4699119 [37:30<591:09:09,  2.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6034/4699119 [37:31<539:22:40,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6035/4699119 [37:31<495:22:50,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 156, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6036/4699119 [37:31<426:54:08,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 228, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6037/4699119 [37:31<399:53:32,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6038/4699119 [37:32<378:04:30,  3.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6039/4699119 [37:32<381:20:22,  3.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6040/4699119 [37:32<457:32:07,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 480, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6041/4699119 [37:33<494:36:40,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6042/4699119 [37:33<460:52:23,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6043/4699119 [37:34<510:11:59,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6044/4699119 [37:34<484:39:05,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6045/4699119 [37:34<473:09:00,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 148, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6046/4699119 [37:34<410:12:42,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6047/4699119 [37:35<474:56:23,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6048/4699119 [37:35<449:55:11,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 188, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6049/4699119 [37:35<406:18:37,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6050/4699119 [37:36<473:27:03,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6051/4699119 [37:36<444:59:01,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6052/4699119 [37:37<452:27:25,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6053/4699119 [37:37<505:13:55,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6054/4699119 [37:37<444:52:53,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6055/4699119 [37:38<421:02:11,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 147, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6056/4699119 [37:38<374:02:55,  3.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6057/4699119 [37:38<415:21:30,  3.14it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6058/4699119 [37:39<478:11:07,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6059/4699119 [37:39<483:11:12,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 329, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6060/4699119 [37:39<474:03:42,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6061/4699119 [37:40<519:42:13,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6062/4699119 [37:40<513:15:11,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6063/4699119 [37:41<546:48:04,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6064/4699119 [37:41<570:32:43,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6065/4699119 [37:42<540:20:46,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6066/4699119 [37:42<478:57:46,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6067/4699119 [37:42<469:00:40,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6068/4699119 [37:43<506:01:46,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6069/4699119 [37:43<448:00:08,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6070/4699119 [37:43<407:23:50,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6071/4699119 [37:43<425:29:53,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6072/4699119 [37:44<478:42:17,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6073/4699119 [37:44<471:35:00,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6074/4699119 [37:45<425:05:22,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 270, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6075/4699119 [37:45<413:39:51,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6076/4699119 [37:45<477:19:56,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6077/4699119 [37:46<521:37:55,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 378, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6078/4699119 [37:46<513:03:30,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6079/4699119 [37:46<449:18:28,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6080/4699119 [37:47<502:15:00,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 280, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6081/4699119 [37:47<468:23:45,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6082/4699119 [37:47<437:57:45,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 486, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6083/4699119 [37:48<490:30:32,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6084/4699119 [37:48<519:27:15,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 348, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6085/4699119 [37:49<500:39:39,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6086/4699119 [37:49<499:09:55,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 299, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6087/4699119 [37:49<476:32:06,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6088/4699119 [37:50<461:38:01,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6089/4699119 [37:50<422:36:38,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6090/4699119 [37:50<414:39:30,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6091/4699119 [37:51<419:54:36,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 499, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6092/4699119 [37:51<487:08:45,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6093/4699119 [37:52<528:30:35,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6094/4699119 [37:52<557:41:04,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6095/4699119 [37:53<578:44:53,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 106, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6096/4699119 [37:53<472:31:58,  2.76it/s]

last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6097/4699119 [37:53<440:23:03,  2.96it/s]

last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6098/4699119 [37:53<437:08:39,  2.98it/s]

last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6099/4699119 [37:54<408:01:13,  3.19it/s]

last_hidden_state = torch.Size([8, 370, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6100/4699119 [37:54<432:13:57,  3.02it/s]

last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6101/4699119 [37:54<422:58:31,  3.08it/s]

last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6102/4699119 [37:55<395:29:05,  3.30it/s]

last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6103/4699119 [37:55<402:15:12,  3.24it/s]

last_hidden_state = torch.Size([8, 366, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6104/4699119 [37:55<423:50:44,  3.08it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6105/4699119 [37:56<485:06:41,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6106/4699119 [37:56<447:28:02,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6107/4699119 [37:56<460:09:24,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6108/4699119 [37:57<413:53:10,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6109/4699119 [37:57<439:39:01,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6110/4699119 [37:57<401:02:51,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 314, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6111/4699119 [37:58<410:10:15,  3.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6112/4699119 [37:58<390:32:59,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 136, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6113/4699119 [37:58<349:08:29,  3.73it/s]

last_hidden_state = torch.Size([8, 320, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6114/4699119 [37:58<372:48:38,  3.50it/s]

last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6115/4699119 [37:59<428:03:03,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6116/4699119 [37:59<487:42:08,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 430, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6117/4699119 [38:00<505:21:01,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6118/4699119 [38:00<473:42:09,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6119/4699119 [38:00<472:07:19,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6120/4699119 [38:01<450:49:00,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6121/4699119 [38:01<450:57:36,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6122/4699119 [38:01<467:43:54,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6123/4699119 [38:02<517:12:24,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 439, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6124/4699119 [38:02<532:06:28,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6125/4699119 [38:03<560:31:40,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6126/4699119 [38:03<521:18:48,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 392, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6127/4699119 [38:04<515:12:46,  2.53it/s]

last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6128/4699119 [38:04<478:28:42,  2.72it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6129/4699119 [38:04<525:03:32,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6130/4699119 [38:05<486:17:03,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6131/4699119 [38:05<527:24:35,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6132/4699119 [38:05<492:37:00,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6133/4699119 [38:06<446:37:49,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6134/4699119 [38:06<488:40:48,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6135/4699119 [38:06<447:42:20,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6136/4699119 [38:07<501:00:40,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6137/4699119 [38:07<441:08:20,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6138/4699119 [38:08<470:20:04,  2.77it/s]

last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6139/4699119 [38:08<426:18:05,  3.06it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6140/4699119 [38:08<487:26:58,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6141/4699119 [38:09<529:44:45,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6142/4699119 [38:09<560:00:37,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6143/4699119 [38:10<506:00:04,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6144/4699119 [38:10<455:24:18,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6145/4699119 [38:10<447:48:25,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 311, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6146/4699119 [38:10<443:01:55,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 429, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6147/4699119 [38:11<476:25:19,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6148/4699119 [38:11<442:41:27,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6149/4699119 [38:12<498:23:59,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6150/4699119 [38:12<526:57:04,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6151/4699119 [38:13<558:03:40,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6152/4699119 [38:13<566:25:32,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 124, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6153/4699119 [38:13<469:53:59,  2.77it/s]

last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6154/4699119 [38:14<464:57:00,  2.80it/s]

last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6156/4699119 [38:14<402:52:15,  3.24it/s]

last_hidden_state = torch.Size([8, 138, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6158/4699119 [38:15<351:23:16,  3.71it/s]

last_hidden_state = torch.Size([8, 158, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6159/4699119 [38:15<344:12:31,  3.79it/s]

last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6161/4699119 [38:15<347:04:35,  3.76it/s]

last_hidden_state = torch.Size([8, 181, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6162/4699119 [38:16<431:01:59,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6163/4699119 [38:16<479:06:19,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6164/4699119 [38:17<454:11:32,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 215, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6165/4699119 [38:17<418:49:05,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6166/4699119 [38:17<481:02:48,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6167/4699119 [38:18<463:43:20,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6168/4699119 [38:18<500:34:01,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6169/4699119 [38:18<497:53:05,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6170/4699119 [38:19<452:23:13,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6171/4699119 [38:19<435:14:35,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6172/4699119 [38:20<493:13:14,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6173/4699119 [38:20<471:28:45,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 464, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6174/4699119 [38:20<503:37:16,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6175/4699119 [38:21<520:07:28,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6176/4699119 [38:21<514:40:41,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6177/4699119 [38:22<547:57:17,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6178/4699119 [38:22<570:01:19,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6179/4699119 [38:23<586:26:04,  2.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6180/4699119 [38:23<490:28:19,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6181/4699119 [38:23<536:21:06,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 425, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6182/4699119 [38:24<542:41:57,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6183/4699119 [38:24<567:52:31,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 308, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6184/4699119 [38:24<526:25:29,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6185/4699119 [38:25<467:17:28,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 267, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6186/4699119 [38:25<443:34:38,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6187/4699119 [38:25<414:09:35,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6188/4699119 [38:26<427:12:36,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 422, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6189/4699119 [38:26<460:32:34,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6190/4699119 [38:27<509:55:51,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6191/4699119 [38:27<477:00:59,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 468, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6192/4699119 [38:27<508:28:52,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6193/4699119 [38:28<480:32:15,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 511, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6194/4699119 [38:28<529:35:27,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6195/4699119 [38:29<559:15:21,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6196/4699119 [38:29<538:18:32,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6198/4699119 [38:30<463:39:29,  2.81it/s]

last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6199/4699119 [38:30<404:28:18,  3.22it/s]

last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6200/4699119 [38:30<391:11:32,  3.33it/s]

last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6201/4699119 [38:30<418:21:26,  3.12it/s]

last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6202/4699119 [38:31<399:07:49,  3.27it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6203/4699119 [38:31<467:58:24,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6204/4699119 [38:31<406:15:27,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6205/4699119 [38:32<374:51:56,  3.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 119, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6207/4699119 [38:32<325:35:06,  4.00it/s]

last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6208/4699119 [38:32<321:18:11,  4.06it/s]

last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6209/4699119 [38:33<411:57:37,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6210/4699119 [38:33<475:53:27,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6211/4699119 [38:33<424:02:26,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6212/4699119 [38:34<484:28:56,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 396, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6213/4699119 [38:34<494:02:39,  2.64it/s]

last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6214/4699119 [38:35<446:09:50,  2.92it/s]

last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6215/4699119 [38:35<468:39:32,  2.78it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6216/4699119 [38:36<519:39:36,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6217/4699119 [38:36<488:18:05,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 346, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6218/4699119 [38:36<478:34:24,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6219/4699119 [38:37<497:58:17,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6220/4699119 [38:37<472:48:07,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6221/4699119 [38:37<467:19:16,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 485, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6222/4699119 [38:38<516:59:15,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6223/4699119 [38:38<485:35:46,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 337, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6224/4699119 [38:38<477:56:55,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6225/4699119 [38:39<523:06:49,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6226/4699119 [38:39<554:47:20,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6227/4699119 [38:40<578:25:20,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6228/4699119 [38:40<521:06:13,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6229/4699119 [38:41<552:27:45,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 446, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6230/4699119 [38:41<554:19:05,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6231/4699119 [38:41<529:09:25,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6232/4699119 [38:42<504:54:41,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6233/4699119 [38:42<464:24:12,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 136, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6234/4699119 [38:42<400:58:45,  3.25it/s]

last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6235/4699119 [38:43<411:36:13,  3.17it/s]

last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6236/4699119 [38:43<439:56:30,  2.96it/s]

last_hidden_state = torch.Size([8, 469, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6237/4699119 [38:43<488:00:43,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 113, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6238/4699119 [38:44<413:18:50,  3.15it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6239/4699119 [38:44<477:03:13,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 268, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6240/4699119 [38:44<449:24:36,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 318, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6241/4699119 [38:45<444:09:30,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6242/4699119 [38:45<498:37:04,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 138, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6243/4699119 [38:45<426:12:18,  3.06it/s]

last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6244/4699119 [38:46<430:37:44,  3.03it/s]

last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6245/4699119 [38:46<455:40:00,  2.86it/s]

last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6246/4699119 [38:47<484:51:12,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 477, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6247/4699119 [38:47<519:08:20,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 461, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6248/4699119 [38:47<539:42:12,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6249/4699119 [38:48<566:23:44,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6250/4699119 [38:48<584:16:26,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 176, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6251/4699119 [38:49<493:39:34,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 354, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6252/4699119 [38:49<485:03:19,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 432, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6253/4699119 [38:49<500:50:21,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6254/4699119 [38:50<538:17:44,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6255/4699119 [38:50<564:34:29,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 404, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6256/4699119 [38:51<552:36:46,  2.36it/s]

last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6257/4699119 [38:51<505:31:14,  2.58it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6258/4699119 [38:52<542:31:49,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6259/4699119 [38:52<502:33:43,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 492, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6260/4699119 [38:52<536:44:16,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 123, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6261/4699119 [38:53<448:33:00,  2.91it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6262/4699119 [38:53<501:44:01,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6263/4699119 [38:53<458:30:48,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6264/4699119 [38:54<508:40:29,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6265/4699119 [38:54<545:34:34,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6266/4699119 [38:55<570:05:39,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 140, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6267/4699119 [38:55<476:06:26,  2.74it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6268/4699119 [38:55<520:43:35,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6269/4699119 [38:56<467:18:10,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 152, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6270/4699119 [38:56<406:07:17,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6271/4699119 [38:56<472:02:14,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6272/4699119 [38:57<472:09:59,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 470, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6273/4699119 [38:57<505:21:48,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6274/4699119 [38:58<503:00:11,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6275/4699119 [38:58<540:16:43,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6276/4699119 [38:59<566:00:26,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6277/4699119 [38:59<502:40:43,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6278/4699119 [38:59<513:30:56,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6279/4699119 [38:59<453:40:33,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6281/4699119 [39:00<410:12:08,  3.18it/s]

last_hidden_state = torch.Size([8, 143, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6282/4699119 [39:01<462:51:04,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 334, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6283/4699119 [39:01<458:52:52,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6284/4699119 [39:01<487:58:40,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6285/4699119 [39:02<495:10:11,  2.63it/s]

last_hidden_state = torch.Size([8, 347, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6286/4699119 [39:02<486:39:45,  2.68it/s]

last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6287/4699119 [39:02<459:08:06,  2.84it/s]

last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6288/4699119 [39:03<440:50:59,  2.96it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6289/4699119 [39:03<496:40:06,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 145, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6290/4699119 [39:03<426:33:56,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6291/4699119 [39:04<476:47:51,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6292/4699119 [39:04<520:25:28,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 256, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6293/4699119 [39:05<474:04:25,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6294/4699119 [39:05<491:07:05,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6295/4699119 [39:05<531:47:32,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6296/4699119 [39:06<559:43:56,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6297/4699119 [39:06<579:19:58,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6298/4699119 [39:07<593:10:22,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6299/4699119 [39:07<548:51:16,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 199, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6300/4699119 [39:07<478:43:41,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6301/4699119 [39:08<475:32:30,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6302/4699119 [39:08<469:19:29,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 322, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6303/4699119 [39:08<463:04:54,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6304/4699119 [39:09<512:47:51,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6305/4699119 [39:09<451:02:11,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 483, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6306/4699119 [39:10<504:24:27,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6307/4699119 [39:10<503:05:26,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6308/4699119 [39:10<470:45:31,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 390, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6309/4699119 [39:11<483:39:35,  2.70it/s]

last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6310/4699119 [39:11<464:14:56,  2.81it/s]

last_hidden_state = torch.Size([8, 454, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6311/4699119 [39:12<495:24:10,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6312/4699119 [39:12<511:57:47,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 165, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6313/4699119 [39:12<442:12:10,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 162, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6314/4699119 [39:12<393:10:32,  3.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6315/4699119 [39:13<402:01:30,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 296, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6316/4699119 [39:13<405:03:04,  3.22it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6317/4699119 [39:13<449:14:18,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 219, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6318/4699119 [39:14<416:14:19,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 405, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6319/4699119 [39:14<452:54:12,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6320/4699119 [39:15<503:49:21,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6321/4699119 [39:15<540:27:29,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 209, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6322/4699119 [39:15<476:05:04,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6323/4699119 [39:16<477:22:12,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 149, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6324/4699119 [39:16<413:58:25,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6325/4699119 [39:16<477:05:52,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 232, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6326/4699119 [39:17<438:28:00,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6327/4699119 [39:17<435:01:30,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6328/4699119 [39:17<466:31:07,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6329/4699119 [39:18<514:12:04,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6330/4699119 [39:18<550:46:17,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6331/4699119 [39:19<507:17:11,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6332/4699119 [39:19<542:25:22,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 379, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6333/4699119 [39:20<532:48:30,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6334/4699119 [39:20<465:08:02,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6335/4699119 [39:20<426:00:59,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6336/4699119 [39:21<486:01:36,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6337/4699119 [39:21<527:41:33,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 125, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6338/4699119 [39:21<442:45:13,  2.94it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6339/4699119 [39:22<497:55:10,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 465, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6340/4699119 [39:22<526:21:03,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 508, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6341/4699119 [39:23<556:58:47,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 165, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6342/4699119 [39:23<473:22:26,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 326, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6343/4699119 [39:23<464:12:27,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6344/4699119 [39:24<514:19:47,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6345/4699119 [39:24<548:52:51,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 440, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6346/4699119 [39:25<549:03:23,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6347/4699119 [39:25<576:46:39,  2.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 143, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6348/4699119 [39:25<481:43:44,  2.71it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6349/4699119 [39:26<526:12:59,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 259, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6350/4699119 [39:26<484:42:56,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 148, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6351/4699119 [39:26<418:25:50,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6352/4699119 [39:27<437:05:34,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6353/4699119 [39:27<396:27:15,  3.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 402, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6354/4699119 [39:27<433:40:16,  3.01it/s]

last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6355/4699119 [39:28<477:18:54,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6356/4699119 [39:28<503:00:18,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6357/4699119 [39:29<539:36:38,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 492, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6358/4699119 [39:29<561:58:59,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 501, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6359/4699119 [39:30<585:05:46,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6360/4699119 [39:30<596:43:23,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6361/4699119 [39:30<545:52:12,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6362/4699119 [39:31<537:26:40,  2.43it/s]

last_hidden_state = torch.Size([8, 214, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6363/4699119 [39:31<475:46:03,  2.74it/s]

last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6364/4699119 [39:31<509:08:25,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6365/4699119 [39:32<544:02:45,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6366/4699119 [39:32<481:54:45,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6367/4699119 [39:33<525:15:46,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6368/4699119 [39:33<458:00:19,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6369/4699119 [39:33<431:32:16,  3.02it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6370/4699119 [39:34<441:09:00,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6371/4699119 [39:34<428:01:40,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6372/4699119 [39:34<406:20:03,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 162, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6373/4699119 [39:34<367:45:11,  3.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6374/4699119 [39:35<445:54:17,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6375/4699119 [39:35<415:45:47,  3.14it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6376/4699119 [39:35<406:53:32,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 459, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6377/4699119 [39:36<461:21:46,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6378/4699119 [39:36<430:06:42,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6379/4699119 [39:37<488:36:59,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6380/4699119 [39:37<472:50:05,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 223, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6381/4699119 [39:37<433:00:57,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6382/4699119 [39:38<490:20:51,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6383/4699119 [39:38<530:29:21,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6384/4699119 [39:39<559:49:52,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6385/4699119 [39:39<509:50:37,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 386, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6386/4699119 [39:39<509:17:14,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 497, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6387/4699119 [39:40<548:29:50,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6389/4699119 [39:40<456:53:01,  2.85it/s]

last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6390/4699119 [39:41<461:12:18,  2.83it/s]

last_hidden_state = torch.Size([8, 452, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6391/4699119 [39:41<494:35:33,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6392/4699119 [39:41<463:01:59,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6393/4699119 [39:42<417:39:06,  3.12it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 307, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6394/4699119 [39:42<421:17:45,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6395/4699119 [39:42<414:07:42,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6396/4699119 [39:43<420:07:15,  3.10it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6397/4699119 [39:43<394:28:14,  3.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 254, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6398/4699119 [39:43<385:45:29,  3.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6399/4699119 [39:44<412:53:57,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6400/4699119 [39:44<480:18:15,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6401/4699119 [39:45<523:36:15,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6402/4699119 [39:45<560:10:03,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6403/4699119 [39:45<484:32:08,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6404/4699119 [39:46<527:06:49,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6405/4699119 [39:46<470:37:23,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6406/4699119 [39:46<427:22:33,  3.05it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 160, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6407/4699119 [39:46<379:07:29,  3.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 269, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6408/4699119 [39:47<383:09:50,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6409/4699119 [39:47<362:58:00,  3.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6410/4699119 [39:47<372:13:29,  3.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 142, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6411/4699119 [39:48<338:42:42,  3.85it/s]

last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6412/4699119 [39:48<402:46:17,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6413/4699119 [39:48<377:31:26,  3.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6414/4699119 [39:49<439:43:23,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6415/4699119 [39:49<495:40:06,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 471, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6416/4699119 [39:50<527:24:21,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6417/4699119 [39:50<556:47:00,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6418/4699119 [39:50<486:39:02,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 289, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6419/4699119 [39:51<465:12:47,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 117, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6420/4699119 [39:51<398:08:44,  3.27it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6421/4699119 [39:51<468:08:12,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6422/4699119 [39:52<422:09:54,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6423/4699119 [39:52<482:54:23,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6424/4699119 [39:53<525:56:15,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 257, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6425/4699119 [39:53<482:55:50,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 457, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6426/4699119 [39:53<515:20:07,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 466, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6427/4699119 [39:54<535:57:39,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6428/4699119 [39:54<564:15:13,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6429/4699119 [39:55<533:40:14,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6430/4699119 [39:55<563:31:02,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 449, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6431/4699119 [39:55<571:03:15,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6432/4699119 [39:56<518:47:52,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6433/4699119 [39:56<551:42:55,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6434/4699119 [39:57<522:29:59,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6435/4699119 [39:57<523:54:23,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6436/4699119 [39:57<457:04:52,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 398, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6437/4699119 [39:58<473:07:27,  2.76it/s]

last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6438/4699119 [39:58<439:03:20,  2.97it/s]

last_hidden_state = torch.Size([8, 321, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6439/4699119 [39:58<441:22:22,  2.95it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6440/4699119 [39:59<496:44:54,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 353, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6441/4699119 [39:59<490:44:02,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6442/4699119 [39:59<479:58:51,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 275, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6443/4699119 [40:00<454:46:14,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6444/4699119 [40:00<507:52:06,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6445/4699119 [40:01<507:31:18,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 159, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6446/4699119 [40:01<436:11:17,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 427, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6447/4699119 [40:01<470:53:05,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6448/4699119 [40:02<428:55:15,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6449/4699119 [40:02<489:33:25,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6450/4699119 [40:02<534:07:44,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 306, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6451/4699119 [40:03<501:28:13,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 399, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6452/4699119 [40:03<507:43:47,  2.57it/s]

last_hidden_state = torch.Size([8, 387, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6453/4699119 [40:04<508:46:47,  2.56it/s]

last_hidden_state = torch.Size([8, 229, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6454/4699119 [40:04<459:49:31,  2.83it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6455/4699119 [40:04<510:12:27,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 111, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6456/4699119 [40:05<424:34:19,  3.07it/s]

last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6458/4699119 [40:05<379:40:20,  3.43it/s]

last_hidden_state = torch.Size([8, 191, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6459/4699119 [40:05<368:41:35,  3.54it/s]

last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6460/4699119 [40:06<357:23:31,  3.65it/s]

last_hidden_state = torch.Size([8, 428, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6461/4699119 [40:06<412:57:19,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 432, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6462/4699119 [40:06<452:27:27,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6463/4699119 [40:07<466:50:56,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 424, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6464/4699119 [40:07<489:46:34,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6465/4699119 [40:08<531:55:04,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6466/4699119 [40:08<563:15:23,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 407, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6467/4699119 [40:09<555:27:30,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6468/4699119 [40:09<578:13:06,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 388, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6469/4699119 [40:09<557:10:10,  2.34it/s]

last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6470/4699119 [40:10<489:11:24,  2.66it/s]

last_hidden_state = torch.Size([8, 317, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6471/4699119 [40:10<473:17:27,  2.75it/s]

last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6472/4699119 [40:10<489:51:21,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6473/4699119 [40:11<530:50:41,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6474/4699119 [40:11<559:12:24,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6475/4699119 [40:12<579:18:48,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6476/4699119 [40:12<593:15:52,  2.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6477/4699119 [40:13<604:04:43,  2.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6478/4699119 [40:13<524:30:40,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6479/4699119 [40:14<556:10:00,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6480/4699119 [40:14<514:03:06,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 380, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6481/4699119 [40:14<509:40:05,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6482/4699119 [40:15<544:46:49,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6483/4699119 [40:15<569:05:39,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 175, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6484/4699119 [40:15<483:53:33,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6485/4699119 [40:16<527:00:24,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6486/4699119 [40:16<487:57:25,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 357, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6487/4699119 [40:17<485:22:20,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6488/4699119 [40:17<481:15:35,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6489/4699119 [40:17<435:41:01,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6490/4699119 [40:18<406:10:42,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6491/4699119 [40:18<384:55:00,  3.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6492/4699119 [40:18<457:05:56,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6493/4699119 [40:18<410:49:42,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6494/4699119 [40:19<475:10:44,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6495/4699119 [40:19<429:20:36,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6496/4699119 [40:20<445:01:35,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6497/4699119 [40:20<455:52:16,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 213, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6498/4699119 [40:20<417:06:50,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6499/4699119 [40:21<451:37:15,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 375, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6500/4699119 [40:21<468:02:58,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 463, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6501/4699119 [40:21<504:29:05,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 358, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6502/4699119 [40:22<494:03:54,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6503/4699119 [40:22<537:38:39,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6504/4699119 [40:23<564:28:40,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 230, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6505/4699119 [40:23<498:35:02,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6506/4699119 [40:24<539:27:00,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 402, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6507/4699119 [40:24<536:09:36,  2.43it/s]

last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6508/4699119 [40:24<483:42:53,  2.69it/s]

last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6509/4699119 [40:25<471:27:55,  2.76it/s]

last_hidden_state = torch.Size([8, 430, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6510/4699119 [40:25<491:57:01,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 310, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6511/4699119 [40:25<472:10:00,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 220, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6512/4699119 [40:26<430:44:09,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6513/4699119 [40:26<490:13:05,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 384, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6514/4699119 [40:26<492:04:43,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 416, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6515/4699119 [40:27<502:25:14,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6516/4699119 [40:27<466:32:40,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 315, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6517/4699119 [40:27<458:01:31,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 341, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6518/4699119 [40:28<460:22:46,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 158, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6519/4699119 [40:28<402:35:40,  3.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6520/4699119 [40:28<469:11:32,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 417, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6521/4699119 [40:29<494:26:23,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 284, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6522/4699119 [40:29<466:23:45,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 493, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6523/4699119 [40:30<515:47:25,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6524/4699119 [40:30<452:39:57,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6525/4699119 [40:30<450:02:01,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6526/4699119 [40:31<447:54:27,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6527/4699119 [40:31<463:44:08,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6528/4699119 [40:31<473:19:36,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 156, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6529/4699119 [40:32<411:03:21,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 364, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6530/4699119 [40:32<429:25:37,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6531/4699119 [40:32<489:39:04,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6532/4699119 [40:33<439:38:39,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6533/4699119 [40:33<495:39:03,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6534/4699119 [40:34<534:59:06,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 331, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6535/4699119 [40:34<510:27:07,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6536/4699119 [40:34<545:54:06,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 211, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6537/4699119 [40:35<480:21:36,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6538/4699119 [40:35<530:23:03,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 476, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6539/4699119 [40:36<547:15:34,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6540/4699119 [40:36<572:39:26,  2.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6541/4699119 [40:36<505:47:39,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 130, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6542/4699119 [40:37<429:48:45,  3.03it/s]

last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6543/4699119 [40:37<435:24:46,  2.99it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6544/4699119 [40:37<492:35:44,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 345, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6545/4699119 [40:38<484:36:24,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 456, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6546/4699119 [40:38<510:30:23,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 412, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6547/4699119 [40:39<515:37:57,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6548/4699119 [40:39<475:30:04,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 273, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6549/4699119 [40:39<451:46:06,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 307, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6550/4699119 [40:40<445:16:44,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6551/4699119 [40:40<425:56:26,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 261, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6552/4699119 [40:40<413:17:23,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6553/4699119 [40:40<403:37:09,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 179, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6554/4699119 [40:41<373:28:52,  3.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6555/4699119 [40:41<449:27:58,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 332, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6556/4699119 [40:42<448:10:08,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6557/4699119 [40:42<503:27:05,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6558/4699119 [40:42<541:04:06,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 217, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6559/4699119 [40:43<479:31:53,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6560/4699119 [40:43<523:08:35,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 397, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6561/4699119 [40:44<523:33:38,  2.49it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6562/4699119 [40:44<553:51:40,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 361, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6563/4699119 [40:44<532:01:49,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 482, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6564/4699119 [40:45<555:51:02,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6565/4699119 [40:45<505:05:16,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 133, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6567/4699119 [40:46<392:23:38,  3.32it/s]

last_hidden_state = torch.Size([8, 192, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6568/4699119 [40:46<463:12:58,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6569/4699119 [40:46<447:58:08,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 282, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6570/4699119 [40:47<432:50:27,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6571/4699119 [40:47<442:47:56,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6572/4699119 [40:48<498:30:52,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 414, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6573/4699119 [40:48<509:07:20,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6574/4699119 [40:48<455:34:57,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 242, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6575/4699119 [40:49<424:14:49,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 423, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6576/4699119 [40:49<463:28:55,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 224, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6577/4699119 [40:49<424:13:46,  3.07it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6578/4699119 [40:49<390:10:02,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6580/4699119 [40:50<395:07:45,  3.30it/s]

last_hidden_state = torch.Size([8, 206, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 434, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6581/4699119 [40:51<443:26:46,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6582/4699119 [40:51<447:21:08,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6583/4699119 [40:51<406:00:31,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6584/4699119 [40:51<390:04:03,  3.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6585/4699119 [40:52<409:09:55,  3.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6586/4699119 [40:52<378:32:27,  3.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6587/4699119 [40:52<452:17:35,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6588/4699119 [40:53<436:44:01,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6589/4699119 [40:53<416:47:32,  3.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 340, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6590/4699119 [40:53<428:27:57,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6591/4699119 [40:54<488:19:47,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6592/4699119 [40:54<464:42:51,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6593/4699119 [40:55<514:37:55,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6594/4699119 [40:55<513:10:32,  2.54it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6595/4699119 [40:56<551:06:37,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 333, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6596/4699119 [40:56<521:18:23,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 489, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6597/4699119 [40:56<554:23:35,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 165, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6598/4699119 [40:57<472:22:32,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 286, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6599/4699119 [40:57<450:14:54,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 236, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6600/4699119 [40:57<419:08:33,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6601/4699119 [40:58<483:12:08,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6602/4699119 [40:58<454:43:46,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6603/4699119 [40:58<495:01:24,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6604/4699119 [40:59<531:42:43,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6605/4699119 [40:59<479:31:35,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6606/4699119 [41:00<523:08:59,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 199, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6607/4699119 [41:00<460:09:39,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 399, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6608/4699119 [41:00<479:22:46,  2.72it/s]

last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6609/4699119 [41:01<487:38:21,  2.67it/s]

last_hidden_state = torch.Size([8, 394, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6611/4699119 [41:01<442:48:42,  2.94it/s]

last_hidden_state = torch.Size([8, 212, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6612/4699119 [41:02<443:07:58,  2.94it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6613/4699119 [41:02<501:38:40,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 208, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6614/4699119 [41:02<445:17:12,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6615/4699119 [41:03<429:51:03,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 295, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6616/4699119 [41:03<426:22:49,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 504, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6617/4699119 [41:03<484:46:44,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6618/4699119 [41:04<526:25:19,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 161, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6619/4699119 [41:04<452:17:00,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6620/4699119 [41:04<451:20:43,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6621/4699119 [41:05<439:20:04,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 195, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6622/4699119 [41:05<401:29:46,  3.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 491, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6623/4699119 [41:06<472:30:48,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6624/4699119 [41:06<520:36:36,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6625/4699119 [41:07<552:09:27,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6626/4699119 [41:07<573:49:46,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 182, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6627/4699119 [41:07<491:58:37,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 435, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6628/4699119 [41:08<513:53:33,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 189, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6629/4699119 [41:08<452:01:40,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 247, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6630/4699119 [41:08<423:52:53,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6631/4699119 [41:09<484:01:26,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6632/4699119 [41:09<454:51:38,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6633/4699119 [41:09<444:52:55,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 258, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6634/4699119 [41:10<425:19:55,  3.06it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6635/4699119 [41:10<412:16:23,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 313, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6636/4699119 [41:10<419:29:10,  3.11it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6637/4699119 [41:10<387:36:16,  3.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6638/4699119 [41:11<406:45:38,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 478, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6639/4699119 [41:11<459:33:57,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6640/4699119 [41:12<462:46:28,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6641/4699119 [41:12<515:39:06,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6642/4699119 [41:12<511:49:08,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 351, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6643/4699119 [41:13<498:20:48,  2.62it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6644/4699119 [41:13<536:13:51,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 330, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6645/4699119 [41:14<508:59:38,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 265, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6646/4699119 [41:14<473:11:56,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6647/4699119 [41:14<455:36:21,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6648/4699119 [41:15<507:17:11,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6649/4699119 [41:15<542:41:55,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6650/4699119 [41:16<567:44:10,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6651/4699119 [41:16<585:33:15,  2.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6652/4699119 [41:17<598:18:21,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6653/4699119 [41:17<606:05:59,  2.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6654/4699119 [41:18<611:49:17,  2.13it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 233, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6655/4699119 [41:18<532:17:23,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6656/4699119 [41:18<534:35:57,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 472, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6657/4699119 [41:19<548:20:30,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6658/4699119 [41:19<540:01:37,  2.41it/s]

last_hidden_state = torch.Size([8, 342, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6659/4699119 [41:19<515:22:19,  2.53it/s]

last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6660/4699119 [41:20<523:03:09,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6661/4699119 [41:20<554:33:31,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 441, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6662/4699119 [41:21<559:49:24,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 131, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6664/4699119 [41:21<404:54:40,  3.22it/s]

last_hidden_state = torch.Size([8, 139, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 426, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6665/4699119 [41:22<444:33:56,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 327, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6666/4699119 [41:22<447:31:17,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 279, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6667/4699119 [41:22<433:07:10,  3.01it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6668/4699119 [41:23<434:08:13,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6669/4699119 [41:23<403:20:28,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6670/4699119 [41:23<397:21:15,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 288, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6671/4699119 [41:23<397:14:45,  3.28it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6672/4699119 [41:24<465:45:10,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 264, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6673/4699119 [41:24<441:25:04,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 302, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6674/4699119 [41:25<434:50:07,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 205, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6675/4699119 [41:25<400:10:58,  3.26it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6676/4699119 [41:25<383:00:44,  3.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 389, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6677/4699119 [41:25<421:45:45,  3.09it/s]

last_hidden_state = torch.Size([8, 401, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6678/4699119 [41:26<455:33:31,  2.86it/s]

last_hidden_state = torch.Size([8, 404, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6679/4699119 [41:26<475:07:46,  2.74it/s]

last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6680/4699119 [41:27<490:33:10,  2.66it/s]

fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6681/4699119 [41:27<532:18:09,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6682/4699119 [41:28<524:18:14,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 432, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6683/4699119 [41:28<528:46:29,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 344, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6684/4699119 [41:28<505:49:43,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 274, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6685/4699119 [41:29<471:16:32,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 453, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6686/4699119 [41:29<505:35:01,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6687/4699119 [41:30<541:58:57,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6688/4699119 [41:30<567:57:27,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 245, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6689/4699119 [41:30<504:55:18,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 377, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6690/4699119 [41:31<504:25:22,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 251, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6691/4699119 [41:31<460:42:04,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 126, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6692/4699119 [41:31<396:01:22,  3.29it/s]

last_hidden_state = torch.Size([8, 301, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6694/4699119 [41:32<358:36:53,  3.63it/s]

last_hidden_state = torch.Size([8, 137, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 338, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6696/4699119 [41:32<343:37:09,  3.79it/s]

last_hidden_state = torch.Size([8, 120, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 383, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6697/4699119 [41:33<391:53:27,  3.33it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6698/4699119 [41:33<462:03:17,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 410, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6699/4699119 [41:33<481:50:40,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6700/4699119 [41:34<525:02:17,  2.48it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6701/4699119 [41:34<458:36:38,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 477, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6702/4699119 [41:35<501:38:00,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6703/4699119 [41:35<538:43:49,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 368, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6704/4699119 [41:35<517:42:11,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 433, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6705/4699119 [41:36<533:02:38,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6706/4699119 [41:36<487:40:17,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 287, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6707/4699119 [41:37<462:23:35,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6708/4699119 [41:37<511:04:16,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6709/4699119 [41:37<501:17:36,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 255, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6710/4699119 [41:38<461:45:48,  2.82it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 141, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6711/4699119 [41:38<401:10:05,  3.25it/s]

last_hidden_state = torch.Size([8, 272, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6712/4699119 [41:38<396:07:10,  3.29it/s]

last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6713/4699119 [41:38<384:38:41,  3.39it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6714/4699119 [41:39<456:51:58,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 277, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6715/4699119 [41:39<438:26:00,  2.97it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6716/4699119 [41:40<481:45:51,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 304, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6717/4699119 [41:40<461:22:41,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6718/4699119 [41:40<429:50:16,  3.03it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6719/4699119 [41:41<488:34:24,  2.67it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6720/4699119 [41:41<529:21:50,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 343, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6721/4699119 [41:42<509:27:28,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6722/4699119 [41:42<544:51:49,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 431, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6723/4699119 [41:42<548:50:25,  2.37it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 309, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6724/4699119 [41:43<513:09:16,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 197, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6725/4699119 [41:43<452:23:16,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6726/4699119 [41:44<504:21:35,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 502, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6727/4699119 [41:44<541:48:35,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 409, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6728/4699119 [41:44<541:46:51,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 227, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6729/4699119 [41:45<482:00:37,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 316, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6730/4699119 [41:45<468:10:55,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 300, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6731/4699119 [41:45<452:20:34,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 203, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6732/4699119 [41:46<411:27:08,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 498, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6733/4699119 [41:46<475:07:23,  2.74it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 238, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6734/4699119 [41:46<437:24:52,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6735/4699119 [41:47<493:22:57,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6736/4699119 [41:47<533:04:16,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6737/4699119 [41:48<569:59:48,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 243, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6738/4699119 [41:48<506:59:05,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 451, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6739/4699119 [41:49<530:41:15,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 221, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6740/4699119 [41:49<473:16:01,  2.75it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6741/4699119 [41:49<520:58:32,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6742/4699119 [41:50<506:04:36,  2.58it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6743/4699119 [41:50<541:48:48,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 241, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6744/4699119 [41:50<486:01:10,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6745/4699119 [41:51<529:30:21,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6746/4699119 [41:51<559:06:19,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6747/4699119 [41:52<579:09:09,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 499, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6748/4699119 [41:52<597:32:01,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 139, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6749/4699119 [41:53<495:43:30,  2.63it/s]

last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6750/4699119 [41:53<464:36:52,  2.81it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6751/4699119 [41:53<514:08:48,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 413, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6752/4699119 [41:54<522:59:59,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 369, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6753/4699119 [41:54<515:17:16,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 179, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6754/4699119 [41:54<451:27:45,  2.89it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 324, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6755/4699119 [41:55<448:36:07,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6756/4699119 [41:55<447:24:50,  2.91it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6757/4699119 [41:55<500:52:07,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 385, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6758/4699119 [41:56<506:53:23,  2.57it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6759/4699119 [41:56<542:38:09,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 381, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6760/4699119 [41:57<530:54:23,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6761/4699119 [41:57<478:25:12,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6762/4699119 [41:57<522:12:56,  2.50it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6763/4699119 [41:58<554:48:57,  2.35it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 225, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6764/4699119 [41:58<490:55:59,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6765/4699119 [41:59<531:48:11,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6766/4699119 [41:59<559:29:35,  2.33it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 262, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6767/4699119 [41:59<506:32:42,  2.57it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6768/4699119 [42:00<530:20:25,  2.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 294, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6769/4699119 [42:00<496:20:22,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6770/4699119 [42:01<447:01:45,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 460, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6771/4699119 [42:01<485:31:51,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 226, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6772/4699119 [42:01<441:35:24,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 371, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6773/4699119 [42:02<457:46:40,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 425, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6774/4699119 [42:02<487:10:17,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6775/4699119 [42:02<434:23:21,  3.00it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 373, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6776/4699119 [42:03<453:40:03,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 298, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6777/4699119 [42:03<442:14:07,  2.95it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 408, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6778/4699119 [42:03<467:16:20,  2.79it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6779/4699119 [42:04<514:35:02,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 276, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6780/4699119 [42:04<477:45:26,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6781/4699119 [42:04<469:34:28,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6782/4699119 [42:05<516:33:35,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 201, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6783/4699119 [42:05<456:25:30,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 307, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6784/4699119 [42:06<448:51:12,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 303, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6785/4699119 [42:06<440:13:25,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6786/4699119 [42:06<502:23:07,  2.59it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 419, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6787/4699119 [42:07<517:36:12,  2.52it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 234, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6788/4699119 [42:07<465:58:55,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6789/4699119 [42:07<463:15:30,  2.81it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6790/4699119 [42:08<511:22:06,  2.55it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 475, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6791/4699119 [42:08<538:34:25,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6792/4699119 [42:09<566:31:23,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 442, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6793/4699119 [42:09<563:20:53,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6794/4699119 [42:10<582:09:49,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6795/4699119 [42:10<594:55:25,  2.19it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 250, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6796/4699119 [42:10<523:18:39,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 193, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6797/4699119 [42:11<460:07:06,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6798/4699119 [42:11<509:58:40,  2.56it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6799/4699119 [42:12<545:52:09,  2.39it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 218, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6800/4699119 [42:12<483:09:08,  2.70it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 492, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6801/4699119 [42:12<524:06:39,  2.49it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6802/4699119 [42:13<458:37:52,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 460, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6803/4699119 [42:13<492:40:08,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 292, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6804/4699119 [42:13<468:44:00,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6805/4699119 [42:14<515:26:10,  2.53it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6806/4699119 [42:14<494:51:09,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 395, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6807/4699119 [42:15<501:18:54,  2.60it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6808/4699119 [42:15<538:08:58,  2.42it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 284, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6809/4699119 [42:15<496:00:23,  2.63it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6810/4699119 [42:16<534:45:58,  2.44it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6811/4699119 [42:16<562:20:25,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6812/4699119 [42:17<581:41:51,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6813/4699119 [42:17<598:18:00,  2.18it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 222, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6814/4699119 [42:18<519:58:58,  2.51it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6815/4699119 [42:18<552:47:49,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6816/4699119 [42:19<574:50:40,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 350, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6817/4699119 [42:19<541:13:58,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 283, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6818/4699119 [42:19<500:02:14,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6819/4699119 [42:20<539:54:41,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 421, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6820/4699119 [42:20<543:29:39,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 216, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6821/4699119 [42:20<480:03:24,  2.72it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 335, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6822/4699119 [42:21<472:56:20,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 190, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6823/4699119 [42:21<422:19:33,  3.09it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 376, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6824/4699119 [42:21<442:38:06,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6825/4699119 [42:22<498:52:22,  2.61it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 253, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6826/4699119 [42:22<460:39:12,  2.83it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6827/4699119 [42:23<513:19:13,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6828/4699119 [42:23<456:13:22,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 200, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6829/4699119 [42:23<412:15:16,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6830/4699119 [42:24<477:52:10,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 346, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6831/4699119 [42:24<470:35:58,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6832/4699119 [42:24<435:33:01,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 248, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6833/4699119 [42:24<412:07:22,  3.16it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 278, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6834/4699119 [42:25<406:55:02,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 183, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6835/4699119 [42:25<376:17:38,  3.46it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6836/4699119 [42:26<452:37:56,  2.88it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 362, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6837/4699119 [42:26<458:32:41,  2.84it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 382, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6838/4699119 [42:26<469:27:42,  2.78it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 263, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6839/4699119 [42:27<443:55:42,  2.94it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 325, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6840/4699119 [42:27<446:31:08,  2.92it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 312, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6841/4699119 [42:27<440:25:13,  2.96it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 406, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6842/4699119 [42:28<465:23:35,  2.80it/s]

last_hidden_state = torch.Size([8, 363, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6843/4699119 [42:28<469:45:02,  2.77it/s]

last_hidden_state = torch.Size([8, 328, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6844/4699119 [42:28<461:03:24,  2.83it/s]

last_hidden_state = torch.Size([8, 438, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6845/4699119 [42:29<489:06:19,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 352, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6846/4699119 [42:29<480:20:35,  2.71it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6847/4699119 [42:29<429:17:25,  3.04it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 403, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6849/4699119 [42:30<414:15:02,  3.15it/s]

last_hidden_state = torch.Size([8, 187, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 291, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6850/4699119 [42:30<413:57:27,  3.15it/s]

last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6851/4699119 [42:31<430:10:25,  3.03it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6852/4699119 [42:31<490:21:36,  2.66it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 204, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6853/4699119 [42:31<435:58:57,  2.99it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6854/4699119 [42:32<492:52:59,  2.64it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 356, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6855/4699119 [42:32<485:47:45,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 319, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6856/4699119 [42:33<471:00:29,  2.77it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 443, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6857/4699119 [42:33<501:01:42,  2.60it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 239, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6858/4699119 [42:33<456:08:01,  2.86it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 339, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6859/4699119 [42:34<457:45:28,  2.85it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 290, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6860/4699119 [42:34<444:24:53,  2.93it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 207, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6861/4699119 [42:34<406:40:38,  3.21it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6862/4699119 [42:35<473:06:07,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 252, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6863/4699119 [42:35<437:30:13,  2.98it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6864/4699119 [42:35<492:33:15,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 372, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6865/4699119 [42:36<491:34:36,  2.65it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 231, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6866/4699119 [42:36<448:54:05,  2.90it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 180, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6867/4699119 [42:36<403:57:12,  3.23it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 349, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6868/4699119 [42:37<422:37:18,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6869/4699119 [42:37<483:43:14,  2.69it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 509, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6870/4699119 [42:38<531:34:20,  2.45it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6871/4699119 [42:38<560:55:22,  2.32it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 359, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6872/4699119 [42:38<536:13:50,  2.43it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6873/4699119 [42:39<564:22:36,  2.31it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6874/4699119 [42:39<581:43:14,  2.24it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 336, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6875/4699119 [42:40<541:43:48,  2.41it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6876/4699119 [42:40<568:24:47,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 266, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6877/4699119 [42:41<513:52:34,  2.54it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6878/4699119 [42:41<547:19:44,  2.38it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 202, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6879/4699119 [42:41<476:52:07,  2.73it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 151, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6880/4699119 [42:41<413:27:03,  3.15it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 474, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6881/4699119 [42:42<465:27:26,  2.80it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 210, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6882/4699119 [42:42<423:04:52,  3.08it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6883/4699119 [42:43<485:52:30,  2.68it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 260, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6884/4699119 [42:43<454:53:31,  2.87it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 194, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6885/4699119 [42:43<411:00:43,  3.17it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 281, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6886/4699119 [42:43<407:26:14,  3.20it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6887/4699119 [42:44<473:02:22,  2.76it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 391, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6888/4699119 [42:44<486:34:25,  2.68it/s]

last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6889/4699119 [42:45<528:06:47,  2.47it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6890/4699119 [42:45<557:14:52,  2.34it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6891/4699119 [42:46<579:26:12,  2.25it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 374, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6892/4699119 [42:46<553:17:37,  2.36it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6893/4699119 [42:47<575:24:25,  2.27it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 393, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6894/4699119 [42:47<559:26:58,  2.33it/s]

last_hidden_state = torch.Size([8, 455, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6895/4699119 [42:48<568:11:09,  2.29it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 367, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6896/4699119 [42:48<542:07:17,  2.40it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 512, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])


Epoch 1 [train]:   0%|          | 6897/4699119 [42:48<567:04:06,  2.30it/s]

id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 129, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6899/4699119 [42:49<407:23:44,  3.20it/s]

last_hidden_state = torch.Size([8, 135, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])
last_hidden_state = torch.Size([8, 177, 3584])
pooled.shape = torch.Size([8, 3584])
projected.shape = torch.Size([8, 1024])
id_proj = torch.Size([8, 1024])
fused shape = torch.Size([8, 2048])


Epoch 1 [train]:   0%|          | 6899/4699119 [42:50<485:34:43,  2.68it/s]


KeyboardInterrupt: 