In [1]:
from typing import Callable, Dict, Optional, Tuple, Union, Any, collections
import torch
import copy 
import random
import numpy as np

from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss

from torch.utils.data import Dataset

from transformers.file_utils import (
    WEIGHTS_NAME,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
    is_sagemaker_distributed_available,
    is_torch_tpu_available,
)

from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions,
    SequenceClassifierOutputWithPast,
)

from transformers import GPT2LMHeadModel, GPT2Model

import os
import time
import datetime
import torch
import math
import copy 
import random
from packaging import version
import pandas as pd
import numpy as np
import pickle

from torch import nn
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, DataLoader, random_split, RandomSampler, SequentialSampler

from transformers import GPT2LMHeadModel,  GPT2Tokenizer, GPT2Config, GPT2LMHeadModel, GPT2Model
from transformers import AdamW, get_linear_schedule_with_warmup
from transformers import Trainer, TrainerState, TrainingArguments

from transformers.utils import logging
logger = logging.get_logger(__name__)

In [2]:
batch_size = 100
MAX_LEN = 128
CAN_NUM = 20
num_of_rerank = 30

# some parameters I cooked up that work reasonably well
epochs = 1
learning_rate = 5e-4
warmup_steps = 1e2
epsilon = 1e-8

# Set the seed value all over the place to make this reproducible.
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

SAVE_PATH = "/mnt/nfs/work1/llcao/zonghaiyao/LM/"

global debug
debug = {}
debug['check_in_num'] = 0
debug['total_sample_num'] = 0

In [3]:
class wiki2021_GPT2Dataset(Dataset):
    def __init__(self, input_ids):
        self.input_ids = input_ids
    
    def __len__(self):
        return len(self.input_ids)

    def __getitem__(self, idx):
        return self.input_ids[idx]

In [4]:
class rerankGPT2LMHeadModel_exclude_cases_label_not_in_candidates(GPT2LMHeadModel):
    def __init__(self, config, MAX_LEN, CAN_NUM, num_of_rerank):
        super().__init__(config)
        self.MAX_LEN = MAX_LEN
        self.CAN_NUM = CAN_NUM
        self.num_of_rerank = num_of_rerank
        self.VOCAB_SIZE = config.vocab_size
        
        self.transformer = GPT2Model(config)
        self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        
#         self.dropout = nn.Dropout(config.hidden_dropout_prob)
#         self.rerank_transformer = GPT2Model(config)
        self.rerank_linear_head = nn.Linear(config.n_embd, 1, bias=False)

        self.init_weights()
    def forward(
        self,
        input_ids=None,
        labels=None,
        is_training=False,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            ``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
            ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
        """
        # make some model parameter not change during rerank (like dropout) ??????????????
        # model.eval()
        global debug
        if is_training:
            self.transformer.eval()

        rerank_places = random.sample(np.arange(1, self.MAX_LEN).tolist(), k=self.num_of_rerank) #no duplicate
        rerank_places = np.concatenate(([0], np.sort(rerank_places), [self.MAX_LEN])) #add first and last tokens to make segments
        
        past_key_values = None
        all_rerank_hidden_states = []
        all_rerank_labels = []
        no_rerank_logits = []
        check_out_num = 0
        
        if not is_training:
            all_candidate_token_ids = []
            all_input_ids = []
            all_prediction_ids = []
        
        for i in range(self.num_of_rerank+1):
            #normal stage
            segment_input_ids = input_ids[:, rerank_places[i]:rerank_places[i+1]]

            segment_outputs = self.transformer(
                segment_input_ids,
                past_key_values = past_key_values
            )

            segment_hidden = segment_outputs[0]
            past_key_values = segment_outputs[1]

            #rerank stage (just for rerank places)
            if i == self.num_of_rerank:
                break

            #rerank stage
            #get logits in rerank place
            logits_before_rerank = self.lm_head(segment_hidden[:, -1, :])
            #get candidate token ids according to the logits
            candidate_token_logits, candidate_token_ids = torch.topk(logits_before_rerank, self.CAN_NUM)
            rerank_labels = labels[..., rerank_places[i+1]]
            
            debug['logits_before_rerank'] = logits_before_rerank
            debug['candidate_token_ids'] = candidate_token_ids
            debug['rerank_labels'] = rerank_labels
            
            assert 1==0
            
            #check whether or not label in candidates
            check_labels = rerank_labels.tolist()
            check_candidates = candidate_token_ids.tolist()
            
            assert len(check_labels)==len(check_candidates)
            
            #check whether or not label in candidates, if not we do not do rerank
            rerank_labels = []
            check_in_index = []

            for j in range(len(check_labels)): 
                if check_labels[j] in check_candidates[j]:
                    rerank_labels.append(check_candidates[j].index(check_labels[j]))
                    check_in_index.append(j)
                else:
                    check_out_num+=1
            rerank_labels = torch.tensor(rerank_labels, device=input_ids.device)

            if rerank_labels.shape[0] == 0:
                continue
            else:
                all_rerank_labels.append(rerank_labels)

            #make context for rerank stage, 50256 is the token_id for </endoftext/>
            sep_token = torch.ones(size = [candidate_token_ids.shape[0], 1], dtype = torch.long, device=input_ids.device) * 50256
            candidate_context_ids = torch.cat([sep_token, candidate_token_ids, sep_token, candidate_token_ids], -1)

            
            #get output from gpt2
            rerank_outputs = self.transformer(candidate_context_ids,
                            past_key_values=past_key_values,
                          )

            #get rerank logits for candidates
            rerank_hidden_states = rerank_outputs[0][:, 2+self.CAN_NUM:2+self.CAN_NUM*2]

            all_rerank_hidden_states.append(rerank_hidden_states[check_in_index])
            no_rerank_logits.append(candidate_token_logits[check_in_index])
            
            if not is_training:
                all_candidate_token_ids.append(candidate_token_ids[check_in_index])
                all_prediction_ids.append(input_ids[:, rerank_places[i+1]][check_in_index])
                all_input_ids.append(input_ids[:, :rerank_places[i+1]][check_in_index])

        
#         print("\n batch info:")
#         print("there are ", check_out_num/(self.num_of_rerank*batch_size), "labels not in candidates")
    
        if is_training:
            #model.train()
            self.transformer.train()
        
            #-------------------------------------------------------------------------
            # cal loss, loss = normal loss + rerank loss
            loss_fct = CrossEntropyLoss(reduction='none')

            # cal rerank loss
            rerank_loss = None

            all_rerank_hidden_states = torch.cat(all_rerank_hidden_states, 0)
            all_rerank_logits = self.rerank_linear_head(all_rerank_hidden_states)
            all_rerank_logits = torch.reshape(all_rerank_logits, [-1, self.CAN_NUM])
            all_rerank_labels = torch.cat(all_rerank_labels, 0)

            rerank_loss = loss_fct(all_rerank_logits, all_rerank_labels)

            # cal normal loss in rerank place (for comparision with rerank results), only evaluate
            normal_loss_in_rerank_place = None


            no_rerank_logits = torch.cat(no_rerank_logits, 0)
            no_rerank_logits = torch.reshape(no_rerank_logits, [-1, self.CAN_NUM])
            #no_rerank_labels = torch.cat(all_rerank_labels, 0) #no_rerank_labels == all_rerank_labels

            normal_loss_in_rerank_place = loss_fct(no_rerank_logits, all_rerank_labels)


            return {"normal_loss_in_rerank_place": normal_loss_in_rerank_place,
                    "rerank_loss": rerank_loss,}
        # for evaluation, we will evaluate the model's performance on different difficult level
        else:
            #-------------------------------------------------------------------------
            # cal loss, loss = normal loss + rerank loss
            loss_fct = CrossEntropyLoss(reduction='none')

            # cal rerank loss
            rerank_loss = None

            all_rerank_hidden_states = torch.cat(all_rerank_hidden_states, 0)
            all_rerank_logits = self.rerank_linear_head(all_rerank_hidden_states)
            all_rerank_logits = torch.reshape(all_rerank_logits, [-1, self.CAN_NUM])
            all_rerank_labels = torch.cat(all_rerank_labels, 0)

            rerank_loss = loss_fct(all_rerank_logits, all_rerank_labels)

            # cal normal loss in rerank place (for comparision with rerank results), only evaluate
            normal_loss_in_rerank_place = None
            no_rerank_logits = torch.cat(no_rerank_logits, 0)
            no_rerank_logits = torch.reshape(no_rerank_logits, [-1, self.CAN_NUM])
            #no_rerank_labels = torch.cat(all_rerank_labels, 0) #no_rerank_labels == all_rerank_labels

            normal_loss_in_rerank_place = loss_fct(no_rerank_logits, all_rerank_labels)
            
            for i in range(len(all_input_ids)):
                target = torch.ones(size = [all_input_ids[i].shape[0], self.MAX_LEN], dtype = torch.long, device=input_ids.device) * 50256
                target[:, :all_input_ids[i].shape[1]] = all_input_ids[i]
                all_input_ids[i] = target
                
            all_input_ids = torch.cat(all_input_ids, 0)
            all_prediction_ids = torch.cat(all_prediction_ids, 0)
            all_candidate_token_ids = torch.cat(all_candidate_token_ids, 0)

            return {"all_rerank_logits": all_rerank_logits,
                    "no_rerank_logits": no_rerank_logits,
                    "difficult_level": all_rerank_labels,
                    "rerank_loss": rerank_loss,
                    "normal_loss_in_rerank_place": normal_loss_in_rerank_place,
                    "all_candidate_token_ids": all_candidate_token_ids,
                    "all_prediction_ids": all_prediction_ids,
                    "all_input_ids": all_input_ids}

In [5]:
# I'm not really doing anything with the config buheret
configuration = GPT2Config.from_pretrained('gpt2', output_hidden_states=False)

# Load the GPT tokenizer.
tokenizer = GPT2Tokenizer.from_pretrained('gpt2', pad_token='<|endoftext|>') #gpt2-medium

# instantiate the model
# model = rerankGPT2LMHeadModel_exclude_cases_label_not_in_candidates.from_pretrained("results/baseline_wiki2021/exclude_cases_label_not_in_candidates_canNUM20/10000", 
#                                                                                     config=configuration,
#                                                                                     MAX_LEN = MAX_LEN,
#                                                                                     CAN_NUM = CAN_NUM, 
#                                                                                     num_of_rerank = num_of_rerank)
# model = rerankGPT2LMHeadModel_exclude_cases_label_not_in_candidates.from_pretrained("gpt2", 
#                                                                                     config=configuration,
#                                                                                     MAX_LEN = MAX_LEN,
#                                                                                     CAN_NUM = CAN_NUM, 
#                                                                                     num_of_rerank = num_of_rerank)
model = rerankGPT2LMHeadModel_exclude_cases_label_not_in_candidates.from_pretrained("results/baseline_wiki2021/not_stage2_canNUM20/last_model", 
                                                                                    config=configuration,
                                                                                    MAX_LEN = MAX_LEN,
                                                                                    CAN_NUM = CAN_NUM, 
                                                                                    num_of_rerank = num_of_rerank)


# this step is necessary because I've added some tokens (bos_token, etc) to the embeddings
# otherwise the tokenizer and model tensors won't match up
model.resize_token_embeddings(len(tokenizer))

if torch.cuda.device_count() > 1:
    print("Let's use", torch.cuda.device_count(), "GPUs!")
    # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = torch.nn.DataParallel(model) # Encapsulate the model


# Tell pytorch to run this model on the GPU.
device = torch.device("cuda")
model.cuda()

Some weights of rerankGPT2LMHeadModel_exclude_cases_label_not_in_candidates were not initialized from the model checkpoint at results/baseline_wiki2021/not_stage2_canNUM20/last_model and are newly initialized: ['rerank_linear_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


DataParallel(
  (module): rerankGPT2LMHeadModel_exclude_cases_label_not_in_candidates(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0): Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): MLP(
            (c_fc): Conv1D()
            (c_proj): Conv1D()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (1): Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): Attention(
            (c_attn): Conv1D()
            (c_proj): Conv1D()
            (attn_dropo

In [6]:
# with open(SAVE_PATH + 'data/wiki2021/wiki2021_0to4_train_dataset.pkl', 'rb') as f:
#     train_input_ids = pickle.load(f)
with open(SAVE_PATH + 'data/wiki2021/wiki2021_0to4_validation_dataset.pkl', 'rb') as f:
    validation_input_ids = pickle.load(f)
with open(SAVE_PATH + 'data/wiki2021/wiki2021_0to4_inside_validation_dataset.pkl', 'rb') as f:
    inside_validation_input_ids = pickle.load(f)
    
# train_dataset = wiki2021_GPT2Dataset(train_input_ids)
validation_dataset = wiki2021_GPT2Dataset(validation_input_ids)
inside_validation_dataset = wiki2021_GPT2Dataset(inside_validation_input_ids)

# Create the DataLoaders for our training and validation datasets.
# We'll take training samples in random order. 
# train_dataloader = DataLoader(
#             train_dataset,  # The training samples.
#             sampler = RandomSampler(train_dataset), # Select batches randomly
#             batch_size = batch_size # Trains with this batch size.
#         )

# For validation the order doesn't matter, so we'll just read them sequentially.
validation_dataloader = DataLoader(
            validation_dataset, # The validation samples.
            sampler = SequentialSampler(validation_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

# For inside_validation the order doesn't matter, so we'll just read them sequentially.
inside_validation_dataloader = DataLoader(
            inside_validation_dataset, # The validation samples.
            sampler = SequentialSampler(inside_validation_dataset), # Pull out batches sequentially.
            batch_size = batch_size # Evaluate with this batch size.
        )

In [7]:
def format_time(elapsed):
    return str(datetime.timedelta(seconds=int(round((elapsed)))))

In [8]:
model = model.to(device)

t1 = time.time()

model.eval()

total_eval_loss = 0
total_eval_normal_loss = 0
total_eval_rerank_loss = 0

all_evaluate_rerank_logits = []
all_evaluate_normal_logits = []
all_evaluate_difficult_level = []
all_evaluate_rerank_loss = []
all_evaluate_normal_loss_in_rerank_place = []
all_evaluate_ground_true = []
all_evaluate_inputs_text = []
all_evaluate_candidate_token = []

# Evaluate data for one epoch
for batch in inside_validation_dataloader:        
    with torch.no_grad():        
        outputs = model(input_ids=batch,         #batch_input_ids
                        labels=batch,            #batch_labels
                        is_training=False,
                        )

        normal_loss = outputs["normal_loss_in_rerank_place"].mean()
        rerank_loss = outputs["rerank_loss"].mean()

        loss = normal_loss + rerank_loss

    batch_loss = loss.item()
    total_eval_loss += batch_loss        

    batch_normal_loss = normal_loss.item()
    total_eval_normal_loss += batch_normal_loss

    batch_rerank_loss = rerank_loss.item()
    total_eval_rerank_loss += batch_rerank_loss

    #fine-grained evaluation
    all_evaluate_rerank_logits.extend(outputs["all_rerank_logits"].tolist())
    all_evaluate_normal_logits.extend(outputs["no_rerank_logits"].tolist())
    all_evaluate_difficult_level.extend(outputs["difficult_level"].tolist())
    all_evaluate_rerank_loss.extend(outputs["rerank_loss"].tolist())
    all_evaluate_normal_loss_in_rerank_place.extend(outputs["normal_loss_in_rerank_place"].tolist())    
    all_evaluate_ground_true.extend(tokenizer.batch_decode(outputs['all_prediction_ids'], skip_special_tokens=True))
    all_evaluate_inputs_text.extend(tokenizer.batch_decode(outputs['all_input_ids'], skip_special_tokens=True))
    all_evaluate_candidate_token.extend([tokenizer.batch_decode(ids) for ids in outputs['all_candidate_token_ids']])


avg_val_loss = total_eval_loss / len(inside_validation_dataloader)
avg_val_normal_loss = total_eval_normal_loss / len(inside_validation_dataloader)       
avg_val_rerank_loss = total_eval_rerank_loss / len(inside_validation_dataloader)    

validation_time = format_time(time.time() - t1)    

print("  inside Validation Loss:", avg_val_loss)
print("  Average inside Validation normal_loss:", avg_val_normal_loss)
print("  Average inside Validation rerank_loss:", avg_val_rerank_loss)
print("  inside Validation took:", validation_time)

#fine_grained_evaluation
fg_eval = pd.DataFrame({"rerank_logits": all_evaluate_rerank_logits,
                        "normal_logits": all_evaluate_normal_logits,
                        "ground_true_difficulty_level": all_evaluate_difficult_level,
                        "rerank_loss": all_evaluate_rerank_loss,
                        "normal_loss": all_evaluate_normal_loss_in_rerank_place,
                        "ground_true": all_evaluate_ground_true,
                        "inputs_text": all_evaluate_inputs_text,
                        "candidate_tokens": all_evaluate_candidate_token,
                        }) 

def softmax(x):
    """Compute softmax values for each sets of scores in x."""
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum()

def cal_entropy_difficulty_level(x):
    x = softmax(x)
    entropy_difficulty_level = 0
    for i in range(CAN_NUM):
        entropy_difficulty_level -= x[i] * np.log10(x[i])

    return entropy_difficulty_level

fg_eval['entropy_difficulty_level'] = fg_eval['normal_logits'].apply(cal_entropy_difficulty_level)

AssertionError: 

In [None]:
# fg_eval

In [9]:
candidate_token_logits, candidate_token_ids = torch.topk(debug['logits_before_rerank'], 20)

In [10]:
candidate_token_logits

tensor([[5.7717, 5.7697, 5.7257,  ..., 2.1319, 2.0612, 2.0487],
        [7.9025, 7.7173, 7.3267,  ..., 3.0679, 3.0306, 3.0210],
        [6.7422, 5.1647, 5.0419,  ..., 3.2128, 3.1623, 3.1347],
        ...,
        [6.7209, 5.7657, 4.9293,  ..., 3.4125, 3.4082, 3.3533],
        [8.7585, 7.9174, 7.8672,  ..., 4.3438, 4.3235, 4.2999],
        [6.2602, 5.3347, 4.7086,  ..., 3.5911, 3.5759, 3.5080]],
       device='cuda:0')

In [11]:
candidate_token_ids

tensor([[   13,    11,   287,  ...,  5140,   326,    26],
        [  287,   788,   284,  ...,   351,   416,   276],
        [  373,   318,    11,  ..., 25409,   546, 16199],
        ...,
        [   13,    11,   290,  ...,  1120,   475,   940],
        [ 4708,  4346, 44185,  ...,  8674, 50196, 15581],
        [ 2888,  1255,   636,  ..., 21018, 24505,  1862]], device='cuda:0')

In [12]:
candidate_token_logits, candidate_token_ids = torch.sort(debug['logits_before_rerank'], descending=True)

In [13]:
candidate_token_logits

tensor([[ 5.7717,  5.7697,  5.7257,  ..., -4.7486, -4.8286, -4.9045],
        [ 7.9025,  7.7173,  7.3267,  ..., -7.5401, -7.7853, -8.0396],
        [ 6.7422,  5.1647,  5.0419,  ..., -4.5140, -4.5846, -4.6381],
        ...,
        [ 6.7209,  5.7657,  4.9293,  ..., -5.6367, -5.8668, -5.8695],
        [ 8.7585,  7.9174,  7.8672,  ..., -5.9153, -5.9355, -6.1981],
        [ 6.2602,  5.3347,  4.7086,  ..., -3.6288, -3.6861, -3.8877]],
       device='cuda:0')

In [14]:
candidate_token_ids

tensor([[   13,    11,   287,  ..., 29157, 43816, 27596],
        [  287,   788,   284,  ...,  2134,   675,  7100],
        [  373,   318,    11,  ..., 31814, 26277, 35044],
        ...,
        [   13,    11,   290,  ...,  4985, 30152, 24957],
        [ 4708,  4346, 44185,  ..., 17511,  3976,  5828],
        [ 2888,  1255,   636,  ..., 47252,  9868,  7732]], device='cuda:0')

In [15]:
a = torch.nn.functional.softmax(candidate_token_logits, dim=1)
a

tensor([[2.9289e-02, 2.9233e-02, 2.7973e-02,  ..., 7.9037e-07, 7.2960e-07,
         6.7625e-07],
        [1.4620e-01, 1.2148e-01, 8.2196e-02,  ..., 2.8729e-08, 2.2480e-08,
         1.7433e-08],
        [2.9327e-02, 6.0556e-03, 5.3557e-03,  ..., 3.7909e-07, 3.5326e-07,
         3.3486e-07],
        ...,
        [3.0249e-02, 1.1639e-02, 5.0426e-03,  ..., 1.2998e-07, 1.0326e-07,
         1.0298e-07],
        [9.9348e-02, 4.2845e-02, 4.0746e-02,  ..., 4.2112e-08, 4.1272e-08,
         3.1738e-08],
        [8.6561e-03, 3.4308e-03, 1.8342e-03,  ..., 4.3914e-07, 4.1466e-07,
         3.3897e-07]], device='cuda:0')

In [16]:
b = torch.cumsum(a, dim=1)
b = b[:, :20]
b

tensor([[0.0293, 0.0585, 0.0865,  ..., 0.1308, 0.1315, 0.1322],
        [0.1462, 0.2677, 0.3499,  ..., 0.4654, 0.4665, 0.4677],
        [0.0293, 0.0354, 0.0407,  ..., 0.0768, 0.0777, 0.0784],
        ...,
        [0.0302, 0.0419, 0.0469,  ..., 0.0732, 0.0743, 0.0753],
        [0.0993, 0.1422, 0.1829,  ..., 0.2430, 0.2442, 0.2453],
        [0.0087, 0.0121, 0.0139,  ..., 0.0290, 0.0296, 0.0301]],
       device='cuda:0')

In [17]:
b[:, -1]

tensor([0.1322, 0.4677, 0.0784, 0.0295, 0.0293, 0.2295, 0.0854, 0.2857, 0.0627,
        0.0331, 0.1492, 0.0641, 0.1287, 0.0796, 0.4730, 0.6499, 0.0410, 0.0866,
        0.0783, 0.2417, 0.3833, 0.0112, 0.2029, 0.1008, 0.2673, 0.0663, 0.3556,
        0.0649, 0.1541, 0.0321, 0.2197, 0.1043, 0.0825, 0.1649, 0.2421, 0.4685,
        0.1464, 0.0382, 0.0175, 0.0996, 0.0560, 0.8657, 0.0472, 0.0218, 0.1519,
        0.0501, 0.1703, 0.0994, 0.6367, 0.0200, 0.2369, 0.5181, 0.0210, 0.0337,
        0.4883, 0.1757, 0.2170, 0.2282, 0.0355, 0.0102, 0.6769, 0.0340, 0.0835,
        0.2154, 0.0489, 0.2493, 0.0576, 0.6750, 0.0360, 0.0959, 0.0621, 0.0393,
        0.8953, 0.0298, 0.0703, 0.0521, 0.0565, 0.6365, 0.2314, 0.0309, 0.0192,
        0.1491, 0.0337, 0.1330, 0.0429, 0.0612, 0.9981, 0.0691, 0.3158, 0.0754,
        0.1059, 0.0522, 0.3490, 0.0489, 0.2075, 0.0346, 0.6829, 0.0753, 0.2453,
        0.0301], device='cuda:0')

In [18]:
debug['check_in_num']

0

In [None]:
debug['total_sample_num']