In [1]:
import sys

sys.path.append('../')

import os
import random

import lightning as L
import numpy as np
import torch
from chemprop import data, featurizers, models, nn
# from data import ConstrastiveDataModule, ExemplarDataset
from dotenv import load_dotenv
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from pytorch_lightning.utilities import move_data_to_device

import wandb
# from commons.data import load_and_split_gsk_dataset

RANDOM_SEED = 42

def set_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seeds(RANDOM_SEED)

load_dotenv('.env.secret')
wandb.login(key='cf344975eb80edf6f0d52af80528cc6094234caf')

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/rahul_e_dev/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mrahul-e-dev[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
import pandas as pd
import rdkit.Chem as Chem
from rdkit.rdBase import BlockLogs
from sklearn.model_selection import GroupShuffleSplit
from commons.utils import get_scaffold, standardize


def mol_to_inchi(mol):
    with BlockLogs():
        return Chem.MolToInchi(mol)


def load_and_split_gsk_dataset(path, RANDOM_SEED):
    df = pd.read_csv(path)
    df = df.iloc[:, 1:]
    df.columns = ["smiles", "per_inhibition"]

    # standardize and convert to inchi
    df["mol"] = df["smiles"].map(standardize)
    df = df.dropna(subset=["mol"])
    df["inchi"] = df["mol"].map(mol_to_inchi)
    df = df.groupby(["inchi"]).filter(lambda x: len(x) == 1).reset_index(drop=True)

    clusters, _ = pd.factorize(
        df["mol"]
        .map(Chem.MolToSmiles)  # type: ignore
        .map(get_scaffold)
    )
    clusters = pd.Series(clusters)

    df = df.drop(["smiles", "inchi"], axis=1)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED)
    train_idxs, val_test_idxs = next(splitter.split(df, groups=clusters))
    df_train = df.loc[train_idxs].reset_index(drop=True)
    df_val_test = df.loc[val_test_idxs].reset_index(drop=True)
    clusters_val_test = clusters.iloc[val_test_idxs].reset_index(drop=True)

    splitter = GroupShuffleSplit(n_splits=1, random_state=RANDOM_SEED, test_size=0.5)
    val_idxs, test_idxs = next(splitter.split(df_val_test, groups=clusters_val_test))
    df_val = df_val_test.loc[val_idxs].reset_index(drop=True)
    df_test = df_val_test.loc[test_idxs].reset_index(drop=True)

    return df_train, df_val, df_test

In [3]:
df_train, df_val, df_test = load_and_split_gsk_dataset("../GSK_HepG2.csv", RANDOM_SEED)

In [4]:
W = 2

In [5]:
def mine_hard_and_negatives(
    all_embeds: torch.Tensor,
    all_targets: torch.Tensor,
    w,
    n_candidates=31,
    n_hard=16
):
    # print(type(all_embeds))
    device = all_embeds.device
    B = all_embeds.shape[0]
    n_rand = n_candidates - n_hard

    logits = w(all_embeds) @ all_embeds.T

    target_closness_mask = (all_targets.view(-1, 1) - all_targets).abs() <= W
    # assign -inf to any values entries present in target_closness_mask
    logits[target_closness_mask] = float("-inf")

    _, hard_candidate_idxs = torch.topk(
        logits, k=n_hard, dim=1, largest=True, sorted=False
    )

    # generate a probability matrix where we assign 0 probability to entries either present in
    # target_closness_mask or they were selected as hard negatives
    rand_candidates_selection = torch.where(~target_closness_mask, 1.0, float("-inf"))
    row_idxs = torch.arange(B, device=device).unsqueeze(1).expand(-1, n_hard)
    # assign any hard negative candidates with 0 probability
    rand_candidates_selection[row_idxs, hard_candidate_idxs] = float("-inf")
    rand_candidates_selection_prob = rand_candidates_selection.softmax(dim=-1)
    rand_candidate_idxs = torch.multinomial(
        rand_candidates_selection_prob, n_rand, replacement=False
    )

    candidate_idxs = torch.cat([hard_candidate_idxs, rand_candidate_idxs], dim=-1)
    assert candidate_idxs.shape == (logits.shape[0], n_candidates)
    return candidate_idxs

In [6]:
from torch.utils.data import Dataset
from typing import NamedTuple
from chemprop.data.datasets import Datum
from chemprop.data.collate import TrainingBatch, collate_batch
from itertools import chain


class ContrastiveDataPoint(NamedTuple):
    anchor: Datum
    candidates: list[Datum]


class ContrastiveDataset(Dataset):
    def __init__(self, mols, targets, n_candidates=31, n_hard=16):
        self.mols = mols
        self.targets = targets
        self.featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

        self.n_candidates = n_candidates
        self.n_hard = n_hard
        self.n_random = n_candidates - n_hard

    def build_init_candidates(self):
        # build initial candidates as random
        selection_mat = torch.ones(len(self.targets), len(self.targets))
        selection_mat.fill_diagonal_(float("-inf"))
        selection_prob = selection_mat.softmax(dim=-1)
        rand_idxs = torch.multinomial(
            selection_prob, self.n_candidates, replacement=False
        )
        self.candidates = rand_idxs.numpy()

    def get_pos_candidate_idx(self, idx):
        anchor_target = self.targets[idx]
        mask = (self.targets - anchor_target).abs() <= W
        valid_idxs = np.argwhere(mask).squeeze()
        N = valid_idxs.shape[0] 
        while True:
            selected_idx = int(valid_idxs[random.randint(0, N - 1)])
            if selected_idx != idx:
                return selected_idx

    def get_datum(self, idx):
        mg = self.featurizer(self.mols[idx])
        target = self.targets[[idx]].to_numpy()
        return data.datasets.Datum(mg, None, None, target, 1.0, None, None)

    def __getitem__(self, idx):
        return ContrastiveDataPoint(
            self.get_datum(idx),
            (
                [self.get_datum(self.get_pos_candidate_idx(idx))] + 
                [self.get_datum(int(c_idx)) for c_idx in self.candidates[idx]]
            )
        )

    def __len__(self):
        return len(self.mols)


class ContrastiveTrainingBatch(NamedTuple):
    anchor: TrainingBatch
    candidates: TrainingBatch
    B: int
    C: int


def collate_contrastive(batch):
    batch_anchors, batch_candidates = zip(*batch)
    B = len(batch)
    C = len(batch_candidates[0])
    batch_anchors = collate_batch(batch_anchors)
    batch_candidates = collate_batch(chain.from_iterable(batch_candidates))
    return ContrastiveTrainingBatch(
        batch_anchors, batch_candidates, B=B, C=C
    )


class SimpleDataPoint(NamedTuple):
    anchor: Datum

class SimpleDataset(Dataset):
    def __init__(self, mols, targets):
        self.mols = mols
        self.targets = targets
        self.featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

    def get_datum(self, idx):
        mg = self.featurizer(self.mols[idx])
        target = self.targets[[idx]].to_numpy()
        return data.datasets.Datum(mg, None, None, target, 1.0, None, None)

    def __getitem__(self, idx):
        return self.get_datum(idx)

    def __len__(self):
        return len(self.mols)
    

# class SimpleBatch(NamedTuple):
#     anchor: TrainingBatch


# def collate_simple(batch):
#     batch_anchors = collate_batch(batch)
#     return SimpleBatch(batch_anchors)

In [7]:
class ConstrastiveDataModule(L.LightningDataModule):
    def __init__(self, df_train, df_val) -> None:
        super().__init__()
        self.df_train = df_train
        self.df_val = df_val
        self.batch_size=16

    def setup(self, stage=None):
        self.train_ds = ContrastiveDataset(
            self.df_train["mol"], 
            self.df_train["per_inhibition"]
        )
        self.val_ds = ContrastiveDataset(
            self.df_val["mol"], 
            self.df_val["per_inhibition"]
        )

        self.train_ds.build_init_candidates()
        self.val_ds.build_init_candidates()

    def train_dataloader(self):
        return DataLoader(
            self.train_ds,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=collate_contrastive,
            num_workers=8,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_ds,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=collate_contrastive,
            num_workers=8,
        )

    def update_train_dataset_neg_candidates(self, candidate_idxs: torch.Tensor):
        assert candidate_idxs.shape[0] == self.df_train.shape[0]
        self.train_ds.candidates = candidate_idxs.cpu().numpy()

    def update_val_dataset_neg_candidates(self, candidate_idxs: torch.Tensor):
        assert candidate_idxs.shape[0] == self.df_val.shape[0], (candidate_idxs.shape, self.df_val.shape)
        self.val_ds.candidates = candidate_idxs.cpu().numpy()

In [8]:
from typing import Any, Iterable
from chemprop.nn import Aggregation, ChempropMetric, MessagePassing, Predictor
from chemprop.nn.transforms import ScaleTransform
import pytorch_lightning as pl


class ContrastiveMPNN(models.MPNN):
    def __init__(
        self,
        message_passing: MessagePassing,
        agg: Aggregation,
        predictor: Predictor,
        batch_norm: bool = False,
        metrics: Iterable[ChempropMetric] | None = None,
        warmup_epochs: int = 2,
        init_lr: float = 0.0001,
        max_lr: float = 0.001,
        final_lr: float = 0.0001,
        X_d_transform: ScaleTransform | None = None,
    ):
        super().__init__(
            message_passing,
            agg,
            predictor,
            batch_norm,
            metrics,
            warmup_epochs,
            init_lr,
            max_lr,
            final_lr,
            X_d_transform,
        )
        self.loss_fn = torch.nn.CrossEntropyLoss()
        self.w = torch.nn.Linear(300, 300)

    # # @override
    # def encoding(self, bmg, V_d=None, X_d=None, i=-1):
    #     return torch.nn.functional.normalize(
    #         super().encoding(bmg, V_d, X_d),
    #         dim=-1
    #     )


    def embed_simple_batch(self, batch: TrainingBatch):
        bmg, V_d, X_d, target, _, _, _ = batch
        Z_anchor = self.encoding(bmg, V_d, X_d)
        return dict(embeds=Z_anchor, targets=target)


    def training_step(self, batch: ContrastiveTrainingBatch, batch_idx):  # type: ignore
        B, C = batch.B, batch.C

        bmg, V_d, X_d, target_anchor, _, _, _ = batch.anchor
        Z_anchor = self.encoding(bmg, V_d, X_d)

        bmg, V_d, X_d, target_candidates, _, _, _ = batch.candidates
        Z_candidates = self.encoding(bmg, V_d, X_d)

        Z_anchor = Z_anchor.view((B, 1, -1))  # (B X 1 X d)
        Z_candidates = Z_candidates.view((B, C, -1))  # (B X d X C)
        # (B X 1 X d)  x  (B X d X C) --> B X 1 X C --> B X C
        logits = (self.w(Z_anchor) @ Z_candidates.transpose(1, 2)).view(B, -1)
        labels = torch.zeros(B).long().to(self.device)
        loss = self.loss_fn(logits, labels)

        self.log("train_loss", loss, batch_size=B, prog_bar=True, on_epoch=True)
        return loss

    def validation_step(self, batch: ContrastiveTrainingBatch, batch_idx):  # type: ignore
        B, C = batch.B, batch.C

        bmg, V_d, X_d, target_anchor, _, _, _ = batch.anchor
        Z_anchor = self.encoding(bmg, V_d, X_d)

        bmg, V_d, X_d, target_candidates, _, _, _ = batch.candidates
        Z_candidates = self.encoding(bmg, V_d, X_d)

        Z_anchor = Z_anchor.view((B, 1, -1))  # (B X 1 X d)
        Z_candidates = Z_candidates.view((B, C, -1))  # (B X d X C)
        # (B X 1 X d)  x  (B X d X C) --> B X 1 X C --> B X C
        logits = (self.w(Z_anchor) @ Z_candidates.transpose(1, 2)).view(B, -1)
        labels = torch.zeros(B).long().to(self.device)
        loss = self.loss_fn(logits, labels)

        self.log("val_loss", loss, batch_size=B)
        return loss

    def get_candidates(self, dl: DataLoader, stage_str: str):
        all_embeds = []
        all_targets = []
        with torch.no_grad():
            for batch in tqdm(dl, desc=f"Mining {stage_str}:", leave=False):
                batch = move_data_to_device(batch, "cuda")
                res = self.trainer.model.embed_simple_batch(batch)  # type: ignore

                all_embeds.append(res["embeds"])
                all_targets.append(res["targets"])

        all_embeds = torch.cat(all_embeds)
        all_targets = torch.cat(all_targets)
        return mine_hard_and_negatives(all_embeds.squeeze(), all_targets.squeeze(), self.trainer.model.w)

    def on_train_epoch_start(self):
        ds = SimpleDataset(
            self.trainer.datamodule.df_train["mol"],  # type: ignore
            self.trainer.datamodule.df_train["per_inhibition"],  # type: ignore
        )
        dl = DataLoader(
            ds,
            collate_fn=collate_batch,
            batch_size=self.trainer.datamodule.batch_size, # type: ignore
            shuffle=False,
        )
        neg_candidate_idxs = self.get_candidates(dl, "Train")
        self.trainer.datamodule.update_train_dataset_neg_candidates(  # type: ignore
            neg_candidate_idxs.to("cpu")
        )

    def on_validation_epoch_start(self):
        ds = SimpleDataset(
            self.trainer.datamodule.df_val["mol"],  # type: ignore
            self.trainer.datamodule.df_val["per_inhibition"],  # type: ignore
        )
        dl = DataLoader(
            ds,
            collate_fn=collate_batch,
            batch_size=self.trainer.datamodule.batch_size, # type: ignore
            shuffle=False,
        )
        neg_candidate_idxs = self.get_candidates(dl, "Val")
        self.trainer.datamodule.update_val_dataset_neg_candidates(  # type: ignore
            neg_candidate_idxs.to("cpu")
        )

In [9]:
fdims = featurizers.SimpleMoleculeMolGraphFeaturizer().shape # the dimensions of the featurizer, given as (atom_dims, bond_dims).
mp = nn.BondMessagePassing()
agg = nn.NormAggregation()
ffn = nn.BinaryClassificationFFN(n_tasks=1)
batch_norm = True
metric_list = [nn.metrics.BinaryF1Score(), nn.metrics.BinaryAUPRC(), nn.metrics.BinaryAUROC()]
contrastive_mpnn = ContrastiveMPNN(mp, agg, ffn, batch_norm, metric_list)
# contrastive_mpnn.max_lr = 0.01

In [10]:
# datamodule=ConstrastiveDataModule(df_train, df_val)
# datamodule.setup()
# dl = datamodule.train_dataloader()
# batch = next(iter(dl))

# B, C = batch.B, batch.C
# bmg, V_d, X_d, target_anchor, _, _, _ = batch.anchor
# bmg, V_d, X_d, target_candidates, _, _, _ = batch.candidates

# (target_anchor - target_candidates.view(B, C)).abs().softmax(dim=-1)

In [11]:
wandb.finish()
wandb_logger = WandbLogger(project="chemprop_delta_clf", log_model="all", save_code=True)
wandb_logger.experiment.mark_preempting()

trainer = L.Trainer(
    logger=wandb_logger,
    enable_checkpointing=True,  # Use `True` if you want to save model checkpoints. The checkpoints will be saved in the `checkpoints` folder.
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=50,  # number of epochs to train for
    reload_dataloaders_every_n_epochs=1,
    log_every_n_steps=50,
    callbacks=[
        EarlyStopping(monitor="val_loss", mode="min", verbose=True, patience=10),
        ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=2)
    ]
)

trainer.fit(contrastive_mpnn, datamodule=ConstrastiveDataModule(df_train, df_val))

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loading `train_dataloader` to estimate number of stepping batches.

  | Name            | Type                    | Params | Mode 
--------------------------------------------------------------------
0 | message_passing | BondMessagePassing      | 227 K  | train
1 | agg             | NormAggregation         | 0      | train
2 | bn              | BatchNorm1d             | 600    | train
3 | predictor       | BinaryClassificationFFN | 90.6 K | train
4 | X_d_transform   | Identity                | 0      | train
5 | metrics         | ModuleList              | 0      | train
6 | loss_fn         | CrossEntropyLoss        | 0      | train
7 | w               | Linear                  | 90.3 K | train
--------------------------------------------------------------------
409 K     Trainable params
0         Non-trainable params
409 K     T

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Metric val_loss improved. New best score: 3.702


Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Metric val_loss improved by 0.412 >= min_delta = 0.0. New best score: 3.291


Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Metric val_loss improved by 0.034 >= min_delta = 0.0. New best score: 3.257


Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Metric val_loss improved by 0.069 >= min_delta = 0.0. New best score: 3.188


Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Metric val_loss improved by 0.023 >= min_delta = 0.0. New best score: 3.165


Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 3.164


Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Metric val_loss improved by 0.114 >= min_delta = 0.0. New best score: 3.050


Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Mining Train::   0%|          | 0/648 [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Mining Val::   0%|          | 0/97 [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 10 records. Best score: 3.050. Signaling Trainer to stop.


In [12]:
model = ContrastiveMPNN.load_from_checkpoint(trainer.callbacks[-1].best_model_path)

In [13]:
def get_embeds_targets(df_split, contrastive_mpnn):
    ds = SimpleDataset(
        df_split["mol"],  # type: ignore
        df_split["per_inhibition"],  # type: ignore
    )
    dl = DataLoader(
        ds,
        collate_fn=collate_batch,
        batch_size=32, # type: ignore
        shuffle=False,
    )

    all_embeds = []
    all_targets = []
    with torch.no_grad():
        for batch in tqdm(dl, leave=False):
            batch = move_data_to_device(batch, "cpu")
            res = contrastive_mpnn.embed_simple_batch(batch)  # type: ignore

            all_embeds.append(res["embeds"])
            all_targets.append(res["targets"])

    all_embeds = torch.cat(all_embeds)
    all_targets = torch.cat(all_targets)

    return all_embeds, all_targets

In [14]:
model.device

device(type='cuda', index=0)

In [15]:
train_embeds, train_targets = get_embeds_targets(df_train, model.to("cpu"))

  0%|          | 0/324 [00:00<?, ?it/s]

In [20]:
asd = model.w(train_embeds) @ train_embeds.T

In [21]:
asd.topk(3, dim=-1).indices


tensor([[7489, 5519, 7061],
        [7489, 5378, 5519],
        [5519, 7489, 5378],
        ...,
        [5519, 6870, 6939],
        [7489, 5378, 5519],
        [5378, 7489, 5519]])

In [22]:
asd

tensor([[ 1.1421,  0.5198, -0.5309,  ..., -0.9022,  0.9609, -0.6830],
        [ 0.9998,  0.4200, -0.7123,  ..., -0.9911,  0.8062, -0.6943],
        [ 0.9027,  0.2393, -0.8529,  ..., -0.7719,  0.6768, -0.7497],
        ...,
        [ 0.8925,  0.0315, -0.8769,  ..., -0.4103,  0.5890, -0.9145],
        [ 1.1211,  0.5619, -0.5877,  ..., -1.0418,  0.9599, -0.6483],
        [ 0.9193,  0.3153, -0.8252,  ..., -1.0486,  0.7333, -0.6512]],
       grad_fn=<MmBackward0>)

In [19]:
wandb.finish()

[34m[1mwandb[0m: [32m[41mERROR[0m The nbformat package was not found. It is required to save notebook history.


0,1
epoch,▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▆▆▆▆▆▆▇▇▇▇▇█
train_loss_epoch,█▅▅▄▂▃▃▄▁▃▂▂▁▃▁▂▁▁▂▁▂▁▁
train_loss_step,▆█▇▄▄▃▅▃▃▄▅▃▆▄▇▅▅▃▄▂▂▁▂▁▃▂▂▄▅▅▃▃▁▄▂▁▁▄▁▄
trainer/global_step,▁▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
val_loss,█▄▃▅▂▅▄▃▂▄▂▃▁▂▃▃▁▃▃▃▄▂▁

0,1
epoch,22.0
train_loss_epoch,3.07014
train_loss_step,3.14901
trainer/global_step,14903.0
val_loss,3.09512
