In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pyrootutils

base_path = pyrootutils.setup_root(
    search_from=".",
    indicator=[".gitignore"],
    project_root_env_var=True,  # set the PROJECT_ROOT environment variable to root directory
    dotenv=True,  # load environment variables from .env if exists in root directory
    pythonpath=True,  # add root directory to the PYTHONPATH (helps with imports)
    cwd=True,  # change current working directory to the root directory (helps with filepaths)
)
import sys
import pickle
import scanpy as sc
import squidpy as sq
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch_geometric as pyg
import pytorch_lightning as pl
from pytorch_lightning.utilities.model_summary import ModelSummary
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    RichProgressBar,
    TQDMProgressBar,
    ModelCheckpoint,
)
import wandb
import copy
import pandas as pd
from torch_geometric.nn.resolver import (
    activation_resolver,
    normalization_resolver,
)
from omegaconf import DictConfig
from omegaconf import OmegaConf
from hydra import initialize, initialize_config_module, initialize_config_dir, compose
import hydra
from hydra.core.hydra_config import HydraConfig
from hydra.core.utils import _flush_loggers, configure_log

from src.data.datasets import *
from src.data.datamodules import SpatialDataModule
from src.data.graph_utils import check_radius
from src.data.utils import load_celltypes, load_sample_names
from src.models.modules import DeconvolutionModel, ln_loss, beta_scheduler
from src.models.dissect_spatial import DissectSpatial, DissectHetero, DissectSpatialHybrid
from src.models.dissect import Dissect
from src.utils.utils import seed_everything
from src.train import train
from src.utils.wandb import *
from src.utils.config_utils import *
from src.utils.experiment import run_experiment

from configs.main_config import config


# ignore all warnings
import warnings

warnings.filterwarnings("ignore")

plt.rcParams["figure.dpi"] = 80

## Define data paths

In [3]:
st_data_files = [
    "spatial/simulations_kidney_slideSeq_v2/UMOD-WT.WT-2a_resolution75.h5ad",
    "spatial/simulations_heart_seqFISH/embryo1_resolution0.11.h5ad",
    "spatial/simulations_kidney_slideSeq_v2/UMOD-KI.KI-4b_resolution105.h5ad",
    "spatial/V1_Mouse_Brain_Sagittal_Anterior.h5ad",
    "spatial/lymph_node/st_lymph.h5ad",
]
experiment_dirs = [
    "experiments/experiment_kidney_slideSeq_v2_UMOD-WT.WT-2a_resolution75",
    "experiments/experiment_heart_seqFISH_embryo1_resolution0.11-new",
    "experiments/experiment_kidney_slideSeq_v2_105",
    "experiments/experiment_mouse_st",
    "experiments/experiment_lymph_node"
]
st_paths = ["${paths.data_dir}" + f  for f in st_data_files]
experiment_paths = ["${paths.root_dir}" + "/" + dir for dir in experiment_dirs]

## Setup experiment config

### Load default config

In [4]:
experiment_name = None
# experiment_name = "hybrid"
experiment_name = "multi_channel"
if experiment_name is not None:
    overrides = [f"experiment={experiment_name}"]
else:
    overrides = []

hydra.core.global_hydra.GlobalHydra.instance().clear()
# import GlobalHydra

config_path = "../configs"
config_name = "train.yaml"

initialize(version_base="1.3", config_path=config_path)
config = compose(
    config_name=config_name,
    overrides=overrides,
    return_hydra_config=True,
)
HydraConfig.instance().set_config(config)
OmegaConf.set_struct(config, False)
del config["hydra"]
config.experiment = experiment_name
# print_config(config)

### Load config from previous experiment

In [5]:
# run_name = "northern-river-217"
# config = get_run_config(run_name, project="dissect-spatial")
run_name = "electric-sweep-37"
run_name = "swift-sweep-658"
config = get_run_config(run_name, project="DISSECT-src")
config = convert_wandb_to_dict_config(config)

Net in config: True


### Set custom values

#### Modify network architecture

In [6]:
# config.net.latent_dim = 64
# # config.net.encoder_kwargs.fusion = "gating"
# # config.net.encoder_kwargs.fusion = "concat_skip"
# config.net.encoder_kwargs.fusion = "concat"
# config.net.encoder_kwargs.lin_channel = True
# config.net.encoder_kwargs.num_heads = 1
# config.net.encoder_kwargs.num_layers = 1
# config.net.encoder_kwargs.use_ffn = False
# config.net.encoder_kwargs.use_pos = False
# config.net.encoder_kwargs.plain_last = False
# # config.net.use_pos = True
config.net.encoder_kwargs.use_sparse = True

In [7]:
print_config(config.net)

use_id: false
use_pos: false
_target_: src.models.dissect_spatial.DissectSpatial
activation: relu6
latent_dim: 128
encoder_type: gnn
encoder_kwargs:
  knn: 6
  norm: batch_norm
  cosine: false
  fusion: concat
  use_id: false
  dropout: 0
  use_ffn: false
  use_pos: false
  num_heads: 2
  inter_skip: false
  norm_first: false
  num_layers: 1
  plain_last: false
  use_sparse: true
  lin_channel: true
  mha_channel: false
  sim_pos_enc: false
  ff_hidden_dim: 256
  latent_channel: false
  spatial_channel: true
  init_embed_hidden_channels:
  - 512
  - 256
decoder_kwargs:
  num_layers: 2
  hidden_channels: 64



#### Modify training parameters

In [8]:
# config.data.radius = 0.04
print(config.data.radius)

0.02


In [9]:
# config.model.l2_lambda = 5e-5
# config.model.l1_lambda = 0.0
# config.model.weight_decay = 0.0
# config.model.beta = 5
# config.model.learning_rate = 1e-3
# config.model.alpha_max = 0.1
# config.model.alpha_min = 0.0
# config.trainer.max_steps = 5000
# config.trainer.check_val_every_n_epoch = 500

In [9]:
print_config(config.model)
# print_config(config.trainer)

beta: 7.5
_target_: src.models.modules.DeconvolutionModel
plotting: false
alpha_max: 0.1
alpha_min: 0
l1_lambda: 0
l2_lambda: 1.0e-05
log_hparams: false
sim_loss_fn: kl_div
weight_decay: 0
learning_rate: 1.0e-05
save_predictions: false
move_data_to_device: true



## Run training

In [10]:
# run training on all datasets if no dataset is specified
tags = "new-experiment"
tags = "ablation-new"
tags = "latestv6"
wandb_mode = "disabled"
wandb_mode = "online"
if wandb_mode == "online":
    config.trainer.fast_dev_run = False
else:
    config.trainer.fast_dev_run = True
config.trainer.deterministic = True
for st_path, experiment_path in list(zip(st_paths, experiment_paths))[0::]:
    print(f"Running experiment on {st_path} and experiment path {experiment_path}")
    metric_dict, object_dict = run_experiment(
        config,
        st_path,
        experiment_path,
        wandb_mode=wandb_mode,
        tags=tags,
        config_path="../../configs",
        progress_bar=False
    )


Running experiment on ${paths.data_dir}spatial/simulations_kidney_slideSeq_v2/UMOD-WT.WT-2a_resolution75.h5ad and experiment path ${paths.root_dir}/experiments/experiment_kidney_slideSeq_v2_UMOD-WT.WT-2a_resolution75


Global seed set to 42


Key experiment not in base config.
Using deterministic mode


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdschaub[0m. Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type           | Params
----------------------------------------
0 | net  | DissectSpatial | 257 K 
----------------------------------------
257 K     Trainable params
0         Non-trainable params
257 K     Total params
1.031     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_steps=5000` reached.


VBox(children=(Label(value='18.812 MB of 19.721 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.9539…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/beta,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/l1_loss,▁▁▂▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇██████████
train/l2_loss,▁▁▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇████████
train/mix_loss,▁▁▂▄▄▅▆▇▇▆▅▇▆▆▇▇▅▇▅▆▇▆▇▇██▆▃▅█▃▇▄▅▃▆█▅▄▄
train/sim_loss,██▇▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁
train/total_loss,██▇▆▆▅▅▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
validation/mean_ccc,▁▄▆▇▇█████
validation/mean_ccc_,▁▄▆▇▇█████

0,1
epoch,4999.0
train/beta,7.5
train/l1_loss,36160.625
train/l2_loss,963.10974
train/mix_loss,0.00089
train/sim_loss,0.13674
train/total_loss,0.15306
trainer/global_step,4999.0
validation/mean_ccc_,0.6502
validation/mean_corr_,0.69269


Running experiment on ${paths.data_dir}spatial/simulations_kidney_slideSeq_v2/UMOD-KI.KI-4b_resolution105.h5ad and experiment path ${paths.root_dir}/experiments/experiment_kidney_slideSeq_v2_105


Global seed set to 42


Key experiment not in base config.
Using deterministic mode


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type           | Params
----------------------------------------
0 | net  | DissectSpatial | 257 K 
----------------------------------------
257 K     Trainable params
0         Non-trainable params
257 K     Total params
1.030     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_steps=5000` reached.


VBox(children=(Label(value='10.168 MB of 10.168 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/beta,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/l1_loss,▁▁▂▂▂▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇▇████████
train/l2_loss,▁▁▂▂▂▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇███████
train/mix_loss,▁▁▃▃▃▄▄▃▅▅▄▆▅▆▃▄▇▆▂▃▇▇▃▂▆▂▇▇▅▇▄▃▆▂▇▃▇█▇▇
train/sim_loss,██▆▆▅▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
train/total_loss,██▆▆▅▅▅▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
validation/mean_ccc,▁▅▆▇▇█████
validation/mean_ccc_,▁▅▆▇▇█████

0,1
epoch,4999.0
train/beta,7.5
train/l1_loss,39712.16016
train/l2_loss,976.00983
train/mix_loss,0.0016
train/sim_loss,0.11429
train/total_loss,0.13607
trainer/global_step,4999.0
validation/mean_ccc_,0.73468
validation/mean_corr_,0.77152


Running experiment on ${paths.data_dir}spatial/simulations_heart_seqFISH/embryo1_resolution0.11.h5ad and experiment path ${paths.root_dir}/experiments/experiment_heart_seqFISH_embryo1_resolution0.11-new


Global seed set to 42


Key experiment not in base config.
Using deterministic mode


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type           | Params
----------------------------------------
0 | net  | DissectSpatial | 258 K 
----------------------------------------
258 K     Trainable params
0         Non-trainable params
258 K     Total params
1.033     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_steps=5000` reached.


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/beta,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/l1_loss,█▇▆▆▅▅▅▅▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/l2_loss,█▆▅▅▅▅▄▄▄▄▃▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂
train/mix_loss,▁▁▁▁▂▃▃▄▃▄▅▆▄▅▄▆▆▅▄▆▇█▅▄▆█▆▇▆▇▆█▇▄▆▄▄▆▇▃
train/sim_loss,███▇▆▆▆▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁
train/total_loss,███▇▆▆▆▅▅▅▅▄▄▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
validation/mean_ccc,▁▃▄▅▅▆▇▇██
validation/mean_ccc_,▁▃▄▅▅▆▇▇██

0,1
epoch,4999.0
train/beta,7.5
train/l1_loss,14979.79004
train/l2_loss,786.94421
train/mix_loss,0.00033
train/sim_loss,1.09385
train/total_loss,1.10422
trainer/global_step,4999.0
validation/mean_ccc_,0.45048
validation/mean_corr_,0.67641


Running experiment on ${paths.data_dir}spatial/V1_Mouse_Brain_Sagittal_Anterior.h5ad and experiment path ${paths.root_dir}/experiments/experiment_mouse_st


Global seed set to 42


Key experiment not in base config.
Using deterministic mode


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type           | Params
----------------------------------------
0 | net  | DissectSpatial | 258 K 
----------------------------------------
258 K     Trainable params
0         Non-trainable params
258 K     Total params
1.033     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_steps=5000` reached.


VBox(children=(Label(value='29.400 MB of 29.400 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/beta,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/l1_loss,▂▁▁▁▂▂▃▃▄▄▄▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██████████
train/l2_loss,▁▁▁▁▂▂▂▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇▇███████
train/mix_loss,▁▁▂▃▄▅▆▇▇█████▇██▇▆▇█▇▆▅▆▇▆▆▆▆▅▆▆▃▄▃▃▄▄▂
train/sim_loss,██▇▆▆▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/total_loss,██▇▆▆▅▅▄▄▄▄▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
epoch,4999.0
train/beta,7.5
train/l1_loss,45418.29297
train/l2_loss,1015.57056
train/mix_loss,0.00024
train/sim_loss,0.18774
train/total_loss,0.19971
trainer/global_step,4999.0


Running experiment on ${paths.data_dir}spatial/lymph_node/st_lymph.h5ad and experiment path ${paths.root_dir}/experiments/experiment_lymph_node


Global seed set to 42


Key experiment not in base config.
Using deterministic mode


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1)` was configured so 1 batch per epoch will be used.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name | Type           | Params
----------------------------------------
0 | net  | DissectSpatial | 259 K 
----------------------------------------
259 K     Trainable params
0         Non-trainable params
259 K     Total params
1.036     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_steps=5000` reached.


VBox(children=(Label(value='64.215 MB of 64.215 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/beta,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/l1_loss,▁▁▁▁▂▂▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇███████
train/l2_loss,▁▁▁▁▂▂▂▂▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇▇█████
train/mix_loss,▁▁▁▂▂▃▄▅▅▅▆▇█▇▇█▇▇▇▇█▇█████████▇▇█▇▇▇▇▇▇
train/sim_loss,███▇▇▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
train/total_loss,███▇▇▆▆▆▅▅▅▅▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███

0,1
epoch,4999.0
train/beta,7.5
train/l1_loss,55700.24219
train/l2_loss,1301.14917
train/mix_loss,0.0002
train/sim_loss,0.64309
train/total_loss,0.6576
trainer/global_step,4999.0


In [68]:
summary = ModelSummary(object_dict["model"], max_depth=-1)
print(summary)

   | Name                                          | Type                            | Params
---------------------------------------------------------------------------------------------------
0  | net                                           | DissectSpatial                  | 4.4 M 
1  | net.encoder                                   | MultiChannelGNNEncoder          | 4.4 M 
2  | net.encoder.mlp                               | MLP                             | 4.0 M 
3  | net.encoder.mlp.act                           | ReLU6                           | 0     
4  | net.encoder.mlp.lins                          | ModuleList                      | 4.0 M 
5  | net.encoder.mlp.lins.0                        | Linear                          | 3.8 M 
6  | net.encoder.mlp.lins.1                        | Linear                          | 131 K 
7  | net.encoder.mlp.lins.2                        | Linear                          | 32.9 K
8  | net.encoder.mlp.norms                         | M

In [69]:
object_dict["net"]

DissectSpatial(
  (encoder): MultiChannelGNNEncoder(
    (mlp): MLP(-1, 512, 256, 128)
    (layers): ModuleList(
      (0): MultiChannelGNNBlock(
        (spatial_conv): GATv2Conv(128, 128, heads=4)
        (mha): MultiHeadSelfAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=128, out_features=128, bias=True)
        )
        (concat_linear): Linear(in_features=768, out_features=128, bias=True)
        (ffn): FeedForwardBlock(
          (linear1): Linear(in_features=128, out_features=256, bias=True)
          (linear2): Linear(in_features=256, out_features=128, bias=True)
          (dropout1): Dropout(p=0.0, inplace=False)
          (dropout2): Dropout(p=0.0, inplace=False)
          (activation): ReLU6()
        )
        (activation): ReLU6()
        (norm1): BatchNorm(128)
        (norm2): BatchNorm(128)
      )
    )
    (pos_encoder): Linear(in_features=2, out_features=128, bias=True)
    (sim_pos_encoder): Linear(in_features=2, out_features=128, bias=T

In [61]:
for param in object_dict["net"].parameters():
    print(param.shape)

torch.Size([512, 7485])
torch.Size([512])
torch.Size([256, 512])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([1, 4, 128])
torch.Size([512])
torch.Size([512, 128])
torch.Size([512])
torch.Size([512, 128])
torch.Size([512])
torch.Size([512, 1])
torch.Size([384, 128])
torch.Size([384])
torch.Size([1, 1, 128])
torch.Size([1, 1, 128])
torch.Size([128, 128])
torch.Size([128])
torch.Size([128, 768])
torch.Size([128])
torch.Size([256, 128])
torch.Size([256])
torch.Size([128, 256])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128])
torch.Size([128, 2])
torch.Size([128])
torch.Size([128, 2])
torch.Size([128])
torch.Size([64, 128])
torch.Size([64])
torch.Size([12, 64])
torch.Size([12])
