In [1]:
import torch
import math
import numpy as np
import random
import sympy as sp
from sympy import sympify, lambdify, symbols, integrate, Interval, Symbol, I, S, oo, plot
from IPython.display import display


TRAIN_SIZE = 1000000
VALID_SIZE = 10000
SEQ_LEN = 1024
EMBED_DIM = 384


def generate_sample(f, min_x=-4, max_x=4):
    increment = (max_x-min_x)/SEQ_LEN/EMBED_DIM
    x, t = symbols(['x','t'])
    fl = lambdify((x), f, "numpy")
    xs = np.arange(min_x, max_x, increment)
    ys = fl(xs)
    if np.isnan(ys).any() or np.isinf(ys).any():
        print("Error! NaN or Inf found!")
        ys = np.zeros(ys.shape())
    return xs, ys

def sample2seq(raw_sample, seq_len, embed_dim, seq_first=True):
    #seq = torch.zeros(seq_len, embed_dim)
    attention_mask = torch.ones(seq_len) # May change to only where values are present
    if seq_first:
        tmp = np.reshape(raw_sample, (seq_len, embed_dim)).astype(np.float32)
        seq = torch.from_numpy(tmp)
    else:
        tmp = np.transpose(np.reshape(raw_sample, (embed_dim, seq_len))).astype(np.float32)
    return {"seq":seq, "attention_mask": attention_mask}


f = sympify("x**2")
xs, ys = generate_sample(f)
print(xs)

sample2seq(xs, SEQ_LEN, EMBED_DIM)

[-4.         -3.99997965 -3.99995931 ...  3.99993896  3.99995931
  3.99997966]


{'seq': tensor([[-4.0000, -4.0000, -4.0000,  ..., -3.9922, -3.9922, -3.9922],
         [-3.9922, -3.9922, -3.9921,  ..., -3.9844, -3.9844, -3.9844],
         [-3.9844, -3.9844, -3.9843,  ..., -3.9766, -3.9766, -3.9766],
         ...,
         [ 3.9766,  3.9766,  3.9766,  ...,  3.9843,  3.9843,  3.9844],
         [ 3.9844,  3.9844,  3.9844,  ...,  3.9921,  3.9921,  3.9922],
         [ 3.9922,  3.9922,  3.9922,  ...,  3.9999,  4.0000,  4.0000]]),
 'attention_mask': tensor([1., 1., 1.,  ..., 1., 1., 1.])}

In [2]:
import json


def remove_constants(f):
    t = Symbol('t')
    return f.as_independent(t)[1]

fin = open("/home/mcwave/code/automath/calculus/datasets/parametric_equations_polynomial_integral_results.json", "r")
lines = fin.readlines()
print(len(lines), "lines loaded")
fin.close()
fin = open("/home/mcwave/code/automath/calculus/datasets/parametric_equations_randomized_polynomial_integral_results.json", "r")
lines.extend(fin.readlines())
print(len(lines), "lines loaded")
fin.close()
fin = open("/home/mcwave/code/automath/calculus/datasets/parametric_equations_randomized_nonpoly_integral_results_corrected.json", "r")
lines.extend(fin.readlines())
print(len(lines), "lines loaded")
fin.close()

MAX_POWER = 6
MAX_AVG_DIFF = 0.01

originals = []

for line in lines:
    result = json.loads(line)
    original = result["original"]
    originals.append(original)
    if len(originals) % 1000 == 0:
        print(len(originals), "cases loaded")

9024 lines loaded
57512 lines loaded
159269 lines loaded
1000 cases loaded
2000 cases loaded
3000 cases loaded
4000 cases loaded
5000 cases loaded
6000 cases loaded
7000 cases loaded
8000 cases loaded
9000 cases loaded
10000 cases loaded
11000 cases loaded
12000 cases loaded
13000 cases loaded
14000 cases loaded
15000 cases loaded
16000 cases loaded
17000 cases loaded
18000 cases loaded
19000 cases loaded
20000 cases loaded
21000 cases loaded
22000 cases loaded
23000 cases loaded
24000 cases loaded
25000 cases loaded
26000 cases loaded
27000 cases loaded
28000 cases loaded
29000 cases loaded
30000 cases loaded
31000 cases loaded
32000 cases loaded
33000 cases loaded
34000 cases loaded
35000 cases loaded
36000 cases loaded
37000 cases loaded
38000 cases loaded
39000 cases loaded
40000 cases loaded
41000 cases loaded
42000 cases loaded
43000 cases loaded
44000 cases loaded
45000 cases loaded
46000 cases loaded
47000 cases loaded
48000 cases loaded
49000 cases loaded
50000 cases loaded
51

In [3]:
NUM_LABELS = 10
MAX_POWER = 6
FUNCTIONS = {'exp':7, 'sin':8, 'cos':9}

def get_coefficients_and_exponents(f):
    variables = list(f.free_symbols)
    assert len(variables)<=1, "Expression having multiple variable " + str(f)
    if len(variables) == 0:
        return list()
    t = variables[0]
    return [[float(x) for x in term.as_coeff_exponent(t)] for term in f.as_ordered_terms()]

def get_expr_type(f):
    s = str(f)
    for function, idx in FUNCTIONS.items():
        if function in s:
            return idx
    try:
        coeffs = get_coefficients_and_exponents(f)
    except:
        print("Cannot get coefficients for", s)
        return -1
    if len(coeffs) == 0:
        return -1
    max_power = int(coeffs[0][1])
    if max_power > MAX_POWER:
        return -1
    return max_power

random.shuffle(originals)

exprs = []
labels = []
for i in range(len(originals)):
    f = sympify(originals[i])
    x, t = symbols(['x','t'])
    f = f.subs({t:x})
    #display(f)
    label = get_expr_type(f)
    if label < 0:
        print("Cannot process", originals[i])
        continue
    exprs.append(f)
    labels.append(label)
    if i % 1000 == 0:
        print(i, "rows processed")

0 rows processed
Cannot process 1.40000000000000
Cannot process 1
Cannot process 1
Cannot process 1.40000000000000
Cannot process 5
Cannot process -2.70000000000000
Cannot process 1
Cannot process 3.90000000000000
Cannot process -3
Cannot process 4
Cannot process 1
Cannot process 1
Cannot process 3
Cannot process -3
Cannot process 3
Cannot process 3
Cannot process 0.700000000000000
Cannot process 2
1000 rows processed
Cannot process 6.10000000000000
Cannot process -5
Cannot process 1
Cannot process -5
Cannot process -1
Cannot process 2
Cannot process 1
Cannot process -0.900000000000000
Cannot process 2
Cannot process 1.50000000000000
Cannot process -7.50000000000000
Cannot process -3
Cannot process -2
Cannot process 5
Cannot process -2
Cannot process 2.60000000000000
Cannot process 3.10000000000000
Cannot process -1.70000000000000
Cannot process 1
Cannot process 1
Cannot process -2.90000000000000
2000 rows processed
Cannot process -1.90000000000000
Cannot process 5
Cannot process 1
Can

Cannot process 7*t**7 + 11.07*t**3 + 7.89*t + 1.45
Cannot process 1
Cannot process 1
Cannot process -2.90000000000000
Cannot process -1
Cannot process 2
Cannot process 4.10000000000000
Cannot process 1
Cannot process -2.20000000000000
Cannot process 4
Cannot process 2.90000000000000
Cannot process -1.20000000000000
Cannot process -6
21000 rows processed
Cannot process 1
Cannot process 1.30000000000000
Cannot process 0.800000000000000
Cannot process 1
Cannot process -4
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process -2
Cannot process 4
Cannot process -2
Cannot process -3
Cannot process 1
Cannot process 2
Cannot process 3.50000000000000
Cannot process -3
Cannot process 1
22000 rows processed
Cannot process 1
Cannot process 1
Cannot process -0.500000000000000
Cannot process -3.20000000000000
Cannot process -3.30000000000000
Cannot process -0.100000000000000
Cannot process -2.90000000000000
Cannot process -7.00000000000000
Cannot process 

Cannot process -1.50000000000000
Cannot process 2
Cannot process 3
Cannot process 5
Cannot process 1
Cannot process -1
Cannot process 1
Cannot process 1
Cannot process 8.17*t**8 - 2.78*t + 3.47
Cannot process 1
Cannot process -2.50000000000000
Cannot process 1
Cannot process -3
Cannot process 2
Cannot process -1.90000000000000
40000 rows processed
Cannot process 1
Cannot process -3
Cannot process -3.60000000000000
Cannot process 1
Cannot process 6
Cannot process 0.500000000000000
Cannot process 3.70000000000000
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process 1.10000000000000
Cannot process -1.70000000000000
Cannot process 4.10000000000000
Cannot process 3
Cannot process 5.50000000000000
Cannot process 1
Cannot process 5
Cannot process 2
Cannot process 1
Cannot process -3
41000 rows processed
Cannot process 2
Cannot process -3.20000000000000
Cannot process -3.30000000000000
Cannot process 1
Cannot process 5
Cannot process -0.800000000000000
Cannot proc

Cannot process -3.80000000000000
Cannot process -5
Cannot process -3.20000000000000
Cannot process 3*t**8 + 8*t**3 + 3.92*t**2
Cannot process 1
Cannot process -1.90000000000000
Cannot process 4
Cannot process 1
Cannot process -1.50000000000000
Cannot process 3.20000000000000
58000 rows processed
Cannot process 7*t**7 + 4.17*t**2 - 3.78*t + 3.47
Cannot process 1
Cannot process 3.40000000000000
Cannot process -7.00000000000000
Cannot process 2.90000000000000
Cannot process 1.30000000000000
Cannot process 4.50000000000000
Cannot process -8
Cannot process 0.700000000000000
Cannot process -2.60000000000000
Cannot process 5.30000000000000
Cannot process 2
Cannot process 1
Cannot process -1
Cannot process 1
59000 rows processed
Cannot process -4.00000000000000
Cannot process 1
Cannot process 4
Cannot process 1
Cannot process 2.00000000000000
Cannot process 2.20000000000000
Cannot process 1
Cannot process -4.80000000000000
Cannot process 3
Cannot process -2.50000000000000
Cannot process 0.3000

77000 rows processed
Cannot process 1
Cannot process -2.70000000000000
Cannot process -5
Cannot process 4
Cannot process -3
Cannot process -3.80000000000000
Cannot process 3.00000000000000
Cannot process 1
Cannot process -2
Cannot process -0.300000000000000
Cannot process -1.60000000000000
Cannot process 5.90000000000000
78000 rows processed
Cannot process 3.10000000000000
Cannot process 1
Cannot process 3
Cannot process 1
Cannot process -3.20000000000000
Cannot process 1
Cannot process 1.10000000000000
Cannot process -2.10000000000000
Cannot process -1.90000000000000
Cannot process -2.20000000000000
Cannot process -1
Cannot process 1
Cannot process 2
Cannot process 1
Cannot process -5.00000000000000
Cannot process -0.200000000000000
Cannot process 6
Cannot process -3.20000000000000
79000 rows processed
Cannot process 1.40000000000000
Cannot process -2.00000000000000
Cannot process 3.80000000000000
Cannot process 3.00000000000000
Cannot process 4.70000000000000
Cannot process -5
Cannot

Cannot process 4
Cannot process 1.00000000000000
Cannot process 1
Cannot process -3.80000000000000
Cannot process 2.90000000000000
Cannot process -4
Cannot process -4.80000000000000
Cannot process 1.70000000000000
Cannot process 4
Cannot process -1.70000000000000
Cannot process 1
Cannot process 1
Cannot process 5
Cannot process 1
Cannot process -1
97000 rows processed
Cannot process 1
Cannot process 0.400000000000000
Cannot process 2*t**8 + 6*t**3 + 8*t**2 + 3*t + 8
Cannot process 1.00000000000000
Cannot process 1
Cannot process -7
Cannot process -4
Cannot process 1
Cannot process 4.50000000000000
Cannot process 1
Cannot process 1
Cannot process 2
Cannot process 5.00000000000000
Cannot process -1.60000000000000
Cannot process 7*t**7 + 11.17*t**3 + 15.24*t + 2.67
Cannot process -0.800000000000000
Cannot process -3.80000000000000
Cannot process -3.50000000000000
Cannot process 0.300000000000000
Cannot process -5.70000000000000
Cannot process 1
Cannot process 1
98000 rows processed
Cannot

Cannot process 6.00000000000000
Cannot process -2.90000000000000
Cannot process -3
Cannot process 1
Cannot process 5
Cannot process 3.60000000000000
Cannot process -3.00000000000000
Cannot process 3.30000000000000
Cannot process 1
Cannot process -3
Cannot process -3
Cannot process 2.00000000000000
Cannot process 1
Cannot process 1.00000000000000
Cannot process -2.30000000000000
117000 rows processed
Cannot process 2.70000000000000
Cannot process -1
Cannot process 1
Cannot process -2.60000000000000
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process 4
Cannot process -1
Cannot process 1
Cannot process 3.40000000000000
Cannot process -1
Cannot process 1
Cannot process 7*t**7 - 4.41*t**3 - 0.54*t + 4.69
Cannot process 2.40000000000000
Cannot process 1
Cannot process 2.90000000000000
Cannot process -1.20000000000000
Cannot process 1
Cannot process 2
Cannot process 1
118000 rows processed
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process 1
Canno

Cannot process 1
Cannot process 5
Cannot process 1.00000000000000
Cannot process 9.00000000000000
136000 rows processed
Cannot process -1.50000000000000
Cannot process 1
Cannot process 1
Cannot process 4.00000000000000
Cannot process -2
Cannot process 2.30000000000000
Cannot process 1
Cannot process 4.10000000000000
Cannot process -0.200000000000000
Cannot process -6.80000000000000
Cannot process -5.70000000000000
Cannot process -4.40000000000000
Cannot process -0.100000000000000
Cannot process 1
Cannot process 1.20000000000000
137000 rows processed
Cannot process 6
Cannot process -2.50000000000000
Cannot process 3
Cannot process 4.90000000000000
Cannot process 1
Cannot process 1
Cannot process -4.70000000000000
Cannot process 7*t**8 - 8*t**3 + 8*t**2 - 7*t + 2
Cannot process 7*t**7 + 4*t**4 + 4.8*t**2 + 1.0*t + 0.01
Cannot process -1.40000000000000
Cannot process 1
Cannot process -0.500000000000000
Cannot process -1.10000000000000
Cannot process 3.20000000000000
Cannot process 2.30000

Cannot process -3.50000000000000
Cannot process 4
Cannot process -1.20000000000000
154000 rows processed
Cannot process 0.100000000000000
Cannot process 0.600000000000000
Cannot process 1
Cannot process 1
Cannot process -3.60000000000000
Cannot process -2
Cannot process -4.10000000000000
Cannot process 4.10000000000000
Cannot process 3
Cannot process 1
Cannot process 5
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process 1
Cannot process 2
Cannot process 1
Cannot process 1
Cannot process 3.80000000000000
Cannot process -1
Cannot process 6
155000 rows processed
Cannot process -2
Cannot process -4.00000000000000
Cannot process 1
Cannot process 4.10000000000000
Cannot process 3
Cannot process -3.20000000000000
Cannot process 1
Cannot process 4
Cannot process 3.70000000000000
Cannot process 1.60000000000000
Cannot process 4.00000000000000
Cannot process 1
Cannot process 0.100000000000000
Cannot process 1
Cannot process 1
Cannot process -0.600000000000000
Cannot process 2.70000

In [6]:
exprs[-20:]

[2.4*exp(x) + 1,
 3.9*sin(0.7*x),
 cos(x),
 1.0 - 2.5*cos(2.6*x),
 2*x**6 - 2*x**3 + 2.76*x**2 - 13.67*x + 6.35,
 4.0*cos(4*x),
 sin(x) + 1,
 1 + exp(-3*x),
 -2.3*cos(x),
 sin(4.5*x),
 -4.0*sin(x),
 sin(x),
 -2.0 + exp(-2.8*x),
 sin(x) + 5,
 5*cos(x),
 -5*x/2 - 13,
 -1.0*sin(x),
 sin(3.8*x),
 105*x/2 + 32,
 -sin(5*x)]

In [5]:
import torch
random.seed(12345)

class SampleEmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, exprs, labels):
        self.exprs = exprs
        self.labels = labels

    def __getitem__(self, idx):
        # item = {'inputs_embeds': torch.tensor(torch.rand(512,768)),
        #         'attention_mask': torch.tensor(torch.ones(512)),
        #        }
        # item["labels"] = torch.tensor([0.0, 1.0])
        xs, ys = generate_sample(self.exprs[idx])
        encoding = sample2seq(ys, SEQ_LEN, EMBED_DIM, seq_first = True)
        item = {'inputs_embeds': torch.tensor(encoding['seq']),
                'attention_mask': torch.tensor(encoding['attention_mask']),
                #'sample': sample
               }
        scale = self.labels[idx]
        item["labels"] = torch.tensor(scale)
        return item

    def __len__(self):
        return len(self.exprs)

    
train_exprs = [exprs[i] for i in range(len(exprs)) if i % 20 != 0]
train_labels = [labels[i] for i in range(len(labels)) if i % 20 != 0]
test_exprs = [exprs[i] for i in range(len(exprs)) if i % 20 == 0]
test_labels = [labels[i] for i in range(len(labels)) if i % 20 == 0]

print("Converting to datasets...")
# convert our tokenized data into a torch Dataset
train_dataset = SampleEmbeddingDataset(train_exprs, train_labels)
valid_dataset = SampleEmbeddingDataset(test_exprs, test_labels)
print("Done")

Converting to datasets...
Done


In [9]:
train_dataset

<__main__.SampleEmbeddingDataset at 0x7fd279bbe1a0>

In [6]:
import torch
from transformers import RobertaModel, RobertaConfig

config = RobertaConfig(max_position_embeddings=SEQ_LEN+2, hidden_size=EMBED_DIM, intermediate_size=EMBED_DIM*3, num_hidden_layers=6)

#roberta_model = RobertaModel.from_pretrained("roberta-base", problem_type='regression')
roberta_model = RobertaModel(config)

my_input = torch.rand(2,SEQ_LEN,EMBED_DIM)

outputs = roberta_model(inputs_embeds=my_input)

  from .autonotebook import tqdm as notebook_tqdm


In [7]:
from torch import nn
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, MaskedLMOutput, BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel

class TransformerWithEmbeddingInput(nn.Module):
    def __init__(self, transformer_model, num_output=1) -> None:
        super().__init__()

        self.num_output = num_output
        self.transformer = transformer_model
        #self.fc = nn.Linear(EMBED_DIM, 2)
        #self.loss_fct = MSELoss()
        self.fc = nn.Linear(config.hidden_size, NUM_LABELS) 

    def forward(
        self,
        inputs_embeds: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None
    ) -> Union[tuple, ImageClassifierOutput]:
        #print("labels:", labels, labels.shape)
        
        outputs = self.transformer(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask
        )
        
        last_hidden_states = outputs['last_hidden_state'][:, -1, :]
        logits = self.fc(last_hidden_states)
        
        #class_label = class_label.to(logits.device)
        loss_fct = CrossEntropyLoss()
        loss = loss_fct(logits, labels)
        
        return ImageClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None,
        )

    
wrapper_model = TransformerWithEmbeddingInput(roberta_model)

In [8]:
from sklearn.metrics import mean_squared_error
from transformers import Trainer, TrainingArguments
from datasets import load_dataset

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    rmse = mean_squared_error(labels, predictions, squared=False)
    return {"rmse": rmse}

args = TrainingArguments(
    # evaluation_strategy = "epoch",
    # save_strategy = "epoch",
    # evaluation_strategy="steps",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    save_steps=5000,
    eval_steps=5000,
    logging_steps=1000,
    save_total_limit=4,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to="none",
    weight_decay=0.01,
    output_dir='datasets/function_type_classifier',
    metric_for_best_model='accuracy')

trainer = Trainer(
    wrapper_model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    #tokenizer=tokenizer,
)


# train the model
trainer.train('datasets/function_type_classifier/checkpoint-15000')

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  item = {'inputs_embeds': torch.tensor(encoding['seq']),
  'attention_mask': torch.tensor(encoding['attention_mask']),


Step,Training Loss



KeyboardInterrupt



In [14]:
valid_dataset[0]

  item = {'inputs_embeds': torch.tensor(encoding['seq']),
  'attention_mask': torch.tensor(encoding['attention_mask']),


{'inputs_embeds': tensor([[ 484.0000,  483.9937,  483.9875,  ...,  481.6155,  481.6093,
           481.6030],
         [ 481.5967,  481.5905,  481.5842,  ...,  479.2182,  479.2119,
           479.2057],
         [ 479.1995,  479.1932,  479.1870,  ...,  476.8268,  476.8206,
           476.8144],
         ...,
         [1144.8707, 1144.8802, 1144.8899,  ..., 1148.5455, 1148.5552,
          1148.5648],
         [1148.5745, 1148.5841, 1148.5938,  ..., 1152.2552, 1152.2649,
          1152.2745],
         [1152.2843, 1152.2939, 1152.3036,  ..., 1155.9709, 1155.9806,
          1155.9904]]),
 'attention_mask': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'labels': tensor(2)}

In [12]:
total_loss = 0

for i in range(len(valid_dataset)):
    example = valid_dataset[i]
    output = wrapper_model(torch.unsqueeze(example['inputs_embeds'], 0).to('cuda:0'), 
                           torch.unsqueeze(example['attention_mask'], 0).to('cuda:0'),
                           torch.unsqueeze(example['labels'], 0).to('cuda:0'))
    total_loss += 
    if i % 100 == 0:
        print("Case", i)
        print(test_exprs[i])
        print(test_labels[i])
        print(output)
        

Case 0
12.12*x**3 + 4.06*x**2 + 1.17*x + 5.08
3
ImageClassifierOutput(loss=tensor(0.0125, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[-1.0171, -4.8216, -1.7467,  8.3226,  3.7999,  1.6492, -2.0648, -0.9222,
         -0.5127, -2.3070]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
Case 1
3*exp(x)
7
ImageClassifierOutput(loss=tensor(0.0002, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[-0.4228, -2.9743, -3.2780,  0.3383,  2.0132,  0.4519, -2.0102, 10.9429,
         -2.4277, -1.3930]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
Case 2
cos(x) + 1
9
ImageClassifierOutput(loss=tensor(0.0002, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[-0.4344, -1.7188, -0.5750, -2.2476, -2.2812, -0.9051, -1.1071,  0.7409,
         -0.0703, 10.0993]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)
Case 3
1.42*x**5 + 7.4*x**3 + 6.25*x + 13.76
5
ImageClas

  item = {'inputs_embeds': torch.tensor(encoding['seq']),
  'attention_mask': torch.tensor(encoding['attention_mask']),


In [22]:
float(output.loss.cpu().detach().numpy())

0.00020203932945150882

In [13]:
len(valid_dataset)

7828

In [11]:
output

ImageClassifierOutput(loss=tensor(0.0003, device='cuda:0', grad_fn=<NllLossBackward0>), logits=tensor([[-0.3362, -3.0806, -2.7865,  0.9445,  2.3106,  0.5195, -2.0267, 10.7769,
         -2.4881, -1.2703]], device='cuda:0', grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [19]:
example

{'inputs_embeds': tensor([[ 484.0000,  483.9937,  483.9875,  ...,  481.6155,  481.6093,
           481.6030],
         [ 481.5967,  481.5905,  481.5842,  ...,  479.2182,  479.2119,
           479.2057],
         [ 479.1995,  479.1932,  479.1870,  ...,  476.8268,  476.8206,
           476.8144],
         ...,
         [1144.8707, 1144.8802, 1144.8899,  ..., 1148.5455, 1148.5552,
          1148.5648],
         [1148.5745, 1148.5841, 1148.5938,  ..., 1152.2552, 1152.2649,
          1152.2745],
         [1152.2843, 1152.2939, 1152.3036,  ..., 1155.9709, 1155.9806,
          1155.9904]]),
 'attention_mask': tensor([1., 1., 1.,  ..., 1., 1., 1.]),
 'labels': tensor(2)}