In [1]:
# Run once if using Colab / fresh environment
!pip install -q mne torch torchvision matplotlib seaborn tqdm



[notice] A new release of pip is available: 24.3.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
# CELL 1: imports, device, seeds
import os, glob, time, json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import mne

# reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cpu


In [3]:
# CELL 2: constants
DATA_ROOT = "./../physionet/files"   # <-- change this to your local path if needed
SAMPLING_RATE = 160                    # PhysioNet MI sampling
TRIAL_DURATION = 3.2
N_SAMPLES = int(SAMPLING_RATE * TRIAL_DURATION)  # 512
N_CHANNELS = 64

EXCLUDED = [88, 89, 92, 100]  # corrupted per paper
EVENTS_OF_INTEREST = {'T1': 0, 'T2': 1}  # 0 = Left, 1 = Right


In [4]:
# CELL 3: load subject trials
def load_subject_physionet(subj_id, data_root=DATA_ROOT, verbose=False):
    subj_dir = os.path.join(data_root, f"S{subj_id:03d}")
    if not os.path.isdir(subj_dir):
        if verbose: print(f"[WARN] Subject folder not found: {subj_dir}")
        return None, None

    edf_files = sorted(glob.glob(os.path.join(subj_dir, "*.edf")))
    if len(edf_files) == 0:
        if verbose: print(f"[WARN] No EDF files for subject {subj_id}")
        return None, None

    X_list = []
    y_list = []
    for f in edf_files:
        try:
            raw = mne.io.read_raw_edf(f, preload=True, verbose=False)
        except Exception as e:
            print(f"[ERROR] reading {f}: {e}")
            continue
        # pick EEG only and set montage
        raw.pick_types(eeg=True, exclude=[])
        try:
            raw.set_montage('standard_1020')
        except Exception:
            pass  # if montage not perfect, continue

        # iterate annotations for T1/T2 descriptions
        annots = raw.annotations
        for ai in range(len(annots)):
            desc = annots[ai]['description']
            if desc in EVENTS_OF_INTEREST:
                onset = annots[ai]['onset']  # seconds
                start_sample = int(onset * SAMPLING_RATE)
                stop_sample = start_sample + N_SAMPLES
                # convert to sample indices in MNE: raw.times are in seconds
                # Use raw.get_data() slicing (channels, samples)
                data = raw.get_data(start=start_sample, stop=stop_sample)  # shape (n_channels, samples)
                if data.shape[0] != N_CHANNELS:
                    # If channel count mismatch, attempt to re-reference or skip
                    if verbose:
                        print(f"[WARN] {subj_id} file {f} has {data.shape[0]} channels (expected {N_CHANNELS}). Skipping this trial.")
                    continue
                if data.shape[1] != N_SAMPLES:
                    if verbose:
                        print(f"[WARN] {subj_id} trial length mismatch: {data.shape[1]} != {N_SAMPLES}. Skipping.")
                    continue
                # convert to float32
                X_list.append(data.astype(np.float32))
                y_list.append(EVENTS_OF_INTEREST[desc])
    if len(X_list) == 0:
        return None, None
    X = np.stack(X_list, axis=0)  # (n_trials, n_channels, n_samples)
    y = np.array(y_list, dtype=np.int64)
    if verbose:
        counts = np.bincount(y, minlength=2)
        print(f"Loaded S{subj_id:03d}: {X.shape[0]} trials, Left={counts[0]}, Right={counts[1]}")
    return X, y


In [5]:
# CELL 4: load all subjects (prints progress)
def load_all_subjects(verbose=True):
    X_all = []
    y_all = []
    subj_ids = []
    missing = []
    for sid in range(1, 110):  # subjects 1..109
        if sid in EXCLUDED:
            if verbose: print(f"Skipping excluded subject S{sid:03d}")
            continue
        Xi, yi = load_subject_physionet(sid, verbose=False)
        if Xi is None:
            missing.append(sid)
            continue
        X_all.append(Xi)
        y_all.append(yi)
        subj_ids.append(np.full(len(yi), sid, dtype=np.int32))
        if verbose:
            c0, c1 = np.bincount(yi, minlength=2)
            print(f"Loaded S{sid:03d}: trials={Xi.shape[0]}, left={c0}, right={c1}")
    if len(X_all) == 0:
        raise RuntimeError("No subjects loaded. Check DATA_ROOT path.")
    X_all = np.concatenate(X_all, axis=0)
    y_all = np.concatenate(y_all, axis=0)
    subj_ids = np.concatenate(subj_ids, axis=0)
    if verbose:
        unique_subj = np.unique(subj_ids)
        print(f"Total subjects loaded: {len(unique_subj)} (should be 105).")
        print(f"Total trials: {X_all.shape[0]}, Left/Right counts: {np.bincount(y_all)}")
        if len(missing) > 0:
            print("Missing subjects:", missing)
    return X_all, y_all, subj_ids

# Run loading (this may take several minutes depending on disk)
X_all, y_all, subj_ids = load_all_subjects(verbose=True)
print("DONE loading all subjects. Data shape:", X_all.shape)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
NOTE: pick_types() is a legacy function. New code should use inst.pick(...).

In [6]:
# CELL 5: preprocessing (fixed for MNE float64 requirement)
from mne.filter import filter_data

APPLY_BANDPASS = True
BANDPASS = (4.0, 45.0)  # μ and β bands

def preprocess_trials(X, apply_bandpass=APPLY_BANDPASS, band=BANDPASS, sfreq=SAMPLING_RATE):
    """
    Apply bandpass and per-trial z-score normalization.
    MNE's filter_data requires float64, so we cast accordingly.
    """
    Xp = np.empty_like(X, dtype=np.float32)
    for i in range(X.shape[0]):
        trial = X[i]  # (C, T)
        if apply_bandpass:
            # cast to float64 for mne filtering
            trial64 = trial.astype(np.float64)
            trial_f = filter_data(trial64, sfreq=sfreq,
                                  l_freq=band[0], h_freq=band[1],
                                  verbose=False)
        else:
            trial_f = trial
        # per-channel z-score (per trial)
        mean = trial_f.mean(axis=1, keepdims=True)
        std  = trial_f.std(axis=1, keepdims=True) + 1e-6
        trial_z = (trial_f - mean) / std
        Xp[i] = trial_z.astype(np.float32)
        if (i + 1) % 500 == 0 or i == X.shape[0] - 1:
            print(f"Preprocessed {i + 1}/{X.shape[0]} trials")
    return Xp

print("Preprocessing all trials (this can be slow)...")
X_proc = preprocess_trials(X_all, apply_bandpass=APPLY_BANDPASS)
print("Preprocessing done. Shape:", X_proc.shape)


Preprocessing all trials (this can be slow)...
Preprocessed 500/18887 trials
Preprocessed 1000/18887 trials
Preprocessed 1500/18887 trials
Preprocessed 2000/18887 trials
Preprocessed 2500/18887 trials
Preprocessed 3000/18887 trials
Preprocessed 3500/18887 trials
Preprocessed 4000/18887 trials
Preprocessed 4500/18887 trials
Preprocessed 5000/18887 trials
Preprocessed 5500/18887 trials
Preprocessed 6000/18887 trials
Preprocessed 6500/18887 trials
Preprocessed 7000/18887 trials
Preprocessed 7500/18887 trials
Preprocessed 8000/18887 trials
Preprocessed 8500/18887 trials
Preprocessed 9000/18887 trials
Preprocessed 9500/18887 trials
Preprocessed 10000/18887 trials
Preprocessed 10500/18887 trials
Preprocessed 11000/18887 trials
Preprocessed 11500/18887 trials
Preprocessed 12000/18887 trials
Preprocessed 12500/18887 trials
Preprocessed 13000/18887 trials
Preprocessed 13500/18887 trials
Preprocessed 14000/18887 trials
Preprocessed 14500/18887 trials
Preprocessed 15000/18887 trials
Preprocessed 

In [7]:
# CELL 6: split subjects -> test 10 subjects, train rest (95)
unique_subj = np.unique(subj_ids)
np.random.seed(SEED)
test_subj = np.random.choice(unique_subj, size=10, replace=False)
train_mask = ~np.isin(subj_ids, test_subj)
test_mask = np.isin(subj_ids, test_subj)

X_train, y_train, sid_train = X_proc[train_mask], y_all[train_mask], subj_ids[train_mask]
X_test,  y_test,  sid_test  = X_proc[test_mask],  y_all[test_mask],  subj_ids[test_mask]

print("Train subjects (count):", len(np.unique(sid_train)))
print("Test subjects (count):", len(np.unique(sid_test)))
print("Test subjects list:", sorted(test_subj.tolist()))
print("Train trials:", X_train.shape, "Test trials:", X_test.shape)
print("Train class distribution:", np.bincount(y_train), "Test class distribution:", np.bincount(y_test))


Train subjects (count): 95
Test subjects (count): 10
Test subjects list: [1, 11, 31, 46, 48, 54, 65, 66, 98, 109]
Train trials: (17087, 64, 512) Test trials: (1800, 64, 512)
Train class distribution: [8556 8531] Test class distribution: [900 900]


In [8]:
# CELL 7: dataset + loaders
class EEGDataset(Dataset):
    def __init__(self, X, y):
        # X: (N, C, T)
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return self.y.shape[0]
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

BATCH_SIZE = 20

train_ds = EEGDataset(X_train, y_train)
test_ds = EEGDataset(X_test, y_test)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, drop_last=False)
val_loader = DataLoader(test_ds,  batch_size=BATCH_SIZE, shuffle=False, drop_last=False)

print("train batches:", len(train_loader), "val batches:", len(val_loader))


train batches: 855 val batches: 90


In [9]:
# CELL 8: model implementation
class TFEMBlock(nn.Module):
    def __init__(self, n_channels, kernel_size=16, dropout=0.25, pool=False):
        super().__init__()
        self.conv = nn.Conv1d(in_channels=n_channels, out_channels=n_channels,
                              kernel_size=kernel_size, padding=kernel_size//2,
                              groups=n_channels, bias=False)
        self.bn = nn.BatchNorm1d(n_channels)
        self.act = nn.ELU()
        self.pool = nn.AvgPool1d(kernel_size=2, stride=2) if pool else nn.Identity()
        self.drop = nn.Dropout(dropout)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        x = self.pool(x)
        x = self.drop(x)
        return x

class CARMBlock(nn.Module):
    def __init__(self, n_channels, rho=0.001):
        super().__init__()
        self.C = n_channels
        self.rho = float(rho)
        init = torch.ones(n_channels, n_channels) - torch.eye(n_channels)
        init = init * 0.01
        self.raw_W = nn.Parameter(init)
        self.theta = nn.Conv1d(in_channels=n_channels, out_channels=n_channels,
                               kernel_size=1, groups=n_channels, bias=True)
        self.act = nn.ELU()
        self.drop = nn.Dropout(0.15)
    def compute_normalized_adj(self):
        W = self.raw_W
        W_sym = 0.5 * (W + W.t())
        W_sym = W_sym * (1.0 - torch.eye(self.C, device=W_sym.device))
        W_tilde = W_sym + torch.eye(self.C, device=W_sym.device)
        deg = W_tilde.sum(dim=1)
        deg_inv_sqrt = 1.0 / torch.sqrt(deg + 1e-8)
        D_inv_sqrt = torch.diag(deg_inv_sqrt)
        W_hat = D_inv_sqrt @ W_tilde @ D_inv_sqrt
        return W_hat
    def forward(self, x):
        # x: (B, C, T)
        W_hat = self.compute_normalized_adj()
        out = torch.einsum('ij,bjt->bit', W_hat, x)
        out = self.theta(out)
        out = self.act(out)
        out = self.drop(out)
        return out
    def manual_raw_update(self, rho_override=None):
        rho = self.rho if rho_override is None else float(rho_override)
        if self.raw_W.grad is None:
            return
        with torch.no_grad():
            g = self.raw_W.grad
            self.raw_W.data = (1.0 - rho) * self.raw_W.data - rho * g
            self.raw_W.data = 0.5 * (self.raw_W.data + self.raw_W.data.t())
            self.raw_W.data.fill_diagonal_(0.0)
        self.raw_W.grad.zero_()
    def get_W_hat_numpy(self):
        with torch.no_grad():
            W_hat = self.compute_normalized_adj().cpu().numpy()
        return W_hat

class EEG_ARNN_6(nn.Module):
    def __init__(self, n_channels=64, n_classes=2, kernel_size=16, dropout=0.25, rho=0.001, pool_pattern=None):
        super().__init__()
        self.C = n_channels
        if pool_pattern is None:
            pool_pattern = [False, True, False, True, False, False]
        self.tfems = nn.ModuleList([TFEMBlock(n_channels=n_channels, kernel_size=kernel_size, dropout=dropout, pool=pool_pattern[i]) for i in range(6)])
        self.carms = nn.ModuleList([CARMBlock(n_channels=n_channels, rho=rho) for _ in range(6)])
        self.channel_fusion = nn.Sequential(
            nn.Linear(n_channels, 256),
            nn.ELU(),
            nn.Dropout(0.35),
            nn.Linear(256, 128),
            nn.ELU(),
            nn.Dropout(0.3)
        )
        self.classifier = nn.Linear(128, n_classes)
    def forward(self, x):
        for tfem, carm in zip(self.tfems, self.carms):
            x = tfem(x)
            x = carm(x)
        x = x.mean(dim=2)
        x = self.channel_fusion(x)
        logits = self.classifier(x)
        return logits
    def manual_carm_updates(self):
        for carm in self.carms:
            carm.manual_raw_update()
    def get_all_W_hats(self):
        return [carm.get_W_hat_numpy() for carm in self.carms]

# instantiate
model = EEG_ARNN_6(n_channels=N_CHANNELS, n_classes=2, kernel_size=16, dropout=0.25, rho=0.001)
model = model.to(device)
print(model)


EEG_ARNN_6(
  (tfems): ModuleList(
    (0): TFEMBlock(
      (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), padding=(8,), groups=64, bias=False)
      (bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ELU(alpha=1.0)
      (pool): Identity()
      (drop): Dropout(p=0.25, inplace=False)
    )
    (1): TFEMBlock(
      (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), padding=(8,), groups=64, bias=False)
      (bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ELU(alpha=1.0)
      (pool): AvgPool1d(kernel_size=(2,), stride=(2,), padding=(0,))
      (drop): Dropout(p=0.25, inplace=False)
    )
    (2): TFEMBlock(
      (conv): Conv1d(64, 64, kernel_size=(16,), stride=(1,), padding=(8,), groups=64, bias=False)
      (bn): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ELU(alpha=1.0)
      (pool): Identity()
      (drop): Dropout(p=0.25, 

In [19]:
# CELL 9: optimizer excluding raw_W
def get_trainable_params_excluding_raw(model):
    params = []
    for name, p in model.named_parameters():
        if 'raw_W' in name:
            # skip raw_W
            continue
        params.append(p)
    return params

LR = 5e-3      #1e-3
WEIGHT_DECAY = 5e-4    #1e-4
EPOCHS = 30   # start  with 120-200; paper used 500 for final runs
SAVE_EVERY = 10

optimizer = torch.optim.Adam(get_trainable_params_excluding_raw(model), lr=LR, weight_decay=WEIGHT_DECAY)
criterion = nn.CrossEntropyLoss()


In [None]:
# CELL 10: training loop
OUT_DIR = "runs/physionet_eegarnn_exp1"
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUT_DIR, "W_hats"), exist_ok=True)
history = {"train_loss": [], "train_acc": [], "val_acc": [], "epoch_time": []}

DEBUG_RAWW_GRADS = True  # set False if too verbose

def evaluate_model(model, loader, device=device):
    model.eval()
    total = 0; correct = 0
    with torch.no_grad():
        for xb, yb in loader:
            xb = xb.to(device); yb = yb.to(device)
            logits = model(xb)
            preds = logits.argmax(dim=1)
            correct += (preds == yb).sum().item()
            total += xb.size(0)
    return 100.0 * correct / total

for epoch in range(1, EPOCHS+1):
    t0 = time.time()
    model.train()
    running_loss = 0.0
    total = 0; correct = 0
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{EPOCHS}", leave=True)
    batch_idx = 0
    for xb, yb in pbar:
        batch_idx += 1
        xb = xb.to(device); yb = yb.to(device)
        logits = model(xb)
        loss = criterion(logits, yb)
        optimizer.zero_grad()
        loss.backward()

        # debug: print grad norms on raw_W (first layer) occasionally
        if DEBUG_RAWW_GRADS and (batch_idx % 50 == 0 or batch_idx == 1):
            grads = []
            for i, carm in enumerate(model.carms):
                if carm.raw_W.grad is None:
                    grads.append(0.0)
                else:
                    grads.append(float(carm.raw_W.grad.norm().cpu().item()))
            print(f"[DEBUG] raw_W grad norms (layers): {np.round(grads,6)}")

        # manual update of raw_W per CARM
        model.manual_carm_updates()

        # optimizer step for other params
        optimizer.step()

        running_loss += float(loss.item()) * xb.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += xb.size(0)
        pbar.set_postfix({"loss": f"{running_loss/total:.4f}", "acc": f"{100.*correct/total:.2f}"})

    epoch_time = time.time() - t0
    train_loss_epoch = running_loss / total
    train_acc_epoch = 100.0 * correct / total
    val_acc_epoch = evaluate_model(model, val_loader, device=device)

    # save W_hats per layer (numpy)
    W_hats = model.get_all_W_hats()
    for li, W in enumerate(W_hats, start=1):
        np.save(os.path.join(OUT_DIR, "W_hats", f"W_epoch{epoch:03d}_layer{li:02d}.npy"), W)

    # checkpoint
    if epoch % SAVE_EVERY == 0 or epoch == EPOCHS:
        ckpt = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "train_loss": train_loss_epoch,
            "train_acc": train_acc_epoch,
            "val_acc": val_acc_epoch
        }
        torch.save(ckpt, os.path.join(OUT_DIR, f"ckpt_epoch{epoch:03d}.pth"))
        print(f"Saved checkpoint epoch {epoch}")

    history['train_loss'].append(train_loss_epoch)
    history['train_acc'].append(train_acc_epoch)
    history['val_acc'].append(val_acc_epoch)
    history['epoch_time'].append(epoch_time)

    # persist history
    with open(os.path.join(OUT_DIR, "history.json"), "w") as fh:
        json.dump(history, fh, indent=2)

    print(f"Epoch {epoch} finished. train_loss={train_loss_epoch:.4f}, train_acc={train_acc_epoch:.2f}%, val_acc={val_acc_epoch:.2f}%, time={epoch_time:.1f}s")


Epoch 1/30:   0%|          | 0/855 [00:00<?, ?it/s]

[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.000111]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.001036]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.000456]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.000163]
[DEBUG] raw_W grad norms (layers): [0.0e+00 0.0e+00 0.0e+00 0.0e+00 0.0e+00 1.2e-05]
[DEBUG] raw_W grad norms (layers): [0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 7.e-06]
[DEBUG] raw_W grad norms (layers): [0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 4.e-06]
[DEBUG] raw_W grad norms (layers): [0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 1.e-06]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 

Epoch 2/30:   0%|          | 0/855 [00:00<?, ?it/s]

[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
Epoch 2 fi

Epoch 3/30:   0%|          | 0/855 [00:00<?, ?it/s]

[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0.      0.      0.      0.      0.      0.00363]
[DEBUG] raw_W grad norms (layers): [0.      0.      0.      0.      0.      0.00094]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.008711]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.002161]
[DEBUG] raw_W grad norms (layers): [0.      0.      0.      0.      0.      0.00056]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.000124]
[DEBUG] raw_W grad norms (layers): [0.0e+00 0.0e+00 0.0e+00 0.0e+00 0.0e+00 9.1e-05]
[DEBUG] raw_W grad norms (layers): [0.0e+00 0.0e+00 0.0e+00 0.0e+00 0.0e+00 2.5e-05]
[DEBUG] raw_W grad norms (layers): [0.0e+00 0.0e+00 0.0e+00 0.0e+00 0.0e+00 2.9e-05]
[DEBUG] raw_W grad norms (layers): [0.0e+00 0.0e+00 

Epoch 4/30:   0%|          | 0/855 [00:00<?, ?it/s]

[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
Epoch 4 fi

Epoch 5/30:   0%|          | 0/855 [00:00<?, ?it/s]

[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
Epoch 5 fi

Epoch 6/30:   0%|          | 0/855 [00:00<?, ?it/s]

[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
Epoch 6 fi

Epoch 7/30:   0%|          | 0/855 [00:00<?, ?it/s]

[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0. 0. 0. 0. 0. 0.]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.009947]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.000475]
[DEBUG] raw_W grad norms (layers): [0.0000e+00 0.0000e+00 0.0000e+00 0.0000e+00 1.0000e-06 1.4924e-02]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.000254]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.000372]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.000379]
[DEBUG] raw_W grad norms (layers): [0.       0.       0.       0.       0.       0.000108]
[DEBUG] raw_W grad norms (layers): [0.e+00 0.e+00 0.e+00 0.e+00 0.e+00 3.e-0