In [None]:
import numpy as np
import matplotlib.pyplot as plt
from model_zigma import ZigMa  # Assuming ZigMa is defined in model_zigma.py
from sklearn.linear_model import QuantileRegressor
import torch
import torch.nn as nn


#Detect Model
class ObjectDetectionModel(nn.Module):
    def __init__(self, in_channels=3, img_dim=256):
        super(ObjectDetectionModel, self).__init__()
        self.zigma = ZigMa(
            in_channels=in_channels,
            embed_dim=640,
            depth=18,
            img_dim=img_dim,
            patch_size=1,
            has_text=False,
            d_context=768,
            n_context_token=77,
            device="cuda",
            scan_type="zigzagN8",
            use_pe=2,
        )

    def forward(self, x):
        return self.zigma(x)

# Generate synth data for demonstration
np.random.seed(0)
X = np.random.rand(100, 1)
y = X.squeeze() + np.random.normal(0, 0.1, 100)

# Define quantiles
quantiles = [0.05, 0.5, 0.95]
predictions = {}

# Fit Quantile Regressor for each quantile
for quantile in quantiles:
    qr = QuantileRegressor(quantile=quantile, alpha=0)
    qr.fit(X, y)
    predictions[quantile] = qr.predict(X)

# Plot the results
plt.scatter(X, y, color='black', label='Data')
for quantile, y_pred in predictions.items():
    plt.plot(X, y_pred, label=f'Quantile: {quantile}')
plt.legend()
plt.xlabel('X')
plt.ylabel('y')
plt.title('Quantile Regression')
plt.show()

class ObjectTrackingAndClassification:
    def __init__(self):
        self.model = ObjectDetectionModel().to("cuda")
        self.quantiles = [0.05, 0.5, 0.95]
        self.qr_models = {q: QuantileRegressor(quantile=q, alpha=0) for q in self.quantiles}

    def train_quantile_regressors(self, X, y):
        for quantile, qr in self.qr_models.items():
            qr.fit(X, y)

    def predict_with_uncertainty(self, X):
        predictions = {q: qr.predict(X) for q, qr in self.qr_models.items()}
        return predictions

    def track_and_classify(self, x):
        # Object detection
        detections = self.model(x)
        
        # Example: Use bounding box coordinates as features for CQR
        X = detections.cpu().detach().numpy()
        y = np.random.rand(X.shape[0])  # Placeholder for actual labels
        
        # Train CQR models
        self.train_quantile_regressors(X, y)
        
        # Predict with uncertainty
        predictions = self.predict_with_uncertainty(X)
        return predictions

tracker = ObjectTrackingAndClassification()
x = torch.rand(10, 3, 256, 256).to("cuda")
predictions = tracker.track_and_classify(x)
print(predictions)

