# Temporal Convolutional Network

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

The Conv1d output length is L_out = [(L_in+2p-d(k-1)-1)/s]+1
where k=kernel size, d=dilation, s=stride, p=padding on each side. <br>
For our blocks, we want L_out = L_in, p = d(k-1)/2.

In [14]:
class TCNResidualBlock(nn.Module):
    """
    Temporal Conv residual block for [B, T, D].
    - depthwise separable optional (kept simple: standard conv here)
    - GLU gating on the first conv
    """
    def __init__(self, d_model: int, *, kernel_size=5, dilation=1, dropout=0.1):
        super().__init__()
        self.d = d_model
        self.ks = kernel_size
        self.dil = dilation
        pad = (kernel_size - 1) // 2 * dilation  # “same” length after a 1-D conv with dilation for non-causal

        # GLU: we need two channel groups (one for content, one for gates)
        self.conv1 = nn.Conv1d(d_model, 2 * d_model, kernel_size, padding=pad, dilation=dilation)
        self.bn1   = nn.BatchNorm1d(2 * d_model)
        self.dropout = nn.Dropout(dropout)

        self.conv2 = nn.Conv1d(d_model, d_model, kernel_size, padding=pad, dilation=dilation)
        self.bn2   = nn.BatchNorm1d(d_model)

        # residual scale to help very deep stacks
        self.res_scale = nn.Parameter(torch.tensor(1.0))

    def forward(self, x, mask=None, **kwargs):   # ← accept & ignore extra kwarg,  x: [B, T, D]
        h = x.transpose(1, 2)                 # -> [B, D, T]
        y = self.conv1(h)                     # [B, 2D, T]
        y = self.bn1(y)
        a, b = y.chunk(2, dim=1)              # split channels, each [B,D,T]
        y = a * torch.sigmoid(b)              # GLU
        y = self.dropout(y)
        y = self.conv2(y)                     # [B, D, T]
        y = self.bn2(y)
        y = y.transpose(1, 2)                 # -> [B, T, D]
        return x + self.res_scale * F.gelu(y) # PreNorm isn’t needed; BN is in conv path


Receptive field sanity (pick layers vs crop_len)

For dilations 1,2,…,2^L−1 and kernel k, non-causal RF:

RF=1+(k−1) (i=0~L−1)∑2^i = 1 + (k-1)*(2^L - 1) (non-causal, stride 1)

In [None]:
# Make sure RF roughly covers the crop
def tcn_rf(k,L): return 1 + (k-1)*(2**L - 1)

tcn_rf(5,4) # sweet spot for 64

61

## Dilation schedule factory

In [16]:
def make_tcn_block_ctor(kernel_size=5, dropout=0.1, base=2):
    """
    Returns a function block_ctor(d_model) that creates blocks with doubling dilation:
    1, 2, 4, 8, ...
    """
    counter = {"i": 0}
    def block_ctor(d_model: int):
        dil = base ** counter["i"]
        counter["i"] += 1
        return TCNResidualBlock(d_model, kernel_size=kernel_size, dilation=dil, dropout=dropout)
    return block_ctor


In [19]:
import sys
import os
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))
from src.encoder_classifier_wrapper import EncoderClassifier, LinearFrontend, SimCLRProjector

In [22]:
# frontend on [B,T,C] -> [B,T,D]
C = 1
frontend = LinearFrontend(in_channels=C, d_model=64)

# dilations: 1,2,4,8 for num_layers=4 (RF grows fast; see note below)
block_ctor = make_tcn_block_ctor(kernel_size=5, dropout=0.1, base=2)

In [7]:
block_ctor(d_model=64)

TCNResidualBlock(
  (conv1): Conv1d(64, 128, kernel_size=(5,), stride=(1,), padding=(32,), dilation=(16,))
  (bn1): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (dropout): Dropout(p=0.1, inplace=False)
  (conv2): Conv1d(64, 64, kernel_size=(5,), stride=(1,), padding=(32,), dilation=(16,))
  (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
)

In [35]:
from src.synthetic_data import synth_trips
from src.ts_contrastive import TwoCropTSDataset, TSView, nt_xent, info_nce_two_way, train_contrastive

In [9]:
X, road, wthr, t = synth_trips(N=2000, T=128, use_accel=False, seed=7)

In [32]:
view = TSView()
ds = TwoCropTSDataset(X, M=None, view=view, crop_len=64)
loader = torch.utils.data.DataLoader(ds, batch_size=256, shuffle=True, drop_last=True)

In [33]:
projector_bn = SimCLRProjector(d=64, p=64)
model = EncoderClassifier(
    d_model=64, num_layers=4, block_ctor=block_ctor,
    pool="mean", posenc=None,                     # posenc=None for conv stacks
    final_norm="ln", final_norm_pos="post_pool",
    proj_dim=64, frontend=frontend, projector=projector_bn
)

In [34]:
# train
train_contrastive(model, loader, loss_fn=nt_xent, epochs=20, lr=1e-3, tau=0.1)


[epoch 001] loss=4.5752 | z-std=0.124 |intra diag/off=1.000/0.014 | cross diag/off=0.863/0.015
[epoch 002] loss=4.3276 | z-std=0.124 |intra diag/off=1.000/0.012 | cross diag/off=0.868/0.014
[epoch 003] loss=4.1402 | z-std=0.124 |intra diag/off=1.000/0.006 | cross diag/off=0.839/0.007
[epoch 004] loss=4.0773 | z-std=0.124 |intra diag/off=1.000/0.006 | cross diag/off=0.853/0.006
[epoch 005] loss=4.0979 | z-std=0.124 |intra diag/off=1.000/0.011 | cross diag/off=0.839/0.012
[epoch 006] loss=3.9950 | z-std=0.124 |intra diag/off=1.000/0.007 | cross diag/off=0.846/0.007
[epoch 007] loss=3.9831 | z-std=0.124 |intra diag/off=1.000/0.005 | cross diag/off=0.845/0.006
[epoch 008] loss=3.9632 | z-std=0.124 |intra diag/off=1.000/0.007 | cross diag/off=0.832/0.008
[epoch 009] loss=4.0133 | z-std=0.124 |intra diag/off=1.000/0.004 | cross diag/off=0.822/0.005
[epoch 010] loss=4.0192 | z-std=0.124 |intra diag/off=1.000/0.008 | cross diag/off=0.844/0.008
[epoch 011] loss=3.9647 | z-std=0.124 |intra diag/

InfoNCE often trades a bit of alignment (positive similarity) for uniformity (negatives pushed apart), which still lowers the loss—as your logs show (off-diag ≈ 0.004–0.015). What matters is a validation metric (e.g., frozen linear-probe accuracy or k-NN retrieval), not the raw loss or cross-diag alone.

In [36]:
# train
train_contrastive(model, loader, loss_fn=info_nce_two_way, epochs=20, lr=1e-3, tau=0.1)

[epoch 001] loss=3.3340 | z-std=0.124 |intra diag/off=1.000/0.008 | cross diag/off=0.785/0.010
[epoch 002] loss=3.2101 | z-std=0.124 |intra diag/off=1.000/0.004 | cross diag/off=0.808/0.004
[epoch 003] loss=3.2079 | z-std=0.124 |intra diag/off=1.000/0.003 | cross diag/off=0.815/0.004
[epoch 004] loss=3.1524 | z-std=0.124 |intra diag/off=1.000/0.007 | cross diag/off=0.800/0.008
[epoch 005] loss=3.1588 | z-std=0.124 |intra diag/off=1.000/0.004 | cross diag/off=0.806/0.005
[epoch 006] loss=3.1943 | z-std=0.124 |intra diag/off=1.000/0.006 | cross diag/off=0.777/0.006
[epoch 007] loss=3.1581 | z-std=0.124 |intra diag/off=1.000/0.008 | cross diag/off=0.802/0.009
[epoch 008] loss=3.1363 | z-std=0.124 |intra diag/off=1.000/0.004 | cross diag/off=0.801/0.005
[epoch 009] loss=3.1954 | z-std=0.125 |intra diag/off=1.000/0.001 | cross diag/off=0.795/0.002
[epoch 010] loss=3.1526 | z-std=0.124 |intra diag/off=1.000/0.002 | cross diag/off=0.795/0.003
[epoch 011] loss=3.1083 | z-std=0.124 |intra diag/

### The two SimCLR-style losses

NT-Xent (2B×2B)
Concatenate 2 views, compute all pairwise similarities.
Negatives: All other samples from both views ⇒ per anchor: 2B−2 negatives.
More negative counts often yields stronger uniformity

Two-way cross-view InfoNCE (B×B)
Use only cross-view sims
Negatives: Only cross-view ⇒ per anchor: B−1 negatives.
B×B is simpler and more stable.

Temperature tuning: With fewer negatives (B×B), you often need a smaller τ to keep the positive sharp. Switching from 2B→B×B frequently benefits from reducing τ (e.g., 0.2→0.1).

## Linear probe

In [37]:
from src.linear_probe import split_idx, train_linear_probe, accuracy

In [41]:
dev = next(model.parameters()).device

In [42]:
model.eval()
enc = model.features  # or a wrapper that returns pooled pre-projection h
with torch.no_grad():
    Z = enc(X.to(dev))   # [N,D]
# simple linear probes (logreg) for road (3) & weather (2)
clf_road = torch.nn.Linear(Z.size(1), 3).to(dev)
clf_wthr = torch.nn.Linear(Z.size(1), 2).to(dev)

In [43]:
# Train/val split
train_idx, val_idx = split_idx(Z.size(0), val=0.2, seed=42)
Ztr, Zva = Z[train_idx], Z[val_idx]
road_tr, road_va = road[train_idx], road[val_idx]
wthr_tr, wthr_va = wthr[train_idx], wthr[val_idx]

# Train probes
W_road = train_linear_probe(Ztr, road_tr, clf_road, epochs=200, lr=1e-2, wd=1e-4, device=device)
W_wthr = train_linear_probe(Ztr, wthr_tr, clf_wthr, epochs=200, lr=1e-2, wd=1e-4, device=device)

In [44]:
# Evaluate
acc_road = accuracy(W_road, Zva, road_va, device=device)
acc_wthr = accuracy(W_wthr, Zva, wthr_va, device=device)
print(f"Linear probe – road: {acc_road:.3f}, weather: {acc_wthr:.3f}")

Linear probe – road: 0.905, weather: 0.685
