# Requirements

## Download Depedencies

In [4]:
import json
!pip install wandb
!pip install kaggle
!pip install sentencepiece
!pip install transformers
!pip install pytorch_lightning
!pip install madgrad
!git clone https://github.com/bloodwass/mixout
!git config --global user.email "simonmeoni@aol.com"
!git config --global user.name "Simon Meoni"
%run mixout/mixout.py
%run mixout/module.py

!chmod 600 ~/.kaggle/kaggle.json
!kaggle competitions download -c commonlitreadabilityprize
!kaggle datasets download simonmeoni/litbank
!kaggle datasets download simonmeoni/paraphrase-clrp
!unzip commonlitreadabilityprize.zip
!unzip paraphrase-clrp.zip


fatal: destination path 'mixout' already exists and is not an empty directory.
Downloading commonlitreadabilityprize.zip to /home/simon/PycharmProjects/clrp
 88%|█████████████████████████████████▌    | 1.00M/1.13M [00:00<00:00, 6.01MB/s]
100%|██████████████████████████████████████| 1.13M/1.13M [00:00<00:00, 6.11MB/s]
Downloading litbank.zip to /home/simon/PycharmProjects/clrp
 99%|█████████████████████████████████████▊| 29.0M/29.2M [00:03<00:00, 10.7MB/s]
100%|██████████████████████████████████████| 29.2M/29.2M [00:03<00:00, 8.96MB/s]
Downloading paraphrase-clrp.zip to /home/simon/PycharmProjects/clrp
 88%|█████████████████████████████████▍    | 1.00M/1.14M [00:00<00:00, 5.14MB/s]
100%|██████████████████████████████████████| 1.14M/1.14M [00:00<00:00, 5.63MB/s]
Archive:  commonlitreadabilityprize.zip
  inflating: sample_submission.csv   
  inflating: test.csv                
  inflating: train.csv               
Archive:  paraphrase-clrp.zip
  inflating: paraphrase.csv   

## Imports

In [1]:
%run mixout/mixout.py
%run mixout/module.py

import gc
import math

import madgrad
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import StratifiedKFold
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoTokenizer,
    AlbertModel,
    AutoModel,
    Trainer,
)

import wandb

## Configurations

In [None]:
SWEEP_CONFIG_PATH = "sweeps/electra-base.json"
DATASET_PATH = "train.csv"

# Data

## Dataset

In [2]:
class CommonLitReadabilityDataset(Dataset):
    def __init__(self, dataset, config):
        self.dataset = dataset
        self.tokenizer = AutoTokenizer.from_pretrained(
            config["global_tokenizer"], use_fast=True
        )
        self.tokenizer_config = config["global_tokenizer_config"]

    def __getitem__(self, index):
        example = self.dataset.loc[index]
        return {
            "x": example["excerpt"],
            "y": example["target"],
            "standard_error": example["standard_error"],
        }

    def __len__(self):
        return len(self.dataset)

    def collate_fn(self, batch):
        merge_dict = {key: [d[key] for d in batch] for key in batch[0].keys()}
        merge_dict["tokens"] = self.tokenizer(merge_dict["x"], **self.tokenizer_config)
        return merge_dict


## Datamodule

In [3]:
class CommonLitReadabilityDataModule(pl.LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.fold = self.config["train_fold"]
        self.seed = self.config["train_seed"]

    def setup(self, stage=None):
        dataset = pd.read_csv("~/train.csv")
        dataset["excerpt"] = dataset["excerpt"].apply(lambda x: x.replace("\n", ""))
        num_bins = int(np.floor(1 + np.log2(len(dataset))))
        dataset.loc[:, "bins"] = pd.cut(dataset["target"], bins=num_bins, labels=False)
        bins = dataset.bins.to_numpy()
        s_kf = StratifiedKFold(
            n_splits=self.config["train_k-fold"], shuffle=True, random_state=self.seed
        )
        fold = list(s_kf.split(X=dataset, y=bins))[self.fold]
        # augmented_dataset = pd.read_csv('/content/paraphrase.csv').iloc[fold[0]].reset_index() 
        self.train = CommonLitReadabilityDataset(pd.concat([dataset.iloc[fold[0]].reset_index()]).reset_index(), self.config)
        self.val = CommonLitReadabilityDataset(dataset.iloc[fold[1]].reset_index(), self.config)

    def train_dataloader(self):
        return DataLoader(
            self.train,
            batch_size= self.config["train_batch_size"],
            collate_fn=self.train.collate_fn,
            num_workers=2,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val, batch_size=16, collate_fn=self.val.collate_fn, num_workers=2
        )


# Model

In [4]:
class CommonLitReadabilityModel(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformers = AutoModel.from_pretrained(config["global_model"])
        self.linear = nn.Linear(self.transformers.config.hidden_size, 1)
        self.loss = nn.MSELoss()
        self.lr = config["optim_lr"]
        self.best_val_loss = 9999.0
        if config["global_reinit_pool_layer"]:
            print("Reinitializing Pooler Layer ...")
            self.transformers.pooler.dense.weight.data.normal_(
                mean=0.0, std=self.transformers.config.initializer_range
            )
            self.transformers.pooler.dense.bias.data.zero_()
            for p in self.transformers.pooler.parameters():
                p.requires_grad = True

        reinit_layers = config["global_reinit_last_n_layer"]
        if reinit_layers > 0:
            print(f"Reinitializing Last {reinit_layers} Layers ...")
            encoder_temp = self.transformers
            for layer in encoder_temp.encoder.layer[-reinit_layers:]:
                for module in layer.modules():
                    if isinstance(module, nn.Linear):
                        module.weight.data.normal_(
                            mean=0.0, std=self.transformers.config.initializer_range
                        )
                        if module.bias is not None:
                            module.bias.data.zero_()
                    elif isinstance(module, nn.Embedding):
                        module.weight.data.normal_(
                            mean=0.0, std=self.transformers.config.initializer_range
                        )
                        if module.padding_idx is not None:
                            module.weight.data[module.padding_idx].zero_()
                    elif isinstance(module, nn.LayerNorm):
                        module.bias.data.zero_()
                        module.weight.data.fill_(1.0)

        if config["train_mixout"] > 0:
            print("Initializing Mixout Regularization")
            for sup_module in self.modules():
                for name, module in sup_module.named_children():
                    if isinstance(module, nn.Dropout):
                        module.p = 0.0
                    if isinstance(module, nn.Linear):
                        target_state_dict = module.state_dict()
                        bias = True if module.bias is not None else False
                        new_module = MixLinear(
                            module.in_features,
                            module.out_features,
                            bias,
                            target_state_dict["weight"],
                            config["train_mixout"],
                        )
                        new_module.load_state_dict(target_state_dict)
                        setattr(sup_module, name, new_module)
            print("Done !")

    def predict(self, batch):
        transformers_output = self.transformers(**batch["tokens"].to(self.device))
        output_state = transformers_output[0][:, 0, :]
        return self.linear(F.dropout(output_state, 0.2)).squeeze(1)

    def training_step(self, batch, batch_idx):
        y_hat = self.predict(batch)
        loss = torch.sqrt(self.loss(y_hat, torch.Tensor(batch["y"]).to(self.device)))
        self.log("train/loss", loss, on_epoch=True, on_step=True)
        return loss

    def validation_step(self, batch, batch_idx):
        y_hat = self.predict(batch)
        loss = torch.sqrt(self.loss(y_hat, torch.Tensor(batch["y"]).to(self.device)))
        self.log("val/loss", loss, on_epoch=True)
        return loss

    def validation_epoch_end(self, validation_step_outputs):
        cpu_loss = self.trainer.logged_metrics["val/loss"].cpu().item()
        if self.best_val_loss > cpu_loss:
            self.best_val_loss = cpu_loss
        self.log("val/best_loss", self.best_val_loss, on_epoch=True)

    def get_optimizer_grouped_parameters(
        self, model, learning_rate, weight_decay=0.01, layerwise_learning_rate_decay=0.7
    ):

        no_decay = ["bias", "LayerNorm.weight"]
        # initialize lr for task specific layer
        optimizer_grouped_parameters = [
            {
                "params": [p for n, p in self.linear.named_parameters()],
                "weight_decay": weight_decay,
                "lr": learning_rate,
            },
        ]
        # initialize lrs for every layer
        num_layers = model.config.num_hidden_layers
        layers = [model.embeddings] + list(model.encoder.layer)
        layers.reverse()
        lr = learning_rate
        optimizer_grouped_parameters = []
        for layer in layers:
            lr *= layerwise_learning_rate_decay
            optimizer_grouped_parameters += [
                {
                    "params": [
                        p
                        for n, p in layer.named_parameters()
                        if not any(nd in n for nd in no_decay)
                    ],
                    "weight_decay": weight_decay,
                    "lr": lr,
                },
                {
                    "params": [
                        p
                        for n, p in layer.named_parameters()
                        if any(nd in n for nd in no_decay)
                    ],
                    "weight_decay": 0.0,
                    "lr": lr,
                },
            ]
        return optimizer_grouped_parameters

    def configure_optimizers(self):
        params = list(self.named_parameters())
        if not isinstance(self.transformers, AlbertModel):
            grouped_parameters = self.get_optimizer_grouped_parameters(
                self.transformers,
                self.lr,
                weight_decay=self.config["optim_weight_decay"],
                layerwise_learning_rate_decay=self.config[
                    "optim_layerwise_learning_rate_decay"
                ],
            )
            optimizer = madgrad.MADGRAD(
                grouped_parameters,
                lr=self.lr,
            )
            num_batches = math.ceil(
                len(self.train_dataloader()) / self.trainer.accumulate_grad_batches
            )
            if self.config["optim_cosine_lr"]:
                return {
                    "lr_scheduler": {
                        "scheduler": CosineAnnealingLR(optimizer, num_batches)
                    },
                    "optimizer": optimizer,
                }
            else:
                return optimizer
        else:
            optimizer = madgrad.MADGRAD(
                self.parameters(),
                lr=self.lr,
                weight_decay=self.config["optim_weight_decay"],
            )
            num_batches = math.ceil(
                len(self.train_dataloader()) / self.trainer.accumulate_grad_batches
            )
            if self.config["optim_cosine_lr"]:
                return {
                    "lr_scheduler": {
                        "scheduler": CosineAnnealingLR(optimizer, num_batches)
                    },
                    "optimizer": optimizer,
                }
            else:
                return optimizer


# Training


## Training Sweep

In [6]:
sweep_id = wandb.sweep(json.load(open(SWEEP_CONFIG_PATH)), project="CommonLit Readability")
def sweep_iteration():
  wandb.init()
  config = wandb.config

  pl.seed_everything(config['train_seed'])
  wandb_logger = WandbLogger()
  datamodule = CommonLitReadabilityDataModule(config)
  model = CommonLitReadabilityModel(config)
  checkpoint_callback = ModelCheckpoint(monitor='val/loss',
                                  save_top_k=1,
                                  save_last=False,
                                  save_weights_only=True,
                                  filename='checkpoint/{epoch:02d}-{val/loss:.4f}',
                                  verbose=True,
                                  mode='min',
                                  every_n_val_epochs=1)
  trainer = Trainer(logger=wandb_logger, 
                  max_epochs=config['train_max_epochs'], 
                  # gpus=1,
                  fast_dev_run=True,
                  callbacks=[checkpoint_callback],
                  resume_from_checkpoint=config['global_checkpoint'],
                  accumulate_grad_batches=config['train_accumulate_grad_batches'],
                  val_check_interval=config['val_val_check_interval'])
  trainer.fit(model, datamodule) 
  wandb.finish()
  !rm -r /content/wandb
  del model
  gc.collect()
  torch.cuda.empty_cache()
wandb.agent(sweep_id, function=sweep_iteration)

[34m[1mwandb[0m: Agent Starting Run: fzhtprr0 with config:
[34m[1mwandb[0m: 	global_checkpoint: None
[34m[1mwandb[0m: 	global_model: smeoni/electra-base-discriminator-clrp
[34m[1mwandb[0m: 	global_reinit_last_n_layer: 3
[34m[1mwandb[0m: 	global_reinit_pool_layer: False
[34m[1mwandb[0m: 	global_tokenizer: google/electra-base-discriminator
[34m[1mwandb[0m: 	global_tokenizer_config: {'max_length': 256, 'padding': True, 'return_tensors': 'pt', 'truncation': True}
[34m[1mwandb[0m: 	optim_cosine_lr: True
[34m[1mwandb[0m: 	optim_layerwise_learning_rate_decay: 0.6
[34m[1mwandb[0m: 	optim_lr: 0.0001
[34m[1mwandb[0m: 	optim_weight_decay: 0
[34m[1mwandb[0m: 	train_accumulate_grad_batches: 1
[34m[1mwandb[0m: 	train_batch_size: 24
[34m[1mwandb[0m: 	train_fold: 0
[34m[1mwandb[0m: 	train_k-fold: 5
[34m[1mwandb[0m: 	train_max_epochs: 8
[34m[1mwandb[0m: 	train_mixout: 0.6
[34m[1mwandb[0m: 	train_seed: 881
[34m[1mwandb[0m: 	val_val_check_interval:

Global seed set to 881


Downloading:   0%|          | 0.00/714 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/438M [00:00<?, ?B/s]

Some weights of the model checkpoint at smeoni/electra-base-discriminator-clrp were not used when initializing ElectraModel: ['generator_predictions.dense.bias', 'generator_lm_head.weight', 'generator_lm_head.bias', 'generator_predictions.dense.weight', 'generator_predictions.LayerNorm.weight', 'generator_predictions.LayerNorm.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Reinitializing Last 3 Layers ...
Initializing Mixout Regularization


  rank_zero_deprecation(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).


Done !


Downloading:   0%|          | 0.00/27.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/666 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/466k [00:00<?, ?B/s]


  | Name         | Type         | Params
----------------------------------------------
0 | transformers | ElectraModel | 108 M 
1 | linear       | MixLinear    | 769   
2 | loss         | MSELoss      | 0     
----------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.570   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

rm: cannot remove '/content/wandb': No such file or directory


[34m[1mwandb[0m: Agent Starting Run: kcdms5ai with config:
[34m[1mwandb[0m: 	global_checkpoint: None
[34m[1mwandb[0m: 	global_model: smeoni/electra-base-discriminator-clrp
[34m[1mwandb[0m: 	global_reinit_last_n_layer: 3
[34m[1mwandb[0m: 	global_reinit_pool_layer: False
[34m[1mwandb[0m: 	global_tokenizer: google/electra-base-discriminator
[34m[1mwandb[0m: 	global_tokenizer_config: {'max_length': 256, 'padding': True, 'return_tensors': 'pt', 'truncation': True}
[34m[1mwandb[0m: 	optim_cosine_lr: True
[34m[1mwandb[0m: 	optim_layerwise_learning_rate_decay: 0.6
[34m[1mwandb[0m: 	optim_lr: 0.0001
[34m[1mwandb[0m: 	optim_weight_decay: 0
[34m[1mwandb[0m: 	train_accumulate_grad_batches: 1
[34m[1mwandb[0m: 	train_batch_size: 24
[34m[1mwandb[0m: 	train_fold: 1
[34m[1mwandb[0m: 	train_k-fold: 5
[34m[1mwandb[0m: 	train_max_epochs: 8
[34m[1mwandb[0m: 	train_mixout: 0.6
[34m[1mwandb[0m: 	train_seed: 881
[34m[1mwandb[0m: 	val_val_check_interval:

Global seed set to 881
Some weights of the model checkpoint at smeoni/electra-base-discriminator-clrp were not used when initializing ElectraModel: ['generator_predictions.dense.bias', 'generator_lm_head.weight', 'generator_lm_head.bias', 'generator_predictions.dense.weight', 'generator_predictions.LayerNorm.weight', 'generator_predictions.LayerNorm.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Reinitializing Last 3 Layers ...
Initializing Mixout Regularization


  rank_zero_deprecation(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).


Done !



  | Name         | Type         | Params
----------------------------------------------
0 | transformers | ElectraModel | 108 M 
1 | linear       | MixLinear    | 769   
2 | loss         | MSELoss      | 0     
----------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.570   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

rm: cannot remove '/content/wandb': No such file or directory


[34m[1mwandb[0m: Agent Starting Run: a5gtgjpa with config:
[34m[1mwandb[0m: 	global_checkpoint: None
[34m[1mwandb[0m: 	global_model: smeoni/electra-base-discriminator-clrp
[34m[1mwandb[0m: 	global_reinit_last_n_layer: 3
[34m[1mwandb[0m: 	global_reinit_pool_layer: False
[34m[1mwandb[0m: 	global_tokenizer: google/electra-base-discriminator
[34m[1mwandb[0m: 	global_tokenizer_config: {'max_length': 256, 'padding': True, 'return_tensors': 'pt', 'truncation': True}
[34m[1mwandb[0m: 	optim_cosine_lr: True
[34m[1mwandb[0m: 	optim_layerwise_learning_rate_decay: 0.6
[34m[1mwandb[0m: 	optim_lr: 5e-05
[34m[1mwandb[0m: 	optim_weight_decay: 0
[34m[1mwandb[0m: 	train_accumulate_grad_batches: 1
[34m[1mwandb[0m: 	train_batch_size: 24
[34m[1mwandb[0m: 	train_fold: 1
[34m[1mwandb[0m: 	train_k-fold: 5
[34m[1mwandb[0m: 	train_max_epochs: 8
[34m[1mwandb[0m: 	train_mixout: 0.4
[34m[1mwandb[0m: 	train_seed: 881
[34m[1mwandb[0m: 	val_val_check_interval: 

Global seed set to 881
Some weights of the model checkpoint at smeoni/electra-base-discriminator-clrp were not used when initializing ElectraModel: ['generator_predictions.dense.bias', 'generator_lm_head.weight', 'generator_lm_head.bias', 'generator_predictions.dense.weight', 'generator_predictions.LayerNorm.weight', 'generator_predictions.LayerNorm.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Reinitializing Last 3 Layers ...
Initializing Mixout Regularization


  rank_zero_deprecation(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).


Done !



  | Name         | Type         | Params
----------------------------------------------
0 | transformers | ElectraModel | 108 M 
1 | linear       | MixLinear    | 769   
2 | loss         | MSELoss      | 0     
----------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.570   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: -1it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

rm: cannot remove '/content/wandb': No such file or directory


[34m[1mwandb[0m: Agent Starting Run: 9zknvq6j with config:
[34m[1mwandb[0m: 	global_checkpoint: None
[34m[1mwandb[0m: 	global_model: smeoni/electra-base-discriminator-clrp
[34m[1mwandb[0m: 	global_reinit_last_n_layer: 3
[34m[1mwandb[0m: 	global_reinit_pool_layer: False
[34m[1mwandb[0m: 	global_tokenizer: google/electra-base-discriminator
[34m[1mwandb[0m: 	global_tokenizer_config: {'max_length': 256, 'padding': True, 'return_tensors': 'pt', 'truncation': True}
[34m[1mwandb[0m: 	optim_cosine_lr: True
[34m[1mwandb[0m: 	optim_layerwise_learning_rate_decay: 0.6
[34m[1mwandb[0m: 	optim_lr: 0.0001
[34m[1mwandb[0m: 	optim_weight_decay: 0
[34m[1mwandb[0m: 	train_accumulate_grad_batches: 1
[34m[1mwandb[0m: 	train_batch_size: 24
[34m[1mwandb[0m: 	train_fold: 0
[34m[1mwandb[0m: 	train_k-fold: 5
[34m[1mwandb[0m: 	train_max_epochs: 8
[34m[1mwandb[0m: 	train_mixout: 0.4
[34m[1mwandb[0m: 	train_seed: 881
[34m[1mwandb[0m: 	val_val_check_interval:

Global seed set to 881
Some weights of the model checkpoint at smeoni/electra-base-discriminator-clrp were not used when initializing ElectraModel: ['generator_predictions.dense.bias', 'generator_lm_head.weight', 'generator_lm_head.bias', 'generator_predictions.dense.weight', 'generator_predictions.LayerNorm.weight', 'generator_predictions.LayerNorm.bias']
- This IS expected if you are initializing ElectraModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ElectraModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Reinitializing Last 3 Layers ...
Initializing Mixout Regularization


  rank_zero_deprecation(
GPU available: True, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
  rank_zero_warn(
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 1 batch(es).


Done !



  | Name         | Type         | Params
----------------------------------------------
0 | transformers | ElectraModel | 108 M 
1 | linear       | MixLinear    | 769   
2 | loss         | MSELoss      | 0     
----------------------------------------------
108 M     Trainable params
0         Non-trainable params
108 M     Total params
435.570   Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: -1it [00:00, ?it/s]

[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.
