In [1]:
import argparse

import lightning
import pandas as pd
import torch
from datasets import Dataset, Features, Sequence, Value
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from scienceworld import ScienceWorldEnv
from torch.utils.data import DataLoader

from sources.cl_nli.model import SimCSE
from sources.fallback_policy.encoder import HFEncoderModel, CustomSimCSEModel
from sources.fallback_policy.model import ContrastiveQNetwork
from sources.scienceworld.utils import parse_beliefs, parse_goal

lightning.seed_everything(42)

Seed set to 42


42

In [4]:
model_name = 'princeton-nlp/sup-simcse-roberta-base'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
encoder_model = HFEncoderModel(model_name, device='cpu')

# 2 blocks + original actions + 2 last actions
#checkpoint_file = "../../checkpoints/sup/version_21/epoch=37-step=190-train_loss_epoch=0.803.ckpt"

# all variations
#checkpoint_file = "../../checkpoints/sup_all/version_0/epoch=15-step=4640-train_loss_epoch=0.806.ckpt"

# 2 blocks
checkpoint_file = "../../checkpoints/sup_all/version_0/epoch=20-step=6090-train_loss_epoch=0.790.ckpt"

# 4 blocks
#checkpoint_file = "../../checkpoints/sup_all/version_1/epoch=34-step=10150-train_loss_epoch=0.762.ckpt"

#checkpoint_file = "../../checkpoints/sup_all/version_4/epoch=29-step=4350-train_loss_epoch=1.256.ckpt"
#checkpoint_file = "../../checkpoints/sup_all/version_4/epoch=35-step=5220-train_loss_epoch=1.252.ckpt"
checkpoint_file = "../../checkpoints/sup_all/version_10/epoch=46-step=6815-train_loss_epoch=1.226.ckpt"

model = ContrastiveQNetwork.load_from_checkpoint(checkpoint_file, encoder_model=encoder_model)
model = model.to('cpu').eval()

In [7]:
belief_base_a = ['you see a door', "you see a cupboard, the cupboard is closed", "your goal is to boil water", "you focused on water"]
belief_base_b = ['you see a door', "you see a a door, the door is open", "your goal is to boil gallium", "you focused on water"]


def encode_belief_base(belief_base: list[str]):
    belief_base_emb = model.encoder_model.encode_batch(belief_base,
                                                       max_size=len(belief_base),
                                                       include_cls=True)
    encoded_belief_base, attention = model.belief_base_encoder(belief_base_emb, [len(belief_base) + 1], )
    return encoded_belief_base, attention


emb_a, a = encode_belief_base(belief_base_a)
emb_b, b = encode_belief_base(belief_base_b)

sim = torch.nn.functional.cosine_similarity(emb_a, emb_b, dim=1)
sim.item()

0.8149359822273254

In [17]:
simcse = SimCSE.load_from_checkpoint('/opt/models/simcse_default/version_0/v0-epoch=4-step=18304-val_nli_loss=0.658-train_loss=0.551.ckpt')

encoder = CustomSimCSEModel(simcse)

model = ContrastiveQNetwork.load_from_checkpoint(checkpoint_file, encoder_model=encoder)


belief_base_a = ["you see a door to the workshop", "you see a cupboard, the cupboard is closed"]
belief_base_b = ["you see a door to the kitchen", "you see a cupboard, the cupboard is open"]


def encode_belief_base(belief_base: list[str]):
    belief_base_emb = model.encoder_model.encode_batch(belief_base,
                                                       max_size=len(belief_base)+1,
                                                       include_cls=True)
    encoded_belief_base, attention = model.belief_base_encoder(belief_base_emb, [len(belief_base)+1], )
    return encoded_belief_base, attention


emb_a, a = encode_belief_base(belief_base_a)
emb_b, b = encode_belief_base(belief_base_b)

sim = torch.nn.functional.cosine_similarity(emb_a, emb_b, dim=1)
sim

Some weights of RobertaModel were not initialized from the model checkpoint at FacebookAI/roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tensor([0.3850], device='cuda:0', grad_fn=<SumBackward1>)

In [8]:
emb_a = encoder_model.encode(belief_base_a)
emb_a.size()

torch.Size([1, 4, 768])