In [None]:
!pip install seaborn



In [None]:
import os
import torch
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch import Tensor
import math
from math import sqrt

from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BertTokenizer,
    BertModel
)

In [None]:
# Monitor GPU memory usage
def print_gpu_memory():
    if torch.cuda.is_available():
        print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
        print(f"GPU memory cached: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB")

# Set random seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)

# Add these at the start of your script
set_seed(42)
print_gpu_memory()

GPU memory allocated: 0.00 MB
GPU memory cached: 0.00 MB


In [None]:
class EncoderLayer(nn.Module):
    def __init__(self, attention, d_model, d_ff=None, dropout=0.1, activation="relu"):
        super(EncoderLayer, self).__init__()
        d_ff = d_ff or 4 * d_model
        self.attention = attention
        self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_ff, kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=d_ff, out_channels=d_model, kernel_size=1)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = F.relu if activation == "relu" else F.gelu

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        new_x, attn = self.attention(x, x, x, attn_mask=attn_mask, tau=tau, delta=delta)
        x = x + self.dropout(new_x)

        y = x = self.norm1(x)
        y = self.dropout(self.activation(self.conv1(y.transpose(-1, 1))))
        y = self.dropout(self.conv2(y).transpose(-1, 1))

        return self.norm2(x + y), attn


class Encoder(nn.Module):
    def __init__(self, attn_layers, conv_layers=None, norm_layer=None):
        super(Encoder, self).__init__()
        self.attn_layers = nn.ModuleList(attn_layers)
        self.conv_layers = (
            nn.ModuleList(conv_layers) if conv_layers is not None else None
        )
        self.norm = norm_layer

    def forward(self, x, attn_mask=None, tau=None, delta=None):
        # x [B, L, D]
        attns = []
        if self.conv_layers is not None:
            for i, (attn_layer, conv_layer) in enumerate(
                zip(self.attn_layers, self.conv_layers)
            ):
                delta = delta if i == 0 else None
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                x = conv_layer(x)
                attns.append(attn)
            x, attn = self.attn_layers[-1](x, tau=tau, delta=None)
            attns.append(attn)
        else:
            for attn_layer in self.attn_layers:
                x, attn = attn_layer(x, attn_mask=attn_mask, tau=tau, delta=delta)
                attns.append(attn)

        if self.norm is not None:
            x = self.norm(x)

        return x, attns

class FullAttention(nn.Module):
    def __init__(
        self,
        mask_flag=True,
        factor=5,
        scale=None,
        attention_dropout=0.1,
        output_attention=False,
    ):
        super(FullAttention, self).__init__()
        self.scale = scale
        self.mask_flag = mask_flag
        self.output_attention = output_attention
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, H, E = queries.shape
        _, S, _, D = values.shape
        scale = self.scale or 1.0 / sqrt(E)

        scores = torch.einsum("blhe,bshe->bhls", queries, keys)

        if self.mask_flag:
            if attn_mask is None:
                attn_mask = TriangularCausalMask(B, L, device=queries.device)

            scores.masked_fill_(attn_mask.mask, -np.inf)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        V = torch.einsum("bhls,bshd->blhd", A, values)

        if self.output_attention:
            return V.contiguous(), A
        else:
            return V.contiguous(), None

class AttentionLayer(nn.Module):
    def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None):
        super(AttentionLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)
        d_values = d_values or (d_model // n_heads)

        self.inner_attention = attention
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_model, d_keys * n_heads)
        self.value_projection = nn.Linear(d_model, d_values * n_heads)
        self.out_projection = nn.Linear(d_values * n_heads, d_model)
        self.n_heads = n_heads

    def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):
        B, L, _ = queries.shape
        _, S, _ = keys.shape
        H = self.n_heads

        queries = self.query_projection(queries).view(B, L, H, -1)
        keys = self.key_projection(keys).view(B, S, H, -1)
        values = self.value_projection(values).view(B, S, H, -1)

        out, attn = self.inner_attention(
            queries, keys, values, attn_mask, tau=tau, delta=delta
        )
        out = out.view(B, L, -1)

        return self.out_projection(out), attn


class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEmbedding, self).__init__()

        # Compute the positional encodings once in log space.

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe = torch.zeros(max_len, d_model).float()
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        pe.require_grad = False

        self.register_buffer("pe", pe)

    def forward(self, x):
        return self.pe[:, : x.size(1)]

In [None]:
class MultimodalFinancialDataset(Dataset):
    def __init__(self, time_series_path, text_path, window_size=5, split="train", max_len=390):
        self.window_size = window_size
        self.split = split
        self.max_len = max_len

        # Load time-series data
        time_series_data = pd.read_csv(time_series_path)

        def process_list(x, max_len):
            try:
                if isinstance(x, str):
                    values = eval(x)
                else:
                    values = x

                values = np.array(values, dtype=np.float32)

                # Handle NaN/Inf values
                if np.any(np.isnan(values)) or np.any(np.isinf(values)):
                    logger.warning(f"Found NaN/Inf values in data, replacing with zeros")
                    values = np.nan_to_num(values, 0)

                if len(values) > max_len:
                    return values[:max_len]
                elif len(values) < max_len:
                    padding = np.full(max_len - len(values), values[-1])
                    return np.concatenate([values, padding])
                return values
            except Exception as e:
                logger.error(f"Error processing value: {x[:100]}... Error: {str(e)}")
                return np.zeros(max_len)

        print("Processing time series data...")
        for col in ['open', 'high', 'low', 'close', 'volume']:
            time_series_data[col] = time_series_data[col].apply(
                lambda x: process_list(x, max_len)
            )

        print("Loading text data...")
        text_data = pd.read_csv(text_path)

        print("Merging datasets...")
        self.data = pd.merge(
            time_series_data,
            text_data,
            left_on='date',
            right_on='date',
            how='inner'
        )

        # Store dates as strings instead of timestamps
        self.dates = pd.to_datetime(self.data['date']).dt.strftime('%Y-%m-%d').values

        print("Processing features...")
        self.features = []
        for _, row in self.data.iterrows():
            try:
                daily_features = np.stack([
                    row['open'],
                    row['high'],
                    row['low'],
                    row['close'],
                    row['volume']
                ])
                self.features.append(daily_features)
            except Exception as e:
                print(f"Error processing row for date {row['date']}: {str(e)}")
                daily_features = np.zeros((5, max_len))
                self.features.append(daily_features)

        self.features = np.array(self.features)
        self.text = self.data['text'].values

        print(f"Dataset shapes:")
        print(f"Features: {self.features.shape}")
        print(f"Number of text samples: {len(self.text)}")

        self._compute_feature_stats()

        # Generate labels - up/down based on closing price
        self.labels = []
        for i in range(len(self.features) - window_size + 1):
            window_data = self.features[i:i + window_size]
            # Compare last day's close to first day's close
            start_price = window_data[0][3, 0]  # First day's opening price
            end_price = window_data[-1][3, -1]  # Last day's closing price
            # 1 if price went up, 0 if down
            label = 1 if end_price > start_price else 0
            self.labels.append(label)

        self.labels = np.array(self.labels)

    def _compute_feature_stats(self):
        """Compute feature statistics with validation"""
        try:
            self.feature_stats = {
                'mean': np.nanmean(self.features[:, :4, :], axis=(0, 2)),
                'std': np.nanstd(self.features[:, :4, :], axis=(0, 2)),
                'volume_mean': np.nanmean(self.features[:, 4, :]),
                'volume_std': np.nanstd(self.features[:, 4, :])
            }

            # Handle zero standard deviations
            self.feature_stats['std'] = np.where(
                self.feature_stats['std'] == 0,
                1e-6,  # Small constant instead of 1
                self.feature_stats['std']
            )

            if self.feature_stats['volume_std'] == 0:
                self.feature_stats['volume_std'] = 1e-6

        except Exception as e:
            raise RuntimeError(f"Error computing feature statistics: {str(e)}")

    def normalize_features(self):
        for i in range(4):
            self.features[:, i, :] = (
                (self.features[:, i, :] - self.feature_stats['mean'][i]) /
                self.feature_stats['std'][i]
            )

        self.features[:, 4, :] = (
            (np.log1p(self.features[:, 4, :]) - np.log1p(self.feature_stats['volume_mean'])) /
            (self.feature_stats['volume_std'] + 1e-8)
        )

    def __len__(self):
        return len(self.features) - self.window_size + 1

    def __getitem__(self, idx):

        # Ensure the idx is an integer
        if isinstance(idx, slice):
            raise TypeError("Slicing the dataset directly is not supported. Use DataLoader or manual splitting.")
        if idx + self.window_size > len(self.features):
            raise IndexError("Index out of range for the dataset.")
        # Get window of data
        x = self.features[idx:idx + self.window_size]
        text_window = self.text[idx:idx + self.window_size]
        dates_window = self.dates[idx:idx + self.window_size]
        label = self.labels[idx]

        # Calculate daily summaries
        daily_summaries = []
        for day_idx in range(self.window_size):
            daily_data = x[day_idx]

            valid_idx = np.where(np.diff(daily_data[3]) != 0)[0]
            seq_len = valid_idx[-1] + 1 if len(valid_idx) > 0 else daily_data.shape[1]

            summary = {
                'date': dates_window[day_idx],
                'day_open': daily_data[0, 0],
                'day_close': daily_data[3, seq_len-1],
                'day_high': np.max(daily_data[1, :seq_len]),
                'day_low': np.min(daily_data[2, :seq_len]),
                'day_volume': np.sum(daily_data[4, :seq_len]),
                'volatility': np.std(daily_data[3, :seq_len]),
                'text': text_window[day_idx]
            }

            prompt = (
                f"Date: {summary['date']} | "
                f"Open: {summary['day_open']:.2f} | "
                f"Close: {summary['day_close']:.2f} | "
                f"High: {summary['day_high']:.2f} | "
                f"Low: {summary['day_low']:.2f} | "
                f"Volume: {summary['day_volume']:,.0f} | "
                f"Volatility: {summary['volatility']:.4f} | "
                f"Context: {summary['text']}"
            )
            daily_summaries.append(prompt)

        return {
            "x_enc": torch.tensor(x, dtype=torch.float32),
            "text": daily_summaries,
            "dates": dates_window.tolist(),  # Convert numpy array to list
            "label": torch.tensor(label, dtype=torch.long)  # Add label to return dict
        }

# Custom collate function to handle the batch creation
def custom_collate_fn(batch):
    """
    Custom collate function to properly batch the data.
    """
    x_enc = torch.stack([item['x_enc'] for item in batch])
    text = [item['text'] for item in batch]
    dates = [item['dates'] for item in batch]
    labels = torch.stack([item['label'] for item in batch])

    return {
        'x_enc': x_enc,
        'text': text,
        'dates': dates,
        'labels': labels
    }

In [None]:
# Add this after creating the dataset but before the dataloader
def display_processed_dataset(dataset, num_samples=5):
    print("\nProcessed Dataset Sample:")
    print("-" * 100)

    for i in range(min(num_samples, len(dataset))):
        sample = dataset[i]

        print(f"\nSample {i+1}:")
        print("Dates:", sample['dates'])

        # Print time series data summary
        ts_data = sample['x_enc']  # Shape: [window_size, features, minutes]
        print("\nTime Series Data Summary:")
        for day in range(ts_data.size(0)):
            print(f"\nDay {day+1} ({sample['dates'][day]}):")
            print(f"Open:  Mean={ts_data[day,0,:].mean():.3f}, Min={ts_data[day,0,:].min():.3f}, Max={ts_data[day,0,:].max():.3f}")
            print(f"High:  Mean={ts_data[day,1,:].mean():.3f}, Min={ts_data[day,1,:].min():.3f}, Max={ts_data[day,1,:].max():.3f}")
            print(f"Low:   Mean={ts_data[day,2,:].mean():.3f}, Min={ts_data[day,2,:].min():.3f}, Max={ts_data[day,2,:].max():.3f}")
            print(f"Close: Mean={ts_data[day,3,:].mean():.3f}, Min={ts_data[day,3,:].min():.3f}, Max={ts_data[day,3,:].max():.3f}")
            print(f"Volume: Mean={ts_data[day,4,:].mean():.3f}, Min={ts_data[day,4,:].min():.3f}, Max={ts_data[day,4,:].max():.3f}")

        print("\nText Prompts:")
        for j, prompt in enumerate(sample['text']):
            print(f"Day {j+1}: {prompt[:200]}...")

        print("-" * 100)

# Create dataset
dataset = MultimodalFinancialDataset(
    time_series_path='AAPL_train_data_aggregated.csv',
    text_path='AAPL_tweets_train.csv',
    window_size=5,
    max_len=390
)

# Normalize features
dataset.normalize_features()

# Display processed samples
display_processed_dataset(dataset, num_samples=5)

# Create data loader with custom collate function
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    collate_fn=custom_collate_fn  # Add this line
)

# Test the dataloader
batch = next(iter(dataloader))
print("Batch shapes:")
print(f"x_enc: {batch['x_enc'].shape}")
print(f"Number of text samples: {len(batch['text'])}")
print(f"Number of date samples: {len(batch['dates'])}")

Processing time series data...
Loading text data...
Merging datasets...
Processing features...
Dataset shapes:
Features: (334, 5, 390)
Number of text samples: 334

Processed Dataset Sample:
----------------------------------------------------------------------------------------------------

Sample 1:
Dates: ['2021-01-04', '2021-01-05', '2021-01-06', '2021-01-07', '2021-01-08']

Time Series Data Summary:

Day 1 (2021-01-04):
Open:  Mean=-1.067, Min=-1.097, Max=-1.041
High:  Mean=-1.066, Min=-1.092, Max=-1.037
Low:   Mean=-1.067, Min=-1.098, Max=-1.047
Close: Mean=-1.067, Min=-1.097, Max=-1.042
Volume: Mean=-0.000, Min=-0.000, Max=0.000

Day 2 (2021-01-05):
Open:  Mean=-0.971, Min=-1.020, Max=-0.943
High:  Mean=-0.971, Min=-1.021, Max=-0.942
Low:   Mean=-0.971, Min=-1.023, Max=-0.945
Close: Mean=-0.971, Min=-1.020, Max=-0.944
Volume: Mean=-0.000, Min=-0.000, Max=0.000

Day 3 (2021-01-06):
Open:  Mean=-1.157, Min=-1.225, Max=-1.047
High:  Mean=-1.155, Min=-1.216, Max=-1.045
Low:   Mean=-1

```
Batch shapes: x_enc: torch.Size([32, 5, 5, 390])
               |    |  |  |
               |    |  |  └── Number of minutes in each trading day (390 = 6.5 hours × 60 minutes)
               |    |  └── Number of features (OHLCV: Open, High, Low, Close, Volume)
               |    └── Window size (5 days of historical data)
               └── Batch size (32 samples per batch)

Number of text samples: 32
Number of date samples: 32
```

Patch-Based Embedding for Time-Series : https://github.com/flixpar/med-ts-llm/blob/main/models/PatchTST.py

In [None]:
# Modified PatchEmbedding class
class PatchEmbedding(nn.Module):
    def __init__(self, d_model, patch_len, stride, padding, dropout):
        super().__init__()
        self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
        # Modified to handle combined features in each patch
        self.value_embedding = nn.Linear(patch_len * 5, d_model, bias=False)  # 5 features combined
        self.position_embedding = PositionalEmbedding(d_model, max_len=1024)
        self.dropout = nn.Dropout(dropout)
        self.patch_len = patch_len
        self.stride = stride

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape [batch_size, window_size, n_features, n_minutes]
               where batch_size=32, window_size=5, n_features=5, n_minutes=390
        """
        batch_size, window_size, n_features, n_minutes = x.shape

        # Reshape to handle windows independently
        x = x.reshape(-1, n_features, n_minutes)  # [batch_size * window_size, n_features, n_minutes]

        # Apply padding
        x = self.padding_patch_layer(x)  # [batch_size * window_size, n_features, n_minutes + padding]

        # Create patches - unfold each feature sequence
        # [batch_size * window_size, n_features, num_patches, patch_len]
        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        num_patches = x.size(2)

        # Reshape to combine features within each patch
        x = x.permute(0, 2, 1, 3)  # [batch_size * window_size, num_patches, n_features, patch_len]
        x = x.reshape(batch_size * window_size, num_patches, -1)  # [batch_size * window_size, num_patches, n_features * patch_len]

        # Apply embeddings
        x = self.value_embedding(x)  # [batch_size * window_size, num_patches, d_model]
        x = x + self.position_embedding(x)
        x = self.dropout(x)

        # Reshape back to include window dimension
        x = x.reshape(batch_size, window_size, num_patches, -1)

        return x, n_features

In [None]:
class PatchTSTWithBERT(nn.Module):
    def __init__(self, config, dataset, bert_model="bert-base-uncased"):
        super().__init__()
        self.config = config
        self.model_config = config['models']['patchtst']

        # Validate input dimensions
        if not hasattr(dataset, 'features'):
            raise ValueError("Dataset must have 'features' attribute")

        self.enc_in = dataset.features.shape[1]  # Number of input features (5)
        self.num_class = 2 # Binary segmentation
        self.max_seq_len = dataset.features.shape[2] # Number of minutes (390)

        print(f"Input features shape: {dataset.features.shape}")
        print(f"Patch length: {self.model_config['patching']['patch_len']}")
        print(f"Stride: {self.model_config['patching']['stride']}")

        self.n_patches = self._calculate_n_patches()
        self.projection_dim = self._calculate_projection_dim()

        print(f"Final projection dim: {self.projection_dim}")

        # Create projection layer
        self.projection = nn.Linear(self.projection_dim, self.num_class)

        # # Calculate patches based on minute-level sequence length
        # self.n_patches = (self.max_seq_len - self.model_config['patching']['patch_len']) // self.model_config['patching']['stride'] + 1
        # print(f"Number of patches per sequence: {self.n_patches}")

        # Patch-based encoding
        self.patch_embedding = PatchEmbedding(
            d_model=self.model_config['d_model'],
            patch_len=self.model_config['patching']['patch_len'],
            stride=self.model_config['patching']['stride'],
            padding=self.model_config['patching']['stride'],
            dropout=config['training']['dropout'],
        )

        # Transformer encoder
        self.encoder = Encoder(
            [
                EncoderLayer(
                    AttentionLayer(
                        FullAttention(False, factor=3, attention_dropout=config['training']['dropout']),
                        self.model_config['d_model'],
                        self.model_config['n_heads'],
                    ),
                    self.model_config['d_model'],
                    self.model_config['d_ff'],
                    dropout=config['training']['dropout'],
                    activation="gelu",
                )
                for _ in range(self.model_config['e_layers'])
            ],
            norm_layer=nn.LayerNorm(self.model_config['d_model']),
        )

        # BERT integration
        self.bert_tokenizer = BertTokenizer.from_pretrained(bert_model)
        self.bert = BertModel.from_pretrained(bert_model)

        if config.get('freeze_bert', True):
            for param in self.bert.parameters():
                param.requires_grad = False

        # Projections
        self.bert_projection = nn.Linear(768, self.model_config['d_model'])

        # Calculate actual flattened dimension based on shapes
        # Each window produces (n_patches * d_model) features
        # Total features after flattening all windows
        self.features_per_window = (self.n_patches * self.model_config['d_model'])
        self.total_features = self.features_per_window * config['history_len']
        # print(f"Features per window: {self.features_per_window}")
        # print(f"Total features after flattening: {self.total_features}")

        # Projection to output classes
        # self.projection = nn.Linear(16640, self.num_class)  # Using actual shape from tensor

    def _calculate_n_patches(self):
        """Calculate number of patches dynamically"""
        return (self.max_seq_len - self.model_config['patching']['patch_len']) // self.model_config['patching']['stride'] + 1

    def _calculate_projection_dim(self):
        """Calculate projection dimension dynamically"""
        # We know from the debug output that we get 26 patches
        n_patches = 26  # Fixed value based on actual output
        features_per_window = n_patches * self.model_config['d_model']
        total_dim = features_per_window * self.config['history_len']
        # This will be 26 * 128 * 5 = 16640, matching our tensor shape
        return total_dim

    def forward(self, x, prompts=None):
        batch_size, window_size, n_features, n_minutes = x.shape
        print(f"Input shape: {x.shape}")  # [batch_size, window_size, n_features, n_minutes]

        # Apply patch embedding
        x_patched, _ = self.patch_embedding(x)
        print(f"After patch embedding: {x_patched.shape}")  # [batch_size, window_size, num_patches, d_model]

        # Process text if provided
        if prompts is not None:
            text_embeds = self.encode_prompts(prompts)
            text_embeds = self.align_prompt_embeddings(
                text_embeds, batch_size, window_size, x_patched
            )
            x_patched = x_patched + text_embeds
            print(f"After text fusion: {x_patched.shape}")

        # Reshape for encoder
        x_patched = x_patched.reshape(-1, x_patched.size(2), x_patched.size(3))
        print(f"Before encoder: {x_patched.shape}")

        # Pass through encoder
        x_encoded, _ = self.encoder(x_patched)
        print(f"After encoder: {x_encoded.shape}")

        # Reshape back to include window dimension
        x_encoded = x_encoded.reshape(batch_size, window_size, -1)
        print(f"After reshaping: {x_encoded.shape}")

        # Final flatten
        x_encoded = x_encoded.reshape(batch_size, -1)
        print(f"Final shape before projection: {x_encoded.shape}")
        print(f"Projection weight shape: {self.projection.weight.shape}")

        return self.projection(x_encoded)

    def encode_prompts(self, prompts):
        """Encodes text prompts using BERT."""
        batch_size = len(prompts)
        window_size = len(prompts[0])
        all_embeddings = []

        for batch_idx in range(batch_size):
            window_embeddings = []
            for window_idx in range(window_size):
                encoded = self.bert_tokenizer(
                    prompts[batch_idx][window_idx],
                    padding='max_length',
                    truncation=True,
                    max_length=128,
                    return_tensors="pt"
                ).to(next(self.bert.parameters()).device)

                with torch.no_grad():
                    outputs = self.bert(**encoded)
                    embedding = outputs.last_hidden_state[:, 0, :]  # Use [CLS] token
                    window_embeddings.append(embedding)

            batch_embeddings = torch.cat(window_embeddings, dim=0)
            all_embeddings.append(batch_embeddings)

        all_embeddings = torch.stack(all_embeddings)
        return self.bert_projection(all_embeddings)

    def align_prompt_embeddings(self, encoded_prompts, batch_size, window_size, x_patched):
        """Aligns encoded prompts with time series embeddings."""
        return encoded_prompts.unsqueeze(2).expand(-1, -1, x_patched.size(2), -1)

In [None]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch
import torch.nn as nn
from tqdm import tqdm
import logging

logger = logging.getLogger(__name__)

class SupervisedSegmentationTask:
    def __init__(self, model, train_loader, val_loader, config):
        self.config = config
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

        self.loss_fn = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=config['lr'],
            weight_decay=config.get('weight_decay', 1e-4)
        )
        self.scheduler = torch.optim.lr_scheduler.StepLR(
            self.optimizer,
            step_size=config.get('scheduler_step_size', 10),
            gamma=config.get('scheduler_gamma', 0.7)
        )

        # Add gradient clipping
        self.grad_clip = config.get('training', {}).get('gradient_clip', 1.0)

        # Initialize best metrics
        self.best_metrics = None
        self.train_metrics = None
        self.val_metrics = None

    def train_epoch(self, epoch):
        """Training with improved metric handling and memory management"""
        self.model.train()
        total_loss = 0
        num_batches = 0
        all_predictions = []
        all_targets = []

        # Use tqdm for progress tracking
        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch+1}")

        for batch_idx, batch in enumerate(pbar):
            try:
                # Move data to device efficiently
                inputs = batch['x_enc'].to(self.device, non_blocking=True)
                prompts = batch['text']
                targets = batch['labels'].to(self.device, non_blocking=True)

                # Clear gradients
                self.optimizer.zero_grad(set_to_none=True)

                # Forward pass with autocast for mixed precision
                with torch.amp.autocast(device_type="cuda", enabled=True):
                    outputs = self.model(inputs, prompts)
                    loss = self.loss_fn(outputs, targets)

                # Backward pass and optimization
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
                self.optimizer.step()

                # Update metrics
                predictions = torch.argmax(outputs, dim=1)
                all_predictions.extend(predictions.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

                total_loss += loss.item()
                num_batches += 1

                # Update progress bar
                if batch_idx % 10 == 0:
                    metrics = self._calculate_metrics(all_predictions, all_targets)
                    pbar.set_postfix(
                        loss=f"{loss.item():.4f}",
                        acc=f"{metrics['accuracy']:.4f}",
                        f1=f"{metrics['f1']:.4f}"
                    )

            except RuntimeError as e:
                if "out of memory" in str(e):
                    if hasattr(torch.cuda, 'empty_cache'):
                        torch.cuda.empty_cache()
                    logger.error(f"GPU OOM in batch {batch_idx}. Try reducing batch size.")
                raise e

        # Calculate final metrics
        metrics = self._calculate_metrics(all_predictions, all_targets)
        metrics['loss'] = total_loss / num_batches

        # Store metrics
        self.train_metrics = metrics

        return metrics

    def validate(self, loader):
        """Validate the model with comprehensive metrics"""
        self.model.eval()
        total_loss = 0
        num_batches = 0
        all_predictions = []
        all_targets = []

        with torch.no_grad():
            for batch in loader:
                inputs = batch['x_enc'].to(self.device)
                prompts = batch['text']
                targets = batch['labels'].to(self.device)

                outputs = self.model(inputs, prompts)
                loss = self.loss_fn(outputs, targets)

                predictions = torch.argmax(outputs, dim=1)
                all_predictions.extend(predictions.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())

                total_loss += loss.item()
                num_batches += 1

        metrics = self._calculate_metrics(all_predictions, all_targets)
        metrics['loss'] = total_loss / num_batches

        # Store validation metrics
        self.val_metrics = metrics

        return metrics

    def train(self, epochs):
        """Train for multiple epochs with validation"""
        best_loss = float('inf')
        patience = self.config.get('training', {}).get('early_stopping_patience', 5)
        patience_counter = 0

        for epoch in range(epochs):
            # Train epoch
            train_metrics = self.train_epoch(epoch)
            logger.info(
                f"Epoch {epoch + 1}/{self.config['training']['epochs']}, "
                f"Train Loss: {train_metrics['loss']:.4f}, "
                f"Train Acc: {train_metrics['accuracy']:.4f}, "
                f"Train Precision: {train_metrics['precision']:.4f}, "
                f"Train Recall: {train_metrics['recall']:.4f}, "
                f"Train F1: {train_metrics['f1']:.4f}"
            )

            # Validate
            val_metrics = self.validate(self.val_loader)
            logger.info(
                f"Validation Loss: {val_metrics['loss']:.4f}, "
                f"Validation Acc: {val_metrics['accuracy']:.4f}, "
                f"Validation Precision: {val_metrics['precision']:.4f}, "
                f"Validation Recall: {val_metrics['recall']:.4f}, "
                f"Validation F1: {val_metrics['f1']:.4f}"
            )

            # Early stopping check
            if val_metrics['loss'] < best_loss:
                best_loss = val_metrics['loss']
                patience_counter = 0
                self.best_metrics = val_metrics
                self.save_checkpoint(f'checkpoints/best_model_epoch_{epoch}.pth')
                logger.info(f"Saved best model with val_loss: {val_metrics['loss']:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    logger.info("Early stopping triggered")
                    break

            self.scheduler.step()

    def _calculate_metrics(self, predictions, targets):
        """Calculate metrics in a memory-efficient way"""
        return {
            'accuracy': accuracy_score(targets, predictions),
            'precision': precision_score(targets, predictions, zero_division='warn'),
            'recall': recall_score(targets, predictions, zero_division='warn'),
            'f1': f1_score(targets, predictions, zero_division='warn')
        }

    def save_checkpoint(self, path):
        """Save model checkpoint with comprehensive metrics"""
        checkpoint = {
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'config': self.config,
            'metrics': {
                'train_metrics': self.train_metrics,
                'val_metrics': self.val_metrics,
                'best_metrics': self.best_metrics
            }
        }
        try:
            torch.save(checkpoint, path)
            logger.info(f"Successfully saved checkpoint to {path}")
        except Exception as e:
            logger.error(f"Error saving checkpoint: {str(e)}")
            raise

    def load_checkpoint(self, path):
        """Load model checkpoint with metrics"""
        try:
            checkpoint = torch.load(path)
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
            self.train_metrics = checkpoint.get('metrics', {}).get('train_metrics')
            self.val_metrics = checkpoint.get('metrics', {}).get('val_metrics')
            self.best_metrics = checkpoint.get('metrics', {}).get('best_metrics')
            logger.info(f"Successfully loaded checkpoint from {path}")
            return checkpoint['metrics']
        except Exception as e:
            logger.error(f"Error loading checkpoint: {str(e)}")
            raise

Visualization Function

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

def plot_financial_predictions(model, dataset, predictions, targets=None, window_size=5):
    """
    Create a comprehensive visualization of financial predictions.
    """
    sns.set_theme()
    plt.figure(figsize=(20, 15))

    # Convert dates to datetime
    dates = pd.to_datetime(dataset.dates)

    # 1. Price Movement Plot
    ax1 = plt.subplot(3, 2, 1)
    close_prices = dataset.features[:, 3, 0]
    ax1.plot(dates, close_prices, label='Close Price', color='blue', linewidth=1)

    # Add validation period shading
    # val_start_idx = int(len(dates) * 0.8)  # 80-20 split
    # ax1.axvspan(dates[val_start_idx], dates[-1],
    #             color='yellow', alpha=0.1, label='Validation Period')

    # # Highlight predictions
    # for i, pred in enumerate(predictions):
    #     if i+window_size < len(dates):
    #         color = 'green' if pred == 1 else 'red'
    #         ax1.axvspan(dates[i], dates[i+window_size], alpha=0.2, color=color)

    val_start_idx = int(len(dates) * 0.8)  # 80-20 split
    val_dates = dates[val_start_idx:]
    val_predictions = predictions[-len(val_dates):]  # Only take predictions for validation period

    for i, pred in enumerate(val_predictions):
        if i+window_size < len(val_dates):
            color = 'green' if pred == 1 else 'red'
            ax1.axvspan(val_dates[i], val_dates[i+window_size], alpha=0.2, color=color)

    ax1.set_title('Price Movement with Predictions', fontsize=12, pad=10)
    ax1.set_xlabel('Date', fontsize=10)
    ax1.set_ylabel('Normalized Price', fontsize=10)

    # Updated date locator and formatter
    ax1.xaxis.set_major_locator(mdates.AutoDateLocator())
    ax1.xaxis.set_minor_locator(mdates.DayLocator())
    ax1.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    plt.setp(ax1.xaxis.get_majorticklabels(), rotation=45, ha='right')
    ax1.legend()

    # 2. Confusion Matrix
    if targets is not None:
        ax2 = plt.subplot(3, 2, 2)
        cm = confusion_matrix(targets, predictions)
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax2)
        ax2.set_title('Confusion Matrix', fontsize=12, pad=10)
        ax2.set_xlabel('Predicted', fontsize=10)
        ax2.set_ylabel('Actual', fontsize=10)

    # 3. Daily Returns Distribution
    ax3 = plt.subplot(3, 2, 3)
    returns = np.diff(close_prices) / close_prices[:-1]
    sns.histplot(returns, kde=True, ax=ax3, bins=50)
    ax3.set_title('Distribution of Daily Returns', fontsize=12, pad=10)
    ax3.set_xlabel('Returns', fontsize=10)
    ax3.set_ylabel('Count', fontsize=10)

    # 4. Volume Profile
    ax4 = plt.subplot(3, 2, 4)
    volumes = dataset.features[:, 4, 0]
    ax4.bar(dates, volumes, alpha=0.5, width=1)
    ax4.set_title('Trading Volume Profile', fontsize=12, pad=10)
    ax4.set_xlabel('Date', fontsize=10)
    ax4.set_ylabel('Normalized Volume', fontsize=10)

    # Updated date locator and formatter for volume plot
    ax4.xaxis.set_major_locator(mdates.AutoDateLocator())
    ax4.xaxis.set_minor_locator(mdates.DayLocator())
    ax4.xaxis.set_major_formatter(mdates.DateFormatter('%Y-%m-%d'))
    plt.setp(ax4.xaxis.get_majorticklabels(), rotation=45, ha='right')

    # 5. Model Confidence
    if hasattr(model, 'predict_proba'):
        ax5 = plt.subplot(3, 2, 5)
        probas = model.predict_proba(dataset)
        ax5.hist(probas[:, 1], bins=50)
        ax5.set_title('Model Prediction Confidence', fontsize=12, pad=10)
        ax5.set_xlabel('Probability of Upward Movement', fontsize=10)
        ax5.set_ylabel('Count', fontsize=10)

    # 6. Performance Metrics
    if targets is not None:
        ax6 = plt.subplot(3, 2, 6)
        report = classification_report(targets, predictions, target_names=['Down', 'Up'])
        ax6.text(0.1, 0.1, report, fontsize=10, family='monospace')
        ax6.axis('off')
        ax6.set_title('Classification Metrics', fontsize=12, pad=10)

    plt.tight_layout(pad=3.0, h_pad=3.0, w_pad=3.0)
    return plt.gcf()

def track_training_metrics(metrics_history):
    """
    Function to track and plot training metrics over epochs.
    """
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot training and validation loss
    ax1.plot(metrics_history['train_loss'], label='Training Loss')
    ax1.plot(metrics_history['val_loss'], label='Validation Loss')
    ax1.set_title('Loss Over Epochs')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()

    # Plot accuracies
    ax2.plot(metrics_history['train_accuracy'], label='Training Accuracy')
    ax2.plot(metrics_history['val_accuracy'], label='Validation Accuracy')
    ax2.set_title('Accuracy Over Epochs')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()

    plt.tight_layout()
    return fig

In [None]:
from pathlib import Path
Path('checkpoints').mkdir(exist_ok=True)

def create_temporal_split(dataset, val_ratio=0.2, window_size=5):
    """Create temporal train/validation split with clear date boundaries"""
    total_size = len(dataset)
    split_idx = int((1 - val_ratio) * total_size)
    split_idx = split_idx - window_size

    train_indices = list(range(split_idx))
    val_indices = list(range(split_idx + window_size, total_size))

    train_dataset = Subset(dataset, train_indices)
    val_dataset = Subset(dataset, val_indices)

    # Print date ranges
    train_start = dataset.dates[0]
    train_end = dataset.dates[split_idx]
    val_start = dataset.dates[split_idx + window_size]
    val_end = dataset.dates[-1]

    print(f"Training period: {train_start} to {train_end}")
    print(f"Validation period: {val_start} to {val_end}")

    return train_dataset, val_dataset

In [None]:
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

if __name__ == "__main__":
    import torch
    import logging
    from torch.utils.data import DataLoader, random_split
    import numpy as np
    from pathlib import Path

    # Set up logging with detailed configuration
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler('training.log'),
            logging.StreamHandler()
        ]
    )
    logger = logging.getLogger(__name__)

    # Set random seeds for reproducibility
    def set_seed(seed=42):
        torch.manual_seed(seed)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(seed)
        np.random.seed(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    set_seed(42)

    # Enhanced configuration
    config = {
        'models': {
            'patchtst': {
                # Embedding dimensions
                'd_model': 128,

                # Patch settings for minute-level data
                'patching': {
                    'patch_len': 30,     # 30-minute patches
                    'stride': 15,        # 15-minute stride
                    'padding': 15        # Padding at sequence ends
                },

                # Transformer settings
                'n_heads': 4,
                'd_ff': 512,
                'e_layers': 3,
            }
        },

        # Training settings
        'training': {
            'dropout': 0.1,
            'batch_size': 16,
            'epochs': 10,
            'gradient_clip': 1.0,
            'early_stopping_patience': 5,
            'validation_split': 0.2,
        },

        # Optimization settings
        'lr': 1e-4,
        'weight_decay': 1e-4,
        'scheduler_step_size': 5,
        'scheduler_gamma': 0.7,

        # Model structure settings
        'history_len': 5,
        'max_seq_len': 390,
        'num_features': 5,

        # Text processing settings
        'bert_model': "bert-base-uncased",
        'max_text_length': 128,
        'freeze_bert': True,
        'bert_pooling': 'cls',

        # Checkpointing
        'checkpoint_dir': 'checkpoints',
        'save_every': 1,
    }

    metrics_history = {
    'train_loss': [], 'val_loss': [],
    'train_accuracy': [], 'val_accuracy': [],
    'train_f1': [], 'val_f1': [],
    'train_precision': [], 'val_precision': [],
    'train_recall': [], 'val_recall': []
}



    try:
        # Create checkpoint directory
        checkpoint_dir = Path(config['checkpoint_dir'])
        checkpoint_dir.mkdir(exist_ok=True)

        # Create datasets with error handling
        logger.info("Creating datasets...")
        try:
            dataset = MultimodalFinancialDataset(
                time_series_path='AAPL_train_data_aggregated.csv',
                text_path='AAPL_tweets_train.csv',
                window_size=config['history_len'],
                max_len=config['max_seq_len']
            )

            # Normalize features
            logger.info("Normalizing features...")
            dataset.normalize_features()

            # Split dataset into train and validation
            # val_size = int(len(dataset) * config['training']['validation_split'])
            # train_size = len(dataset) - val_size
            # train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

            # Create temporal train-validation split
            train_dataset, val_dataset = create_temporal_split(
                dataset=dataset,
                val_ratio=config['training']['validation_split'],
                window_size=config['history_len']
            )

            logger.info(f"Train size: {len(train_dataset)}, Validation size: {len(val_dataset)}")


        except Exception as e:
            logger.error(f"Error creating dataset: {str(e)}")
            raise

        # Create dataloaders with error handling
        logger.info("Creating dataloaders...")
        try:
            train_loader = DataLoader(
                train_dataset,
                batch_size=config['training']['batch_size'],
                shuffle=True,
                collate_fn=custom_collate_fn,
                num_workers=2,
                pin_memory=True
            )

            val_loader = DataLoader(
                val_dataset,
                batch_size=config['training']['batch_size'],
                shuffle=False,
                collate_fn=custom_collate_fn,
                num_workers=2,
                pin_memory=True
            )


        except Exception as e:
            logger.error(f"Error creating dataloaders: {str(e)}")
            raise

        # Initialize model
        logger.info("Initializing model...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        logger.info(f"Using device: {device}")

        try:
            model = PatchTSTWithBERT(
                config=config,
                dataset=dataset,
                bert_model=config['bert_model']
            ).to(device)

            # Print model summary
            total_params = sum(p.numel() for p in model.parameters())
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            logger.info(f"Total parameters: {total_params:,}")
            logger.info(f"Trainable parameters: {trainable_params:,}")

        except Exception as e:
            logger.error(f"Error initializing model: {str(e)}")
            raise

        # Create training task
        logger.info("Setting up training task...")
        task = UnsupervisedSegmentationTask(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            config=config
        )

        # Training loop with enhanced error handling and monitoring
        logger.info("Starting training...")
        best_val_loss = float('inf')
        early_stopping_counter = 0

        try:
            for epoch in range(config['training']['epochs']):
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()

                # Train
                train_metrics = task.train_epoch(epoch)

                # Store training metrics with key mapping
                for k, v in train_metrics.items():
                    if k == 'accuracy':
                        metrics_history['train_accuracy'].append(v)
                    else:
                        metrics_history[f'train_{k}'].append(v)

                logger.info(f"Epoch {epoch + 1}/{config['training']['epochs']}, "
                          f"Train Loss: {train_metrics['loss']:.4f}, "
                          f"Train Accuracy: {train_metrics['accuracy']:.4f}, "  # Changed from Acc to Accuracy
                          f"Train F1: {train_metrics['f1']:.4f}")

                # Validate
                val_metrics = task.validate(val_loader)

                # Store validation metrics with key mapping
                for k, v in val_metrics.items():
                    if k == 'accuracy':
                        metrics_history['val_accuracy'].append(v)
                    else:
                        metrics_history[f'val_{k}'].append(v)

                # Save checkpoint if validation loss improved
                if val_metrics['loss'] < best_val_loss:
                    best_val_loss = val_metrics['loss']
                    early_stopping_counter = 0

                    # Let the task handle checkpoint saving
                    task.save_checkpoint(checkpoint_dir / f'best_model_epoch_{epoch}.pth')
                    logger.info(f"Saved best model with val_loss: {val_metrics['loss']:.4f}")
                else:
                    early_stopping_counter += 1

                # Early stopping check
                if early_stopping_counter >= config['training']['early_stopping_patience']:
                    logger.info("Early stopping triggered")
                    break

        except KeyboardInterrupt:
            logger.info("Training interrupted by user")

        except Exception as e:
            logger.error(f"Error during training: {str(e)}")
            import traceback
            logger.error(traceback.format_exc())

        finally:
            final_checkpoint_path = checkpoint_dir / 'final_model.pth'
            task.save_checkpoint(final_checkpoint_path)
            logger.info(f"Final model saved to {final_checkpoint_path}")


        logger.info("Creating final visualizations...")

        # 1. Plot training history
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))

        # Loss plot
        axes[0,0].plot(metrics_history['train_loss'], label='Train')
        axes[0,0].plot(metrics_history['val_loss'], label='Validation')
        axes[0,0].set_title('Loss Over Epochs')
        axes[0,0].set_xlabel('Epoch')
        axes[0,0].set_ylabel('Loss')
        axes[0,0].legend()

        # Accuracy plot
        axes[0,1].plot(metrics_history['train_accuracy'], label='Train')
        axes[0,1].plot(metrics_history['val_accuracy'], label='Validation')
        axes[0,1].set_title('Accuracy Over Epochs')
        axes[0,1].set_xlabel('Epoch')
        axes[0,1].set_ylabel('Accuracy')
        axes[0,1].legend()

        # F1 Score plot
        axes[1,0].plot(metrics_history['train_f1'], label='Train')
        axes[1,0].plot(metrics_history['val_f1'], label='Validation')
        axes[1,0].set_title('F1 Score Over Epochs')
        axes[1,0].set_xlabel('Epoch')
        axes[1,0].set_ylabel('F1 Score')
        axes[1,0].legend()

        # Make predictions on validation set
        model.eval()
        # all_preds = []
        # all_targets = []
        # with torch.no_grad():
        #     for batch in val_loader:
        #         inputs = batch['x_enc'].to(device)
        #         targets = batch['labels'].to(device)
        #         outputs = model(inputs, batch['text'])
        #         preds = torch.argmax(outputs, dim=1)
        #         all_preds.extend(preds.cpu().numpy())
        #         all_targets.extend(targets.cpu().numpy())

        val_indices = []  # Keep track of validation indices
        val_preds = []
        val_targets = []

        with torch.no_grad():
            for batch_idx, batch in enumerate(val_loader):
                inputs = batch['x_enc'].to(device)
                targets = batch['labels'].to(device)
                outputs = model(inputs, batch['text'])
                preds = torch.argmax(outputs, dim=1)

                # Collect predictions and targets
                val_preds.extend(preds.cpu().numpy())
                val_targets.extend(targets.cpu().numpy())

                # Compute and store validation indices
                batch_start_idx = batch_idx * config['training']['batch_size']
                batch_end_idx = min((batch_idx + 1) * config['training']['batch_size'], len(val_dataset))
                val_indices.extend(range(batch_start_idx, batch_end_idx))

        # Confusion Matrix
        cm = confusion_matrix(val_targets, val_preds)
        sns.heatmap(cm, annot=True, fmt='d', ax=axes[1,1])
        axes[1,1].set_title('Confusion Matrix')
        axes[1,1].set_xlabel('Predicted')
        axes[1,1].set_ylabel('Actual')

        plt.tight_layout()
        plt.savefig('training_metrics.png')
        plt.close()

        # Save metrics to file
        with open('training_metrics.txt', 'w') as f:
            f.write("Final Training Metrics:\n")
            for k, v in train_metrics.items():
                f.write(f"{k}: {v:.4f}\n")
            f.write("\nFinal Validation Metrics:\n")
            for k, v in val_metrics.items():
                f.write(f"{k}: {v:.4f}\n")

        # Plot example predictions
        # fig = plot_financial_predictions(
        #     model=model,
        #     dataset=dataset,
        #     predictions=all_preds,
        #     targets=all_targets,
        #     window_size=config['history_len']
        # )
        fig = plot_financial_predictions(
          model=model,
          dataset=dataset,
          predictions=val_preds,
          targets=val_targets,
          window_size=config['history_len']
      )
        plt.savefig('predictions_visualization.png')
        plt.close()

    except Exception as e:
        logger.error(f"Fatal error: {str(e)}")
        import traceback
        logger.error(traceback.format_exc())
        raise


Processing time series data...
Loading text data...
Merging datasets...
Processing features...
Dataset shapes:
Features: (334, 5, 390)
Number of text samples: 334
Training period: 2021-01-04 to 2022-01-12
Validation period: 2022-01-20 to 2022-04-29
Input features shape: (334, 5, 390)
Patch length: 30
Stride: 15
Final projection dim: 16640


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Epoch 1:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])


Epoch 1:   6%|▌         | 1/17 [00:06<01:42,  6.42s/it, acc=0.7500, f1=0.8571, loss=0.5538]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  12%|█▏        | 2/17 [00:11<01:26,  5.74s/it, acc=0.7500, f1=0.8571, loss=0.5538]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  18%|█▊        | 3/17 [00:17<01:18,  5.62s/it, acc=0.7500, f1=0.8571, loss=0.5538]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  24%|██▎       | 4/17 [00:22<01:11,  5.48s/it, acc=0.7500, f1=0.8571, loss=0.5538]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  29%|██▉       | 5/17 [00:27<01:04,  5.35s/it, acc=0.7500, f1=0.8571, loss=0.5538]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  35%|███▌      | 6/17 [00:33<01:00,  5.51s/it, acc=0.7500, f1=0.8571, loss=0.5538]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  41%|████      | 7/17 [00:38<00:54,  5.46s/it, acc=0.7500, f1=0.8571, loss=0.5538]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  47%|████▋     | 8/17 [00:43<00:47,  5.33s/it, acc=0.7500, f1=0.8571, loss=0.5538]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  53%|█████▎    | 9/17 [00:48<00:42,  5.25s/it, acc=0.7500, f1=0.8571, loss=0.5538]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  59%|█████▉    | 10/17 [00:53<00:36,  5.17s/it, acc=0.7500, f1=0.8571, loss=0.5538]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  65%|██████▍   | 11/17 [00:58<00:30,  5.13s/it, acc=0.5000, f1=0.5319, loss=0.6580]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  71%|███████   | 12/17 [01:04<00:25,  5.14s/it, acc=0.5000, f1=0.5319, loss=0.6580]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  76%|███████▋  | 13/17 [01:09<00:20,  5.22s/it, acc=0.5000, f1=0.5319, loss=0.6580]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  82%|████████▏ | 14/17 [01:14<00:15,  5.29s/it, acc=0.5000, f1=0.5319, loss=0.6580]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  88%|████████▊ | 15/17 [01:19<00:10,  5.20s/it, acc=0.5000, f1=0.5319, loss=0.6580]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 1:  94%|█████████▍| 16/17 [01:24<00:05,  5.11s/it, acc=0.5000, f1=0.5319, loss=0.6580]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 1: 100%|██████████| 17/17 [01:25<00:00,  5.05s/it, acc=0.5000, f1=0.5319, loss=0.6580]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 

Epoch 2:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:   6%|▌         | 1/17 [00:05<01:25,  5.37s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  12%|█▏        | 2/17 [00:10<01:20,  5.36s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  18%|█▊        | 3/17 [00:15<01:13,  5.25s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  24%|██▎       | 4/17 [00:20<01:06,  5.12s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  29%|██▉       | 5/17 [00:25<01:01,  5.15s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  35%|███▌      | 6/17 [00:31<00:56,  5.15s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  41%|████      | 7/17 [00:35<00:50,  5.03s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  47%|████▋     | 8/17 [00:40<00:44,  4.98s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  53%|█████▎    | 9/17 [00:46<00:40,  5.12s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  59%|█████▉    | 10/17 [00:50<00:34,  4.98s/it, acc=0.5625, f1=0.7200, loss=0.9532]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  65%|██████▍   | 11/17 [00:55<00:29,  4.98s/it, acc=0.5170, f1=0.6083, loss=0.8765]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  71%|███████   | 12/17 [01:01<00:25,  5.08s/it, acc=0.5170, f1=0.6083, loss=0.8765]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  76%|███████▋  | 13/17 [01:06<00:21,  5.27s/it, acc=0.5170, f1=0.6083, loss=0.8765]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  82%|████████▏ | 14/17 [01:12<00:15,  5.24s/it, acc=0.5170, f1=0.6083, loss=0.8765]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  88%|████████▊ | 15/17 [01:17<00:10,  5.17s/it, acc=0.5170, f1=0.6083, loss=0.8765]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 2:  94%|█████████▍| 16/17 [01:21<00:05,  5.10s/it, acc=0.5170, f1=0.6083, loss=0.8765]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 2: 100%|██████████| 17/17 [01:22<00:00,  4.88s/it, acc=0.5170, f1=0.6083, loss=0.8765]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 

Epoch 3:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:   6%|▌         | 1/17 [00:04<01:19,  4.98s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  12%|█▏        | 2/17 [00:10<01:15,  5.03s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  18%|█▊        | 3/17 [00:15<01:10,  5.02s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  24%|██▎       | 4/17 [00:20<01:07,  5.19s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  29%|██▉       | 5/17 [00:25<01:01,  5.16s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  35%|███▌      | 6/17 [00:31<00:57,  5.25s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  41%|████      | 7/17 [00:36<00:51,  5.18s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  47%|████▋     | 8/17 [00:41<00:46,  5.16s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  53%|█████▎    | 9/17 [00:46<00:41,  5.20s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  59%|█████▉    | 10/17 [00:51<00:35,  5.12s/it, acc=0.6875, f1=0.7619, loss=0.4824]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  65%|██████▍   | 11/17 [00:56<00:30,  5.12s/it, acc=0.7159, f1=0.7788, loss=0.5469]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  71%|███████   | 12/17 [01:01<00:25,  5.15s/it, acc=0.7159, f1=0.7788, loss=0.5469]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  76%|███████▋  | 13/17 [01:06<00:20,  5.12s/it, acc=0.7159, f1=0.7788, loss=0.5469]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  82%|████████▏ | 14/17 [01:11<00:15,  5.05s/it, acc=0.7159, f1=0.7788, loss=0.5469]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  88%|████████▊ | 15/17 [01:16<00:10,  5.04s/it, acc=0.7159, f1=0.7788, loss=0.5469]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 3:  94%|█████████▍| 16/17 [01:21<00:05,  5.05s/it, acc=0.7159, f1=0.7788, loss=0.5469]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 3: 100%|██████████| 17/17 [01:22<00:00,  4.87s/it, acc=0.7159, f1=0.7788, loss=0.5469]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 

Epoch 4:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:   6%|▌         | 1/17 [00:05<01:26,  5.39s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  12%|█▏        | 2/17 [00:09<01:12,  4.84s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  18%|█▊        | 3/17 [00:14<01:07,  4.84s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  24%|██▎       | 4/17 [00:19<01:04,  4.95s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  29%|██▉       | 5/17 [00:24<01:00,  5.04s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  35%|███▌      | 6/17 [00:30<00:56,  5.10s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  41%|████      | 7/17 [00:35<00:51,  5.14s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  47%|████▋     | 8/17 [00:40<00:46,  5.16s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  53%|█████▎    | 9/17 [00:46<00:41,  5.24s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  59%|█████▉    | 10/17 [00:51<00:36,  5.23s/it, acc=0.8125, f1=0.8421, loss=0.5027]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  65%|██████▍   | 11/17 [00:56<00:31,  5.20s/it, acc=0.7784, f1=0.8282, loss=0.5101]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  71%|███████   | 12/17 [01:01<00:25,  5.16s/it, acc=0.7784, f1=0.8282, loss=0.5101]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  76%|███████▋  | 13/17 [01:06<00:20,  5.16s/it, acc=0.7784, f1=0.8282, loss=0.5101]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  82%|████████▏ | 14/17 [01:11<00:15,  5.20s/it, acc=0.7784, f1=0.8282, loss=0.5101]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  88%|████████▊ | 15/17 [01:17<00:10,  5.23s/it, acc=0.7784, f1=0.8282, loss=0.5101]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 4:  94%|█████████▍| 16/17 [01:22<00:05,  5.17s/it, acc=0.7784, f1=0.8282, loss=0.5101]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 4: 100%|██████████| 17/17 [01:23<00:00,  4.90s/it, acc=0.7784, f1=0.8282, loss=0.5101]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 

Epoch 5:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:   6%|▌         | 1/17 [00:05<01:25,  5.34s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  12%|█▏        | 2/17 [00:10<01:19,  5.27s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  18%|█▊        | 3/17 [00:15<01:13,  5.25s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  24%|██▎       | 4/17 [00:20<01:07,  5.16s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  29%|██▉       | 5/17 [00:26<01:02,  5.18s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  35%|███▌      | 6/17 [00:31<00:56,  5.17s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  41%|████      | 7/17 [00:36<00:50,  5.09s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  47%|████▋     | 8/17 [00:41<00:45,  5.10s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  53%|█████▎    | 9/17 [00:46<00:41,  5.13s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  59%|█████▉    | 10/17 [00:51<00:35,  5.03s/it, acc=0.9375, f1=0.9600, loss=0.2802]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  65%|██████▍   | 11/17 [00:56<00:30,  5.13s/it, acc=0.7898, f1=0.8384, loss=0.4023]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  71%|███████   | 12/17 [01:01<00:25,  5.14s/it, acc=0.7898, f1=0.8384, loss=0.4023]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  76%|███████▋  | 13/17 [01:06<00:20,  5.12s/it, acc=0.7898, f1=0.8384, loss=0.4023]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  82%|████████▏ | 14/17 [01:11<00:15,  5.05s/it, acc=0.7898, f1=0.8384, loss=0.4023]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  88%|████████▊ | 15/17 [01:16<00:10,  5.08s/it, acc=0.7898, f1=0.8384, loss=0.4023]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 5:  94%|█████████▍| 16/17 [01:22<00:05,  5.11s/it, acc=0.7898, f1=0.8384, loss=0.4023]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 5: 100%|██████████| 17/17 [01:23<00:00,  4.89s/it, acc=0.7898, f1=0.8384, loss=0.4023]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 

Epoch 6:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:   6%|▌         | 1/17 [00:05<01:27,  5.46s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  12%|█▏        | 2/17 [00:10<01:21,  5.41s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  18%|█▊        | 3/17 [00:16<01:15,  5.39s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  24%|██▎       | 4/17 [00:21<01:08,  5.25s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  29%|██▉       | 5/17 [00:26<01:01,  5.17s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  35%|███▌      | 6/17 [00:31<00:56,  5.10s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  41%|████      | 7/17 [00:36<00:50,  5.07s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  47%|████▋     | 8/17 [00:41<00:45,  5.04s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  53%|█████▎    | 9/17 [00:46<00:40,  5.11s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  59%|█████▉    | 10/17 [00:51<00:35,  5.02s/it, acc=0.7500, f1=0.8182, loss=0.3903]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  65%|██████▍   | 11/17 [00:56<00:30,  5.02s/it, acc=0.7955, f1=0.8302, loss=0.4927]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  71%|███████   | 12/17 [01:01<00:25,  5.01s/it, acc=0.7955, f1=0.8302, loss=0.4927]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  76%|███████▋  | 13/17 [01:06<00:20,  5.05s/it, acc=0.7955, f1=0.8302, loss=0.4927]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  82%|████████▏ | 14/17 [01:11<00:15,  5.06s/it, acc=0.7955, f1=0.8302, loss=0.4927]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  88%|████████▊ | 15/17 [01:17<00:10,  5.19s/it, acc=0.7955, f1=0.8302, loss=0.4927]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 6:  94%|█████████▍| 16/17 [01:22<00:05,  5.17s/it, acc=0.7955, f1=0.8302, loss=0.4927]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 6: 100%|██████████| 17/17 [01:23<00:00,  4.89s/it, acc=0.7955, f1=0.8302, loss=0.4927]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 

Epoch 7:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:   6%|▌         | 1/17 [00:05<01:25,  5.33s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  12%|█▏        | 2/17 [00:10<01:19,  5.29s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  18%|█▊        | 3/17 [00:15<01:12,  5.19s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  24%|██▎       | 4/17 [00:20<01:05,  5.07s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  29%|██▉       | 5/17 [00:26<01:02,  5.22s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  35%|███▌      | 6/17 [00:31<00:57,  5.24s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  41%|████      | 7/17 [00:36<00:51,  5.18s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  47%|████▋     | 8/17 [00:41<00:46,  5.11s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  53%|█████▎    | 9/17 [00:46<00:41,  5.17s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  59%|█████▉    | 10/17 [00:51<00:35,  5.12s/it, acc=0.8750, f1=0.9000, loss=0.2672]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  65%|██████▍   | 11/17 [00:56<00:30,  5.14s/it, acc=0.8068, f1=0.8426, loss=0.4434]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  71%|███████   | 12/17 [01:02<00:25,  5.17s/it, acc=0.8068, f1=0.8426, loss=0.4434]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  76%|███████▋  | 13/17 [01:06<00:20,  5.07s/it, acc=0.8068, f1=0.8426, loss=0.4434]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  82%|████████▏ | 14/17 [01:12<00:15,  5.12s/it, acc=0.8068, f1=0.8426, loss=0.4434]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  88%|████████▊ | 15/17 [01:17<00:10,  5.09s/it, acc=0.8068, f1=0.8426, loss=0.4434]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 7:  94%|█████████▍| 16/17 [01:22<00:05,  5.05s/it, acc=0.8068, f1=0.8426, loss=0.4434]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 7: 100%|██████████| 17/17 [01:23<00:00,  4.89s/it, acc=0.8068, f1=0.8426, loss=0.4434]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 

Epoch 8:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:   6%|▌         | 1/17 [00:05<01:23,  5.22s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  12%|█▏        | 2/17 [00:10<01:16,  5.13s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  18%|█▊        | 3/17 [00:15<01:13,  5.26s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  24%|██▎       | 4/17 [00:20<01:07,  5.17s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  29%|██▉       | 5/17 [00:26<01:03,  5.29s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  35%|███▌      | 6/17 [00:31<00:57,  5.24s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  41%|████      | 7/17 [00:36<00:52,  5.23s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  47%|████▋     | 8/17 [00:42<00:47,  5.30s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  53%|█████▎    | 9/17 [00:47<00:42,  5.26s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  59%|█████▉    | 10/17 [00:52<00:35,  5.14s/it, acc=0.8750, f1=0.8571, loss=0.3789]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  65%|██████▍   | 11/17 [00:57<00:30,  5.16s/it, acc=0.8409, f1=0.8704, loss=0.6256]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  71%|███████   | 12/17 [01:02<00:25,  5.13s/it, acc=0.8409, f1=0.8704, loss=0.6256]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  76%|███████▋  | 13/17 [01:07<00:20,  5.03s/it, acc=0.8409, f1=0.8704, loss=0.6256]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  82%|████████▏ | 14/17 [01:12<00:15,  5.08s/it, acc=0.8409, f1=0.8704, loss=0.6256]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  88%|████████▊ | 15/17 [01:17<00:10,  5.12s/it, acc=0.8409, f1=0.8704, loss=0.6256]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 8:  94%|█████████▍| 16/17 [01:22<00:05,  5.06s/it, acc=0.8409, f1=0.8704, loss=0.6256]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 8: 100%|██████████| 17/17 [01:23<00:00,  4.91s/it, acc=0.8409, f1=0.8704, loss=0.6256]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 

Epoch 9:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:   6%|▌         | 1/17 [00:05<01:21,  5.11s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  12%|█▏        | 2/17 [00:10<01:16,  5.12s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  18%|█▊        | 3/17 [00:15<01:11,  5.11s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  24%|██▎       | 4/17 [00:20<01:07,  5.17s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  29%|██▉       | 5/17 [00:25<01:02,  5.18s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  35%|███▌      | 6/17 [00:31<00:57,  5.24s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  41%|████      | 7/17 [00:36<00:52,  5.26s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  47%|████▋     | 8/17 [00:41<00:46,  5.16s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  53%|█████▎    | 9/17 [00:46<00:40,  5.11s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  59%|█████▉    | 10/17 [00:51<00:36,  5.15s/it, acc=0.7500, f1=0.7778, loss=0.4022]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  65%|██████▍   | 11/17 [00:56<00:30,  5.08s/it, acc=0.8920, f1=0.9005, loss=0.4318]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  71%|███████   | 12/17 [01:01<00:25,  5.16s/it, acc=0.8920, f1=0.9005, loss=0.4318]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  76%|███████▋  | 13/17 [01:07<00:20,  5.17s/it, acc=0.8920, f1=0.9005, loss=0.4318]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  82%|████████▏ | 14/17 [01:12<00:15,  5.13s/it, acc=0.8920, f1=0.9005, loss=0.4318]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  88%|████████▊ | 15/17 [01:17<00:10,  5.12s/it, acc=0.8920, f1=0.9005, loss=0.4318]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 9:  94%|█████████▍| 16/17 [01:22<00:05,  5.08s/it, acc=0.8920, f1=0.9005, loss=0.4318]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 9: 100%|██████████| 17/17 [01:23<00:00,  4.89s/it, acc=0.8920, f1=0.9005, loss=0.4318]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 

Epoch 10:   0%|          | 0/17 [00:00<?, ?it/s]

Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:   6%|▌         | 1/17 [00:05<01:32,  5.80s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  12%|█▏        | 2/17 [00:11<01:22,  5.51s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  18%|█▊        | 3/17 [00:16<01:13,  5.27s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  24%|██▎       | 4/17 [00:21<01:09,  5.32s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  29%|██▉       | 5/17 [00:26<01:02,  5.23s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  35%|███▌      | 6/17 [00:31<00:56,  5.17s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  41%|████      | 7/17 [00:36<00:51,  5.11s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  47%|████▋     | 8/17 [00:41<00:46,  5.20s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  53%|█████▎    | 9/17 [00:47<00:42,  5.26s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  59%|█████▉    | 10/17 [00:52<00:35,  5.10s/it, acc=0.9375, f1=0.9474, loss=0.2698]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  65%|██████▍   | 11/17 [00:57<00:30,  5.08s/it, acc=0.8750, f1=0.9043, loss=0.2887]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  71%|███████   | 12/17 [01:02<00:25,  5.06s/it, acc=0.8750, f1=0.9043, loss=0.2887]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  76%|███████▋  | 13/17 [01:07<00:20,  5.08s/it, acc=0.8750, f1=0.9043, loss=0.2887]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  82%|████████▏ | 14/17 [01:12<00:15,  5.02s/it, acc=0.8750, f1=0.9043, loss=0.2887]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  88%|████████▊ | 15/17 [01:16<00:09,  4.94s/it, acc=0.8750, f1=0.9043, loss=0.2887]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])


Epoch 10:  94%|█████████▍| 16/17 [01:22<00:05,  5.02s/it, acc=0.8750, f1=0.9043, loss=0.2887]

After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([3, 5, 5, 390])
After patch embedding: torch.Size([3, 5, 26, 128])


Epoch 10: 100%|██████████| 17/17 [01:23<00:00,  4.88s/it, acc=0.8750, f1=0.9043, loss=0.2887]

After text fusion: torch.Size([3, 5, 26, 128])
Before encoder: torch.Size([15, 26, 128])
After encoder: torch.Size([15, 26, 128])
After reshaping: torch.Size([3, 5, 3328])
Final shape before projection: torch.Size([3, 16640])
Projection weight shape: torch.Size([2, 16640])





Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 3328])
Final shape before projection: torch.Size([16, 16640])
Projection weight shape: torch.Size([2, 16640])
Input shape: torch.Size([16, 5, 5, 390])
After patch embedding: torch.Size([16, 5, 26, 128])
After text fusion: torch.Size([16, 5, 26, 128])
Before encoder: torch.Size([80, 26, 128])
After encoder: torch.Size([80, 26, 128])
After reshaping: torch.Size([16, 5, 