In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import numpy as np
from torch.utils.data import Dataset, DataLoader
import glob
import wandb
import os
import torch.optim as optimizers

In [3]:
import dfs_code
from torch_geometric.data import InMemoryDataset, Data
import pickle
import torch
import torch.nn as nn
import tqdm
import copy
import pandas as pd
#torch.multiprocessing.set_sharing_strategy('file_system') # this is important on local machine
#def set_worker_sharing_strategy(worker_id: int) -> None:
#    torch.multiprocessing.set_sharing_strategy('file_system')

In [4]:
import sys
sys.path = ['../../../src'] + sys.path
from dfs_transformer import EarlyStopping, DFSCodeSeq2SeqFC, smiles2graph, PubChem

In [5]:
wandb.init(project='pubchem-bert', entity='chrisxx', name='bert10M')

config = wandb.config
config.mode = "min2min" #rnd2rnd
config.fraction_missing = 0.15
config.n_atoms = 118
config.n_bonds = 5
config.emb_dim = 120
config.nhead = 12
config.nlayers = 6
config.max_nodes = 250
config.max_edges = 500
config.dim_feedforward = 2048
config.n_files = 64
config.n_splits = 16
config.n_iter_per_split = 1
config.lr = 0.00003
config.n_epochs = 10000
config.lr_adjustment_period = 500
config.patience = 5
config.factor = 0.8
config.minimal_lr = 6e-8
config.batch_size = 50
config.accumulate_grads = 2
config.valid_patience = 100
config.valid_minimal_improvement=0.00
config.model_dir = "../../../models/new/pubchem10M/features_selfattention/medium/"
#config.data_dir = "/mnt/project/pubchem_noH/"
config.data_dir = "/home/wendlerc/noH/timeout60_4/"
#config.data_dir = "/home/wendlerc/noH/10K/"
config.pretrained_dir = "../../../models/pubchem10M/features_selfattention/medium/"#"../../models/chembl/better_transformer/medium/"
config.num_workers = 0
config.prefetch_factor = 2
config.persistent_workers = False
config.load_last = False
config.gpu_id = 1
config.missing_value = -1

[34m[1mwandb[0m: Currently logged in as: [33mchrisxx[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.12.1 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade
2021-09-12 01:23:25.379500: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2021-09-12 01:23:25.379900: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.


In [6]:
os.makedirs(config.model_dir, exist_ok=True)

In [7]:
path = config.data_dir

In [8]:
def collate_fn_features(dlist):
    node_batch = [] 
    edge_batch = []
    rnd_code_batch = []
    min_code_batch = []
    for d in dlist:
        rnd_code, rnd_index = dfs_code.rnd_dfs_code_from_torch_geometric(d, 
                                                                         d.z.numpy().tolist(), 
                                                                         np.argmax(d.edge_attr.numpy(), axis=1))
        node_batch += [d.node_features]
        edge_batch += [d.edge_features]
        rnd_code_batch += [torch.tensor(rnd_code)]
        min_code_batch += [d.min_dfs_code]
    return rnd_code_batch, node_batch, edge_batch, min_code_batch

In [9]:
ngpu=1
device = torch.device('cuda:%d'%config.gpu_id if (torch.cuda.is_available() and ngpu > 0) else 'cpu')

In [10]:
to_cuda = lambda T: [t.to(device) for t in T]

In [11]:
dataset = PubChem(path, n_used = 1)
loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, pin_memory=False, collate_fn=collate_fn_features,
                   num_workers=config.num_workers, prefetch_factor=config.prefetch_factor, 
                    persistent_workers=config.persistent_workers)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:03<00:00,  3.92s/it]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 156171/156171 [00:13<00:00, 11565.06it/s]


In [12]:
data = next(iter(loader))

In [13]:
n_node_features = data[1][0].shape[1]
n_edge_features = data[2][0].shape[1]

In [14]:
print(n_node_features, n_edge_features)

133 14


In [15]:
codes = data[-1]

In [16]:
codes

[tensor([[ 0,  1,  6,  0,  6,  7, 17,  6],
         [ 1,  2,  6,  0,  6,  6, 14,  5],
         [ 2,  0,  6,  0,  6,  5, 12,  7],
         [ 2,  3,  6,  0,  6,  5, 10,  4],
         [ 3,  4,  6,  0,  7,  4,  8,  3],
         [ 4,  5,  7,  0,  6,  3,  7, 13],
         [ 5,  6,  6,  0,  6, 13, 29, 10],
         [ 6,  7,  6,  0,  6, 10, 23, 11],
         [ 6,  8,  6,  0,  6, 10, 24, 12],
         [ 6,  9,  6,  0,  7, 10, 22,  9],
         [ 9, 10,  7,  0,  6,  9, 20,  8],
         [10,  2,  6,  0,  6,  8, 18,  5],
         [ 4, 11,  7,  0,  6,  3,  5,  2],
         [11, 12,  6,  0,  6,  2,  3,  1],
         [12, 13,  6,  1,  6,  1,  1,  0]]),
 tensor([[ 0,  1,  6,  0,  6,  8, 18,  7],
         [ 1,  2,  6,  0,  6,  7, 15,  6],
         [ 2,  3,  6,  0,  6,  6, 13,  5],
         [ 3,  4,  6,  0,  6,  5, 11,  4],
         [ 4,  5,  6,  0,  6,  4,  9,  3],
         [ 5,  0,  6,  0,  6,  3,  8,  8],
         [ 5,  6,  6,  0,  7,  3,  6,  2],
         [ 6,  7,  7,  0,  6,  2,  4,  1],
         

In [17]:
def BERTize(codes):
    inputs = []
    targets = []
    for code in codes:
        n = len(code)
        perm = np.random.permutation(n)
        target_idx = perm[:int(config.fraction_missing*n)]
        input_idx = perm[int(config.fraction_missing*n):]
        inp = code.clone()
        target = code.clone()
        target[input_idx] = -1
        inp[target_idx] = -1
        inputs += [inp]
        targets += [target]
    return inputs, targets

In [18]:
inputs, targets = BERTize(codes)

In [19]:
inputs

[tensor([[ 0,  1,  6,  0,  6,  7, 17,  6],
         [ 1,  2,  6,  0,  6,  6, 14,  5],
         [ 2,  0,  6,  0,  6,  5, 12,  7],
         [ 2,  3,  6,  0,  6,  5, 10,  4],
         [ 3,  4,  6,  0,  7,  4,  8,  3],
         [ 4,  5,  7,  0,  6,  3,  7, 13],
         [ 5,  6,  6,  0,  6, 13, 29, 10],
         [ 6,  7,  6,  0,  6, 10, 23, 11],
         [ 6,  8,  6,  0,  6, 10, 24, 12],
         [ 6,  9,  6,  0,  7, 10, 22,  9],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [10,  2,  6,  0,  6,  8, 18,  5],
         [ 4, 11,  7,  0,  6,  3,  5,  2],
         [11, 12,  6,  0,  6,  2,  3,  1],
         [-1, -1, -1, -1, -1, -1, -1, -1]]),
 tensor([[-1, -1, -1, -1, -1, -1, -1, -1],
         [ 1,  2,  6,  0,  6,  7, 15,  6],
         [ 2,  3,  6,  0,  6,  6, 13,  5],
         [ 3,  4,  6,  0,  6,  5, 11,  4],
         [ 4,  5,  6,  0,  6,  4,  9,  3],
         [ 5,  0,  6,  0,  6,  3,  8,  8],
         [ 5,  6,  6,  0,  7,  3,  6,  2],
         [ 6,  7,  7,  0,  6,  2,  4,  1],
         

In [20]:
targets

[tensor([[-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [ 9, 10,  7,  0,  6,  9, 20,  8],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [12, 13,  6,  1,  6,  1,  1,  0]]),
 tensor([[ 0,  1,  6,  0,  6,  8, 18,  7],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         [-1, -1, -1, -1, -1, -1, -1, -1],
         

In [21]:
model = DFSCodeSeq2SeqFC(n_atoms=config.n_atoms,
                         n_bonds=config.n_bonds, 
                         emb_dim=config.emb_dim, 
                         nhead=config.nhead, 
                         nlayers=config.nlayers, 
                         max_nodes=config.max_nodes, 
                         max_edges=config.max_edges,
                         atom_embedding=nn.Linear(n_node_features, config.emb_dim), 
                         bond_embedding=nn.Linear(n_edge_features, config.emb_dim),
                         missing_value = config.missing_value)

In [22]:
if config.pretrained_dir is not None:
    model.load_state_dict(torch.load(config.pretrained_dir+'checkpoint.pt'), strict=False)

In [23]:
if config.load_last:
    model.load_state_dict(torch.load(config.model_dir+'checkpoint.pt'))

In [24]:
optim = optimizers.Adam(model.parameters(), lr=config.lr)

lr_scheduler = optimizers.lr_scheduler.ReduceLROnPlateau(optim, mode='min', verbose=True, patience=config.patience, factor=config.factor)
#lr_scheduler = optimizers.lr_scheduler.ExponentialLR(optim, gamma=config.factor)

early_stopping = EarlyStopping(patience=config.valid_patience, delta=config.valid_minimal_improvement,
                              path=config.model_dir+'checkpoint.pt')
bce = torch.nn.BCEWithLogitsLoss()
ce = torch.nn.CrossEntropyLoss(ignore_index=-1)
softmax = nn.Softmax(dim=2)

In [25]:
model.to(device)

DFSCodeSeq2SeqFC(
  (encoder): DFSCodeEncoder(
    (emb_dfs): PositionalEncoding(
      (dropout): Dropout(p=0, inplace=False)
    )
    (emb_seq): PositionalEncoding(
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (emb_atom): Linear(in_features=133, out_features=120, bias=True)
    (emb_bond): Linear(in_features=14, out_features=120, bias=True)
    (mixer): Linear(in_features=600, out_features=600, bias=True)
    (enc): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=600, out_features=600, bias=True)
          )
          (linear1): Linear(in_features=600, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=600, bias=True)
          (norm1): LayerNorm((600,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((600,), eps=1e-05, 

In [26]:
try:
    for epoch in range(config.n_epochs):
        epoch_loss = 0
        for split in range(config.n_splits):
            n_ids = config.n_files//config.n_splits
            dataset = PubChem(path, n_used = n_ids, max_nodes=config.max_nodes, max_edges=config.max_edges)
            loader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, pin_memory=False, 
                                collate_fn=collate_fn_features,
                   num_workers=config.num_workers, prefetch_factor=config.prefetch_factor, 
                    persistent_workers=config.persistent_workers)
            for j in range(config.n_iter_per_split):
                pbar = tqdm.tqdm(loader)
                for i, data in enumerate(pbar):
                    if i % config.accumulate_grads == 0: #bei 0 wollen wir das
                        optim.zero_grad()
                    rndc, nattr, eattr, minc = data
                    if config.mode == "min2min":
                        code = to_cuda(minc)
                    elif config.mode == "rnd2rnd":
                        code = to_cuda(rndc)
                    else:
                        raise ValueError("unrecognized config.mode")
                    nattr = to_cuda(nattr)
                    eattr = to_cuda(eattr)
                    #prepare inputs
                    inputs, targets = BERTize(code)
                    #prepare labels
                    targetc_seq = nn.utils.rnn.pad_sequence(targets, padding_value=-1)
                    #prediction
                    dfs1, dfs2, atm1, atm2, bnd = model(inputs, nattr, eattr)
                    pred_dfs1 = torch.reshape(dfs1, (-1, config.max_nodes))
                    pred_dfs2 = torch.reshape(dfs2, (-1, config.max_nodes))
                    pred_atm1 = torch.reshape(atm1, (-1, config.n_atoms))
                    pred_atm2 = torch.reshape(atm2, (-1, config.n_atoms))
                    pred_bnd = torch.reshape(bnd, (-1, config.n_bonds))
                    tgt_dfs1 = targetc_seq[:, :, 0].view(-1)
                    tgt_dfs2 = targetc_seq[:, :, 1].view(-1)
                    tgt_atm1 = targetc_seq[:, :, 2].view(-1)
                    tgt_atm2 = targetc_seq[:, :, 4].view(-1)
                    tgt_bnd = targetc_seq[:, :, 3].view(-1)
                    loss = ce(pred_dfs1, tgt_dfs1) 
                    loss += ce(pred_dfs2, tgt_dfs2)
                    loss += ce(pred_atm1, tgt_atm1)
                    loss += ce(pred_bnd, tgt_bnd)
                    loss += ce(pred_atm2, tgt_atm2)
                    loss.backward()
                    if (i+1) % config.accumulate_grads == 0:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
                        optim.step() # bei 2 wollen wir das
                    epoch_loss = (epoch_loss*i + loss.item())/(i+1)
                    mask = tgt_dfs1 != -1
                    n_tgts = torch.sum(mask)
                    acc_dfs1 = (torch.argmax(pred_dfs1[mask], axis=1) == tgt_dfs1[mask]).sum()/n_tgts
                    acc_dfs2 = (torch.argmax(pred_dfs2[mask], axis=1) == tgt_dfs2[mask]).sum()/n_tgts
                    acc_atm1 = (torch.argmax(pred_atm1[mask], axis=1) == tgt_atm1[mask]).sum()/n_tgts
                    acc_atm2 = (torch.argmax(pred_atm2[mask], axis=1) == tgt_atm2[mask]).sum()/n_tgts
                    acc_bnd = (torch.argmax(pred_bnd[mask], axis=1) == tgt_bnd[mask]).sum()/n_tgts
                    curr_lr = list(optim.param_groups)[0]['lr']
                    wandb.log({'loss':epoch_loss, 'learning rate':curr_lr,
                               'acc-dfs1':acc_dfs1, 'acc-dfs2':acc_dfs2, 
                               'acc-atm1':acc_atm1, 'acc-atm2':acc_atm2,
                               'acc-bnd':acc_bnd})
                    pbar.set_description('Epoch %d: CE %2.6f accs: %2.2f %2.2f %2.2f %2.2f %2.2f'%(epoch+1, 
                                                                                                   epoch_loss, 
                                                                                                   100*acc_dfs1,
                                                                                                   100*acc_dfs2,
                                                                                                   100*acc_atm1,
                                                                                                   100*acc_bnd,
                                                                                                   100*acc_atm2))

                    if i % config.lr_adjustment_period == 0:
                        early_stopping(epoch_loss, model)
                        lr_scheduler.step(epoch_loss)
                        if early_stopping.early_stop:
                            break

                        if curr_lr < config.minimal_lr:
                            break
            if early_stopping.early_stop:
                break
            if curr_lr < config.minimal_lr:
                break
                
            del dataset
            del loader
            
        if early_stopping.early_stop:
            break
        if curr_lr < config.minimal_lr:
            break

except KeyboardInterrupt:
    torch.save(model.state_dict(), config.model_dir+'_keyboardinterrupt.pt')
    print('keyboard interrupt caught')


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:15<00:00,  3.87s/it]
100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 624682/624682 [01:05<00:00, 9546.24it/s]
  0%|                                                                                                                                                                                    | 0/12494 [00:00<?, ?it/s]

tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(20, device='cuda:1')
tensor(17, device='cuda:1')
tensor(29, device='cuda:1')
tensor(31, device='cuda:1')
tensor(23, device='cuda:1')
tensor(42, device='cuda:1')
tensor(19, device='cuda:1')
tensor(43, device='cuda:1')
tensor(12, device='cuda:1')
tensor(31, device='cuda:1')
tensor(29, device='cuda:1')
tensor(17, device='cuda:1')
tensor(11, device='cuda:1')
tensor(29, device='cuda:1')
tensor(18, device='cuda:1')
tensor(73, device='cuda:1')
tensor(77, device='cuda:1')
tensor(26, device='cuda:1')
tensor(17, device='cuda:1')
tensor(32, device='cuda:1')
tensor(14, device='cuda:1')
tensor(97, device='cuda:1')
tensor(29, device='cuda:1')
tensor(16, device='cuda:1')
tensor(17, device='cuda:1')
tensor(18, device='cuda:1')
tensor(18, device='cuda:1')
tensor(23, device='cuda:1')
tensor(17, device='cuda:1')
tensor(20, device='cuda:1')
tensor(28, device='cuda:1')
tensor(18, device='cuda:1')
tensor(19, device='cuda:1')
tensor(17, device='c

Epoch 1: CE 0.660161 accs: 91.92 93.94 97.47 98.99 95.45:   0%|                                                                                                                | 1/12494 [00:01<5:15:09,  1.51s/it]

tensor(12, device='cuda:1')
tensor(33, device='cuda:1')
tensor(19, device='cuda:1')
tensor(23, device='cuda:1')
tensor(29, device='cuda:1')
tensor(33, device='cuda:1')
tensor(14, device='cuda:1')
tensor(25, device='cuda:1')
tensor(17, device='cuda:1')
tensor(21, device='cuda:1')
tensor(23, device='cuda:1')
tensor(18, device='cuda:1')
tensor(27, device='cuda:1')
tensor(23, device='cuda:1')
tensor(38, device='cuda:1')
tensor(24, device='cuda:1')
tensor(23, device='cuda:1')
tensor(24, device='cuda:1')
tensor(21, device='cuda:1')
tensor(16, device='cuda:1')
tensor(18, device='cuda:1')
tensor(35, device='cuda:1')
tensor(32, device='cuda:1')
tensor(16, device='cuda:1')
tensor(18, device='cuda:1')
tensor(21, device='cuda:1')
tensor(32, device='cuda:1')
tensor(19, device='cuda:1')
tensor(17, device='cuda:1')
tensor(34, device='cuda:1')
tensor(50, device='cuda:1')
tensor(17, device='cuda:1')
tensor(13, device='cuda:1')
tensor(22, device='cuda:1')
tensor(25, device='cuda:1')
tensor(17, device='c

Epoch 1: CE 0.382695 accs: 100.00 99.36 98.08 99.36 94.23:   0%|                                                                                                               | 3/12494 [00:02<1:53:28,  1.83it/s]

tensor(24, device='cuda:1')
tensor(25, device='cuda:1')
tensor(17, device='cuda:1')
tensor(24, device='cuda:1')
tensor(13, device='cuda:1')
tensor(31, device='cuda:1')
tensor(19, device='cuda:1')
tensor(16, device='cuda:1')
tensor(16, device='cuda:1')
tensor(17, device='cuda:1')
tensor(41, device='cuda:1')
tensor(27, device='cuda:1')
tensor(17, device='cuda:1')
tensor(23, device='cuda:1')
tensor(11, device='cuda:1')
tensor(15, device='cuda:1')
tensor(16, device='cuda:1')
tensor(18, device='cuda:1')
tensor(17, device='cuda:1')
tensor(21, device='cuda:1')
tensor(25, device='cuda:1')
tensor(30, device='cuda:1')
tensor(11, device='cuda:1')
tensor(27, device='cuda:1')
tensor(30, device='cuda:1')
tensor(11, device='cuda:1')
tensor(14, device='cuda:1')
tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(27, device='cuda:1')
tensor(21, device='cuda:1')
tensor(29, device='cuda:1')
tensor(29, device='cuda:1')
tensor(19, device='cuda:1')
tensor(15, device='cuda:1')
tensor(24, device='c

Epoch 1: CE 0.333550 accs: 99.41 98.82 97.65 98.82 96.47:   0%|                                                                                                                | 5/12494 [00:02<1:21:07,  2.57it/s]

tensor(27, device='cuda:1')
tensor(17, device='cuda:1')
tensor(28, device='cuda:1')
tensor(26, device='cuda:1')
tensor(34, device='cuda:1')
tensor(16, device='cuda:1')
tensor(17, device='cuda:1')
tensor(36, device='cuda:1')
tensor(17, device='cuda:1')
tensor(21, device='cuda:1')
tensor(15, device='cuda:1')
tensor(16, device='cuda:1')
tensor(17, device='cuda:1')
tensor(19, device='cuda:1')
tensor(26, device='cuda:1')
tensor(22, device='cuda:1')
tensor(28, device='cuda:1')
tensor(28, device='cuda:1')
tensor(15, device='cuda:1')
tensor(12, device='cuda:1')
tensor(27, device='cuda:1')
tensor(31, device='cuda:1')
tensor(15, device='cuda:1')
tensor(23, device='cuda:1')
tensor(20, device='cuda:1')
tensor(19, device='cuda:1')
tensor(36, device='cuda:1')
tensor(19, device='cuda:1')
tensor(27, device='cuda:1')
tensor(34, device='cuda:1')
tensor(13, device='cuda:1')
tensor(12, device='cuda:1')
tensor(23, device='cuda:1')
tensor(28, device='cuda:1')
tensor(16, device='cuda:1')
tensor(36, device='c

Epoch 1: CE 0.342085 accs: 96.26 97.33 98.93 99.47 95.72:   0%|                                                                                                                | 6/12494 [00:03<1:17:02,  2.70it/s]

tensor(32, device='cuda:1')
tensor(54, device='cuda:1')
tensor(32, device='cuda:1')
tensor(24, device='cuda:1')
tensor(28, device='cuda:1')
tensor(21, device='cuda:1')
tensor(35, device='cuda:1')
tensor(17, device='cuda:1')
tensor(20, device='cuda:1')
tensor(28, device='cuda:1')
tensor(15, device='cuda:1')
tensor(25, device='cuda:1')
tensor(25, device='cuda:1')
tensor(12, device='cuda:1')
tensor(23, device='cuda:1')
tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(28, device='cuda:1')
tensor(17, device='cuda:1')
tensor(25, device='cuda:1')
tensor(17, device='cuda:1')
tensor(18, device='cuda:1')
tensor(31, device='cuda:1')
tensor(15, device='cuda:1')
tensor(19, device='cuda:1')
tensor(23, device='cuda:1')
tensor(24, device='cuda:1')
tensor(19, device='cuda:1')
tensor(18, device='cuda:1')
tensor(13, device='cuda:1')
tensor(26, device='cuda:1')
tensor(16, device='cuda:1')
tensor(27, device='cuda:1')
tensor(36, device='cuda:1')
tensor(20, device='cuda:1')
tensor(17, device='c

Epoch 1: CE 0.314857 accs: 99.43 99.43 100.00 98.86 97.14:   0%|                                                                                                               | 7/12494 [00:03<1:09:35,  2.99it/s]

tensor(15, device='cuda:1')
tensor(26, device='cuda:1')
tensor(30, device='cuda:1')
tensor(18, device='cuda:1')
tensor(19, device='cuda:1')
tensor(22, device='cuda:1')
tensor(17, device='cuda:1')
tensor(31, device='cuda:1')
tensor(23, device='cuda:1')
tensor(17, device='cuda:1')
tensor(23, device='cuda:1')
tensor(15, device='cuda:1')
tensor(28, device='cuda:1')
tensor(14, device='cuda:1')
tensor(24, device='cuda:1')
tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(38, device='cuda:1')
tensor(28, device='cuda:1')
tensor(12, device='cuda:1')
tensor(29, device='cuda:1')
tensor(14, device='cuda:1')
tensor(15, device='cuda:1')
tensor(22, device='cuda:1')
tensor(29, device='cuda:1')
tensor(17, device='cuda:1')
tensor(23, device='cuda:1')
tensor(16, device='cuda:1')
tensor(28, device='cuda:1')
tensor(20, device='cuda:1')
tensor(28, device='cuda:1')
tensor(34, device='cuda:1')
tensor(17, device='cuda:1')
tensor(34, device='cuda:1')
tensor(20, device='cuda:1')
tensor(22, device='c

Epoch 1: CE 0.301455 accs: 100.00 100.00 98.29 100.00 94.86:   0%|                                                                                                             | 8/12494 [00:03<1:04:41,  3.22it/s]

tensor(12, device='cuda:1')
tensor(32, device='cuda:1')
tensor(18, device='cuda:1')
tensor(31, device='cuda:1')
tensor(33, device='cuda:1')
tensor(21, device='cuda:1')
tensor(21, device='cuda:1')
tensor(23, device='cuda:1')
tensor(24, device='cuda:1')
tensor(31, device='cuda:1')
tensor(17, device='cuda:1')
tensor(15, device='cuda:1')
tensor(11, device='cuda:1')
tensor(31, device='cuda:1')
tensor(18, device='cuda:1')
tensor(17, device='cuda:1')
tensor(12, device='cuda:1')
tensor(14, device='cuda:1')
tensor(23, device='cuda:1')
tensor(29, device='cuda:1')
tensor(28, device='cuda:1')
tensor(13, device='cuda:1')
tensor(14, device='cuda:1')
tensor(18, device='cuda:1')
tensor(24, device='cuda:1')
tensor(27, device='cuda:1')
tensor(13, device='cuda:1')
tensor(12, device='cuda:1')
tensor(23, device='cuda:1')
tensor(27, device='cuda:1')
tensor(17, device='cuda:1')
tensor(23, device='cuda:1')
tensor(30, device='cuda:1')
tensor(48, device='cuda:1')
tensor(28, device='cuda:1')
tensor(29, device='c

Epoch 1: CE 0.278302 accs: 98.82 98.82 99.41 100.00 97.04:   0%|                                                                                                                | 10/12494 [00:04<59:42,  3.48it/s]

tensor(14, device='cuda:1')
tensor(19, device='cuda:1')
tensor(12, device='cuda:1')
tensor(24, device='cuda:1')
tensor(23, device='cuda:1')
tensor(16, device='cuda:1')
tensor(28, device='cuda:1')
tensor(27, device='cuda:1')
tensor(18, device='cuda:1')
tensor(17, device='cuda:1')
tensor(27, device='cuda:1')
tensor(34, device='cuda:1')
tensor(51, device='cuda:1')
tensor(19, device='cuda:1')
tensor(14, device='cuda:1')
tensor(19, device='cuda:1')
tensor(23, device='cuda:1')
tensor(17, device='cuda:1')
tensor(12, device='cuda:1')
tensor(25, device='cuda:1')
tensor(29, device='cuda:1')
tensor(26, device='cuda:1')
tensor(31, device='cuda:1')
tensor(19, device='cuda:1')
tensor(28, device='cuda:1')
tensor(51, device='cuda:1')
tensor(20, device='cuda:1')
tensor(30, device='cuda:1')
tensor(12, device='cuda:1')
tensor(13, device='cuda:1')
tensor(16, device='cuda:1')
tensor(29, device='cuda:1')
tensor(27, device='cuda:1')
tensor(19, device='cuda:1')
tensor(19, device='cuda:1')
tensor(22, device='c

Epoch 1: CE 0.302412 accs: 97.06 94.71 98.24 98.24 96.47:   0%|                                                                                                               | 11/12494 [00:04<1:02:01,  3.35it/s]

tensor(26, device='cuda:1')
tensor(16, device='cuda:1')
tensor(21, device='cuda:1')
tensor(19, device='cuda:1')
tensor(15, device='cuda:1')
tensor(35, device='cuda:1')
tensor(18, device='cuda:1')
tensor(21, device='cuda:1')
tensor(21, device='cuda:1')
tensor(43, device='cuda:1')
tensor(20, device='cuda:1')
tensor(19, device='cuda:1')
tensor(20, device='cuda:1')
tensor(19, device='cuda:1')
tensor(17, device='cuda:1')
tensor(16, device='cuda:1')
tensor(15, device='cuda:1')
tensor(16, device='cuda:1')
tensor(25, device='cuda:1')
tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(27, device='cuda:1')
tensor(28, device='cuda:1')
tensor(27, device='cuda:1')
tensor(22, device='cuda:1')
tensor(34, device='cuda:1')
tensor(21, device='cuda:1')
tensor(24, device='cuda:1')
tensor(16, device='cuda:1')
tensor(32, device='cuda:1')
tensor(17, device='cuda:1')
tensor(29, device='cuda:1')
tensor(14, device='cuda:1')
tensor(12, device='cuda:1')
tensor(51, device='c

Epoch 1: CE 0.291965 accs: 98.91 100.00 98.37 99.46 96.74:   0%|                                                                                                              | 13/12494 [00:04<1:00:33,  3.44it/s]

tensor(15, device='cuda:1')
tensor(9, device='cuda:1')
tensor(23, device='cuda:1')
tensor(8, device='cuda:1')
tensor(49, device='cuda:1')
tensor(32, device='cuda:1')
tensor(24, device='cuda:1')
tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(22, device='cuda:1')
tensor(24, device='cuda:1')
tensor(12, device='cuda:1')
tensor(29, device='cuda:1')
tensor(17, device='cuda:1')
tensor(30, device='cuda:1')
tensor(24, device='cuda:1')
tensor(51, device='cuda:1')
tensor(13, device='cuda:1')
tensor(30, device='cuda:1')
tensor(23, device='cuda:1')
tensor(12, device='cuda:1')
tensor(27, device='cuda:1')
tensor(15, device='cuda:1')
tensor(17, device='cuda:1')
tensor(35, device='cuda:1')
tensor(18, device='cuda:1')
tensor(24, device='cuda:1')
tensor(27, device='cuda:1')
tensor(17, device='cuda:1')
tensor(19, device='cuda:1')
tensor(23, device='cuda:1')
tensor(35, device='cuda:1')
tensor(46, device='cuda:1')
tensor(16, device='cuda:1')
tensor(19, device='cuda:1')
tensor(26, device='cud

Epoch 1: CE 0.288170 accs: 98.40 99.47 98.93 100.00 93.05:   0%|▏                                                                                                               | 14/12494 [00:05<59:18,  3.51it/s]

tensor(50, device='cuda:1')
tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(29, device='cuda:1')
tensor(37, device='cuda:1')
tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(45, device='cuda:1')
tensor(22, device='cuda:1')
tensor(44, device='cuda:1')
tensor(35, device='cuda:1')
tensor(22, device='cuda:1')
tensor(16, device='cuda:1')
tensor(29, device='cuda:1')
tensor(18, device='cuda:1')
tensor(23, device='cuda:1')
tensor(17, device='cuda:1')
tensor(16, device='cuda:1')
tensor(15, device='cuda:1')
tensor(9, device='cuda:1')
tensor(26, device='cuda:1')
tensor(13, device='cuda:1')
tensor(24, device='cuda:1')
tensor(15, device='cuda:1')
tensor(17, device='cuda:1')
tensor(18, device='cuda:1')
tensor(19, device='cuda:1')
tensor(17, device='cuda:1')
tensor(22, device='cuda:1')
tensor(17, device='cuda:1')
tensor(20, device='cuda:1')
tensor(17, device='cuda:1')
tensor(24, device='cuda:1')
tensor(15, device='cuda:1')
tensor(21, device='cuda:1')
tensor(24, device='cu

Epoch 1: CE 0.293208 accs: 98.75 98.75 98.75 98.75 94.38:   0%|▏                                                                                                                | 15/12494 [00:05<57:51,  3.59it/s]

tensor(27, device='cuda:1')
tensor(28, device='cuda:1')
tensor(18, device='cuda:1')
tensor(12, device='cuda:1')
tensor(18, device='cuda:1')
tensor(42, device='cuda:1')
tensor(13, device='cuda:1')
tensor(27, device='cuda:1')
tensor(16, device='cuda:1')
tensor(17, device='cuda:1')
tensor(49, device='cuda:1')
tensor(29, device='cuda:1')
tensor(24, device='cuda:1')
tensor(12, device='cuda:1')
tensor(29, device='cuda:1')
tensor(17, device='cuda:1')
tensor(13, device='cuda:1')
tensor(25, device='cuda:1')
tensor(17, device='cuda:1')
tensor(24, device='cuda:1')
tensor(19, device='cuda:1')
tensor(13, device='cuda:1')
tensor(28, device='cuda:1')
tensor(35, device='cuda:1')
tensor(25, device='cuda:1')
tensor(25, device='cuda:1')
tensor(19, device='cuda:1')
tensor(24, device='cuda:1')
tensor(23, device='cuda:1')
tensor(24, device='cuda:1')
tensor(18, device='cuda:1')
tensor(26, device='cuda:1')
tensor(19, device='cuda:1')
tensor(23, device='cuda:1')
tensor(22, device='cuda:1')
tensor(16, device='c

Epoch 1: CE 0.290010 accs: 98.37 98.91 99.46 99.46 94.57:   0%|▏                                                                                                                | 16/12494 [00:05<59:33,  3.49it/s]

tensor(28, device='cuda:1')
tensor(21, device='cuda:1')
tensor(17, device='cuda:1')
tensor(13, device='cuda:1')
tensor(23, device='cuda:1')
tensor(28, device='cuda:1')
tensor(23, device='cuda:1')
tensor(19, device='cuda:1')
tensor(34, device='cuda:1')
tensor(32, device='cuda:1')
tensor(71, device='cuda:1')
tensor(19, device='cuda:1')
tensor(13, device='cuda:1')
tensor(33, device='cuda:1')
tensor(25, device='cuda:1')
tensor(19, device='cuda:1')
tensor(29, device='cuda:1')
tensor(33, device='cuda:1')
tensor(24, device='cuda:1')
tensor(25, device='cuda:1')
tensor(20, device='cuda:1')
tensor(25, device='cuda:1')
tensor(56, device='cuda:1')
tensor(22, device='cuda:1')
tensor(37, device='cuda:1')
tensor(23, device='cuda:1')
tensor(19, device='cuda:1')
tensor(12, device='cuda:1')
tensor(15, device='cuda:1')
tensor(18, device='cuda:1')
tensor(12, device='cuda:1')
tensor(19, device='cuda:1')
tensor(12, device='cuda:1')
tensor(18, device='cuda:1')
tensor(19, device='cuda:1')
tensor(18, device='c

Epoch 1: CE 0.288169 accs: 98.39 98.92 98.39 100.00 95.16:   0%|▏                                                                                                               | 17/12494 [00:06<56:44,  3.66it/s]

tensor(21, device='cuda:1')
tensor(27, device='cuda:1')
tensor(31, device='cuda:1')
tensor(17, device='cuda:1')
tensor(24, device='cuda:1')
tensor(19, device='cuda:1')
tensor(29, device='cuda:1')
tensor(19, device='cuda:1')
tensor(31, device='cuda:1')
tensor(28, device='cuda:1')
tensor(17, device='cuda:1')
tensor(27, device='cuda:1')
tensor(17, device='cuda:1')
tensor(29, device='cuda:1')
tensor(34, device='cuda:1')
tensor(15, device='cuda:1')
tensor(55, device='cuda:1')
tensor(38, device='cuda:1')
tensor(26, device='cuda:1')
tensor(18, device='cuda:1')
tensor(16, device='cuda:1')
tensor(23, device='cuda:1')
tensor(13, device='cuda:1')
tensor(34, device='cuda:1')
tensor(28, device='cuda:1')
tensor(28, device='cuda:1')
tensor(15, device='cuda:1')
tensor(20, device='cuda:1')
tensor(34, device='cuda:1')
tensor(20, device='cuda:1')
tensor(21, device='cuda:1')
tensor(21, device='cuda:1')
tensor(16, device='cuda:1')
tensor(18, device='cuda:1')
tensor(24, device='cuda:1')
tensor(16, device='c

Epoch 1: CE 0.280841 accs: 100.00 99.44 100.00 98.88 96.09:   0%|▏                                                                                                              | 18/12494 [00:06<59:24,  3.50it/s]

tensor(16, device='cuda:1')
tensor(17, device='cuda:1')
tensor(20, device='cuda:1')
tensor(16, device='cuda:1')
tensor(29, device='cuda:1')
tensor(16, device='cuda:1')
tensor(19, device='cuda:1')
tensor(19, device='cuda:1')
tensor(24, device='cuda:1')
tensor(26, device='cuda:1')
tensor(19, device='cuda:1')
tensor(22, device='cuda:1')
tensor(37, device='cuda:1')
tensor(13, device='cuda:1')
tensor(19, device='cuda:1')
tensor(19, device='cuda:1')
tensor(24, device='cuda:1')
tensor(37, device='cuda:1')
tensor(23, device='cuda:1')
tensor(35, device='cuda:1')
tensor(10, device='cuda:1')
tensor(34, device='cuda:1')
tensor(29, device='cuda:1')
tensor(18, device='cuda:1')
tensor(21, device='cuda:1')
tensor(19, device='cuda:1')
tensor(18, device='cuda:1')
tensor(55, device='cuda:1')
tensor(28, device='cuda:1')
tensor(27, device='cuda:1')
tensor(17, device='cuda:1')
tensor(19, device='cuda:1')
tensor(31, device='cuda:1')
tensor(17, device='cuda:1')
tensor(31, device='cuda:1')
tensor(12, device='c

Epoch 1: CE 0.271067 accs: 100.00 100.00 99.40 99.40 97.59:   0%|▏                                                                                                              | 20/12494 [00:06<58:51,  3.53it/s]

tensor(25, device='cuda:1')
tensor(17, device='cuda:1')
tensor(37, device='cuda:1')
tensor(22, device='cuda:1')
tensor(29, device='cuda:1')
tensor(18, device='cuda:1')
tensor(32, device='cuda:1')
tensor(13, device='cuda:1')
tensor(17, device='cuda:1')
tensor(19, device='cuda:1')
tensor(13, device='cuda:1')
tensor(17, device='cuda:1')
tensor(31, device='cuda:1')
tensor(14, device='cuda:1')
tensor(17, device='cuda:1')
tensor(17, device='cuda:1')
tensor(36, device='cuda:1')
tensor(10, device='cuda:1')
tensor(30, device='cuda:1')
tensor(29, device='cuda:1')
tensor(19, device='cuda:1')
tensor(19, device='cuda:1')
tensor(29, device='cuda:1')
tensor(31, device='cuda:1')
tensor(25, device='cuda:1')
tensor(15, device='cuda:1')
tensor(12, device='cuda:1')
tensor(30, device='cuda:1')
tensor(15, device='cuda:1')
tensor(16, device='cuda:1')
tensor(22, device='cuda:1')
tensor(13, device='cuda:1')
tensor(14, device='cuda:1')
tensor(21, device='cuda:1')
tensor(14, device='cuda:1')
tensor(25, device='c

Epoch 1: CE 0.281628 accs: 98.84 98.84 97.69 100.00 91.33:   0%|▏                                                                                                               | 22/12494 [00:07<56:18,  3.69it/s]

tensor(17, device='cuda:1')
tensor(19, device='cuda:1')
tensor(18, device='cuda:1')
tensor(32, device='cuda:1')
tensor(34, device='cuda:1')
tensor(23, device='cuda:1')
tensor(19, device='cuda:1')
tensor(22, device='cuda:1')
tensor(16, device='cuda:1')
tensor(18, device='cuda:1')
tensor(19, device='cuda:1')
tensor(26, device='cuda:1')
tensor(16, device='cuda:1')
tensor(19, device='cuda:1')
tensor(22, device='cuda:1')
tensor(17, device='cuda:1')
tensor(29, device='cuda:1')
tensor(18, device='cuda:1')
tensor(30, device='cuda:1')
tensor(25, device='cuda:1')
tensor(25, device='cuda:1')
tensor(30, device='cuda:1')
tensor(21, device='cuda:1')
tensor(40, device='cuda:1')
tensor(19, device='cuda:1')
tensor(22, device='cuda:1')
tensor(20, device='cuda:1')
tensor(13, device='cuda:1')
tensor(12, device='cuda:1')
tensor(18, device='cuda:1')
tensor(17, device='cuda:1')
tensor(28, device='cuda:1')
tensor(36, device='cuda:1')
tensor(18, device='cuda:1')
tensor(34, device='cuda:1')
tensor(34, device='c

Epoch 1: CE 0.275259 accs: 99.42 100.00 98.26 99.42 98.26:   0%|▏                                                                                                               | 23/12494 [00:07<58:54,  3.53it/s]

tensor(34, device='cuda:1')
tensor(10, device='cuda:1')
tensor(18, device='cuda:1')
tensor(29, device='cuda:1')
tensor(20, device='cuda:1')
tensor(22, device='cuda:1')
tensor(37, device='cuda:1')
tensor(29, device='cuda:1')
tensor(15, device='cuda:1')
tensor(5, device='cuda:1')
tensor(28, device='cuda:1')
tensor(26, device='cuda:1')
tensor(32, device='cuda:1')
tensor(25, device='cuda:1')
tensor(29, device='cuda:1')
tensor(33, device='cuda:1')
tensor(19, device='cuda:1')
tensor(29, device='cuda:1')
tensor(10, device='cuda:1')
tensor(32, device='cuda:1')
tensor(36, device='cuda:1')
tensor(13, device='cuda:1')
tensor(15, device='cuda:1')
tensor(23, device='cuda:1')
tensor(24, device='cuda:1')
tensor(19, device='cuda:1')
tensor(27, device='cuda:1')
tensor(30, device='cuda:1')
tensor(22, device='cuda:1')
tensor(19, device='cuda:1')
tensor(22, device='cuda:1')
tensor(21, device='cuda:1')
tensor(19, device='cuda:1')
tensor(14, device='cuda:1')
tensor(19, device='cuda:1')
tensor(33, device='cu

/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [63,0,0], thread: [120,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [63,0,0], thread: [121,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [63,0,0], thread: [122,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [63,0,0], thread: [123,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [63,0,0], thread: [124,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:97: operator(): block: [6

tensor(38, device='cuda:1')
tensor(17, device='cuda:1')
tensor(29, device='cuda:1')
tensor(23, device='cuda:1')
tensor(34, device='cuda:1')
tensor(42, device='cuda:1')
tensor(23, device='cuda:1')
tensor(21, device='cuda:1')
tensor(20, device='cuda:1')
tensor(13, device='cuda:1')
tensor(36, device='cuda:1')
tensor(43, device='cuda:1')
tensor(23, device='cuda:1')
tensor(21, device='cuda:1')
tensor(28, device='cuda:1')
tensor(17, device='cuda:1')
tensor(16, device='cuda:1')
tensor(18, device='cuda:1')
tensor(30, device='cuda:1')
tensor(22, device='cuda:1')
tensor(24, device='cuda:1')
tensor(17, device='cuda:1')
tensor(23, device='cuda:1')
tensor(29, device='cuda:1')
tensor(18, device='cuda:1')
tensor(14, device='cuda:1')
tensor(312, device='cuda:1')


RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.