In [1]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
import gc
from tqdm import tqdm

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class Config:
    WINDOW_SIZE = 60
    N_REGIMES = 3
    BATCH_SIZE = 32
    EPOCHS = 100
    FEATURE_DIM = 64
    HIDDEN_DIM = 128
    LEARNING_RATE = 0.001
    TEMPERATURE = 0.1
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    REGIME_NAMES = ["Bearish", "Sideways", "Bullish"]
    REGIME_COLORS = ["#FF5050", "#808080", "#50C878"]

# 1. Load and preprocess the data with memory optimizations
def load_and_preprocess_data(file_path, window_size=60):
    print("Loading data...")
    df = pd.read_csv(file_path, low_memory=False)

    print(f"Data loaded: {df.shape[0]} rows, {df.shape[1]} columns")

    if 'date' in df.columns:
        df['date'] = pd.to_datetime(df['date'])
        df.sort_values('date', inplace=True)

    tickers = df['ticker'].unique() if 'ticker' in df.columns else ['data']
    print(f"Found {len(tickers)} unique tickers")

    all_data = {}

    for ticker in tickers:
        print(f"Processing ticker: {ticker}")
        if 'ticker' in df.columns:
            ticker_data = df[df['ticker'] == ticker].copy()
        else:
            ticker_data = df.copy()

        numeric_cols = ticker_data.select_dtypes(include=[np.number]).columns.tolist()
        if 'date' in numeric_cols:
            numeric_cols.remove('date')

        if len(ticker_data) < window_size:
            print(f"Not enough data for ticker {ticker}, skipping.")
            continue

        scaler = StandardScaler()
        ticker_data[numeric_cols] = scaler.fit_transform(ticker_data[numeric_cols])

        sequences = []
        data_array = ticker_data[numeric_cols].values
        for i in range(len(ticker_data) - window_size + 1):
            seq = data_array[i:i+window_size]
            sequences.append(seq)

        if sequences:
            all_data[ticker] = np.array(sequences)
            print(f"Created {len(sequences)} sequences of shape {sequences[0].shape}")

        gc.collect()

    return all_data, tickers

# 2. Define the TimeSeriesDataset class with improved augmentations
class TimeSeriesDataset(Dataset):
    def __init__(self, data, strong_transform=False, weak_transform=False):
        self.data = data
        self.strong_transform = strong_transform
        self.weak_transform = weak_transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        if self.strong_transform and self.weak_transform:
            strong_aug = self._strong_augment(sample)
            weak_aug = self._weak_augment(sample)
            return strong_aug, weak_aug

        return sample

    def _strong_augment(self, x):
        """Improved strong augmentation with controlled randomness"""
        x_aug = x.copy()

        scale = np.random.uniform(0.8, 1.2)
        x_aug = x_aug * scale

        seq_len, feat_dim = x_aug.shape
        num_segments = np.random.randint(3, 6)
        segment_size = seq_len // num_segments

        segments = []
        for i in range(num_segments):
            start_idx = i * segment_size
            end_idx = (i + 1) * segment_size if i < num_segments - 1 else seq_len
            segment = x_aug[start_idx:end_idx].copy()

            warp_factor = np.random.uniform(0.9, 1.1)

            if warp_factor != 1.0:
                new_length = max(3, int(len(segment) * warp_factor))
                indices = np.linspace(0, len(segment) - 1, new_length)
                warped_segment = np.zeros((new_length, feat_dim))

                for j in range(feat_dim):
                    warped_segment[:, j] = np.interp(indices, np.arange(len(segment)), segment[:, j])

                segments.append(warped_segment)
            else:
                segments.append(segment)

        x_concatenated = np.vstack(segments)
        if len(x_concatenated) > seq_len:
            x_concatenated = x_concatenated[:seq_len]
        elif len(x_concatenated) < seq_len:
            padding = np.zeros((seq_len - len(x_concatenated), feat_dim))
            x_concatenated = np.vstack([x_concatenated, padding])

        return x_concatenated

    def _weak_augment(self, x):
        """Improved weak augmentation for financial time series"""
        x_aug = x.copy()

        noise_scale = 0.02 * np.mean(np.abs(x_aug))
        noise = np.random.normal(0, noise_scale, x_aug.shape)
        x_aug = x_aug + noise

        mask_percent = np.random.uniform(0.01, 0.05)
        seq_len = x_aug.shape[0]
        num_points = int(seq_len * mask_percent)

        if num_points > 0:
            mask_indices = np.random.choice(np.arange(seq_len), size=num_points, replace=False)
            for idx in mask_indices:
                if idx > 0 and idx < seq_len - 1:
                    x_aug[idx] = (x_aug[idx-1] + x_aug[idx+1]) / 2

        return x_aug

# 3. Enhanced encoder network with learnable embeddings
class TSEncoder(nn.Module):
    def __init__(self, input_dim, hidden_dim=64, feature_dim=64, dropout=0.1):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.feature_dim = feature_dim

        self.embedding = nn.Linear(input_dim, hidden_dim)

        self.conv_blocks = nn.Sequential(
            nn.Conv1d(hidden_dim, hidden_dim, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv1d(hidden_dim, hidden_dim * 2, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout),

            nn.Conv1d(hidden_dim * 2, hidden_dim * 2, kernel_size=3, padding=1),
            nn.BatchNorm1d(hidden_dim * 2),
            nn.ReLU(),
            nn.Dropout(dropout)
        )

        self.attention = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

        self.projection = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, feature_dim)
        )

    def forward(self, x):
        batch_size, seq_len, _ = x.shape

        x = self.embedding(x)

        x = x.transpose(1, 2)

        x = self.conv_blocks(x)

        x = x.transpose(1, 2)

        attn_weights = self.attention(x)
        attn_weights = F.softmax(attn_weights, dim=1)

        x = torch.sum(x * attn_weights, dim=1)

        x = self.projection(x)

        return x

# 4. Improved contrastive loss with temperature adjustment
class NTXentLoss(nn.Module):
    def __init__(self, temperature=0.1, eps=1e-8):
        super().__init__()
        self.temperature = temperature
        self.eps = eps
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, z_i, z_j):
        z_i = F.normalize(z_i, dim=1)
        z_j = F.normalize(z_j, dim=1)

        batch_size = z_i.size(0)
        representations = torch.cat([z_i, z_j], dim=0)

        similarity_matrix = torch.matmul(representations, representations.t()) / self.temperature

        mask = torch.eye(2 * batch_size, dtype=torch.bool, device=similarity_matrix.device)
        similarity_matrix = similarity_matrix.masked_fill(mask, -float('inf'))

        labels = torch.cat([
            torch.arange(batch_size, 2 * batch_size),
            torch.arange(batch_size)
        ]).to(similarity_matrix.device)

        loss = self.criterion(similarity_matrix, labels)

        return loss

# 5. Training function with memory optimization and progress tracking
def train_ts_tcc(data, tickers, config):
    """
    Train the TS-TCC model with improved memory management
    """
    models = {}
    embeddings = {}

    for ticker in tickers:
        ticker_data = data.get(ticker, None)
        if ticker_data is None or len(ticker_data) == 0:
            print(f"No data for ticker {ticker}, skipping.")
            continue

        input_dim = ticker_data.shape[2]
        print(f"\nTraining model for {ticker} with input dimension {input_dim}")

        dataset = TimeSeriesDataset(ticker_data, strong_transform=True, weak_transform=True)
        dataloader = DataLoader(
            dataset,
            batch_size=config.BATCH_SIZE,
            shuffle=True,
            pin_memory=True
        )

        model = TSEncoder(
            input_dim=input_dim,
            hidden_dim=config.HIDDEN_DIM,
            feature_dim=config.FEATURE_DIM
        ).to(config.DEVICE)

        optimizer = torch.optim.Adam(
            model.parameters(),
            lr=config.LEARNING_RATE,
            weight_decay=1e-5
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=config.EPOCHS
        )

        criterion = NTXentLoss(temperature=config.TEMPERATURE)

        progress_bar = tqdm(range(config.EPOCHS), desc=f"Training {ticker}")
        for epoch in progress_bar:
            model.train()
            total_loss = 0

            for batch in dataloader:
                strong_aug, weak_aug = batch
                strong_aug = strong_aug.float().to(config.DEVICE)
                weak_aug = weak_aug.float().to(config.DEVICE)

                z_i = model(strong_aug)
                z_j = model(weak_aug)

                loss = criterion(z_i, z_j)

                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                optimizer.step()

                total_loss += loss.item()

            scheduler.step()

            avg_loss = total_loss / len(dataloader)
            progress_bar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{scheduler.get_last_lr()[0]:.6f}"})

            torch.cuda.empty_cache()
            gc.collect()

        models[ticker] = model

        print(f"Generating embeddings for {ticker}...")
        model.eval()
        dataset_no_aug = TimeSeriesDataset(ticker_data)
        dataloader_no_aug = DataLoader(
            dataset_no_aug,
            batch_size=config.BATCH_SIZE,
            shuffle=False,
            pin_memory=True
        )

        all_embeddings = []
        with torch.no_grad():
            for batch in tqdm(dataloader_no_aug, desc="Extracting embeddings"):
                batch = batch.float().to(config.DEVICE)
                embedding = model(batch)
                all_embeddings.append(embedding.cpu().numpy())

                del batch, embedding
                torch.cuda.empty_cache()

        embeddings[ticker] = np.vstack(all_embeddings)
        print(f"Generated {len(embeddings[ticker])} embeddings for {ticker}")

        del model, dataset, dataloader, dataset_no_aug, dataloader_no_aug
        torch.cuda.empty_cache()
        gc.collect()

    return models, embeddings

# 6. Improved regime identification with validation
def identify_market_regimes(embeddings, config):
    """
    Cluster the embeddings to identify different market regimes with improved validation
    """
    regime_labels = {}
    regime_models = {}

    for ticker, emb in embeddings.items():
        print(f"Identifying market regimes for {ticker}...")

        best_score = float('inf')
        best_kmeans = None

        for seed in range(5):
            kmeans = KMeans(
                n_clusters=config.N_REGIMES,
                random_state=seed,
                n_init=10,
                max_iter=500
            )
            kmeans.fit(emb)

            score = kmeans.inertia_
            if score < best_score:
                best_score = score
                best_kmeans = kmeans

        labels = best_kmeans.labels_

        centers = best_kmeans.cluster_centers_
        center_values = np.sum(centers, axis=1)
        regime_order = np.argsort(center_values)

        label_mapping = {regime_order[i]: i for i in range(config.N_REGIMES)}

        ordered_labels = np.array([label_mapping[label] for label in labels])

        regime_labels[ticker] = ordered_labels
        regime_models[ticker] = best_kmeans

        print(f"Identified {config.N_REGIMES} regimes for {ticker}")

    return regime_labels, regime_models

# 7. Enhanced visualization with better colors and clarity
def visualize_regimes(embeddings, regime_labels, tickers, config):
    """
    Visualize the identified market regimes with improved styling
    """
    fig = plt.figure(figsize=(18, 12))
    plt.style.use('seaborn-v0_8-whitegrid')

    regime_colors = config.REGIME_COLORS
    regime_names = config.REGIME_NAMES

    for i, ticker in enumerate(tickers):
        if ticker not in embeddings or ticker not in regime_labels:
            continue

        print(f"Running t-SNE for {ticker}...")

        max_samples = 5000
        if len(embeddings[ticker]) > max_samples:
            indices = np.random.choice(len(embeddings[ticker]), max_samples, replace=False)
            emb_sample = embeddings[ticker][indices]
            labels_sample = regime_labels[ticker][indices]
        else:
            emb_sample = embeddings[ticker]
            labels_sample = regime_labels[ticker]

        tsne = TSNE(
            n_components=2,
            method='barnes_hut',
            angle=0.5,
            perplexity=min(30, len(emb_sample) - 1),
            n_iter=1000,
            random_state=42,
            verbose=1
        )

        reduced_emb = tsne.fit_transform(emb_sample)

        ax = fig.add_subplot(2, 2, i+1)

        for regime in range(config.N_REGIMES):
            mask = labels_sample == regime
            ax.scatter(
                reduced_emb[mask, 0],
                reduced_emb[mask, 1],
                c=regime_colors[regime],
                label=regime_names[regime],
                alpha=0.7,
                edgecolor='w',
                linewidth=0.5,
                s=50
            )

        ax.set_title(f"Market Regimes for {ticker}", fontsize=14, fontweight='bold')
        ax.set_xlabel("t-SNE Dimension 1", fontsize=12)
        ax.set_ylabel("t-SNE Dimension 2", fontsize=12)

        ax.legend(title="Regimes", fontsize=12)

        ax.set_xticks([])
        ax.set_yticks([])

    plt.tight_layout()
    plt.savefig("market_regimes_tsne.png", dpi=300, bbox_inches='tight')
    plt.close()
    print("t-SNE visualization saved to 'market_regimes_tsne.png'")

# 8. Improved time series visualization with regime mapping
def map_regimes_to_timeseries(file_path, regime_labels, config):
    """
    Map the identified regimes back to the original time series with improved visualization
    """
    df = pd.read_csv(file_path)

    has_date_column = 'date' in df.columns
    if has_date_column:
        df['date'] = pd.to_datetime(df['date'])
        df.sort_values('date', inplace=True)
    else:
        print("No 'date' column found. Creating a numerical index instead.")
        df['date'] = pd.RangeIndex(start=0, stop=len(df), step=1)

    tickers = df['ticker'].unique() if 'ticker' in df.columns else ['data']

    plt.style.use('seaborn-v0_8-whitegrid')

    for ticker in tickers:
        if ticker not in regime_labels:
            continue

        print(f"Mapping regimes to time series for {ticker}...")

        if 'ticker' in df.columns:
            ticker_data = df[df['ticker'] == ticker].copy().reset_index(drop=True)
        else:
            ticker_data = df.copy().reset_index(drop=True)

        price_col = 'close' if 'close' in ticker_data.columns else ticker_data.select_dtypes(include=[np.number]).columns[0]

        regimes = np.full(len(ticker_data), np.nan)
        window_size = config.WINDOW_SIZE

        if len(regime_labels[ticker]) <= len(ticker_data) - window_size + 1:
            regimes[window_size-1:window_size-1+len(regime_labels[ticker])] = regime_labels[ticker]
        else:
            regimes[window_size-1:] = regime_labels[ticker][:len(ticker_data)-(window_size-1)]

        ticker_data['regime'] = regimes

        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(20, 16), gridspec_kw={'height_ratios': [3, 1]}, sharex=True)

        ax1.plot(ticker_data['date'], ticker_data[price_col], color='black', linewidth=2)
        ax1.set_title(f"{ticker} Price with Market Regimes", fontsize=16, fontweight='bold')
        ax1.set_ylabel(f"{price_col} Price", fontsize=14)
        ax1.grid(True, alpha=0.3)

        last_regime = None
        start_idx = None

        def regime_seen_before(regime_val, start_idx):
            if pd.isna(regime_val) or start_idx is None or start_idx <= 0:
                return False

            prev_regimes = [
                int(r) for r in ticker_data['regime'].iloc[:start_idx]
                if pd.notna(r)
            ]

            return int(regime_val) in prev_regimes

        for i, regime in enumerate(ticker_data['regime']):
            if pd.notna(regime) and regime != last_regime:
                if start_idx is not None and pd.notna(last_regime):
                    regime_label = ""
                    if not regime_seen_before(last_regime, start_idx):
                        regime_label = config.REGIME_NAMES[int(last_regime)]

                    ax1.axvspan(
                        ticker_data['date'].iloc[start_idx],
                        ticker_data['date'].iloc[i],
                        alpha=0.2,
                        color=config.REGIME_COLORS[int(last_regime)],
                        label=regime_label
                    )
                start_idx = i
                last_regime = regime

        if start_idx is not None and pd.notna(last_regime):
            regime_label = ""
            if not regime_seen_before(last_regime, start_idx):
                regime_label = config.REGIME_NAMES[int(last_regime)]

            ax1.axvspan(
                ticker_data['date'].iloc[start_idx],
                ticker_data['date'].iloc[-1],
                alpha=0.2,
                color=config.REGIME_COLORS[int(last_regime)],
                label=regime_label
            )

        handles, labels = ax1.get_legend_handles_labels()
        by_label = dict(zip(labels, handles))
        if "" in by_label:
            del by_label[""]
        ax1.legend(by_label.values(), by_label.keys(), loc='upper left', fontsize=12)

        for regime in range(config.N_REGIMES):
            mask = ticker_data['regime'] == regime
            if mask.any():
                ax2.scatter(
                    ticker_data.loc[mask, 'date'],
                    np.ones(mask.sum()) * regime,
                    color=config.REGIME_COLORS[regime],
                    s=100,
                    label=config.REGIME_NAMES[regime],
                    alpha=0.7
                )

        ax2.set_yticks(range(config.N_REGIMES))
        ax2.set_yticklabels(config.REGIME_NAMES)
        ax2.set_ylabel("Market Regime", fontsize=14)
        ax2.set_xlabel("Date" if has_date_column else "Index", fontsize=14)
        ax2.grid(True, alpha=0.3)

        if has_date_column:
            fig.autofmt_xdate()

        regime_descriptions = [
            "Bearish: Falling prices, negative momentum",
            "Sideways: Range-bound, low volatility",
            "Bullish: Rising prices, positive momentum"
        ]

        props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
        ax1.text(
            0.02, 0.02,
            '\n'.join(regime_descriptions),
            transform=ax1.transAxes,
            fontsize=12,
            verticalalignment='bottom',
            bbox=props
        )

        plt.tight_layout()
        plt.savefig(f"{ticker}_regimes_timeseries.png", dpi=300, bbox_inches='tight')
        plt.close()
        print(f"Time series visualization saved to '{ticker}_regimes_timeseries.png'")

# 9. Main execution function with memory management
def run_ts_tcc_regime_detection(file_path, config):
    """
    Run the full TS-TCC pipeline for market regime detection with improved memory management
    """
    print("Running TS-TCC Market Regime Detection with the following configuration:")
    for key, value in vars(config).items():
        if key != 'DEVICE':
            print(f"{key}: {value}")
    print(f"Using device: {config.DEVICE}")

    print("\n=== Loading and Preprocessing Data ===")
    data, tickers = load_and_preprocess_data(file_path, config.WINDOW_SIZE)

    print("\n=== Training TS-TCC Models ===")
    models, embeddings = train_ts_tcc(data, tickers, config)

    del data
    gc.collect()
    torch.cuda.empty_cache()

    print("\n=== Identifying Market Regimes ===")
    regime_labels, regime_models = identify_market_regimes(embeddings, config)

    print("\n=== Visualizing Market Regimes ===")
    visualize_regimes(embeddings, regime_labels, tickers, config)

    print("\n=== Mapping Regimes to Time Series ===")
    map_regimes_to_timeseries(file_path, regime_labels, config)

    print("\n=== Analysis Complete! ===")
    return models, embeddings, regime_labels, regime_models

if __name__ == "__main__":
    FILE_PATH = "melted_data.csv"

    config = Config()

    models, embeddings, regime_labels, regime_models = run_ts_tcc_regime_detection(
        file_path=FILE_PATH,
        config=config
    )

Running TS-TCC Market Regime Detection with the following configuration:
Using device: cuda

=== Loading and Preprocessing Data ===
Loading data...
Data loaded: 40495 rows, 8 columns
Found 1 unique tickers
Processing ticker: data
Created 40436 sequences of shape (60, 6)

=== Training TS-TCC Models ===

Training model for data with input dimension 6


Training data: 100%|██████████| 100/100 [42:48<00:00, 25.69s/it, loss=0.0406, lr=0.000000]


Generating embeddings for data...


Extracting embeddings: 100%|██████████| 1264/1264 [00:02<00:00, 595.12it/s]


Generated 40436 embeddings for data

=== Identifying Market Regimes ===
Identifying market regimes for data...
Identified 3 regimes for data

=== Visualizing Market Regimes ===
Running t-SNE for data...
[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 5000 samples in 0.000s...




[t-SNE] Computed neighbors for 5000 samples in 0.317s...
[t-SNE] Computed conditional probabilities for sample 1000 / 5000
[t-SNE] Computed conditional probabilities for sample 2000 / 5000
[t-SNE] Computed conditional probabilities for sample 3000 / 5000
[t-SNE] Computed conditional probabilities for sample 4000 / 5000
[t-SNE] Computed conditional probabilities for sample 5000 / 5000
[t-SNE] Mean sigma: 2.585286
[t-SNE] KL divergence after 250 iterations with early exaggeration: 81.909134
[t-SNE] KL divergence after 1000 iterations: 1.188588
t-SNE visualization saved to 'market_regimes_tsne.png'

=== Mapping Regimes to Time Series ===
No 'date' column found. Creating a numerical index instead.
Mapping regimes to time series for data...
Time series visualization saved to 'data_regimes_timeseries.png'

=== Analysis Complete! ===
