In [1]:
import torch
import pandas as pd
from torch.utils.data import DataLoader
from transformers import RobertaTokenizer, RobertaForSequenceClassification, MobileBertTokenizer, MobileBertForSequenceClassification
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm

2023-02-19 19:50:33.110194: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.10.1


In [2]:
DataSet_name = 'Data_V1.2.csv'
# Model_name = 'roberta'
Model_name = 'mobilebert'
Learning_rate = 5e-5
batch_size = 64
epochs = 1000
is_train = True
word_col = 'Word'
input_col=['Date', 'Month', 'Day', 'WeekNum', 'Contest number']
output_col=['Number of  reported results', 'Number in hard mode', '1 try', '2 tries', '3 tries', '4 tries', '5 tries', '6 tries', '7 or more tries (X)']
model_ckpt_path = 'model_ckpt'

In [3]:
# Load the dataset from a CSV file
df = pd.read_csv(DataSet_name)

# Percentage needs to be divided by 100
Need_div_cols=['1 try', '2 tries', '3 tries', '4 tries', '5 tries', '6 tries', '7 or more tries (X)']
for col in Need_div_cols:
    df[col] = df[col]/100

# Print the first 10 rows of the data
print(df.head(10))

         Date  Month  Day  WeekNum  Contest number   Word  \
0  2022/12/31     12   31        7             560  manly   
1  2022/12/30     12   30        6             559  molar   
2  2022/12/29     12   29        5             558  havoc   
3  2022/12/28     12   28        4             557  impel   
4  2022/12/27     12   27        3             556  condo   
5  2022/12/26     12   26        2             555  judge   
6  2022/12/25     12   25        1             554  extra   
7  2022/12/24     12   24        7             553  poise   
8  2022/12/23     12   23        6             552  aorta   
9  2022/12/22     12   22        5             551  excel   

   Number of  reported results  Number in hard mode  1 try  2 tries  3 tries  \
0                        20380                 1899   0.00     0.02     0.17   
1                        21204                 1973   0.00     0.04     0.21   
2                        20001                 1919   0.00     0.02     0.16   
3       

In [4]:


# Set up the tokenizer and model
if Model_name == 'roberta':
    # # Define the loss function and metric
    # loss_fn = torch.nn.CrossEntropyLoss()
    # metric_fn = accuracy_score

    # Define a new linear layer to map the output to the desired shape
    class CustomOutputLayer(torch.nn.Module):
        def __init__(self, in_features):
            super().__init__()
            self.linear1 = torch.nn.Linear(in_features, 2)
            self.linear2 = torch.nn.Linear(in_features, 7)
            self.softmax = torch.nn.Softmax(dim=1)
        
        def forward(self, x):
            # out1, out2 = x
            out1 = self.linear1(x)
            out2 = self.linear2(x)
            probs = self.softmax(out2)
            return torch.cat([out1, probs], dim=1)

    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    model = RobertaForSequenceClassification.from_pretrained('roberta-base')

    # Replace the last layer of the model with the custom output layer
    print(model.classifier)
    # print(model.classifier.in_features)
    model.classifier = CustomOutputLayer(2)

elif Model_name == 'mobilebert':
    # Define a new linear layer to map the output to the desired shape
    class CustomOutputLayer(torch.nn.Module):
        def __init__(self, in_features):
            super().__init__()
            self.linear1 = torch.nn.Linear(in_features, 2)
            self.linear2 = torch.nn.Linear(in_features, 7)
            self.softmax = torch.nn.Softmax(dim=1)

        def forward(self, x):
            out1 = self.linear1(x)
            out2 = self.linear2(x)
            probs = self.softmax(out2)
            return torch.cat([out1, probs], dim=1)

    tokenizer = MobileBertTokenizer.from_pretrained('google/mobilebert-uncased')
    model = MobileBertForSequenceClassification.from_pretrained('google/mobilebert-uncased')

    # Replace the last layer of the MobileBERT model with the custom output layer
    print(model.classifier)
    print(model.classifier.in_features)
    model.classifier = CustomOutputLayer(model.classifier.in_features)

# Move model to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Some weights of the model checkpoint at google/mobilebert-uncased were not used when initializing MobileBertForSequenceClassification: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.predictions.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing MobileBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing MobileBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some 

Linear(in_features=512, out_features=2, bias=True)
512


KeyboardInterrupt: 

In [None]:
# Prepare dataset and dataloader
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, df, tokenizer, input_col, output_col):
        self.df = df
        self.tokenizer = tokenizer
        self.input_col = input_col
        self.output_col = output_col

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        inputs = [row[col] for col in self.input_col]
        inputs = ' '.join(str(x) for x in inputs)
        inputs = self.tokenizer.encode_plus(inputs, padding='max_length', max_length=20, truncation=True, return_tensors='pt')
        labels = [row[col] for col in self.output_col]
        labels = torch.tensor(labels, dtype=torch.float32)
        return inputs['input_ids'][0], inputs['attention_mask'][0], labels

# from sklearn.preprocessing import StandardScaler

# class CustomDataset(torch.utils.data.Dataset):
#     def __init__(self, data, tokenizer, input_cols, output_cols, max_length):
#         self.tokenizer = tokenizer
#         self.max_length = max_length
#         self.inputs = data[input_cols].to_numpy().astype('float32')
#         self.outputs = data[output_cols].to_numpy().astype('float32')
#         self.n_inputs = len(input_cols)
#         self.n_outputs = len(output_cols)
        
#         # normalize inputs using StandardScaler
#         self.scaler = StandardScaler()
#         self.inputs = self.scaler.fit_transform(self.inputs)
        
#     def __len__(self):
#         return len(self.inputs)
    
#     def __getitem__(self, idx):
#         input_ids, attention_mask = self._encode_inputs(self.inputs[idx])
#         labels = self.outputs[idx]
        
#         return {
#             'input_ids': input_ids,
#             'attention_mask': attention_mask,
#             'labels': labels
#         }
    
#     def _encode_inputs(self, row):
#         tokens = [self.tokenizer.cls_token] + [str(elem) for elem in row] + [self.tokenizer.sep_token]
#         input_ids = self.tokenizer.encode(tokens, add_special_tokens=False, max_length=self.max_length)
#         padding_length = self.max_length - len(input_ids)
#         attention_mask = [1] * len(input_ids) + [0] * padding_length
#         input_ids = input_ids + [self.tokenizer.pad_token_id] * padding_length
        
#         return input_ids, attention_mask


dataset = CustomDataset(df, tokenizer, input_col, output_col)
dataloader = DataLoader(dataset, batch_size=batch_size)

# Print the first batch of data
for batch in dataloader:
    # print(batch['input_ids'])
    # print(batch['attention_mask'])
    # print(batch['output'])
    print(batch[0])
    print(batch[1])
    print(batch[2])
    break

tensor([[  101, 16798,  2475,  ...,     0,     0,     0],
        [  101, 16798,  2475,  ...,     0,     0,     0],
        [  101, 16798,  2475,  ...,     0,     0,     0],
        ...,
        [  101, 16798,  2475,  ...,     0,     0,     0],
        [  101, 16798,  2475,  ...,     0,     0,     0],
        [  101, 16798,  2475,  ...,     0,     0,     0]])
tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])
tensor([[2.0380e+04, 1.8990e+03, 0.0000e+00, 2.0000e-02, 1.7000e-01, 3.7000e-01,
         2.9000e-01, 1.2000e-01, 2.0000e-02],
        [2.1204e+04, 1.9730e+03, 0.0000e+00, 4.0000e-02, 2.1000e-01, 3.8000e-01,
         2.6000e-01, 9.0000e-02, 1.0000e-02],
        [2.0001e+04, 1.9190e+03, 0.0000e+00, 2.0000e-02, 1.6000e-01, 3.8000e-01,
         3.0000e-01, 1.2000e-01, 2.0000e-02],
        [2.0160e+04, 1.9370e+03, 0.0000

In [None]:
# Set up the optimizer and loss function
optimizer = torch.optim.AdamW(model.parameters(), lr=Learning_rate)
loss_fn = torch.nn.MSELoss()

In [None]:
# Train the model

# Record the loss and accuracy for each epoch
train_loss = []
# train_acc = []

if is_train:
    model.train()
    # progress_bar = tqdm(range(epochs), desc='Epoch')
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader):
            # input_ids = batch['input_ids'].to(device)
            # attention_mask = batch['attention_mask'].to(device)
            # labels = batch['output'].to(device)
            input_ids = batch[0].to(device)
            attention_mask = batch[1].to(device)
            labels = batch[2].to(device)

            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = loss_fn(outputs.logits, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)

        # Record the loss and accuracy for each epoch
        train_loss.append(avg_loss)

        # Use exponential representation and retain 4 decimal places
        avg_loss_str = '{:.4e}'.format(avg_loss)
        
        print(f'Epoch {epoch+1}, Loss: {avg_loss_str}')

        if (epoch+1) % 10 == 0:
            model.save_pretrained(model_ckpt_path+'/epoch'+str(epoch+1)+' loss'+avg_loss_str)


        # train_acc.append(avg_acc)

    # Save the fine-tuned model
    model.save_pretrained(model_ckpt_path+'/final epoch:'+str(epoch+1)+' loss:'+avg_loss_str)

    # Plot the loss and accuracy for each epoch
    import matplotlib.pyplot as plt
    plt.plot(train_loss, label='Training loss')
    plt.title('Training loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.show()
    # Save the plot
    plt.savefig('loss.png', dpi=1000)

  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 1, Loss: 7.9778e+12


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 2, Loss: 2.5470e+10


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 3, Loss: 2.0682e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 4, Loss: 2.0942e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 5, Loss: 2.0932e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 6, Loss: 2.0865e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 7, Loss: 2.0719e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 8, Loss: 2.0523e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 9, Loss: 2.0309e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 10, Loss: 2.0063e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 11, Loss: 1.9824e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 12, Loss: 1.9576e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 13, Loss: 1.9254e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 14, Loss: 1.8890e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 15, Loss: 1.8436e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 16, Loss: 1.7911e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 17, Loss: 1.7263e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 18, Loss: 1.6637e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 19, Loss: 1.5637e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 20, Loss: 1.4170e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 21, Loss: 1.2891e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 22, Loss: 1.1513e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 23, Loss: 1.1078e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 24, Loss: 1.1128e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 25, Loss: 1.1148e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 26, Loss: 1.1006e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 27, Loss: 1.0597e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 28, Loss: 1.0079e+09


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 29, Loss: 9.8288e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 30, Loss: 9.4410e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 31, Loss: 9.6136e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 32, Loss: 9.2508e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 33, Loss: 8.9648e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 34, Loss: 8.5359e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 35, Loss: 8.9683e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 36, Loss: 8.3694e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 37, Loss: 8.1542e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 38, Loss: 7.1521e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 39, Loss: 7.3117e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 40, Loss: 7.0430e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 41, Loss: 7.3816e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 42, Loss: 6.3715e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 43, Loss: 5.5504e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 44, Loss: 5.9898e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 45, Loss: 6.4359e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 46, Loss: 5.1666e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 47, Loss: 4.7261e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 48, Loss: 4.6110e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 49, Loss: 4.7624e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 50, Loss: 4.4598e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 51, Loss: 4.7041e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 52, Loss: 3.9389e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 53, Loss: 3.6942e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 54, Loss: 4.3822e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 55, Loss: 3.9500e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 56, Loss: 3.8330e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 57, Loss: 4.0079e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 58, Loss: 3.5830e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 59, Loss: 3.6036e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 60, Loss: 3.4717e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 61, Loss: 3.1165e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 62, Loss: 3.3014e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 63, Loss: 3.0731e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 64, Loss: 3.3838e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 65, Loss: 2.8007e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 66, Loss: 3.0404e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 67, Loss: 3.6446e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 68, Loss: 2.5274e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 69, Loss: 2.5779e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 70, Loss: 2.7635e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 71, Loss: 2.5089e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 72, Loss: 2.6399e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 73, Loss: 2.1043e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 74, Loss: 2.3653e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 75, Loss: 2.2274e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 76, Loss: 2.0078e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 77, Loss: 2.5282e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 78, Loss: 2.2223e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 79, Loss: 2.0201e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 80, Loss: 2.1549e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 81, Loss: 2.1808e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 82, Loss: 2.0295e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 83, Loss: 1.7944e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 84, Loss: 1.5570e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 85, Loss: 1.6817e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 86, Loss: 1.4307e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 87, Loss: 1.7804e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 88, Loss: 1.9845e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 89, Loss: 1.9933e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 90, Loss: 1.7715e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 91, Loss: 1.5560e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 92, Loss: 1.4993e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 93, Loss: 1.4568e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 94, Loss: 1.7672e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 95, Loss: 1.7782e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 96, Loss: 1.3176e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 97, Loss: 1.5405e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 98, Loss: 1.4902e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 99, Loss: 1.5184e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 100, Loss: 1.3079e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 101, Loss: 1.3684e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 102, Loss: 1.0535e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 103, Loss: 9.9464e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 104, Loss: 1.4534e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 105, Loss: 1.4099e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 106, Loss: 1.5440e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 107, Loss: 1.6338e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 108, Loss: 1.2126e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 109, Loss: 1.1011e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 110, Loss: 1.0546e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 111, Loss: 9.7949e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 112, Loss: 1.1348e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 113, Loss: 1.2091e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 114, Loss: 9.8171e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 115, Loss: 1.1710e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 116, Loss: 1.4420e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 117, Loss: 1.2179e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 118, Loss: 1.2727e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 119, Loss: 1.1321e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 120, Loss: 1.0352e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 121, Loss: 8.6966e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 122, Loss: 1.0368e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 123, Loss: 1.3358e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 124, Loss: 1.1523e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 125, Loss: 1.0160e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 126, Loss: 9.5864e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 127, Loss: 9.1838e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 128, Loss: 8.6302e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 129, Loss: 9.0632e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 130, Loss: 9.0930e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 131, Loss: 1.1402e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 132, Loss: 7.6898e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 133, Loss: 8.6548e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 134, Loss: 7.7051e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 135, Loss: 9.8715e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 136, Loss: 8.3926e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 137, Loss: 1.0572e+08


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 138, Loss: 6.1472e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 139, Loss: 7.4959e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 140, Loss: 8.4134e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 141, Loss: 7.4756e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 142, Loss: 8.9918e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 143, Loss: 8.1557e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 144, Loss: 6.4197e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 145, Loss: 7.6283e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 146, Loss: 8.0558e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 147, Loss: 7.1056e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 148, Loss: 6.5206e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 149, Loss: 7.2133e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 150, Loss: 5.9556e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 151, Loss: 7.1150e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 152, Loss: 7.0574e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 153, Loss: 7.1374e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 154, Loss: 6.8954e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 155, Loss: 5.2483e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 156, Loss: 6.1953e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 157, Loss: 6.8957e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 158, Loss: 6.3395e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 159, Loss: 5.0996e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 160, Loss: 5.3839e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 161, Loss: 7.7080e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 162, Loss: 5.3458e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 163, Loss: 5.2132e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 164, Loss: 6.1130e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 165, Loss: 6.4665e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 166, Loss: 5.7673e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 167, Loss: 6.3857e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 168, Loss: 4.6694e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 169, Loss: 5.4896e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 170, Loss: 6.3030e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 171, Loss: 6.0515e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 172, Loss: 5.5546e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 173, Loss: 6.0338e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 174, Loss: 4.2474e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 175, Loss: 4.6682e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 176, Loss: 5.4552e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 177, Loss: 5.0371e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 178, Loss: 5.4183e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 179, Loss: 5.8465e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 180, Loss: 9.1391e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 181, Loss: 6.3402e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 182, Loss: 6.0551e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 183, Loss: 3.9445e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 184, Loss: 5.4240e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 185, Loss: 5.2201e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 186, Loss: 4.6473e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 187, Loss: 5.3915e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 188, Loss: 4.4651e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 189, Loss: 4.8221e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 190, Loss: 4.7155e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 191, Loss: 4.6895e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 192, Loss: 4.4423e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 193, Loss: 4.1053e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 194, Loss: 4.7332e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 195, Loss: 3.6366e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 196, Loss: 5.8076e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 197, Loss: 4.8263e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 198, Loss: 5.7882e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 199, Loss: 4.1147e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 200, Loss: 4.9643e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 201, Loss: 4.3639e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 202, Loss: 3.6251e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 203, Loss: 4.2707e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 204, Loss: 4.0229e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 205, Loss: 5.3730e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 206, Loss: 5.0783e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 207, Loss: 4.4552e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 208, Loss: 3.0688e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 209, Loss: 4.2625e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 210, Loss: 3.9372e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 211, Loss: 4.2095e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 212, Loss: 3.6677e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 213, Loss: 4.9541e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 214, Loss: 4.8143e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 215, Loss: 4.4921e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 216, Loss: 4.6554e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 217, Loss: 4.0828e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 218, Loss: 3.5658e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 219, Loss: 3.5503e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 220, Loss: 4.2803e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 221, Loss: 3.4069e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 222, Loss: 4.0290e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 223, Loss: 3.7009e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 224, Loss: 3.7016e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 225, Loss: 3.1185e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 226, Loss: 3.2639e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 227, Loss: 3.4260e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 228, Loss: 3.4533e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 229, Loss: 3.5053e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 230, Loss: 4.6222e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 231, Loss: 5.6041e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 232, Loss: 4.4032e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 233, Loss: 3.2918e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 234, Loss: 3.4886e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 235, Loss: 3.2312e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 236, Loss: 3.4716e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 237, Loss: 3.9950e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 238, Loss: 3.0841e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 239, Loss: 3.9416e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 240, Loss: 3.4050e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 241, Loss: 2.7758e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 242, Loss: 3.6153e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 243, Loss: 4.2326e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 244, Loss: 5.0139e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 245, Loss: 3.6458e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 246, Loss: 3.9413e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 247, Loss: 2.5296e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 248, Loss: 3.1933e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 249, Loss: 3.9667e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 250, Loss: 3.2453e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 251, Loss: 3.7263e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 252, Loss: 3.5120e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 253, Loss: 2.7977e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 254, Loss: 2.1656e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 255, Loss: 3.3695e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 256, Loss: 3.2496e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 257, Loss: 4.1361e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 258, Loss: 3.8110e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 259, Loss: 2.8341e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 260, Loss: 2.4966e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 261, Loss: 2.2935e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 262, Loss: 2.4913e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 263, Loss: 2.8794e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 264, Loss: 3.4512e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 265, Loss: 3.2877e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 266, Loss: 2.7901e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 267, Loss: 2.6214e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 268, Loss: 2.4062e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 269, Loss: 2.1133e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 270, Loss: 2.8580e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 271, Loss: 2.7063e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 272, Loss: 3.9390e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 273, Loss: 4.3771e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 274, Loss: 2.6893e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 275, Loss: 2.6636e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 276, Loss: 2.6939e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 277, Loss: 2.4414e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 278, Loss: 2.6852e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 279, Loss: 2.6209e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 280, Loss: 2.7822e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 281, Loss: 2.6776e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 282, Loss: 2.3428e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 283, Loss: 2.4153e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 284, Loss: 2.0406e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 285, Loss: 2.5850e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 286, Loss: 2.6519e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 287, Loss: 2.8991e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 288, Loss: 2.1576e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 289, Loss: 3.4614e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 290, Loss: 5.3348e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 291, Loss: 3.8132e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 292, Loss: 3.2023e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 293, Loss: 2.0163e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 294, Loss: 2.7714e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 295, Loss: 2.1927e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 296, Loss: 2.1092e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 297, Loss: 2.9870e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 298, Loss: 2.3563e+07


  0%|          | 0/6 [00:00<?, ?it/s]

Epoch 299, Loss: 2.1798e+07


  0%|          | 0/6 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
# # Evaluate the model from the saved checkpoint
# model = model.from_pretrained('fine_tuned_model')
# model.eval()

# # Predict the output for a single example
# example = df.iloc[0]
# inputs = [example[col] for col in input_col]
# inputs = ' '.join(str(x) for x in inputs)
# inputs = tokenizer.encode_plus(inputs, padding='max_length', max_length=20, truncation=True, return_tensors='pt')
# input_ids = inputs['input_ids'].to(device)
# attention_mask = inputs['attention_mask'].to(device)
# outputs = model(input_ids, attention_mask=attention_mask)
# print(outputs.logits)