In [1]:
import torch
import math
import numpy as np
import random
import sympy as sp
from sympy import sympify, lambdify, symbols, integrate, Interval, Symbol, I, S, oo, plot
from IPython.display import display


TRAIN_SIZE = 1000000
VALID_SIZE = 10000
SEQ_LEN = 768
OUTPUT_LEN = 256
EMBED_DIM = 384


def generate_sample(f, min_x=-4, max_x=4):
    increment = (max_x-min_x)/SEQ_LEN/EMBED_DIM
    x, t = symbols(['x','t'])
    fl = lambdify((t), f, "numpy")
    xs = np.arange(min_x, max_x, increment)
    ys = fl(xs)
    if isinstance(ys, float) or isinstance(ys, int):
        ys = np.full(len(xs), float(ys))
    if np.isnan(ys).any() or np.isinf(ys).any():
        print("Error! NaN or Inf found!")
        ys = np.zeros(ys.shape())
    return xs, ys

def sample2seq(raw_sample, seq_len, embed_dim, seq_first=True):
    #seq = torch.zeros(seq_len, embed_dim)
    attention_mask = torch.ones(seq_len) # May change to only where values are present
    if seq_first:
        tmp = np.reshape(raw_sample, (seq_len, embed_dim)).astype(np.float32)
        seq = torch.from_numpy(tmp)
    else:
        tmp = np.transpose(np.reshape(raw_sample, (embed_dim, seq_len))).astype(np.float32)
    return {"seq":seq, "attention_mask": attention_mask}


f = sympify("t**2")
xs, ys = generate_sample(f)
print(xs)

sample2seq(xs, SEQ_LEN, EMBED_DIM)

[-4.         -3.99997287 -3.99994575 ...  3.99991862  3.99994575
  3.99997287]


{'seq': tensor([[-4.0000, -4.0000, -3.9999,  ..., -3.9897, -3.9896, -3.9896],
         [-3.9896, -3.9896, -3.9895,  ..., -3.9792, -3.9792, -3.9792],
         [-3.9792, -3.9791, -3.9791,  ..., -3.9688, -3.9688, -3.9688],
         ...,
         [ 3.9688,  3.9688,  3.9688,  ...,  3.9791,  3.9791,  3.9791],
         [ 3.9792,  3.9792,  3.9792,  ...,  3.9895,  3.9895,  3.9896],
         [ 3.9896,  3.9896,  3.9896,  ...,  3.9999,  3.9999,  4.0000]]),
 'attention_mask': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1.

In [None]:
# import json


# def remove_constants(f):
#     t = Symbol('t')
#     return f.as_independent(t)[1]

# fin = open("/home/mcwave/code/automath/calculus/datasets/parametric_equations_polynomial_integral_results.json", "r")
# lines = fin.readlines()
# print(len(lines), "lines loaded")
# fin.close()
# fin = open("/home/mcwave/code/automath/calculus/datasets/parametric_equations_randomized_polynomial_integral_results.json", "r")
# lines.extend(fin.readlines())
# print(len(lines), "lines loaded")
# fin.close()
# fin = open("/home/mcwave/code/automath/calculus/datasets/parametric_equations_randomized_nonpoly_integral_results_corrected.json", "r")
# lines.extend(fin.readlines())
# print(len(lines), "lines loaded")
# fin.close()

# MAX_POWER = 6
# MAX_AVG_DIFF = 0.01

# originals = []

# for line in lines:
#     result = json.loads(line)
#     original = result["original"]
#     originals.append(original)
#     if len(originals) % 1000 == 0:
#         print(len(originals), "cases loaded")

In [2]:
# NUM_LABELS = 10
# MAX_POWER = 6
# FUNCTIONS = {'exp':7, 'sin':8, 'cos':9}

# def get_coefficients_and_exponents(f):
#     variables = list(f.free_symbols)
#     assert len(variables)<=1, "Expression having multiple variable " + str(f)
#     if len(variables) == 0:
#         return list()
#     t = variables[0]
#     return [[float(x) for x in term.as_coeff_exponent(t)] for term in f.as_ordered_terms()]

# def get_expr_type(f):
#     s = str(f)
#     for function, idx in FUNCTIONS.items():
#         if function in s:
#             return idx
#     try:
#         coeffs = get_coefficients_and_exponents(f)
#     except:
#         print("Cannot get coefficients for", s)
#         return -1
#     if len(coeffs) == 0:
#         return -1
#     max_power = int(coeffs[0][1])
#     if max_power > MAX_POWER:
#         return -1
#     return max_power

# random.shuffle(originals)

# exprs = []
# labels = []
# for i in range(len(originals)):
#     f = sympify(originals[i])
#     x, t = symbols(['x','t'])
#     f = f.subs({t:x})
#     #display(f)
#     label = get_expr_type(f)
#     if label < 0:
#         print("Cannot process", originals[i])
#         continue
#     exprs.append(f)
#     labels.append(label)
#     if i % 1000 == 0:
#         print(i, "rows processed")

fin = open('datasets/random_exprs_1m.txt', 'r')
lines = fin.readlines()
exprs = []
for line in lines[0:1000000]:
    exprs.append(sympify(line))
    if len(exprs) % 1000 == 0:
        print(len(exprs), "rows processed")
fin.close()

1000 rows processed
2000 rows processed
3000 rows processed
4000 rows processed
5000 rows processed
6000 rows processed
7000 rows processed
8000 rows processed
9000 rows processed
10000 rows processed
11000 rows processed
12000 rows processed
13000 rows processed
14000 rows processed
15000 rows processed
16000 rows processed
17000 rows processed
18000 rows processed
19000 rows processed
20000 rows processed
21000 rows processed
22000 rows processed
23000 rows processed
24000 rows processed
25000 rows processed
26000 rows processed
27000 rows processed
28000 rows processed
29000 rows processed
30000 rows processed
31000 rows processed
32000 rows processed
33000 rows processed
34000 rows processed
35000 rows processed
36000 rows processed
37000 rows processed
38000 rows processed
39000 rows processed
40000 rows processed
41000 rows processed
42000 rows processed
43000 rows processed
44000 rows processed
45000 rows processed
46000 rows processed
47000 rows processed
48000 rows processed
4

379000 rows processed
380000 rows processed
381000 rows processed
382000 rows processed
383000 rows processed
384000 rows processed
385000 rows processed
386000 rows processed
387000 rows processed
388000 rows processed
389000 rows processed
390000 rows processed
391000 rows processed
392000 rows processed
393000 rows processed
394000 rows processed
395000 rows processed
396000 rows processed
397000 rows processed
398000 rows processed
399000 rows processed
400000 rows processed
401000 rows processed
402000 rows processed
403000 rows processed
404000 rows processed
405000 rows processed
406000 rows processed
407000 rows processed
408000 rows processed
409000 rows processed
410000 rows processed
411000 rows processed
412000 rows processed
413000 rows processed
414000 rows processed
415000 rows processed
416000 rows processed
417000 rows processed
418000 rows processed
419000 rows processed
420000 rows processed
421000 rows processed
422000 rows processed
423000 rows processed
424000 row

752000 rows processed
753000 rows processed
754000 rows processed
755000 rows processed
756000 rows processed
757000 rows processed
758000 rows processed
759000 rows processed
760000 rows processed
761000 rows processed
762000 rows processed
763000 rows processed
764000 rows processed
765000 rows processed
766000 rows processed
767000 rows processed
768000 rows processed
769000 rows processed
770000 rows processed
771000 rows processed
772000 rows processed
773000 rows processed
774000 rows processed
775000 rows processed
776000 rows processed
777000 rows processed
778000 rows processed
779000 rows processed
780000 rows processed
781000 rows processed
782000 rows processed
783000 rows processed
784000 rows processed
785000 rows processed
786000 rows processed
787000 rows processed
788000 rows processed
789000 rows processed
790000 rows processed
791000 rows processed
792000 rows processed
793000 rows processed
794000 rows processed
795000 rows processed
796000 rows processed
797000 row

In [3]:
import torch
from transformers import RobertaTokenizer

random.seed(12345)

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

class SampleEmbeddingDataset(torch.utils.data.Dataset):
    def __init__(self, exprs):
        self.exprs = exprs
        #self.labels = labels

    def __getitem__(self, idx):
        # item = {'inputs_embeds': torch.tensor(torch.rand(512,768)),
        #         'attention_mask': torch.tensor(torch.ones(512)),
        #        }
        # item["labels"] = torch.tensor([0.0, 1.0])
        #print(str(self.exprs[idx]))
        xs, ys = generate_sample(self.exprs[idx])
        #print("ys", ys)
        encoding = sample2seq(ys, SEQ_LEN, EMBED_DIM, seq_first = True)
        attention_mask = torch.cat([torch.tensor(encoding['attention_mask']), torch.ones(OUTPUT_LEN)])
        inputs_embeds = torch.cat([torch.tensor(encoding['seq']), torch.zeros(OUTPUT_LEN, EMBED_DIM)])
        item = {'inputs_embeds': inputs_embeds,
                'attention_mask': attention_mask,
                #'sample': sample
               }
        #scale = self.labels[idx]link
        labels = tokenizer(str(self.exprs[idx]), return_tensors='pt', padding='max_length', truncation=True, max_length=OUTPUT_LEN)
        input_ids = torch.cat([torch.ones(SEQ_LEN, dtype=torch.int64), labels['input_ids'][0]])
        item["labels"] = input_ids
        item["inputs_ids"] = input_ids.clone()
        return item

    def __len__(self):
        return len(self.exprs)

    
train_exprs = [exprs[i] for i in range(len(exprs)) if i % 50 != 0]
test_exprs = [exprs[i] for i in range(len(exprs)) if i % 50 == 0]

print("Converting to datasets...")
# convert our tokenized data into a torch Dataset
train_dataset = SampleEmbeddingDataset(train_exprs)
valid_dataset = SampleEmbeddingDataset(test_exprs)
print("Done")

  from .autonotebook import tqdm as notebook_tqdm


Converting to datasets...
Done


In [13]:
tokenizer.decode(valid_dataset[1]['labels'][768:])

  attention_mask = torch.cat([torch.tensor(encoding['attention_mask']), torch.ones(OUTPUT_LEN)])
  inputs_embeds = torch.cat([torch.tensor(encoding['seq']), torch.zeros(OUTPUT_LEN, EMBED_DIM)])


'<s>5*t**4 - 1.4*t</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pa

In [5]:
import torch
from transformers import RobertaForCausalLM, RobertaConfig

config = RobertaConfig(max_position_embeddings=SEQ_LEN+OUTPUT_LEN+2, hidden_size=EMBED_DIM, intermediate_size=EMBED_DIM*3, num_hidden_layers=6)
config.is_decoder = True
config.architectures = ['RobertaForCausalLM']

#roberta_model = RobertaModel.from_pretrained("roberta-base", problem_type='regression')
roberta_model = RobertaForCausalLM(config)

my_input = torch.rand(2,SEQ_LEN+OUTPUT_LEN,EMBED_DIM)

outputs = roberta_model(inputs_embeds=my_input)

In [6]:
"""
RobertaConfig {
  "_name_or_path": "roberta-base",
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.36.2",
  "type_vocab_size": 1,
  "use_cache": true,
  "vocab_size": 50265
}
"""

config

RobertaConfig {
  "architectures": [
    "RobertaForCausalLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "bos_token_id": 0,
  "classifier_dropout": null,
  "eos_token_id": 2,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 1152,
  "is_decoder": true,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 1026,
  "model_type": "roberta",
  "num_attention_heads": 12,
  "num_hidden_layers": 6,
  "pad_token_id": 1,
  "position_embedding_type": "absolute",
  "transformers_version": "4.36.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 50265
}

In [12]:
outputs['logits'].shape

torch.Size([2, 1024, 50265])

In [14]:
from torch import nn
import collections.abc
import math
from typing import Dict, List, Optional, Set, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput, CausalLMOutputWithCrossAttentions, MaskedLMOutput, BaseModelOutputWithPoolingAndCrossAttentions
from transformers.modeling_utils import PreTrainedModel

class TransformerWithEmbeddingInput(nn.Module):
    def __init__(self, transformer_model, num_output=1) -> None:
        super().__init__()
        self.transformer = transformer_model

    def forward(
        self,
        inputs_ids: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None
    ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
        #print("labels:", labels, labels.shape)
        embedding_output = self.transformer.roberta.embeddings(
            input_ids=inputs_ids
        )
        #print("input 769", inputs_ids[:,769])
        #print("pos 769", embedding_output[:,769,:])
        sum_embedding = embedding_output + inputs_embeds
#         print("inputs_embeds.shape", inputs_embeds.shape)
#         print(inputs_ids[:,0])
#         print(inputs_ids[:,770])
#         print(inputs_ids[:,-1])
#         print("embedding_output.shape", embedding_output.shape)
#         print("pos 0", embedding_output[:,0,:])
#         print("pos 769", embedding_output[:,770,:])
#         print("pos -1", embedding_output[:,-1,:])
        
        return self.transformer.forward(
            #input_ids=inputs_ids,
            inputs_embeds=sum_embedding,
            attention_mask=attention_mask,
            labels=labels
        )
        
#         outputs = self.transformer(
#             inputs_embeds=inputs_embeds,
#             attention_mask=attention_mask
#         )
        
#         last_hidden_states = outputs['last_hidden_state'][:, -1, :]
#         logits = self.fc(last_hidden_states)
        
#         #class_label = class_label.to(logits.device)
#         loss_fct = CrossEntropyLoss()
#         loss = loss_fct(logits, labels)
        
#         return ImageClassifierOutput(
#             loss=loss,
#             logits=logits,
#             hidden_states=None,
#             attentions=None,
#         )

    
wrapper_model = TransformerWithEmbeddingInput(roberta_model)
#wrapper_model = torch.load('datasets/function_type_lm_test.model')

In [25]:
from sklearn.metrics import mean_squared_error
from transformers import Trainer, TrainingArguments
from datasets import load_dataset

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    rmse = mean_squared_error(labels, predictions, squared=False)
    return {"rmse": rmse}

args = TrainingArguments(
    # evaluation_strategy = "epoch",
    # save_strategy = "epoch",
    evaluation_strategy="steps",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=10,
    save_steps=10000000,
    eval_steps=10000,
    logging_steps=5000,
    save_total_limit=4,
    remove_unused_columns=False,
    push_to_hub=False,
    report_to="none",
    weight_decay=0.01,
    output_dir='datasets/function_type_lm',
    metric_for_best_model='accuracy')

# Tie the weights of the embedding layer and the decoder layer
# if hasattr(wrapper_model, 'tie_weights'):
#     wrapper_model.tie_weights()

trainer = Trainer(
    wrapper_model,
    args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    #tokenizer=tokenizer,
)


# train the model
trainer.train() #'datasets/function_type_classifier/checkpoint-15000')

Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
  attention_mask = torch.cat([torch.tensor(encoding['attention_mask']), torch.ones(OUTPUT_LEN)])
  inputs_embeds = torch.cat([torch.tensor(encoding['seq']), torch.zeros(OUTPUT_LEN, EMBED_DIM)])


Step,Training Loss,Validation Loss
10000,0.0045,0.005976
20000,0.0044,0.006859
30000,0.0044,0.005871
40000,0.0043,0.006294
50000,0.0043,0.007834
60000,0.0042,0.00628
70000,0.0042,0.00751
80000,0.0041,0.005692
90000,0.0042,0.006418
100000,0.0041,0.006038


KeyboardInterrupt: 

In [35]:
wrapper_model.eval()
example = valid_dataset[14]

print("inputs_ids:", example['inputs_ids'][768:790], tokenizer.decode(example['inputs_ids'][768:790]))

BOS = 0
PAD = 1
EOS = 2
START_IDX = 768

inputs_ids = torch.ones(example['inputs_ids'].shape, dtype=example['inputs_ids'].dtype)
inputs_ids[START_IDX] = BOS
#inputs_ids[START_IDX+1] = 12


idx = START_IDX

while idx < SEQ_LEN + 20:
    #print("input_ids:", tokenizer.decode(inputs_ids[768:785]))
    #print("inputs_embeds:", example['inputs_embeds'][758:770])
    #print("attention_mask:", example['attention_mask'][768:780])
    #print("labels:", tokenizer.decode(example['labels'][768:780]))
    outputs = wrapper_model(torch.unsqueeze(inputs_ids, 0).to('cuda:0'),
                            torch.unsqueeze(example['inputs_embeds'], 0).to('cuda:0'), 
                            torch.unsqueeze(example['attention_mask'], 0).to('cuda:0'),
                            torch.unsqueeze(example['labels'], 0).to('cuda:0'))
    #
    logits = outputs.logits
    #print("logits:", logits[:, idx, :])
    predicted_token_id = torch.argmax(logits[:, idx, :], dim=-1)
    # Convert the token ID to the actual token
    predicted_token = tokenizer.decode(predicted_token_id)
    print(idx, predicted_token)
    idx += 1
    #if idx > START_IDX+2:
    inputs_ids[idx] = predicted_token_id.detach().cpu().numpy()[0]

  attention_mask = torch.cat([torch.tensor(encoding['attention_mask']), torch.ones(OUTPUT_LEN)])
  inputs_embeds = torch.cat([torch.tensor(encoding['seq']), torch.zeros(OUTPUT_LEN, EMBED_DIM)])


inputs_ids: tensor([   0,  134,  111,  132,    4,  406, 3226,   90,    2,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1]) <s>1 - 2.7*t</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>
768 -
769 2
770 .
771 9
772 *
773 t
774 </s>
775 <pad>
776 <pad>
777 <pad>
778 <pad>
779 <pad>
780 <pad>
781 <pad>
782 <pad>
783 <pad>
784 <pad>
785 <pad>
786 <pad>
787 <pad>


In [111]:
inputs_ids = torch.ones(example['inputs_ids'].shape, dtype=example['inputs_ids'].dtype)
inputs_ids[769] = 0

In [95]:
example['labels'][768:]

tensor([    0,   288,     4,   406,  3226, 16254,  1640,   306,     4,   398,
         3226,    90,    43,     2,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
            1,     1,     1,     1,     1,     1,     1,     1, 

In [22]:
example = valid_dataset[1000]
tokenizer.decode(example['inputs_ids'])

  attention_mask = torch.cat([torch.tensor(encoding['attention_mask']), torch.ones(OUTPUT_LEN)])
  inputs_embeds = torch.cat([torch.tensor(encoding['seq']), torch.zeros(OUTPUT_LEN, EMBED_DIM)])


'<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad

In [26]:
torch.save(wrapper_model, 'datasets/function_type_lm_800k_loss0.006.model')